Compare commits

..

49 Commits

Author SHA1 Message Date
Pascal Seitz
3ddca31292 simplify fetch block in column_block_accessor 2025-12-29 10:56:50 +01:00
Pascal Seitz
87fe3a311f share column block accessor 2025-12-12 16:54:44 +08:00
Pascal Seitz
71dc08424c add comment 2025-12-12 16:19:06 +08:00
Pascal Seitz
15913446b8 cleanup
remove clone
move data in term req, single doc opt for stats
2025-12-10 14:34:43 +08:00
Pascal Seitz
78bd3826dc remove stacktrace bloat, use &mut helper
increase cache to 2048
2025-12-08 10:35:23 +08:00
Pascal Seitz
1b56487307 specialize columntype in stats 2025-12-08 10:20:09 +08:00
Pascal Seitz
030554d544 use radix map, fix prepare_max_bucket
use paged term map in term agg
use special no sub agg term map impl
2025-12-08 10:20:09 +08:00
Pascal Seitz
c852bac532 reduce dynamic dispatch, faster term agg 2025-12-08 10:20:09 +08:00
Pascal Seitz
2ce4da8b66 one collector per agg request instead per bucket
In this refactoring a collector knows in which bucket of the parent
their data is in. This allows to convert the previous approach of one
collector per bucket to one collector per request.

low card bucket optimization
2025-12-08 10:20:09 +08:00
Pascal Seitz
0dd6a958f8 add more tests for new collection type 2025-12-08 10:20:09 +08:00
Pascal Seitz
254314a4a3 improve bench 2025-12-07 10:05:30 +08:00
PSeitz
b2f99c6217 add term->histogram benchmark (#2758)
* add term->histogram benchmark

* add more term aggs

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-12-04 02:29:37 +01:00
PSeitz
76de5bab6f fix unsafe warnings (#2757) 2025-12-03 20:15:21 +08:00
rustmailer
b7eb31162b docs: add usage example to README (#2743) 2025-12-02 21:56:57 +01:00
Paul Masurel
63c66005db Lazy scorers (#2726)
* Refactoring of the score tweaker into `SortKeyComputer`s to unlock two features.

- Allow lazy evaluation of score. As soon as we identified that a doc won't
reach the topK threshold, we can stop the evaluation.
- Allow for a different segment level score, segment level score and their conversion.

This PR breaks public API, but fixing code is straightforward.

* Bumping tantivy version

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-12-01 15:38:57 +01:00
Paul Masurel
7d513a44c5 Added some benchmark for top K by a fast field (#2754)
Also removed query parsing from the bench code.

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-12-01 14:58:29 +01:00
Stu Hood
ca87fcd454 Implement collect_block for Collectors which wrap other Collectors (#2727)
* Implement `collect_block` for tuple Collectors, and for MultiCollector.

* Two more.
2025-12-01 12:26:29 +01:00
Ang
08a92675dc Fix typos again (#2753)
Found via `codespell -S benches,stopwords.rs -L
womens,parth,abd,childs,ond,ser,ue,mot,hel,atleast,pris,claus,allo`
2025-12-01 12:15:41 +01:00
Raphaël Cohen
f7f4b354d6 fix: Handle phrase prefixed with star (#2751)
Signed-off-by: Darkheir <raphael.cohen@sekoia.io>
2025-12-01 11:43:25 +01:00
Paul Masurel
25d44fcec8 Revert "remove unused columnar api (#2742)" (#2748)
* Revert "remove unused columnar api (#2742)"

This reverts commit 8725594d47.

* Clippy comment + removing fill_vals

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-11-26 17:44:02 +01:00
PSeitz-dd
842fe9295f split Term in Term and IndexingTerm (#2744)
* split Term in Term and IndexingTerm

* add append_json_path to JsonTermSerializer
2025-11-26 16:48:59 +01:00
Paul Masurel
f88b7200b2 Optimization when posting list are saturated. (#2745)
* Optimization when posting list are saturated.

If a posting list doc freq is the segment reader's
max_doc, and if scoring does not matter, we can replace it
by a AllScorer.

In turn, in a boolean query, we can dismiss  all scorers and
empty scorers, to accelerate the request.

* Added range query optimization

* CR comment

* CR comments

* CR comment

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-11-26 15:50:57 +01:00
PSeitz-dd
8725594d47 remove unused columnar api (#2742) 2025-11-21 18:07:25 +01:00
PSeitz
43a784671a clippy (#2741)
Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-11-21 18:07:03 +01:00
Paul Masurel
c363bbd23d Optimize term aggregation with low cardinality + some refactoring (#2740)
This introduce an optimization of top level term aggregation on field with a low cardinality.

We then use a Vec as the underlying map.
In addition, we buffer subaggregations.

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
Co-authored-by: Paul Masurel <paul@quickwit.io>
2025-11-21 14:46:29 +01:00
Moe
70e591e230 feat: added filter aggregation (#2711)
* Initial impl

* Added `Filter` impl in `build_single_agg_segment_collector_with_reader` + Added tests

* Added `Filter(FilterBucketResult)` + Made tests work.

* Fixed type issues.

* Fixed a test.

* 8a7a73a: Pass `segment_reader`

* Added more tests.

* Improved parsing + tests

* refactoring

* Added more tests.

* refactoring: moved parsing code under QueryParser

* Use Tantivy syntax instead of ES

* Added a sanity check test.

* Simplified impl + tests

* Added back tests in a more maintable way

* nitz.

* nitz

* implemented very simple fast-path

* improved a comment

* implemented fast field support

* Used `BoundsRange`

* Improved fast field impl + tests

* Simplified execution.

* Fixed exports + nitz

* Improved the tests to check to the expected result.

* Improved test by checking the whole result JSON

* Removed brittle perf checks.

* Added efficiency verification tests.

* Added one more efficiency check test.

* Improved the efficiency tests.

* Removed unnecessary parsing code + added direct Query obj

* Fixed tests.

* Improved tests

* Fixed code structure

* Fixed lint issues

* nitz.

* nitz

* nitz.

* nitz.

* nitz.

* Added an example

* Fixed PR comments.

* Applied PR comments + nitz

* nitz.

* Improved the code.

* Fixed a perf issue.

* Added batch processing.

* Made the example more interesting

* Fixed bucket count

* Renamed Direct to CustomQuery

* Fixed lint issues.

* No need for scorer to be an `Option`

* nitz

* Used BitSet

* Added an optimization for AllQuery

* Fixed merge issues.

* Fixed lint issues.

* Added benchmark for FILTER

* Removed the Option wrapper.

* nitz.

* Applied PR comments.

* Fixed the AllQuery optimization

* Applied PR comments.

* feat: used `erased_serde` to allow filter query to be serialized

* further improved a comment

* Added back tests.

* removed an unused method

* removed an unused method

* Added documentation

* nitz.

* Added query builder.

* Fixed a comment.

* Applied PR comments.

* Fixed doctest issues.

* Added ser/de

* Removed bench in test

* Fixed a lint issue.
2025-11-18 20:54:31 +01:00
Arthur
5277367cb0 remove duplicated call to index_writer.commit() in example (#2732) 2025-11-12 14:52:44 +01:00
Paul Masurel
8b02bff9b8 Removing obsolete benchmark screenshot (#2730)
Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-11-05 09:55:13 +01:00
PSeitz
60225bdd45 cleanup (#2724)
Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-10-23 10:23:34 +02:00
PSeitz
938bfec8b7 use FxHashMap for Aggregations Request (#2722)
Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-10-21 15:59:18 +02:00
PSeitz
dabcaa5809 fix merge intermediate aggregation results (#2719)
Previously the merging relied on the order of the results, which is invalid since https://github.com/quickwit-oss/tantivy/pull/2035.
This bug is only hit in specific scenarios, when the aggregation collectors are built in a different order on different segments.

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-10-17 12:41:31 +02:00
PSeitz
d410a3b0c0 Add Filtering for Term Aggregations (#2717)
* Add Filtering for Term Aggregations

Closes #2702

* add AggregationsSegmentCtx memory consumption

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-10-15 17:39:53 +02:00
Remi
fc93391d0e Minor clarifications on the AggregationsWithAccessor refacto (#2716) 2025-10-14 19:59:33 +02:00
PSeitz
f8e79271ab Replace AggregationsWithAccessor (#2715)
* add nested histogram-termagg benchmark

* Replace AggregationsWithAccessor with AggData

With AggregationsWithAccessor pre-computation and caching was done on the collector level.
If you have 10000 sub collectors (e.g. a term aggregation with sub aggregations) this is very inefficient.
`AggData` instead moves the data from the collector to a node which reflects the cardinality of the request tree instead of the cardinality of the segment collector.
It also moves the global struct shared with all aggregations in to aggregation specific structs. So each aggregation has its own space to store cached data and aggregation specific information.

This also breaks up the dependency to the elastic search aggregation structure somewhat.

Due to lifetime issues, we move the agg request specific object out of `AggData` during the collection and move it back at the end (for now). That's some unnecessary work, which costs CPU.

This allows better caching and will also pave the way for another potential optimization, by separating the collector and its storage. Currently we allocate a new collector for each sub aggregation bucket (for nested aggregations), but ideally we would have just one collector instance.

* renames

* move request data to agg request files

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-10-14 09:22:11 +02:00
PSeitz
33835b6a01 Add DocSet::cost() (#2707)
* query: add DocSet cost hint and use it for intersection ordering

- Add DocSet::cost()
- Use cost() instead of size_hint() to order scorers in intersect_scorers

This isolates cost-related changes without the new seek APIs from
PR #2538

* add comments

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-10-13 16:25:49 +02:00
PSeitz
270ca5123c refactor postings (#2709)
rename shallow_seek to seek_block
remove full_block from public postings API

This is as preparation to optionally handle Bitsets in the postings
2025-10-08 16:55:25 +02:00
Mustafa S. Moiz
714366d3b9 docs: correct grammar (#2704)
Correct phrasing for a single line in the docs (`one documents` -> `a document`).
2025-10-08 16:47:09 +02:00
PSeitz-dd
40659d4d07 improve naming in buffered_union (#2705) 2025-09-24 10:58:46 +02:00
PSeitz
e1e131a804 add and/or queries benchmark (#2701) 2025-09-22 16:32:49 +02:00
PSeitz-dd
70da310b2d perf: deduplicate queries (#2698)
* deduplicate queries

Deduplicate queries in the UserInputAst after parsing queries

* add return type
2025-09-22 12:16:58 +02:00
PSeitz
85010b589a clippy (#2700)
* clippy

* clippy

* clippy

* clippy + fmt

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-09-19 18:04:25 +02:00
PSeitz-dd
2340dca628 fix compiler warnings (#2699)
* fix compiler warnings

* fix import
2025-09-19 15:55:04 +02:00
Remi
71a26d5b24 Fix CI with rust 1.90 (#2696)
* Empty commit

* Fix dead code lint error
2025-09-18 23:06:33 +02:00
PSeitz-dd
203751f2fe Optimize ExistsQuery for a high number of dynamic columns (#2694)
* Optimize ExistsQuery for a high number of dynamic columns

The previous algorithm checked _each_ doc in _each_ column for
existence. This causes huge cost on JSON fields with e.g. 100k columns.
Compute a bitset instead if we have more than one column.

add `iter_docs` to the multivalued_index

* add benchmark

subfields=1
exists_json_union    Memory: 89.3 KB (+2.01%)    Avg: 0.4865ms (-26.03%)    Median: 0.4865ms (-26.03%)    [0.4865ms .. 0.4865ms]
subfields=2
exists_json_union    Memory: 68.1 KB     Avg: 1.7048ms (-0.46%)    Median: 1.7048ms (-0.46%)    [1.7048ms .. 1.7048ms]
subfields=3
exists_json_union    Memory: 61.8 KB     Avg: 2.0742ms (-2.22%)    Median: 2.0742ms (-2.22%)    [2.0742ms .. 2.0742ms]
subfields=4
exists_json_union    Memory: 119.8 KB (+103.44%)    Avg: 3.9500ms (+42.62%)    Median: 3.9500ms (+42.62%)    [3.9500ms .. 3.9500ms]
subfields=5
exists_json_union    Memory: 120.4 KB (+107.65%)    Avg: 3.9610ms (+20.65%)    Median: 3.9610ms (+20.65%)    [3.9610ms .. 3.9610ms]
subfields=6
exists_json_union    Memory: 120.6 KB (+107.49%)    Avg: 3.8903ms (+3.11%)    Median: 3.8903ms (+3.11%)    [3.8903ms .. 3.8903ms]
subfields=7
exists_json_union    Memory: 120.9 KB (+106.93%)    Avg: 3.6220ms (-16.22%)    Median: 3.6220ms (-16.22%)    [3.6220ms .. 3.6220ms]
subfields=8
exists_json_union    Memory: 121.3 KB (+106.23%)    Avg: 4.0981ms (-15.97%)    Median: 4.0981ms (-15.97%)    [4.0981ms .. 4.0981ms]
subfields=16
exists_json_union    Memory: 123.1 KB (+103.09%)    Avg: 4.3483ms (-92.26%)    Median: 4.3483ms (-92.26%)    [4.3483ms .. 4.3483ms]
subfields=256
exists_json_union    Memory: 204.6 KB (+19.85%)    Avg: 3.8874ms (-99.01%)    Median: 3.8874ms (-99.01%)    [3.8874ms .. 3.8874ms]
subfields=4096
exists_json_union    Memory: 2.0 MB     Avg: 3.5571ms (-99.90%)    Median: 3.5571ms (-99.90%)    [3.5571ms .. 3.5571ms]
subfields=65536
exists_json_union    Memory: 28.3 MB     Avg: 14.4417ms (-99.97%)    Median: 14.4417ms (-99.97%)    [14.4417ms .. 14.4417ms]
subfields=262144
exists_json_union    Memory: 113.3 MB     Avg: 66.2860ms (-99.95%)    Median: 66.2860ms (-99.95%)    [66.2860ms .. 66.2860ms]

* rename methods
2025-09-16 18:21:03 +02:00
PSeitz-dd
7963b0b4aa Add fast field fallback for term query if not indexed (#2693)
* Add fast field fallback for term query if not indexed

* only fallback without scores
2025-09-12 14:58:21 +02:00
Paul Masurel
d5eefca11d Merge pull request #2692 from quickwit-oss/paul.masurel/coerce-floats-too-in-search-too
This PR changes the logic used on the ingestion of floats.
2025-09-10 09:46:54 +02:00
Paul Masurel
5d6c8de23e Align search float search logic to the columnar coercion rules
It applies the same logic on floats as for u64 or i64.
In all case, the idea is (for the inverted index) to coerce number
to their canonical representation, before indexing and before searching.

That way a document with the float 1.0 will be searchable when the user
searches for 1.

Note that contrary to the columnar, we do not attempt to coerce all of the
terms associated to a given json path to a single numerical type.
We simply rely on this "point-wise" canonicalization.
2025-09-09 19:28:17 +02:00
PSeitz
a06365f39f Update CHANGELOG.md for bugfixes (#2674)
* Update CHANGELOG.md

* Update CHANGELOG.md
2025-09-04 11:51:00 +02:00
Raphaël Cohen
f4b374110f feat: Regex query grammar (#2677)
* feat: Regex query grammar

* feat: Disable regexes by default

* chore: Apply formatting
2025-09-03 10:07:04 +02:00
167 changed files with 11518 additions and 3935 deletions

View File

@@ -14,6 +14,18 @@ Tantivy 0.25
- Support mixed field types in query parser [#2676](https://github.com/quickwit-oss/tantivy/pull/2676)(@trinity-1686a)
- Add per-field size details [#2679](https://github.com/quickwit-oss/tantivy/pull/2679)(@fulmicoton)
Tantivy 0.24.2
================================
- Fix TopNComputer for reverse order. [#2672](https://github.com/quickwit-oss/tantivy/pull/2672)(@stuhood @PSeitz)
Affected queries are [order_by_fast_field](https://docs.rs/tantivy/latest/tantivy/collector/struct.TopDocs.html#method.order_by_fast_field) and
[order_by_u64_field](https://docs.rs/tantivy/latest/tantivy/collector/struct.TopDocs.html#method.order_by_u64_field)
for `Order::Asc`
Tantivy 0.24.1
================================
- Fix: bump required rust version to 1.81
Tantivy 0.24
================================
Tantivy 0.24 will be backwards compatible with indices created with v0.22 and v0.21. The new minimum rust version will be 1.75. Tantivy 0.23 will be skipped.
@@ -66,7 +78,7 @@ This will slightly increase space and access time. [#2439](https://github.com/qu
- **Store DateTime as nanoseconds in doc store** DateTime in the doc store was truncated to microseconds previously. This removes this truncation, while still keeping backwards compatibility. [#2486](https://github.com/quickwit-oss/tantivy/pull/2486)(@PSeitz)
- **Performace/Memory**
- **Performance/Memory**
- lift clauses in LogicalAst for optimized ast during execution [#2449](https://github.com/quickwit-oss/tantivy/pull/2449)(@PSeitz)
- Use Vec instead of BTreeMap to back OwnedValue object [#2364](https://github.com/quickwit-oss/tantivy/pull/2364)(@fulmicoton)
- Replace TantivyDocument with CompactDoc. CompactDoc is much smaller and provides similar performance. [#2402](https://github.com/quickwit-oss/tantivy/pull/2402)(@PSeitz)
@@ -96,6 +108,14 @@ This will slightly increase space and access time. [#2439](https://github.com/qu
- Fix trait bound of StoreReader::iter [#2360](https://github.com/quickwit-oss/tantivy/pull/2360)(@adamreichold)
- remove read_postings_no_deletes [#2526](https://github.com/quickwit-oss/tantivy/pull/2526)(@PSeitz)
Tantivy 0.22.1
================================
- Fix TopNComputer for reverse order. [#2672](https://github.com/quickwit-oss/tantivy/pull/2672)(@stuhood @PSeitz)
Affected queries are [order_by_fast_field](https://docs.rs/tantivy/latest/tantivy/collector/struct.TopDocs.html#method.order_by_fast_field) and
[order_by_u64_field](https://docs.rs/tantivy/latest/tantivy/collector/struct.TopDocs.html#method.order_by_u64_field)
for `Order::Asc`
Tantivy 0.22
================================

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy"
version = "0.25.0"
version = "0.26.0"
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
license = "MIT"
categories = ["database-implementations", "data-structures"]
@@ -69,6 +69,7 @@ hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
typetag = "0.2.21"
[target.'cfg(windows)'.dependencies]
winapi = "0.3.9"
@@ -87,7 +88,7 @@ more-asserts = "0.3.1"
rand_distr = "0.4.3"
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
postcard = { version = "1.0.4", features = [
"use-std",
"use-std",
], default-features = false }
[target.'cfg(not(windows))'.dev-dependencies]
@@ -167,3 +168,11 @@ harness = false
[[bench]]
name = "agg_bench"
harness = false
[[bench]]
name = "exists_json"
harness = false
[[bench]]
name = "and_or_queries"
harness = false

View File

@@ -23,8 +23,6 @@ performance for different types of queries/collections.
Your mileage WILL vary depending on the nature of queries and their load.
<img src="doc/assets/images/searchbenchmark.png">
Details about the benchmark can be found at this [repository](https://github.com/quickwit-oss/search-benchmark-game).
## Features
@@ -125,6 +123,7 @@ You can also find other bindings on [GitHub](https://github.com/search?q=tantivy
- [seshat](https://github.com/matrix-org/seshat/): A matrix message database/indexer
- [tantiny](https://github.com/baygeldin/tantiny): Tiny full-text search for Ruby
- [lnx](https://github.com/lnx-search/lnx): adaptable, typo tolerant search engine with a REST API
- [Bichon](https://github.com/rustmailer/bichon): A lightweight, high-performance Rust email archiver with WebUI
- and [more](https://github.com/search?q=tantivy)!
### On average, how much faster is Tantivy compared to Lucene?

View File

@@ -10,7 +10,7 @@ rename FastFieldReaders::open to load
remove fast field reader
find a way to unify the two DateTime.
readd type check in the filter wrapper
re-add type check in the filter wrapper
add unit test on columnar list columns.

View File

@@ -1,5 +1,6 @@
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use rand::distributions::WeightedIndex;
use rand::prelude::SliceRandom;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
@@ -53,26 +54,41 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, stats_f64);
register!(group, extendedstats_f64);
register!(group, percentiles_f64);
register!(group, terms_few);
register!(group, terms_many);
register!(group, terms_7);
register!(group, terms_all_unique);
register!(group, terms_150_000);
register!(group, terms_many_top_1000);
register!(group, terms_many_order_by_term);
register!(group, terms_many_with_top_hits);
register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status_with_histogram);
register!(group, terms_zipf_1000);
register!(group, terms_zipf_1000_with_histogram);
register!(group, terms_zipf_1000_with_avg_sub_agg);
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, cardinality_agg);
register!(group, terms_few_with_cardinality_agg);
register!(group, terms_status_with_cardinality_agg);
register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg);
register!(group, range_agg_with_term_agg_few);
register!(group, range_agg_with_term_agg_status);
register!(group, range_agg_with_term_agg_many);
register!(group, histogram);
register!(group, histogram_hard_bounds);
register!(group, histogram_with_avg_sub_agg);
register!(group, histogram_with_term_agg_status);
register!(group, avg_and_range_with_avg_sub_agg);
// Filter aggregation benchmarks
register!(group, filter_agg_all_query_count_agg);
register!(group, filter_agg_term_query_count_agg);
register!(group, filter_agg_all_query_with_sub_aggs);
register!(group, filter_agg_term_query_with_sub_aggs);
group.run();
}
@@ -123,12 +139,12 @@ fn extendedstats_f64(index: &Index) {
}
fn percentiles_f64(index: &Index) {
let agg_req = json!({
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
}
}
}
});
execute_agg(index, agg_req);
}
@@ -143,10 +159,10 @@ fn cardinality_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_few_with_cardinality_agg(index: &Index) {
fn terms_status_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"terms": { "field": "text_few_terms_status" },
"aggs": {
"cardinality": {
"cardinality": {
@@ -159,13 +175,20 @@ fn terms_few_with_cardinality_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_few(index: &Index) {
fn terms_7(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms" } },
"my_texts": { "terms": { "field": "text_few_terms_status" } },
});
execute_agg(index, agg_req);
}
fn terms_many(index: &Index) {
fn terms_all_unique(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_all_unique_terms" } },
});
execute_agg(index, agg_req);
}
fn terms_150_000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms" } },
});
@@ -213,6 +236,72 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_all_unique_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_all_unique_terms" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_1000_terms_zipf" } },
});
execute_agg(index, agg_req);
}
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -268,7 +357,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn range_agg_with_term_agg_few(index: &Index) {
fn range_agg_with_term_agg_status(index: &Index) {
let agg_req = json!({
"rangef64": {
"range": {
@@ -283,7 +372,7 @@ fn range_agg_with_term_agg_few(index: &Index) {
]
},
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms" } },
"my_texts": { "terms": { "field": "text_few_terms_status" } },
}
},
});
@@ -339,6 +428,17 @@ fn histogram_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn histogram_with_term_agg_status(index: &Index) {
let agg_req = json!({
"rangef64": {
"histogram": { "field": "score_f64", "interval": 10 },
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms_status" } }
}
}
});
execute_agg(index, agg_req);
}
fn avg_and_range_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"rangef64": {
@@ -378,6 +478,13 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
// Flag to use existing index
let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok();
if reuse_index && std::path::Path::new("agg_bench").exists() {
return Index::open_in_dir("agg_bench");
}
// crreate dir
std::fs::create_dir_all("agg_bench")?;
let mut schema_builder = Schema::builder();
let text_fieldtype = tantivy::schema::TextOptions::default()
.set_indexing_options(
@@ -386,20 +493,47 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let json_field = schema_builder.add_json_field("json", FAST);
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
let text_field_few_terms_status =
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
let text_field_1000_terms_zipf =
schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST);
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let index = Index::create_from_tempdir(schema_builder.build())?;
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
// use tmp dir
let index = if reuse_index {
Index::create_in_dir("agg_bench", schema_builder.build())?
} else {
Index::create_from_tempdir(schema_builder.build())?
};
// Approximate log proportions
let status_field_data = [
("INFO", 8000),
("ERROR", 300),
("WARN", 1200),
("DEBUG", 500),
("OK", 500),
("CRITICAL", 20),
("EMERGENCY", 1),
];
let log_level_distribution =
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
let many_terms_data = (0..150_000)
.map(|num| format!("author{num}"))
.collect::<Vec<_>>();
// Prepare 1000 unique terms sampled using a Zipf distribution.
// Exponent ~1.1 approximates top-20 terms covering around ~20%.
let terms_1000: Vec<String> = (1..=1000).map(|i| format!("term_{i}")).collect();
let zipf_1000 = rand_distr::Zipf::new(1000, 1.1f64).unwrap();
{
let mut rng = StdRng::from_seed([1u8; 32]);
let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?;
@@ -409,15 +543,25 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!())?;
}
if cardinality == Cardinality::Multivalued {
let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0;
let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0;
let idx_a = zipf_1000.sample(&mut rng) as usize - 1;
let idx_b = zipf_1000.sample(&mut rng) as usize - 1;
let term_1000_a = &terms_1000[idx_a];
let term_1000_b = &terms_1000[idx_b];
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
text_field => "cool",
text_field => "cool",
text_field_all_unique_terms => "cool",
text_field_all_unique_terms => "coolo",
text_field_many_terms => "cool",
text_field_many_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms_status => log_level_sample_a,
text_field_few_terms_status => log_level_sample_b,
text_field_1000_terms_zipf => term_1000_a.as_str(),
text_field_1000_terms_zipf => term_1000_b.as_str(),
score_field => 1u64,
score_field => 1u64,
score_field_f64 => lg_norm.sample(&mut rng),
@@ -442,8 +586,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!(
text_field => "cool",
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()),
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,
@@ -460,3 +606,61 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
Ok(index)
}
// Filter aggregation benchmarks
fn filter_agg_all_query_count_agg(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "*",
"aggs": {
"count": { "value_count": { "field": "score" } }
}
}
});
execute_agg(index, agg_req);
}
fn filter_agg_term_query_count_agg(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "text:cool",
"aggs": {
"count": { "value_count": { "field": "score" } }
}
}
});
execute_agg(index, agg_req);
}
fn filter_agg_all_query_with_sub_aggs(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "*",
"aggs": {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms_status" }
}
}
}
});
execute_agg(index, agg_req);
}
fn filter_agg_term_query_with_sub_aggs(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "text:cool",
"aggs": {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms_status" }
}
}
}
});
execute_agg(index, agg_req);
}

218
benches/and_or_queries.rs Normal file
View File

@@ -0,0 +1,218 @@
// Benchmarks boolean conjunction queries using binggan.
//
// Whats measured:
// - Or and And queries with varying selectivity (only `Term` queries for now on leafs)
// - Nested AND/OR combinations (on multiple fields)
// - No-scoring path using the Count collector (focus on iterator/skip performance)
// - Top-K retrieval (k=10) using the TopDocs collector
//
// Corpus model:
// - Synthetic docs; each token a/b/c is independently included per doc
// - If none of a/b/c are included, emit a neutral filler token to keep doc length similar
//
// Notes:
// - After optimization, when scoring is disabled Tantivy reads doc-only postings
// (IndexRecordOption::Basic), avoiding frequency decoding overhead.
// - This bench isolates boolean iteration speed and intersection/union cost.
// - Use `cargo bench --bench boolean_conjunction` to run.
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::sort_key::SortByStaticFastValue;
use tantivy::collector::{Collector, Count, TopDocs};
use tantivy::query::{Query, QueryParser};
use tantivy::schema::{Schema, FAST, TEXT};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
#[derive(Clone)]
struct BenchIndex {
#[allow(dead_code)]
index: Index,
searcher: Searcher,
query_parser: QueryParser,
}
/// Build a single index containing both fields (title, body) and
/// return two BenchIndex views:
/// - single_field: QueryParser defaults to only "body"
/// - multi_field: QueryParser defaults to ["title", "body"]
fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (BenchIndex, BenchIndex) {
// Unified schema (two text fields)
let mut schema_builder = Schema::builder();
let f_title = schema_builder.add_text_field("title", TEXT);
let f_body = schema_builder.add_text_field("body", TEXT);
let f_score = schema_builder.add_u64_field("score", FAST);
let f_score2 = schema_builder.add_u64_field("score2", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// Populate index with stable RNG for reproducibility.
let mut rng = StdRng::from_seed([7u8; 32]);
// Populate: spread each present token 90/10 to body/title
{
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
for _ in 0..num_docs {
let has_a = rng.gen_bool(p_a as f64);
let has_b = rng.gen_bool(p_b as f64);
let has_c = rng.gen_bool(p_c as f64);
let score = rng.gen_range(0u64..100u64);
let score2 = rng.gen_range(0u64..100_000u64);
let mut title_tokens: Vec<&str> = Vec::new();
let mut body_tokens: Vec<&str> = Vec::new();
if has_a {
if rng.gen_bool(0.1) {
title_tokens.push("a");
} else {
body_tokens.push("a");
}
}
if has_b {
if rng.gen_bool(0.1) {
title_tokens.push("b");
} else {
body_tokens.push("b");
}
}
if has_c {
if rng.gen_bool(0.1) {
title_tokens.push("c");
} else {
body_tokens.push("c");
}
}
if title_tokens.is_empty() && body_tokens.is_empty() {
body_tokens.push("z");
}
writer
.add_document(doc!(
f_title=>title_tokens.join(" "),
f_body=>body_tokens.join(" "),
f_score=>score,
f_score2=>score2,
))
.unwrap();
}
writer.commit().unwrap();
}
// Prepare reader/searcher once.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
// Build two query parsers with different default fields.
let qp_single = QueryParser::for_index(&index, vec![f_body]);
let qp_multi = QueryParser::for_index(&index, vec![f_title, f_body]);
let single_view = BenchIndex {
index: index.clone(),
searcher: searcher.clone(),
query_parser: qp_single,
};
let multi_view = BenchIndex {
index,
searcher,
query_parser: qp_multi,
};
(single_view, multi_view)
}
fn main() {
// Prepare corpora with varying selectivity. Build one index per corpus
// and derive two views (single-field vs multi-field) from it.
let scenarios = vec![
(
"N=1M, p(a)=5%, p(b)=1%, p(c)=15%".to_string(),
1_000_000,
0.05,
0.01,
0.15,
),
(
"N=1M, p(a)=1%, p(b)=1%, p(c)=15%".to_string(),
1_000_000,
0.01,
0.01,
0.15,
),
];
let queries = &["a", "+a +b", "+a +b +c", "a OR b", "a OR b OR c"];
let mut runner = BenchRunner::new();
for (label, n, pa, pb, pc) in scenarios {
let (single_view, multi_view) = build_shared_indices(n, pa, pb, pc);
for (view_name, bench_index) in [("single_field", single_view), ("multi_field", multi_view)]
{
// Single-field group: default field is body only
let mut group = runner.new_group();
group.set_name(format!("{}{}", view_name, label));
for query_str in queries {
add_bench_task(&mut group, &bench_index, query_str, Count, "count");
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by_score(),
"top10",
);
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by_fast_field::<u64>("score", Order::Asc),
"top10_by_ff",
);
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by((
SortByStaticFastValue::<u64>::for_field("score"),
SortByStaticFastValue::<u64>::for_field("score2"),
)),
"top10_by_2ff",
);
}
group.run();
}
}
}
fn add_bench_task<C: Collector + 'static>(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
collector: C,
collector_name: &str,
) {
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let search_task = SearchTask {
searcher: bench_index.searcher.clone(),
collector,
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct SearchTask<C: Collector> {
searcher: Searcher,
collector: C,
query: Box<dyn Query>,
}
impl<C: Collector> SearchTask<C> {
#[inline(never)]
pub fn run(&self) -> usize {
self.searcher.search(&self.query, &self.collector).unwrap();
1
}
}

69
benches/exists_json.rs Normal file
View File

@@ -0,0 +1,69 @@
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use serde_json::json;
use tantivy::collector::Count;
use tantivy::query::ExistsQuery;
use tantivy::schema::{Schema, FAST, TEXT};
use tantivy::{doc, Index};
#[global_allocator]
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
fn main() {
let doc_count: usize = 500_000;
let subfield_counts: &[usize] = &[1, 2, 3, 4, 5, 6, 7, 8, 16, 256, 4096, 65536, 262144];
let indices: Vec<(String, Index)> = subfield_counts
.iter()
.map(|&sub_fields| {
(
format!("subfields={sub_fields}"),
build_index_with_json_subfields(doc_count, sub_fields),
)
})
.collect();
let mut group = InputGroup::new_with_inputs(indices);
group.add_plugin(PeakMemAllocPlugin::new(GLOBAL));
group.config().num_iter_group = Some(1);
group.config().num_iter_bench = Some(1);
group.register("exists_json", exists_json_union);
group.run();
}
fn exists_json_union(index: &Index) {
let reader = index.reader().expect("reader");
let searcher = reader.searcher();
let query = ExistsQuery::new("json".to_string(), true);
let count = searcher.search(&query, &Count).expect("exists search");
// Prevents optimizer from eliding the search
black_box(count);
}
fn build_index_with_json_subfields(num_docs: usize, num_subfields: usize) -> Index {
// Schema: single JSON field stored as FAST to support ExistsQuery.
let mut schema_builder = Schema::builder();
let json_field = schema_builder.add_json_field("json", TEXT | FAST);
let schema = schema_builder.build();
let index = Index::create_from_tempdir(schema).expect("create index");
{
let mut index_writer = index
.writer_with_num_threads(1, 200_000_000)
.expect("writer");
for i in 0..num_docs {
let sub = i % num_subfields;
// Only one subpath set per document; rotate subpaths so that
// no single subpath is full, but the union covers all docs.
let v = json!({ format!("field_{sub}"): i as u64 });
index_writer
.add_document(doc!(json_field => v))
.expect("add_document");
}
index_writer.commit().expect("commit");
}
index
}

View File

@@ -48,7 +48,7 @@ impl BitPacker {
pub fn flush<TWrite: io::Write + ?Sized>(&mut self, output: &mut TWrite) -> io::Result<()> {
if self.mini_buffer_written > 0 {
let num_bytes = (self.mini_buffer_written + 7) / 8;
let num_bytes = self.mini_buffer_written.div_ceil(8);
let bytes = self.mini_buffer.to_le_bytes();
output.write_all(&bytes[..num_bytes])?;
self.mini_buffer_written = 0;
@@ -138,7 +138,7 @@ impl BitUnpacker {
// We use `usize` here to avoid overflow issues.
let end_bit_read = (end_idx as usize) * self.num_bits;
let end_byte_read = (end_bit_read + 7) / 8;
let end_byte_read = end_bit_read.div_ceil(8);
assert!(
end_byte_read <= data.len(),
"Requested index is out of bounds."
@@ -258,7 +258,7 @@ mod test {
bitpacker.write(val, num_bits, &mut data).unwrap();
}
bitpacker.close(&mut data).unwrap();
assert_eq!(data.len(), ((num_bits as usize) * len + 7) / 8);
assert_eq!(data.len(), ((num_bits as usize) * len).div_ceil(8));
let bitunpacker = BitUnpacker::new(num_bits);
(bitunpacker, vals, data)
}
@@ -304,7 +304,7 @@ mod test {
bitpacker.write(val, num_bits, &mut buffer).unwrap();
}
bitpacker.flush(&mut buffer).unwrap();
assert_eq!(buffer.len(), (vals.len() * num_bits as usize + 7) / 8);
assert_eq!(buffer.len(), (vals.len() * num_bits as usize).div_ceil(8));
let bitunpacker = BitUnpacker::new(num_bits);
let max_val = if num_bits == 64 {
u64::MAX

View File

@@ -140,10 +140,10 @@ impl BlockedBitpacker {
pub fn iter(&self) -> impl Iterator<Item = u64> + '_ {
// todo performance: we could decompress a whole block and cache it instead
let bitpacked_elems = self.offset_and_bits.len() * BLOCK_SIZE;
let iter = (0..bitpacked_elems)
(0..bitpacked_elems)
.map(move |idx| self.get(idx))
.chain(self.buffer.iter().cloned());
iter
.chain(self.buffer.iter().cloned())
}
}

View File

@@ -19,7 +19,7 @@ fn u32_to_i32(val: u32) -> i32 {
#[inline]
unsafe fn u32_to_i32_avx2(vals_u32x8s: DataType) -> DataType {
const HIGHEST_BIT_MASK: DataType = from_u32x8([HIGHEST_BIT; NUM_LANES]);
op_xor(vals_u32x8s, HIGHEST_BIT_MASK)
unsafe { op_xor(vals_u32x8s, HIGHEST_BIT_MASK) }
}
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
@@ -66,17 +66,19 @@ unsafe fn filter_vec_avx2_aux(
]);
const SHIFT: __m256i = from_u32x8([NUM_LANES as u32; NUM_LANES]);
for _ in 0..num_words {
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
unsafe {
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
}
}
output_tail.offset_from(output) as usize
unsafe { output_tail.offset_from(output) as usize }
}
#[inline]
@@ -92,8 +94,7 @@ unsafe fn compute_filter_bitset(val: __m256i, range: std::ops::RangeInclusive<__
let too_low = op_greater(*range.start(), val);
let too_high = op_greater(val, *range.end());
let inside = op_or(too_low, too_high);
255 - std::arch::x86_64::_mm256_movemask_ps(std::mem::transmute::<DataType, __m256>(inside))
as u8
255 - std::arch::x86_64::_mm256_movemask_ps(_mm256_castsi256_ps(inside)) as u8
}
union U8x32 {

View File

@@ -73,7 +73,7 @@ The crate introduces the following concepts.
`Columnar` is an equivalent of a dataframe.
It maps `column_key` to `Column`.
A `Column<T>` asssociates a `RowId` (u32) to any
A `Column<T>` associates a `RowId` (u32) to any
number of values.
This is made possible by wrapping a `ColumnIndex` and a `ColumnValue` object.

View File

@@ -89,13 +89,6 @@ fn main() {
black_box(sum);
});
group.register("first_block_fetch", |column| {
let mut block: Vec<Option<u64>> = vec![None; 64];
let fetch_docids = (0..64).collect::<Vec<_>>();
column.first_vals(&fetch_docids, &mut block);
black_box(block[0]);
});
group.register("first_block_single_calls", |column| {
let mut block: Vec<Option<u64>> = vec![None; 64];
let fetch_docids = (0..64).collect::<Vec<_>>();

View File

@@ -29,12 +29,20 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
}
}
#[inline]
pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column<T>, missing: T) {
pub fn fetch_block_with_missing(
&mut self,
docs: &[u32],
accessor: &Column<T>,
missing: Option<T>,
) {
self.fetch_block(docs, accessor);
// no missing values
if accessor.index.get_cardinality().is_full() {
return;
}
let Some(missing) = missing else {
return;
};
// We can compare docid_cache length with docs to find missing docs
// For multi value columns we can't rely on the length and always need to scan

View File

@@ -131,6 +131,8 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
self.index.docids_to_rowids(doc_ids, doc_ids_out, row_ids)
}
/// Get an iterator over the values for the provided docid.
#[inline]
pub fn values_for_doc(&self, doc_id: DocId) -> impl Iterator<Item = T> + '_ {
self.index
.value_row_ids(doc_id)
@@ -158,15 +160,6 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
.select_batch_in_place(selected_docid_range.start, doc_ids);
}
/// Fills the output vector with the (possibly multiple values that are associated_with
/// `row_id`.
///
/// This method clears the `output` vector.
pub fn fill_vals(&self, row_id: RowId, output: &mut Vec<T>) {
output.clear();
output.extend(self.values_for_doc(row_id));
}
pub fn first_or_default_col(self, default_value: T) -> Arc<dyn ColumnValues<T>> {
Arc::new(FirstValueWithDefault {
column: self,

View File

@@ -56,7 +56,7 @@ fn get_doc_ids_with_values<'a>(
ColumnIndex::Full => Box::new(doc_range),
ColumnIndex::Optional(optional_index) => Box::new(
optional_index
.iter_docs()
.iter_non_null_docs()
.map(move |row| row + doc_range.start),
),
ColumnIndex::Multivalued(multivalued_index) => match multivalued_index {
@@ -73,7 +73,7 @@ fn get_doc_ids_with_values<'a>(
MultiValueIndex::MultiValueIndexV2(multivalued_index) => Box::new(
multivalued_index
.optional_index
.iter_docs()
.iter_non_null_docs()
.map(move |row| row + doc_range.start),
),
},
@@ -105,10 +105,11 @@ fn get_num_values_iterator<'a>(
) -> Box<dyn Iterator<Item = u32> + 'a> {
match column_index {
ColumnIndex::Empty { .. } => Box::new(std::iter::empty()),
ColumnIndex::Full => Box::new(std::iter::repeat(1u32).take(num_docs as usize)),
ColumnIndex::Optional(optional_index) => {
Box::new(std::iter::repeat(1u32).take(optional_index.num_non_nulls() as usize))
}
ColumnIndex::Full => Box::new(std::iter::repeat_n(1u32, num_docs as usize)),
ColumnIndex::Optional(optional_index) => Box::new(std::iter::repeat_n(
1u32,
optional_index.num_non_nulls() as usize,
)),
ColumnIndex::Multivalued(multivalued_index) => Box::new(
multivalued_index
.get_start_index_column()
@@ -177,7 +178,7 @@ impl<'a> Iterable<RowId> for StackedOptionalIndex<'a> {
ColumnIndex::Full => Box::new(columnar_row_range),
ColumnIndex::Optional(optional_index) => Box::new(
optional_index
.iter_docs()
.iter_non_null_docs()
.map(move |row_id: RowId| columnar_row_range.start + row_id),
),
ColumnIndex::Multivalued(_) => {

View File

@@ -215,6 +215,32 @@ impl MultiValueIndex {
}
}
/// Returns an iterator over document ids that have at least one value.
pub fn iter_non_null_docs(&self) -> Box<dyn Iterator<Item = DocId> + '_> {
match self {
MultiValueIndex::MultiValueIndexV1(idx) => {
let mut doc: DocId = 0u32;
let num_docs = idx.num_docs();
Box::new(std::iter::from_fn(move || {
// This is not the most efficient way to do this, but it's legacy code.
while doc < num_docs {
let cur = doc;
doc += 1;
let start = idx.start_index_column.get_val(cur);
let end = idx.start_index_column.get_val(cur + 1);
if end > start {
return Some(cur);
}
}
None
}))
}
MultiValueIndex::MultiValueIndexV2(idx) => {
Box::new(idx.optional_index.iter_non_null_docs())
}
}
}
/// Converts a list of ranks (row ids of values) in a 1:n index to the corresponding list of
/// docids. Positions are converted inplace to docids.
///

View File

@@ -1,4 +1,4 @@
use std::io::{self, Write};
use std::io;
use std::sync::Arc;
mod set;
@@ -11,7 +11,7 @@ use set_block::{
};
use crate::iterable::Iterable;
use crate::{DocId, InvalidData, RowId};
use crate::{DocId, RowId};
/// The threshold for for number of elements after which we switch to dense block encoding.
///
@@ -88,7 +88,7 @@ pub struct OptionalIndex {
impl Iterable<u32> for &OptionalIndex {
fn boxed_iter(&self) -> Box<dyn Iterator<Item = u32> + '_> {
Box::new(self.iter_docs())
Box::new(self.iter_non_null_docs())
}
}
@@ -280,8 +280,9 @@ impl OptionalIndex {
self.num_non_null_docs
}
pub fn iter_docs(&self) -> impl Iterator<Item = RowId> + '_ {
// TODO optimize
pub fn iter_non_null_docs(&self) -> impl Iterator<Item = RowId> + '_ {
// TODO optimize. We could iterate over the blocks directly.
// We use the dense value ids and retrieve the doc ids via select.
let mut select_batch = self.select_cursor();
(0..self.num_non_null_docs).map(move |rank| select_batch.select(rank))
}
@@ -334,38 +335,6 @@ enum Block<'a> {
Sparse(SparseBlock<'a>),
}
#[derive(Debug, Copy, Clone)]
enum OptionalIndexCodec {
Dense = 0,
Sparse = 1,
}
impl OptionalIndexCodec {
fn to_code(self) -> u8 {
self as u8
}
fn try_from_code(code: u8) -> Result<Self, InvalidData> {
match code {
0 => Ok(Self::Dense),
1 => Ok(Self::Sparse),
_ => Err(InvalidData),
}
}
}
impl BinarySerializable for OptionalIndexCodec {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&[self.to_code()])
}
fn deserialize<R: io::Read>(reader: &mut R) -> io::Result<Self> {
let optional_codec_code = u8::deserialize(reader)?;
let optional_codec = Self::try_from_code(optional_codec_code)?;
Ok(optional_codec)
}
}
fn serialize_optional_index_block(block_els: &[u16], out: &mut impl io::Write) -> io::Result<()> {
let is_sparse = is_sparse(block_els.len() as u32);
if is_sparse {

View File

@@ -164,7 +164,11 @@ fn test_optional_index_large() {
fn test_optional_index_iter_aux(row_ids: &[RowId], num_rows: RowId) {
let optional_index = OptionalIndex::for_test(num_rows, row_ids);
assert_eq!(optional_index.num_docs(), num_rows);
assert!(optional_index.iter_docs().eq(row_ids.iter().copied()));
assert!(
optional_index
.iter_non_null_docs()
.eq(row_ids.iter().copied())
);
}
#[test]

View File

@@ -1,7 +1,7 @@
use std::fmt::Debug;
use std::net::Ipv6Addr;
/// Montonic maps a value to u128 value space
/// Monotonic maps a value to u128 value space
/// Monotonic mapping enables `PartialOrd` on u128 space without conversion to original space.
pub trait MonotonicallyMappableToU128: 'static + PartialOrd + Copy + Debug + Send + Sync {
/// Converts a value to u128.

View File

@@ -185,10 +185,10 @@ impl CompactSpaceBuilder {
let mut covered_space = Vec::with_capacity(self.blanks.len());
// beginning of the blanks
if let Some(first_blank_start) = self.blanks.first().map(RangeInclusive::start) {
if *first_blank_start != 0 {
covered_space.push(0..=first_blank_start - 1);
}
if let Some(first_blank_start) = self.blanks.first().map(RangeInclusive::start)
&& *first_blank_start != 0
{
covered_space.push(0..=first_blank_start - 1);
}
// Between the blanks
@@ -202,10 +202,10 @@ impl CompactSpaceBuilder {
covered_space.extend(between_blanks);
// end of the blanks
if let Some(last_blank_end) = self.blanks.last().map(RangeInclusive::end) {
if *last_blank_end != u128::MAX {
covered_space.push(last_blank_end + 1..=u128::MAX);
}
if let Some(last_blank_end) = self.blanks.last().map(RangeInclusive::end)
&& *last_blank_end != u128::MAX
{
covered_space.push(last_blank_end + 1..=u128::MAX);
}
if covered_space.is_empty() {

View File

@@ -105,7 +105,7 @@ impl ColumnCodecEstimator for BitpackedCodecEstimator {
fn estimate(&self, stats: &ColumnStats) -> Option<u64> {
let num_bits_per_value = num_bits(stats);
Some(stats.num_bytes() + (stats.num_rows as u64 * (num_bits_per_value as u64) + 7) / 8)
Some(stats.num_bytes() + (stats.num_rows as u64 * (num_bits_per_value as u64)).div_ceil(8))
}
fn serialize(

View File

@@ -8,7 +8,7 @@ use crate::column_values::ColumnValues;
const MID_POINT: u64 = (1u64 << 32) - 1u64;
/// `Line` describes a line function `y: ax + b` using integer
/// arithmetics.
/// arithmetic.
///
/// The slope is in fact a decimal split into a 32 bit integer value,
/// and a 32-bit decimal value.
@@ -94,7 +94,7 @@ impl Line {
// `(i, ys[])`.
//
// The best intercept therefore has the form
// `y[i] - line.eval(i)` (using wrapping arithmetics).
// `y[i] - line.eval(i)` (using wrapping arithmetic).
// In other words, the best intercept is one of the `y - Line::eval(ys[i])`
// and our task is just to pick the one that minimizes our error.
//

View File

@@ -117,7 +117,7 @@ impl ColumnCodecEstimator for LinearCodecEstimator {
Some(
stats.num_bytes()
+ linear_params.num_bytes()
+ (num_bits as u64 * stats.num_rows as u64 + 7) / 8,
+ (num_bits as u64 * stats.num_rows as u64).div_ceil(8),
)
}

View File

@@ -52,7 +52,7 @@ pub trait ColumnCodecEstimator<T = u64>: 'static {
) -> io::Result<()>;
}
/// A column codec describes a colunm serialization format.
/// A column codec describes a column serialization format.
pub trait ColumnCodec<T: PartialOrd = u64> {
/// Specialized `ColumnValues` type.
type ColumnValues: ColumnValues<T> + 'static;

View File

@@ -367,7 +367,7 @@ fn is_empty_after_merge(
ColumnIndex::Empty { .. } => true,
ColumnIndex::Full => alive_bitset.len() == 0,
ColumnIndex::Optional(optional_index) => {
for doc in optional_index.iter_docs() {
for doc in optional_index.iter_non_null_docs() {
if alive_bitset.contains(doc) {
return false;
}

View File

@@ -244,7 +244,7 @@ impl SymbolValue for UnorderedId {
fn compute_num_bytes_for_u64(val: u64) -> usize {
let msb = (64u32 - val.leading_zeros()) as usize;
(msb + 7) / 8
msb.div_ceil(8)
}
fn encode_zig_zag(n: i64) -> u64 {

View File

@@ -1,3 +1,5 @@
use std::str::FromStr;
use common::DateTime;
use crate::InvalidData;
@@ -9,6 +11,23 @@ pub enum NumericalValue {
F64(f64),
}
impl FromStr for NumericalValue {
type Err = ();
fn from_str(s: &str) -> Result<Self, ()> {
if let Ok(val_i64) = s.parse::<i64>() {
return Ok(val_i64.into());
}
if let Ok(val_u64) = s.parse::<u64>() {
return Ok(val_u64.into());
}
if let Ok(val_f64) = s.parse::<f64>() {
return Ok(NumericalValue::from(val_f64).normalize());
}
Err(())
}
}
impl NumericalValue {
pub fn numerical_type(&self) -> NumericalType {
match self {
@@ -26,7 +45,7 @@ impl NumericalValue {
if val <= i64::MAX as u64 {
NumericalValue::I64(val as i64)
} else {
NumericalValue::F64(val as f64)
NumericalValue::U64(val)
}
}
NumericalValue::I64(val) => NumericalValue::I64(val),
@@ -141,6 +160,7 @@ impl Coerce for DateTime {
#[cfg(test)]
mod tests {
use super::NumericalType;
use crate::NumericalValue;
#[test]
fn test_numerical_type_code() {
@@ -153,4 +173,58 @@ mod tests {
}
assert_eq!(num_numerical_type, 3);
}
#[test]
fn test_parse_numerical() {
assert_eq!(
"123".parse::<NumericalValue>().unwrap(),
NumericalValue::I64(123)
);
assert_eq!(
"18446744073709551615".parse::<NumericalValue>().unwrap(),
NumericalValue::U64(18446744073709551615u64)
);
assert_eq!(
"1.0".parse::<NumericalValue>().unwrap(),
NumericalValue::I64(1i64)
);
assert_eq!(
"1.1".parse::<NumericalValue>().unwrap(),
NumericalValue::F64(1.1f64)
);
assert_eq!(
"-1.0".parse::<NumericalValue>().unwrap(),
NumericalValue::I64(-1i64)
);
}
#[test]
fn test_normalize_numerical() {
assert_eq!(
NumericalValue::from(1u64).normalize(),
NumericalValue::I64(1i64),
);
let limit_val = i64::MAX as u64 + 1u64;
assert_eq!(
NumericalValue::from(limit_val).normalize(),
NumericalValue::U64(limit_val),
);
assert_eq!(
NumericalValue::from(-1i64).normalize(),
NumericalValue::I64(-1i64),
);
assert_eq!(
NumericalValue::from(-2.0f64).normalize(),
NumericalValue::I64(-2i64),
);
assert_eq!(
NumericalValue::from(-2.1f64).normalize(),
NumericalValue::F64(-2.1f64),
);
let large_float = 2.0f64.powf(70.0f64);
assert_eq!(
NumericalValue::from(large_float).normalize(),
NumericalValue::F64(large_float),
);
}
}

View File

@@ -181,9 +181,17 @@ pub struct BitSet {
len: u64,
max_value: u32,
}
impl std::fmt::Debug for BitSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BitSet")
.field("len", &self.len)
.field("max_value", &self.max_value)
.finish()
}
}
fn num_buckets(max_val: u32) -> u32 {
(max_val + 63u32) / 64u32
max_val.div_ceil(64u32)
}
impl BitSet {

View File

@@ -28,7 +28,9 @@ impl BinarySerializable for VIntU128 {
writer.write_all(&buffer)
}
#[allow(clippy::unbuffered_bytes)]
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
#[allow(clippy::unbuffered_bytes)]
let mut bytes = reader.bytes();
let mut result = 0u128;
let mut shift = 0u64;
@@ -195,7 +197,9 @@ impl BinarySerializable for VInt {
writer.write_all(&buffer[0..num_bytes])
}
#[allow(clippy::unbuffered_bytes)]
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
#[allow(clippy::unbuffered_bytes)]
let mut bytes = reader.bytes();
let mut result = 0u64;
let mut shift = 0u64;

Binary file not shown.

Before

Width:  |  Height:  |  Size: 653 KiB

View File

@@ -208,7 +208,7 @@ fn main() -> tantivy::Result<()> {
// is the role of the `TopDocs` collector.
// We can now perform our query.
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
// The actual documents still need to be
// retrieved from Tantivy's store.
@@ -226,7 +226,7 @@ fn main() -> tantivy::Result<()> {
let query = query_parser.parse_query("title:sea^20 body:whale^70")?;
let (_score, doc_address) = searcher
.search(&query, &TopDocs::with_limit(1))?
.search(&query, &TopDocs::with_limit(1).order_by_score())?
.into_iter()
.next()
.unwrap();

View File

@@ -100,7 +100,7 @@ fn main() -> tantivy::Result<()> {
// here we want to get a hit on the 'ken' in Frankenstein
let query = query_parser.parse_query("ken")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
for (_, doc_address) in top_docs {
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;

View File

@@ -50,14 +50,14 @@ fn main() -> tantivy::Result<()> {
{
// Simple exact search on the date
let query = query_parser.parse_query("occurred_at:\"2022-06-22T12:53:50.53Z\"")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
assert_eq!(count_docs.len(), 1);
}
{
// Range query on the date field
let query = query_parser
.parse_query(r#"occurred_at:[2022-06-22T12:58:00Z TO 2022-06-23T00:00:00Z}"#)?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4).order_by_score())?;
assert_eq!(count_docs.len(), 1);
for (_score, doc_address) in count_docs {
let retrieved_doc = searcher.doc::<TantivyDocument>(doc_address)?;

View File

@@ -28,7 +28,7 @@ fn extract_doc_given_isbn(
// The second argument is here to tell we don't care about decoding positions,
// or term frequencies.
let term_query = TermQuery::new(isbn_term.clone(), IndexRecordOption::Basic);
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1))?;
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1).order_by_score())?;
if let Some((_score, doc_address)) = top_docs.first() {
let doc = searcher.doc(*doc_address)?;

View File

@@ -0,0 +1,212 @@
// # Filter Aggregation Example
//
// This example demonstrates filter aggregations - creating buckets of documents
// matching specific queries, with nested aggregations computed on each bucket.
//
// Filter aggregations are useful for computing metrics on different subsets of
// your data in a single query, like "average price overall + average price for
// electronics + count of in-stock items".
use serde_json::json;
use tantivy::aggregation::agg_req::Aggregations;
use tantivy::aggregation::AggregationCollector;
use tantivy::query::AllQuery;
use tantivy::schema::{Schema, FAST, INDEXED, TEXT};
use tantivy::{doc, Index};
fn main() -> tantivy::Result<()> {
// Create a simple product schema
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("category", TEXT | FAST);
schema_builder.add_text_field("brand", TEXT | FAST);
schema_builder.add_u64_field("price", FAST);
schema_builder.add_f64_field("rating", FAST);
schema_builder.add_bool_field("in_stock", FAST | INDEXED);
let schema = schema_builder.build();
// Create index and add sample products
let index = Index::create_in_ram(schema.clone());
let mut writer = index.writer(50_000_000)?;
writer.add_document(doc!(
schema.get_field("category")? => "electronics",
schema.get_field("brand")? => "apple",
schema.get_field("price")? => 999u64,
schema.get_field("rating")? => 4.5f64,
schema.get_field("in_stock")? => true
))?;
writer.add_document(doc!(
schema.get_field("category")? => "electronics",
schema.get_field("brand")? => "samsung",
schema.get_field("price")? => 799u64,
schema.get_field("rating")? => 4.2f64,
schema.get_field("in_stock")? => true
))?;
writer.add_document(doc!(
schema.get_field("category")? => "clothing",
schema.get_field("brand")? => "nike",
schema.get_field("price")? => 120u64,
schema.get_field("rating")? => 4.1f64,
schema.get_field("in_stock")? => false
))?;
writer.add_document(doc!(
schema.get_field("category")? => "books",
schema.get_field("brand")? => "penguin",
schema.get_field("price")? => 25u64,
schema.get_field("rating")? => 4.8f64,
schema.get_field("in_stock")? => true
))?;
writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
// Example 1: Basic filter with metric aggregation
println!("=== Example 1: Electronics average price ===");
let agg_req = json!({
"electronics": {
"filter": "category:electronics",
"aggs": {
"avg_price": { "avg": { "field": "price" } }
}
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"electronics": {
"doc_count": 2,
"avg_price": { "value": 899.0 }
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}\n", serde_json::to_string_pretty(&result)?);
// Example 2: Multiple independent filters
println!("=== Example 2: Multiple filters in one query ===");
let agg_req = json!({
"electronics": {
"filter": "category:electronics",
"aggs": { "avg_price": { "avg": { "field": "price" } } }
},
"in_stock": {
"filter": "in_stock:true",
"aggs": { "count": { "value_count": { "field": "brand" } } }
},
"high_rated": {
"filter": "rating:[4.5 TO *]",
"aggs": { "count": { "value_count": { "field": "brand" } } }
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"electronics": {
"doc_count": 2,
"avg_price": { "value": 899.0 }
},
"in_stock": {
"doc_count": 3,
"count": { "value": 3.0 }
},
"high_rated": {
"doc_count": 2,
"count": { "value": 2.0 }
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}\n", serde_json::to_string_pretty(&result)?);
// Example 3: Nested filters - progressive refinement
println!("=== Example 3: Nested filters ===");
let agg_req = json!({
"in_stock": {
"filter": "in_stock:true",
"aggs": {
"electronics": {
"filter": "category:electronics",
"aggs": {
"expensive": {
"filter": "price:[800 TO *]",
"aggs": {
"avg_rating": { "avg": { "field": "rating" } }
}
}
}
}
}
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"in_stock": {
"doc_count": 3, // apple, samsung, penguin
"electronics": {
"doc_count": 2, // apple, samsung
"expensive": {
"doc_count": 1, // only apple (999)
"avg_rating": { "value": 4.5 }
}
}
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}\n", serde_json::to_string_pretty(&result)?);
// Example 4: Filter with sub-aggregation (terms)
println!("=== Example 4: Filter with terms sub-aggregation ===");
let agg_req = json!({
"electronics": {
"filter": "category:electronics",
"aggs": {
"by_brand": {
"terms": { "field": "brand" },
"aggs": {
"avg_price": { "avg": { "field": "price" } }
}
}
}
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"electronics": {
"doc_count": 2,
"by_brand": {
"buckets": [
{
"key": "samsung",
"doc_count": 1,
"avg_price": { "value": 799.0 }
},
{
"key": "apple",
"doc_count": 1,
"avg_price": { "value": 999.0 }
}
],
"sum_other_doc_count": 0,
"doc_count_error_upper_bound": 0
}
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}", serde_json::to_string_pretty(&result)?);
Ok(())
}

View File

@@ -85,7 +85,6 @@ fn main() -> tantivy::Result<()> {
index_writer.add_document(doc!(
title => "The Diary of a Young Girl",
))?;
index_writer.commit()?;
// ### Committing
//
@@ -146,7 +145,7 @@ fn main() -> tantivy::Result<()> {
let query = FuzzyTermQuery::new(term, 2, true);
let (top_docs, count) = searcher
.search(&query, &(TopDocs::with_limit(5), Count))
.search(&query, &(TopDocs::with_limit(5).order_by_score(), Count))
.unwrap();
assert_eq!(count, 3);
assert_eq!(top_docs.len(), 3);

View File

@@ -69,25 +69,25 @@ fn main() -> tantivy::Result<()> {
{
// Inclusive range queries
let query = query_parser.parse_query("ip:[192.168.0.80 TO 192.168.0.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
assert_eq!(count_docs.len(), 1);
}
{
// Exclusive range queries
let query = query_parser.parse_query("ip:{192.168.0.80 TO 192.168.1.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 0);
}
{
// Find docs with IP addresses smaller equal 192.168.1.100
let query = query_parser.parse_query("ip:[* TO 192.168.1.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 2);
}
{
// Find docs with IP addresses smaller than 192.168.1.100
let query = query_parser.parse_query("ip:[* TO 192.168.1.100}")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 2);
}

View File

@@ -59,12 +59,12 @@ fn main() -> tantivy::Result<()> {
let query_parser = QueryParser::for_index(&index, vec![event_type, attributes]);
{
let query = query_parser.parse_query("target:submit-button")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 2);
}
{
let query = query_parser.parse_query("target:submit")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 2);
}
{
@@ -74,33 +74,33 @@ fn main() -> tantivy::Result<()> {
}
{
let query = query_parser.parse_query("click AND cart.product_id:133")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 1);
}
{
// The sub-fields in the json field marked as default field still need to be explicitly
// addressed
let query = query_parser.parse_query("click AND 133")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 0);
}
{
// Default json fields are ignored if they collide with the schema
let query = query_parser.parse_query("event_type:holiday-sale")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 0);
}
// # Query via full attribute path
{
// This only searches in our schema's `event_type` field
let query = query_parser.parse_query("event_type:click")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 2);
}
{
// Default json fields can still be accessed by full path
let query = query_parser.parse_query("attributes.event_type:holiday-sale")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 1);
}
Ok(())

View File

@@ -63,7 +63,7 @@ fn main() -> Result<()> {
// but not "in the Gulf Stream".
let query = query_parser.parse_query("\"in the su\"*")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
let mut titles = top_docs
.into_iter()
.map(|(_score, doc_address)| {

View File

@@ -107,7 +107,8 @@ fn main() -> tantivy::Result<()> {
IndexRecordOption::Basic,
);
let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
let (top_docs, count) =
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
assert_eq!(count, 2);
@@ -128,7 +129,8 @@ fn main() -> tantivy::Result<()> {
IndexRecordOption::Basic,
);
let (_top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
let (_top_docs, count) =
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
assert_eq!(count, 0);

View File

@@ -50,7 +50,7 @@ fn main() -> tantivy::Result<()> {
let query_parser = QueryParser::for_index(&index, vec![title, body]);
let query = query_parser.parse_query("sycamore spring")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
let snippet_generator = SnippetGenerator::create(&searcher, &*query, body)?;

View File

@@ -102,7 +102,7 @@ fn main() -> tantivy::Result<()> {
// stop words are applied on the query as well.
// The following will be equivalent to `title:frankenstein`
let query = query_parser.parse_query("title:\"the Frankenstein\"")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
for (score, doc_address) in top_docs {
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;

View File

@@ -164,7 +164,7 @@ fn main() -> tantivy::Result<()> {
move |doc_id: DocId| Reverse(price[doc_id as usize])
};
let most_expensive_first = TopDocs::with_limit(10).custom_score(score_by_price);
let most_expensive_first = TopDocs::with_limit(10).order_by(score_by_price);
let hits = searcher.search(&query, &most_expensive_first)?;
assert_eq!(

View File

@@ -15,3 +15,5 @@ edition = "2024"
nom = "7"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"
ordered-float = "5.0.0"
fnv = "1.0.7"

View File

@@ -117,6 +117,22 @@ where F: nom::Parser<I, (O, ErrorList), Infallible> {
}
}
pub(crate) fn terminated_infallible<I, O1, O2, F, G>(
mut first: F,
mut second: G,
) -> impl FnMut(I) -> JResult<I, O1>
where
F: nom::Parser<I, (O1, ErrorList), Infallible>,
G: nom::Parser<I, (O2, ErrorList), Infallible>,
{
move |input: I| {
let (input, (o1, mut err)) = first.parse(input)?;
let (input, (_, mut err2)) = second.parse(input)?;
err.append(&mut err2);
Ok((input, (o1, err)))
}
}
pub(crate) fn delimited_infallible<I, O1, O2, O3, F, G, H>(
mut first: F,
mut second: G,

View File

@@ -31,7 +31,17 @@ pub fn parse_query_lenient(query: &str) -> (UserInputAst, Vec<LenientError>) {
#[cfg(test)]
mod tests {
use crate::{parse_query, parse_query_lenient};
use crate::{UserInputAst, parse_query, parse_query_lenient};
#[test]
fn test_deduplication() {
let ast: UserInputAst = parse_query("a a").unwrap();
let json = serde_json::to_string(&ast).unwrap();
assert_eq!(
json,
r#"{"type":"bool","clauses":[[null,{"type":"literal","field_name":null,"phrase":"a","delimiter":"none","slop":0,"prefix":false}]]}"#
);
}
#[test]
fn test_parse_query_serialization() {

View File

@@ -1,6 +1,7 @@
use std::borrow::Cow;
use std::iter::once;
use fnv::FnvHashSet;
use nom::IResult;
use nom::branch::alt;
use nom::bytes::complete::tag;
@@ -68,7 +69,7 @@ fn interpret_escape(source: &str) -> String {
/// Consume a word outside of any context.
// TODO should support escape sequences
fn word(inp: &str) -> IResult<&str, Cow<str>> {
fn word(inp: &str) -> IResult<&str, Cow<'_, str>> {
map_res(
recognize(tuple((
alt((
@@ -305,15 +306,14 @@ fn term_group_infallible(inp: &str) -> JResult<&str, UserInputAst> {
let (inp, (field_name, _, _, _)) =
tuple((field_name, multispace0, char('('), multispace0))(inp).expect("precondition failed");
let res = delimited_infallible(
delimited_infallible(
nothing,
map(ast_infallible, |(mut ast, errors)| {
ast.set_default_field(field_name.to_string());
(ast, errors)
}),
opt_i_err(char(')'), "expected ')'"),
)(inp);
res
)(inp)
}
fn exists(inp: &str) -> IResult<&str, UserInputLeaf> {
@@ -367,7 +367,10 @@ fn literal(inp: &str) -> IResult<&str, UserInputAst> {
// something (a field name) got parsed before
alt((
map(
tuple((opt(field_name), alt((range, set, exists, term_or_phrase)))),
tuple((
opt(field_name),
alt((range, set, exists, regex, term_or_phrase)),
)),
|(field_name, leaf): (Option<String>, UserInputLeaf)| leaf.set_field(field_name).into(),
),
term_group,
@@ -389,6 +392,10 @@ fn literal_no_group_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>>
value((), peek(one_of("{[><"))),
map(range_infallible, |(range, errs)| (Some(range), errs)),
),
(
value((), peek(one_of("/"))),
map(regex_infallible, |(regex, errs)| (Some(regex), errs)),
),
),
delimited_infallible(space0_infallible, term_or_phrase_infallible, nothing),
),
@@ -689,6 +696,61 @@ fn set_infallible(mut inp: &str) -> JResult<&str, UserInputLeaf> {
}
}
fn regex(inp: &str) -> IResult<&str, UserInputLeaf> {
map(
terminated(
delimited(
char('/'),
many1(alt((preceded(char('\\'), char('/')), none_of("/")))),
char('/'),
),
peek(alt((multispace1, eof))),
),
|elements| UserInputLeaf::Regex {
field: None,
pattern: elements.into_iter().collect::<String>(),
},
)(inp)
}
fn regex_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
match terminated_infallible(
delimited_infallible(
opt_i_err(char('/'), "missing delimiter /"),
opt_i(many1(alt((preceded(char('\\'), char('/')), none_of("/"))))),
opt_i_err(char('/'), "missing delimiter /"),
),
opt_i_err(
peek(alt((multispace1, eof))),
"expected whitespace or end of input",
),
)(inp)
{
Ok((rest, (elements_part, errors))) => {
let pattern = match elements_part {
Some(elements_part) => elements_part.into_iter().collect(),
None => String::new(),
};
let res = UserInputLeaf::Regex {
field: None,
pattern,
};
Ok((rest, (res, errors)))
}
Err(e) => {
let errs = vec![LenientErrorInternal {
pos: inp.len(),
message: e.to_string(),
}];
let res = UserInputLeaf::Regex {
field: None,
pattern: String::new(),
};
Ok((inp, (res, errs)))
}
}
}
fn negate(expr: UserInputAst) -> UserInputAst {
expr.unary(Occur::MustNot)
}
@@ -696,7 +758,17 @@ fn negate(expr: UserInputAst) -> UserInputAst {
fn leaf(inp: &str) -> IResult<&str, UserInputAst> {
alt((
delimited(char('('), ast, char(')')),
map(char('*'), |_| UserInputAst::from(UserInputLeaf::All)),
map(
terminated(
char('*'),
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
),
|_| UserInputAst::from(UserInputLeaf::All),
),
map(preceded(tuple((tag("NOT"), multispace1)), leaf), negate),
literal,
))(inp)
@@ -717,7 +789,17 @@ fn leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
),
),
(
value((), char('*')),
value(
(),
terminated(
char('*'),
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
),
),
map(nothing, |_| {
(Some(UserInputAst::from(UserInputLeaf::All)), Vec::new())
}),
@@ -753,7 +835,7 @@ fn boosted_leaf(inp: &str) -> IResult<&str, UserInputAst> {
tuple((leaf, fallible(boost))),
|(leaf, boost_opt)| match boost_opt {
Some(boost) if (boost - 1.0).abs() > f64::EPSILON => {
UserInputAst::Boost(Box::new(leaf), boost)
UserInputAst::Boost(Box::new(leaf), boost.into())
}
_ => leaf,
},
@@ -765,7 +847,7 @@ fn boosted_leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
tuple_infallible((leaf_infallible, boost)),
|((leaf, boost_opt), error)| match boost_opt {
Some(boost) if (boost - 1.0).abs() > f64::EPSILON => (
leaf.map(|leaf| UserInputAst::Boost(Box::new(leaf), boost)),
leaf.map(|leaf| UserInputAst::Boost(Box::new(leaf), boost.into())),
error,
),
_ => (leaf, error),
@@ -1016,12 +1098,25 @@ pub fn parse_to_ast_lenient(query_str: &str) -> (UserInputAst, Vec<LenientError>
(rewrite_ast(res), errors)
}
/// Removes unnecessary children clauses in AST
///
/// Motivated by [issue #1433](https://github.com/quickwit-oss/tantivy/issues/1433)
fn rewrite_ast(mut input: UserInputAst) -> UserInputAst {
if let UserInputAst::Clause(terms) = &mut input {
for term in terms {
if let UserInputAst::Clause(sub_clauses) = &mut input {
// call rewrite_ast recursively on children clauses if applicable
let mut new_clauses = Vec::with_capacity(sub_clauses.len());
for (occur, clause) in sub_clauses.drain(..) {
let rewritten_clause = rewrite_ast(clause);
new_clauses.push((occur, rewritten_clause));
}
*sub_clauses = new_clauses;
// remove duplicate child clauses
// e.g. (+a +b) OR (+c +d) OR (+a +b) => (+a +b) OR (+c +d)
let mut seen = FnvHashSet::default();
sub_clauses.retain(|term| seen.insert(term.clone()));
// Removes unnecessary children clauses in AST
//
// Motivated by [issue #1433](https://github.com/quickwit-oss/tantivy/issues/1433)
for term in sub_clauses {
rewrite_ast_clause(term);
}
}
@@ -1596,6 +1691,21 @@ mod test {
test_parse_query_to_ast_helper("abc:a b", "(*\"abc\":a *b)");
test_parse_query_to_ast_helper("abc:\"a b\"", "\"abc\":\"a b\"");
test_parse_query_to_ast_helper("foo:[1 TO 5]", "\"foo\":[\"1\" TO \"5\"]");
// Phrase prefixed with *
test_parse_query_to_ast_helper("foo:(*A)", "\"foo\":*A");
test_parse_query_to_ast_helper("*A", "*A");
test_parse_query_to_ast_helper("(*A)", "*A");
test_parse_query_to_ast_helper("foo:(A OR B)", "(?\"foo\":A ?\"foo\":B)");
test_parse_query_to_ast_helper("foo:(A* OR B*)", "(?\"foo\":A* ?\"foo\":B*)");
test_parse_query_to_ast_helper("foo:(*A OR *B)", "(?\"foo\":*A ?\"foo\":*B)");
}
#[test]
fn test_parse_query_all() {
test_parse_query_to_ast_helper("*", "*");
test_parse_query_to_ast_helper("(*)", "*");
test_parse_query_to_ast_helper("(* )", "*");
}
#[test]
@@ -1694,6 +1804,63 @@ mod test {
test_is_parse_err(r#"!bc:def"#, "!bc:def");
}
#[test]
fn test_regex_parser() {
let r = parse_to_ast(r#"a:/joh?n(ath[oa]n)/"#);
assert!(r.is_ok(), "Failed to parse custom query: {r:?}");
let (_, input) = r.unwrap();
match input {
UserInputAst::Leaf(leaf) => match leaf.as_ref() {
UserInputLeaf::Regex { field, pattern } => {
assert_eq!(field, &Some("a".to_string()));
assert_eq!(pattern, "joh?n(ath[oa]n)");
}
_ => panic!("Expected a regex leaf, got {leaf:?}"),
},
_ => panic!("Expected a leaf"),
}
let r = parse_to_ast(r#"a:/\\/cgi-bin\\/luci.*/"#);
assert!(r.is_ok(), "Failed to parse custom query: {r:?}");
let (_, input) = r.unwrap();
match input {
UserInputAst::Leaf(leaf) => match leaf.as_ref() {
UserInputLeaf::Regex { field, pattern } => {
assert_eq!(field, &Some("a".to_string()));
assert_eq!(pattern, "\\/cgi-bin\\/luci.*");
}
_ => panic!("Expected a regex leaf, got {leaf:?}"),
},
_ => panic!("Expected a leaf"),
}
}
#[test]
fn test_regex_parser_lenient() {
let literal = |query| literal_infallible(query).unwrap().1;
let (res, errs) = literal(r#"a:/joh?n(ath[oa]n)/"#);
let expected = UserInputLeaf::Regex {
field: Some("a".to_string()),
pattern: "joh?n(ath[oa]n)".to_string(),
}
.into();
assert_eq!(res.unwrap(), expected);
assert!(errs.is_empty(), "Expected no errors, got: {errs:?}");
let (res, errs) = literal("title:/joh?n(ath[oa]n)");
let expected = UserInputLeaf::Regex {
field: Some("title".to_string()),
pattern: "joh?n(ath[oa]n)".to_string(),
}
.into();
assert_eq!(res.unwrap(), expected);
assert_eq!(errs.len(), 1, "Expected 1 error, got: {errs:?}");
assert_eq!(
errs[0].message, "missing delimiter /",
"Unexpected error message",
);
}
#[test]
fn test_space_before_value() {
test_parse_query_to_ast_helper("field : a", r#""field":a"#);

View File

@@ -5,7 +5,7 @@ use serde::Serialize;
use crate::Occur;
#[derive(PartialEq, Clone, Serialize)]
#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum UserInputLeaf {
@@ -23,6 +23,10 @@ pub enum UserInputLeaf {
Exists {
field: String,
},
Regex {
field: Option<String>,
pattern: String,
},
}
impl UserInputLeaf {
@@ -46,6 +50,7 @@ impl UserInputLeaf {
UserInputLeaf::Exists { field: _ } => UserInputLeaf::Exists {
field: field.expect("Exist query without a field isn't allowed"),
},
UserInputLeaf::Regex { field: _, pattern } => UserInputLeaf::Regex { field, pattern },
}
}
@@ -103,11 +108,19 @@ impl Debug for UserInputLeaf {
UserInputLeaf::Exists { field } => {
write!(formatter, "$exists(\"{field}\")")
}
UserInputLeaf::Regex { field, pattern } => {
if let Some(field) = field {
// TODO properly escape field (in case of \")
write!(formatter, "\"{field}\":")?;
}
// TODO properly escape pattern (in case of \")
write!(formatter, "/{pattern}/")
}
}
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum Delimiter {
SingleQuotes,
@@ -115,7 +128,7 @@ pub enum Delimiter {
None,
}
#[derive(PartialEq, Clone, Serialize)]
#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
pub struct UserInputLiteral {
pub field_name: Option<String>,
@@ -154,7 +167,7 @@ impl fmt::Debug for UserInputLiteral {
}
}
#[derive(PartialEq, Debug, Clone, Serialize)]
#[derive(PartialEq, Eq, Hash, Debug, Clone, Serialize)]
#[serde(tag = "type", content = "value")]
#[serde(rename_all = "snake_case")]
pub enum UserInputBound {
@@ -191,11 +204,11 @@ impl UserInputBound {
}
}
#[derive(PartialEq, Clone, Serialize)]
#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
#[serde(into = "UserInputAstSerde")]
pub enum UserInputAst {
Clause(Vec<(Option<Occur>, UserInputAst)>),
Boost(Box<UserInputAst>, f64),
Boost(Box<UserInputAst>, ordered_float::OrderedFloat<f64>),
Leaf(Box<UserInputLeaf>),
}
@@ -217,9 +230,10 @@ impl From<UserInputAst> for UserInputAstSerde {
fn from(ast: UserInputAst) -> Self {
match ast {
UserInputAst::Clause(clause) => UserInputAstSerde::Bool { clauses: clause },
UserInputAst::Boost(underlying, boost) => {
UserInputAstSerde::Boost { underlying, boost }
}
UserInputAst::Boost(underlying, boost) => UserInputAstSerde::Boost {
underlying,
boost: boost.into_inner(),
},
UserInputAst::Leaf(leaf) => UserInputAstSerde::Leaf(leaf),
}
}
@@ -378,7 +392,7 @@ mod tests {
#[test]
fn test_boost_serialization() {
let inner_ast = UserInputAst::Leaf(Box::new(UserInputLeaf::All));
let boost_ast = UserInputAst::Boost(Box::new(inner_ast), 2.5);
let boost_ast = UserInputAst::Boost(Box::new(inner_ast), 2.5.into());
let json = serde_json::to_string(&boost_ast).unwrap();
assert_eq!(
json,
@@ -405,7 +419,7 @@ mod tests {
}))),
),
])),
2.5,
2.5.into(),
);
let json = serde_json::to_string(&boost_ast).unwrap();
assert_eq!(

View File

@@ -20,17 +20,16 @@ Contains all metric aggregations, like average aggregation. Metric aggregations
#### agg_req
agg_req contains the users aggregation request. Deserialization from json is compatible with elasticsearch aggregation requests.
#### agg_req_with_accessor
agg_req_with_accessor contains the users aggregation request enriched with fast field accessors etc, which are
#### agg_data
agg_data contains the users aggregation request enriched with fast field accessors etc, which are
used during collection.
#### segment_agg_result
segment_agg_result contains the aggregation result tree, which is used for collection of a segment.
The tree from agg_req_with_accessor is passed during collection.
agg_data is passed during collection.
#### intermediate_agg_result
intermediate_agg_result contains the aggregation tree for merging with other trees.
#### agg_result
agg_result contains the final aggregation tree.

View File

@@ -0,0 +1,105 @@
//! This will enhance the request tree with access to the fastfield and metadata.
use std::io;
use columnar::{Column, ColumnType};
use crate::aggregation::{f64_to_fastfield_u64, Key};
use crate::index::SegmentReader;
/// Get the missing value as internal u64 representation
///
/// For terms we use u64::MAX as sentinel value
/// For numerical data we convert the value into the representation
/// we would get from the fast field, when we open it as u64_lenient_for_type.
///
/// That way we can use it the same way as if it would come from the fastfield.
pub(crate) fn get_missing_val_as_u64_lenient(
column_type: ColumnType,
column_max_value: u64,
missing: &Key,
field_name: &str,
) -> crate::Result<Option<u64>> {
let missing_val = match missing {
Key::Str(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
// Allow fallback to number on text fields
Key::F64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::U64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::I64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::F64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val, &column_type)
}
// NOTE: We may loose precision of the passed missing value by casting i64 and u64 to f64.
Key::I64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
Key::U64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
_ => {
return Err(crate::TantivyError::InvalidArgument(format!(
"Missing value {missing:?} for field {field_name} is not supported for column \
type {column_type:?}"
)));
}
};
Ok(missing_val)
}
pub(crate) fn get_numeric_or_date_column_types() -> &'static [ColumnType] {
&[
ColumnType::F64,
ColumnType::U64,
ColumnType::I64,
ColumnType::DateTime,
]
}
/// Get fast field reader or empty as default.
pub(crate) fn get_ff_reader(
reader: &SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
) -> crate::Result<(columnar::Column<u64>, ColumnType)> {
let ff_fields = reader.fast_fields();
let ff_field_with_type = ff_fields
.u64_lenient_for_type(allowed_column_types, field_name)?
.unwrap_or_else(|| {
(
Column::build_empty_column(reader.num_docs()),
ColumnType::U64,
)
});
Ok(ff_field_with_type)
}
pub(crate) fn get_dynamic_columns(
reader: &SegmentReader,
field_name: &str,
) -> crate::Result<Vec<columnar::DynamicColumn>> {
let ff_fields = reader.fast_fields().dynamic_column_handles(field_name)?;
let cols = ff_fields
.iter()
.map(|h| h.open())
.collect::<io::Result<_>>()?;
assert!(!ff_fields.is_empty(), "field {field_name} not found");
Ok(cols)
}
/// Get all fast field reader or empty as default.
///
/// Is guaranteed to return at least one column.
pub(crate) fn get_all_ff_reader_or_empty(
reader: &SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
fallback_type: ColumnType,
) -> crate::Result<Vec<(columnar::Column<u64>, ColumnType)>> {
let ff_fields = reader.fast_fields();
let mut ff_field_with_type =
ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?;
if ff_field_with_type.is_empty() {
ff_field_with_type.push((Column::build_empty_column(reader.num_docs()), fallback_type));
}
Ok(ff_field_with_type)
}

1071
src/aggregation/agg_data.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -35,6 +35,7 @@ pub struct AggregationLimitsGuard {
/// Allocated memory with this guard.
allocated_with_the_guard: u64,
}
impl Clone for AggregationLimitsGuard {
fn clone(&self) -> Self {
Self {
@@ -70,7 +71,7 @@ impl AggregationLimitsGuard {
/// *memory_limit*
/// memory_limit is defined in bytes.
/// Aggregation fails when the estimated memory consumption of the aggregation is higher than
/// memory_limit.
/// memory_limit.
/// memory_limit will default to `DEFAULT_MEMORY_LIMIT` (500MB)
///
/// *bucket_limit*

View File

@@ -26,12 +26,14 @@
//! let _agg_req: Aggregations = serde_json::from_str(elasticsearch_compatible_json_req).unwrap();
//! ```
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
DateHistogramAggregationReq, FilterAggregation, HistogramAggregation, RangeAggregation,
TermsAggregation,
};
use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
@@ -43,7 +45,7 @@ use super::metric::{
/// defined names. It is also used in buckets aggregations to define sub-aggregations.
///
/// The key is the user defined name of the aggregation.
pub type Aggregations = HashMap<String, Aggregation>;
pub type Aggregations = FxHashMap<String, Aggregation>;
/// Aggregation request.
///
@@ -129,6 +131,9 @@ pub enum AggregationVariants {
/// Put data into buckets of terms.
#[serde(rename = "terms")]
Terms(TermsAggregation),
/// Filter documents into a single bucket.
#[serde(rename = "filter")]
Filter(FilterAggregation),
// Metric aggregation types
/// Computes the average of the extracted values.
@@ -174,6 +179,7 @@ impl AggregationVariants {
AggregationVariants::Range(range) => vec![range.field.as_str()],
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::Filter(filter) => filter.get_fast_field_names(),
AggregationVariants::Average(avg) => vec![avg.field_name()],
AggregationVariants::Count(count) => vec![count.field_name()],
AggregationVariants::Max(max) => vec![max.field_name()],
@@ -208,13 +214,6 @@ impl AggregationVariants {
_ => None,
}
}
pub(crate) fn as_top_hits(&self) -> Option<&TopHitsAggregationReq> {
match &self {
AggregationVariants::TopHits(top_hits) => Some(top_hits),
_ => None,
}
}
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
match &self {
AggregationVariants::Percentiles(percentile_req) => Some(percentile_req),

View File

@@ -1,471 +0,0 @@
//! This will enhance the request tree with access to the fastfield and metadata.
use std::collections::HashMap;
use std::io;
use columnar::{Column, ColumnBlockAccessor, ColumnType, DynamicColumn, StrColumn};
use super::agg_req::{Aggregation, AggregationVariants, Aggregations};
use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
};
use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
MaxAggregation, MinAggregation, StatsAggregation, SumAggregation,
};
use super::segment_agg_result::AggregationLimitsGuard;
use super::VecWithNames;
use crate::aggregation::{f64_to_fastfield_u64, Key};
use crate::index::SegmentReader;
use crate::SegmentOrdinal;
#[derive(Default)]
pub(crate) struct AggregationsWithAccessor {
pub aggs: VecWithNames<AggregationWithAccessor>,
}
impl AggregationsWithAccessor {
fn from_data(aggs: VecWithNames<AggregationWithAccessor>) -> Self {
Self { aggs }
}
pub fn is_empty(&self) -> bool {
self.aggs.is_empty()
}
}
pub struct AggregationWithAccessor {
pub(crate) segment_ordinal: SegmentOrdinal,
/// In general there can be buckets without fast field access, e.g. buckets that are created
/// based on search terms. That is not that case currently, but eventually this needs to be
/// Option or moved.
pub(crate) accessor: Column<u64>,
/// Load insert u64 for missing use case
pub(crate) missing_value_for_accessor: Option<u64>,
pub(crate) str_dict_column: Option<StrColumn>,
pub(crate) field_type: ColumnType,
pub(crate) sub_aggregation: AggregationsWithAccessor,
pub(crate) limits: AggregationLimitsGuard,
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
/// Used for missing term aggregation, which checks all columns for existence.
/// And also for `top_hits` aggregation, which may sort on multiple fields.
/// By convention the missing aggregation is chosen, when this property is set
/// (instead bein set in `agg`).
/// If this needs to used by other aggregations, we need to refactor this.
// NOTE: we can make all other aggregations use this instead of the `accessor` and `field_type`
// (making them obsolete) But will it have a performance impact?
pub(crate) accessors: Vec<(Column<u64>, ColumnType)>,
/// Map field names to all associated column accessors.
/// This field is used for `docvalue_fields`, which is currently only supported for `top_hits`.
pub(crate) value_accessors: HashMap<String, Vec<DynamicColumn>>,
pub(crate) agg: Aggregation,
}
impl AggregationWithAccessor {
/// May return multiple accessors if the aggregation is e.g. on mixed field types.
fn try_from_agg(
agg: &Aggregation,
sub_aggregation: &Aggregations,
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
limits: AggregationLimitsGuard,
) -> crate::Result<Vec<AggregationWithAccessor>> {
let mut agg = agg.clone();
let add_agg_with_accessor = |agg: &Aggregation,
accessor: Column<u64>,
column_type: ColumnType,
aggs: &mut Vec<AggregationWithAccessor>|
-> crate::Result<()> {
let res = AggregationWithAccessor {
segment_ordinal,
accessor,
accessors: Default::default(),
value_accessors: Default::default(),
field_type: column_type,
sub_aggregation: get_aggs_with_segment_accessor_and_validate(
sub_aggregation,
reader,
segment_ordinal,
&limits,
)?,
agg: agg.clone(),
limits: limits.clone(),
missing_value_for_accessor: None,
str_dict_column: None,
column_block_accessor: Default::default(),
};
aggs.push(res);
Ok(())
};
let add_agg_with_accessors = |agg: &Aggregation,
accessors: Vec<(Column<u64>, ColumnType)>,
aggs: &mut Vec<AggregationWithAccessor>,
value_accessors: HashMap<String, Vec<DynamicColumn>>|
-> crate::Result<()> {
let (accessor, field_type) = accessors.first().expect("at least one accessor");
let limits = limits.clone();
let res = AggregationWithAccessor {
segment_ordinal,
// TODO: We should do away with the `accessor` field altogether
accessor: accessor.clone(),
value_accessors,
field_type: *field_type,
accessors,
sub_aggregation: get_aggs_with_segment_accessor_and_validate(
sub_aggregation,
reader,
segment_ordinal,
&limits,
)?,
agg: agg.clone(),
limits,
missing_value_for_accessor: None,
str_dict_column: None,
column_block_accessor: Default::default(),
};
aggs.push(res);
Ok(())
};
let mut res: Vec<AggregationWithAccessor> = Vec::new();
use AggregationVariants::*;
match agg.agg {
Range(RangeAggregation {
field: ref field_name,
..
}) => {
let (accessor, column_type) =
get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
Histogram(HistogramAggregation {
field: ref field_name,
..
}) => {
let (accessor, column_type) =
get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
DateHistogram(DateHistogramAggregationReq {
field: ref field_name,
..
}) => {
let (accessor, column_type) =
// Only DateTime is supported for DateHistogram
get_ff_reader(reader, field_name, Some(&[ColumnType::DateTime]))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
Terms(TermsAggregation {
field: ref field_name,
ref missing,
..
})
| Cardinality(CardinalityAggregationReq {
field: ref field_name,
ref missing,
..
}) => {
let str_dict_column = reader.fast_fields().str(field_name)?;
let allowed_column_types = [
ColumnType::I64,
ColumnType::U64,
ColumnType::F64,
ColumnType::Str,
ColumnType::DateTime,
ColumnType::Bool,
ColumnType::IpAddr,
// ColumnType::Bytes Unsupported
];
// In case the column is empty we want the shim column to match the missing type
let fallback_type = missing
.as_ref()
.map(|missing| match missing {
Key::Str(_) => ColumnType::Str,
Key::F64(_) => ColumnType::F64,
Key::I64(_) => ColumnType::I64,
Key::U64(_) => ColumnType::U64,
})
.unwrap_or(ColumnType::U64);
let column_and_types = get_all_ff_reader_or_empty(
reader,
field_name,
Some(&allowed_column_types),
fallback_type,
)?;
let missing_and_more_than_one_col = column_and_types.len() > 1 && missing.is_some();
let text_on_non_text_col = column_and_types.len() == 1
&& column_and_types[0].1.numerical_type().is_some()
&& missing
.as_ref()
.map(|m| matches!(m, Key::Str(_)))
.unwrap_or(false);
// Actually we could convert the text to a number and have the fast path, if it is
// provided in Rfc3339 format. But this use case is probably common
// enough to justify the effort.
let text_on_date_col = column_and_types.len() == 1
&& column_and_types[0].1 == ColumnType::DateTime
&& missing
.as_ref()
.map(|m| matches!(m, Key::Str(_)))
.unwrap_or(false);
let use_special_missing_agg =
missing_and_more_than_one_col || text_on_non_text_col || text_on_date_col;
if use_special_missing_agg {
let column_and_types =
get_all_ff_reader_or_empty(reader, field_name, None, fallback_type)?;
let accessors = column_and_types
.iter()
.map(|c_t| (c_t.0.clone(), c_t.1))
.collect();
add_agg_with_accessors(&agg, accessors, &mut res, Default::default())?;
}
for (accessor, column_type) in column_and_types {
let missing_value_term_agg = if use_special_missing_agg {
None
} else {
missing.clone()
};
let missing_value_for_accessor =
if let Some(missing) = missing_value_term_agg.as_ref() {
get_missing_val_as_u64_lenient(
column_type,
missing,
agg.agg.get_fast_field_names()[0],
)?
} else {
None
};
let limits = limits.clone();
let agg = AggregationWithAccessor {
segment_ordinal,
missing_value_for_accessor,
accessor,
accessors: Default::default(),
value_accessors: Default::default(),
field_type: column_type,
sub_aggregation: get_aggs_with_segment_accessor_and_validate(
sub_aggregation,
reader,
segment_ordinal,
&limits,
)?,
agg: agg.clone(),
str_dict_column: str_dict_column.clone(),
limits,
column_block_accessor: Default::default(),
};
res.push(agg);
}
}
Average(AverageAggregation {
field: ref field_name,
..
})
| Max(MaxAggregation {
field: ref field_name,
..
})
| Min(MinAggregation {
field: ref field_name,
..
})
| Stats(StatsAggregation {
field: ref field_name,
..
})
| ExtendedStats(ExtendedStatsAggregation {
field: ref field_name,
..
})
| Sum(SumAggregation {
field: ref field_name,
..
}) => {
let (accessor, column_type) =
get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
Count(CountAggregation {
field: ref field_name,
..
}) => {
let allowed_column_types = [
ColumnType::I64,
ColumnType::U64,
ColumnType::F64,
ColumnType::Str,
ColumnType::DateTime,
ColumnType::Bool,
ColumnType::IpAddr,
// ColumnType::Bytes Unsupported
];
let (accessor, column_type) =
get_ff_reader(reader, field_name, Some(&allowed_column_types))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
Percentiles(ref percentiles) => {
let (accessor, column_type) = get_ff_reader(
reader,
percentiles.field_name(),
Some(get_numeric_or_date_column_types()),
)?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
TopHits(ref mut top_hits) => {
top_hits.validate_and_resolve_field_names(reader.fast_fields().columnar())?;
let accessors: Vec<(Column<u64>, ColumnType)> = top_hits
.field_names()
.iter()
.map(|field| {
get_ff_reader(reader, field, Some(get_numeric_or_date_column_types()))
})
.collect::<crate::Result<_>>()?;
let value_accessors = top_hits
.value_field_names()
.iter()
.map(|field_name| {
Ok((
field_name.to_string(),
get_dynamic_columns(reader, field_name)?,
))
})
.collect::<crate::Result<_>>()?;
add_agg_with_accessors(&agg, accessors, &mut res, value_accessors)?;
}
};
Ok(res)
}
}
/// Get the missing value as internal u64 representation
///
/// For terms we use u64::MAX as sentinel value
/// For numerical data we convert the value into the representation
/// we would get from the fast field, when we open it as u64_lenient_for_type.
///
/// That way we can use it the same way as if it would come from the fastfield.
fn get_missing_val_as_u64_lenient(
column_type: ColumnType,
missing: &Key,
field_name: &str,
) -> crate::Result<Option<u64>> {
let missing_val = match missing {
Key::Str(_) if column_type == ColumnType::Str => Some(u64::MAX),
// Allow fallback to number on text fields
Key::F64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::U64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::I64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::F64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val, &column_type)
}
// NOTE: We may loose precision of the passed missing value by casting i64 and u64 to f64.
Key::I64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
Key::U64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
_ => {
return Err(crate::TantivyError::InvalidArgument(format!(
"Missing value {missing:?} for field {field_name} is not supported for column \
type {column_type:?}"
)));
}
};
Ok(missing_val)
}
fn get_numeric_or_date_column_types() -> &'static [ColumnType] {
&[
ColumnType::F64,
ColumnType::U64,
ColumnType::I64,
ColumnType::DateTime,
]
}
pub(crate) fn get_aggs_with_segment_accessor_and_validate(
aggs: &Aggregations,
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
limits: &AggregationLimitsGuard,
) -> crate::Result<AggregationsWithAccessor> {
let mut aggss = Vec::new();
for (key, agg) in aggs.iter() {
let aggs = AggregationWithAccessor::try_from_agg(
agg,
agg.sub_aggregation(),
reader,
segment_ordinal,
limits.clone(),
)?;
for agg in aggs {
aggss.push((key.to_string(), agg));
}
}
Ok(AggregationsWithAccessor::from_data(
VecWithNames::from_entries(aggss),
))
}
/// Get fast field reader or empty as default.
fn get_ff_reader(
reader: &SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
) -> crate::Result<(columnar::Column<u64>, ColumnType)> {
let ff_fields = reader.fast_fields();
let ff_field_with_type = ff_fields
.u64_lenient_for_type(allowed_column_types, field_name)?
.unwrap_or_else(|| {
(
Column::build_empty_column(reader.num_docs()),
ColumnType::U64,
)
});
Ok(ff_field_with_type)
}
fn get_dynamic_columns(
reader: &SegmentReader,
field_name: &str,
) -> crate::Result<Vec<columnar::DynamicColumn>> {
let ff_fields = reader.fast_fields().dynamic_column_handles(field_name)?;
let cols = ff_fields
.iter()
.map(|h| h.open())
.collect::<io::Result<_>>()?;
assert!(!ff_fields.is_empty(), "field {field_name} not found");
Ok(cols)
}
/// Get all fast field reader or empty as default.
///
/// Is guaranteed to return at least one column.
fn get_all_ff_reader_or_empty(
reader: &SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
fallback_type: ColumnType,
) -> crate::Result<Vec<(columnar::Column<u64>, ColumnType)>> {
let ff_fields = reader.fast_fields();
let mut ff_field_with_type =
ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?;
if ff_field_with_type.is_empty() {
ff_field_with_type.push((Column::build_empty_column(reader.num_docs()), fallback_type));
}
Ok(ff_field_with_type)
}

View File

@@ -16,7 +16,7 @@ use super::{AggregationError, Key};
use crate::TantivyError;
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
/// The final aggegation result.
/// The final aggregation result.
pub struct AggregationResults(pub FxHashMap<String, AggregationResult>);
impl AggregationResults {
@@ -156,6 +156,8 @@ pub enum BucketResult {
/// The upper bound error for the doc count of each term.
doc_count_error_upper_bound: Option<u64>,
},
/// This is the filter result - a single bucket with sub-aggregations
Filter(FilterBucketResult),
}
impl BucketResult {
@@ -172,6 +174,11 @@ impl BucketResult {
sum_other_doc_count: _,
doc_count_error_upper_bound: _,
} => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(),
BucketResult::Filter(filter_result) => {
// Filter doesn't add to bucket count - it's not a user-facing bucket
// Only count sub-aggregation buckets
filter_result.sub_aggregations.get_bucket_count()
}
}
}
}
@@ -308,3 +315,25 @@ impl RangeBucketEntry {
1 + self.sub_aggregation.get_bucket_count()
}
}
/// This is the filter bucket result, which contains the document count and sub-aggregations.
///
/// # JSON Format
/// ```json
/// {
/// "electronics_only": {
/// "doc_count": 2,
/// "avg_price": {
/// "value": 150.0
/// }
/// }
/// }
/// ```
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct FilterBucketResult {
/// Number of documents in the filter bucket
pub doc_count: u64,
/// Sub-aggregation results
#[serde(flatten)]
pub sub_aggregations: AggregationResults,
}

View File

@@ -2,16 +2,441 @@ use serde_json::Value;
use crate::aggregation::agg_req::{Aggregation, Aggregations};
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::segment_agg_result::AggregationLimitsGuard;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, FAST};
use crate::{Index, IndexWriter, Term};
// The following tests ensure that each bucket aggregation type correctly functions as a
// sub-aggregation of another bucket aggregation in two scenarios:
// 1) The parent has more buckets than the child sub-aggregation
// 2) The child sub-aggregation has more buckets than the parent
//
// These scenarios exercise the bucket id mapping and sub-aggregation routing logic.
#[test]
fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 4 buckets
// Child: terms on text -> 2 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
// Exact expected structure and counts
assert_eq!(
res["parent_range"]["buckets"],
json!([
{
"key": "*-3",
"doc_count": 1,
"to": 3.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "3-7",
"doc_count": 3,
"from": 3.0,
"to": 7.0,
"child_terms": {
"buckets": [
{"doc_count": 2, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
},
{
"key": "7-20",
"doc_count": 3,
"from": 7.0,
"to": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 3, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "20-*",
"doc_count": 2,
"from": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
])
);
// Case B: child has more buckets than parent
// Parent: histogram on score with large interval -> 1 bucket
// Child: terms on text -> 2 buckets (cool/nohit)
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_hist": {
"histogram": {"field": "score", "interval": 100.0},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_hist"],
json!({
"buckets": [
{
"key": 0.0,
"doc_count": 9,
"child_terms": {
"buckets": [
{"doc_count": 7, "key": "cool"},
{"doc_count": 2, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
]
})
);
Ok(())
}
#[test]
fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 5 buckets
// Child: coarse range with 3 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 2, "from": 20.0}
]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: range with 4 buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several ranges
// Child: histogram with large interval (single bucket per parent)
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 100.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text -> 2 buckets
// Child: histogram with small interval -> multiple buckets including empties
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 10.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 4},
{"key": 10.0, "doc_count": 2},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 1},
{"key": 10.0, "doc_count": 0},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several buckets
// Child: date_histogram with 30d -> single bucket per parent
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
let buckets = res["parent_range"]["buckets"].as_array().unwrap();
// Verify each parent bucket has exactly one child date bucket with matching doc_count
for bucket in buckets {
let parent_count = bucket["doc_count"].as_u64().unwrap();
let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(child_buckets.len(), 1);
assert_eq!(child_buckets[0]["doc_count"], parent_count);
}
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: date_histogram with 1d -> multiple buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
let buckets = res["parent_terms"]["buckets"].as_array().unwrap();
// cool bucket
assert_eq!(buckets[0]["key"], "cool");
let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(cool_buckets.len(), 3);
assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0
assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1
assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2
// nohit bucket
assert_eq!(buckets[1]["key"], "nohit");
let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(nohit_buckets.len(), 2);
assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1
assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2
Ok(())
}
fn get_avg_req(field_name: &str) -> Aggregation {
serde_json::from_value(json!({
"avg": {
@@ -26,6 +451,10 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
// Note: The flushng part of these tests are outdated, since the buffering change after converting
// the collection into one collector per request instead of per bucket.
//
// However they are useful as they test a complex aggregation requests.
fn test_aggregation_flushing(
merge_segments: bool,
use_distributed_collector: bool,
@@ -38,8 +467,9 @@ fn test_aggregation_flushing(
let reader = index.reader()?;
assert_eq!(DOC_BLOCK_SIZE, 64);
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block.
assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64);
// In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one
// block.
//
// Build a request so that on the first level we have one full cache, which is then flushed.
// The same cache should have some residue docs at the end, which are flushed (Range 0-70)
@@ -128,10 +558,8 @@ fn test_aggregation_flushing(
.unwrap();
let agg_res: AggregationResults = if use_distributed_collector {
let collector = DistributedAggregationCollector::from_aggs(
agg_req.clone(),
AggregationLimitsGuard::default(),
);
let collector =
DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default());
let searcher = reader.searcher();
let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap();
@@ -155,7 +583,7 @@ fn test_aggregation_flushing(
searcher.search(&AllQuery, &collector).unwrap()
};
let res: Value = serde_json::to_value(&agg_res)?;
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
assert_eq!(res["bucketsL1"]["buckets"][0]["doc_count"], 3);
assert_eq!(
@@ -270,7 +698,7 @@ fn test_aggregation_level1_simple() -> crate::Result<()> {
let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();
let res: Value = serde_json::to_value(&agg_res)?;
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
assert_eq!(res["average"]["value"], 12.142857142857142);
assert_eq!(
res["range"]["buckets"],
@@ -304,29 +732,6 @@ fn test_aggregation_level1_simple() -> crate::Result<()> {
Ok(())
}
#[test]
fn test_aggregation_term_truncate_sum_other_doc_count() {
let index = get_test_index_2_segments(true).unwrap();
let reader = index.reader().unwrap();
let count_per_text: Aggregation = serde_json::from_value(json!({ "terms": { "field": "text", "size": 1 } })).unwrap();
let aggs: Aggregations = vec![("group_by_term_truncate".to_string(), count_per_text)]
.into_iter()
.collect();
let collector = get_collector(aggs);
let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
let res: Value = serde_json::to_value(&agg_res).unwrap();
assert_eq!(res, serde_json::json!({
"group_by_term_truncate": {
"buckets": [{ "doc_count": 7, "key": "cool" }],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 2,
},
}));
}
#[test]
fn test_aggregation_level1() -> crate::Result<()> {
let index = get_test_index_2_segments(true)?;
@@ -365,7 +770,7 @@ fn test_aggregation_level1() -> crate::Result<()> {
let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();
let res: Value = serde_json::to_value(&agg_res)?;
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
assert_eq!(res["average"]["value"], 12.142857142857142);
assert_eq!(res["average_f64"]["value"], 12.214285714285714);
assert_eq!(res["average_i64"]["value"], 12.142857142857142);
@@ -420,7 +825,7 @@ fn test_aggregation_level2(
IndexRecordOption::Basic,
);
let elasticsearch_compatible_json_req = serde_json::json!(
let elasticsearch_compatible_json_req = r#"
{
"rangef64": {
"range": {
@@ -473,8 +878,9 @@ fn test_aggregation_level2(
"term_agg": { "terms": { "field": "text" } }
}
}
});
let agg_req: Aggregations = serde_json::from_value(elasticsearch_compatible_json_req).unwrap();
}
"#;
let agg_req: Aggregations = serde_json::from_str(elasticsearch_compatible_json_req).unwrap();
let agg_res: AggregationResults = if use_distributed_collector {
let collector =
@@ -491,7 +897,7 @@ fn test_aggregation_level2(
searcher.search(&term_query, &collector).unwrap()
};
let res: Value = serde_json::to_value(agg_res)?;
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
assert_eq!(res["range"]["buckets"][1]["key"], "3-7");
assert_eq!(res["range"]["buckets"][1]["doc_count"], 2u64);

File diff suppressed because it is too large Load Diff

View File

@@ -1,25 +1,49 @@
use std::cmp::Ordering;
use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use tantivy_bitpacker::minmax;
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::BucketEntry;
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::*;
use crate::TantivyError;
/// Contains all information required by the SegmentHistogramCollector to perform the
/// histogram or date_histogram aggregation on a segment.
pub struct HistogramAggReqData {
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// The field type of the fast field.
pub field_type: ColumnType,
/// The name of the aggregation.
pub name: String,
/// The histogram aggregation request.
pub req: HistogramAggregation,
/// True if this is a date_histogram aggregation.
pub is_date_histogram: bool,
/// The bounds to limit the buckets to.
pub bounds: HistogramBounds,
/// The offset used to calculate the bucket position.
pub offset: f64,
}
impl HistogramAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`.
/// Each document value is rounded down to its bucket.
///
@@ -228,18 +252,24 @@ impl HistogramBounds {
pub(crate) struct SegmentHistogramBucketEntry {
pub key: f64,
pub doc_count: u64,
pub bucket_id: BucketId,
}
impl SegmentHistogramBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
agg_with_accessor: &AggregationsWithAccessor,
sub_aggregation: &mut Option<CachedSubAggs>,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateHistogramBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)?;
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
self.bucket_id,
)?;
}
Ok(IntermediateHistogramBucketEntry {
key: self.key,
@@ -249,31 +279,38 @@ impl SegmentHistogramBucketEntry {
}
}
#[derive(Clone, Debug, Default)]
struct HistogramBuckets {
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct SegmentHistogramCollector {
/// The buckets containing the aggregation data.
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>,
sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
column_type: ColumnType,
interval: f64,
offset: f64,
bounds: HistogramBounds,
/// One Histogram bucket per parent bucket id.
parent_buckets: Vec<HistogramBuckets>,
sub_agg: Option<CachedSubAggs>,
accessor_idx: usize,
bucket_id_provider: BucketIdProvider,
}
impl SegmentAggregationCollector for SegmentHistogramCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?;
let name = agg_data
.get_histogram_req_data(self.accessor_idx)
.name
.clone();
// TODO: avoid prepare_max_bucket here and handle empty buckets.
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
@@ -282,72 +319,77 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_with_accessor)
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
let req = agg_data.take_histogram_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let bounds = self.bounds;
let interval = self.interval;
let offset = self.offset;
let get_bucket_pos = |val| (get_bucket_pos_f64(val, interval, offset) as i64);
let bounds = req.bounds;
let interval = req.req.interval;
let offset = req.offset;
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
bucket_agg_accessor
agg_data
.column_block_accessor
.fetch_block(docs, &bucket_agg_accessor.accessor);
for (doc, val) in bucket_agg_accessor
.fetch_block(docs, &req.accessor);
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &bucket_agg_accessor.accessor)
.iter_docid_vals(docs, &req.accessor)
{
let val = self.f64_from_fastfield_u64(val);
let val = f64_from_fastfield_u64(val, req.field_type);
let bucket_pos = get_bucket_pos(val);
if bounds.contains(val) {
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| {
let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
SegmentHistogramBucketEntry { key, doc_count: 0 }
SegmentHistogramBucketEntry {
key,
doc_count: 0,
bucket_id: self.bucket_id_provider.next_bucket_id(),
}
});
bucket.doc_count += 1;
if let Some(sub_aggregation_blueprint) = self.sub_aggregation_blueprint.as_mut() {
self.sub_aggregations
.entry(bucket_pos)
.or_insert_with(|| sub_aggregation_blueprint.clone())
.collect(doc, &mut bucket_agg_accessor.sub_aggregation)?;
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.push(bucket.bucket_id, doc);
}
}
}
agg_data.put_back_histogram_req_data(self.accessor_idx, req);
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
bucket_agg_accessor
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
}
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
let sub_aggregation_accessor =
&mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
for sub_aggregation in self.sub_aggregations.values_mut() {
sub_aggregation.flush(sub_aggregation_accessor)?;
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_aggregation) = &mut self.sub_agg {
sub_aggregation.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
self.parent_buckets.push(HistogramBuckets {
buckets: FxHashMap::default(),
});
}
Ok(())
}
}
@@ -355,72 +397,62 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
impl SegmentHistogramCollector {
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let sub_aggs_mem = self.sub_aggregations.memory_consumption();
let buckets_mem = self.buckets.memory_consumption();
self_mem + sub_aggs_mem + buckets_mem
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
self_mem + buckets_mem
}
/// Converts the collector result into a intermediate bucket result.
pub fn into_intermediate_bucket_result(
self,
agg_with_accessor: &AggregationWithAccessor,
fn add_intermediate_bucket_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
histogram: HistogramBuckets,
) -> crate::Result<IntermediateBucketResult> {
let mut buckets = Vec::with_capacity(self.buckets.len());
let mut buckets = Vec::with_capacity(histogram.buckets.len());
for (bucket_pos, bucket) in self.buckets {
let bucket_res = bucket.into_intermediate_bucket_entry(
self.sub_aggregations.get(&bucket_pos).cloned(),
&agg_with_accessor.sub_aggregation,
);
for bucket in histogram.buckets.into_values() {
let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
buckets.push(bucket_res?);
}
buckets.sort_unstable_by(|b1, b2| b1.key.total_cmp(&b2.key));
let is_date_agg = agg_data
.get_histogram_req_data(self.accessor_idx)
.field_type
== ColumnType::DateTime;
Ok(IntermediateBucketResult::Histogram {
buckets,
is_date_agg: self.column_type == ColumnType::DateTime,
is_date_agg,
})
}
pub(crate) fn from_req_and_validate(
mut req: HistogramAggregation,
sub_aggregation: &mut AggregationsWithAccessor,
field_type: ColumnType,
accessor_idx: usize,
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
req.validate()?;
if field_type == ColumnType::DateTime {
req.normalize_date_time();
}
let sub_aggregation_blueprint = if sub_aggregation.is_empty() {
None
let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
let sub_aggregation = build_segment_agg_collector(sub_aggregation)?;
Some(sub_aggregation)
None
};
let bounds = req.hard_bounds.unwrap_or(HistogramBounds {
let req_data = agg_data.get_histogram_req_data_mut(node.idx_in_req_data);
req_data.req.validate()?;
if req_data.field_type == ColumnType::DateTime && !req_data.is_date_histogram {
req_data.req.normalize_date_time();
}
req_data.bounds = req_data.req.hard_bounds.unwrap_or(HistogramBounds {
min: f64::MIN,
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
let sub_agg = sub_agg.map(CachedSubAggs::new);
Ok(Self {
buckets: Default::default(),
column_type: field_type,
interval: req.interval,
offset: req.offset.unwrap_or(0.0),
bounds,
sub_aggregations: Default::default(),
sub_aggregation_blueprint,
accessor_idx,
parent_buckets: Default::default(),
sub_agg,
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
})
}
#[inline]
fn f64_from_fastfield_u64(&self, val: u64) -> f64 {
f64_from_fastfield_u64(val, &self.column_type)
}
}
#[inline]

View File

@@ -22,6 +22,7 @@
//! - [Range](RangeAggregation)
//! - [Terms](TermsAggregation)
mod filter;
mod histogram;
mod range;
mod term_agg;
@@ -30,6 +31,7 @@ mod term_missing_agg;
use std::collections::HashMap;
use std::fmt;
pub use filter::*;
pub use histogram::*;
pub use range::*;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};

View File

@@ -1,20 +1,45 @@
use std::fmt::Debug;
use std::ops::Range;
use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::AggregationLimitsGuard;
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::*;
use crate::TantivyError;
/// Contains all information required by the SegmentRangeCollector to perform the
/// range aggregation on a segment.
pub struct RangeAggReqData {
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// The type of the fast field.
pub field_type: ColumnType,
/// The range aggregation request.
pub req: RangeAggregation,
/// The name of the aggregation.
pub name: String,
/// Whether this is a top-level aggregation.
pub is_top_level: bool,
}
impl RangeAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// Provide user-defined buckets to aggregate on.
///
/// Two special buckets will automatically be created to cover the whole range of values.
@@ -128,19 +153,47 @@ pub(crate) struct SegmentRangeAndBucketEntry {
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug)]
pub struct SegmentRangeCollector {
pub struct SegmentRangeCollector<const LOWCARD: bool = false> {
/// The buckets containing the aggregation data.
buckets: Vec<SegmentRangeAndBucketEntry>,
/// One for each ParentBucketId
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
column_type: ColumnType,
pub(crate) accessor_idx: usize,
sub_agg: Option<CachedSubAggs<LOWCARD>>,
/// Here things get a bit weird. We need to assign unique bucket ids across all
/// parent buckets. So we keep track of the next available bucket id here.
/// This allows a kind of flattening of the bucket ids across all parent buckets.
/// E.g. in nested aggregations:
/// Term Agg -> Range aggregation -> Stats aggregation
/// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range
/// aggregation with 4 buckets. The Range aggregation will create buckets with ids:
/// - INFO: 0,1,2,3
/// - ERROR: 4,5,6,7
/// - WARN: 8,9,10,11
///
/// This allows the Stats aggregation to have unique bucket ids to refer to.
bucket_id_provider: BucketIdProvider,
limits: AggregationLimitsGuard,
}
impl<const LOWCARD: bool> Debug for SegmentRangeCollector<LOWCARD> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentRangeCollector")
.field("parent_buckets_len", &self.parent_buckets.len())
.field("column_type", &self.column_type)
.field("accessor_idx", &self.accessor_idx)
.field("has_sub_agg", &self.sub_agg.is_some())
.finish()
}
}
/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry
#[derive(Clone)]
pub(crate) struct SegmentRangeBucketEntry {
pub key: Key,
pub doc_count: u64,
pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
// pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
pub bucket_id: BucketId,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not
@@ -161,46 +214,50 @@ impl Debug for SegmentRangeBucketEntry {
impl SegmentRangeBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateRangeBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = self.sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)?
} else {
Default::default()
};
let sub_aggregation = IntermediateAggregationResults::default();
Ok(IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation: sub_aggregation_res,
sub_aggregation_res: sub_aggregation,
from: self.from,
to: self.to,
})
}
}
impl SegmentAggregationCollector for SegmentRangeCollector {
impl<const LOWCARD: bool> SegmentAggregationCollector for SegmentRangeCollector<LOWCARD> {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let field_type = self.column_type;
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let sub_agg = &agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
let name = agg_data
.get_range_req_data(self.accessor_idx)
.name
.to_string();
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
.buckets
let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = buckets
.into_iter()
.map(move |range_bucket| {
Ok((
range_to_string(&range_bucket.range, &field_type)?,
range_bucket
.bucket
.into_intermediate_bucket_entry(sub_agg)?,
))
.map(|range_bucket| {
let bucket_id = range_bucket.bucket.bucket_id;
let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?;
if let Some(sub_aggregation) = &mut self.sub_agg {
sub_aggregation
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut agg.sub_aggregation_res,
bucket_id,
)?;
}
Ok((range_to_string(&range_bucket.range, &field_type)?, agg))
})
.collect::<crate::Result<_>>()?;
@@ -217,69 +274,114 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_with_accessor)
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
let req = agg_data.take_range_req_data(self.accessor_idx);
bucket_agg_accessor
agg_data
.column_block_accessor
.fetch_block(docs, &bucket_agg_accessor.accessor);
.fetch_block(docs, &req.accessor);
for (doc, val) in bucket_agg_accessor
let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &bucket_agg_accessor.accessor)
.iter_docid_vals(docs, &req.accessor)
{
let bucket_pos = self.get_bucket_pos(val);
let bucket = &mut self.buckets[bucket_pos];
let bucket_pos = get_bucket_pos(val, buckets);
let bucket = &mut buckets[bucket_pos];
bucket.bucket.doc_count += 1;
if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation {
sub_aggregation.collect(doc, &mut bucket_agg_accessor.sub_aggregation)?;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket.bucket_id, doc);
}
}
agg_data.put_back_range_req_data(self.accessor_idx, req);
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
let sub_aggregation_accessor =
&mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
}
Ok(())
}
for bucket in self.buckets.iter_mut() {
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.flush(sub_aggregation_accessor)?;
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let new_buckets = self.create_new_buckets(agg_data)?;
self.parent_buckets.push(new_buckets);
}
Ok(())
}
}
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
pub(crate) fn build_segment_range_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let accessor_idx = node.idx_in_req_data;
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
let field_type = req_data.field_type;
impl SegmentRangeCollector {
pub(crate) fn from_req_and_validate(
req: &RangeAggregation,
sub_aggregation: &mut AggregationsWithAccessor,
limits: &mut AggregationLimitsGuard,
field_type: ColumnType,
accessor_idx: usize,
) -> crate::Result<Self> {
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
// E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets,
// we can are still in low cardinality territory.
let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64;
let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
};
if is_low_card {
Ok(Box::new(SegmentRangeCollector {
sub_agg: sub_agg.map(CachedSubAggs::<true>::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
} else {
Ok(Box::new(SegmentRangeCollector {
sub_agg: sub_agg.map(CachedSubAggs::<false>::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
}
}
impl<const LOWCARD: bool> SegmentRangeCollector<LOWCARD> {
pub(crate) fn create_new_buckets(
&mut self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
let field_type = self.column_type;
let req_data = agg_data.get_range_req_data(self.accessor_idx);
// The range input on the request is f64.
// We need to convert to u64 ranges, because we read the values as u64.
// The mapping from the conversion is monotonic so ordering is preserved.
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)?
let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)?
.iter()
.map(|range| {
let bucket_id = self.bucket_id_provider.next_bucket_id();
let key = range
.key
.clone()
@@ -288,24 +390,20 @@ impl SegmentRangeCollector {
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, &field_type))
Some(f64_from_fastfield_u64(range.range.end, field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, &field_type))
};
let sub_aggregation = if sub_aggregation.is_empty() {
None
} else {
Some(build_segment_agg_collector(sub_aggregation)?)
Some(f64_from_fastfield_u64(range.range.start, field_type))
};
// let sub_aggregation = sub_agg_prototype.clone();
Ok(SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
sub_aggregation,
bucket_id,
key,
from,
to,
@@ -314,27 +412,20 @@ impl SegmentRangeCollector {
})
.collect::<crate::Result<_>>()?;
limits.add_memory_consumed(
self.limits.add_memory_consumed(
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
)?;
Ok(SegmentRangeCollector {
buckets,
column_type: field_type,
accessor_idx,
})
}
#[inline]
fn get_bucket_pos(&self, val: u64) -> usize {
let pos = self
.buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(self.buckets[pos].range.contains(&val));
pos
Ok(buckets)
}
}
#[inline]
fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize {
let pos = buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(buckets[pos].range.contains(&val));
pos
}
/// Converts the user provided f64 range value to fast field value space.
///
@@ -431,7 +522,7 @@ pub(crate) fn range_to_string(
let val = i64::from_u64(val);
format_date(val)
} else {
Ok(f64_from_fastfield_u64(val, field_type).to_string())
Ok(f64_from_fastfield_u64(val, *field_type).to_string())
}
};
@@ -467,15 +558,48 @@ mod tests {
ranges,
..Default::default()
};
// Build buckets directly as in from_req_and_validate without AggregationsData
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)
.expect("unexpected error in extend_validate_ranges")
.iter()
.map(|range| {
let key = range
.key
.clone()
.map(|key| Ok(Key::Str(key)))
.unwrap_or_else(|| range_to_key(&range.range, &field_type))
.expect("unexpected error in range_to_key");
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, field_type))
};
SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
key,
from,
to,
bucket_id: 0,
},
}
})
.collect();
SegmentRangeCollector::from_req_and_validate(
&req,
&mut Default::default(),
&mut AggregationLimitsGuard::default(),
field_type,
0,
)
.expect("unexpected error")
SegmentRangeCollector {
parent_buckets: vec![buckets],
column_type: field_type,
accessor_idx: 0,
sub_agg: None,
bucket_id_provider: Default::default(),
limits: AggregationLimitsGuard::default(),
}
}
#[test]
@@ -721,7 +845,7 @@ mod tests {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -744,7 +868,7 @@ mod tests {
];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -759,7 +883,7 @@ mod tests {
let buckets = vec![(-10f64..-1f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
}
@@ -768,7 +892,7 @@ mod tests {
let buckets = vec![(0f64..10f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
}
@@ -777,7 +901,7 @@ mod tests {
fn range_binary_search_test_u64() {
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
let search = |val: u64| collector.get_bucket_pos(val);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9), 0);
@@ -823,7 +947,7 @@ mod tests {
let ranges = vec![(10.0..100.0).into()];
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
let search = |val: u64| collector.get_bucket_pos(val);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9f64.to_u64()), 0);
@@ -835,63 +959,3 @@ mod tests {
// the max value
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use itertools::Itertools;
use rand::seq::SliceRandom;
use rand::thread_rng;
use super::*;
use crate::aggregation::bucket::range::tests::get_collector_from_ranges;
const TOTAL_DOCS: u64 = 1_000_000u64;
const NUM_DOCS: u64 = 50_000u64;
fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector {
let bucket_size = num_docs / num_buckets;
let mut buckets: Vec<RangeAggregationRange> = vec![];
for i in 0..num_buckets {
let bucket_start = (i * bucket_size) as f64;
buckets.push((bucket_start..bucket_start + bucket_size as f64).into())
}
get_collector_from_ranges(buckets, ColumnType::U64)
}
fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec<u64> {
let mut rng = thread_rng();
let all_docs = (0..total_docs - 1).collect_vec();
let mut vals = all_docs
.as_slice()
.choose_multiple(&mut rng, num_docs_returned as usize)
.cloned()
.collect_vec();
vals.sort();
vals
}
fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) {
let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS);
let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS);
b.iter(|| {
let mut bucket_pos = 0;
for val in &vals {
bucket_pos = collector.get_bucket_pos(*val);
}
bucket_pos
})
}
#[bench]
fn bench_range_100_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 100)
}
#[bench]
fn bench_range_10_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 10)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,55 +1,93 @@
use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap;
use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::term_agg::TermsAggregation;
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
/// Special aggregation to handle missing values for term aggregations.
/// This missing aggregation will check multiple columns for existence.
///
/// This is needed when:
/// - The field is multi-valued and we therefore have multiple columns
/// - The field is not text and missing is provided as string (we cannot use the numeric missing
/// value optimization)
#[derive(Default)]
pub struct MissingTermAggReqData {
/// The accessors to check for existence of a value.
pub accessors: Vec<(Column<u64>, ColumnType)>,
/// The name of the aggregation.
pub name: String,
/// The original terms aggregation request.
pub req: TermsAggregation,
}
impl MissingTermAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
#[derive(Default, Debug, Clone)]
struct MissingCount {
missing_count: u32,
bucket_id: BucketId,
}
/// The specialized missing term aggregation.
#[derive(Default, Debug, Clone)]
#[derive(Default, Debug)]
pub struct TermMissingAgg {
missing_count: u32,
accessor_idx: usize,
sub_agg: Option<Box<dyn SegmentAggregationCollector>>,
sub_agg: Option<CachedSubAggs>,
/// Idx = parent bucket id, Value = missing count for that bucket
missing_count_per_bucket: Vec<MissingCount>,
bucket_id_provider: BucketIdProvider,
}
impl TermMissingAgg {
pub(crate) fn new(
accessor_idx: usize,
sub_aggregations: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let has_sub_aggregations = !sub_aggregations.is_empty();
let has_sub_aggregations = !node.children.is_empty();
let accessor_idx = node.idx_in_req_data;
let sub_agg = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collector(sub_aggregations)?;
let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let sub_agg = sub_agg.map(CachedSubAggs::new);
let bucket_id_provider = BucketIdProvider::default();
Ok(Self {
accessor_idx,
sub_agg,
..Default::default()
missing_count_per_bucket: Vec::new(),
bucket_id_provider,
})
}
}
impl SegmentAggregationCollector for TermMissingAgg {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
let term_agg = agg_with_accessor
.agg
.agg
.as_term()
.expect("TermMissingAgg collector must be term agg req");
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let term_agg = &req_data.req;
let missing = term_agg
.missing
.as_ref()
@@ -58,16 +96,16 @@ impl SegmentAggregationCollector for TermMissingAgg {
let mut entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> =
Default::default();
let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize];
let mut missing_entry = IntermediateTermBucketEntry {
doc_count: self.missing_count,
doc_count: missing_count.missing_count,
sub_aggregation: Default::default(),
};
if let Some(sub_agg) = self.sub_agg {
if let Some(sub_agg) = &mut self.sub_agg {
let mut res = IntermediateAggregationResults::default();
sub_agg.add_intermediate_aggregation_result(
&agg_with_accessor.sub_aggregation,
&mut res,
)?;
sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?;
missing_entry.sub_aggregation = res;
}
entries.insert(missing.into(), missing_entry);
@@ -80,37 +118,62 @@ impl SegmentAggregationCollector for TermMissingAgg {
},
};
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
results.push(
req_data.name.to_string(),
IntermediateAggregationResult::Bucket(bucket),
)?;
Ok(())
}
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let agg = &mut agg_with_accessor.aggs.values[self.accessor_idx];
let has_value = agg
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
self.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.collect(doc, &mut agg.sub_aggregation)?;
let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize];
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
for doc in docs {
let doc = *doc;
let has_value = req_data
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
bucket.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket_id, doc);
}
}
}
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn collect_block(
fn prepare_max_bucket(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
for doc in docs {
self.collect(*doc, agg_with_accessor)?;
while self.missing_count_per_bucket.len() <= max_bucket as usize {
let bucket_id = self.bucket_id_provider.next_bucket_id();
self.missing_count_per_bucket.push(MissingCount {
missing_count: 0,
bucket_id,
});
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
}
Ok(())
}

View File

@@ -1,83 +0,0 @@
use super::agg_req_with_accessor::AggregationsWithAccessor;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::DocId;
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
/// BufAggregationCollector buffers documents before calling collect_block().
#[derive(Clone)]
pub(crate) struct BufAggregationCollector {
pub(crate) collector: Box<dyn SegmentAggregationCollector>,
staged_docs: DocBlock,
num_staged_docs: usize,
}
impl std::fmt::Debug for BufAggregationCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentAggregationResultsCollector")
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
.field("num_staged_docs", &self.num_staged_docs)
.finish()
}
}
impl BufAggregationCollector {
pub fn new(collector: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
collector,
num_staged_docs: 0,
staged_docs: [0; DOC_BLOCK_SIZE],
}
}
}
impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
Box::new(self.collector).add_intermediate_aggregation_result(agg_with_accessor, results)
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.staged_docs[self.num_staged_docs] = doc;
self.num_staged_docs += 1;
if self.num_staged_docs == self.staged_docs.len() {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;
self.num_staged_docs = 0;
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.collector.collect_block(docs, agg_with_accessor)?;
Ok(())
}
#[inline]
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;
self.num_staged_docs = 0;
self.collector.flush(agg_with_accessor)?;
Ok(())
}
}

View File

@@ -0,0 +1,185 @@
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
use crate::aggregation::BucketId;
use crate::DocId;
#[derive(Debug)]
/// A cache for sub-aggregations, storing doc ids per bucket id.
/// Depending on the cardinality of the parent aggregation, we use different
/// storage strategies.
///
/// ## Low Cardinality
/// Cardinality here refers to the number of unique flattened buckets that can be created
/// by the parent aggregation.
/// Flattened buckets are the result of combining all buckets per collector
/// into a single list of buckets, where each bucket is identified by its BucketId.
///
/// ## Usage
/// Since this is caching for sub-aggregations, it is only used by bucket
/// aggregations.
///
/// TODO: consider using a more advanced data structure for high cardinality
/// aggregations.
/// What this datastructure does in general is to group docs by bucket id.
pub(crate) struct CachedSubAggs<const LOWCARD: bool = false> {
/// Only used when LOWCARD is true.
/// Cache doc ids per bucket for sub-aggregations.
///
/// The outer Vec is indexed by BucketId.
per_bucket_docs: Vec<Vec<DocId>>,
/// Only used when LOWCARD is false.
///
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
///
/// We want to keep this cheap, because high cardinality aggregations can have a lot of
/// buckets, and they may be nothing to group.
partitions: [PartitionEntry; NUM_PARTITIONS],
pub(crate) sub_agg_collector: Box<dyn SegmentAggregationCollector>,
num_docs: usize,
}
const FLUSH_THRESHOLD: usize = 2048;
const NUM_PARTITIONS: usize = 16;
impl<const LOWCARD: bool> CachedSubAggs<LOWCARD> {
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
&mut self.sub_agg_collector
}
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
per_bucket_docs: Vec::new(),
num_docs: 0,
sub_agg_collector: sub_agg,
partitions: core::array::from_fn(|_| PartitionEntry::default()),
}
}
#[inline]
pub fn clear(&mut self) {
for v in &mut self.per_bucket_docs {
v.clear();
}
for partition in &mut self.partitions {
partition.clear();
}
self.num_docs = 0;
}
#[inline]
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
if LOWCARD {
// TODO: We could flush single buckets here
let idx = bucket_id as usize;
if self.per_bucket_docs.len() <= idx {
self.per_bucket_docs.resize_with(idx + 1, Vec::new);
}
self.per_bucket_docs[idx].push(doc_id);
} else {
let idx = bucket_id % NUM_PARTITIONS as u32;
let slot = &mut self.partitions[idx as usize];
slot.bucket_ids.push(bucket_id);
slot.docs.push(doc_id);
}
self.num_docs += 1;
}
/// Check if we need to flush based on the number of documents cached.
/// If so, flushes the cache to the provided aggregation collector.
pub fn check_flush_local(
&mut self,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if self.num_docs >= FLUSH_THRESHOLD {
self.flush_local(agg_data, false)?;
}
Ok(())
}
/// Note: this does _not_ flush the sub aggregations
fn flush_local(
&mut self,
agg_data: &mut AggregationsSegmentCtx,
force: bool,
) -> crate::Result<()> {
if LOWCARD {
// Pre-aggregated: call collect per bucket.
let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1);
self.sub_agg_collector
.prepare_max_bucket(max_bucket, agg_data)?;
// The threshold above which we flush buckets individually.
// Note: We need to make sure that we don't lock ourselves into a situation where we hit
// the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush)
let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2);
const _: () = {
// MAX_NUM_TERMS_FOR_VEC == LOWCARD threshold
let bucket_treshold = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2);
assert!(
bucket_treshold > 0,
"Bucket threshold must be greater than 0"
);
};
if force {
bucket_treshold = 0;
}
for (bucket_id, docs) in self
.per_bucket_docs
.iter()
.enumerate()
.filter(|(_, docs)| docs.len() > bucket_treshold)
{
self.sub_agg_collector
.collect(bucket_id as BucketId, docs, agg_data)?;
}
} else {
let mut max_bucket = 0u32;
for partition in &self.partitions {
if let Some(&local_max) = partition.bucket_ids.iter().max() {
max_bucket = max_bucket.max(local_max);
}
}
self.sub_agg_collector
.prepare_max_bucket(max_bucket, agg_data)?;
for slot in &self.partitions {
if !slot.bucket_ids.is_empty() {
// Reduce dynamic dispatch overhead by collecting a full partition in one call.
self.sub_agg_collector.collect_multiple(
&slot.bucket_ids,
&slot.docs,
agg_data,
)?;
}
}
}
self.clear();
Ok(())
}
/// Note: this _does_ flush the sub aggregations
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if self.num_docs != 0 {
self.flush_local(agg_data, true)?;
}
self.sub_agg_collector.flush(agg_data)?;
Ok(())
}
}
#[derive(Debug, Clone, Default)]
struct PartitionEntry {
bucket_ids: Vec<BucketId>,
docs: Vec<DocId>,
}
impl PartitionEntry {
#[inline]
fn clear(&mut self) {
self.bucket_ids.clear();
self.docs.clear();
}
}

View File

@@ -1,12 +1,12 @@
use super::agg_req::Aggregations;
use super::agg_req_with_accessor::AggregationsWithAccessor;
use super::agg_result::AggregationResults;
use super::buf_collector::BufAggregationCollector;
use super::cached_sub_aggs::CachedSubAggs;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::{
build_segment_agg_collector, AggregationLimitsGuard, SegmentAggregationCollector,
use super::AggContextParams;
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
use crate::aggregation::agg_data::{
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
};
use crate::aggregation::agg_req_with_accessor::get_aggs_with_segment_accessor_and_validate;
use crate::collector::{Collector, SegmentCollector};
use crate::index::SegmentReader;
use crate::{DocId, SegmentOrdinal, TantivyError};
@@ -22,7 +22,7 @@ pub const DEFAULT_MEMORY_LIMIT: u64 = 500_000_000;
/// The collector collects all aggregations by the underlying aggregation request.
pub struct AggregationCollector {
agg: Aggregations,
limits: AggregationLimitsGuard,
context: AggContextParams,
}
impl AggregationCollector {
@@ -30,8 +30,8 @@ impl AggregationCollector {
///
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit)
pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self {
Self { agg, limits }
pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self {
Self { agg, context }
}
}
@@ -45,7 +45,7 @@ impl AggregationCollector {
/// into the final `AggregationResults` via the `into_final_result()` method.
pub struct DistributedAggregationCollector {
agg: Aggregations,
limits: AggregationLimitsGuard,
context: AggContextParams,
}
impl DistributedAggregationCollector {
@@ -53,8 +53,8 @@ impl DistributedAggregationCollector {
///
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit)
pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self {
Self { agg, limits }
pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self {
Self { agg, context }
}
}
@@ -72,7 +72,7 @@ impl Collector for DistributedAggregationCollector {
&self.agg,
reader,
segment_local_id,
&self.limits,
&self.context,
)
}
@@ -102,7 +102,7 @@ impl Collector for AggregationCollector {
&self.agg,
reader,
segment_local_id,
&self.limits,
&self.context,
)
}
@@ -115,7 +115,7 @@ impl Collector for AggregationCollector {
segment_fruits: Vec<<Self::Child as SegmentCollector>::Fruit>,
) -> crate::Result<Self::Fruit> {
let res = merge_fruits(segment_fruits)?;
res.into_final_result(self.agg.clone(), self.limits.clone())
res.into_final_result(self.agg.clone(), self.context.limits.clone())
}
}
@@ -135,8 +135,8 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsWithAccessor,
agg_collector: BufAggregationCollector,
aggs_with_accessor: AggregationsSegmentCtx,
agg_collector: CachedSubAggs<true>,
error: Option<TantivyError>,
}
@@ -147,14 +147,17 @@ impl AggregationSegmentCollector {
agg: &Aggregations,
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
limits: &AggregationLimitsGuard,
context: &AggContextParams,
) -> crate::Result<Self> {
let mut aggs_with_accessor =
get_aggs_with_segment_accessor_and_validate(agg, reader, segment_ordinal, limits)?;
let result =
BufAggregationCollector::new(build_segment_agg_collector(&mut aggs_with_accessor)?);
let mut agg_data =
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
let mut result = CachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
result
.get_sub_agg_collector()
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero
Ok(AggregationSegmentCollector {
aggs_with_accessor,
aggs_with_accessor: agg_data,
agg_collector: result,
error: None,
})
@@ -169,26 +172,31 @@ impl SegmentCollector for AggregationSegmentCollector {
if self.error.is_some() {
return;
}
if let Err(err) = self
self.agg_collector.push(0, doc);
match self
.agg_collector
.collect(doc, &mut self.aggs_with_accessor)
.check_flush_local(&mut self.aggs_with_accessor)
{
self.error = Some(err);
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
}
}
/// The query pushes the documents to the collector via this method.
///
/// Only valid for Collectors that ignore docs
fn collect_block(&mut self, docs: &[DocId]) {
if self.error.is_some() {
return;
}
if let Err(err) = self
.agg_collector
.collect_block(docs, &mut self.aggs_with_accessor)
{
self.error = Some(err);
match self.agg_collector.get_sub_agg_collector().collect(
0,
docs,
&mut self.aggs_with_accessor,
) {
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
}
}
@@ -199,10 +207,13 @@ impl SegmentCollector for AggregationSegmentCollector {
self.agg_collector.flush(&mut self.aggs_with_accessor)?;
let mut sub_aggregation_res = IntermediateAggregationResults::default();
Box::new(self.agg_collector).add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
)?;
self.agg_collector
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
0,
)?;
Ok(sub_aggregation_res)
}

View File

@@ -24,7 +24,9 @@ use super::metric::{
};
use super::segment_agg_result::AggregationLimitsGuard;
use super::{format_date, AggregationError, Key, SerializedKey};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
use crate::aggregation::agg_result::{
AggregationResults, BucketEntries, BucketEntry, FilterBucketResult,
};
use crate::aggregation::bucket::TermsAggregationInternal;
use crate::aggregation::metric::CardinalityCollector;
use crate::TantivyError;
@@ -179,12 +181,17 @@ impl IntermediateAggregationResults {
}
/// Merge another intermediate aggregation result into this result.
///
/// The order of the values need to be the same on both results. This is ensured when the same
/// (key values) are present on the underlying `VecWithNames` struct.
pub fn merge_fruits(&mut self, other: IntermediateAggregationResults) -> crate::Result<()> {
for (left, right) in self.aggs_res.values_mut().zip(other.aggs_res.into_values()) {
left.merge_fruits(right)?;
pub fn merge_fruits(&mut self, mut other: IntermediateAggregationResults) -> crate::Result<()> {
for (key, left) in self.aggs_res.iter_mut() {
if let Some(key) = other.aggs_res.remove(key) {
left.merge_fruits(key)?;
}
}
// Move remainder of other aggs_res into self.
// Note: Currently we don't expect this to happen, as we create empty intermediate results
// via [IntermediateAggregationResults::empty_from_req].
for (key, value) in other.aggs_res {
self.aggs_res.insert(key, value);
}
Ok(())
}
@@ -241,11 +248,16 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
Cardinality(_) => IntermediateAggregationResult::Metric(
IntermediateMetricResult::Cardinality(CardinalityCollector::default()),
),
Filter(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Filter {
doc_count: 0,
sub_aggregations: IntermediateAggregationResults::default(),
}),
}
}
/// An aggregation is either a bucket or a metric.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[allow(clippy::large_enum_variant)]
pub enum IntermediateAggregationResult {
/// Bucket variant
Bucket(IntermediateBucketResult),
@@ -426,6 +438,13 @@ pub enum IntermediateBucketResult {
/// The term buckets
buckets: IntermediateTermBucketResult,
},
/// Filter aggregation - a single bucket with sub-aggregations
Filter {
/// Document count in the filter bucket
doc_count: u64,
/// Sub-aggregation results
sub_aggregations: IntermediateAggregationResults,
},
}
impl IntermediateBucketResult {
@@ -509,6 +528,18 @@ impl IntermediateBucketResult {
req.sub_aggregation(),
limits,
),
IntermediateBucketResult::Filter {
doc_count,
sub_aggregations,
} => {
// Convert sub-aggregation results to final format
let final_sub_aggregations = sub_aggregations
.into_final_result(req.sub_aggregation().clone(), limits.clone())?;
Ok(BucketResult::Filter(FilterBucketResult {
doc_count,
sub_aggregations: final_sub_aggregations,
}))
}
}
}
@@ -562,6 +593,19 @@ impl IntermediateBucketResult {
*buckets_left = buckets?;
}
(
IntermediateBucketResult::Filter {
doc_count: doc_count_left,
sub_aggregations: sub_aggs_left,
},
IntermediateBucketResult::Filter {
doc_count: doc_count_right,
sub_aggregations: sub_aggs_right,
},
) => {
*doc_count_left += doc_count_right;
sub_aggs_left.merge_fruits(sub_aggs_right)?;
}
(IntermediateBucketResult::Range(_), _) => {
panic!("try merge on different types")
}
@@ -571,6 +615,9 @@ impl IntermediateBucketResult {
(IntermediateBucketResult::Terms { .. }, _) => {
panic!("try merge on different types")
}
(IntermediateBucketResult::Filter { .. }, _) => {
panic!("try merge on different types")
}
}
Ok(())
}
@@ -745,7 +792,7 @@ pub struct IntermediateRangeBucketEntry {
/// The number of documents in the bucket.
pub doc_count: u64,
/// The sub_aggregation in this bucket.
pub sub_aggregation: IntermediateAggregationResults,
pub sub_aggregation_res: IntermediateAggregationResults,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`.
@@ -764,7 +811,7 @@ impl IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation
.sub_aggregation_res
.into_final_result_internal(req, limits)?,
to: self.to,
from: self.from,
@@ -810,7 +857,8 @@ impl MergeFruits for IntermediateTermBucketEntry {
impl MergeFruits for IntermediateRangeBucketEntry {
fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> {
self.doc_count += other.doc_count;
self.sub_aggregation.merge_fruits(other.sub_aggregation)?;
self.sub_aggregation_res
.merge_fruits(other.sub_aggregation_res)?;
Ok(())
}
}
@@ -840,7 +888,7 @@ mod tests {
IntermediateRangeBucketEntry {
key: IntermediateKey::Str(key.to_string()),
doc_count: *doc_count,
sub_aggregation: Default::default(),
sub_aggregation_res: Default::default(),
from: None,
to: None,
},
@@ -873,7 +921,7 @@ mod tests {
doc_count: *doc_count,
from: None,
to: None,
sub_aggregation: get_sub_test_tree(&[(
sub_aggregation_res: get_sub_test_tree(&[(
sub_aggregation_key.to_string(),
*sub_aggregation_count,
)]),

View File

@@ -52,10 +52,8 @@ pub struct IntermediateAverage {
impl IntermediateAverage {
/// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateAverage) {

View File

@@ -2,15 +2,13 @@ use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher};
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::Dictionary;
use columnar::{Column, ColumnType, Dictionary, StrColumn};
use common::f64_to_u64;
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use rustc_hash::FxHashSet;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
@@ -97,6 +95,30 @@ pub struct CardinalityAggregationReq {
pub missing: Option<Key>,
}
/// Contains all information required by the SegmentCardinalityCollector to perform the
/// cardinality aggregation on a segment.
pub struct CardinalityAggReqData {
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// The column_type of the field.
pub column_type: ColumnType,
/// The string dictionary column if the field is of type string.
pub str_dict_column: Option<StrColumn>,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_value_for_accessor: Option<u64>,
/// The name of the aggregation.
pub name: String,
/// The aggregation request.
pub req: CardinalityAggregationReq,
}
impl CardinalityAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
impl CardinalityAggregationReq {
/// Creates a new [`CardinalityAggregationReq`] instance from a field name.
pub fn from_field_name(field_name: String) -> Self {
@@ -111,51 +133,37 @@ impl CardinalityAggregationReq {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentCardinalityCollector {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
column_type: ColumnType,
buckets: Vec<SegmentCardinalityCollectorBucket>,
accessor_idx: usize,
missing: Option<Key>,
/// The column accessor to access the fast field values.
accessor: Column<u64>,
/// The column_type of the field.
column_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
missing_value_for_accessor: Option<u64>,
}
impl SegmentCardinalityCollector {
pub fn from_req(column_type: ColumnType, accessor_idx: usize, missing: &Option<Key>) -> Self {
#[derive(Clone, Debug, PartialEq, Default)]
pub(crate) struct SegmentCardinalityCollectorBucket {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
}
impl SegmentCardinalityCollectorBucket {
pub fn new(column_type: ColumnType) -> Self {
Self {
cardinality: CardinalityCollector::new(column_type as u8),
entries: Default::default(),
column_type,
accessor_idx,
missing: missing.clone(),
entries: FxHashSet::default(),
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_accessor: &mut AggregationWithAccessor,
) {
if let Some(missing) = agg_accessor.missing_value_for_accessor {
agg_accessor.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
missing,
);
} else {
agg_accessor
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
}
}
fn into_intermediate_metric_result(
mut self,
agg_with_accessor: &AggregationWithAccessor,
req_data: &CardinalityAggReqData,
) -> crate::Result<IntermediateMetricResult> {
if self.column_type == ColumnType::Str {
if req_data.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let dict = agg_with_accessor
let dict = req_data
.str_dict_column
.as_ref()
.map(|el| el.dictionary())
@@ -173,6 +181,7 @@ impl SegmentCardinalityCollector {
term_ids.push(term_ord as u32);
}
}
term_ids.sort_unstable();
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
self.cardinality.sketch.insert_any(&term);
@@ -180,10 +189,10 @@ impl SegmentCardinalityCollector {
})?;
if has_missing {
// Replace missing with the actual value provided
let missing_key = self
.missing
.as_ref()
.expect("Found sentinel value u64::MAX for term_ord but `missing` is not set");
let missing_key =
req_data.req.missing.as_ref().expect(
"Found sentinel value u64::MAX for term_ord but `missing` is not set",
);
match missing_key {
Key::Str(missing) => {
self.cardinality.sketch.insert_any(&missing);
@@ -206,16 +215,49 @@ impl SegmentCardinalityCollector {
}
}
impl SegmentCardinalityCollector {
pub fn from_req(
column_type: ColumnType,
accessor_idx: usize,
accessor: Column<u64>,
missing_value_for_accessor: Option<u64>,
) -> Self {
Self {
buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1],
column_type,
accessor_idx,
accessor,
missing_value_for_accessor,
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_value_for_accessor,
);
}
}
impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
let name = req_data.name.to_string();
// take the bucket in buckets and replace it with a new empty one
let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_result = self.into_intermediate_metric_result(agg_with_accessor)?;
let intermediate_result = bucket.into_intermediate_metric_result(req_data)?;
results.push(
name,
IntermediateAggregationResult::Metric(intermediate_result),
@@ -226,27 +268,20 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_with_accessor)
}
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.fetch_block_with_field(docs, bucket_agg_accessor);
self.fetch_block_with_field(docs, agg_data);
let bucket = &mut self.buckets[parent_bucket_id as usize];
let col_block_accessor = &bucket_agg_accessor.column_block_accessor;
let col_block_accessor = &agg_data.column_block_accessor;
if self.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() {
self.entries.insert(term_ord);
bucket.entries.insert(term_ord);
}
} else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = bucket_agg_accessor
let compact_space_accessor = self
.accessor
.values
.clone()
@@ -261,16 +296,29 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
self.cardinality.sketch.insert_any(&val);
bucket.cardinality.sketch.insert_any(&val);
}
} else {
for val in col_block_accessor.iter_vals() {
self.cardinality.sketch.insert_any(&val);
bucket.cardinality.sketch.insert_any(&val);
}
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
if max_bucket as usize >= self.buckets.len() {
self.buckets.resize_with(max_bucket as usize + 1, || {
SegmentCardinalityCollectorBucket::new(self.column_type)
});
}
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]

View File

@@ -52,10 +52,8 @@ pub struct IntermediateCount {
impl IntermediateCount {
/// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateCount) {

View File

@@ -4,15 +4,13 @@ use std::mem;
use serde::{Deserialize, Serialize};
use super::*;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
use crate::TantivyError;
/// A multi-value metric aggregation that computes a collection of extended statistics
/// on numeric values that are extracted
@@ -63,7 +61,7 @@ impl ExtendedStatsAggregation {
/// Extended stats contains a collection of statistics
/// they extends stats adding variance, standard deviation
/// and bound informations
/// and bound information
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ExtendedStats {
/// The number of documents.
@@ -319,51 +317,28 @@ impl IntermediateExtendedStats {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentExtendedStatsCollector {
name: String,
missing: Option<u64>,
field_type: ColumnType,
pub(crate) extended_stats: IntermediateExtendedStats,
pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
accessor: columnar::Column<u64>,
buckets: Vec<IntermediateExtendedStats>,
sigma: Option<f64>,
}
impl SegmentExtendedStatsCollector {
pub fn from_req(
field_type: ColumnType,
sigma: Option<f64>,
accessor_idx: usize,
missing: Option<f64>,
) -> Self {
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
pub fn from_req(req: &MetricAggReqData, sigma: Option<f64>) -> Self {
let missing = req
.missing
.and_then(|val| f64_to_fastfield_u64(val, &req.field_type));
Self {
field_type,
extended_stats: IntermediateExtendedStats::with_sigma(sigma),
accessor_idx,
name: req.name.clone(),
field_type: req.field_type,
accessor: req.accessor.clone(),
missing,
val_cache: Default::default(),
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
agg_accessor: &mut AggregationWithAccessor,
) {
if let Some(missing) = self.missing.as_ref() {
agg_accessor.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
*missing,
);
} else {
agg_accessor
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
}
for val in agg_accessor.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16],
sigma,
}
}
}
@@ -371,15 +346,18 @@ impl SegmentExtendedStatsCollector {
impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let name = self.name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
results.push(
name,
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
self.extended_stats,
extended_stats,
)),
)?;
@@ -389,39 +367,36 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
if let Some(missing) = self.missing {
let mut has_val = false;
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
has_val = true;
}
if !has_val {
self.extended_stats
.collect(f64_from_fastfield_u64(missing, &self.field_type));
}
} else {
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
}
let mut extended_stats = self.buckets[parent_bucket_id as usize].clone();
agg_data
.column_block_accessor
.fetch_block_with_missing(docs, &self.accessor, self.missing);
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
extended_stats.collect(val1);
}
// store back
self.buckets[parent_bucket_id as usize] = extended_stats;
Ok(())
}
#[inline]
fn collect_block(
fn prepare_max_bucket(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.collect_block_with_field(docs, field);
if self.buckets.len() <= max_bucket as usize {
self.buckets.resize_with(max_bucket as usize + 1, || {
IntermediateExtendedStats::with_sigma(self.sigma)
});
}
Ok(())
}
}

View File

@@ -52,10 +52,8 @@ pub struct IntermediateMax {
impl IntermediateMax {
/// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMax) {

View File

@@ -52,10 +52,8 @@ pub struct IntermediateMin {
impl IntermediateMin {
/// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMin) {

View File

@@ -31,6 +31,7 @@ use std::collections::HashMap;
pub use average::*;
pub use cardinality::*;
use columnar::{Column, ColumnType};
pub use count::*;
pub use extended_stats::*;
pub use max::*;
@@ -44,6 +45,33 @@ pub use top_hits::*;
use crate::schema::OwnedValue;
/// Contains all information required by metric aggregations like avg, min, max, sum, stats,
/// extended_stats, count, percentiles.
#[repr(C)]
pub struct MetricAggReqData {
/// True if the field is of number or date type.
pub is_number_or_date_type: bool,
/// The type of the field.
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// Used when converting to intermediate result
pub collecting_for: StatsType,
/// The missing value
pub missing: Option<f64>,
/// The name of the aggregation.
pub name: String,
}
impl MetricAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// Single-metric aggregations use this common result structure.
///
/// Main reason to wrap it in value is to match elasticsearch output structure.

View File

@@ -3,15 +3,13 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use super::*;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
use crate::TantivyError;
/// # Percentiles
///
@@ -112,7 +110,8 @@ impl PercentilesAggregationReq {
&self.field
}
fn validate(&self) -> crate::Result<()> {
/// Validates the request parameters.
pub fn validate(&self) -> crate::Result<()> {
if let Some(percents) = self.percents.as_ref() {
let all_in_range = percents
.iter()
@@ -131,12 +130,16 @@ impl PercentilesAggregationReq {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentPercentilesCollector {
field_type: ColumnType,
pub(crate) percentiles: PercentilesCollector,
pub(crate) buckets: Vec<PercentilesCollector>,
pub(crate) accessor_idx: usize,
missing: Option<u64>,
/// The type of the field.
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
}
#[derive(Clone, Serialize, Deserialize)]
@@ -232,43 +235,17 @@ impl PercentilesCollector {
impl SegmentPercentilesCollector {
pub fn from_req_and_validate(
req: &PercentilesAggregationReq,
field_type: ColumnType,
missing_u64: Option<u64>,
accessor: Column<u64>,
accessor_idx: usize,
) -> crate::Result<Self> {
req.validate()?;
let missing = req
.missing
.and_then(|val| f64_to_fastfield_u64(val, &field_type));
Ok(Self {
) -> Self {
Self {
buckets: Vec::with_capacity(64),
field_type,
percentiles: PercentilesCollector::new(),
missing_u64,
accessor,
accessor_idx,
missing,
})
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
agg_accessor: &mut AggregationWithAccessor,
) {
if let Some(missing) = self.missing.as_ref() {
agg_accessor.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
*missing,
);
} else {
agg_accessor
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
}
for val in agg_accessor.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.percentiles.collect(val1);
}
}
}
@@ -276,12 +253,18 @@ impl SegmentPercentilesCollector {
impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
// Swap collector with an empty one to avoid cloning
let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_metric_result =
IntermediateMetricResult::Percentiles(percentiles_collector);
results.push(
name,
@@ -294,40 +277,33 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
let percentiles = &mut self.buckets[parent_bucket_id as usize];
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
if let Some(missing) = self.missing {
let mut has_val = false;
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.percentiles.collect(val1);
has_val = true;
}
if !has_val {
self.percentiles
.collect(f64_from_fastfield_u64(missing, &self.field_type));
}
} else {
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.percentiles.collect(val1);
}
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
percentiles.collect(val1);
}
Ok(())
}
#[inline]
fn collect_block(
fn prepare_max_bucket(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.collect_block_with_field(docs, field);
while self.buckets.len() <= max_bucket as usize {
self.buckets.push(PercentilesCollector::new());
}
Ok(())
}
}

View File

@@ -1,17 +1,16 @@
use std::fmt::Debug;
use columnar::{Column, ColumnType};
use serde::{Deserialize, Serialize};
use super::*;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
use crate::TantivyError;
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
/// are extracted from the aggregated documents.
@@ -84,7 +83,7 @@ impl Stats {
/// Intermediate result of the stats aggregation that can be combined with other intermediate
/// results.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateStats {
/// The number of extracted values.
pub(crate) count: u64,
@@ -166,106 +165,98 @@ impl IntermediateStats {
}
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) enum SegmentStatsType {
/// The type of stats aggregation to perform.
/// Note that not all stats types are supported in the stats aggregation.
#[derive(Clone, Copy, Debug)]
pub enum StatsType {
/// The average of the values.
Average,
/// The count of the values.
Count,
/// The maximum value.
Max,
/// The minimum value.
Min,
/// The stats (count, sum, min, max, avg) of the values.
Stats,
/// The extended stats (count, sum, min, max, avg, sum_of_squares, variance, std_deviation,
ExtendedStats(Option<f64>), // sigma
/// The sum of the values.
Sum,
/// The percentiles of the values.
Percentiles,
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentStatsCollector {
missing: Option<u64>,
field_type: ColumnType,
pub(crate) collecting_for: SegmentStatsType,
pub(crate) stats: IntermediateStats,
pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
fn create_collector<const TYPE_ID: u8>(
req: &MetricAggReqData,
) -> Box<dyn SegmentAggregationCollector> {
Box::new(SegmentStatsCollector::<TYPE_ID> {
name: req.name.clone(),
collecting_for: req.collecting_for,
is_number_or_date_type: req.is_number_or_date_type,
missing_u64: req.missing_u64,
accessor: req.accessor.clone(),
buckets: vec![IntermediateStats::default()],
})
}
impl SegmentStatsCollector {
pub fn from_req(
field_type: ColumnType,
collecting_for: SegmentStatsType,
accessor_idx: usize,
missing: Option<f64>,
) -> Self {
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
Self {
field_type,
collecting_for,
stats: IntermediateStats::default(),
accessor_idx,
missing,
val_cache: Default::default(),
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
agg_accessor: &mut AggregationWithAccessor,
) {
if let Some(missing) = self.missing.as_ref() {
agg_accessor.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
*missing,
);
} else {
agg_accessor
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
}
if [
ColumnType::I64,
ColumnType::U64,
ColumnType::F64,
ColumnType::DateTime,
]
.contains(&self.field_type)
{
for val in agg_accessor.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.stats.collect(val1);
}
} else {
for _val in agg_accessor.column_block_accessor.iter_vals() {
// we ignore the value and simply record that we got something
self.stats.collect(0.0);
}
}
/// Build a concrete `SegmentStatsCollector` depending on the column type.
pub(crate) fn build_segment_stats_collector(
req: &MetricAggReqData,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match req.field_type {
ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)),
ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)),
ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)),
ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)),
ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)),
ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)),
ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)),
ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)),
}
}
impl SegmentAggregationCollector for SegmentStatsCollector {
#[repr(C)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentStatsCollector<const COLUMN_TYPE_ID: u8> {
pub(crate) missing_u64: Option<u64>,
pub(crate) accessor: Column<u64>,
pub(crate) is_number_or_date_type: bool,
pub(crate) buckets: Vec<IntermediateStats>,
pub(crate) name: String,
pub(crate) collecting_for: StatsType,
}
impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
for SegmentStatsCollector<COLUMN_TYPE_ID>
{
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let name = self.name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let stats = self.buckets[parent_bucket_id as usize];
let intermediate_metric_result = match self.collecting_for {
SegmentStatsType::Average => {
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self))
StatsType::Average => {
IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats))
}
SegmentStatsType::Count => {
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self))
StatsType::Count => {
IntermediateMetricResult::Count(IntermediateCount::from_stats(stats))
}
SegmentStatsType::Max => {
IntermediateMetricResult::Max(IntermediateMax::from_collector(*self))
}
SegmentStatsType::Min => {
IntermediateMetricResult::Min(IntermediateMin::from_collector(*self))
}
SegmentStatsType::Stats => IntermediateMetricResult::Stats(self.stats),
SegmentStatsType::Sum => {
IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self))
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)),
StatsType::Stats => IntermediateMetricResult::Stats(stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)),
_ => {
return Err(TantivyError::InvalidArgument(format!(
"Unsupported stats type for stats aggregation: {:?}",
self.collecting_for
)))
}
};
@@ -280,41 +271,67 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
if let Some(missing) = self.missing {
let mut has_val = false;
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.stats.collect(val1);
has_val = true;
}
if !has_val {
self.stats
.collect(f64_from_fastfield_u64(missing, &self.field_type));
}
} else {
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.stats.collect(val1);
}
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.collect_block_with_field(docs, field);
// TODO: remove once we fetch all values for all bucket ids in one go
if docs.len() == 1 && self.missing_u64.is_none() {
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
self.accessor.values_for_doc(docs[0]),
self.is_number_or_date_type,
)?;
return Ok(());
}
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
agg_data.column_block_accessor.iter_vals(),
self.is_number_or_date_type,
)?;
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let required_buckets = (max_bucket as usize) + 1;
if self.buckets.len() < required_buckets {
self.buckets
.resize_with(required_buckets, IntermediateStats::default);
}
Ok(())
}
}
#[inline]
fn collect_stats<const COLUMN_TYPE_ID: u8>(
stats: &mut IntermediateStats,
vals: impl Iterator<Item = u64>,
is_number_or_date_type: bool,
) -> crate::Result<()> {
if is_number_or_date_type {
for val in vals {
let val1 = convert_to_f64::<COLUMN_TYPE_ID>(val);
stats.collect(val1);
}
} else {
for _val in vals {
// we ignore the value and simply record that we got something
stats.collect(0.0);
}
}
Ok(())
}
#[cfg(test)]

View File

@@ -52,10 +52,8 @@ pub struct IntermediateSum {
impl IntermediateSum {
/// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateSum) {

View File

@@ -9,16 +9,41 @@ use serde::ser::SerializeMap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use super::{TopHitsMetricResult, TopHitsVecEntry};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::bucket::Order;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::AggregationError;
use crate::aggregation::{AggregationError, BucketId};
use crate::collector::sort_key::ReverseComparator;
use crate::collector::TopNComputer;
use crate::schema::OwnedValue;
use crate::{DocAddress, DocId, SegmentOrdinal};
/// Contains all information required by the TopHitsSegmentCollector to perform the
/// top_hits aggregation on a segment.
#[derive(Default)]
pub struct TopHitsAggReqData {
/// The accessors to access the fast field values.
pub accessors: Vec<(Column<u64>, ColumnType)>,
/// The accessors to access the fast field values for retrieving document fields.
pub value_accessors: HashMap<String, Vec<DynamicColumn>>,
/// The ordinal of the segment this request data is for.
pub segment_ordinal: SegmentOrdinal,
/// The name of the aggregation.
pub name: String,
/// The top_hits aggregation request.
pub req: TopHitsAggregationReq,
}
impl TopHitsAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// # Top Hits
///
/// The top hits aggregation is a useful tool to answer questions like:
@@ -433,7 +458,7 @@ impl Eq for DocSortValuesAndFields {}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct TopHitsTopNComputer {
req: TopHitsAggregationReq,
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, false>,
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, ReverseComparator>,
}
impl std::cmp::PartialEq for TopHitsTopNComputer {
@@ -446,7 +471,10 @@ impl TopHitsTopNComputer {
/// Create a new TopHitsCollector
pub fn new(req: &TopHitsAggregationReq) -> Self {
Self {
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
top_n: TopNComputer::new_with_comparator(
req.size + req.from.unwrap_or(0),
ReverseComparator,
),
req: req.clone(),
}
}
@@ -457,7 +485,7 @@ impl TopHitsTopNComputer {
pub(crate) fn merge_fruits(&mut self, other_fruit: Self) -> crate::Result<()> {
for doc in other_fruit.top_n.into_vec() {
self.collect(doc.feature, doc.doc);
self.collect(doc.sort_key, doc.doc);
}
Ok(())
}
@@ -469,9 +497,9 @@ impl TopHitsTopNComputer {
.into_sorted_vec()
.into_iter()
.map(|doc| TopHitsVecEntry {
sort: doc.feature.sorts.iter().map(|f| f.value).collect(),
sort: doc.sort_key.sorts.iter().map(|f| f.value).collect(),
doc_value_fields: doc
.feature
.sort_key
.doc_value_fields
.into_iter()
.map(|(k, v)| (k, v.into()))
@@ -492,7 +520,8 @@ impl TopHitsTopNComputer {
pub(crate) struct TopHitsSegmentCollector {
segment_ordinal: SegmentOrdinal,
accessor_idx: usize,
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, false>,
buckets: Vec<TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>>,
num_hits: usize,
}
impl TopHitsSegmentCollector {
@@ -501,25 +530,35 @@ impl TopHitsSegmentCollector {
accessor_idx: usize,
segment_ordinal: SegmentOrdinal,
) -> Self {
let num_hits = req.size + req.from.unwrap_or(0);
Self {
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
num_hits,
segment_ordinal,
accessor_idx,
buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1],
}
}
fn into_top_hits_collector(
self,
fn get_top_hits_computer(
&mut self,
parent_bucket_id: BucketId,
value_accessors: &HashMap<String, Vec<DynamicColumn>>,
req: &TopHitsAggregationReq,
) -> TopHitsTopNComputer {
if parent_bucket_id as usize >= self.buckets.len() {
return TopHitsTopNComputer::new(req);
}
let top_n = std::mem::replace(
&mut self.buckets[parent_bucket_id as usize],
TopNComputer::new(0),
);
let mut top_hits_computer = TopHitsTopNComputer::new(req);
let top_results = self.top_n.into_vec();
let top_results = top_n.into_vec();
for res in top_results {
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
top_hits_computer.collect(
DocSortValuesAndFields {
sorts: res.feature,
sorts: res.sort_key,
doc_value_fields,
},
res.doc,
@@ -528,61 +567,26 @@ impl TopHitsSegmentCollector {
top_hits_computer
}
/// TODO add a specialized variant for a single sort field
fn collect_with(
&mut self,
doc_id: crate::DocId,
req: &TopHitsAggregationReq,
accessors: &[(Column<u64>, ColumnType)],
) -> crate::Result<()> {
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
self.top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
Ok(())
}
}
impl SegmentAggregationCollector for TopHitsSegmentCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
let value_accessors = &agg_with_accessor.aggs.values[self.accessor_idx].value_accessors;
let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx]
.agg
.agg
.as_top_hits()
.expect("aggregation request must be of type top hits");
let value_accessors = &req_data.value_accessors;
let intermediate_result = IntermediateMetricResult::TopHits(
self.into_top_hits_collector(value_accessors, tophits_req),
);
let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer(
parent_bucket_id,
value_accessors,
&req_data.req,
));
results.push(
name,
req_data.name.to_string(),
IntermediateAggregationResult::Metric(intermediate_result),
)
}
@@ -590,34 +594,55 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
/// TODO: Consider a caching layer to reduce the call overhead
fn collect(
&mut self,
doc_id: crate::DocId,
agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx]
.agg
.agg
.as_top_hits()
.expect("aggregation request must be of type top hits");
let accessors = &agg_with_accessor.aggs.values[self.accessor_idx].accessors;
self.collect_with(doc_id, tophits_req, accessors)?;
let top_n = &mut self.buckets[parent_bucket_id as usize];
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
let req = &req_data.req;
let accessors = &req_data.accessors;
for doc_id in docs {
let doc_id = *doc_id;
// TODO: this is terrible, a new vec is allocated for every doc
// We can fetch blocks instead
// We don't need to store the order for every value
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
}
Ok(())
}
fn collect_block(
fn prepare_max_bucket(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx]
.agg
.agg
.as_top_hits()
.expect("aggregation request must be of type top hits");
let accessors = &agg_with_accessor.aggs.values[self.accessor_idx].accessors;
// TODO: Consider getting fields with the column block accessor.
for doc in docs {
self.collect_with(*doc, tophits_req, accessors)?;
}
self.buckets.resize(
(max_bucket as usize) + 1,
TopNComputer::new_with_comparator(self.num_hits, ReverseComparator),
);
Ok(())
}
}
@@ -635,6 +660,7 @@ mod tests {
use crate::aggregation::bucket::tests::get_test_index_from_docs;
use crate::aggregation::tests::get_test_index_from_values;
use crate::aggregation::AggregationCollector;
use crate::collector::sort_key::ReverseComparator;
use crate::collector::ComparableDoc;
use crate::query::AllQuery;
use crate::schema::OwnedValue;
@@ -650,7 +676,7 @@ mod tests {
fn collector_with_capacity(capacity: usize) -> super::TopHitsTopNComputer {
super::TopHitsTopNComputer {
top_n: super::TopNComputer::new(capacity),
top_n: super::TopNComputer::new_with_comparator(capacity, ReverseComparator),
req: Default::default(),
}
}
@@ -734,7 +760,7 @@ mod tests {
],
"from": 0,
}
}
}
}))
.unwrap();
@@ -764,12 +790,12 @@ mod tests {
#[test]
fn test_top_hits_collector_single_feature() -> crate::Result<()> {
let docs = vec![
ComparableDoc::<_, _, false> {
ComparableDoc::<_, _> {
doc: crate::DocAddress {
segment_ord: 0,
doc_id: 0,
},
feature: DocSortValuesAndFields {
sort_key: DocSortValuesAndFields {
sorts: vec![DocValueAndOrder {
value: Some(1),
order: Order::Asc,
@@ -782,7 +808,7 @@ mod tests {
segment_ord: 0,
doc_id: 2,
},
feature: DocSortValuesAndFields {
sort_key: DocSortValuesAndFields {
sorts: vec![DocValueAndOrder {
value: Some(3),
order: Order::Asc,
@@ -795,7 +821,7 @@ mod tests {
segment_ord: 0,
doc_id: 1,
},
feature: DocSortValuesAndFields {
sort_key: DocSortValuesAndFields {
sorts: vec![DocValueAndOrder {
value: Some(5),
order: Order::Asc,
@@ -807,7 +833,7 @@ mod tests {
let mut collector = collector_with_capacity(3);
for doc in docs.clone() {
collector.collect(doc.feature, doc.doc);
collector.collect(doc.sort_key, doc.doc);
}
let res = collector.into_final_result();
@@ -817,15 +843,15 @@ mod tests {
super::TopHitsMetricResult {
hits: vec![
super::TopHitsVecEntry {
sort: vec![docs[0].feature.sorts[0].value],
sort: vec![docs[0].sort_key.sorts[0].value],
doc_value_fields: Default::default(),
},
super::TopHitsVecEntry {
sort: vec![docs[1].feature.sorts[0].value],
sort: vec![docs[1].sort_key.sorts[0].value],
doc_value_fields: Default::default(),
},
super::TopHitsVecEntry {
sort: vec![docs[2].feature.sorts[0].value],
sort: vec![docs[2].sort_key.sorts[0].value],
doc_value_fields: Default::default(),
},
]
@@ -863,7 +889,7 @@ mod tests {
"mixed.*",
],
}
}
}
}))?;
let collector = AggregationCollector::from_aggs(d, Default::default());

View File

@@ -127,12 +127,13 @@
//! [`AggregationResults`](agg_result::AggregationResults) via the
//! [`into_final_result`](intermediate_agg_result::IntermediateAggregationResults::into_final_result) method.
mod accessor_helpers;
mod agg_data;
mod agg_limits;
pub mod agg_req;
mod agg_req_with_accessor;
pub mod agg_result;
pub mod bucket;
mod buf_collector;
pub(crate) mod cached_sub_aggs;
mod collector;
mod date;
mod error;
@@ -140,7 +141,6 @@ pub mod intermediate_agg_result;
pub mod metric;
mod segment_agg_result;
use std::collections::HashMap;
use std::fmt::Display;
#[cfg(test)]
@@ -160,6 +160,41 @@ use itertools::Itertools;
use serde::de::{self, Visitor};
use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::TokenizerManager;
/// A bucket id is a dense identifier for a bucket within an aggregation.
/// It is used to index into a Vec that hold per-bucket data.
///
/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId.
/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket.
///
/// This allows to have a single AggregationCollector instance per aggregation,
/// that can handle multiple buckets efficiently.
///
/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])].
/// For that we'll need a buffer. One Vec per bucket aggregation is needed.
pub type BucketId = u32;
/// Context parameters for aggregation execution
///
/// This struct holds shared resources needed during aggregation execution:
/// - `limits`: Memory and bucket limits for the aggregation
/// - `tokenizers`: TokenizerManager for parsing query strings in filter aggregations
#[derive(Clone, Default)]
pub struct AggContextParams {
/// Aggregation limits (memory and bucket count)
pub limits: AggregationLimitsGuard,
/// Tokenizer manager for query string parsing
pub tokenizers: TokenizerManager,
}
impl AggContextParams {
/// Create new aggregation context parameters
pub fn new(limits: AggregationLimitsGuard, tokenizers: TokenizerManager) -> Self {
Self { limits, tokenizers }
}
}
fn parse_str_into_f64<E: de::Error>(value: &str) -> Result<f64, E> {
let parsed = value
.parse::<f64>()
@@ -257,80 +292,6 @@ where D: Deserializer<'de> {
deserializer.deserialize_any(StringOrFloatVisitor)
}
/// Represents an associative array `(key => values)` in a very efficient manner.
#[derive(PartialEq, Serialize, Deserialize)]
pub(crate) struct VecWithNames<T> {
pub(crate) values: Vec<T>,
keys: Vec<String>,
}
impl<T: Clone> Clone for VecWithNames<T> {
fn clone(&self) -> Self {
Self {
values: self.values.clone(),
keys: self.keys.clone(),
}
}
}
impl<T> Default for VecWithNames<T> {
fn default() -> Self {
Self {
values: Default::default(),
keys: Default::default(),
}
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for VecWithNames<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_map().entries(self.iter()).finish()
}
}
impl<T> From<HashMap<String, T>> for VecWithNames<T> {
fn from(map: HashMap<String, T>) -> Self {
VecWithNames::from_entries(map.into_iter().collect_vec())
}
}
impl<T> VecWithNames<T> {
fn from_entries(mut entries: Vec<(String, T)>) -> Self {
// Sort to ensure order of elements match across multiple instances
entries.sort_by(|left, right| left.0.cmp(&right.0));
let mut data = Vec::with_capacity(entries.len());
let mut data_names = Vec::with_capacity(entries.len());
for entry in entries {
data_names.push(entry.0);
data.push(entry.1);
}
VecWithNames {
values: data,
keys: data_names,
}
}
fn iter(&self) -> impl Iterator<Item = (&str, &T)> + '_ {
self.keys().zip(self.values.iter())
}
fn keys(&self) -> impl Iterator<Item = &str> + '_ {
self.keys.iter().map(|key| key.as_str())
}
fn values_mut(&mut self) -> impl Iterator<Item = &mut T> + '_ {
self.values.iter_mut()
}
fn is_empty(&self) -> bool {
self.keys.is_empty()
}
fn len(&self) -> usize {
self.keys.len()
}
fn get(&self, name: &str) -> Option<&T> {
self.keys()
.position(|key| key == name)
.map(|pos| &self.values[pos])
}
}
/// The serialized key is used in a `HashMap`.
pub type SerializedKey = String;
@@ -387,19 +348,37 @@ impl Display for Key {
}
}
pub(crate) fn convert_to_f64<const COLUMN_TYPE_ID: u8>(val: u64) -> f64 {
if COLUMN_TYPE_ID == ColumnType::U64 as u8 {
val as f64
} else if COLUMN_TYPE_ID == ColumnType::I64 as u8
|| COLUMN_TYPE_ID == ColumnType::DateTime as u8
{
i64::from_u64(val) as f64
} else if COLUMN_TYPE_ID == ColumnType::F64 as u8 {
f64::from_u64(val)
} else if COLUMN_TYPE_ID == ColumnType::Bool as u8 {
val as f64
} else {
panic!(
"ColumnType ID {} cannot be converted to f64 metric",
COLUMN_TYPE_ID
)
}
}
/// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics.
///
/// # Panics
/// Only `u64`, `f64`, `date`, and `i64` are supported.
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 {
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 {
match field_type {
ColumnType::U64 => val as f64,
ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64,
ColumnType::F64 => f64::from_u64(val),
ColumnType::Bool => val as f64,
_ => {
panic!("unexpected type {field_type:?}. This should not happen")
}
ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val),
ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val),
ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val),
ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val),
ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val),
_ => panic!("unexpected type {field_type:?}. This should not happen"),
}
}
@@ -464,7 +443,10 @@ mod tests {
query: Option<(&str, &str)>,
limits: AggregationLimitsGuard,
) -> crate::Result<Value> {
let collector = AggregationCollector::from_aggs(agg_req, limits);
let collector = AggregationCollector::from_aggs(
agg_req,
AggContextParams::new(limits, index.tokenizers().clone()),
);
let reader = index.reader()?;
let searcher = reader.searcher();

View File

@@ -6,179 +6,79 @@
use std::fmt::Debug;
pub(crate) use super::agg_limits::AggregationLimitsGuard;
use super::agg_req::AggregationVariants;
use super::agg_req_with_accessor::{AggregationWithAccessor, AggregationsWithAccessor};
use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector};
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::metric::{
AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation,
SegmentPercentilesCollector, SegmentStatsCollector, SegmentStatsType, StatsAggregation,
SumAggregation,
};
use crate::aggregation::bucket::TermMissingAgg;
use crate::aggregation::metric::{
CardinalityAggregationReq, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
TopHitsSegmentCollector,
};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::BucketId;
pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug {
/// Monotonically increasing provider of BucketIds.
#[derive(Debug, Clone, Default)]
pub struct BucketIdProvider(u32);
impl BucketIdProvider {
/// Get the next BucketId.
pub fn next_bucket_id(&mut self) -> BucketId {
let bucket_id = self.0;
self.0 += 1;
bucket_id
}
}
/// A SegmentAggregationCollector is used to collect aggregation results.
pub trait SegmentAggregationCollector: Debug {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()>;
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()>;
fn collect_block(
/// Collect docs for multiple buckets in one call.
/// Minimizes dynamic dispatch overhead when collecting many buckets.
///
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect_multiple(
&mut self,
bucket_ids: &[BucketId],
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
debug_assert_eq!(bucket_ids.len(), docs.len());
let mut start = 0;
while start < bucket_ids.len() {
let bucket_id = bucket_ids[start];
let mut end = start + 1;
while end < bucket_ids.len() && bucket_ids[end] == bucket_id {
end += 1;
}
self.collect(bucket_id, &docs[start..end], agg_data)?;
start = end;
}
Ok(())
}
/// Prepare the collector for collecting up to BucketId `max_bucket`.
/// This is useful so we can split allocation ahead of time of collecting.
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()>;
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
/// This method ensures those staged docs will be collected.
fn flush(&mut self, _agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
Ok(())
}
}
pub(crate) trait CollectorClone {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
}
impl<T> CollectorClone for T
where T: 'static + SegmentAggregationCollector + Clone
{
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn SegmentAggregationCollector> {
fn clone(&self) -> Box<dyn SegmentAggregationCollector> {
self.clone_box()
}
}
pub(crate) fn build_segment_agg_collector(
req: &mut AggregationsWithAccessor,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
// Single collector special case
if req.aggs.len() == 1 {
let req = &mut req.aggs.values[0];
let accessor_idx = 0;
return build_single_agg_segment_collector(req, accessor_idx);
}
let agg = GenericSegmentAggregationResultsCollector::from_req_and_validate(req)?;
Ok(Box::new(agg))
}
pub(crate) fn build_single_agg_segment_collector(
req: &mut AggregationWithAccessor,
accessor_idx: usize,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
use AggregationVariants::*;
match &req.agg.agg {
Terms(terms_req) => {
if req.accessors.is_empty() {
Ok(Box::new(SegmentTermCollector::from_req_and_validate(
terms_req,
&mut req.sub_aggregation,
req.field_type,
accessor_idx,
)?))
} else {
Ok(Box::new(TermMissingAgg::new(
accessor_idx,
&mut req.sub_aggregation,
)?))
}
}
Range(range_req) => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
range_req,
&mut req.sub_aggregation,
&mut req.limits,
req.field_type,
accessor_idx,
)?)),
Histogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
histogram.clone(),
&mut req.sub_aggregation,
req.field_type,
accessor_idx,
)?)),
DateHistogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
histogram.to_histogram_req()?,
&mut req.sub_aggregation,
req.field_type,
accessor_idx,
)?)),
Average(AverageAggregation { missing, .. }) => {
Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Average,
accessor_idx,
*missing,
)))
}
Count(CountAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Count,
accessor_idx,
*missing,
))),
Max(MaxAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Max,
accessor_idx,
*missing,
))),
Min(MinAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Min,
accessor_idx,
*missing,
))),
Stats(StatsAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Stats,
accessor_idx,
*missing,
))),
ExtendedStats(ExtendedStatsAggregation { missing, sigma, .. }) => Ok(Box::new(
SegmentExtendedStatsCollector::from_req(req.field_type, *sigma, accessor_idx, *missing),
)),
Sum(SumAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Sum,
accessor_idx,
*missing,
))),
Percentiles(percentiles_req) => Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(
percentiles_req,
req.field_type,
accessor_idx,
)?,
)),
TopHits(top_hits_req) => Ok(Box::new(TopHitsSegmentCollector::from_req(
top_hits_req,
accessor_idx,
req.segment_ordinal,
))),
Cardinality(CardinalityAggregationReq { missing, .. }) => Ok(Box::new(
SegmentCardinalityCollector::from_req(req.field_type, accessor_idx, missing),
)),
}
}
#[derive(Clone, Default)]
#[derive(Default)]
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
/// and can provide specialized versions instead, that remove some of its overhead.
@@ -196,12 +96,13 @@ impl Debug for GenericSegmentAggregationResultsCollector {
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
for agg in self.aggs {
agg.add_intermediate_aggregation_result(agg_with_accessor, results)?;
for agg in &mut self.aggs {
agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?;
}
Ok(())
@@ -209,43 +110,31 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_with_accessor)?;
Ok(())
}
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.collect_block(docs, agg_with_accessor)?;
collector.collect(parent_bucket_id, docs, agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.flush(agg_with_accessor)?;
collector.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.prepare_max_bucket(max_bucket, agg_data)?;
}
Ok(())
}
}
impl GenericSegmentAggregationResultsCollector {
pub(crate) fn from_req_and_validate(req: &mut AggregationsWithAccessor) -> crate::Result<Self> {
let aggs = req
.aggs
.values_mut()
.enumerate()
.map(|(accessor_idx, req)| build_single_agg_segment_collector(req, accessor_idx))
.collect::<crate::Result<Vec<Box<dyn SegmentAggregationCollector>>>>()?;
Ok(GenericSegmentAggregationResultsCollector { aggs })
}
}

View File

@@ -1,121 +0,0 @@
use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
use crate::collector::{Collector, SegmentCollector};
use crate::{DocAddress, DocId, Score, SegmentReader};
pub(crate) struct CustomScoreTopCollector<TCustomScorer, TScore = Score> {
custom_scorer: TCustomScorer,
collector: TopCollector<TScore>,
}
impl<TCustomScorer, TScore> CustomScoreTopCollector<TCustomScorer, TScore>
where TScore: Clone + PartialOrd
{
pub(crate) fn new(
custom_scorer: TCustomScorer,
collector: TopCollector<TScore>,
) -> CustomScoreTopCollector<TCustomScorer, TScore> {
CustomScoreTopCollector {
custom_scorer,
collector,
}
}
}
/// A custom segment scorer makes it possible to define any kind of score
/// for a given document belonging to a specific segment.
///
/// It is the segment local version of the [`CustomScorer`].
pub trait CustomSegmentScorer<TScore>: 'static {
/// Computes the score of a specific `doc`.
fn score(&mut self, doc: DocId) -> TScore;
}
/// `CustomScorer` makes it possible to define any kind of score.
///
/// The `CustomerScorer` itself does not make much of the computation itself.
/// Instead, it helps constructing `Self::Child` instances that will compute
/// the score at a segment scale.
pub trait CustomScorer<TScore>: Sync {
/// Type of the associated [`CustomSegmentScorer`].
type Child: CustomSegmentScorer<TScore>;
/// Builds a child scorer for a specific segment. The child scorer is associated with
/// a specific segment.
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child>;
}
impl<TCustomScorer, TScore> Collector for CustomScoreTopCollector<TCustomScorer, TScore>
where
TCustomScorer: CustomScorer<TScore> + Send + Sync,
TScore: 'static + PartialOrd + Clone + Send + Sync,
{
type Fruit = Vec<(TScore, DocAddress)>;
type Child = CustomScoreTopSegmentCollector<TCustomScorer::Child, TScore>;
fn for_segment(
&self,
segment_local_id: u32,
segment_reader: &SegmentReader,
) -> crate::Result<Self::Child> {
let segment_collector = self.collector.for_segment(segment_local_id, segment_reader);
let segment_scorer = self.custom_scorer.segment_scorer(segment_reader)?;
Ok(CustomScoreTopSegmentCollector {
segment_collector,
segment_scorer,
})
}
fn requires_scoring(&self) -> bool {
false
}
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> crate::Result<Self::Fruit> {
self.collector.merge_fruits(segment_fruits)
}
}
pub struct CustomScoreTopSegmentCollector<T, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync + Sized,
T: CustomSegmentScorer<TScore>,
{
segment_collector: TopSegmentCollector<TScore>,
segment_scorer: T,
}
impl<T, TScore> SegmentCollector for CustomScoreTopSegmentCollector<T, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync,
T: 'static + CustomSegmentScorer<TScore>,
{
type Fruit = Vec<(TScore, DocAddress)>;
fn collect(&mut self, doc: DocId, _score: Score) {
let score = self.segment_scorer.score(doc);
self.segment_collector.collect(doc, score);
}
fn harvest(self) -> Vec<(TScore, DocAddress)> {
self.segment_collector.harvest()
}
}
impl<F, TScore, T> CustomScorer<TScore> for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> T,
T: CustomSegmentScorer<TScore>,
{
type Child = T;
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child> {
Ok((self)(segment_reader))
}
}
impl<F, TScore> CustomSegmentScorer<TScore> for F
where F: 'static + FnMut(DocId) -> TScore
{
fn score(&mut self, doc: DocId) -> TScore {
(self)(doc)
}
}

View File

@@ -484,7 +484,6 @@ impl FacetCounts {
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use std::iter;
use columnar::Dictionary;
use rand::distributions::Uniform;

View File

@@ -12,6 +12,7 @@ use std::marker::PhantomData;
use columnar::{BytesColumn, Column, DynamicColumn, HasAssociatedColumnType};
use crate::collector::{Collector, SegmentCollector};
use crate::schema::Schema;
use crate::{DocId, Score, SegmentReader};
/// The `FilterCollector` filters docs using a fast field value and a predicate.
@@ -49,13 +50,13 @@ use crate::{DocId, Score, SegmentReader};
///
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary")?;
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2));
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2).order_by_score());
/// let top_docs = searcher.search(&query, &no_filter_collector)?;
///
/// assert_eq!(top_docs.len(), 1);
/// assert_eq!(top_docs[0].1, DocAddress::new(0, 1));
///
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2));
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2).order_by_score());
/// let filtered_top_docs = searcher.search(&query, &filter_all_collector)?;
///
/// assert_eq!(filtered_top_docs.len(), 0);
@@ -104,6 +105,11 @@ where
type Child = FilterSegmentCollector<TCollector::Child, TPredicate, TPredicateValue>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.collector.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -120,6 +126,7 @@ where
segment_collector,
predicate: self.predicate.clone(),
t_predicate_value: PhantomData,
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
})
}
@@ -140,6 +147,7 @@ pub struct FilterSegmentCollector<TSegmentCollector, TPredicate, TPredicateValue
segment_collector: TSegmentCollector,
predicate: TPredicate,
t_predicate_value: PhantomData<TPredicateValue>,
filtered_docs: Vec<DocId>,
}
impl<TSegmentCollector, TPredicate, TPredicateValue>
@@ -176,6 +184,20 @@ where
}
}
fn collect_block(&mut self, docs: &[DocId]) {
self.filtered_docs.clear();
for &doc in docs {
// TODO: `accept_document` could be further optimized to do batch lookups of column
// values for single-valued columns.
if self.accept_document(doc) {
self.filtered_docs.push(doc);
}
}
if !self.filtered_docs.is_empty() {
self.segment_collector.collect_block(&self.filtered_docs);
}
}
fn harvest(self) -> TSegmentCollector::Fruit {
self.segment_collector.harvest()
}
@@ -218,7 +240,7 @@ where
///
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary")?;
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2));
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2).order_by_score());
/// let top_docs = searcher.search(&query, &filter_collector)?;
///
/// assert_eq!(top_docs.len(), 1);
@@ -258,6 +280,10 @@ where
type Child = BytesFilterSegmentCollector<TCollector::Child, TPredicate>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.collector.check_schema(schema)
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -274,6 +300,7 @@ where
segment_collector,
predicate: self.predicate.clone(),
buffer: Vec::new(),
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
})
}
@@ -296,6 +323,7 @@ where TPredicate: 'static
segment_collector: TSegmentCollector,
predicate: TPredicate,
buffer: Vec<u8>,
filtered_docs: Vec<DocId>,
}
impl<TSegmentCollector, TPredicate> BytesFilterSegmentCollector<TSegmentCollector, TPredicate>
@@ -334,6 +362,20 @@ where
}
}
fn collect_block(&mut self, docs: &[DocId]) {
self.filtered_docs.clear();
for &doc in docs {
// TODO: `accept_document` could be further optimized to do batch lookups of column
// values for single-valued columns.
if self.accept_document(doc) {
self.filtered_docs.push(doc);
}
}
if !self.filtered_docs.is_empty() {
self.segment_collector.collect_block(&self.filtered_docs);
}
}
fn harvest(self) -> TSegmentCollector::Fruit {
self.segment_collector.harvest()
}

View File

@@ -57,7 +57,7 @@
//! # let query_parser = QueryParser::for_index(&index, vec![title]);
//! # let query = query_parser.parse_query("diary")?;
//! let (doc_count, top_docs): (usize, Vec<(Score, DocAddress)>) =
//! searcher.search(&query, &(Count, TopDocs::with_limit(2)))?;
//! searcher.search(&query, &(Count, TopDocs::with_limit(2).order_by_score()))?;
//! # Ok(())
//! # }
//! ```
@@ -83,11 +83,15 @@
use downcast_rs::impl_downcast;
use crate::schema::Schema;
use crate::{DocId, Score, SegmentOrdinal, SegmentReader};
mod count_collector;
pub use self::count_collector::Count;
/// Sort keys
pub mod sort_key;
mod histogram_collector;
pub use histogram_collector::HistogramCollector;
@@ -95,16 +99,13 @@ mod multi_collector;
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
mod top_collector;
pub use self::top_collector::ComparableDoc;
mod top_score_collector;
pub use self::top_collector::ComparableDoc;
pub use self::top_score_collector::{TopDocs, TopNComputer};
mod custom_score_top_collector;
pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer};
mod tweak_score_top_collector;
pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker};
mod sort_key_top_collector;
pub use self::sort_key::{SegmentSortKeyComputer, SortKeyComputer};
mod facet_collector;
pub use self::facet_collector::{FacetCollector, FacetCounts};
use crate::query::Weight;
@@ -145,6 +146,11 @@ pub trait Collector: Sync + Send {
/// Type of the `SegmentCollector` associated with this collector.
type Child: SegmentCollector;
/// Returns an error if the schema is not compatible with the collector.
fn check_schema(&self, _schema: &Schema) -> crate::Result<()> {
Ok(())
}
/// `set_segment` is called before beginning to enumerate
/// on this segment.
fn for_segment(
@@ -170,41 +176,50 @@ pub trait Collector: Sync + Send {
segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
let with_scoring = self.requires_scoring();
let mut segment_collector = self.for_segment(segment_ord, reader)?;
match (reader.alive_bitset(), self.requires_scoring()) {
(Some(alive_bitset), true) => {
weight.for_each(reader, &mut |doc, score| {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, score);
}
})?;
}
(Some(alive_bitset), false) => {
weight.for_each_no_score(reader, &mut |docs| {
for doc in docs.iter().cloned() {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
}
}
})?;
}
(None, true) => {
weight.for_each(reader, &mut |doc, score| {
segment_collector.collect(doc, score);
})?;
}
(None, false) => {
weight.for_each_no_score(reader, &mut |docs| {
segment_collector.collect_block(docs);
})?;
}
}
default_collect_segment_impl(&mut segment_collector, weight, reader, with_scoring)?;
Ok(segment_collector.harvest())
}
}
pub(crate) fn default_collect_segment_impl<TSegmentCollector: SegmentCollector>(
segment_collector: &mut TSegmentCollector,
weight: &dyn Weight,
reader: &SegmentReader,
with_scoring: bool,
) -> crate::Result<()> {
match (reader.alive_bitset(), with_scoring) {
(Some(alive_bitset), true) => {
weight.for_each(reader, &mut |doc, score| {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, score);
}
})?;
}
(Some(alive_bitset), false) => {
weight.for_each_no_score(reader, &mut |docs| {
for doc in docs.iter().cloned() {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
}
}
})?;
}
(None, true) => {
weight.for_each(reader, &mut |doc, score| {
segment_collector.collect(doc, score);
})?;
}
(None, false) => {
weight.for_each_no_score(reader, &mut |docs| {
segment_collector.collect_block(docs);
})?;
}
}
Ok(())
}
impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCollector> {
type Fruit = Option<TSegmentCollector::Fruit>;
@@ -214,6 +229,12 @@ impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCo
}
}
fn collect_block(&mut self, docs: &[DocId]) {
if let Some(segment_collector) = self {
segment_collector.collect_block(docs);
}
}
fn harvest(self) -> Self::Fruit {
self.map(|segment_collector| segment_collector.harvest())
}
@@ -224,6 +245,13 @@ impl<TCollector: Collector> Collector for Option<TCollector> {
type Child = Option<<TCollector as Collector>::Child>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
if let Some(underlying_collector) = self {
underlying_collector.check_schema(schema)?;
}
Ok(())
}
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
@@ -299,6 +327,12 @@ where
type Fruit = (Left::Fruit, Right::Fruit);
type Child = (Left::Child, Right::Child);
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -342,6 +376,11 @@ where
self.1.collect(doc, score);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.0.collect_block(docs);
self.1.collect_block(docs);
}
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
(self.0.harvest(), self.1.harvest())
}
@@ -358,6 +397,13 @@ where
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit);
type Child = (One::Child, Two::Child, Three::Child);
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -407,6 +453,12 @@ where
self.2.collect(doc, score);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.0.collect_block(docs);
self.1.collect_block(docs);
self.2.collect_block(docs);
}
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
(self.0.harvest(), self.1.harvest(), self.2.harvest())
}
@@ -424,6 +476,14 @@ where
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit, Four::Fruit);
type Child = (One::Child, Two::Child, Three::Child, Four::Child);
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
self.3.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -482,6 +542,13 @@ where
self.3.collect(doc, score);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.0.collect_block(docs);
self.1.collect_block(docs);
self.2.collect_block(docs);
self.3.collect_block(docs);
}
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
(
self.0.harvest(),

View File

@@ -3,6 +3,7 @@ use std::ops::Deref;
use super::{Collector, SegmentCollector};
use crate::collector::Fruit;
use crate::schema::Schema;
use crate::{DocId, Score, SegmentOrdinal, SegmentReader, TantivyError};
/// MultiFruit keeps Fruits from every nested Collector
@@ -16,6 +17,10 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
type Fruit = Box<dyn Fruit>;
type Child = Box<dyn BoxableSegmentCollector>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -147,7 +152,7 @@ impl<TFruit: Fruit> FruitHandle<TFruit> {
/// let searcher = reader.searcher();
///
/// let mut collectors = MultiCollector::new();
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2));
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
/// let count_handle = collectors.add_collector(Count);
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary").unwrap();
@@ -194,6 +199,13 @@ impl Collector for MultiCollector<'_> {
type Fruit = MultiFruit;
type Child = MultiCollectorChild;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
for collector in &self.collector_wrappers {
collector.check_schema(schema)?;
}
Ok(())
}
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
@@ -250,6 +262,12 @@ impl SegmentCollector for MultiCollectorChild {
}
}
fn collect_block(&mut self, docs: &[DocId]) {
for child in &mut self.children {
child.collect_block(docs);
}
}
fn harvest(self) -> MultiFruit {
MultiFruit {
sub_fruits: self
@@ -293,7 +311,7 @@ mod tests {
let query = TermQuery::new(term, IndexRecordOption::Basic);
let mut collectors = MultiCollector::new();
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2));
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
let count_handler = collectors.add_collector(Count);
let mut multifruits = searcher.search(&query, &collectors).unwrap();

View File

@@ -0,0 +1,393 @@
mod order;
mod sort_by_score;
mod sort_by_static_fast_value;
mod sort_by_string;
mod sort_key_computer;
pub use order::*;
pub use sort_by_score::SortBySimilarityScore;
pub use sort_by_static_fast_value::SortByStaticFastValue;
pub use sort_by_string::SortByString;
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::ops::Range;
use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString};
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
use crate::indexer::NoMergePolicy;
use crate::query::{AllQuery, QueryParser};
use crate::schema::{Schema, FAST, TEXT};
use crate::{DocAddress, Document, Index, Order, Score, Searcher};
fn make_index() -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
let id = schema_builder.add_u64_field("id", FAST);
let city = schema_builder.add_text_field("city", TEXT | FAST);
let catchphrase = schema_builder.add_text_field("catchphrase", TEXT);
let altitude = schema_builder.add_f64_field("altitude", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
fn create_segment(index: &Index, docs: Vec<impl Document>) -> crate::Result<()> {
let mut index_writer = index.writer_for_tests()?;
index_writer.set_merge_policy(Box::new(NoMergePolicy));
for doc in docs {
index_writer.add_document(doc)?;
}
index_writer.commit()?;
Ok(())
}
create_segment(
&index,
vec![
doc!(
id => 0_u64,
city => "austin",
catchphrase => "Hills, Barbeque, Glow",
altitude => 149.0,
),
doc!(
id => 1_u64,
city => "greenville",
catchphrase => "Grow, Glow, Glow",
altitude => 27.0,
),
],
)?;
create_segment(
&index,
vec![doc!(
id => 2_u64,
city => "tokyo",
catchphrase => "Glow, Glow, Glow",
altitude => 40.0,
)],
)?;
create_segment(
&index,
vec![doc!(
id => 3_u64,
catchphrase => "No, No, No",
altitude => 0.0,
)],
)?;
Ok(index)
}
// NOTE: You cannot determine the SegmentIds that will be generated for Segments
// ahead of time, so DocAddresses must be mapped back to a unique id for each Searcher.
fn id_mapping(searcher: &Searcher) -> HashMap<DocAddress, u64> {
searcher
.search(&AllQuery, &DocSetCollector)
.unwrap()
.into_iter()
.map(|doc_address| {
let column = searcher.segment_readers()[doc_address.segment_ord as usize]
.fast_fields()
.u64("id")
.unwrap();
(doc_address, column.first(doc_address.doc_id).unwrap())
})
.collect()
}
#[test]
fn test_order_by_string() -> crate::Result<()> {
let index = make_index()?;
#[track_caller]
fn assert_query(
index: &Index,
order: Order,
doc_range: Range<usize>,
expected: Vec<(Option<String>, u64)>,
) -> crate::Result<()> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
// Try as primitive.
let top_collector = TopDocs::for_doc_range(doc_range)
.order_by((SortByString::for_field("city"), order));
let actual = searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(sort_key_opt, doc)| (sort_key_opt, ids[&doc]))
.collect::<Vec<_>>();
assert_eq!(actual, expected);
Ok(())
}
assert_query(
&index,
Order::Asc,
0..4,
vec![
(Some("austin".to_owned()), 0),
(Some("greenville".to_owned()), 1),
(Some("tokyo".to_owned()), 2),
(None, 3),
],
)?;
assert_query(
&index,
Order::Asc,
0..3,
vec![
(Some("austin".to_owned()), 0),
(Some("greenville".to_owned()), 1),
(Some("tokyo".to_owned()), 2),
],
)?;
assert_query(
&index,
Order::Asc,
0..2,
vec![
(Some("austin".to_owned()), 0),
(Some("greenville".to_owned()), 1),
],
)?;
assert_query(
&index,
Order::Asc,
0..1,
vec![(Some("austin".to_string()), 0)],
)?;
assert_query(
&index,
Order::Asc,
1..3,
vec![
(Some("greenville".to_owned()), 1),
(Some("tokyo".to_owned()), 2),
],
)?;
assert_query(
&index,
Order::Desc,
0..4,
vec![
(Some("tokyo".to_owned()), 2),
(Some("greenville".to_owned()), 1),
(Some("austin".to_owned()), 0),
(None, 3),
],
)?;
assert_query(
&index,
Order::Desc,
1..3,
vec![
(Some("greenville".to_owned()), 1),
(Some("austin".to_owned()), 0),
],
)?;
assert_query(
&index,
Order::Desc,
0..1,
vec![(Some("tokyo".to_owned()), 2)],
)?;
Ok(())
}
#[test]
fn test_order_by_f64() -> crate::Result<()> {
let index = make_index()?;
fn assert_query(
index: &Index,
order: Order,
expected: Vec<(Option<f64>, u64)>,
) -> crate::Result<()> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
// Try as primitive.
let top_collector = TopDocs::with_limit(3)
.order_by((SortByStaticFastValue::<f64>::for_field("altitude"), order));
let actual = searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(altitude_opt, doc)| (altitude_opt, ids[&doc]))
.collect::<Vec<_>>();
assert_eq!(actual, expected);
Ok(())
}
assert_query(
&index,
Order::Asc,
vec![(Some(0.0), 3), (Some(27.0), 1), (Some(40.0), 2)],
)?;
assert_query(
&index,
Order::Desc,
vec![(Some(149.0), 0), (Some(40.0), 2), (Some(27.0), 1)],
)?;
Ok(())
}
#[test]
fn test_order_by_score() -> crate::Result<()> {
let index = make_index()?;
fn query(index: &Index, order: Order) -> crate::Result<Vec<(Score, u64)>> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by((SortBySimilarityScore, order));
let field = index.schema().get_field("catchphrase").unwrap();
let query_parser = QueryParser::for_index(index, vec![field]);
let text_query = query_parser.parse_query("glow")?;
Ok(searcher
.search(&text_query, &top_collector)?
.into_iter()
.map(|(score, doc)| (score, ids[&doc]))
.collect())
}
assert_eq!(
&query(&index, Order::Desc)?,
&[(0.5604893, 2), (0.4904281, 1), (0.35667497, 0),]
);
assert_eq!(
&query(&index, Order::Asc)?,
&[(0.35667497, 0), (0.4904281, 1), (0.5604893, 2),]
);
Ok(())
}
#[test]
fn test_order_by_score_then_string() -> crate::Result<()> {
let index = make_index()?;
type SortKey = (Score, Option<String>);
fn query(
index: &Index,
score_order: Order,
city_order: Order,
) -> crate::Result<Vec<(SortKey, u64)>> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by((
(SortBySimilarityScore, score_order),
(SortByString::for_field("city"), city_order),
));
Ok(searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(f, doc)| (f, ids[&doc]))
.collect())
}
assert_eq!(
&query(&index, Order::Asc, Order::Asc)?,
&[
((1.0, Some("austin".to_owned())), 0),
((1.0, Some("greenville".to_owned())), 1),
((1.0, Some("tokyo".to_owned())), 2),
((1.0, None), 3),
]
);
assert_eq!(
&query(&index, Order::Asc, Order::Desc)?,
&[
((1.0, Some("tokyo".to_owned())), 2),
((1.0, Some("greenville".to_owned())), 1),
((1.0, Some("austin".to_owned())), 0),
((1.0, None), 3),
]
);
Ok(())
}
use proptest::prelude::*;
proptest! {
#[test]
fn test_order_by_string_prop(
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..64_usize,
offset in 0..64_usize,
segments_terms in
proptest::collection::vec(
proptest::collection::vec(0..32_u8, 1..32_usize),
0..8_usize,
)
) {
let mut schema_builder = Schema::builder();
let city = schema_builder.add_text_field("city", TEXT | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests()?;
// A Vec<Vec<u8>>, where the outer Vec represents segments, and the inner Vec
// represents terms.
for segment_terms in segments_terms.into_iter() {
for term in segment_terms.into_iter() {
let term = format!("{term:0>3}");
index_writer.add_document(doc!(
city => term,
))?;
}
index_writer.commit()?;
}
let searcher = index.reader()?.searcher();
let top_n_results = searcher.search(&AllQuery, &TopDocs::with_limit(limit)
.and_offset(offset)
.order_by_string_fast_field("city", order))?;
let all_results = searcher.search(&AllQuery, &DocSetCollector)?.into_iter().map(|doc_address| {
// Get the term for this address.
let column = searcher.segment_readers()[doc_address.segment_ord as usize].fast_fields().str("city").unwrap().unwrap();
let value = column.term_ords(doc_address.doc_id).next().map(|term_ord| {
let mut city = Vec::new();
column.dictionary().ord_to_term(term_ord, &mut city).unwrap();
String::try_from(city).unwrap()
});
(value, doc_address)
});
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
prop_assert_eq!(
expected_docs,
top_n_results
);
}
}
}

View File

@@ -0,0 +1,348 @@
use std::cmp::Ordering;
use serde::{Deserialize, Serialize};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::Schema;
use crate::{DocId, Order, Score};
/// Comparator trait defining the order in which documents should be ordered.
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
/// Return the order between two values.
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
}
/// With the natural comparator, the top k collector will return
/// the top documents in decreasing order.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct NaturalComparator;
impl<T: PartialOrd> Comparator<T> for NaturalComparator {
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
lhs.partial_cmp(rhs).unwrap()
}
}
/// Sorts document in reverse order.
///
/// If the sort key is None, it will considered as the lowest value, and will therefore appear
/// first.
///
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct ReverseComparator;
impl<T> Comparator<T> for ReverseComparator
where NaturalComparator: Comparator<T>
{
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
NaturalComparator.compare(rhs, lhs)
}
}
/// Sorts document in reverse order, but considers None as having the lowest value.
///
/// This is usually what is wanted when sorting by a field in an ascending order.
/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the
/// cheapest items first, and the items without a price at last.
#[derive(Debug, Copy, Clone, Default)]
pub struct ReverseNoneIsLowerComparator;
impl<T> Comparator<Option<T>> for ReverseNoneIsLowerComparator
where ReverseComparator: Comparator<T>
{
#[inline(always)]
fn compare(&self, lhs_opt: &Option<T>, rhs_opt: &Option<T>) -> Ordering {
match (lhs_opt, rhs_opt) {
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(Some(lhs), Some(rhs)) => ReverseComparator.compare(lhs, rhs),
}
}
}
impl Comparator<u32> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<u64> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<f64> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<f32> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<i64> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<String> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
/// An enum representing the different sort orders.
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
pub enum ComparatorEnum {
/// Natural order (See [NaturalComparator])
#[default]
Natural,
/// Reverse order (See [ReverseComparator])
Reverse,
/// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator])
ReverseNoneLower,
}
impl From<Order> for ComparatorEnum {
fn from(order: Order) -> Self {
match order {
Order::Asc => ComparatorEnum::ReverseNoneLower,
Order::Desc => ComparatorEnum::Natural,
}
}
}
impl<T> Comparator<T> for ComparatorEnum
where
ReverseNoneIsLowerComparator: Comparator<T>,
NaturalComparator: Comparator<T>,
ReverseComparator: Comparator<T>,
{
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
match self {
ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs),
ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs),
ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs),
}
}
}
impl<Head, Tail, LeftComparator, RightComparator> Comparator<(Head, Tail)>
for (LeftComparator, RightComparator)
where
LeftComparator: Comparator<Head>,
RightComparator: Comparator<Tail>,
{
#[inline(always)]
fn compare(&self, lhs: &(Head, Tail), rhs: &(Head, Tail)) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, (Type2, Type3))>
for (Comparator1, Comparator2, Comparator3)
where
Comparator1: Comparator<Type1>,
Comparator2: Comparator<Type2>,
Comparator3: Comparator<Type3>,
{
#[inline(always)]
fn compare(&self, lhs: &(Type1, (Type2, Type3)), rhs: &(Type1, (Type2, Type3))) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
.then_with(|| self.2.compare(&lhs.1 .1, &rhs.1 .1))
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, Type2, Type3)>
for (Comparator1, Comparator2, Comparator3)
where
Comparator1: Comparator<Type1>,
Comparator2: Comparator<Type2>,
Comparator3: Comparator<Type3>,
{
#[inline(always)]
fn compare(&self, lhs: &(Type1, Type2, Type3), rhs: &(Type1, Type2, Type3)) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
Comparator<(Type1, (Type2, (Type3, Type4)))>
for (Comparator1, Comparator2, Comparator3, Comparator4)
where
Comparator1: Comparator<Type1>,
Comparator2: Comparator<Type2>,
Comparator3: Comparator<Type3>,
Comparator4: Comparator<Type4>,
{
#[inline(always)]
fn compare(
&self,
lhs: &(Type1, (Type2, (Type3, Type4))),
rhs: &(Type1, (Type2, (Type3, Type4))),
) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
.then_with(|| self.2.compare(&lhs.1 .1 .0, &rhs.1 .1 .0))
.then_with(|| self.3.compare(&lhs.1 .1 .1, &rhs.1 .1 .1))
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
Comparator<(Type1, Type2, Type3, Type4)>
for (Comparator1, Comparator2, Comparator3, Comparator4)
where
Comparator1: Comparator<Type1>,
Comparator2: Comparator<Type2>,
Comparator3: Comparator<Type3>,
Comparator4: Comparator<Type4>,
{
#[inline(always)]
fn compare(
&self,
lhs: &(Type1, Type2, Type3, Type4),
rhs: &(Type1, Type2, Type3, Type4),
) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
.then_with(|| self.3.compare(&lhs.3, &rhs.3))
}
}
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, ComparatorEnum)
where
TSortKeyComputer: SortKeyComputer,
ComparatorEnum: Comparator<TSortKeyComputer::SortKey>,
ComparatorEnum: Comparator<
<<TSortKeyComputer as SortKeyComputer>::Child as SegmentSortKeyComputer>::SegmentSortKey,
>,
{
type SortKey = TSortKeyComputer::SortKey;
type Child = SegmentSortKeyComputerWithComparator<TSortKeyComputer::Child, Self::Comparator>;
type Comparator = ComparatorEnum;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)
}
fn requires_scoring(&self) -> bool {
self.0.requires_scoring()
}
fn comparator(&self) -> Self::Comparator {
self.1
}
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let child = self.0.segment_sort_key_computer(segment_reader)?;
Ok(SegmentSortKeyComputerWithComparator {
segment_sort_key_computer: child,
comparator: self.comparator(),
})
}
}
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, Order)
where
TSortKeyComputer: SortKeyComputer,
ComparatorEnum: Comparator<TSortKeyComputer::SortKey>,
ComparatorEnum: Comparator<
<<TSortKeyComputer as SortKeyComputer>::Child as SegmentSortKeyComputer>::SegmentSortKey,
>,
{
type SortKey = TSortKeyComputer::SortKey;
type Child = SegmentSortKeyComputerWithComparator<TSortKeyComputer::Child, Self::Comparator>;
type Comparator = ComparatorEnum;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)
}
fn requires_scoring(&self) -> bool {
self.0.requires_scoring()
}
fn comparator(&self) -> Self::Comparator {
self.1.into()
}
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let child = self.0.segment_sort_key_computer(segment_reader)?;
Ok(SegmentSortKeyComputerWithComparator {
segment_sort_key_computer: child,
comparator: self.comparator(),
})
}
}
/// A segment sort key computer with a custom ordering.
pub struct SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator> {
segment_sort_key_computer: TSegmentSortKeyComputer,
comparator: TComparator,
}
impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComputer
for SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator>
where
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send,
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
{
type SortKey = TSegmentSortKeyComputer::SortKey;
type SegmentSortKey = TSegmentSortKey;
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.segment_sort_key_computer.segment_sort_key(doc, score)
}
#[inline(always)]
fn compare_segment_sort_key(
&self,
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
self.comparator.compare(left, right)
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
self.segment_sort_key_computer
.convert_segment_sort_key(sort_key)
}
}

View File

@@ -0,0 +1,77 @@
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::{DocAddress, DocId, Score};
/// Sort by similarity score.
#[derive(Clone, Debug, Copy)]
pub struct SortBySimilarityScore;
impl SortKeyComputer for SortBySimilarityScore {
type SortKey = Score;
type Child = SortBySimilarityScore;
type Comparator = NaturalComparator;
fn requires_scoring(&self) -> bool {
true
}
fn segment_sort_key_computer(
&self,
_segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
Ok(SortBySimilarityScore)
}
// Sorting by score is special in that it allows for the Block-Wand optimization.
fn collect_segment_top_k(
&self,
k: usize,
weight: &dyn crate::query::Weight,
reader: &crate::SegmentReader,
segment_ord: u32,
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
let mut top_n: TopNComputer<Score, DocId, Self::Comparator> =
TopNComputer::new_with_comparator(k, self.comparator());
if let Some(alive_bitset) = reader.alive_bitset() {
let mut threshold = Score::MIN;
top_n.threshold = Some(threshold);
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
if alive_bitset.is_deleted(doc) {
return threshold;
}
top_n.push(score, doc);
threshold = top_n.threshold.unwrap_or(Score::MIN);
threshold
})?;
} else {
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
top_n.push(score, doc);
top_n.threshold.unwrap_or(Score::MIN)
})?;
}
Ok(top_n
.into_vec()
.into_iter()
.map(|cid| (cid.sort_key, DocAddress::new(segment_ord, cid.doc)))
.collect())
}
}
impl SegmentSortKeyComputer for SortBySimilarityScore {
type SortKey = Score;
type SegmentSortKey = Score;
#[inline(always)]
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {
score
}
fn convert_segment_sort_key(&self, score: Score) -> Score {
score
}
}

View File

@@ -0,0 +1,98 @@
use std::marker::PhantomData;
use columnar::Column;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
use crate::{DocId, Score, SegmentReader};
/// Sorts by a fast value (u64, i64, f64, bool).
///
/// The field must appear explicitly in the schema, with the right type, and declared as
/// a fast field..
///
/// If the field is multivalued, only the first value is considered.
///
/// Documents that do not have this value are still considered.
/// Their sort key will simply be `None`.
#[derive(Debug, Clone)]
pub struct SortByStaticFastValue<T: FastValue> {
field: String,
typ: PhantomData<T>,
}
impl<T: FastValue> SortByStaticFastValue<T> {
/// Creates a new `SortByStaticFastValue` instance for the given field.
pub fn for_field(column_name: impl ToString) -> SortByStaticFastValue<T> {
Self {
field: column_name.to_string(),
typ: PhantomData,
}
}
}
impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
type Child = SortByFastValueSegmentSortKeyComputer<T>;
type SortKey = Option<T>;
type Comparator = NaturalComparator;
fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> {
// At the segment sort key computer level, we rely on the u64 representation.
// The mapping is monotonic, so it is sufficient to compute our top-K docs.
let field = schema.get_field(&self.field)?;
let field_entry = schema.get_field_entry(field);
if !field_entry.is_fast() {
return Err(crate::TantivyError::SchemaError(format!(
"Field `{}` is not a fast field.",
self.field,
)));
}
let schema_type = field_entry.field_type().value_type();
if schema_type != T::to_type() {
return Err(crate::TantivyError::SchemaError(format!(
"Field `{}` is of type {schema_type:?}, not of the type {:?}.",
&self.field,
T::to_type()
)));
}
Ok(())
}
fn segment_sort_key_computer(
&self,
segment_reader: &SegmentReader,
) -> crate::Result<Self::Child> {
let sort_column_opt = segment_reader.fast_fields().u64_lenient(&self.field)?;
let (sort_column, _sort_column_type) =
sort_column_opt.ok_or_else(|| FastFieldNotAvailableError {
field_name: self.field.clone(),
})?;
Ok(SortByFastValueSegmentSortKeyComputer {
sort_column,
typ: PhantomData,
})
}
}
pub struct SortByFastValueSegmentSortKeyComputer<T> {
sort_column: Column<u64>,
typ: PhantomData<T>,
}
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
type SortKey = Option<T>;
type SegmentSortKey = Option<u64>;
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {
self.sort_column.first(doc)
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key.map(T::from_u64)
}
}

View File

@@ -0,0 +1,72 @@
use columnar::StrColumn;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
use crate::{DocId, Score};
/// Sort by the first value of a string column.
///
/// The string can be dynamic (coming from a json field)
/// or static (being specificaly defined in the configuration).
///
/// If the field is multivalued, only the first value is considered.
///
/// Documents that do not have this value are still considered.
/// Their sort key will simply be `None`.
#[derive(Debug, Clone)]
pub struct SortByString {
column_name: String,
}
impl SortByString {
/// Creates a new sort by string sort key computer.
pub fn for_field(column_name: impl ToString) -> Self {
SortByString {
column_name: column_name.to_string(),
}
}
}
impl SortKeyComputer for SortByString {
type SortKey = Option<String>;
type Child = ByStringColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?;
Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt })
}
}
pub struct ByStringColumnSegmentSortKeyComputer {
str_column_opt: Option<StrColumn>,
}
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
type SortKey = Option<String>;
type SegmentSortKey = Option<TermOrdinal>;
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
let str_column = self.str_column_opt.as_ref()?;
str_column.ords().first(doc)
}
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
let term_ord = term_ord_opt?;
let str_column = self.str_column_opt.as_ref()?;
let mut bytes = Vec::new();
str_column
.dictionary()
.ord_to_term(term_ord, &mut bytes)
.ok()?;
String::try_from(bytes).ok()
}
}

View File

@@ -0,0 +1,631 @@
use std::cmp::Ordering;
use crate::collector::sort_key::{Comparator, NaturalComparator};
use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector;
use crate::collector::{default_collect_segment_impl, SegmentCollector as _, TopNComputer};
use crate::schema::Schema;
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
/// A `SegmentSortKeyComputer` makes it possible to modify the default score
/// for a given document belonging to a specific segment.
///
/// It is the segment local version of the [`SortKeyComputer`].
pub trait SegmentSortKeyComputer: 'static {
/// The final score being emitted.
type SortKey: 'static + PartialOrd + Send + Sync + Clone;
/// Sort key used by at the segment level by the `SegmentSortKeyComputer`.
///
/// It is typically small like a `u64`, and is meant to be converted
/// to the final score at the end of the collection of the segment.
type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone;
/// Computes the sort key for the given document and score.
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
/// Computes the sort key and pushes the document in a TopN Computer.
///
/// When using a tuple as the sorting key, the sort key is evaluated in a lazy manner.
#[inline(always)]
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
let sort_key = self.segment_sort_key(doc, score);
top_n_computer.push(sort_key, doc);
}
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
///
/// This method must be consistent with the `SortKey` ordering.
#[inline(always)]
fn compare_segment_sort_key(
&self,
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
NaturalComparator.compare(left, right)
}
/// Implementing this method makes it possible to avoid computing
/// a sort_key entirely if we can assess that it won't pass a threshold
/// with a partial computation.
///
/// This is currently used for lexicographic sorting.
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let sort_key = self.segment_sort_key(doc_id, score);
let cmp = self.compare_segment_sort_key(&sort_key, threshold);
if cmp == Ordering::Less {
None
} else {
Some((cmp, sort_key))
}
}
/// Convert a segment level sort key into the global sort key.
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey;
}
/// `SortKeyComputer` defines the sort key to be used by a TopK Collector.
///
/// The `SortKeyComputer` itself does not make much of the computation itself.
/// Instead, it helps constructing `Self::Child` instances that will compute
/// the sort key at a segment scale.
pub trait SortKeyComputer: Sync {
/// The sort key type.
type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug;
/// Type of the associated [`SegmentSortKeyComputer`].
type Child: SegmentSortKeyComputer<SortKey = Self::SortKey>;
/// Comparator type.
type Comparator: Comparator<Self::SortKey>
+ Comparator<<Self::Child as SegmentSortKeyComputer>::SegmentSortKey>
+ 'static;
/// Checks whether the schema is compatible with the sort key computer.
fn check_schema(&self, _schema: &Schema) -> crate::Result<()> {
Ok(())
}
/// Returns the sort key comparator.
fn comparator(&self) -> Self::Comparator {
Self::Comparator::default()
}
/// Indicates whether the sort key actually uses the similarity score (by default BM25).
/// If set to false, the similary score might not be computed (as an optimization),
/// and the score fed in the segment sort key computer could take any value.
fn requires_scoring(&self) -> bool {
false
}
/// Sorting by score has a overriding implementation for BM25 scores, using Block-WAND.
fn collect_segment_top_k(
&self,
k: usize,
weight: &dyn crate::query::Weight,
reader: &crate::SegmentReader,
segment_ord: u32,
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
let with_scoring = self.requires_scoring();
let segment_sort_key_computer = self.segment_sort_key_computer(reader)?;
let topn_computer = TopNComputer::new_with_comparator(k, self.comparator());
let mut segment_top_key_collector = TopBySortKeySegmentCollector {
topn_computer,
segment_ord,
segment_sort_key_computer,
};
default_collect_segment_impl(&mut segment_top_key_collector, weight, reader, with_scoring)?;
Ok(segment_top_key_collector.harvest())
}
/// Builds a child sort key computer for a specific segment.
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
}
impl<HeadSortKeyComputer, TailSortKeyComputer> SortKeyComputer
for (HeadSortKeyComputer, TailSortKeyComputer)
where
HeadSortKeyComputer: SortKeyComputer,
TailSortKeyComputer: SortKeyComputer,
{
type SortKey = (
<HeadSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
<TailSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
);
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
type Comparator = (
HeadSortKeyComputer::Comparator,
TailSortKeyComputer::Comparator,
);
fn comparator(&self) -> Self::Comparator {
(self.0.comparator(), self.1.comparator())
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((
self.0.segment_sort_key_computer(segment_reader)?,
self.1.segment_sort_key_computer(segment_reader)?,
))
}
/// Checks whether the schema is compatible with the sort key computer.
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
Ok(())
}
/// Indicates whether the sort key actually uses the similarity score (by default BM25).
/// If set to false, the similary score might not be computed (as an optimization),
/// and the score fed in the segment sort key computer could take any value.
fn requires_scoring(&self) -> bool {
self.0.requires_scoring() || self.1.requires_scoring()
}
}
impl<HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer> SegmentSortKeyComputer
for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer)
where
HeadSegmentSortKeyComputer: SegmentSortKeyComputer,
TailSegmentSortKeyComputer: SegmentSortKeyComputer,
{
type SortKey = (
HeadSegmentSortKeyComputer::SortKey,
TailSegmentSortKeyComputer::SortKey,
);
type SegmentSortKey = (
HeadSegmentSortKeyComputer::SegmentSortKey,
TailSegmentSortKeyComputer::SegmentSortKey,
);
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
///
/// By default, it uses the natural ordering.
#[inline]
fn compare_segment_sort_key(
&self,
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
self.0
.compare_segment_sort_key(&left.0, &right.0)
.then_with(|| self.1.compare_segment_sort_key(&left.1, &right.1))
}
#[inline(always)]
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
let sort_key: Self::SegmentSortKey;
if let Some(threshold) = &top_n_computer.threshold {
if let Some((_cmp, lazy_sort_key)) = self.accept_sort_key_lazy(doc, score, threshold) {
sort_key = lazy_sort_key;
} else {
return;
}
} else {
sort_key = self.segment_sort_key(doc, score);
};
top_n_computer.append_doc(doc, sort_key);
}
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
let head_sort_key = self.0.segment_sort_key(doc, score);
let tail_sort_key = self.1.segment_sort_key(doc, score);
(head_sort_key, tail_sort_key)
}
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let (head_threshold, tail_threshold) = threshold;
let (head_cmp, head_sort_key) =
self.0.accept_sort_key_lazy(doc_id, score, head_threshold)?;
if head_cmp == Ordering::Equal {
let (tail_cmp, tail_sort_key) =
self.1.accept_sort_key_lazy(doc_id, score, tail_threshold)?;
Some((tail_cmp, (head_sort_key, tail_sort_key)))
} else {
let tail_sort_key = self.1.segment_sort_key(doc_id, score);
Some((head_cmp, (head_sort_key, tail_sort_key)))
}
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
let (head_sort_key, tail_sort_key) = sort_key;
(
self.0.convert_segment_sort_key(head_sort_key),
self.1.convert_segment_sort_key(tail_sort_key),
)
}
}
/// This struct is used as an adapter to take a sort key computer and map its score to another
/// new sort key.
pub struct MappedSegmentSortKeyComputer<T, PreviousSortKey, NewSortKey> {
sort_key_computer: T,
map: fn(PreviousSortKey) -> NewSortKey,
}
impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
where
T: SegmentSortKeyComputer<SortKey = PreviousScore>,
PreviousScore: 'static + Clone + Send + Sync + PartialOrd,
NewScore: 'static + Clone + Send + Sync + PartialOrd,
{
type SortKey = NewScore;
type SegmentSortKey = T::SegmentSortKey;
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.sort_key_computer.segment_sort_key(doc, score)
}
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
self.sort_key_computer
.accept_sort_key_lazy(doc_id, score, threshold)
}
#[inline(always)]
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
self.sort_key_computer
.compute_sort_key_and_collect(doc, score, top_n_computer);
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey {
(self.map)(
self.sort_key_computer
.convert_segment_sort_key(segment_sort_key),
)
}
}
// We then re-use our (head, tail) implement and our mapper by seeing mapping any tuple (a, b, c,
// ...) as the chain (a, (b, (c, ...)))
impl<SortKeyComputer1, SortKeyComputer2, SortKeyComputer3> SortKeyComputer
for (SortKeyComputer1, SortKeyComputer2, SortKeyComputer3)
where
SortKeyComputer1: SortKeyComputer,
SortKeyComputer2: SortKeyComputer,
SortKeyComputer3: SortKeyComputer,
{
type SortKey = (
SortKeyComputer1::SortKey,
SortKeyComputer2::SortKey,
SortKeyComputer3::SortKey,
);
type Child = MappedSegmentSortKeyComputer<
<(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(SortKeyComputer2::SortKey, SortKeyComputer3::SortKey),
),
Self::SortKey,
>;
type Comparator = (
SortKeyComputer1::Comparator,
SortKeyComputer2::Comparator,
SortKeyComputer3::Comparator,
);
fn comparator(&self) -> Self::Comparator {
(
self.0.comparator(),
self.1.comparator(),
self.2.comparator(),
)
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let map = |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3);
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)),
map,
})
}
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
Ok(())
}
fn requires_scoring(&self) -> bool {
self.0.requires_scoring() || self.1.requires_scoring() || self.2.requires_scoring()
}
}
impl<SortKeyComputer1, SortKeyComputer2, SortKeyComputer3, SortKeyComputer4> SortKeyComputer
for (
SortKeyComputer1,
SortKeyComputer2,
SortKeyComputer3,
SortKeyComputer4,
)
where
SortKeyComputer1: SortKeyComputer,
SortKeyComputer2: SortKeyComputer,
SortKeyComputer3: SortKeyComputer,
SortKeyComputer4: SortKeyComputer,
{
type Child = MappedSegmentSortKeyComputer<
<(
SortKeyComputer1,
(SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)),
) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(
SortKeyComputer2::SortKey,
(SortKeyComputer3::SortKey, SortKeyComputer4::SortKey),
),
),
Self::SortKey,
>;
type SortKey = (
SortKeyComputer1::SortKey,
SortKeyComputer2::SortKey,
SortKeyComputer3::SortKey,
SortKeyComputer4::SortKey,
);
type Comparator = (
SortKeyComputer1::Comparator,
SortKeyComputer2::Comparator,
SortKeyComputer3::Comparator,
SortKeyComputer4::Comparator,
);
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?;
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: (
sort_key_computer1,
(sort_key_computer2, (sort_key_computer3, sort_key_computer4)),
),
map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| {
(sort_key1, sort_key2, sort_key3, sort_key4)
},
})
}
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
self.3.check_schema(schema)?;
Ok(())
}
fn requires_scoring(&self) -> bool {
self.0.requires_scoring()
|| self.1.requires_scoring()
|| self.2.requires_scoring()
|| self.3.requires_scoring()
}
}
impl<F, SegmentF, TSortKey> SortKeyComputer for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF,
SegmentF: 'static + FnMut(DocId) -> TSortKey,
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
{
type SortKey = TSortKey;
type Child = SegmentF;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((self)(segment_reader))
}
}
impl<F, TSortKey> SegmentSortKeyComputer for F
where
F: 'static + FnMut(DocId) -> TSortKey,
TSortKey: 'static + PartialOrd + Clone + Send + Sync,
{
type SortKey = TSortKey;
type SegmentSortKey = TSortKey;
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
(self)(doc)
}
/// Convert a segment level score into the global level score.
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::sync::Arc;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::Schema;
use crate::{DocId, Index, Order, SegmentReader};
fn build_test_index() -> Index {
let schema = Schema::builder().build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::TantivyDocument::default())
.unwrap();
index_writer.commit().unwrap();
index
}
#[test]
fn test_lazy_score_computer() {
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
let call_count_new_clone = call_count_clone.clone();
move |_doc: DocId| {
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);
"b"
}
};
let lazy_score_computer = (score_computer_primary, score_computer_secondary);
let index = build_test_index();
let searcher = index.reader().unwrap().searcher();
let mut segment_sort_key_computer = lazy_score_computer
.segment_sort_key_computer(searcher.segment_reader(0))
.unwrap();
let expected_sort_key = (200, "b");
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, "a"));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 1);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, "c"));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 2);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, "a"));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 3);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, "c"));
assert!(sort_key_opt.is_none());
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, "a"));
assert_eq!(sort_key_opt, None);
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, "c"));
assert_eq!(sort_key_opt, None);
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &expected_sort_key);
assert_eq!(sort_key_opt, Some((Ordering::Equal, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 5);
}
}
#[test]
fn test_lazy_score_computer_dynamic_ordering() {
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
let call_count_new_clone = call_count_clone.clone();
move |_doc: DocId| {
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);
2u32
}
};
let lazy_score_computer = (
(score_computer_primary, Order::Desc),
(score_computer_secondary, Order::Asc),
);
let index = build_test_index();
let searcher = index.reader().unwrap().searcher();
let mut segment_sort_key_computer = lazy_score_computer
.segment_sort_key_computer(searcher.segment_reader(0))
.unwrap();
let expected_sort_key = (200, 2u32);
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, 1u32));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 1);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, 3u32));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 2);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, 1u32));
assert!(sort_key_opt.is_none());
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 3);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, 3u32));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, 1u32));
assert_eq!(sort_key_opt, None);
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, 3u32));
assert_eq!(sort_key_opt, None);
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &expected_sort_key);
assert_eq!(sort_key_opt, Some((Ordering::Equal, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 5);
}
assert_eq!(
segment_sort_key_computer.convert_segment_sort_key(expected_sort_key),
(200u32, 2u32)
);
}
}

View File

@@ -0,0 +1,193 @@
use std::ops::Range;
use crate::collector::sort_key::{Comparator, SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{Collector, SegmentCollector, TopNComputer};
use crate::query::Weight;
use crate::schema::Schema;
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
pub(crate) struct TopBySortKeyCollector<TSortKeyComputer> {
sort_key_computer: TSortKeyComputer,
doc_range: Range<usize>,
}
impl<TSortKeyComputer> TopBySortKeyCollector<TSortKeyComputer> {
pub fn new(sort_key_computer: TSortKeyComputer, doc_range: Range<usize>) -> Self {
TopBySortKeyCollector {
sort_key_computer,
doc_range,
}
}
}
impl<TSortKeyComputer> Collector for TopBySortKeyCollector<TSortKeyComputer>
where TSortKeyComputer: SortKeyComputer + Send + Sync + 'static
{
type Fruit = Vec<(TSortKeyComputer::SortKey, DocAddress)>;
type Child =
TopBySortKeySegmentCollector<TSortKeyComputer::Child, TSortKeyComputer::Comparator>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.sort_key_computer.check_schema(schema)
}
fn for_segment(&self, segment_ord: u32, segment_reader: &SegmentReader) -> Result<Self::Child> {
let segment_sort_key_computer = self
.sort_key_computer
.segment_sort_key_computer(segment_reader)?;
let topn_computer = TopNComputer::new_with_comparator(
self.doc_range.end,
self.sort_key_computer.comparator(),
);
Ok(TopBySortKeySegmentCollector {
topn_computer,
segment_ord,
segment_sort_key_computer,
})
}
fn requires_scoring(&self) -> bool {
self.sort_key_computer.requires_scoring()
}
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
Ok(merge_top_k(
segment_fruits.into_iter().flatten(),
self.doc_range.clone(),
self.sort_key_computer.comparator(),
))
}
fn collect_segment(
&self,
weight: &dyn Weight,
segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<Vec<(TSortKeyComputer::SortKey, DocAddress)>> {
let k = self.doc_range.end;
let docs = self
.sort_key_computer
.collect_segment_top_k(k, weight, reader, segment_ord)?;
Ok(docs)
}
}
fn merge_top_k<D: Ord, TSortKey: Clone + std::fmt::Debug, C: Comparator<TSortKey>>(
sort_key_docs: impl Iterator<Item = (TSortKey, D)>,
doc_range: Range<usize>,
comparator: C,
) -> Vec<(TSortKey, D)> {
if doc_range.is_empty() {
return Vec::new();
}
let mut top_collector: TopNComputer<TSortKey, D, C> =
TopNComputer::new_with_comparator(doc_range.end, comparator);
for (sort_key, doc) in sort_key_docs {
top_collector.push(sort_key, doc);
}
top_collector
.into_sorted_vec()
.into_iter()
.skip(doc_range.start)
.map(|cdoc| (cdoc.sort_key, cdoc.doc))
.collect()
}
pub struct TopBySortKeySegmentCollector<TSegmentSortKeyComputer, C>
where
TSegmentSortKeyComputer: SegmentSortKeyComputer,
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey>,
{
pub(crate) topn_computer: TopNComputer<TSegmentSortKeyComputer::SegmentSortKey, DocId, C>,
pub(crate) segment_ord: u32,
pub(crate) segment_sort_key_computer: TSegmentSortKeyComputer,
}
impl<TSegmentSortKeyComputer, C> SegmentCollector
for TopBySortKeySegmentCollector<TSegmentSortKeyComputer, C>
where
TSegmentSortKeyComputer: 'static + SegmentSortKeyComputer,
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey> + 'static,
{
type Fruit = Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)>;
fn collect(&mut self, doc: DocId, score: Score) {
self.segment_sort_key_computer.compute_sort_key_and_collect(
doc,
score,
&mut self.topn_computer,
);
}
fn harvest(self) -> Self::Fruit {
let segment_ord = self.segment_ord;
let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self
.topn_computer
.into_vec()
.into_iter()
.map(|comparable_doc| {
let sort_key = self
.segment_sort_key_computer
.convert_segment_sort_key(comparable_doc.sort_key);
(
sort_key,
DocAddress {
segment_ord,
doc_id: comparable_doc.doc,
},
)
})
.collect();
segment_hits
}
}
#[cfg(test)]
mod tests {
use std::ops::Range;
use rand;
use rand::seq::SliceRandom as _;
use super::merge_top_k;
use crate::collector::sort_key::ComparatorEnum;
use crate::Order;
fn test_merge_top_k_aux(
order: Order,
doc_range: Range<usize>,
expected: &[(crate::Score, usize)],
) {
let mut vals: Vec<(crate::Score, usize)> = (0..10).map(|val| (val as f32, val)).collect();
vals.shuffle(&mut rand::thread_rng());
let vals_merged = merge_top_k(vals.into_iter(), doc_range, ComparatorEnum::from(order));
assert_eq!(&vals_merged, expected);
}
#[test]
fn test_merge_top_k() {
test_merge_top_k_aux(Order::Asc, 0..0, &[]);
test_merge_top_k_aux(Order::Asc, 3..3, &[]);
test_merge_top_k_aux(Order::Asc, 0..3, &[(0.0f32, 0), (1.0f32, 1), (2.0f32, 2)]);
test_merge_top_k_aux(
Order::Asc,
0..11,
&[
(0.0f32, 0),
(1.0f32, 1),
(2.0f32, 2),
(3.0f32, 3),
(4.0f32, 4),
(5.0f32, 5),
(6.0f32, 6),
(7.0f32, 7),
(8.0f32, 8),
(9.0f32, 9),
],
);
test_merge_top_k_aux(Order::Asc, 1..3, &[(1.0f32, 1), (2.0f32, 2)]);
test_merge_top_k_aux(Order::Desc, 0..2, &[(9.0f32, 9), (8.0f32, 8)]);
test_merge_top_k_aux(Order::Desc, 2..4, &[(7.0f32, 7), (6.0f32, 6)]);
}
}

View File

@@ -40,7 +40,7 @@ pub fn test_filter_collector() -> crate::Result<()> {
let filter_some_collector = FilterCollector::new(
"price".to_string(),
&|value: u64| value > 20_120u64,
TopDocs::with_limit(2),
TopDocs::with_limit(2).order_by_score(),
);
let top_docs = searcher.search(&query, &filter_some_collector)?;
@@ -50,7 +50,7 @@ pub fn test_filter_collector() -> crate::Result<()> {
let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(
"price".to_string(),
&|value| value < 5u64,
TopDocs::with_limit(2),
TopDocs::with_limit(2).order_by_score(),
);
let filtered_top_docs = searcher.search(&query, &filter_all_collector).unwrap();
@@ -62,8 +62,11 @@ pub fn test_filter_collector() -> crate::Result<()> {
> 0
}
let filter_dates_collector =
FilterCollector::new("date".to_string(), &date_filter, TopDocs::with_limit(5));
let filter_dates_collector = FilterCollector::new(
"date".to_string(),
&date_filter,
TopDocs::with_limit(5).order_by_score(),
);
let filtered_date_docs = searcher.search(&query, &filter_dates_collector)?;
assert_eq!(filtered_date_docs.len(), 2);

View File

@@ -1,12 +1,7 @@
use std::cmp::Ordering;
use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
use super::top_score_collector::TopNComputer;
use crate::index::SegmentReader;
use crate::{DocAddress, DocId, SegmentOrdinal};
/// Contains a feature (field, score, etc.) of a document along with the document address.
///
/// It guarantees stable sorting: in case of a tie on the feature, the document
@@ -19,7 +14,7 @@ use crate::{DocAddress, DocId, SegmentOrdinal};
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
/// The feature of the document. In practice, this is
/// is any type that implements `PartialOrd`.
pub feature: T,
pub sort_key: T,
/// The document address. In practice, this is any
/// type that implements `PartialOrd`, and is guaranteed
/// to be unique for each document.
@@ -28,9 +23,9 @@ pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
for ComparableDoc<T, D, R>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
.field("feature", &self.feature)
.field("feature", &self.sort_key)
.field("doc", &self.doc)
.finish()
}
@@ -46,8 +41,8 @@ impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R>
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
let by_feature = self
.feature
.partial_cmp(&other.feature)
.sort_key
.partial_cmp(&other.sort_key)
.map(|ord| if R { ord.reverse() } else { ord })
.unwrap_or(Ordering::Equal);
@@ -67,308 +62,3 @@ impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T,
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}
pub(crate) struct TopCollector<T> {
pub limit: usize,
pub offset: usize,
_marker: PhantomData<T>,
}
impl<T> TopCollector<T>
where T: PartialOrd + Clone
{
/// Creates a top collector, with a number of documents equal to "limit".
///
/// # Panics
/// The method panics if limit is 0
pub fn with_limit(limit: usize) -> TopCollector<T> {
assert!(limit >= 1, "Limit must be strictly greater than 0.");
Self {
limit,
offset: 0,
_marker: PhantomData,
}
}
/// Skip the first "offset" documents when collecting.
///
/// This is equivalent to `OFFSET` in MySQL or PostgreSQL and `start` in
/// Lucene's TopDocsCollector.
pub fn and_offset(mut self, offset: usize) -> TopCollector<T> {
self.offset = offset;
self
}
pub fn merge_fruits(
&self,
children: Vec<Vec<(T, DocAddress)>>,
) -> crate::Result<Vec<(T, DocAddress)>> {
if self.limit == 0 {
return Ok(Vec::new());
}
let mut top_collector: TopNComputer<_, _> = TopNComputer::new(self.limit + self.offset);
for child_fruit in children {
for (feature, doc) in child_fruit {
top_collector.push(feature, doc);
}
}
Ok(top_collector
.into_sorted_vec()
.into_iter()
.skip(self.offset)
.map(|cdoc| (cdoc.feature, cdoc.doc))
.collect())
}
pub(crate) fn for_segment<F: PartialOrd + Clone>(
&self,
segment_id: SegmentOrdinal,
_: &SegmentReader,
) -> TopSegmentCollector<F> {
TopSegmentCollector::new(segment_id, self.limit + self.offset)
}
/// Create a new TopCollector with the same limit and offset.
///
/// Ideally we would use Into but the blanket implementation seems to cause the Scorer traits
/// to fail.
#[doc(hidden)]
pub(crate) fn into_tscore<TScore: PartialOrd + Clone>(self) -> TopCollector<TScore> {
TopCollector {
limit: self.limit,
offset: self.offset,
_marker: PhantomData,
}
}
}
/// The Top Collector keeps track of the K documents
/// sorted by type `T`.
///
/// The implementation is based on a repeatedly truncating on the median after K * 2 documents
/// The theoretical complexity for collecting the top `K` out of `n` documents
/// is `O(n + K)`.
pub(crate) struct TopSegmentCollector<T> {
/// We reverse the order of the feature in order to
/// have top-semantics instead of bottom semantics.
topn_computer: TopNComputer<T, DocId>,
segment_ord: u32,
}
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
fn new(segment_ord: SegmentOrdinal, limit: usize) -> TopSegmentCollector<T> {
TopSegmentCollector {
topn_computer: TopNComputer::new(limit),
segment_ord,
}
}
}
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
pub fn harvest(self) -> Vec<(T, DocAddress)> {
let segment_ord = self.segment_ord;
self.topn_computer
.into_sorted_vec()
.into_iter()
.map(|comparable_doc| {
(
comparable_doc.feature,
DocAddress {
segment_ord,
doc_id: comparable_doc.doc,
},
)
})
.collect()
}
/// Collects a document scored by the given feature
///
/// It collects documents until it has reached the max capacity. Once it reaches capacity, it
/// will compare the lowest scoring item with the given one and keep whichever is greater.
#[inline]
pub fn collect(&mut self, doc: DocId, feature: T) {
self.topn_computer.push(feature, doc);
}
}
#[cfg(test)]
mod tests {
use super::{TopCollector, TopSegmentCollector};
use crate::DocAddress;
#[test]
fn test_top_collector_not_at_capacity() {
let mut top_collector = TopSegmentCollector::new(0, 4);
top_collector.collect(1, 0.8);
top_collector.collect(3, 0.2);
top_collector.collect(5, 0.3);
assert_eq!(
top_collector.harvest(),
vec![
(0.8, DocAddress::new(0, 1)),
(0.3, DocAddress::new(0, 5)),
(0.2, DocAddress::new(0, 3))
]
);
}
#[test]
fn test_top_collector_at_capacity() {
let mut top_collector = TopSegmentCollector::new(0, 4);
top_collector.collect(1, 0.8);
top_collector.collect(3, 0.2);
top_collector.collect(5, 0.3);
top_collector.collect(7, 0.9);
top_collector.collect(9, -0.2);
assert_eq!(
top_collector.harvest(),
vec![
(0.9, DocAddress::new(0, 7)),
(0.8, DocAddress::new(0, 1)),
(0.3, DocAddress::new(0, 5)),
(0.2, DocAddress::new(0, 3))
]
);
}
#[test]
fn test_top_segment_collector_stable_ordering_for_equal_feature() {
// given that the documents are collected in ascending doc id order,
// when harvesting we have to guarantee stable sorting in case of a tie
// on the score
let doc_ids_collection = [4, 5, 6];
let score = 3.3f32;
let mut top_collector_limit_2 = TopSegmentCollector::new(0, 2);
for id in &doc_ids_collection {
top_collector_limit_2.collect(*id, score);
}
let mut top_collector_limit_3 = TopSegmentCollector::new(0, 3);
for id in &doc_ids_collection {
top_collector_limit_3.collect(*id, score);
}
assert_eq!(
top_collector_limit_2.harvest(),
top_collector_limit_3.harvest()[..2].to_vec(),
);
}
#[test]
fn test_top_collector_with_limit_and_offset() {
let collector = TopCollector::with_limit(2).and_offset(1);
let results = collector
.merge_fruits(vec![vec![
(0.9, DocAddress::new(0, 1)),
(0.8, DocAddress::new(0, 2)),
(0.7, DocAddress::new(0, 3)),
(0.6, DocAddress::new(0, 4)),
(0.5, DocAddress::new(0, 5)),
]])
.unwrap();
assert_eq!(
results,
vec![(0.8, DocAddress::new(0, 2)), (0.7, DocAddress::new(0, 3)),]
);
}
#[test]
fn test_top_collector_with_limit_larger_than_set_and_offset() {
let collector = TopCollector::with_limit(2).and_offset(1);
let results = collector
.merge_fruits(vec![vec![
(0.9, DocAddress::new(0, 1)),
(0.8, DocAddress::new(0, 2)),
]])
.unwrap();
assert_eq!(results, vec![(0.8, DocAddress::new(0, 2)),]);
}
#[test]
fn test_top_collector_with_limit_and_offset_larger_than_set() {
let collector = TopCollector::with_limit(2).and_offset(20);
let results = collector
.merge_fruits(vec![vec![
(0.9, DocAddress::new(0, 1)),
(0.8, DocAddress::new(0, 2)),
]])
.unwrap();
assert_eq!(results, vec![]);
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use test::Bencher;
use super::TopSegmentCollector;
#[bench]
fn bench_top_segment_collector_collect_not_at_capacity(b: &mut Bencher) {
let mut top_collector = TopSegmentCollector::new(0, 400);
b.iter(|| {
for i in 0..100 {
top_collector.collect(i, 0.8);
}
});
}
#[bench]
fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) {
let mut top_collector = TopSegmentCollector::new(0, 100);
for i in 0..100 {
top_collector.collect(i, 0.8);
}
b.iter(|| {
for i in 0..100 {
top_collector.collect(i, 0.8);
}
});
}
#[bench]
fn bench_top_segment_collector_collect_and_harvest_many_ties(b: &mut Bencher) {
b.iter(|| {
let mut top_collector = TopSegmentCollector::new(0, 100);
for i in 0..100 {
top_collector.collect(i, 0.8);
}
// it would be nice to be able to do the setup N times but still
// measure only harvest(). We can't since harvest() consumes
// the top_collector.
top_collector.harvest()
});
}
#[bench]
fn bench_top_segment_collector_collect_and_harvest_no_tie(b: &mut Bencher) {
b.iter(|| {
let mut top_collector = TopSegmentCollector::new(0, 100);
let mut score = 1.0;
for i in 0..100 {
score += 1.0;
top_collector.collect(i, score);
}
// it would be nice to be able to do the setup N times but still
// measure only harvest(). We can't since harvest() consumes
// the top_collector.
top_collector.harvest()
});
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,124 +0,0 @@
use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
use crate::collector::{Collector, SegmentCollector};
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
pub(crate) struct TweakedScoreTopCollector<TScoreTweaker, TScore = Score> {
score_tweaker: TScoreTweaker,
collector: TopCollector<TScore>,
}
impl<TScoreTweaker, TScore> TweakedScoreTopCollector<TScoreTweaker, TScore>
where TScore: Clone + PartialOrd
{
pub fn new(
score_tweaker: TScoreTweaker,
collector: TopCollector<TScore>,
) -> TweakedScoreTopCollector<TScoreTweaker, TScore> {
TweakedScoreTopCollector {
score_tweaker,
collector,
}
}
}
/// A `ScoreSegmentTweaker` makes it possible to modify the default score
/// for a given document belonging to a specific segment.
///
/// It is the segment local version of the [`ScoreTweaker`].
pub trait ScoreSegmentTweaker<TScore>: 'static {
/// Tweak the given `score` for the document `doc`.
fn score(&mut self, doc: DocId, score: Score) -> TScore;
}
/// `ScoreTweaker` makes it possible to tweak the score
/// emitted by the scorer into another one.
///
/// The `ScoreTweaker` itself does not make much of the computation itself.
/// Instead, it helps constructing `Self::Child` instances that will compute
/// the score at a segment scale.
pub trait ScoreTweaker<TScore>: Sync {
/// Type of the associated [`ScoreSegmentTweaker`].
type Child: ScoreSegmentTweaker<TScore>;
/// Builds a child tweaker for a specific segment. The child scorer is associated with
/// a specific segment.
fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
}
impl<TScoreTweaker, TScore> Collector for TweakedScoreTopCollector<TScoreTweaker, TScore>
where
TScoreTweaker: ScoreTweaker<TScore> + Send + Sync,
TScore: 'static + PartialOrd + Clone + Send + Sync,
{
type Fruit = Vec<(TScore, DocAddress)>;
type Child = TopTweakedScoreSegmentCollector<TScoreTweaker::Child, TScore>;
fn for_segment(
&self,
segment_local_id: u32,
segment_reader: &SegmentReader,
) -> Result<Self::Child> {
let segment_scorer = self.score_tweaker.segment_tweaker(segment_reader)?;
let segment_collector = self.collector.for_segment(segment_local_id, segment_reader);
Ok(TopTweakedScoreSegmentCollector {
segment_collector,
segment_scorer,
})
}
fn requires_scoring(&self) -> bool {
true
}
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
self.collector.merge_fruits(segment_fruits)
}
}
pub struct TopTweakedScoreSegmentCollector<TSegmentScoreTweaker, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync + Sized,
TSegmentScoreTweaker: ScoreSegmentTweaker<TScore>,
{
segment_collector: TopSegmentCollector<TScore>,
segment_scorer: TSegmentScoreTweaker,
}
impl<TSegmentScoreTweaker, TScore> SegmentCollector
for TopTweakedScoreSegmentCollector<TSegmentScoreTweaker, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync,
TSegmentScoreTweaker: 'static + ScoreSegmentTweaker<TScore>,
{
type Fruit = Vec<(TScore, DocAddress)>;
fn collect(&mut self, doc: DocId, score: Score) {
let score = self.segment_scorer.score(doc, score);
self.segment_collector.collect(doc, score);
}
fn harvest(self) -> Vec<(TScore, DocAddress)> {
self.segment_collector.harvest()
}
}
impl<F, TScore, TSegmentScoreTweaker> ScoreTweaker<TScore> for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> TSegmentScoreTweaker,
TSegmentScoreTweaker: ScoreSegmentTweaker<TScore>,
{
type Child = TSegmentScoreTweaker;
fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((self)(segment_reader))
}
}
impl<F, TScore> ScoreSegmentTweaker<TScore> for F
where F: 'static + FnMut(DocId, Score) -> TScore
{
fn score(&mut self, doc: DocId, score: Score) -> TScore {
(self)(doc, score)
}
}

View File

@@ -69,7 +69,7 @@ fn assert_date_time_precision(index: &Index, doc_store_precision: DateTimePrecis
.parse_query("dateformat")
.expect("Failed to parse query");
let top_docs = searcher
.search(&query, &TopDocs::with_limit(1))
.search(&query, &TopDocs::with_limit(1).order_by_score())
.expect("Search failed");
assert_eq!(top_docs.len(), 1, "Expected 1 search result");

View File

@@ -48,7 +48,15 @@ impl Executor {
F: Sized + Sync + Fn(A) -> crate::Result<R>,
{
match self {
Executor::SingleThread => args.map(f).collect::<crate::Result<_>>(),
Executor::SingleThread => {
// Avoid `collect`, since the stacktrace is blown up by it, which makes profiling
// harder.
let mut result = Vec::with_capacity(args.size_hint().0);
for arg in args {
result.push(f(arg)?);
}
Ok(result)
}
Executor::ThreadPool(pool) => {
let args: Vec<A> = args.collect();
let num_fruits = args.len();

View File

@@ -1,7 +1,9 @@
use columnar::NumericalValue;
use common::json_path_writer::{JSON_END_OF_PATH, JSON_PATH_SEGMENT_SEP};
use common::{replace_in_place, JsonPathWriter};
use rustc_hash::FxHashMap;
use crate::indexer::indexing_term::IndexingTerm;
use crate::postings::{IndexingContext, IndexingPosition, PostingsWriter};
use crate::schema::document::{ReferenceValue, ReferenceValueLeaf, Value};
use crate::schema::{Type, DATE_TIME_PRECISION_INDEXED};
@@ -76,7 +78,7 @@ fn index_json_object<'a, V: Value<'a>>(
doc: DocId,
json_visitor: V::ObjectIter,
text_analyzer: &mut TextAnalyzer,
term_buffer: &mut Term,
term_buffer: &mut IndexingTerm,
json_path_writer: &mut JsonPathWriter,
postings_writer: &mut dyn PostingsWriter,
ctx: &mut IndexingContext,
@@ -106,17 +108,17 @@ pub(crate) fn index_json_value<'a, V: Value<'a>>(
doc: DocId,
json_value: V,
text_analyzer: &mut TextAnalyzer,
term_buffer: &mut Term,
term_buffer: &mut IndexingTerm,
json_path_writer: &mut JsonPathWriter,
postings_writer: &mut dyn PostingsWriter,
ctx: &mut IndexingContext,
positions_per_path: &mut IndexingPositionsPerPath,
) {
let set_path_id = |term_buffer: &mut Term, unordered_id: u32| {
let set_path_id = |term_buffer: &mut IndexingTerm, unordered_id: u32| {
term_buffer.truncate_value_bytes(0);
term_buffer.append_bytes(&unordered_id.to_be_bytes());
};
let set_type = |term_buffer: &mut Term, typ: Type| {
let set_type = |term_buffer: &mut IndexingTerm, typ: Type| {
term_buffer.append_bytes(&[typ.to_code()]);
};
@@ -152,7 +154,7 @@ pub(crate) fn index_json_value<'a, V: Value<'a>>(
if let Ok(i64_val) = val.try_into() {
term_buffer.append_type_and_fast_value::<i64>(i64_val);
} else {
term_buffer.append_type_and_fast_value(val);
term_buffer.append_type_and_fast_value::<u64>(val);
}
postings_writer.subscribe(doc, 0u32, term_buffer, ctx);
}
@@ -166,12 +168,30 @@ pub(crate) fn index_json_value<'a, V: Value<'a>>(
postings_writer.subscribe(doc, 0u32, term_buffer, ctx);
}
ReferenceValueLeaf::F64(val) => {
if !val.is_finite() {
return;
};
set_path_id(
term_buffer,
ctx.path_to_unordered_id
.get_or_allocate_unordered_id(json_path_writer.as_str()),
);
term_buffer.append_type_and_fast_value(val);
// Normalize here is important.
// In the inverted index, we coerce all numerical values to their canonical
// representation.
//
// (We do the same thing on the query side)
match NumericalValue::F64(val).normalize() {
NumericalValue::I64(val_i64) => {
term_buffer.append_type_and_fast_value::<i64>(val_i64);
}
NumericalValue::U64(val_u64) => {
term_buffer.append_type_and_fast_value::<u64>(val_u64);
}
NumericalValue::F64(val_f64) => {
term_buffer.append_type_and_fast_value::<f64>(val_f64);
}
}
postings_writer.subscribe(doc, 0u32, term_buffer, ctx);
}
ReferenceValueLeaf::Bool(val) => {
@@ -241,8 +261,8 @@ pub(crate) fn index_json_value<'a, V: Value<'a>>(
///
/// The term must be json + JSON path.
pub fn convert_to_fast_value_and_append_to_json_term(
mut term: Term,
phrase: &str,
term: &Term,
text: &str,
truncate_date_for_search: bool,
) -> Option<Term> {
assert_eq!(
@@ -254,31 +274,50 @@ pub fn convert_to_fast_value_and_append_to_json_term(
0,
"JSON value bytes should be empty"
);
if let Ok(dt) = OffsetDateTime::parse(phrase, &Rfc3339) {
let mut dt = DateTime::from_utc(dt.to_offset(UtcOffset::UTC));
if truncate_date_for_search {
dt = dt.truncate(DATE_TIME_PRECISION_INDEXED);
try_convert_to_datetime_and_append_to_json_term(term, text, truncate_date_for_search)
.or_else(|| try_convert_to_number_and_append_to_json_term(term, text))
.or_else(|| try_convert_to_bool_and_append_to_json_term_typed(term, text))
}
fn try_convert_to_datetime_and_append_to_json_term(
term: &Term,
text: &str,
truncate_date_for_search: bool,
) -> Option<Term> {
let dt = OffsetDateTime::parse(text, &Rfc3339).ok()?;
let mut dt = DateTime::from_utc(dt.to_offset(UtcOffset::UTC));
if truncate_date_for_search {
dt = dt.truncate(DATE_TIME_PRECISION_INDEXED);
}
let mut term_clone = term.clone();
term_clone.append_type_and_fast_value(dt);
Some(term_clone)
}
fn try_convert_to_number_and_append_to_json_term(term: &Term, text: &str) -> Option<Term> {
let numerical_value: NumericalValue = str::parse::<NumericalValue>(text).ok()?;
let mut term_clone = term.clone();
// Parse is actually returning normalized values already today, but let's not
// not rely on that hidden contract.
match numerical_value.normalize() {
NumericalValue::I64(i64_value) => {
term_clone.append_type_and_fast_value::<i64>(i64_value);
}
NumericalValue::U64(u64_value) => {
term_clone.append_type_and_fast_value::<u64>(u64_value);
}
NumericalValue::F64(f64_value) => {
term_clone.append_type_and_fast_value::<f64>(f64_value);
}
term.append_type_and_fast_value(dt);
return Some(term);
}
if let Ok(i64_val) = str::parse::<i64>(phrase) {
term.append_type_and_fast_value(i64_val);
return Some(term);
}
if let Ok(u64_val) = str::parse::<u64>(phrase) {
term.append_type_and_fast_value(u64_val);
return Some(term);
}
if let Ok(f64_val) = str::parse::<f64>(phrase) {
term.append_type_and_fast_value(f64_val);
return Some(term);
}
if let Ok(bool_val) = str::parse::<bool>(phrase) {
term.append_type_and_fast_value(bool_val);
return Some(term);
}
None
Some(term_clone)
}
fn try_convert_to_bool_and_append_to_json_term_typed(term: &Term, text: &str) -> Option<Term> {
let val = str::parse::<bool>(text).ok()?;
let mut term_clone = term.clone();
term_clone.append_type_and_fast_value(val);
Some(term_clone)
}
/// Splits a json path supplied to the query parser in such a way that

View File

@@ -225,6 +225,7 @@ impl Searcher {
enabled_scoring: EnableScoring,
) -> crate::Result<C::Fruit> {
let weight = query.weight(enabled_scoring)?;
collector.check_schema(self.schema())?;
let segment_readers = self.segment_readers();
let fruits = executor.map(
|(segment_ord, segment_reader)| {

View File

@@ -108,7 +108,7 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
/// Opens a file and returns a boxed `FileHandle`.
///
/// Users of `Directory` should typically call `Directory::open_read(...)`,
/// while `Directory` implementor should implement `get_file_handle()`.
/// while `Directory` implementer should implement `get_file_handle()`.
fn get_file_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>, OpenReadError>;
/// Once a virtual file is open, its data may not

Some files were not shown because too many files have changed in this diff Show More