Compare commits

..

29 Commits

Author SHA1 Message Date
trinity Pointard
32a8f8646f document 1 unused bit in skiplist 2025-12-19 10:34:20 +01:00
trinity Pointard
53c4b8346c add small doc on some queries using fast field when not indexed 2025-12-19 10:34:20 +01:00
Stu Hood
c0f21a45ae Use a strict comparison in TopNComputer (#2777)
* Remove `(Partial)Ord` from `ComparableDoc`, and unify comparison between `TopNComputer` and `Comparator`.

* Doc cleanups.

* Require Ord for `ComparableDoc`.

* Semantics are actually _ascending_ DocId order.

* Adjust docs again for ascending DocId order.

* minor change

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-12-18 12:13:23 +01:00
Moe
73657dff77 fix: fixed integer overflow in ExpUnrolledLinkedList for large datasets (#2735)
* Fixed the overflow issue.

* Fixed lint issues.

* Applied PR fixes.

* Fixed a lint issue.
2025-12-16 22:57:12 +01:00
Moe
e3c9be1f92 fix: boolean query incorrectly dropping documents when AllScorer is present (#2760)
* Fixed the range issue.

* Fixed the second all scorer issue

* Improved docs + tests

* Improved code.

* Fixed lint issues.

* Improved tests + logic based on PR comments.

* Fixed lint issues.

* Increase the document count.

* Improved the prop-tests

* Expand the index size, and remove unused parameter.

---------

Co-authored-by: Stu Hood <stuhood@gmail.com>
2025-12-16 22:52:02 +01:00
Ming
ba61ed6ef3 fix: vint buffer can overflow (#2778)
* fix vint overflow

* comment
2025-12-16 22:50:41 +01:00
trinity-1686a
d0e1600135 fix bug with minimum_should_match and AllScorer (#2774) 2025-12-14 10:10:45 +01:00
PSeitz-dd
e9020d17d4 fix coverage (#2769) 2025-12-11 11:35:58 +01:00
PSeitz-dd
5ba0031f7d move rand_distr to dev_dep (#2772) 2025-12-11 18:23:50 +08:00
Philippe Noël
22dde8f9ae chore: Make some delete-related functions public (#46) (#2766)
Co-authored-by: Ming <ming.ying.nyc@gmail.com>
2025-12-11 01:22:15 +01:00
Philippe Noël
14cc24614e Make DeleteMeta pub (#2765)
Co-authored-by: Ming Ying <ming.ying.nyc@gmail.com>
2025-12-11 00:11:03 +01:00
Philippe Noël
8a1079b2dc expose AddOperation and with_max_doc (#7) (#2762)
Co-authored-by: Ming <ming.ying.nyc@gmail.com>
2025-12-11 00:10:42 +01:00
Philippe Noël
794ff1ffc9 chore: Make Language hashable (#79) (#2763)
Co-authored-by: Ming <ming.ying.nyc@gmail.com>
2025-12-10 15:38:43 +01:00
PSeitz-dd
c6912ce89a Handle JSON fields and columnar in space_usage (#2761)
return field names in space_usage instead of `Field`
more detailed info for columns
2025-12-10 20:33:33 +08:00
PSeitz
618e3bd11b Term and IndexingTerm cleanup (#2750)
* refactor term

* add deprecated functions

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-12-05 09:48:40 +08:00
PSeitz
b2f99c6217 add term->histogram benchmark (#2758)
* add term->histogram benchmark

* add more term aggs

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-12-04 02:29:37 +01:00
PSeitz
76de5bab6f fix unsafe warnings (#2757) 2025-12-03 20:15:21 +08:00
rustmailer
b7eb31162b docs: add usage example to README (#2743) 2025-12-02 21:56:57 +01:00
Paul Masurel
63c66005db Lazy scorers (#2726)
* Refactoring of the score tweaker into `SortKeyComputer`s to unlock two features.

- Allow lazy evaluation of score. As soon as we identified that a doc won't
reach the topK threshold, we can stop the evaluation.
- Allow for a different segment level score, segment level score and their conversion.

This PR breaks public API, but fixing code is straightforward.

* Bumping tantivy version

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-12-01 15:38:57 +01:00
Paul Masurel
7d513a44c5 Added some benchmark for top K by a fast field (#2754)
Also removed query parsing from the bench code.

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-12-01 14:58:29 +01:00
Stu Hood
ca87fcd454 Implement collect_block for Collectors which wrap other Collectors (#2727)
* Implement `collect_block` for tuple Collectors, and for MultiCollector.

* Two more.
2025-12-01 12:26:29 +01:00
Ang
08a92675dc Fix typos again (#2753)
Found via `codespell -S benches,stopwords.rs -L
womens,parth,abd,childs,ond,ser,ue,mot,hel,atleast,pris,claus,allo`
2025-12-01 12:15:41 +01:00
Raphaël Cohen
f7f4b354d6 fix: Handle phrase prefixed with star (#2751)
Signed-off-by: Darkheir <raphael.cohen@sekoia.io>
2025-12-01 11:43:25 +01:00
Paul Masurel
25d44fcec8 Revert "remove unused columnar api (#2742)" (#2748)
* Revert "remove unused columnar api (#2742)"

This reverts commit 8725594d47.

* Clippy comment + removing fill_vals

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-11-26 17:44:02 +01:00
PSeitz-dd
842fe9295f split Term in Term and IndexingTerm (#2744)
* split Term in Term and IndexingTerm

* add append_json_path to JsonTermSerializer
2025-11-26 16:48:59 +01:00
Paul Masurel
f88b7200b2 Optimization when posting list are saturated. (#2745)
* Optimization when posting list are saturated.

If a posting list doc freq is the segment reader's
max_doc, and if scoring does not matter, we can replace it
by a AllScorer.

In turn, in a boolean query, we can dismiss  all scorers and
empty scorers, to accelerate the request.

* Added range query optimization

* CR comment

* CR comments

* CR comment

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-11-26 15:50:57 +01:00
PSeitz-dd
8725594d47 remove unused columnar api (#2742) 2025-11-21 18:07:25 +01:00
PSeitz
43a784671a clippy (#2741)
Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-11-21 18:07:03 +01:00
Paul Masurel
c363bbd23d Optimize term aggregation with low cardinality + some refactoring (#2740)
This introduce an optimization of top level term aggregation on field with a low cardinality.

We then use a Vec as the underlying map.
In addition, we buffer subaggregations.

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
Co-authored-by: Paul Masurel <paul@quickwit.io>
2025-11-21 14:46:29 +01:00
113 changed files with 5678 additions and 2776 deletions

View File

@@ -15,11 +15,11 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
continue-on-error: true

View File

@@ -78,7 +78,7 @@ This will slightly increase space and access time. [#2439](https://github.com/qu
- **Store DateTime as nanoseconds in doc store** DateTime in the doc store was truncated to microseconds previously. This removes this truncation, while still keeping backwards compatibility. [#2486](https://github.com/quickwit-oss/tantivy/pull/2486)(@PSeitz)
- **Performace/Memory**
- **Performance/Memory**
- lift clauses in LogicalAst for optimized ast during execution [#2449](https://github.com/quickwit-oss/tantivy/pull/2449)(@PSeitz)
- Use Vec instead of BTreeMap to back OwnedValue object [#2364](https://github.com/quickwit-oss/tantivy/pull/2364)(@fulmicoton)
- Replace TantivyDocument with CompactDoc. CompactDoc is much smaller and provides similar performance. [#2402](https://github.com/quickwit-oss/tantivy/pull/2402)(@PSeitz)

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy"
version = "0.25.0"
version = "0.26.0"
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
license = "MIT"
categories = ["database-implementations", "data-structures"]

View File

@@ -123,6 +123,7 @@ You can also find other bindings on [GitHub](https://github.com/search?q=tantivy
- [seshat](https://github.com/matrix-org/seshat/): A matrix message database/indexer
- [tantiny](https://github.com/baygeldin/tantiny): Tiny full-text search for Ruby
- [lnx](https://github.com/lnx-search/lnx): adaptable, typo tolerant search engine with a REST API
- [Bichon](https://github.com/rustmailer/bichon): A lightweight, high-performance Rust email archiver with WebUI
- and [more](https://github.com/search?q=tantivy)!
### On average, how much faster is Tantivy compared to Lucene?

View File

@@ -10,7 +10,7 @@ rename FastFieldReaders::open to load
remove fast field reader
find a way to unify the two DateTime.
readd type check in the filter wrapper
re-add type check in the filter wrapper
add unit test on columnar list columns.

View File

@@ -1,5 +1,6 @@
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use rand::distributions::WeightedIndex;
use rand::prelude::SliceRandom;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
@@ -54,11 +55,19 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, extendedstats_f64);
register!(group, percentiles_f64);
register!(group, terms_few);
register!(group, terms_all_unique);
register!(group, terms_many);
register!(group, terms_many_top_1000);
register!(group, terms_many_order_by_term);
register!(group, terms_many_with_top_hits);
register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_few_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status);
register!(group, terms_few_with_histogram);
register!(group, terms_status_with_histogram);
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, cardinality_agg);
@@ -130,12 +139,12 @@ fn extendedstats_f64(index: &Index) {
}
fn percentiles_f64(index: &Index) {
let agg_req = json!({
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
}
}
}
});
execute_agg(index, agg_req);
}
@@ -172,6 +181,19 @@ fn terms_few(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_status(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms_status" } },
});
execute_agg(index, agg_req);
}
fn terms_all_unique(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_all_unique_terms" } },
});
execute_agg(index, agg_req);
}
fn terms_many(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms" } },
@@ -220,6 +242,63 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_all_unique_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_all_unique_terms" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_few_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_few_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_status_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -404,14 +483,21 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let json_field = schema_builder.add_json_field("json", FAST);
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
let text_field_few_terms_status =
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let index = Index::create_from_tempdir(schema_builder.build())?;
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
// Approximate production log proportions: INFO dominant, WARN and DEBUG occasional, ERROR rare.
let log_level_distribution = WeightedIndex::new([80u32, 3, 12, 5]).unwrap();
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
@@ -427,15 +513,21 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!())?;
}
if cardinality == Cardinality::Multivalued {
let log_level_sample_a = few_terms_data[log_level_distribution.sample(&mut rng)];
let log_level_sample_b = few_terms_data[log_level_distribution.sample(&mut rng)];
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
text_field => "cool",
text_field => "cool",
text_field_all_unique_terms => "cool",
text_field_all_unique_terms => "coolo",
text_field_many_terms => "cool",
text_field_many_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms_status => log_level_sample_a,
text_field_few_terms_status => log_level_sample_b,
score_field => 1u64,
score_field => 1u64,
score_field_f64 => lg_norm.sample(&mut rng),
@@ -460,8 +552,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!(
text_field => "cool",
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()),
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms_status => few_terms_data[log_level_distribution.sample(&mut rng)],
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,

View File

@@ -16,14 +16,15 @@
// - This bench isolates boolean iteration speed and intersection/union cost.
// - Use `cargo bench --bench boolean_conjunction` to run.
use binggan::{black_box, BenchRunner};
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Count, TopDocs};
use tantivy::query::QueryParser;
use tantivy::schema::{Schema, TEXT};
use tantivy::{doc, Index, ReloadPolicy, Searcher};
use tantivy::collector::sort_key::SortByStaticFastValue;
use tantivy::collector::{Collector, Count, TopDocs};
use tantivy::query::{Query, QueryParser};
use tantivy::schema::{Schema, FAST, TEXT};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
#[derive(Clone)]
struct BenchIndex {
@@ -33,23 +34,6 @@ struct BenchIndex {
query_parser: QueryParser,
}
impl BenchIndex {
#[inline(always)]
fn count_query(&self, query_str: &str) -> usize {
let query = self.query_parser.parse_query(query_str).unwrap();
self.searcher.search(&query, &Count).unwrap()
}
#[inline(always)]
fn topk_len(&self, query_str: &str, k: usize) -> usize {
let query = self.query_parser.parse_query(query_str).unwrap();
self.searcher
.search(&query, &TopDocs::with_limit(k))
.unwrap()
.len()
}
}
/// Build a single index containing both fields (title, body) and
/// return two BenchIndex views:
/// - single_field: QueryParser defaults to only "body"
@@ -59,6 +43,8 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
let mut schema_builder = Schema::builder();
let f_title = schema_builder.add_text_field("title", TEXT);
let f_body = schema_builder.add_text_field("body", TEXT);
let f_score = schema_builder.add_u64_field("score", FAST);
let f_score2 = schema_builder.add_u64_field("score2", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
@@ -67,11 +53,13 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
// Populate: spread each present token 90/10 to body/title
{
let mut writer = index.writer(500_000_000).unwrap();
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
for _ in 0..num_docs {
let has_a = rng.gen_bool(p_a as f64);
let has_b = rng.gen_bool(p_b as f64);
let has_c = rng.gen_bool(p_c as f64);
let score = rng.gen_range(0u64..100u64);
let score2 = rng.gen_range(0u64..100_000u64);
let mut title_tokens: Vec<&str> = Vec::new();
let mut body_tokens: Vec<&str> = Vec::new();
if has_a {
@@ -101,7 +89,9 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
writer
.add_document(doc!(
f_title=>title_tokens.join(" "),
f_body=>body_tokens.join(" ")
f_body=>body_tokens.join(" "),
f_score=>score,
f_score2=>score2,
))
.unwrap();
}
@@ -153,72 +143,76 @@ fn main() {
),
];
let queries = &["a", "+a +b", "+a +b +c", "a OR b", "a OR b OR c"];
let mut runner = BenchRunner::new();
for (label, n, pa, pb, pc) in scenarios {
let (single_view, multi_view) = build_shared_indices(n, pa, pb, pc);
// Single-field group: default field is body only
for (view_name, bench_index) in [("single_field", single_view), ("multi_field", multi_view)]
{
// Single-field group: default field is body only
let mut group = runner.new_group();
group.set_name(format!("single_field — {}", label));
group.register_with_input("+a_+b_count", &single_view, |benv: &BenchIndex| {
black_box(benv.count_query("+a +b"))
});
group.register_with_input("+a_+b_+c_count", &single_view, |benv: &BenchIndex| {
black_box(benv.count_query("+a +b +c"))
});
group.register_with_input("+a_+b_top10", &single_view, |benv: &BenchIndex| {
black_box(benv.topk_len("+a +b", 10))
});
group.register_with_input("+a_+b_+c_top10", &single_view, |benv: &BenchIndex| {
black_box(benv.topk_len("+a +b +c", 10))
});
// OR queries
group.register_with_input("a_OR_b_count", &single_view, |benv: &BenchIndex| {
black_box(benv.count_query("a OR b"))
});
group.register_with_input("a_OR_b_OR_c_count", &single_view, |benv: &BenchIndex| {
black_box(benv.count_query("a OR b OR c"))
});
group.register_with_input("a_OR_b_top10", &single_view, |benv: &BenchIndex| {
black_box(benv.topk_len("a OR b", 10))
});
group.register_with_input("a_OR_b_OR_c_top10", &single_view, |benv: &BenchIndex| {
black_box(benv.topk_len("a OR b OR c", 10))
});
group.run();
}
// Multi-field group: default fields are [title, body]
{
let mut group = runner.new_group();
group.set_name(format!("multi_field — {}", label));
group.register_with_input("+a_+b_count", &multi_view, |benv: &BenchIndex| {
black_box(benv.count_query("+a +b"))
});
group.register_with_input("+a_+b_+c_count", &multi_view, |benv: &BenchIndex| {
black_box(benv.count_query("+a +b +c"))
});
group.register_with_input("+a_+b_top10", &multi_view, |benv: &BenchIndex| {
black_box(benv.topk_len("+a +b", 10))
});
group.register_with_input("+a_+b_+c_top10", &multi_view, |benv: &BenchIndex| {
black_box(benv.topk_len("+a +b +c", 10))
});
// OR queries
group.register_with_input("a_OR_b_count", &multi_view, |benv: &BenchIndex| {
black_box(benv.count_query("a OR b"))
});
group.register_with_input("a_OR_b_OR_c_count", &multi_view, |benv: &BenchIndex| {
black_box(benv.count_query("a OR b OR c"))
});
group.register_with_input("a_OR_b_top10", &multi_view, |benv: &BenchIndex| {
black_box(benv.topk_len("a OR b", 10))
});
group.register_with_input("a_OR_b_OR_c_top10", &multi_view, |benv: &BenchIndex| {
black_box(benv.topk_len("a OR b OR c", 10))
});
group.set_name(format!("{}{}", view_name, label));
for query_str in queries {
add_bench_task(&mut group, &bench_index, query_str, Count, "count");
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by_score(),
"top10",
);
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by_fast_field::<u64>("score", Order::Asc),
"top10_by_ff",
);
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by((
SortByStaticFastValue::<u64>::for_field("score"),
SortByStaticFastValue::<u64>::for_field("score2"),
)),
"top10_by_2ff",
);
}
group.run();
}
}
}
fn add_bench_task<C: Collector + 'static>(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
collector: C,
collector_name: &str,
) {
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let search_task = SearchTask {
searcher: bench_index.searcher.clone(),
collector,
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct SearchTask<C: Collector> {
searcher: Searcher,
collector: C,
query: Box<dyn Query>,
}
impl<C: Collector> SearchTask<C> {
#[inline(never)]
pub fn run(&self) -> usize {
self.searcher.search(&self.query, &self.collector).unwrap();
1
}
}

View File

@@ -258,7 +258,7 @@ mod test {
bitpacker.write(val, num_bits, &mut data).unwrap();
}
bitpacker.close(&mut data).unwrap();
assert_eq!(data.len(), ((num_bits as usize) * len + 7) / 8);
assert_eq!(data.len(), ((num_bits as usize) * len).div_ceil(8));
let bitunpacker = BitUnpacker::new(num_bits);
(bitunpacker, vals, data)
}
@@ -304,7 +304,7 @@ mod test {
bitpacker.write(val, num_bits, &mut buffer).unwrap();
}
bitpacker.flush(&mut buffer).unwrap();
assert_eq!(buffer.len(), (vals.len() * num_bits as usize + 7) / 8);
assert_eq!(buffer.len(), (vals.len() * num_bits as usize).div_ceil(8));
let bitunpacker = BitUnpacker::new(num_bits);
let max_val = if num_bits == 64 {
u64::MAX

View File

@@ -19,7 +19,7 @@ fn u32_to_i32(val: u32) -> i32 {
#[inline]
unsafe fn u32_to_i32_avx2(vals_u32x8s: DataType) -> DataType {
const HIGHEST_BIT_MASK: DataType = from_u32x8([HIGHEST_BIT; NUM_LANES]);
op_xor(vals_u32x8s, HIGHEST_BIT_MASK)
unsafe { op_xor(vals_u32x8s, HIGHEST_BIT_MASK) }
}
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
@@ -66,17 +66,19 @@ unsafe fn filter_vec_avx2_aux(
]);
const SHIFT: __m256i = from_u32x8([NUM_LANES as u32; NUM_LANES]);
for _ in 0..num_words {
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
unsafe {
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
}
}
output_tail.offset_from(output) as usize
unsafe { output_tail.offset_from(output) as usize }
}
#[inline]
@@ -92,8 +94,7 @@ unsafe fn compute_filter_bitset(val: __m256i, range: std::ops::RangeInclusive<__
let too_low = op_greater(*range.start(), val);
let too_high = op_greater(val, *range.end());
let inside = op_or(too_low, too_high);
255 - std::arch::x86_64::_mm256_movemask_ps(std::mem::transmute::<DataType, __m256>(inside))
as u8
255 - std::arch::x86_64::_mm256_movemask_ps(_mm256_castsi256_ps(inside)) as u8
}
union U8x32 {

View File

@@ -73,7 +73,7 @@ The crate introduces the following concepts.
`Columnar` is an equivalent of a dataframe.
It maps `column_key` to `Column`.
A `Column<T>` asssociates a `RowId` (u32) to any
A `Column<T>` associates a `RowId` (u32) to any
number of values.
This is made possible by wrapping a `ColumnIndex` and a `ColumnValue` object.

View File

@@ -89,13 +89,6 @@ fn main() {
black_box(sum);
});
group.register("first_block_fetch", |column| {
let mut block: Vec<Option<u64>> = vec![None; 64];
let fetch_docids = (0..64).collect::<Vec<_>>();
column.first_vals(&fetch_docids, &mut block);
black_box(block[0]);
});
group.register("first_block_single_calls", |column| {
let mut block: Vec<Option<u64>> = vec![None; 64];
let fetch_docids = (0..64).collect::<Vec<_>>();

View File

@@ -131,6 +131,8 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
self.index.docids_to_rowids(doc_ids, doc_ids_out, row_ids)
}
/// Get an iterator over the values for the provided docid.
#[inline]
pub fn values_for_doc(&self, doc_id: DocId) -> impl Iterator<Item = T> + '_ {
self.index
.value_row_ids(doc_id)
@@ -158,15 +160,6 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
.select_batch_in_place(selected_docid_range.start, doc_ids);
}
/// Fills the output vector with the (possibly multiple values that are associated_with
/// `row_id`.
///
/// This method clears the `output` vector.
pub fn fill_vals(&self, row_id: RowId, output: &mut Vec<T>) {
output.clear();
output.extend(self.values_for_doc(row_id));
}
pub fn first_or_default_col(self, default_value: T) -> Arc<dyn ColumnValues<T>> {
Arc::new(FirstValueWithDefault {
column: self,

View File

@@ -1,7 +1,7 @@
use std::fmt::Debug;
use std::net::Ipv6Addr;
/// Montonic maps a value to u128 value space
/// Monotonic maps a value to u128 value space
/// Monotonic mapping enables `PartialOrd` on u128 space without conversion to original space.
pub trait MonotonicallyMappableToU128: 'static + PartialOrd + Copy + Debug + Send + Sync {
/// Converts a value to u128.

View File

@@ -8,7 +8,7 @@ use crate::column_values::ColumnValues;
const MID_POINT: u64 = (1u64 << 32) - 1u64;
/// `Line` describes a line function `y: ax + b` using integer
/// arithmetics.
/// arithmetic.
///
/// The slope is in fact a decimal split into a 32 bit integer value,
/// and a 32-bit decimal value.
@@ -94,7 +94,7 @@ impl Line {
// `(i, ys[])`.
//
// The best intercept therefore has the form
// `y[i] - line.eval(i)` (using wrapping arithmetics).
// `y[i] - line.eval(i)` (using wrapping arithmetic).
// In other words, the best intercept is one of the `y - Line::eval(ys[i])`
// and our task is just to pick the one that minimizes our error.
//

View File

@@ -52,7 +52,7 @@ pub trait ColumnCodecEstimator<T = u64>: 'static {
) -> io::Result<()>;
}
/// A column codec describes a colunm serialization format.
/// A column codec describes a column serialization format.
pub trait ColumnCodec<T: PartialOrd = u64> {
/// Specialized `ColumnValues` type.
type ColumnValues: ColumnValues<T> + 'static;

View File

@@ -3,7 +3,8 @@ use std::sync::Arc;
use std::{fmt, io};
use common::file_slice::FileSlice;
use common::{ByteCount, DateTime, HasLen, OwnedBytes};
use common::{ByteCount, DateTime, OwnedBytes};
use serde::{Deserialize, Serialize};
use crate::column::{BytesColumn, Column, StrColumn};
use crate::column_values::{StrictlyMonotonicFn, monotonic_map_column};
@@ -317,10 +318,89 @@ impl DynamicColumnHandle {
}
pub fn num_bytes(&self) -> ByteCount {
self.file_slice.len().into()
self.file_slice.num_bytes()
}
/// Legacy helper returning the column space usage.
pub fn column_and_dictionary_num_bytes(&self) -> io::Result<ColumnSpaceUsage> {
self.space_usage()
}
/// Return the space usage of the column, optionally broken down by dictionary and column
/// values.
///
/// For dictionary encoded columns (strings and bytes), this splits the total footprint into
/// the dictionary and the remaining column data (including index and values).
/// For all other column types, the dictionary size is `None` and the column size
/// equals the total bytes.
pub fn space_usage(&self) -> io::Result<ColumnSpaceUsage> {
let total_num_bytes = self.num_bytes();
let dynamic_column = self.open()?;
let dictionary_num_bytes = match &dynamic_column {
DynamicColumn::Bytes(bytes_column) => bytes_column.dictionary().num_bytes(),
DynamicColumn::Str(str_column) => str_column.dictionary().num_bytes(),
_ => {
return Ok(ColumnSpaceUsage::new(self.num_bytes(), None));
}
};
assert!(dictionary_num_bytes <= total_num_bytes);
let column_num_bytes =
ByteCount::from(total_num_bytes.get_bytes() - dictionary_num_bytes.get_bytes());
Ok(ColumnSpaceUsage::new(
column_num_bytes,
Some(dictionary_num_bytes),
))
}
pub fn column_type(&self) -> ColumnType {
self.column_type
}
}
/// Represents space usage of a column.
///
/// `column_num_bytes` tracks the column payload (index, values and footer).
/// For dictionary encoded columns, `dictionary_num_bytes` captures the dictionary footprint.
/// [`ColumnSpaceUsage::total_num_bytes`] returns the sum of both parts.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ColumnSpaceUsage {
column_num_bytes: ByteCount,
dictionary_num_bytes: Option<ByteCount>,
}
impl ColumnSpaceUsage {
pub(crate) fn new(
column_num_bytes: ByteCount,
dictionary_num_bytes: Option<ByteCount>,
) -> Self {
ColumnSpaceUsage {
column_num_bytes,
dictionary_num_bytes,
}
}
pub fn column_num_bytes(&self) -> ByteCount {
self.column_num_bytes
}
pub fn dictionary_num_bytes(&self) -> Option<ByteCount> {
self.dictionary_num_bytes
}
pub fn total_num_bytes(&self) -> ByteCount {
self.column_num_bytes + self.dictionary_num_bytes.unwrap_or_default()
}
/// Merge two space usage values by summing their components.
pub fn merge(&self, other: &ColumnSpaceUsage) -> ColumnSpaceUsage {
let dictionary_num_bytes = match (self.dictionary_num_bytes, other.dictionary_num_bytes) {
(Some(lhs), Some(rhs)) => Some(lhs + rhs),
(Some(val), None) | (None, Some(val)) => Some(val),
(None, None) => None,
};
ColumnSpaceUsage {
column_num_bytes: self.column_num_bytes + other.column_num_bytes,
dictionary_num_bytes,
}
}
}

View File

@@ -48,7 +48,7 @@ pub use columnar::{
use sstable::VoidSSTable;
pub use value::{NumericalType, NumericalValue};
pub use self::dynamic_column::{DynamicColumn, DynamicColumnHandle};
pub use self::dynamic_column::{ColumnSpaceUsage, DynamicColumn, DynamicColumnHandle};
pub type RowId = u32;
pub type DocId = u32;

View File

@@ -28,7 +28,9 @@ impl BinarySerializable for VIntU128 {
writer.write_all(&buffer)
}
#[allow(clippy::unbuffered_bytes)]
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
#[allow(clippy::unbuffered_bytes)]
let mut bytes = reader.bytes();
let mut result = 0u128;
let mut shift = 0u64;
@@ -195,7 +197,9 @@ impl BinarySerializable for VInt {
writer.write_all(&buffer[0..num_bytes])
}
#[allow(clippy::unbuffered_bytes)]
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
#[allow(clippy::unbuffered_bytes)]
let mut bytes = reader.bytes();
let mut result = 0u64;
let mut shift = 0u64;

View File

@@ -208,7 +208,7 @@ fn main() -> tantivy::Result<()> {
// is the role of the `TopDocs` collector.
// We can now perform our query.
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
// The actual documents still need to be
// retrieved from Tantivy's store.
@@ -226,7 +226,7 @@ fn main() -> tantivy::Result<()> {
let query = query_parser.parse_query("title:sea^20 body:whale^70")?;
let (_score, doc_address) = searcher
.search(&query, &TopDocs::with_limit(1))?
.search(&query, &TopDocs::with_limit(1).order_by_score())?
.into_iter()
.next()
.unwrap();

View File

@@ -100,7 +100,7 @@ fn main() -> tantivy::Result<()> {
// here we want to get a hit on the 'ken' in Frankenstein
let query = query_parser.parse_query("ken")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
for (_, doc_address) in top_docs {
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;

View File

@@ -50,14 +50,14 @@ fn main() -> tantivy::Result<()> {
{
// Simple exact search on the date
let query = query_parser.parse_query("occurred_at:\"2022-06-22T12:53:50.53Z\"")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
assert_eq!(count_docs.len(), 1);
}
{
// Range query on the date field
let query = query_parser
.parse_query(r#"occurred_at:[2022-06-22T12:58:00Z TO 2022-06-23T00:00:00Z}"#)?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4).order_by_score())?;
assert_eq!(count_docs.len(), 1);
for (_score, doc_address) in count_docs {
let retrieved_doc = searcher.doc::<TantivyDocument>(doc_address)?;

View File

@@ -28,7 +28,7 @@ fn extract_doc_given_isbn(
// The second argument is here to tell we don't care about decoding positions,
// or term frequencies.
let term_query = TermQuery::new(isbn_term.clone(), IndexRecordOption::Basic);
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1))?;
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1).order_by_score())?;
if let Some((_score, doc_address)) = top_docs.first() {
let doc = searcher.doc(*doc_address)?;

View File

@@ -145,7 +145,7 @@ fn main() -> tantivy::Result<()> {
let query = FuzzyTermQuery::new(term, 2, true);
let (top_docs, count) = searcher
.search(&query, &(TopDocs::with_limit(5), Count))
.search(&query, &(TopDocs::with_limit(5).order_by_score(), Count))
.unwrap();
assert_eq!(count, 3);
assert_eq!(top_docs.len(), 3);

View File

@@ -69,25 +69,25 @@ fn main() -> tantivy::Result<()> {
{
// Inclusive range queries
let query = query_parser.parse_query("ip:[192.168.0.80 TO 192.168.0.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
assert_eq!(count_docs.len(), 1);
}
{
// Exclusive range queries
let query = query_parser.parse_query("ip:{192.168.0.80 TO 192.168.1.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 0);
}
{
// Find docs with IP addresses smaller equal 192.168.1.100
let query = query_parser.parse_query("ip:[* TO 192.168.1.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 2);
}
{
// Find docs with IP addresses smaller than 192.168.1.100
let query = query_parser.parse_query("ip:[* TO 192.168.1.100}")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 2);
}

View File

@@ -59,12 +59,12 @@ fn main() -> tantivy::Result<()> {
let query_parser = QueryParser::for_index(&index, vec![event_type, attributes]);
{
let query = query_parser.parse_query("target:submit-button")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 2);
}
{
let query = query_parser.parse_query("target:submit")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 2);
}
{
@@ -74,33 +74,33 @@ fn main() -> tantivy::Result<()> {
}
{
let query = query_parser.parse_query("click AND cart.product_id:133")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 1);
}
{
// The sub-fields in the json field marked as default field still need to be explicitly
// addressed
let query = query_parser.parse_query("click AND 133")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 0);
}
{
// Default json fields are ignored if they collide with the schema
let query = query_parser.parse_query("event_type:holiday-sale")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 0);
}
// # Query via full attribute path
{
// This only searches in our schema's `event_type` field
let query = query_parser.parse_query("event_type:click")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 2);
}
{
// Default json fields can still be accessed by full path
let query = query_parser.parse_query("attributes.event_type:holiday-sale")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 1);
}
Ok(())

View File

@@ -63,7 +63,7 @@ fn main() -> Result<()> {
// but not "in the Gulf Stream".
let query = query_parser.parse_query("\"in the su\"*")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
let mut titles = top_docs
.into_iter()
.map(|(_score, doc_address)| {

View File

@@ -107,7 +107,8 @@ fn main() -> tantivy::Result<()> {
IndexRecordOption::Basic,
);
let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
let (top_docs, count) =
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
assert_eq!(count, 2);
@@ -128,7 +129,8 @@ fn main() -> tantivy::Result<()> {
IndexRecordOption::Basic,
);
let (_top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
let (_top_docs, count) =
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
assert_eq!(count, 0);

View File

@@ -50,7 +50,7 @@ fn main() -> tantivy::Result<()> {
let query_parser = QueryParser::for_index(&index, vec![title, body]);
let query = query_parser.parse_query("sycamore spring")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
let snippet_generator = SnippetGenerator::create(&searcher, &*query, body)?;

View File

@@ -102,7 +102,7 @@ fn main() -> tantivy::Result<()> {
// stop words are applied on the query as well.
// The following will be equivalent to `title:frankenstein`
let query = query_parser.parse_query("title:\"the Frankenstein\"")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
for (score, doc_address) in top_docs {
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;

View File

@@ -164,7 +164,7 @@ fn main() -> tantivy::Result<()> {
move |doc_id: DocId| Reverse(price[doc_id as usize])
};
let most_expensive_first = TopDocs::with_limit(10).custom_score(score_by_price);
let most_expensive_first = TopDocs::with_limit(10).order_by(score_by_price);
let hits = searcher.search(&query, &most_expensive_first)?;
assert_eq!(

View File

@@ -758,7 +758,17 @@ fn negate(expr: UserInputAst) -> UserInputAst {
fn leaf(inp: &str) -> IResult<&str, UserInputAst> {
alt((
delimited(char('('), ast, char(')')),
map(char('*'), |_| UserInputAst::from(UserInputLeaf::All)),
map(
terminated(
char('*'),
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
),
|_| UserInputAst::from(UserInputLeaf::All),
),
map(preceded(tuple((tag("NOT"), multispace1)), leaf), negate),
literal,
))(inp)
@@ -779,7 +789,17 @@ fn leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
),
),
(
value((), char('*')),
value(
(),
terminated(
char('*'),
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
),
),
map(nothing, |_| {
(Some(UserInputAst::from(UserInputLeaf::All)), Vec::new())
}),
@@ -1671,6 +1691,21 @@ mod test {
test_parse_query_to_ast_helper("abc:a b", "(*\"abc\":a *b)");
test_parse_query_to_ast_helper("abc:\"a b\"", "\"abc\":\"a b\"");
test_parse_query_to_ast_helper("foo:[1 TO 5]", "\"foo\":[\"1\" TO \"5\"]");
// Phrase prefixed with *
test_parse_query_to_ast_helper("foo:(*A)", "\"foo\":*A");
test_parse_query_to_ast_helper("*A", "*A");
test_parse_query_to_ast_helper("(*A)", "*A");
test_parse_query_to_ast_helper("foo:(A OR B)", "(?\"foo\":A ?\"foo\":B)");
test_parse_query_to_ast_helper("foo:(A* OR B*)", "(?\"foo\":A* ?\"foo\":B*)");
test_parse_query_to_ast_helper("foo:(*A OR *B)", "(?\"foo\":*A ?\"foo\":*B)");
}
#[test]
fn test_parse_query_all() {
test_parse_query_to_ast_helper("*", "*");
test_parse_query_to_ast_helper("(*)", "*");
test_parse_query_to_ast_helper("(* )", "*");
}
#[test]

View File

@@ -16,15 +16,16 @@ use crate::index::SegmentReader;
/// That way we can use it the same way as if it would come from the fastfield.
pub(crate) fn get_missing_val_as_u64_lenient(
column_type: ColumnType,
column_max_value: u64,
missing: &Key,
field_name: &str,
) -> crate::Result<Option<u64>> {
let missing_val = match missing {
Key::Str(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::Str(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
// Allow fallback to number on text fields
Key::F64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::U64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::I64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::F64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::U64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::I64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::F64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val, &column_type)
}

View File

@@ -10,10 +10,10 @@ use crate::aggregation::accessor_helpers::{
};
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{
build_segment_aggregation_collector, FilterAggReqData, HistogramAggReqData, HistogramBounds,
IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector,
SegmentHistogramCollector, SegmentRangeCollector, TermMissingAgg, TermsAggReqData,
TermsAggregation, TermsAggregationInternal,
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
};
use crate::aggregation::metric::{
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
@@ -373,7 +373,7 @@ pub(crate) fn build_segment_agg_collector(
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match node.kind {
AggKind::Terms => build_segment_aggregation_collector(req, node),
AggKind::Terms => crate::aggregation::bucket::build_segment_term_collector(req, node),
AggKind::MissingTerm => {
let req_data = &mut req.per_request.missing_term_req_data[node.idx_in_req_data];
if req_data.accessors.is_empty() {
@@ -496,7 +496,7 @@ pub(crate) fn build_aggregations_data_from_req(
};
for (name, agg) in aggs.iter() {
let nodes = build_nodes(name, agg, reader, segment_ordinal, &mut data)?;
let nodes = build_nodes(name, agg, reader, segment_ordinal, &mut data, true)?;
data.per_request.agg_tree.extend(nodes);
}
Ok(data)
@@ -508,6 +508,7 @@ fn build_nodes(
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
is_top_level: bool,
) -> crate::Result<Vec<AggRefNode>> {
use AggregationVariants::*;
match &req.agg {
@@ -594,6 +595,7 @@ fn build_nodes(
data,
&req.sub_aggregation,
TermsOrCardinalityRequest::Terms(terms_req.clone()),
is_top_level,
),
Cardinality(card_req) => build_terms_or_cardinality_nodes(
agg_name,
@@ -604,6 +606,7 @@ fn build_nodes(
data,
&req.sub_aggregation,
TermsOrCardinalityRequest::Cardinality(card_req.clone()),
is_top_level,
),
Average(AverageAggregation { field, missing, .. })
| Max(MaxAggregation { field, missing, .. })
@@ -732,7 +735,7 @@ fn build_nodes(
// Build the query and evaluator upfront
let schema = reader.schema();
let tokenizers = &data.context.tokenizers;
let query = filter_req.parse_query(&schema, tokenizers)?;
let query = filter_req.parse_query(schema, tokenizers)?;
let evaluator = crate::aggregation::bucket::DocumentQueryEvaluator::new(
query,
schema.clone(),
@@ -769,7 +772,14 @@ fn build_children(
) -> crate::Result<Vec<AggRefNode>> {
let mut children = Vec::new();
for (name, agg) in aggs.iter() {
children.extend(build_nodes(name, agg, reader, segment_ordinal, data)?);
children.extend(build_nodes(
name,
agg,
reader,
segment_ordinal,
data,
false,
)?);
}
Ok(children)
}
@@ -833,6 +843,7 @@ fn build_terms_or_cardinality_nodes(
data: &mut AggregationsSegmentCtx,
sub_aggs: &Aggregations,
req: TermsOrCardinalityRequest,
is_top_level: bool,
) -> crate::Result<Vec<AggRefNode>> {
let mut nodes = Vec::new();
@@ -889,7 +900,7 @@ fn build_terms_or_cardinality_nodes(
let missing_value_for_accessor = if use_special_missing_agg {
None
} else if let Some(m) = missing.as_ref() {
get_missing_val_as_u64_lenient(column_type, m, field_name)?
get_missing_val_as_u64_lenient(column_type, accessor.max_value(), m, field_name)?
} else {
None
};
@@ -922,6 +933,7 @@ fn build_terms_or_cardinality_nodes(
sub_aggregation_blueprint: None,
sug_aggregations: sub_aggs.clone(),
allowed_term_ids,
is_top_level,
});
(idx_in_req_data, AggKind::Terms)
}

View File

@@ -35,6 +35,7 @@ pub struct AggregationLimitsGuard {
/// Allocated memory with this guard.
allocated_with_the_guard: u64,
}
impl Clone for AggregationLimitsGuard {
fn clone(&self) -> Self {
Self {

View File

@@ -16,7 +16,7 @@ use super::{AggregationError, Key};
use crate::TantivyError;
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
/// The final aggegation result.
/// The final aggregation result.
pub struct AggregationResults(pub FxHashMap<String, AggregationResult>);
impl AggregationResults {

View File

@@ -32,7 +32,7 @@ use crate::{DocId, SegmentReader, TantivyError};
///
/// # Implementation Requirements
///
/// Implementors must:
/// Implementers must:
/// 1. Derive `Debug`, `Clone`, `Serialize`, and `Deserialize`
/// 2. Use `#[typetag::serde]` attribute on the impl block
/// 3. Implement `build_query()` to construct the query from schema/tokenizers
@@ -639,16 +639,14 @@ pub struct IntermediateFilterBucketResult {
#[cfg(test)]
mod tests {
use std::time::Instant;
use serde_json::{json, Value};
use super::*;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::{AggContextParams, AggregationCollector};
use crate::query::{AllQuery, QueryParser, TermQuery};
use crate::schema::{IndexRecordOption, Schema, Term, FAST, INDEXED, STORED, TEXT};
use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, Term, FAST, INDEXED, TEXT};
use crate::{doc, Index, IndexWriter};
// Test helper functions
@@ -729,12 +727,13 @@ mod tests {
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer: IndexWriter = index.writer(50_000_000)?;
let mut writer: IndexWriter = index.writer_for_tests()?;
writer.add_document(doc!(
category => "electronics", brand => "apple",
price => 999u64, rating => 4.5f64, in_stock => true
))?;
writer.commit()?;
writer.add_document(doc!(
category => "electronics", brand => "samsung",
price => 799u64, rating => 4.2f64, in_stock => true
@@ -938,7 +937,7 @@ mod tests {
let index = create_standard_test_index()?;
let reader = index.reader()?;
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 2);
let agg = json!({
"premium_electronics": {
"filter": "category:electronics AND price:[800 TO *]",

View File

@@ -1,196 +0,0 @@
use std::fmt::Debug;
use columnar::ColumnType;
use rustc_hash::FxHashMap;
use super::OrderTarget;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::bucket::get_agg_name_and_property;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::TantivyError;
#[derive(Clone, Debug, Default)]
/// Container to store term_ids/or u64 values and their buckets.
struct TermBuckets {
pub(crate) entries: FxHashMap<u64, u32>,
pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
}
impl TermBuckets {
fn get_memory_consumption(&self) -> usize {
let sub_aggs_mem = self.sub_aggs.memory_consumption();
let buckets_mem = self.entries.memory_consumption();
sub_aggs_mem + buckets_mem
}
fn force_flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for sub_aggregations in &mut self.sub_aggs.values_mut() {
sub_aggregations.as_mut().flush(agg_data)?;
}
Ok(())
}
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug)]
pub struct SegmentTermCollector {
/// The buckets containing the aggregation data.
term_buckets: TermBuckets,
accessor_idx: usize,
}
impl SegmentAggregationCollector for SegmentTermCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_data.get_term_req_data(self.accessor_idx).name.clone();
let entries: Vec<(u64, u32)> = self.term_buckets.entries.into_iter().collect();
let bucket = super::into_intermediate_bucket_result(
self.accessor_idx,
entries,
self.term_buckets.sub_aggs,
agg_data,
)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mut req_data = agg_data.take_term_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
if let Some(missing) = req_data.missing_value_for_accessor {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for term_id in req_data.column_block_accessor.iter_vals() {
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
let entry = self.term_buckets.entries.entry(term_id).or_default();
*entry += 1;
}
// has subagg
if let Some(blueprint) = req_data.sub_aggregation_blueprint.as_ref() {
for (doc, term_id) in req_data
.column_block_accessor
.iter_docid_vals(docs, &req_data.accessor)
{
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
let sub_aggregations = self
.term_buckets
.sub_aggs
.entry(term_id)
.or_insert_with(|| blueprint.clone());
sub_aggregations.collect(doc, agg_data)?;
}
}
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
}
agg_data.put_back_term_req_data(self.accessor_idx, req_data);
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
self.term_buckets.force_flush(agg_data)?;
Ok(())
}
}
impl SegmentTermCollector {
pub fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data);
let column_type = terms_req_data.column_type;
let accessor_idx = node.idx_in_req_data;
if column_type == ColumnType::Bytes {
return Err(TantivyError::InvalidArgument(format!(
"terms aggregation is not supported for column type {column_type:?}"
)));
}
let term_buckets = TermBuckets::default();
// Validate sub aggregation exists
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
))
})?;
}
let has_sub_aggregations = !node.children.is_empty();
let blueprint = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data);
terms_req_data.sub_aggregation_blueprint = blueprint;
Ok(SegmentTermCollector {
term_buckets,
accessor_idx,
})
}
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let term_buckets_mem = self.term_buckets.get_memory_consumption();
self_mem + term_buckets_mem
}
}

View File

@@ -1,228 +0,0 @@
use std::vec;
use rustc_hash::FxHashMap;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::{get_agg_name_and_property, OrderTarget};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::{DocId, TantivyError};
const MAX_BATCH_SIZE: usize = 1_024;
#[derive(Debug, Clone)]
struct LowCardTermBuckets {
entries: Box<[u32]>,
sub_aggs: Vec<Box<dyn SegmentAggregationCollector>>,
doc_buffers: Box<[Vec<DocId>]>,
}
impl LowCardTermBuckets {
pub fn with_num_buckets(
num_buckets: usize,
sub_aggs_blueprint_opt: Option<&Box<dyn SegmentAggregationCollector>>,
) -> Self {
let sub_aggs = sub_aggs_blueprint_opt
.as_ref()
.map(|blueprint| {
std::iter::repeat_with(|| blueprint.clone_box())
.take(num_buckets)
.collect::<Vec<_>>()
})
.unwrap_or_default();
Self {
entries: vec![0; num_buckets].into_boxed_slice(),
sub_aggs,
doc_buffers: std::iter::repeat_with(|| Vec::with_capacity(MAX_BATCH_SIZE))
.take(num_buckets)
.collect::<Vec<_>>()
.into_boxed_slice(),
}
}
fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
+ self.entries.len() * std::mem::size_of::<u32>()
+ self.doc_buffers.len()
* (std::mem::size_of::<Vec<DocId>>()
+ std::mem::size_of::<DocId>() * MAX_BATCH_SIZE)
}
}
#[derive(Debug, Clone)]
pub struct LowCardSegmentTermCollector {
term_buckets: LowCardTermBuckets,
accessor_idx: usize,
}
impl LowCardSegmentTermCollector {
pub fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data);
let accessor_idx = node.idx_in_req_data;
let cardinality = terms_req_data
.accessor
.max_value()
.max(terms_req_data.missing_value_for_accessor.unwrap_or(0))
+ 1;
assert!(cardinality <= super::LOW_CARDINALITY_THRESHOLD);
// Validate sub aggregation exists
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
))
})?;
}
let has_sub_aggregations = !node.children.is_empty();
let blueprint = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data);
let term_buckets =
LowCardTermBuckets::with_num_buckets(cardinality as usize, blueprint.as_ref());
terms_req_data.sub_aggregation_blueprint = blueprint;
Ok(LowCardSegmentTermCollector {
term_buckets,
accessor_idx,
})
}
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let term_buckets_mem = self.term_buckets.get_memory_consumption();
self_mem + term_buckets_mem
}
}
impl SegmentAggregationCollector for LowCardSegmentTermCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_data.get_term_req_data(self.accessor_idx).name.clone();
let sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>> = self
.term_buckets
.sub_aggs
.into_iter()
.enumerate()
.filter(|(bucket_id, _sub_agg)| self.term_buckets.entries[*bucket_id] > 0)
.map(|(bucket_id, sub_agg)| (bucket_id as u64, sub_agg))
.collect();
let entries: Vec<(u64, u32)> = self
.term_buckets
.entries
.iter()
.enumerate()
.filter(|(_, count)| **count > 0)
.map(|(bucket_id, count)| (bucket_id as u64, *count))
.collect();
let bucket =
super::into_intermediate_bucket_result(self.accessor_idx, entries, sub_aggs, agg_data)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if docs.len() > MAX_BATCH_SIZE {
for batch in docs.chunks(MAX_BATCH_SIZE) {
self.collect_block(batch, agg_data)?;
}
}
let mut req_data = agg_data.take_term_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
if let Some(missing) = req_data.missing_value_for_accessor {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
// has subagg
if req_data.sub_aggregation_blueprint.is_some() {
for (doc, term_id) in req_data
.column_block_accessor
.iter_docid_vals(docs, &req_data.accessor)
{
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
self.term_buckets.doc_buffers[term_id as usize].push(doc);
}
for (bucket_id, docs) in self.term_buckets.doc_buffers.iter_mut().enumerate() {
self.term_buckets.entries[bucket_id] += docs.len() as u32;
self.term_buckets.sub_aggs[bucket_id].collect_block(&docs[..], agg_data)?;
docs.clear();
}
} else {
for term_id in req_data.column_block_accessor.iter_vals() {
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
self.term_buckets.entries[term_id as usize] += 1;
}
}
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
}
agg_data.put_back_term_req_data(self.accessor_idx, req_data);
Ok(())
}
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for sub_aggregations in &mut self.term_buckets.sub_aggs.iter_mut() {
sub_aggregations.as_mut().flush(agg_data)?;
}
Ok(())
}
}

View File

@@ -3,7 +3,12 @@ use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::DocId;
#[cfg(test)]
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
#[cfg(not(test))]
pub(crate) const DOC_BLOCK_SIZE: usize = 256;
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
/// BufAggregationCollector buffers documents before calling collect_block().
@@ -15,7 +20,7 @@ pub(crate) struct BufAggregationCollector {
}
impl std::fmt::Debug for BufAggregationCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SegmentAggregationResultsCollector")
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
.field("num_staged_docs", &self.num_staged_docs)
@@ -66,7 +71,6 @@ impl SegmentAggregationCollector for BufAggregationCollector {
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collector.collect_block(docs, agg_data)?;
Ok(())
}

View File

@@ -62,7 +62,7 @@ impl ExtendedStatsAggregation {
/// Extended stats contains a collection of statistics
/// they extends stats adding variance, standard deviation
/// and bound informations
/// and bound information
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ExtendedStats {
/// The number of documents.

View File

@@ -16,6 +16,7 @@ use crate::aggregation::intermediate_agg_result::{
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::AggregationError;
use crate::collector::sort_key::ReverseComparator;
use crate::collector::TopNComputer;
use crate::schema::OwnedValue;
use crate::{DocAddress, DocId, SegmentOrdinal};
@@ -458,7 +459,7 @@ impl Eq for DocSortValuesAndFields {}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct TopHitsTopNComputer {
req: TopHitsAggregationReq,
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, false>,
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, ReverseComparator>,
}
impl std::cmp::PartialEq for TopHitsTopNComputer {
@@ -482,7 +483,7 @@ impl TopHitsTopNComputer {
pub(crate) fn merge_fruits(&mut self, other_fruit: Self) -> crate::Result<()> {
for doc in other_fruit.top_n.into_vec() {
self.collect(doc.feature, doc.doc);
self.collect(doc.sort_key, doc.doc);
}
Ok(())
}
@@ -494,9 +495,9 @@ impl TopHitsTopNComputer {
.into_sorted_vec()
.into_iter()
.map(|doc| TopHitsVecEntry {
sort: doc.feature.sorts.iter().map(|f| f.value).collect(),
sort: doc.sort_key.sorts.iter().map(|f| f.value).collect(),
doc_value_fields: doc
.feature
.sort_key
.doc_value_fields
.into_iter()
.map(|(k, v)| (k, v.into()))
@@ -517,7 +518,7 @@ impl TopHitsTopNComputer {
pub(crate) struct TopHitsSegmentCollector {
segment_ordinal: SegmentOrdinal,
accessor_idx: usize,
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, false>,
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>,
}
impl TopHitsSegmentCollector {
@@ -544,7 +545,7 @@ impl TopHitsSegmentCollector {
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
top_hits_computer.collect(
DocSortValuesAndFields {
sorts: res.feature,
sorts: res.sort_key,
doc_value_fields,
},
res.doc,
@@ -645,6 +646,7 @@ mod tests {
use crate::aggregation::bucket::tests::get_test_index_from_docs;
use crate::aggregation::tests::get_test_index_from_values;
use crate::aggregation::AggregationCollector;
use crate::collector::sort_key::ReverseComparator;
use crate::collector::ComparableDoc;
use crate::query::AllQuery;
use crate::schema::OwnedValue;
@@ -660,7 +662,7 @@ mod tests {
fn collector_with_capacity(capacity: usize) -> super::TopHitsTopNComputer {
super::TopHitsTopNComputer {
top_n: super::TopNComputer::new(capacity),
top_n: super::TopNComputer::new_with_comparator(capacity, ReverseComparator),
req: Default::default(),
}
}
@@ -774,12 +776,12 @@ mod tests {
#[test]
fn test_top_hits_collector_single_feature() -> crate::Result<()> {
let docs = vec![
ComparableDoc::<_, _, false> {
ComparableDoc::<_, _> {
doc: crate::DocAddress {
segment_ord: 0,
doc_id: 0,
},
feature: DocSortValuesAndFields {
sort_key: DocSortValuesAndFields {
sorts: vec![DocValueAndOrder {
value: Some(1),
order: Order::Asc,
@@ -792,7 +794,7 @@ mod tests {
segment_ord: 0,
doc_id: 2,
},
feature: DocSortValuesAndFields {
sort_key: DocSortValuesAndFields {
sorts: vec![DocValueAndOrder {
value: Some(3),
order: Order::Asc,
@@ -805,7 +807,7 @@ mod tests {
segment_ord: 0,
doc_id: 1,
},
feature: DocSortValuesAndFields {
sort_key: DocSortValuesAndFields {
sorts: vec![DocValueAndOrder {
value: Some(5),
order: Order::Asc,
@@ -817,7 +819,7 @@ mod tests {
let mut collector = collector_with_capacity(3);
for doc in docs.clone() {
collector.collect(doc.feature, doc.doc);
collector.collect(doc.sort_key, doc.doc);
}
let res = collector.into_final_result();
@@ -827,15 +829,15 @@ mod tests {
super::TopHitsMetricResult {
hits: vec![
super::TopHitsVecEntry {
sort: vec![docs[0].feature.sorts[0].value],
sort: vec![docs[0].sort_key.sorts[0].value],
doc_value_fields: Default::default(),
},
super::TopHitsVecEntry {
sort: vec![docs[1].feature.sorts[0].value],
sort: vec![docs[1].sort_key.sorts[0].value],
doc_value_fields: Default::default(),
},
super::TopHitsVecEntry {
sort: vec![docs[2].feature.sorts[0].value],
sort: vec![docs[2].sort_key.sorts[0].value],
doc_value_fields: Default::default(),
},
]

View File

@@ -17,14 +17,11 @@ pub trait SegmentAggregationCollector: CollectorClone + Debug {
results: &mut IntermediateAggregationResults,
) -> crate::Result<()>;
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
) -> crate::Result<()>;
fn collect_block(
&mut self,

View File

@@ -1,121 +0,0 @@
use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
use crate::collector::{Collector, SegmentCollector};
use crate::{DocAddress, DocId, Score, SegmentReader};
pub(crate) struct CustomScoreTopCollector<TCustomScorer, TScore = Score> {
custom_scorer: TCustomScorer,
collector: TopCollector<TScore>,
}
impl<TCustomScorer, TScore> CustomScoreTopCollector<TCustomScorer, TScore>
where TScore: Clone + PartialOrd
{
pub(crate) fn new(
custom_scorer: TCustomScorer,
collector: TopCollector<TScore>,
) -> CustomScoreTopCollector<TCustomScorer, TScore> {
CustomScoreTopCollector {
custom_scorer,
collector,
}
}
}
/// A custom segment scorer makes it possible to define any kind of score
/// for a given document belonging to a specific segment.
///
/// It is the segment local version of the [`CustomScorer`].
pub trait CustomSegmentScorer<TScore>: 'static {
/// Computes the score of a specific `doc`.
fn score(&mut self, doc: DocId) -> TScore;
}
/// `CustomScorer` makes it possible to define any kind of score.
///
/// The `CustomerScorer` itself does not make much of the computation itself.
/// Instead, it helps constructing `Self::Child` instances that will compute
/// the score at a segment scale.
pub trait CustomScorer<TScore>: Sync {
/// Type of the associated [`CustomSegmentScorer`].
type Child: CustomSegmentScorer<TScore>;
/// Builds a child scorer for a specific segment. The child scorer is associated with
/// a specific segment.
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child>;
}
impl<TCustomScorer, TScore> Collector for CustomScoreTopCollector<TCustomScorer, TScore>
where
TCustomScorer: CustomScorer<TScore> + Send + Sync,
TScore: 'static + PartialOrd + Clone + Send + Sync,
{
type Fruit = Vec<(TScore, DocAddress)>;
type Child = CustomScoreTopSegmentCollector<TCustomScorer::Child, TScore>;
fn for_segment(
&self,
segment_local_id: u32,
segment_reader: &SegmentReader,
) -> crate::Result<Self::Child> {
let segment_collector = self.collector.for_segment(segment_local_id, segment_reader);
let segment_scorer = self.custom_scorer.segment_scorer(segment_reader)?;
Ok(CustomScoreTopSegmentCollector {
segment_collector,
segment_scorer,
})
}
fn requires_scoring(&self) -> bool {
false
}
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> crate::Result<Self::Fruit> {
self.collector.merge_fruits(segment_fruits)
}
}
pub struct CustomScoreTopSegmentCollector<T, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync + Sized,
T: CustomSegmentScorer<TScore>,
{
segment_collector: TopSegmentCollector<TScore>,
segment_scorer: T,
}
impl<T, TScore> SegmentCollector for CustomScoreTopSegmentCollector<T, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync,
T: 'static + CustomSegmentScorer<TScore>,
{
type Fruit = Vec<(TScore, DocAddress)>;
fn collect(&mut self, doc: DocId, _score: Score) {
let score = self.segment_scorer.score(doc);
self.segment_collector.collect(doc, score);
}
fn harvest(self) -> Vec<(TScore, DocAddress)> {
self.segment_collector.harvest()
}
}
impl<F, TScore, T> CustomScorer<TScore> for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> T,
T: CustomSegmentScorer<TScore>,
{
type Child = T;
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child> {
Ok((self)(segment_reader))
}
}
impl<F, TScore> CustomSegmentScorer<TScore> for F
where F: 'static + FnMut(DocId) -> TScore
{
fn score(&mut self, doc: DocId) -> TScore {
(self)(doc)
}
}

View File

@@ -12,6 +12,7 @@ use std::marker::PhantomData;
use columnar::{BytesColumn, Column, DynamicColumn, HasAssociatedColumnType};
use crate::collector::{Collector, SegmentCollector};
use crate::schema::Schema;
use crate::{DocId, Score, SegmentReader};
/// The `FilterCollector` filters docs using a fast field value and a predicate.
@@ -49,13 +50,13 @@ use crate::{DocId, Score, SegmentReader};
///
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary")?;
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2));
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2).order_by_score());
/// let top_docs = searcher.search(&query, &no_filter_collector)?;
///
/// assert_eq!(top_docs.len(), 1);
/// assert_eq!(top_docs[0].1, DocAddress::new(0, 1));
///
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2));
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2).order_by_score());
/// let filtered_top_docs = searcher.search(&query, &filter_all_collector)?;
///
/// assert_eq!(filtered_top_docs.len(), 0);
@@ -104,6 +105,11 @@ where
type Child = FilterSegmentCollector<TCollector::Child, TPredicate, TPredicateValue>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.collector.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -120,6 +126,7 @@ where
segment_collector,
predicate: self.predicate.clone(),
t_predicate_value: PhantomData,
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
})
}
@@ -140,6 +147,7 @@ pub struct FilterSegmentCollector<TSegmentCollector, TPredicate, TPredicateValue
segment_collector: TSegmentCollector,
predicate: TPredicate,
t_predicate_value: PhantomData<TPredicateValue>,
filtered_docs: Vec<DocId>,
}
impl<TSegmentCollector, TPredicate, TPredicateValue>
@@ -176,6 +184,20 @@ where
}
}
fn collect_block(&mut self, docs: &[DocId]) {
self.filtered_docs.clear();
for &doc in docs {
// TODO: `accept_document` could be further optimized to do batch lookups of column
// values for single-valued columns.
if self.accept_document(doc) {
self.filtered_docs.push(doc);
}
}
if !self.filtered_docs.is_empty() {
self.segment_collector.collect_block(&self.filtered_docs);
}
}
fn harvest(self) -> TSegmentCollector::Fruit {
self.segment_collector.harvest()
}
@@ -218,7 +240,7 @@ where
///
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary")?;
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2));
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2).order_by_score());
/// let top_docs = searcher.search(&query, &filter_collector)?;
///
/// assert_eq!(top_docs.len(), 1);
@@ -258,6 +280,10 @@ where
type Child = BytesFilterSegmentCollector<TCollector::Child, TPredicate>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.collector.check_schema(schema)
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -274,6 +300,7 @@ where
segment_collector,
predicate: self.predicate.clone(),
buffer: Vec::new(),
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
})
}
@@ -296,6 +323,7 @@ where TPredicate: 'static
segment_collector: TSegmentCollector,
predicate: TPredicate,
buffer: Vec<u8>,
filtered_docs: Vec<DocId>,
}
impl<TSegmentCollector, TPredicate> BytesFilterSegmentCollector<TSegmentCollector, TPredicate>
@@ -334,6 +362,20 @@ where
}
}
fn collect_block(&mut self, docs: &[DocId]) {
self.filtered_docs.clear();
for &doc in docs {
// TODO: `accept_document` could be further optimized to do batch lookups of column
// values for single-valued columns.
if self.accept_document(doc) {
self.filtered_docs.push(doc);
}
}
if !self.filtered_docs.is_empty() {
self.segment_collector.collect_block(&self.filtered_docs);
}
}
fn harvest(self) -> TSegmentCollector::Fruit {
self.segment_collector.harvest()
}

View File

@@ -57,7 +57,7 @@
//! # let query_parser = QueryParser::for_index(&index, vec![title]);
//! # let query = query_parser.parse_query("diary")?;
//! let (doc_count, top_docs): (usize, Vec<(Score, DocAddress)>) =
//! searcher.search(&query, &(Count, TopDocs::with_limit(2)))?;
//! searcher.search(&query, &(Count, TopDocs::with_limit(2).order_by_score()))?;
//! # Ok(())
//! # }
//! ```
@@ -83,11 +83,15 @@
use downcast_rs::impl_downcast;
use crate::schema::Schema;
use crate::{DocId, Score, SegmentOrdinal, SegmentReader};
mod count_collector;
pub use self::count_collector::Count;
/// Sort keys
pub mod sort_key;
mod histogram_collector;
pub use histogram_collector::HistogramCollector;
@@ -95,16 +99,13 @@ mod multi_collector;
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
mod top_collector;
pub use self::top_collector::ComparableDoc;
mod top_score_collector;
pub use self::top_collector::ComparableDoc;
pub use self::top_score_collector::{TopDocs, TopNComputer};
mod custom_score_top_collector;
pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer};
mod tweak_score_top_collector;
pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker};
mod sort_key_top_collector;
pub use self::sort_key::{SegmentSortKeyComputer, SortKeyComputer};
mod facet_collector;
pub use self::facet_collector::{FacetCollector, FacetCounts};
use crate::query::Weight;
@@ -145,6 +146,11 @@ pub trait Collector: Sync + Send {
/// Type of the `SegmentCollector` associated with this collector.
type Child: SegmentCollector;
/// Returns an error if the schema is not compatible with the collector.
fn check_schema(&self, _schema: &Schema) -> crate::Result<()> {
Ok(())
}
/// `set_segment` is called before beginning to enumerate
/// on this segment.
fn for_segment(
@@ -170,41 +176,50 @@ pub trait Collector: Sync + Send {
segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
let with_scoring = self.requires_scoring();
let mut segment_collector = self.for_segment(segment_ord, reader)?;
match (reader.alive_bitset(), self.requires_scoring()) {
(Some(alive_bitset), true) => {
weight.for_each(reader, &mut |doc, score| {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, score);
}
})?;
}
(Some(alive_bitset), false) => {
weight.for_each_no_score(reader, &mut |docs| {
for doc in docs.iter().cloned() {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
}
}
})?;
}
(None, true) => {
weight.for_each(reader, &mut |doc, score| {
segment_collector.collect(doc, score);
})?;
}
(None, false) => {
weight.for_each_no_score(reader, &mut |docs| {
segment_collector.collect_block(docs);
})?;
}
}
default_collect_segment_impl(&mut segment_collector, weight, reader, with_scoring)?;
Ok(segment_collector.harvest())
}
}
pub(crate) fn default_collect_segment_impl<TSegmentCollector: SegmentCollector>(
segment_collector: &mut TSegmentCollector,
weight: &dyn Weight,
reader: &SegmentReader,
with_scoring: bool,
) -> crate::Result<()> {
match (reader.alive_bitset(), with_scoring) {
(Some(alive_bitset), true) => {
weight.for_each(reader, &mut |doc, score| {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, score);
}
})?;
}
(Some(alive_bitset), false) => {
weight.for_each_no_score(reader, &mut |docs| {
for doc in docs.iter().cloned() {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
}
}
})?;
}
(None, true) => {
weight.for_each(reader, &mut |doc, score| {
segment_collector.collect(doc, score);
})?;
}
(None, false) => {
weight.for_each_no_score(reader, &mut |docs| {
segment_collector.collect_block(docs);
})?;
}
}
Ok(())
}
impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCollector> {
type Fruit = Option<TSegmentCollector::Fruit>;
@@ -214,6 +229,12 @@ impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCo
}
}
fn collect_block(&mut self, docs: &[DocId]) {
if let Some(segment_collector) = self {
segment_collector.collect_block(docs);
}
}
fn harvest(self) -> Self::Fruit {
self.map(|segment_collector| segment_collector.harvest())
}
@@ -224,6 +245,13 @@ impl<TCollector: Collector> Collector for Option<TCollector> {
type Child = Option<<TCollector as Collector>::Child>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
if let Some(underlying_collector) = self {
underlying_collector.check_schema(schema)?;
}
Ok(())
}
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
@@ -299,6 +327,12 @@ where
type Fruit = (Left::Fruit, Right::Fruit);
type Child = (Left::Child, Right::Child);
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -342,6 +376,11 @@ where
self.1.collect(doc, score);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.0.collect_block(docs);
self.1.collect_block(docs);
}
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
(self.0.harvest(), self.1.harvest())
}
@@ -358,6 +397,13 @@ where
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit);
type Child = (One::Child, Two::Child, Three::Child);
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -407,6 +453,12 @@ where
self.2.collect(doc, score);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.0.collect_block(docs);
self.1.collect_block(docs);
self.2.collect_block(docs);
}
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
(self.0.harvest(), self.1.harvest(), self.2.harvest())
}
@@ -424,6 +476,14 @@ where
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit, Four::Fruit);
type Child = (One::Child, Two::Child, Three::Child, Four::Child);
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
self.3.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -482,6 +542,13 @@ where
self.3.collect(doc, score);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.0.collect_block(docs);
self.1.collect_block(docs);
self.2.collect_block(docs);
self.3.collect_block(docs);
}
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
(
self.0.harvest(),

View File

@@ -3,6 +3,7 @@ use std::ops::Deref;
use super::{Collector, SegmentCollector};
use crate::collector::Fruit;
use crate::schema::Schema;
use crate::{DocId, Score, SegmentOrdinal, SegmentReader, TantivyError};
/// MultiFruit keeps Fruits from every nested Collector
@@ -16,6 +17,10 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
type Fruit = Box<dyn Fruit>;
type Child = Box<dyn BoxableSegmentCollector>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -147,7 +152,7 @@ impl<TFruit: Fruit> FruitHandle<TFruit> {
/// let searcher = reader.searcher();
///
/// let mut collectors = MultiCollector::new();
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2));
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
/// let count_handle = collectors.add_collector(Count);
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary").unwrap();
@@ -194,6 +199,13 @@ impl Collector for MultiCollector<'_> {
type Fruit = MultiFruit;
type Child = MultiCollectorChild;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
for collector in &self.collector_wrappers {
collector.check_schema(schema)?;
}
Ok(())
}
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
@@ -250,6 +262,12 @@ impl SegmentCollector for MultiCollectorChild {
}
}
fn collect_block(&mut self, docs: &[DocId]) {
for child in &mut self.children {
child.collect_block(docs);
}
}
fn harvest(self) -> MultiFruit {
MultiFruit {
sub_fruits: self
@@ -293,7 +311,7 @@ mod tests {
let query = TermQuery::new(term, IndexRecordOption::Basic);
let mut collectors = MultiCollector::new();
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2));
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
let count_handler = collectors.add_collector(Count);
let mut multifruits = searcher.search(&query, &collectors).unwrap();

View File

@@ -0,0 +1,407 @@
mod order;
mod sort_by_score;
mod sort_by_static_fast_value;
mod sort_by_string;
mod sort_key_computer;
pub use order::*;
pub use sort_by_score::SortBySimilarityScore;
pub use sort_by_static_fast_value::SortByStaticFastValue;
pub use sort_by_string::SortByString;
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
#[cfg(test)]
pub(crate) mod tests {
// By spec, regardless of whether ascending or descending order was requested, in presence of a
// tie, we sort by ascending doc id/doc address.
pub(crate) fn sort_hits<TSortKey: Ord, D: Ord>(
hits: &mut [ComparableDoc<TSortKey, D>],
order: Order,
) {
if order.is_asc() {
hits.sort_by(|l, r| l.sort_key.cmp(&r.sort_key).then(l.doc.cmp(&r.doc)));
} else {
hits.sort_by(|l, r| {
l.sort_key
.cmp(&r.sort_key)
.reverse() // This is descending
.then(l.doc.cmp(&r.doc))
});
}
}
use std::collections::HashMap;
use std::ops::Range;
use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString};
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
use crate::indexer::NoMergePolicy;
use crate::query::{AllQuery, QueryParser};
use crate::schema::{Schema, FAST, TEXT};
use crate::{DocAddress, Document, Index, Order, Score, Searcher};
fn make_index() -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
let id = schema_builder.add_u64_field("id", FAST);
let city = schema_builder.add_text_field("city", TEXT | FAST);
let catchphrase = schema_builder.add_text_field("catchphrase", TEXT);
let altitude = schema_builder.add_f64_field("altitude", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
fn create_segment(index: &Index, docs: Vec<impl Document>) -> crate::Result<()> {
let mut index_writer = index.writer_for_tests()?;
index_writer.set_merge_policy(Box::new(NoMergePolicy));
for doc in docs {
index_writer.add_document(doc)?;
}
index_writer.commit()?;
Ok(())
}
create_segment(
&index,
vec![
doc!(
id => 0_u64,
city => "austin",
catchphrase => "Hills, Barbeque, Glow",
altitude => 149.0,
),
doc!(
id => 1_u64,
city => "greenville",
catchphrase => "Grow, Glow, Glow",
altitude => 27.0,
),
],
)?;
create_segment(
&index,
vec![doc!(
id => 2_u64,
city => "tokyo",
catchphrase => "Glow, Glow, Glow",
altitude => 40.0,
)],
)?;
create_segment(
&index,
vec![doc!(
id => 3_u64,
catchphrase => "No, No, No",
altitude => 0.0,
)],
)?;
Ok(index)
}
// NOTE: You cannot determine the SegmentIds that will be generated for Segments
// ahead of time, so DocAddresses must be mapped back to a unique id for each Searcher.
fn id_mapping(searcher: &Searcher) -> HashMap<DocAddress, u64> {
searcher
.search(&AllQuery, &DocSetCollector)
.unwrap()
.into_iter()
.map(|doc_address| {
let column = searcher.segment_readers()[doc_address.segment_ord as usize]
.fast_fields()
.u64("id")
.unwrap();
(doc_address, column.first(doc_address.doc_id).unwrap())
})
.collect()
}
#[test]
fn test_order_by_string() -> crate::Result<()> {
let index = make_index()?;
#[track_caller]
fn assert_query(
index: &Index,
order: Order,
doc_range: Range<usize>,
expected: Vec<(Option<String>, u64)>,
) -> crate::Result<()> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
// Try as primitive.
let top_collector = TopDocs::for_doc_range(doc_range)
.order_by((SortByString::for_field("city"), order));
let actual = searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(sort_key_opt, doc)| (sort_key_opt, ids[&doc]))
.collect::<Vec<_>>();
assert_eq!(actual, expected);
Ok(())
}
assert_query(
&index,
Order::Asc,
0..4,
vec![
(Some("austin".to_owned()), 0),
(Some("greenville".to_owned()), 1),
(Some("tokyo".to_owned()), 2),
(None, 3),
],
)?;
assert_query(
&index,
Order::Asc,
0..3,
vec![
(Some("austin".to_owned()), 0),
(Some("greenville".to_owned()), 1),
(Some("tokyo".to_owned()), 2),
],
)?;
assert_query(
&index,
Order::Asc,
0..2,
vec![
(Some("austin".to_owned()), 0),
(Some("greenville".to_owned()), 1),
],
)?;
assert_query(
&index,
Order::Asc,
0..1,
vec![(Some("austin".to_string()), 0)],
)?;
assert_query(
&index,
Order::Asc,
1..3,
vec![
(Some("greenville".to_owned()), 1),
(Some("tokyo".to_owned()), 2),
],
)?;
assert_query(
&index,
Order::Desc,
0..4,
vec![
(Some("tokyo".to_owned()), 2),
(Some("greenville".to_owned()), 1),
(Some("austin".to_owned()), 0),
(None, 3),
],
)?;
assert_query(
&index,
Order::Desc,
1..3,
vec![
(Some("greenville".to_owned()), 1),
(Some("austin".to_owned()), 0),
],
)?;
assert_query(
&index,
Order::Desc,
0..1,
vec![(Some("tokyo".to_owned()), 2)],
)?;
Ok(())
}
#[test]
fn test_order_by_f64() -> crate::Result<()> {
let index = make_index()?;
fn assert_query(
index: &Index,
order: Order,
expected: Vec<(Option<f64>, u64)>,
) -> crate::Result<()> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
// Try as primitive.
let top_collector = TopDocs::with_limit(3)
.order_by((SortByStaticFastValue::<f64>::for_field("altitude"), order));
let actual = searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(altitude_opt, doc)| (altitude_opt, ids[&doc]))
.collect::<Vec<_>>();
assert_eq!(actual, expected);
Ok(())
}
assert_query(
&index,
Order::Asc,
vec![(Some(0.0), 3), (Some(27.0), 1), (Some(40.0), 2)],
)?;
assert_query(
&index,
Order::Desc,
vec![(Some(149.0), 0), (Some(40.0), 2), (Some(27.0), 1)],
)?;
Ok(())
}
#[test]
fn test_order_by_score() -> crate::Result<()> {
let index = make_index()?;
fn query(index: &Index, order: Order) -> crate::Result<Vec<(Score, u64)>> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by((SortBySimilarityScore, order));
let field = index.schema().get_field("catchphrase").unwrap();
let query_parser = QueryParser::for_index(index, vec![field]);
let text_query = query_parser.parse_query("glow")?;
Ok(searcher
.search(&text_query, &top_collector)?
.into_iter()
.map(|(score, doc)| (score, ids[&doc]))
.collect())
}
assert_eq!(
&query(&index, Order::Desc)?,
&[(0.5604893, 2), (0.4904281, 1), (0.35667497, 0),]
);
assert_eq!(
&query(&index, Order::Asc)?,
&[(0.35667497, 0), (0.4904281, 1), (0.5604893, 2),]
);
Ok(())
}
#[test]
fn test_order_by_score_then_string() -> crate::Result<()> {
let index = make_index()?;
type SortKey = (Score, Option<String>);
fn query(
index: &Index,
score_order: Order,
city_order: Order,
) -> crate::Result<Vec<(SortKey, u64)>> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by((
(SortBySimilarityScore, score_order),
(SortByString::for_field("city"), city_order),
));
Ok(searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(f, doc)| (f, ids[&doc]))
.collect())
}
assert_eq!(
&query(&index, Order::Asc, Order::Asc)?,
&[
((1.0, Some("austin".to_owned())), 0),
((1.0, Some("greenville".to_owned())), 1),
((1.0, Some("tokyo".to_owned())), 2),
((1.0, None), 3),
]
);
assert_eq!(
&query(&index, Order::Asc, Order::Desc)?,
&[
((1.0, Some("tokyo".to_owned())), 2),
((1.0, Some("greenville".to_owned())), 1),
((1.0, Some("austin".to_owned())), 0),
((1.0, None), 3),
]
);
Ok(())
}
use proptest::prelude::*;
proptest! {
#[test]
fn test_order_by_string_prop(
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..64_usize,
offset in 0..64_usize,
segments_terms in
proptest::collection::vec(
proptest::collection::vec(0..32_u8, 1..32_usize),
0..8_usize,
)
) {
let mut schema_builder = Schema::builder();
let city = schema_builder.add_text_field("city", TEXT | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests()?;
// A Vec<Vec<u8>>, where the outer Vec represents segments, and the inner Vec
// represents terms.
for segment_terms in segments_terms.into_iter() {
for term in segment_terms.into_iter() {
let term = format!("{term:0>3}");
index_writer.add_document(doc!(
city => term,
))?;
}
index_writer.commit()?;
}
let searcher = index.reader()?.searcher();
let top_n_results = searcher.search(&AllQuery, &TopDocs::with_limit(limit)
.and_offset(offset)
.order_by_string_fast_field("city", order))?;
let all_results = searcher.search(&AllQuery, &DocSetCollector)?.into_iter().map(|doc_address| {
// Get the term for this address.
let column = searcher.segment_readers()[doc_address.segment_ord as usize].fast_fields().str("city").unwrap().unwrap();
let value = column.term_ords(doc_address.doc_id).next().map(|term_ord| {
let mut city = Vec::new();
column.dictionary().ord_to_term(term_ord, &mut city).unwrap();
String::try_from(city).unwrap()
});
(value, doc_address)
});
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
sort_hits(&mut comparable_docs, order);
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
prop_assert_eq!(
expected_docs,
top_n_results
);
}
}
}

View File

@@ -0,0 +1,349 @@
use std::cmp::Ordering;
use serde::{Deserialize, Serialize};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::Schema;
use crate::{DocId, Order, Score};
/// Comparator trait defining the order in which documents should be ordered.
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
/// Return the order between two values.
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
}
/// With the natural comparator, the top k collector will return
/// the top documents in decreasing order.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct NaturalComparator;
impl<T: PartialOrd> Comparator<T> for NaturalComparator {
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
lhs.partial_cmp(rhs).unwrap()
}
}
/// Sorts document in reverse order.
///
/// If the sort key is None, it will considered as the lowest value, and will therefore appear
/// first.
///
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
/// to the NaturalComparator. In presence of a tie on the sort key, documents will always be
/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the comparator.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct ReverseComparator;
impl<T> Comparator<T> for ReverseComparator
where NaturalComparator: Comparator<T>
{
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
NaturalComparator.compare(rhs, lhs)
}
}
/// Sorts document in reverse order, but considers None as having the lowest value.
///
/// This is usually what is wanted when sorting by a field in an ascending order.
/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the
/// cheapest items first, and the items without a price at last.
#[derive(Debug, Copy, Clone, Default)]
pub struct ReverseNoneIsLowerComparator;
impl<T> Comparator<Option<T>> for ReverseNoneIsLowerComparator
where ReverseComparator: Comparator<T>
{
#[inline(always)]
fn compare(&self, lhs_opt: &Option<T>, rhs_opt: &Option<T>) -> Ordering {
match (lhs_opt, rhs_opt) {
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(Some(lhs), Some(rhs)) => ReverseComparator.compare(lhs, rhs),
}
}
}
impl Comparator<u32> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<u64> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<f64> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<f32> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<i64> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<String> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
/// An enum representing the different sort orders.
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
pub enum ComparatorEnum {
/// Natural order (See [NaturalComparator])
#[default]
Natural,
/// Reverse order (See [ReverseComparator])
Reverse,
/// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator])
ReverseNoneLower,
}
impl From<Order> for ComparatorEnum {
fn from(order: Order) -> Self {
match order {
Order::Asc => ComparatorEnum::ReverseNoneLower,
Order::Desc => ComparatorEnum::Natural,
}
}
}
impl<T> Comparator<T> for ComparatorEnum
where
ReverseNoneIsLowerComparator: Comparator<T>,
NaturalComparator: Comparator<T>,
ReverseComparator: Comparator<T>,
{
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
match self {
ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs),
ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs),
ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs),
}
}
}
impl<Head, Tail, LeftComparator, RightComparator> Comparator<(Head, Tail)>
for (LeftComparator, RightComparator)
where
LeftComparator: Comparator<Head>,
RightComparator: Comparator<Tail>,
{
#[inline(always)]
fn compare(&self, lhs: &(Head, Tail), rhs: &(Head, Tail)) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, (Type2, Type3))>
for (Comparator1, Comparator2, Comparator3)
where
Comparator1: Comparator<Type1>,
Comparator2: Comparator<Type2>,
Comparator3: Comparator<Type3>,
{
#[inline(always)]
fn compare(&self, lhs: &(Type1, (Type2, Type3)), rhs: &(Type1, (Type2, Type3))) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
.then_with(|| self.2.compare(&lhs.1 .1, &rhs.1 .1))
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, Type2, Type3)>
for (Comparator1, Comparator2, Comparator3)
where
Comparator1: Comparator<Type1>,
Comparator2: Comparator<Type2>,
Comparator3: Comparator<Type3>,
{
#[inline(always)]
fn compare(&self, lhs: &(Type1, Type2, Type3), rhs: &(Type1, Type2, Type3)) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
Comparator<(Type1, (Type2, (Type3, Type4)))>
for (Comparator1, Comparator2, Comparator3, Comparator4)
where
Comparator1: Comparator<Type1>,
Comparator2: Comparator<Type2>,
Comparator3: Comparator<Type3>,
Comparator4: Comparator<Type4>,
{
#[inline(always)]
fn compare(
&self,
lhs: &(Type1, (Type2, (Type3, Type4))),
rhs: &(Type1, (Type2, (Type3, Type4))),
) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
.then_with(|| self.2.compare(&lhs.1 .1 .0, &rhs.1 .1 .0))
.then_with(|| self.3.compare(&lhs.1 .1 .1, &rhs.1 .1 .1))
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
Comparator<(Type1, Type2, Type3, Type4)>
for (Comparator1, Comparator2, Comparator3, Comparator4)
where
Comparator1: Comparator<Type1>,
Comparator2: Comparator<Type2>,
Comparator3: Comparator<Type3>,
Comparator4: Comparator<Type4>,
{
#[inline(always)]
fn compare(
&self,
lhs: &(Type1, Type2, Type3, Type4),
rhs: &(Type1, Type2, Type3, Type4),
) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
.then_with(|| self.3.compare(&lhs.3, &rhs.3))
}
}
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, ComparatorEnum)
where
TSortKeyComputer: SortKeyComputer,
ComparatorEnum: Comparator<TSortKeyComputer::SortKey>,
ComparatorEnum: Comparator<
<<TSortKeyComputer as SortKeyComputer>::Child as SegmentSortKeyComputer>::SegmentSortKey,
>,
{
type SortKey = TSortKeyComputer::SortKey;
type Child = SegmentSortKeyComputerWithComparator<TSortKeyComputer::Child, Self::Comparator>;
type Comparator = ComparatorEnum;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)
}
fn requires_scoring(&self) -> bool {
self.0.requires_scoring()
}
fn comparator(&self) -> Self::Comparator {
self.1
}
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let child = self.0.segment_sort_key_computer(segment_reader)?;
Ok(SegmentSortKeyComputerWithComparator {
segment_sort_key_computer: child,
comparator: self.comparator(),
})
}
}
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, Order)
where
TSortKeyComputer: SortKeyComputer,
ComparatorEnum: Comparator<TSortKeyComputer::SortKey>,
ComparatorEnum: Comparator<
<<TSortKeyComputer as SortKeyComputer>::Child as SegmentSortKeyComputer>::SegmentSortKey,
>,
{
type SortKey = TSortKeyComputer::SortKey;
type Child = SegmentSortKeyComputerWithComparator<TSortKeyComputer::Child, Self::Comparator>;
type Comparator = ComparatorEnum;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)
}
fn requires_scoring(&self) -> bool {
self.0.requires_scoring()
}
fn comparator(&self) -> Self::Comparator {
self.1.into()
}
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let child = self.0.segment_sort_key_computer(segment_reader)?;
Ok(SegmentSortKeyComputerWithComparator {
segment_sort_key_computer: child,
comparator: self.comparator(),
})
}
}
/// A segment sort key computer with a custom ordering.
pub struct SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator> {
segment_sort_key_computer: TSegmentSortKeyComputer,
comparator: TComparator,
}
impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComputer
for SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator>
where
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send,
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
{
type SortKey = TSegmentSortKeyComputer::SortKey;
type SegmentSortKey = TSegmentSortKey;
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.segment_sort_key_computer.segment_sort_key(doc, score)
}
#[inline(always)]
fn compare_segment_sort_key(
&self,
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
self.comparator.compare(left, right)
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
self.segment_sort_key_computer
.convert_segment_sort_key(sort_key)
}
}

View File

@@ -0,0 +1,77 @@
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::{DocAddress, DocId, Score};
/// Sort by similarity score.
#[derive(Clone, Debug, Copy)]
pub struct SortBySimilarityScore;
impl SortKeyComputer for SortBySimilarityScore {
type SortKey = Score;
type Child = SortBySimilarityScore;
type Comparator = NaturalComparator;
fn requires_scoring(&self) -> bool {
true
}
fn segment_sort_key_computer(
&self,
_segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
Ok(SortBySimilarityScore)
}
// Sorting by score is special in that it allows for the Block-Wand optimization.
fn collect_segment_top_k(
&self,
k: usize,
weight: &dyn crate::query::Weight,
reader: &crate::SegmentReader,
segment_ord: u32,
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
let mut top_n: TopNComputer<Score, DocId, Self::Comparator> =
TopNComputer::new_with_comparator(k, self.comparator());
if let Some(alive_bitset) = reader.alive_bitset() {
let mut threshold = Score::MIN;
top_n.threshold = Some(threshold);
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
if alive_bitset.is_deleted(doc) {
return threshold;
}
top_n.push(score, doc);
threshold = top_n.threshold.unwrap_or(Score::MIN);
threshold
})?;
} else {
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
top_n.push(score, doc);
top_n.threshold.unwrap_or(Score::MIN)
})?;
}
Ok(top_n
.into_vec()
.into_iter()
.map(|cid| (cid.sort_key, DocAddress::new(segment_ord, cid.doc)))
.collect())
}
}
impl SegmentSortKeyComputer for SortBySimilarityScore {
type SortKey = Score;
type SegmentSortKey = Score;
#[inline(always)]
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {
score
}
fn convert_segment_sort_key(&self, score: Score) -> Score {
score
}
}

View File

@@ -0,0 +1,98 @@
use std::marker::PhantomData;
use columnar::Column;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
use crate::{DocId, Score, SegmentReader};
/// Sorts by a fast value (u64, i64, f64, bool).
///
/// The field must appear explicitly in the schema, with the right type, and declared as
/// a fast field..
///
/// If the field is multivalued, only the first value is considered.
///
/// Documents that do not have this value are still considered.
/// Their sort key will simply be `None`.
#[derive(Debug, Clone)]
pub struct SortByStaticFastValue<T: FastValue> {
field: String,
typ: PhantomData<T>,
}
impl<T: FastValue> SortByStaticFastValue<T> {
/// Creates a new `SortByStaticFastValue` instance for the given field.
pub fn for_field(column_name: impl ToString) -> SortByStaticFastValue<T> {
Self {
field: column_name.to_string(),
typ: PhantomData,
}
}
}
impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
type Child = SortByFastValueSegmentSortKeyComputer<T>;
type SortKey = Option<T>;
type Comparator = NaturalComparator;
fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> {
// At the segment sort key computer level, we rely on the u64 representation.
// The mapping is monotonic, so it is sufficient to compute our top-K docs.
let field = schema.get_field(&self.field)?;
let field_entry = schema.get_field_entry(field);
if !field_entry.is_fast() {
return Err(crate::TantivyError::SchemaError(format!(
"Field `{}` is not a fast field.",
self.field,
)));
}
let schema_type = field_entry.field_type().value_type();
if schema_type != T::to_type() {
return Err(crate::TantivyError::SchemaError(format!(
"Field `{}` is of type {schema_type:?}, not of the type {:?}.",
&self.field,
T::to_type()
)));
}
Ok(())
}
fn segment_sort_key_computer(
&self,
segment_reader: &SegmentReader,
) -> crate::Result<Self::Child> {
let sort_column_opt = segment_reader.fast_fields().u64_lenient(&self.field)?;
let (sort_column, _sort_column_type) =
sort_column_opt.ok_or_else(|| FastFieldNotAvailableError {
field_name: self.field.clone(),
})?;
Ok(SortByFastValueSegmentSortKeyComputer {
sort_column,
typ: PhantomData,
})
}
}
pub struct SortByFastValueSegmentSortKeyComputer<T> {
sort_column: Column<u64>,
typ: PhantomData<T>,
}
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
type SortKey = Option<T>;
type SegmentSortKey = Option<u64>;
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {
self.sort_column.first(doc)
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key.map(T::from_u64)
}
}

View File

@@ -0,0 +1,72 @@
use columnar::StrColumn;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
use crate::{DocId, Score};
/// Sort by the first value of a string column.
///
/// The string can be dynamic (coming from a json field)
/// or static (being specificaly defined in the configuration).
///
/// If the field is multivalued, only the first value is considered.
///
/// Documents that do not have this value are still considered.
/// Their sort key will simply be `None`.
#[derive(Debug, Clone)]
pub struct SortByString {
column_name: String,
}
impl SortByString {
/// Creates a new sort by string sort key computer.
pub fn for_field(column_name: impl ToString) -> Self {
SortByString {
column_name: column_name.to_string(),
}
}
}
impl SortKeyComputer for SortByString {
type SortKey = Option<String>;
type Child = ByStringColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?;
Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt })
}
}
pub struct ByStringColumnSegmentSortKeyComputer {
str_column_opt: Option<StrColumn>,
}
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
type SortKey = Option<String>;
type SegmentSortKey = Option<TermOrdinal>;
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
let str_column = self.str_column_opt.as_ref()?;
str_column.ords().first(doc)
}
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
let term_ord = term_ord_opt?;
let str_column = self.str_column_opt.as_ref()?;
let mut bytes = Vec::new();
str_column
.dictionary()
.ord_to_term(term_ord, &mut bytes)
.ok()?;
String::try_from(bytes).ok()
}
}

View File

@@ -0,0 +1,631 @@
use std::cmp::Ordering;
use crate::collector::sort_key::{Comparator, NaturalComparator};
use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector;
use crate::collector::{default_collect_segment_impl, SegmentCollector as _, TopNComputer};
use crate::schema::Schema;
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
/// A `SegmentSortKeyComputer` makes it possible to modify the default score
/// for a given document belonging to a specific segment.
///
/// It is the segment local version of the [`SortKeyComputer`].
pub trait SegmentSortKeyComputer: 'static {
/// The final score being emitted.
type SortKey: 'static + PartialOrd + Send + Sync + Clone;
/// Sort key used by at the segment level by the `SegmentSortKeyComputer`.
///
/// It is typically small like a `u64`, and is meant to be converted
/// to the final score at the end of the collection of the segment.
type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone;
/// Computes the sort key for the given document and score.
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
/// Computes the sort key and pushes the document in a TopN Computer.
///
/// When using a tuple as the sorting key, the sort key is evaluated in a lazy manner.
#[inline(always)]
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
let sort_key = self.segment_sort_key(doc, score);
top_n_computer.push(sort_key, doc);
}
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
///
/// This method must be consistent with the `SortKey` ordering.
#[inline(always)]
fn compare_segment_sort_key(
&self,
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
NaturalComparator.compare(left, right)
}
/// Implementing this method makes it possible to avoid computing
/// a sort_key entirely if we can assess that it won't pass a threshold
/// with a partial computation.
///
/// This is currently used for lexicographic sorting.
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let sort_key = self.segment_sort_key(doc_id, score);
let cmp = self.compare_segment_sort_key(&sort_key, threshold);
if cmp == Ordering::Less {
None
} else {
Some((cmp, sort_key))
}
}
/// Convert a segment level sort key into the global sort key.
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey;
}
/// `SortKeyComputer` defines the sort key to be used by a TopK Collector.
///
/// The `SortKeyComputer` itself does not make much of the computation itself.
/// Instead, it helps constructing `Self::Child` instances that will compute
/// the sort key at a segment scale.
pub trait SortKeyComputer: Sync {
/// The sort key type.
type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug;
/// Type of the associated [`SegmentSortKeyComputer`].
type Child: SegmentSortKeyComputer<SortKey = Self::SortKey>;
/// Comparator type.
type Comparator: Comparator<Self::SortKey>
+ Comparator<<Self::Child as SegmentSortKeyComputer>::SegmentSortKey>
+ 'static;
/// Checks whether the schema is compatible with the sort key computer.
fn check_schema(&self, _schema: &Schema) -> crate::Result<()> {
Ok(())
}
/// Returns the sort key comparator.
fn comparator(&self) -> Self::Comparator {
Self::Comparator::default()
}
/// Indicates whether the sort key actually uses the similarity score (by default BM25).
/// If set to false, the similary score might not be computed (as an optimization),
/// and the score fed in the segment sort key computer could take any value.
fn requires_scoring(&self) -> bool {
false
}
/// Sorting by score has a overriding implementation for BM25 scores, using Block-WAND.
fn collect_segment_top_k(
&self,
k: usize,
weight: &dyn crate::query::Weight,
reader: &crate::SegmentReader,
segment_ord: u32,
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
let with_scoring = self.requires_scoring();
let segment_sort_key_computer = self.segment_sort_key_computer(reader)?;
let topn_computer = TopNComputer::new_with_comparator(k, self.comparator());
let mut segment_top_key_collector = TopBySortKeySegmentCollector {
topn_computer,
segment_ord,
segment_sort_key_computer,
};
default_collect_segment_impl(&mut segment_top_key_collector, weight, reader, with_scoring)?;
Ok(segment_top_key_collector.harvest())
}
/// Builds a child sort key computer for a specific segment.
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
}
impl<HeadSortKeyComputer, TailSortKeyComputer> SortKeyComputer
for (HeadSortKeyComputer, TailSortKeyComputer)
where
HeadSortKeyComputer: SortKeyComputer,
TailSortKeyComputer: SortKeyComputer,
{
type SortKey = (
<HeadSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
<TailSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
);
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
type Comparator = (
HeadSortKeyComputer::Comparator,
TailSortKeyComputer::Comparator,
);
fn comparator(&self) -> Self::Comparator {
(self.0.comparator(), self.1.comparator())
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((
self.0.segment_sort_key_computer(segment_reader)?,
self.1.segment_sort_key_computer(segment_reader)?,
))
}
/// Checks whether the schema is compatible with the sort key computer.
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
Ok(())
}
/// Indicates whether the sort key actually uses the similarity score (by default BM25).
/// If set to false, the similary score might not be computed (as an optimization),
/// and the score fed in the segment sort key computer could take any value.
fn requires_scoring(&self) -> bool {
self.0.requires_scoring() || self.1.requires_scoring()
}
}
impl<HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer> SegmentSortKeyComputer
for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer)
where
HeadSegmentSortKeyComputer: SegmentSortKeyComputer,
TailSegmentSortKeyComputer: SegmentSortKeyComputer,
{
type SortKey = (
HeadSegmentSortKeyComputer::SortKey,
TailSegmentSortKeyComputer::SortKey,
);
type SegmentSortKey = (
HeadSegmentSortKeyComputer::SegmentSortKey,
TailSegmentSortKeyComputer::SegmentSortKey,
);
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
///
/// By default, it uses the natural ordering.
#[inline]
fn compare_segment_sort_key(
&self,
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
self.0
.compare_segment_sort_key(&left.0, &right.0)
.then_with(|| self.1.compare_segment_sort_key(&left.1, &right.1))
}
#[inline(always)]
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
let sort_key: Self::SegmentSortKey;
if let Some(threshold) = &top_n_computer.threshold {
if let Some((_cmp, lazy_sort_key)) = self.accept_sort_key_lazy(doc, score, threshold) {
sort_key = lazy_sort_key;
} else {
return;
}
} else {
sort_key = self.segment_sort_key(doc, score);
};
top_n_computer.append_doc(doc, sort_key);
}
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
let head_sort_key = self.0.segment_sort_key(doc, score);
let tail_sort_key = self.1.segment_sort_key(doc, score);
(head_sort_key, tail_sort_key)
}
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let (head_threshold, tail_threshold) = threshold;
let (head_cmp, head_sort_key) =
self.0.accept_sort_key_lazy(doc_id, score, head_threshold)?;
if head_cmp == Ordering::Equal {
let (tail_cmp, tail_sort_key) =
self.1.accept_sort_key_lazy(doc_id, score, tail_threshold)?;
Some((tail_cmp, (head_sort_key, tail_sort_key)))
} else {
let tail_sort_key = self.1.segment_sort_key(doc_id, score);
Some((head_cmp, (head_sort_key, tail_sort_key)))
}
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
let (head_sort_key, tail_sort_key) = sort_key;
(
self.0.convert_segment_sort_key(head_sort_key),
self.1.convert_segment_sort_key(tail_sort_key),
)
}
}
/// This struct is used as an adapter to take a sort key computer and map its score to another
/// new sort key.
pub struct MappedSegmentSortKeyComputer<T, PreviousSortKey, NewSortKey> {
sort_key_computer: T,
map: fn(PreviousSortKey) -> NewSortKey,
}
impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
where
T: SegmentSortKeyComputer<SortKey = PreviousScore>,
PreviousScore: 'static + Clone + Send + Sync + PartialOrd,
NewScore: 'static + Clone + Send + Sync + PartialOrd,
{
type SortKey = NewScore;
type SegmentSortKey = T::SegmentSortKey;
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.sort_key_computer.segment_sort_key(doc, score)
}
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
self.sort_key_computer
.accept_sort_key_lazy(doc_id, score, threshold)
}
#[inline(always)]
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
self.sort_key_computer
.compute_sort_key_and_collect(doc, score, top_n_computer);
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey {
(self.map)(
self.sort_key_computer
.convert_segment_sort_key(segment_sort_key),
)
}
}
// We then re-use our (head, tail) implement and our mapper by seeing mapping any tuple (a, b, c,
// ...) as the chain (a, (b, (c, ...)))
impl<SortKeyComputer1, SortKeyComputer2, SortKeyComputer3> SortKeyComputer
for (SortKeyComputer1, SortKeyComputer2, SortKeyComputer3)
where
SortKeyComputer1: SortKeyComputer,
SortKeyComputer2: SortKeyComputer,
SortKeyComputer3: SortKeyComputer,
{
type SortKey = (
SortKeyComputer1::SortKey,
SortKeyComputer2::SortKey,
SortKeyComputer3::SortKey,
);
type Child = MappedSegmentSortKeyComputer<
<(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(SortKeyComputer2::SortKey, SortKeyComputer3::SortKey),
),
Self::SortKey,
>;
type Comparator = (
SortKeyComputer1::Comparator,
SortKeyComputer2::Comparator,
SortKeyComputer3::Comparator,
);
fn comparator(&self) -> Self::Comparator {
(
self.0.comparator(),
self.1.comparator(),
self.2.comparator(),
)
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let map = |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3);
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)),
map,
})
}
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
Ok(())
}
fn requires_scoring(&self) -> bool {
self.0.requires_scoring() || self.1.requires_scoring() || self.2.requires_scoring()
}
}
impl<SortKeyComputer1, SortKeyComputer2, SortKeyComputer3, SortKeyComputer4> SortKeyComputer
for (
SortKeyComputer1,
SortKeyComputer2,
SortKeyComputer3,
SortKeyComputer4,
)
where
SortKeyComputer1: SortKeyComputer,
SortKeyComputer2: SortKeyComputer,
SortKeyComputer3: SortKeyComputer,
SortKeyComputer4: SortKeyComputer,
{
type Child = MappedSegmentSortKeyComputer<
<(
SortKeyComputer1,
(SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)),
) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(
SortKeyComputer2::SortKey,
(SortKeyComputer3::SortKey, SortKeyComputer4::SortKey),
),
),
Self::SortKey,
>;
type SortKey = (
SortKeyComputer1::SortKey,
SortKeyComputer2::SortKey,
SortKeyComputer3::SortKey,
SortKeyComputer4::SortKey,
);
type Comparator = (
SortKeyComputer1::Comparator,
SortKeyComputer2::Comparator,
SortKeyComputer3::Comparator,
SortKeyComputer4::Comparator,
);
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?;
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: (
sort_key_computer1,
(sort_key_computer2, (sort_key_computer3, sort_key_computer4)),
),
map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| {
(sort_key1, sort_key2, sort_key3, sort_key4)
},
})
}
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
self.3.check_schema(schema)?;
Ok(())
}
fn requires_scoring(&self) -> bool {
self.0.requires_scoring()
|| self.1.requires_scoring()
|| self.2.requires_scoring()
|| self.3.requires_scoring()
}
}
impl<F, SegmentF, TSortKey> SortKeyComputer for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF,
SegmentF: 'static + FnMut(DocId) -> TSortKey,
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
{
type SortKey = TSortKey;
type Child = SegmentF;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((self)(segment_reader))
}
}
impl<F, TSortKey> SegmentSortKeyComputer for F
where
F: 'static + FnMut(DocId) -> TSortKey,
TSortKey: 'static + PartialOrd + Clone + Send + Sync,
{
type SortKey = TSortKey;
type SegmentSortKey = TSortKey;
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
(self)(doc)
}
/// Convert a segment level score into the global level score.
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::sync::Arc;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::Schema;
use crate::{DocId, Index, Order, SegmentReader};
fn build_test_index() -> Index {
let schema = Schema::builder().build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::TantivyDocument::default())
.unwrap();
index_writer.commit().unwrap();
index
}
#[test]
fn test_lazy_score_computer() {
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
let call_count_new_clone = call_count_clone.clone();
move |_doc: DocId| {
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);
"b"
}
};
let lazy_score_computer = (score_computer_primary, score_computer_secondary);
let index = build_test_index();
let searcher = index.reader().unwrap().searcher();
let mut segment_sort_key_computer = lazy_score_computer
.segment_sort_key_computer(searcher.segment_reader(0))
.unwrap();
let expected_sort_key = (200, "b");
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, "a"));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 1);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, "c"));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 2);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, "a"));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 3);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, "c"));
assert!(sort_key_opt.is_none());
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, "a"));
assert_eq!(sort_key_opt, None);
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, "c"));
assert_eq!(sort_key_opt, None);
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &expected_sort_key);
assert_eq!(sort_key_opt, Some((Ordering::Equal, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 5);
}
}
#[test]
fn test_lazy_score_computer_dynamic_ordering() {
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
let call_count_new_clone = call_count_clone.clone();
move |_doc: DocId| {
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);
2u32
}
};
let lazy_score_computer = (
(score_computer_primary, Order::Desc),
(score_computer_secondary, Order::Asc),
);
let index = build_test_index();
let searcher = index.reader().unwrap().searcher();
let mut segment_sort_key_computer = lazy_score_computer
.segment_sort_key_computer(searcher.segment_reader(0))
.unwrap();
let expected_sort_key = (200, 2u32);
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, 1u32));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 1);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, 3u32));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 2);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, 1u32));
assert!(sort_key_opt.is_none());
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 3);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, 3u32));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, 1u32));
assert_eq!(sort_key_opt, None);
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, 3u32));
assert_eq!(sort_key_opt, None);
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &expected_sort_key);
assert_eq!(sort_key_opt, Some((Ordering::Equal, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 5);
}
assert_eq!(
segment_sort_key_computer.convert_segment_sort_key(expected_sort_key),
(200u32, 2u32)
);
}
}

View File

@@ -0,0 +1,193 @@
use std::ops::Range;
use crate::collector::sort_key::{Comparator, SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{Collector, SegmentCollector, TopNComputer};
use crate::query::Weight;
use crate::schema::Schema;
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
pub(crate) struct TopBySortKeyCollector<TSortKeyComputer> {
sort_key_computer: TSortKeyComputer,
doc_range: Range<usize>,
}
impl<TSortKeyComputer> TopBySortKeyCollector<TSortKeyComputer> {
pub fn new(sort_key_computer: TSortKeyComputer, doc_range: Range<usize>) -> Self {
TopBySortKeyCollector {
sort_key_computer,
doc_range,
}
}
}
impl<TSortKeyComputer> Collector for TopBySortKeyCollector<TSortKeyComputer>
where TSortKeyComputer: SortKeyComputer + Send + Sync + 'static
{
type Fruit = Vec<(TSortKeyComputer::SortKey, DocAddress)>;
type Child =
TopBySortKeySegmentCollector<TSortKeyComputer::Child, TSortKeyComputer::Comparator>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.sort_key_computer.check_schema(schema)
}
fn for_segment(&self, segment_ord: u32, segment_reader: &SegmentReader) -> Result<Self::Child> {
let segment_sort_key_computer = self
.sort_key_computer
.segment_sort_key_computer(segment_reader)?;
let topn_computer = TopNComputer::new_with_comparator(
self.doc_range.end,
self.sort_key_computer.comparator(),
);
Ok(TopBySortKeySegmentCollector {
topn_computer,
segment_ord,
segment_sort_key_computer,
})
}
fn requires_scoring(&self) -> bool {
self.sort_key_computer.requires_scoring()
}
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
Ok(merge_top_k(
segment_fruits.into_iter().flatten(),
self.doc_range.clone(),
self.sort_key_computer.comparator(),
))
}
fn collect_segment(
&self,
weight: &dyn Weight,
segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<Vec<(TSortKeyComputer::SortKey, DocAddress)>> {
let k = self.doc_range.end;
let docs = self
.sort_key_computer
.collect_segment_top_k(k, weight, reader, segment_ord)?;
Ok(docs)
}
}
fn merge_top_k<D: Ord, TSortKey: Clone + std::fmt::Debug, C: Comparator<TSortKey>>(
sort_key_docs: impl Iterator<Item = (TSortKey, D)>,
doc_range: Range<usize>,
comparator: C,
) -> Vec<(TSortKey, D)> {
if doc_range.is_empty() {
return Vec::new();
}
let mut top_collector: TopNComputer<TSortKey, D, C> =
TopNComputer::new_with_comparator(doc_range.end, comparator);
for (sort_key, doc) in sort_key_docs {
top_collector.push(sort_key, doc);
}
top_collector
.into_sorted_vec()
.into_iter()
.skip(doc_range.start)
.map(|cdoc| (cdoc.sort_key, cdoc.doc))
.collect()
}
pub struct TopBySortKeySegmentCollector<TSegmentSortKeyComputer, C>
where
TSegmentSortKeyComputer: SegmentSortKeyComputer,
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey>,
{
pub(crate) topn_computer: TopNComputer<TSegmentSortKeyComputer::SegmentSortKey, DocId, C>,
pub(crate) segment_ord: u32,
pub(crate) segment_sort_key_computer: TSegmentSortKeyComputer,
}
impl<TSegmentSortKeyComputer, C> SegmentCollector
for TopBySortKeySegmentCollector<TSegmentSortKeyComputer, C>
where
TSegmentSortKeyComputer: 'static + SegmentSortKeyComputer,
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey> + 'static,
{
type Fruit = Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)>;
fn collect(&mut self, doc: DocId, score: Score) {
self.segment_sort_key_computer.compute_sort_key_and_collect(
doc,
score,
&mut self.topn_computer,
);
}
fn harvest(self) -> Self::Fruit {
let segment_ord = self.segment_ord;
let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self
.topn_computer
.into_vec()
.into_iter()
.map(|comparable_doc| {
let sort_key = self
.segment_sort_key_computer
.convert_segment_sort_key(comparable_doc.sort_key);
(
sort_key,
DocAddress {
segment_ord,
doc_id: comparable_doc.doc,
},
)
})
.collect();
segment_hits
}
}
#[cfg(test)]
mod tests {
use std::ops::Range;
use rand;
use rand::seq::SliceRandom as _;
use super::merge_top_k;
use crate::collector::sort_key::ComparatorEnum;
use crate::Order;
fn test_merge_top_k_aux(
order: Order,
doc_range: Range<usize>,
expected: &[(crate::Score, usize)],
) {
let mut vals: Vec<(crate::Score, usize)> = (0..10).map(|val| (val as f32, val)).collect();
vals.shuffle(&mut rand::thread_rng());
let vals_merged = merge_top_k(vals.into_iter(), doc_range, ComparatorEnum::from(order));
assert_eq!(&vals_merged, expected);
}
#[test]
fn test_merge_top_k() {
test_merge_top_k_aux(Order::Asc, 0..0, &[]);
test_merge_top_k_aux(Order::Asc, 3..3, &[]);
test_merge_top_k_aux(Order::Asc, 0..3, &[(0.0f32, 0), (1.0f32, 1), (2.0f32, 2)]);
test_merge_top_k_aux(
Order::Asc,
0..11,
&[
(0.0f32, 0),
(1.0f32, 1),
(2.0f32, 2),
(3.0f32, 3),
(4.0f32, 4),
(5.0f32, 5),
(6.0f32, 6),
(7.0f32, 7),
(8.0f32, 8),
(9.0f32, 9),
],
);
test_merge_top_k_aux(Order::Asc, 1..3, &[(1.0f32, 1), (2.0f32, 2)]);
test_merge_top_k_aux(Order::Desc, 0..2, &[(9.0f32, 9), (8.0f32, 8)]);
test_merge_top_k_aux(Order::Desc, 2..4, &[(7.0f32, 7), (6.0f32, 6)]);
}
}

View File

@@ -40,7 +40,7 @@ pub fn test_filter_collector() -> crate::Result<()> {
let filter_some_collector = FilterCollector::new(
"price".to_string(),
&|value: u64| value > 20_120u64,
TopDocs::with_limit(2),
TopDocs::with_limit(2).order_by_score(),
);
let top_docs = searcher.search(&query, &filter_some_collector)?;
@@ -50,7 +50,7 @@ pub fn test_filter_collector() -> crate::Result<()> {
let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(
"price".to_string(),
&|value| value < 5u64,
TopDocs::with_limit(2),
TopDocs::with_limit(2).order_by_score(),
);
let filtered_top_docs = searcher.search(&query, &filter_all_collector).unwrap();
@@ -62,8 +62,11 @@ pub fn test_filter_collector() -> crate::Result<()> {
> 0
}
let filter_dates_collector =
FilterCollector::new("date".to_string(), &date_filter, TopDocs::with_limit(5));
let filter_dates_collector = FilterCollector::new(
"date".to_string(),
&date_filter,
TopDocs::with_limit(5).order_by_score(),
);
let filtered_date_docs = searcher.search(&query, &filter_dates_collector)?;
assert_eq!(filtered_date_docs.len(), 2);

View File

@@ -1,374 +1,22 @@
use std::cmp::Ordering;
use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
use super::top_score_collector::TopNComputer;
use crate::index::SegmentReader;
use crate::{DocAddress, DocId, SegmentOrdinal};
/// Contains a feature (field, score, etc.) of a document along with the document address.
///
/// It guarantees stable sorting: in case of a tie on the feature, the document
/// address is used.
///
/// The REVERSE_ORDER generic parameter controls whether the by-feature order
/// should be reversed, which is useful for achieving for example largest-first
/// semantics without having to wrap the feature in a `Reverse`.
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
/// Used only by TopNComputer, which implements the actual comparison via a `Comparator`.
#[derive(Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
pub struct ComparableDoc<T, D> {
/// The feature of the document. In practice, this is
/// is any type that implements `PartialOrd`.
pub feature: T,
/// The document address. In practice, this is any
/// type that implements `PartialOrd`, and is guaranteed
/// to be unique for each document.
/// is a type which can be compared with a `Comparator<T>`.
pub sort_key: T,
/// The document address. In practice, this is either a `DocId` or `DocAddress`.
pub doc: D,
}
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
for ComparableDoc<T, D, R>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
.field("feature", &self.feature)
impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ComparableDoc")
.field("feature", &self.sort_key)
.field("doc", &self.doc)
.finish()
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialOrd for ComparableDoc<T, D, R> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R> {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
let by_feature = self
.feature
.partial_cmp(&other.feature)
.map(|ord| if R { ord.reverse() } else { ord })
.unwrap_or(Ordering::Equal);
let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal);
// In case of a tie on the feature, we sort by ascending
// `DocAddress` in order to ensure a stable sorting of the
// documents.
by_feature.then_with(lazy_by_doc_address)
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T, D, R> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}
pub(crate) struct TopCollector<T> {
pub limit: usize,
pub offset: usize,
_marker: PhantomData<T>,
}
impl<T> TopCollector<T>
where T: PartialOrd + Clone
{
/// Creates a top collector, with a number of documents equal to "limit".
///
/// # Panics
/// The method panics if limit is 0
pub fn with_limit(limit: usize) -> TopCollector<T> {
assert!(limit >= 1, "Limit must be strictly greater than 0.");
Self {
limit,
offset: 0,
_marker: PhantomData,
}
}
/// Skip the first "offset" documents when collecting.
///
/// This is equivalent to `OFFSET` in MySQL or PostgreSQL and `start` in
/// Lucene's TopDocsCollector.
pub fn and_offset(mut self, offset: usize) -> TopCollector<T> {
self.offset = offset;
self
}
pub fn merge_fruits(
&self,
children: Vec<Vec<(T, DocAddress)>>,
) -> crate::Result<Vec<(T, DocAddress)>> {
if self.limit == 0 {
return Ok(Vec::new());
}
let mut top_collector: TopNComputer<_, _> = TopNComputer::new(self.limit + self.offset);
for child_fruit in children {
for (feature, doc) in child_fruit {
top_collector.push(feature, doc);
}
}
Ok(top_collector
.into_sorted_vec()
.into_iter()
.skip(self.offset)
.map(|cdoc| (cdoc.feature, cdoc.doc))
.collect())
}
pub(crate) fn for_segment<F: PartialOrd + Clone>(
&self,
segment_id: SegmentOrdinal,
_: &SegmentReader,
) -> TopSegmentCollector<F> {
TopSegmentCollector::new(segment_id, self.limit + self.offset)
}
/// Create a new TopCollector with the same limit and offset.
///
/// Ideally we would use Into but the blanket implementation seems to cause the Scorer traits
/// to fail.
#[doc(hidden)]
pub(crate) fn into_tscore<TScore: PartialOrd + Clone>(self) -> TopCollector<TScore> {
TopCollector {
limit: self.limit,
offset: self.offset,
_marker: PhantomData,
}
}
}
/// The Top Collector keeps track of the K documents
/// sorted by type `T`.
///
/// The implementation is based on a repeatedly truncating on the median after K * 2 documents
/// The theoretical complexity for collecting the top `K` out of `n` documents
/// is `O(n + K)`.
pub(crate) struct TopSegmentCollector<T> {
/// We reverse the order of the feature in order to
/// have top-semantics instead of bottom semantics.
topn_computer: TopNComputer<T, DocId>,
segment_ord: u32,
}
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
fn new(segment_ord: SegmentOrdinal, limit: usize) -> TopSegmentCollector<T> {
TopSegmentCollector {
topn_computer: TopNComputer::new(limit),
segment_ord,
}
}
}
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
pub fn harvest(self) -> Vec<(T, DocAddress)> {
let segment_ord = self.segment_ord;
self.topn_computer
.into_sorted_vec()
.into_iter()
.map(|comparable_doc| {
(
comparable_doc.feature,
DocAddress {
segment_ord,
doc_id: comparable_doc.doc,
},
)
})
.collect()
}
/// Collects a document scored by the given feature
///
/// It collects documents until it has reached the max capacity. Once it reaches capacity, it
/// will compare the lowest scoring item with the given one and keep whichever is greater.
#[inline]
pub fn collect(&mut self, doc: DocId, feature: T) {
self.topn_computer.push(feature, doc);
}
}
#[cfg(test)]
mod tests {
use super::{TopCollector, TopSegmentCollector};
use crate::DocAddress;
#[test]
fn test_top_collector_not_at_capacity() {
let mut top_collector = TopSegmentCollector::new(0, 4);
top_collector.collect(1, 0.8);
top_collector.collect(3, 0.2);
top_collector.collect(5, 0.3);
assert_eq!(
top_collector.harvest(),
vec![
(0.8, DocAddress::new(0, 1)),
(0.3, DocAddress::new(0, 5)),
(0.2, DocAddress::new(0, 3))
]
);
}
#[test]
fn test_top_collector_at_capacity() {
let mut top_collector = TopSegmentCollector::new(0, 4);
top_collector.collect(1, 0.8);
top_collector.collect(3, 0.2);
top_collector.collect(5, 0.3);
top_collector.collect(7, 0.9);
top_collector.collect(9, -0.2);
assert_eq!(
top_collector.harvest(),
vec![
(0.9, DocAddress::new(0, 7)),
(0.8, DocAddress::new(0, 1)),
(0.3, DocAddress::new(0, 5)),
(0.2, DocAddress::new(0, 3))
]
);
}
#[test]
fn test_top_segment_collector_stable_ordering_for_equal_feature() {
// given that the documents are collected in ascending doc id order,
// when harvesting we have to guarantee stable sorting in case of a tie
// on the score
let doc_ids_collection = [4, 5, 6];
let score = 3.3f32;
let mut top_collector_limit_2 = TopSegmentCollector::new(0, 2);
for id in &doc_ids_collection {
top_collector_limit_2.collect(*id, score);
}
let mut top_collector_limit_3 = TopSegmentCollector::new(0, 3);
for id in &doc_ids_collection {
top_collector_limit_3.collect(*id, score);
}
assert_eq!(
top_collector_limit_2.harvest(),
top_collector_limit_3.harvest()[..2].to_vec(),
);
}
#[test]
fn test_top_collector_with_limit_and_offset() {
let collector = TopCollector::with_limit(2).and_offset(1);
let results = collector
.merge_fruits(vec![vec![
(0.9, DocAddress::new(0, 1)),
(0.8, DocAddress::new(0, 2)),
(0.7, DocAddress::new(0, 3)),
(0.6, DocAddress::new(0, 4)),
(0.5, DocAddress::new(0, 5)),
]])
.unwrap();
assert_eq!(
results,
vec![(0.8, DocAddress::new(0, 2)), (0.7, DocAddress::new(0, 3)),]
);
}
#[test]
fn test_top_collector_with_limit_larger_than_set_and_offset() {
let collector = TopCollector::with_limit(2).and_offset(1);
let results = collector
.merge_fruits(vec![vec![
(0.9, DocAddress::new(0, 1)),
(0.8, DocAddress::new(0, 2)),
]])
.unwrap();
assert_eq!(results, vec![(0.8, DocAddress::new(0, 2)),]);
}
#[test]
fn test_top_collector_with_limit_and_offset_larger_than_set() {
let collector = TopCollector::with_limit(2).and_offset(20);
let results = collector
.merge_fruits(vec![vec![
(0.9, DocAddress::new(0, 1)),
(0.8, DocAddress::new(0, 2)),
]])
.unwrap();
assert_eq!(results, vec![]);
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use test::Bencher;
use super::TopSegmentCollector;
#[bench]
fn bench_top_segment_collector_collect_not_at_capacity(b: &mut Bencher) {
let mut top_collector = TopSegmentCollector::new(0, 400);
b.iter(|| {
for i in 0..100 {
top_collector.collect(i, 0.8);
}
});
}
#[bench]
fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) {
let mut top_collector = TopSegmentCollector::new(0, 100);
for i in 0..100 {
top_collector.collect(i, 0.8);
}
b.iter(|| {
for i in 0..100 {
top_collector.collect(i, 0.8);
}
});
}
#[bench]
fn bench_top_segment_collector_collect_and_harvest_many_ties(b: &mut Bencher) {
b.iter(|| {
let mut top_collector = TopSegmentCollector::new(0, 100);
for i in 0..100 {
top_collector.collect(i, 0.8);
}
// it would be nice to be able to do the setup N times but still
// measure only harvest(). We can't since harvest() consumes
// the top_collector.
top_collector.harvest()
});
}
#[bench]
fn bench_top_segment_collector_collect_and_harvest_no_tie(b: &mut Bencher) {
b.iter(|| {
let mut top_collector = TopSegmentCollector::new(0, 100);
let mut score = 1.0;
for i in 0..100 {
score += 1.0;
top_collector.collect(i, score);
}
// it would be nice to be able to do the setup N times but still
// measure only harvest(). We can't since harvest() consumes
// the top_collector.
top_collector.harvest()
});
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,124 +0,0 @@
use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
use crate::collector::{Collector, SegmentCollector};
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
pub(crate) struct TweakedScoreTopCollector<TScoreTweaker, TScore = Score> {
score_tweaker: TScoreTweaker,
collector: TopCollector<TScore>,
}
impl<TScoreTweaker, TScore> TweakedScoreTopCollector<TScoreTweaker, TScore>
where TScore: Clone + PartialOrd
{
pub fn new(
score_tweaker: TScoreTweaker,
collector: TopCollector<TScore>,
) -> TweakedScoreTopCollector<TScoreTweaker, TScore> {
TweakedScoreTopCollector {
score_tweaker,
collector,
}
}
}
/// A `ScoreSegmentTweaker` makes it possible to modify the default score
/// for a given document belonging to a specific segment.
///
/// It is the segment local version of the [`ScoreTweaker`].
pub trait ScoreSegmentTweaker<TScore>: 'static {
/// Tweak the given `score` for the document `doc`.
fn score(&mut self, doc: DocId, score: Score) -> TScore;
}
/// `ScoreTweaker` makes it possible to tweak the score
/// emitted by the scorer into another one.
///
/// The `ScoreTweaker` itself does not make much of the computation itself.
/// Instead, it helps constructing `Self::Child` instances that will compute
/// the score at a segment scale.
pub trait ScoreTweaker<TScore>: Sync {
/// Type of the associated [`ScoreSegmentTweaker`].
type Child: ScoreSegmentTweaker<TScore>;
/// Builds a child tweaker for a specific segment. The child scorer is associated with
/// a specific segment.
fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
}
impl<TScoreTweaker, TScore> Collector for TweakedScoreTopCollector<TScoreTweaker, TScore>
where
TScoreTweaker: ScoreTweaker<TScore> + Send + Sync,
TScore: 'static + PartialOrd + Clone + Send + Sync,
{
type Fruit = Vec<(TScore, DocAddress)>;
type Child = TopTweakedScoreSegmentCollector<TScoreTweaker::Child, TScore>;
fn for_segment(
&self,
segment_local_id: u32,
segment_reader: &SegmentReader,
) -> Result<Self::Child> {
let segment_scorer = self.score_tweaker.segment_tweaker(segment_reader)?;
let segment_collector = self.collector.for_segment(segment_local_id, segment_reader);
Ok(TopTweakedScoreSegmentCollector {
segment_collector,
segment_scorer,
})
}
fn requires_scoring(&self) -> bool {
true
}
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
self.collector.merge_fruits(segment_fruits)
}
}
pub struct TopTweakedScoreSegmentCollector<TSegmentScoreTweaker, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync + Sized,
TSegmentScoreTweaker: ScoreSegmentTweaker<TScore>,
{
segment_collector: TopSegmentCollector<TScore>,
segment_scorer: TSegmentScoreTweaker,
}
impl<TSegmentScoreTweaker, TScore> SegmentCollector
for TopTweakedScoreSegmentCollector<TSegmentScoreTweaker, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync,
TSegmentScoreTweaker: 'static + ScoreSegmentTweaker<TScore>,
{
type Fruit = Vec<(TScore, DocAddress)>;
fn collect(&mut self, doc: DocId, score: Score) {
let score = self.segment_scorer.score(doc, score);
self.segment_collector.collect(doc, score);
}
fn harvest(self) -> Vec<(TScore, DocAddress)> {
self.segment_collector.harvest()
}
}
impl<F, TScore, TSegmentScoreTweaker> ScoreTweaker<TScore> for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> TSegmentScoreTweaker,
TSegmentScoreTweaker: ScoreSegmentTweaker<TScore>,
{
type Child = TSegmentScoreTweaker;
fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((self)(segment_reader))
}
}
impl<F, TScore> ScoreSegmentTweaker<TScore> for F
where F: 'static + FnMut(DocId, Score) -> TScore
{
fn score(&mut self, doc: DocId, score: Score) -> TScore {
(self)(doc, score)
}
}

View File

@@ -69,7 +69,7 @@ fn assert_date_time_precision(index: &Index, doc_store_precision: DateTimePrecis
.parse_query("dateformat")
.expect("Failed to parse query");
let top_docs = searcher
.search(&query, &TopDocs::with_limit(1))
.search(&query, &TopDocs::with_limit(1).order_by_score())
.expect("Search failed");
assert_eq!(top_docs.len(), 1, "Expected 1 search result");

View File

@@ -3,6 +3,7 @@ use common::json_path_writer::{JSON_END_OF_PATH, JSON_PATH_SEGMENT_SEP};
use common::{replace_in_place, JsonPathWriter};
use rustc_hash::FxHashMap;
use crate::indexer::indexing_term::IndexingTerm;
use crate::postings::{IndexingContext, IndexingPosition, PostingsWriter};
use crate::schema::document::{ReferenceValue, ReferenceValueLeaf, Value};
use crate::schema::{Type, DATE_TIME_PRECISION_INDEXED};
@@ -77,7 +78,7 @@ fn index_json_object<'a, V: Value<'a>>(
doc: DocId,
json_visitor: V::ObjectIter,
text_analyzer: &mut TextAnalyzer,
term_buffer: &mut Term,
term_buffer: &mut IndexingTerm,
json_path_writer: &mut JsonPathWriter,
postings_writer: &mut dyn PostingsWriter,
ctx: &mut IndexingContext,
@@ -107,17 +108,17 @@ pub(crate) fn index_json_value<'a, V: Value<'a>>(
doc: DocId,
json_value: V,
text_analyzer: &mut TextAnalyzer,
term_buffer: &mut Term,
term_buffer: &mut IndexingTerm,
json_path_writer: &mut JsonPathWriter,
postings_writer: &mut dyn PostingsWriter,
ctx: &mut IndexingContext,
positions_per_path: &mut IndexingPositionsPerPath,
) {
let set_path_id = |term_buffer: &mut Term, unordered_id: u32| {
let set_path_id = |term_buffer: &mut IndexingTerm, unordered_id: u32| {
term_buffer.truncate_value_bytes(0);
term_buffer.append_bytes(&unordered_id.to_be_bytes());
};
let set_type = |term_buffer: &mut Term, typ: Type| {
let set_type = |term_buffer: &mut IndexingTerm, typ: Type| {
term_buffer.append_bytes(&[typ.to_code()]);
};
@@ -405,7 +406,7 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_str("red");
assert_eq!(term.serialized_term(), b"\x00\x00\x00\x01jcolor\x00sred")
assert_eq!(term.serialized_value_bytes(), b"color\x00sred".to_vec())
}
#[test]
@@ -415,8 +416,8 @@ mod tests {
term.append_type_and_fast_value(-4i64);
assert_eq!(
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc"
term.serialized_value_bytes(),
b"color\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc".to_vec()
)
}
@@ -427,8 +428,8 @@ mod tests {
term.append_type_and_fast_value(4u64);
assert_eq!(
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00u\x00\x00\x00\x00\x00\x00\x00\x04"
term.serialized_value_bytes(),
b"color\x00u\x00\x00\x00\x00\x00\x00\x00\x04".to_vec()
)
}
@@ -438,8 +439,8 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_fast_value(4.0f64);
assert_eq!(
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00f\xc0\x10\x00\x00\x00\x00\x00\x00"
term.serialized_value_bytes(),
b"color\x00f\xc0\x10\x00\x00\x00\x00\x00\x00".to_vec()
)
}
@@ -449,8 +450,8 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_fast_value(true);
assert_eq!(
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00o\x00\x00\x00\x00\x00\x00\x00\x01"
term.serialized_value_bytes(),
b"color\x00o\x00\x00\x00\x00\x00\x00\x00\x01".to_vec()
)
}

View File

@@ -225,6 +225,7 @@ impl Searcher {
enabled_scoring: EnableScoring,
) -> crate::Result<C::Fruit> {
let weight = query.weight(enabled_scoring)?;
collector.check_schema(self.schema())?;
let segment_readers = self.segment_readers();
let fruits = executor.map(
|(segment_ord, segment_reader)| {

View File

@@ -5,7 +5,7 @@ use std::ops::Range;
use common::{BinarySerializable, CountingWriter, HasLen, VInt};
use crate::directory::{FileSlice, TerminatingWrite, WritePtr};
use crate::schema::Field;
use crate::schema::{Field, Schema};
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
#[derive(Eq, PartialEq, Hash, Copy, Ord, PartialOrd, Clone, Debug)]
@@ -167,10 +167,11 @@ impl CompositeFile {
.map(|byte_range| self.data.slice(byte_range.clone()))
}
pub fn space_usage(&self) -> PerFieldSpaceUsage {
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
let mut fields = Vec::new();
for (&field_addr, byte_range) in &self.offsets_index {
let mut field_usage = FieldUsage::empty(field_addr.field);
let field_name = schema.get_field_name(field_addr.field).to_string();
let mut field_usage = FieldUsage::empty(field_name);
field_usage.add_field_idx(field_addr.idx, byte_range.len().into());
fields.push(field_usage);
}

View File

@@ -108,7 +108,7 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
/// Opens a file and returns a boxed `FileHandle`.
///
/// Users of `Directory` should typically call `Directory::open_read(...)`,
/// while `Directory` implementor should implement `get_file_handle()`.
/// while `Directory` implementer should implement `get_file_handle()`.
fn get_file_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>, OpenReadError>;
/// Once a virtual file is open, its data may not

View File

@@ -104,7 +104,7 @@ pub enum TantivyError {
#[error("{0:?}")]
IncompatibleIndex(Incompatibility),
/// An internal error occurred. This is are internal states that should not be reached.
/// e.g. a datastructure is incorrectly inititalized.
/// e.g. a datastructure is incorrectly initialized.
#[error("Internal error: '{0}'")]
InternalError(String),
#[error("Deserialize error: {0}")]

View File

@@ -726,22 +726,22 @@ mod tests {
.column_opt::<DateTime>("multi_date")
.unwrap()
.unwrap();
let mut dates = Vec::new();
{
assert_eq!(date_fast_field.get_val(0).into_timestamp_nanos(), 1i64);
dates_fast_field.fill_vals(0u32, &mut dates);
let dates: Vec<DateTime> = dates_fast_field.values_for_doc(0u32).collect();
assert_eq!(dates.len(), 2);
assert_eq!(dates[0].into_timestamp_nanos(), 2i64);
assert_eq!(dates[1].into_timestamp_nanos(), 3i64);
}
{
assert_eq!(date_fast_field.get_val(1).into_timestamp_nanos(), 4i64);
dates_fast_field.fill_vals(1u32, &mut dates);
let dates: Vec<DateTime> = dates_fast_field.values_for_doc(1u32).collect();
assert!(dates.is_empty());
}
{
assert_eq!(date_fast_field.get_val(2).into_timestamp_nanos(), 0i64);
dates_fast_field.fill_vals(2u32, &mut dates);
let dates: Vec<DateTime> = dates_fast_field.values_for_doc(2u32).collect();
assert_eq!(dates.len(), 2);
assert_eq!(dates[0].into_timestamp_nanos(), 5i64);
assert_eq!(dates[1].into_timestamp_nanos(), 6i64);

View File

@@ -8,7 +8,7 @@ use columnar::{
};
use common::ByteCount;
use crate::core::json_utils::encode_column_name;
use crate::core::json_utils::{encode_column_name, json_path_sep_to_dot};
use crate::directory::FileSlice;
use crate::schema::{Field, FieldEntry, FieldType, Schema};
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
@@ -39,19 +39,15 @@ impl FastFieldReaders {
self.resolve_column_name_given_default_field(column_name, default_field_opt)
}
pub(crate) fn space_usage(&self, schema: &Schema) -> io::Result<PerFieldSpaceUsage> {
pub(crate) fn space_usage(&self) -> io::Result<PerFieldSpaceUsage> {
let mut per_field_usages: Vec<FieldUsage> = Default::default();
for (field, field_entry) in schema.fields() {
let column_handles = self.columnar.read_columns(field_entry.name())?;
let num_bytes: ByteCount = column_handles
.iter()
.map(|column_handle| column_handle.num_bytes())
.sum();
let mut field_usage = FieldUsage::empty(field);
field_usage.add_field_idx(0, num_bytes);
for (mut field_name, column_handle) in self.columnar.iter_columns()? {
json_path_sep_to_dot(&mut field_name);
let space_usage = column_handle.space_usage()?;
let mut field_usage = FieldUsage::empty(field_name);
field_usage.set_column_usage(space_usage);
per_field_usages.push(field_usage);
}
// TODO fix space usage for JSON fields.
Ok(PerFieldSpaceUsage::new(per_field_usages))
}

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use super::{fieldnorm_to_id, id_to_fieldnorm};
use crate::directory::{CompositeFile, FileSlice, OwnedBytes};
use crate::schema::Field;
use crate::schema::{Field, Schema};
use crate::space_usage::PerFieldSpaceUsage;
use crate::DocId;
@@ -37,8 +37,8 @@ impl FieldNormReaders {
}
/// Return a break down of the space usage per field.
pub fn space_usage(&self) -> PerFieldSpaceUsage {
self.data.space_usage()
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
self.data.space_usage(schema)
}
/// Returns a handle to inner file

View File

@@ -13,9 +13,9 @@ use crate::store::Compressor;
use crate::{Inventory, Opstamp, TrackedObject};
#[derive(Clone, Debug, Serialize, Deserialize)]
struct DeleteMeta {
pub struct DeleteMeta {
num_deleted_docs: u32,
opstamp: Opstamp,
pub opstamp: Opstamp,
}
#[derive(Clone, Default)]
@@ -213,7 +213,7 @@ impl SegmentMeta {
struct InnerSegmentMeta {
segment_id: SegmentId,
max_doc: u32,
deletes: Option<DeleteMeta>,
pub deletes: Option<DeleteMeta>,
/// If you want to avoid the SegmentComponent::TempStore file to be covered by
/// garbage collection and deleted, set this to true. This is used during merge.
#[serde(skip)]
@@ -276,13 +276,14 @@ impl Default for IndexSettings {
}
/// The order to sort by
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub enum Order {
/// Ascending Order
Asc,
/// Descending Order
Desc,
}
impl Order {
/// return if the Order is ascending
pub fn is_asc(&self) -> bool {

View File

@@ -46,7 +46,7 @@ impl Segment {
///
/// This method is only used when updating `max_doc` from 0
/// as we finalize a fresh new segment.
pub(crate) fn with_max_doc(self, max_doc: u32) -> Segment {
pub fn with_max_doc(self, max_doc: u32) -> Segment {
Segment {
index: self.index,
meta: self.meta.with_max_doc(max_doc),

View File

@@ -455,11 +455,11 @@ impl SegmentReader {
pub fn space_usage(&self) -> io::Result<SegmentSpaceUsage> {
Ok(SegmentSpaceUsage::new(
self.num_docs(),
self.termdict_composite.space_usage(),
self.postings_composite.space_usage(),
self.positions_composite.space_usage(),
self.fast_fields_readers.space_usage(self.schema())?,
self.fieldnorm_readers.space_usage(),
self.termdict_composite.space_usage(self.schema()),
self.postings_composite.space_usage(self.schema()),
self.positions_composite.space_usage(self.schema()),
self.fast_fields_readers.space_usage()?,
self.fieldnorm_readers.space_usage(self.schema()),
self.get_store_reader(0)?.space_usage(),
self.alive_bitset_opt
.as_ref()
@@ -608,7 +608,7 @@ mod test {
term_dictionary_size: Some(ByteCount::from(100u64)),
postings_size: Some(ByteCount::from(1_000u64)),
positions_size: Some(ByteCount::from(2_000u64)),
fast_size: Some(ByteCount::from(1_000u64).into()),
fast_size: Some(ByteCount::from(1_000u64)),
};
let field_metadata2 = FieldMetadata {
field_name: "a".to_string(),
@@ -617,7 +617,7 @@ mod test {
term_dictionary_size: Some(ByteCount::from(80u64)),
postings_size: Some(ByteCount::from(1_500u64)),
positions_size: Some(ByteCount::from(2_500u64)),
fast_size: Some(ByteCount::from(3_000u64).into()),
fast_size: Some(ByteCount::from(3_000u64)),
};
let expected = FieldMetadata {
field_name: "a".to_string(),
@@ -626,7 +626,7 @@ mod test {
term_dictionary_size: Some(ByteCount::from(180u64)),
postings_size: Some(ByteCount::from(2_500u64)),
positions_size: Some(ByteCount::from(4_500u64)),
fast_size: Some(ByteCount::from(4_000u64).into()),
fast_size: Some(ByteCount::from(4_000u64)),
};
assert_merge(
&[vec![field_metadata1.clone()], vec![field_metadata2]],

View File

@@ -23,13 +23,18 @@ struct InnerDeleteQueue {
last_block: Weak<Block>,
}
/// The delete queue is a linked list storing delete operations.
///
/// Several consumers can hold a reference to it. Delete operations
/// get dropped/gc'ed when no more consumers are holding a reference
/// to them.
#[derive(Clone)]
pub struct DeleteQueue {
inner: Arc<RwLock<InnerDeleteQueue>>,
}
impl DeleteQueue {
// Creates a new delete queue.
/// Creates a new empty delete queue.
pub fn new() -> DeleteQueue {
DeleteQueue {
inner: Arc::default(),
@@ -58,10 +63,10 @@ impl DeleteQueue {
block
}
// Creates a new cursor that makes it possible to
// consume future delete operations.
//
// Past delete operations are not accessible.
/// Creates a new cursor that makes it possible to
/// consume future delete operations.
///
/// Past delete operations are not accessible.
pub fn cursor(&self) -> DeleteCursor {
let last_block = self.get_last_block();
let operations_len = last_block.operations.len();
@@ -71,7 +76,7 @@ impl DeleteQueue {
}
}
// Appends a new delete operations.
/// Appends a new delete operations.
pub fn push(&self, delete_operation: DeleteOperation) {
self.inner
.write()
@@ -169,6 +174,7 @@ struct Block {
next: NextBlock,
}
/// As we process delete operations, keeps track of our position.
#[derive(Clone)]
pub struct DeleteCursor {
block: Arc<Block>,

View File

@@ -128,7 +128,7 @@ fn compute_deleted_bitset(
/// is `==` target_opstamp.
/// For instance, there was no delete operation between the state of the `segment_entry` and
/// the `target_opstamp`, `segment_entry` is not updated.
pub(crate) fn advance_deletes(
pub fn advance_deletes(
mut segment: Segment,
segment_entry: &mut SegmentEntry,
target_opstamp: Opstamp,
@@ -513,7 +513,7 @@ impl<D: Document> IndexWriter<D> {
/// let searcher = index.reader()?.searcher();
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query_promo = query_parser.parse_query("Prometheus")?;
/// let top_docs_promo = searcher.search(&query_promo, &TopDocs::with_limit(1))?;
/// let top_docs_promo = searcher.search(&query_promo, &TopDocs::with_limit(1).order_by_score())?;
///
/// assert!(top_docs_promo.is_empty());
/// Ok(())
@@ -946,11 +946,11 @@ mod tests {
let searcher = reader.searcher();
let a_docs = searcher
.search(&a_query, &TopDocs::with_limit(1))
.search(&a_query, &TopDocs::with_limit(1).order_by_score())
.expect("search for a failed");
let b_docs = searcher
.search(&b_query, &TopDocs::with_limit(1))
.search(&b_query, &TopDocs::with_limit(1).order_by_score())
.expect("search for b failed");
assert_eq!(a_docs.len(), 1);
@@ -2014,8 +2014,9 @@ mod tests {
let query = QueryParser::for_index(&index, vec![field])
.parse_query(term)
.unwrap();
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(1000)).unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(1000).order_by_score())
.unwrap();
top_docs.iter().map(|el| el.1).collect::<Vec<_>>()
};
@@ -2449,8 +2450,9 @@ mod tests {
Term::from_field_u64(id_field, existing_id),
IndexRecordOption::Basic,
);
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(10).order_by_score())
.unwrap();
assert_eq!(top_docs.len(), 1); // Was failing
@@ -2491,8 +2493,9 @@ mod tests {
Term::from_field_i64(id_field, 10i64),
IndexRecordOption::Basic,
);
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(10).order_by_score())
.unwrap();
assert_eq!(top_docs.len(), 1); // Fails
@@ -2500,8 +2503,9 @@ mod tests {
Term::from_field_i64(id_field, 30i64),
IndexRecordOption::Basic,
);
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(10).order_by_score())
.unwrap();
assert_eq!(top_docs.len(), 1); // Fails

View File

@@ -0,0 +1,214 @@
use std::net::Ipv6Addr;
use columnar::MonotonicallyMappableToU128;
use crate::fastfield::FastValue;
use crate::schema::Field;
/// IndexingTerm is used to represent a term during indexing.
/// It's a serialized representation over field and value.
///
/// It actually wraps a `Vec<u8>`. The first 4 bytes are the field.
///
/// We serialize the field, because we index everything in a single
/// global term dictionary during indexing.
#[derive(Clone)]
pub(crate) struct IndexingTerm<B = Vec<u8>>(B)
where B: AsRef<[u8]>;
/// The number of bytes used as metadata by `Term`.
const TERM_METADATA_LENGTH: usize = 4;
impl IndexingTerm {
/// Create a new Term with a buffer with a given capacity.
pub fn with_capacity(capacity: usize) -> IndexingTerm {
let mut data = Vec::with_capacity(TERM_METADATA_LENGTH + capacity);
data.resize(TERM_METADATA_LENGTH, 0u8);
IndexingTerm(data)
}
/// Panics when the term is not empty... ie: some value is set.
/// Use `clear_with_field_and_type` in that case.
///
/// Sets field and the type.
pub(crate) fn set_field(&mut self, field: Field) {
assert!(self.is_empty());
self.0[0..4].clone_from_slice(field.field_id().to_be_bytes().as_ref());
}
/// Is empty if there are no value bytes.
pub fn is_empty(&self) -> bool {
self.0.len() == TERM_METADATA_LENGTH
}
/// Removes the value_bytes and set the field
pub(crate) fn clear_with_field(&mut self, field: Field) {
self.truncate_value_bytes(0);
self.set_field(field);
}
/// Sets a u64 value in the term.
///
/// U64 are serialized using (8-byte) BigEndian
/// representation.
/// The use of BigEndian has the benefit of preserving
/// the natural order of the values.
pub fn set_u64(&mut self, val: u64) {
self.set_fast_value(val);
}
/// Sets a `i64` value in the term.
pub fn set_i64(&mut self, val: i64) {
self.set_fast_value(val);
}
/// Sets a `f64` value in the term.
pub fn set_f64(&mut self, val: f64) {
self.set_fast_value(val);
}
/// Sets a `bool` value in the term.
pub fn set_bool(&mut self, val: bool) {
self.set_fast_value(val);
}
fn set_fast_value<T: FastValue>(&mut self, val: T) {
self.set_bytes(val.to_u64().to_be_bytes().as_ref());
}
/// Append a type marker + fast value to a term.
/// This is used in JSON type to append a fast value after the path.
///
/// It will not clear existing bytes.
pub fn append_type_and_fast_value<T: FastValue>(&mut self, val: T) {
self.0.push(T::to_type().to_code());
let value = val.to_u64();
self.0.extend(value.to_be_bytes().as_ref());
}
/// Sets a `Ipv6Addr` value in the term.
pub fn set_ip_addr(&mut self, val: Ipv6Addr) {
self.set_bytes(val.to_u128().to_be_bytes().as_ref());
}
/// Sets the value of a `Bytes` field.
pub fn set_bytes(&mut self, bytes: &[u8]) {
self.truncate_value_bytes(0);
self.0.extend(bytes);
}
/// Truncates the value bytes of the term. Value and field type stays the same.
pub fn truncate_value_bytes(&mut self, len: usize) {
self.0.truncate(len + TERM_METADATA_LENGTH);
}
/// The length of the bytes.
pub fn len_bytes(&self) -> usize {
self.0.len() - TERM_METADATA_LENGTH
}
/// Appends value bytes to the Term.
///
/// This function returns the segment that has just been added.
#[inline]
pub fn append_bytes(&mut self, bytes: &[u8]) -> &mut [u8] {
let len_before = self.0.len();
self.0.extend_from_slice(bytes);
&mut self.0[len_before..]
}
}
impl<B> IndexingTerm<B>
where B: AsRef<[u8]>
{
/// Wraps serialized term bytes.
///
/// The input buffer is expected to be the concatenation of the big endian encoded field id
/// followed by the serialized value bytes (type tag + payload).
#[inline]
pub fn wrap(serialized_term: B) -> IndexingTerm<B> {
debug_assert!(serialized_term.as_ref().len() >= TERM_METADATA_LENGTH);
IndexingTerm(serialized_term)
}
/// Returns the field this term belongs to.
#[inline]
pub fn field(&self) -> Field {
let field_id_bytes: [u8; 4] = self.0.as_ref()[..4].try_into().unwrap();
Field::from_field_id(u32::from_be_bytes(field_id_bytes))
}
/// Returns the serialized representation of Term.
/// This includes field_id, value type and value.
///
/// Do NOT rely on this byte representation in the index.
/// This value is likely to change in the future.
#[inline]
pub fn serialized_term(&self) -> &[u8] {
self.0.as_ref()
}
}
#[cfg(test)]
mod tests {
use super::IndexingTerm;
use crate::schema::*;
#[test]
pub fn test_term_str() {
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("text", STRING);
let title_field = schema_builder.add_text_field("title", STRING);
let mut term = IndexingTerm::with_capacity(0);
term.set_field(title_field);
term.set_bytes(b"test");
assert_eq!(term.field(), title_field);
assert_eq!(term.serialized_term(), b"\x00\x00\x00\x01test".to_vec())
}
/// Size (in bytes) of the buffer of a fast value (u64, i64, f64, or date) term.
/// <field> + <type byte> + <value len>
///
/// - <field> is a big endian encoded u32 field id
/// - <value> is, if this is not the json term, a binary representation specific to the type.
/// If it is a JSON Term, then it is prepended with the path that leads to this leaf value.
const FAST_VALUE_TERM_LEN: usize = 4 + 8;
#[test]
pub fn test_term_u64() {
let mut schema_builder = Schema::builder();
let count_field = schema_builder.add_u64_field("count", INDEXED);
let mut term = IndexingTerm::with_capacity(0);
term.set_field(count_field);
term.set_u64(983u64);
assert_eq!(term.field(), count_field);
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
}
#[test]
pub fn test_term_bool() {
let mut schema_builder = Schema::builder();
let bool_field = schema_builder.add_bool_field("bool", INDEXED);
let term = {
let mut term = IndexingTerm::with_capacity(0);
term.set_field(bool_field);
term.set_bool(true);
term
};
assert_eq!(term.field(), bool_field);
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
}
#[test]
pub fn indexing_term_wrap_extracts_field() {
let field = Field::from_field_id(7u32);
let mut term = IndexingTerm::with_capacity(0);
term.set_field(field);
term.append_bytes(b"abc");
let wrapped = IndexingTerm::wrap(term.serialized_term());
assert_eq!(wrapped.field(), field);
assert_eq!(wrapped.serialized_term(), term.serialized_term());
}
}

View File

@@ -104,8 +104,9 @@ mod tests {
let query = QueryParser::for_index(&index, vec![my_text_field])
.parse_query(term)
.unwrap();
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(3)).unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(3).order_by_score())
.unwrap();
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>()
};

View File

@@ -1518,7 +1518,8 @@ mod tests {
let searcher = reader.searcher();
let mut term_scorer = term_query
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
.specialized_scorer(searcher.segment_reader(0u32), 1.0)?;
.term_scorer_for_test(searcher.segment_reader(0u32), 1.0)?
.unwrap();
assert_eq!(term_scorer.doc(), 0);
assert_nearly_equals!(term_scorer.block_max_score(), 0.0079681855);
assert_nearly_equals!(term_scorer.score(), 0.0079681855);
@@ -1533,7 +1534,8 @@ mod tests {
for segment_reader in searcher.segment_readers() {
let mut term_scorer = term_query
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
.specialized_scorer(segment_reader, 1.0)?;
.term_scorer_for_test(segment_reader, 1.0)?
.unwrap();
// the difference compared to before is intrinsic to the bm25 formula. no worries
// there.
for doc in segment_reader.doc_ids_alive() {
@@ -1558,7 +1560,8 @@ mod tests {
let segment_reader = searcher.segment_reader(0u32);
let mut term_scorer = term_query
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
.specialized_scorer(segment_reader, 1.0)?;
.term_scorer_for_test(segment_reader, 1.0)?
.unwrap();
// the difference compared to before is intrinsic to the bm25 formula. no worries there.
for doc in segment_reader.doc_ids_alive() {
assert_eq!(term_scorer.doc(), doc);

View File

@@ -4,7 +4,7 @@
//! `IndexWriter` is the main entry point for that, which created from
//! [`Index::writer`](crate::Index::writer).
pub(crate) mod delete_queue;
pub mod delete_queue;
pub(crate) mod path_to_unordered_id;
pub(crate) mod doc_id_mapping;
@@ -12,6 +12,7 @@ mod doc_opstamp_mapping;
mod flat_map_with_buffer;
pub(crate) mod index_writer;
pub(crate) mod index_writer_status;
pub(crate) mod indexing_term;
mod log_merge_policy;
mod merge_index_test;
mod merge_operation;
@@ -31,12 +32,11 @@ mod stamper;
use crossbeam_channel as channel;
use smallvec::SmallVec;
pub use self::index_writer::{IndexWriter, IndexWriterOptions};
pub use self::index_writer::{advance_deletes, IndexWriter, IndexWriterOptions};
pub use self::log_merge_policy::LogMergePolicy;
pub use self::merge_operation::MergeOperation;
pub use self::merge_policy::{MergeCandidate, MergePolicy, NoMergePolicy};
use self::operation::AddOperation;
pub use self::operation::UserOperation;
pub use self::operation::{AddOperation, DeleteOperation, UserOperation};
pub use self::prepared_commit::PreparedCommit;
pub use self::segment_entry::SegmentEntry;
pub(crate) use self::segment_serializer::SegmentSerializer;
@@ -181,6 +181,7 @@ mod tests_mmap {
let field_name_out = ".";
test_json_field_name(field_name_in, field_name_out);
}
#[test]
fn test_json_field_dot() {
// Test when field name contains a '.'
@@ -587,7 +588,9 @@ mod tests_mmap {
};
let query_str = &format!("{}:{}", indexed_field.field_name, val);
let query = query_parser.parse_query(query_str).unwrap();
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2)).unwrap();
let count_docs = searcher
.search(&*query, &TopDocs::with_limit(2).order_by_score())
.unwrap();
if indexed_field.field_name.contains("empty") || indexed_field.typ == Type::Json {
assert_eq!(count_docs.len(), 0);
} else {
@@ -659,7 +662,9 @@ mod tests_mmap {
for (indexed_field, val) in fields_and_vals.iter() {
let query_str = &format!("{indexed_field}:{val}");
let query = query_parser.parse_query(query_str).unwrap();
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2)).unwrap();
let count_docs = searcher
.search(&*query, &TopDocs::with_limit(2).order_by_score())
.unwrap();
assert!(!count_docs.is_empty(), "{indexed_field}:{val}");
}
// Test if field name can be used for aggregation

View File

@@ -5,14 +5,20 @@ use crate::Opstamp;
/// Timestamped Delete operation.
pub struct DeleteOperation {
/// Operation stamp.
/// It is used to check whether the delete operation
/// applies to an added document operation.
pub opstamp: Opstamp,
/// Weight is used to define the set of documents to be deleted.
pub target: Box<dyn Weight>,
}
/// Timestamped Add operation.
#[derive(Eq, PartialEq, Debug)]
pub struct AddOperation<D: Document = TantivyDocument> {
/// Operation stamp.
pub opstamp: Opstamp,
/// Document to be added.
pub document: D,
}

View File

@@ -1052,8 +1052,9 @@ mod tests {
let query = QueryParser::for_index(&index, vec![text_field])
.parse_query(term)
.unwrap();
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(3)).unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(3).order_by_score())
.unwrap();
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>()
};

View File

@@ -7,6 +7,7 @@ use super::operation::AddOperation;
use crate::fastfield::FastFieldsWriter;
use crate::fieldnorm::{FieldNormReaders, FieldNormsWriter};
use crate::index::{Segment, SegmentComponent};
use crate::indexer::indexing_term::IndexingTerm;
use crate::indexer::segment_serializer::SegmentSerializer;
use crate::json_utils::{index_json_value, IndexingPositionsPerPath};
use crate::postings::{
@@ -14,7 +15,7 @@ use crate::postings::{
PerFieldPostingsWriter, PostingsWriter,
};
use crate::schema::document::{Document, Value};
use crate::schema::{FieldEntry, FieldType, Schema, Term, DATE_TIME_PRECISION_INDEXED};
use crate::schema::{FieldEntry, FieldType, Schema, DATE_TIME_PRECISION_INDEXED};
use crate::tokenizer::{FacetTokenizer, PreTokenizedStream, TextAnalyzer, Tokenizer};
use crate::{DocId, Opstamp, TantivyError};
@@ -55,7 +56,7 @@ pub struct SegmentWriter {
pub(crate) json_positions_per_path: IndexingPositionsPerPath,
pub(crate) doc_opstamps: Vec<Opstamp>,
per_field_text_analyzers: Vec<TextAnalyzer>,
term_buffer: Term,
term_buffer: IndexingTerm,
schema: Schema,
}
@@ -112,7 +113,7 @@ impl SegmentWriter {
)?,
doc_opstamps: Vec::with_capacity(1_000),
per_field_text_analyzers,
term_buffer: Term::with_capacity(16),
term_buffer: IndexingTerm::with_capacity(16),
schema,
})
}
@@ -170,7 +171,7 @@ impl SegmentWriter {
let (term_buffer, ctx) = (&mut self.term_buffer, &mut self.ctx);
let postings_writer: &mut dyn PostingsWriter =
self.per_field_postings_writers.get_for_field_mut(field);
term_buffer.clear_with_field_and_type(field_entry.field_type().value_type(), field);
term_buffer.clear_with_field(field);
match field_entry.field_type() {
FieldType::Facet(_) => {
@@ -519,7 +520,7 @@ mod tests {
.reader()
.unwrap()
.searcher()
.search(&text_query, &TopDocs::with_limit(4))
.search(&text_query, &TopDocs::with_limit(4).order_by_score())
.unwrap();
assert_eq!(score_docs.len(), 1);
@@ -528,7 +529,7 @@ mod tests {
.reader()
.unwrap()
.searcher()
.search(&text_query, &TopDocs::with_limit(4))
.search(&text_query, &TopDocs::with_limit(4).order_by_score())
.unwrap();
assert_eq!(score_docs.len(), 2);
}
@@ -561,7 +562,7 @@ mod tests {
.reader()
.unwrap()
.searcher()
.search(&text_query, &TopDocs::with_limit(4))
.search(&text_query, &TopDocs::with_limit(4).order_by_score())
.unwrap();
assert_eq!(score_docs.len(), 1);
};

View File

@@ -42,7 +42,6 @@ mod test {
use super::Stamper;
#[expect(clippy::redundant_clone)]
#[test]
fn test_stamper() {
let stamper = Stamper::new(7u64);
@@ -58,7 +57,6 @@ mod test {
assert_eq!(stamper.stamp(), 15u64);
}
#[expect(clippy::redundant_clone)]
#[test]
fn test_stamper_revert() {
let stamper = Stamper::new(7u64);

View File

@@ -85,7 +85,7 @@
//! // Perform search.
//! // `topdocs` contains the 10 most relevant doc ids, sorted by decreasing scores...
//! let top_docs: Vec<(Score, DocAddress)> =
//! searcher.search(&query, &TopDocs::with_limit(10))?;
//! searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
//!
//! for (_score, doc_address) in top_docs {
//! // Retrieve the actual content of documents given its `doc_address`.
@@ -125,7 +125,7 @@
//!
//! - **Searching**: [Searcher] searches the segments with anything that implements
//! [Query](query::Query) and merges the results. The list of [supported
//! queries](query::Query#implementors). Custom Queries are supported by implementing the
//! queries](query::Query#implementers). Custom Queries are supported by implementing the
//! [Query](query::Query) trait.
//!
//! - **[Directory](directory)**: Abstraction over the storage where the index data is stored.
@@ -216,9 +216,7 @@ use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
pub use self::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN, TERMINATED};
#[doc(hidden)]
pub use crate::core::json_utils;
pub use crate::core::{Executor, Searcher, SearcherGeneration};
pub use crate::core::{json_utils, Executor, Searcher, SearcherGeneration};
pub use crate::directory::Directory;
pub use crate::index::{
Index, IndexBuilder, IndexMeta, IndexSettings, InvertedIndexReader, Order, Segment,

View File

@@ -1,8 +1,10 @@
use bitpacking::{BitPacker, BitPacker4x};
use common::FixedSize;
pub const COMPRESSION_BLOCK_SIZE: usize = BitPacker4x::BLOCK_LEN;
const COMPRESSED_BLOCK_MAX_SIZE: usize = COMPRESSION_BLOCK_SIZE * u32::SIZE_IN_BYTES;
// in vint encoding, each byte stores 7 bits of data, so we need at most 32 / 7 = 4.57 bytes to
// store a u32 in the worst case, rounding up to 5 bytes total
const MAX_VINT_SIZE: usize = 5;
const COMPRESSED_BLOCK_MAX_SIZE: usize = COMPRESSION_BLOCK_SIZE * MAX_VINT_SIZE;
mod vint;
@@ -267,7 +269,6 @@ impl VIntDecoder for BlockDecoder {
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::TERMINATED;
@@ -372,6 +373,13 @@ pub(crate) mod tests {
}
}
}
#[test]
fn test_compress_vint_unsorted_does_not_overflow() {
let mut encoder = BlockEncoder::new();
let input: Vec<u32> = vec![u32::MAX; COMPRESSION_BLOCK_SIZE];
encoder.compress_vint_unsorted(&input);
}
}
#[cfg(all(test, feature = "unstable"))]

View File

@@ -3,13 +3,14 @@ use std::io;
use common::json_path_writer::JSON_END_OF_PATH;
use stacker::Addr;
use crate::indexer::indexing_term::IndexingTerm;
use crate::indexer::path_to_unordered_id::OrderedPathId;
use crate::postings::postings_writer::SpecializedPostingsWriter;
use crate::postings::recorder::{BufferLender, DocIdRecorder, Recorder};
use crate::postings::{FieldSerializer, IndexingContext, IndexingPosition, PostingsWriter};
use crate::schema::{Field, Type};
use crate::tokenizer::TokenStream;
use crate::{DocId, Term};
use crate::DocId;
/// The `JsonPostingsWriter` is odd in that it relies on a hidden contract:
///
@@ -33,7 +34,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
&mut self,
doc: crate::DocId,
pos: u32,
term: &crate::Term,
term: &IndexingTerm,
ctx: &mut IndexingContext,
) {
self.non_str_posting_writer.subscribe(doc, pos, term, ctx);
@@ -43,7 +44,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
&mut self,
doc_id: DocId,
token_stream: &mut dyn TokenStream,
term_buffer: &mut Term,
term_buffer: &mut IndexingTerm,
ctx: &mut IndexingContext,
indexing_position: &mut IndexingPosition,
) {
@@ -64,40 +65,37 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
ctx: &IndexingContext,
serializer: &mut FieldSerializer,
) -> io::Result<()> {
let mut term_buffer = Term::with_capacity(48);
let mut term_buffer = JsonTermSerializer(Vec::with_capacity(48));
let mut buffer_lender = BufferLender::default();
term_buffer.clear_with_field_and_type(Type::Json, Field::from_field_id(0));
let mut prev_term_id = u32::MAX;
let mut term_path_len = 0; // this will be set in the first iteration
for (_field, path_id, term, addr) in ordered_term_addrs {
if prev_term_id != path_id.path_id() {
term_buffer.truncate_value_bytes(0);
term_buffer.append_path(ordered_id_to_path[path_id.path_id() as usize].as_bytes());
term_buffer.append_bytes(&[JSON_END_OF_PATH]);
term_path_len = term_buffer.len_bytes();
term_buffer.clear();
term_buffer.append_json_path(ordered_id_to_path[path_id.path_id() as usize]);
term_path_len = term_buffer.len();
prev_term_id = path_id.path_id();
}
term_buffer.truncate_value_bytes(term_path_len);
term_buffer.truncate(term_path_len);
term_buffer.append_bytes(term);
if let Some(json_value) = term_buffer.value().as_json_value_bytes() {
let typ = json_value.typ();
if typ == Type::Str {
SpecializedPostingsWriter::<Rec>::serialize_one_term(
term_buffer.serialized_value_bytes(),
*addr,
&mut buffer_lender,
ctx,
serializer,
)?;
} else {
SpecializedPostingsWriter::<DocIdRecorder>::serialize_one_term(
term_buffer.serialized_value_bytes(),
*addr,
&mut buffer_lender,
ctx,
serializer,
)?;
}
let typ = Type::from_code(term[0]).expect("Invalid type code in JSON term");
if typ == Type::Str {
SpecializedPostingsWriter::<Rec>::serialize_one_term(
term_buffer.as_bytes(),
*addr,
&mut buffer_lender,
ctx,
serializer,
)?;
} else {
SpecializedPostingsWriter::<DocIdRecorder>::serialize_one_term(
term_buffer.as_bytes(),
*addr,
&mut buffer_lender,
ctx,
serializer,
)?;
}
}
Ok(())
@@ -107,3 +105,50 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
self.str_posting_writer.total_num_tokens() + self.non_str_posting_writer.total_num_tokens()
}
}
/// Helper to build the JSON term bytes that land in the term dictionary.
/// Format: `[json path utf8][JSON_END_OF_PATH][type tag][payload]`
struct JsonTermSerializer(Vec<u8>);
impl JsonTermSerializer {
/// Appends a JSON path to the Term.
/// The path is terminated by a special end-of-path 0 byte.
#[inline]
pub fn append_json_path(&mut self, path: &str) {
let bytes = path.as_bytes();
// Replace any occurrence of the end-of-path byte with Ascii '0' byte.
if bytes.contains(&JSON_END_OF_PATH) {
self.0.extend(
bytes
.iter()
.map(|&b| if b == JSON_END_OF_PATH { b'0' } else { b }),
);
} else {
self.0.extend_from_slice(bytes);
}
self.0.push(JSON_END_OF_PATH);
}
/// Appends value bytes to the Term.
///
/// This function returns the segment that has just been added.
#[inline]
pub fn append_bytes(&mut self, bytes: &[u8]) -> &mut [u8] {
let len_before = self.0.len();
self.0.extend_from_slice(bytes);
&mut self.0[len_before..]
}
fn clear(&mut self) {
self.0.clear();
}
fn truncate(&mut self, len: usize) {
self.0.truncate(len);
}
fn len(&self) -> usize {
self.0.len()
}
fn as_bytes(&self) -> &[u8] {
&self.0
}
}

View File

@@ -5,12 +5,13 @@ use std::ops::Range;
use stacker::Addr;
use crate::fieldnorm::FieldNormReaders;
use crate::indexer::indexing_term::IndexingTerm;
use crate::indexer::path_to_unordered_id::OrderedPathId;
use crate::postings::recorder::{BufferLender, Recorder};
use crate::postings::{
FieldSerializer, IndexingContext, InvertedIndexSerializer, PerFieldPostingsWriter,
};
use crate::schema::{Field, Schema, Term, Type};
use crate::schema::{Field, Schema, Type};
use crate::tokenizer::{Token, TokenStream, MAX_TOKEN_LEN};
use crate::DocId;
@@ -58,14 +59,14 @@ pub(crate) fn serialize_postings(
let mut term_offsets: Vec<(Field, OrderedPathId, &[u8], Addr)> =
Vec::with_capacity(ctx.term_index.len());
term_offsets.extend(ctx.term_index.iter().map(|(key, addr)| {
let field = Term::wrap(key).field();
let field = IndexingTerm::wrap(key).field();
if schema.get_field_entry(field).field_type().value_type() == Type::Json {
let byte_range_path = 5..5 + 4;
let byte_range_path = 4..4 + 4;
let unordered_id = u32::from_be_bytes(key[byte_range_path.clone()].try_into().unwrap());
let path_id = unordered_id_to_ordered_id[unordered_id as usize];
(field, path_id, &key[byte_range_path.end..], addr)
} else {
(field, 0.into(), &key[5..], addr)
(field, 0.into(), &key[4..], addr)
}
}));
// Sort by field, path, and term
@@ -111,7 +112,7 @@ pub(crate) trait PostingsWriter: Send + Sync {
/// * term - the term
/// * ctx - Contains a term hashmap and a memory arena to store all necessary posting list
/// information.
fn subscribe(&mut self, doc: DocId, pos: u32, term: &Term, ctx: &mut IndexingContext);
fn subscribe(&mut self, doc: DocId, pos: u32, term: &IndexingTerm, ctx: &mut IndexingContext);
/// Serializes the postings on disk.
/// The actual serialization format is handled by the `PostingsSerializer`.
@@ -128,7 +129,7 @@ pub(crate) trait PostingsWriter: Send + Sync {
&mut self,
doc_id: DocId,
token_stream: &mut dyn TokenStream,
term_buffer: &mut Term,
term_buffer: &mut IndexingTerm,
ctx: &mut IndexingContext,
indexing_position: &mut IndexingPosition,
) {
@@ -198,7 +199,13 @@ impl<Rec: Recorder> SpecializedPostingsWriter<Rec> {
impl<Rec: Recorder> PostingsWriter for SpecializedPostingsWriter<Rec> {
#[inline]
fn subscribe(&mut self, doc: DocId, position: u32, term: &Term, ctx: &mut IndexingContext) {
fn subscribe(
&mut self,
doc: DocId,
position: u32,
term: &IndexingTerm,
ctx: &mut IndexingContext,
) {
debug_assert!(term.serialized_term().len() >= 4);
self.total_num_tokens += 1;
let (term_index, arena) = (&mut ctx.term_index, &mut ctx.arena);

View File

@@ -6,17 +6,21 @@ use crate::{DocId, Score, TERMINATED};
// doc num bits uses the following encoding:
// given 0b a b cdefgh
// |1|2| 3 |
// |1|2|3| 4 |
// - 1: unused
// - 2: is delta-1 encoded. 0 if not, 1, if yes
// - 3: a 6 bit number in 0..=32, the actual bitwidth
// - 3: unused
// - 4: a 5 bit number in 0..32, the actual bitwidth. Bitpacking could in theory say this is 32
// (requiring a 6th bit), but the biggest doc_id we can want to encode is TERMINATED-1, which can
// be represented on 31b without delta encoding.
fn encode_bitwidth(bitwidth: u8, delta_1: bool) -> u8 {
assert!(bitwidth < 32);
bitwidth | ((delta_1 as u8) << 6)
}
fn decode_bitwidth(raw_bitwidth: u8) -> (u8, bool) {
let delta_1 = ((raw_bitwidth >> 6) & 1) != 0;
let bitwidth = raw_bitwidth & 0x3f;
let bitwidth = raw_bitwidth & 0x1f;
(bitwidth, delta_1)
}
@@ -430,7 +434,7 @@ mod tests {
#[test]
fn test_encode_decode_bitwidth() {
for bitwidth in 0..=32 {
for bitwidth in 0..32 {
for delta_1 in [false, true] {
assert_eq!(
(bitwidth, delta_1),

View File

@@ -23,7 +23,11 @@ pub struct AllWeight;
impl Weight for AllWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let all_scorer = AllScorer::new(reader.max_doc());
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
if boost != 1.0 {
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
} else {
Ok(Box::new(all_scorer))
}
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {

View File

@@ -1,3 +1,5 @@
use std::sync::Arc;
use crate::fieldnorm::FieldNormReader;
use crate::query::Explanation;
use crate::schema::Field;
@@ -57,13 +59,13 @@ fn cached_tf_component(fieldnorm: u32, average_fieldnorm: Score) -> Score {
K1 * (1.0 - B + B * fieldnorm as Score / average_fieldnorm)
}
fn compute_tf_cache(average_fieldnorm: Score) -> [Score; 256] {
fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
let mut cache: [Score; 256] = [0.0; 256];
for (fieldnorm_id, cache_mut) in cache.iter_mut().enumerate() {
let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8);
*cache_mut = cached_tf_component(fieldnorm, average_fieldnorm);
}
cache
Arc::new(cache)
}
/// A struct used for computing BM25 scores.
@@ -71,17 +73,20 @@ fn compute_tf_cache(average_fieldnorm: Score) -> [Score; 256] {
pub struct Bm25Weight {
idf_explain: Option<Explanation>,
weight: Score,
cache: [Score; 256],
cache: Arc<[Score; 256]>,
average_fieldnorm: Score,
}
impl Bm25Weight {
/// Increase the weight by a multiplicative factor.
pub fn boost_by(&self, boost: Score) -> Bm25Weight {
if boost == 1.0f32 {
return self.clone();
}
Bm25Weight {
idf_explain: self.idf_explain.clone(),
weight: self.weight * boost,
cache: self.cache,
cache: self.cache.clone(),
average_fieldnorm: self.average_fieldnorm,
}
}

View File

@@ -9,7 +9,7 @@ use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::term_query::TermScorer;
use crate::query::weight::{for_each_docset_buffered, for_each_pruning_scorer, for_each_scorer};
use crate::query::{
intersect_scorers, BufferedUnionScorer, EmptyScorer, Exclude, Explanation, Occur,
intersect_scorers, AllScorer, BufferedUnionScorer, EmptyScorer, Exclude, Explanation, Occur,
RequiredOptionalScorer, Scorer, Weight,
};
use crate::{DocId, Score};
@@ -97,6 +97,74 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
}
}
/// Returns the effective MUST scorer, accounting for removed AllScorers.
///
/// When AllScorer instances are removed from must_scorers as an optimization,
/// we must restore the "match all" semantics if the list becomes empty.
fn effective_must_scorer(
must_scorers: Vec<Box<dyn Scorer>>,
removed_all_scorer_count: usize,
max_doc: DocId,
num_docs: u32,
) -> Option<Box<dyn Scorer>> {
if must_scorers.is_empty() {
if removed_all_scorer_count > 0 {
// Had AllScorer(s) only - all docs match
Some(Box::new(AllScorer::new(max_doc)))
} else {
// No MUST constraint at all
None
}
} else {
Some(intersect_scorers(must_scorers, num_docs))
}
}
/// Returns a SHOULD scorer with AllScorer union if any were removed.
///
/// For union semantics (OR): if any SHOULD clause was an AllScorer, the result
/// should include all documents. We restore this by unioning with AllScorer.
///
/// When `scoring_enabled` is false, we can just return AllScorer alone since
/// we don't need score contributions from the should_scorer.
fn effective_should_scorer_for_union<TScoreCombiner: ScoreCombiner>(
should_scorer: SpecializedScorer,
removed_all_scorer_count: usize,
max_doc: DocId,
num_docs: u32,
score_combiner_fn: impl Fn() -> TScoreCombiner,
scoring_enabled: bool,
) -> SpecializedScorer {
if removed_all_scorer_count > 0 {
if scoring_enabled {
// Need to union to get score contributions from both
let all_scorers: Vec<Box<dyn Scorer>> = vec![
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
Box::new(AllScorer::new(max_doc)),
];
SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
all_scorers,
score_combiner_fn,
num_docs,
)))
} else {
// Scoring disabled - AllScorer alone is sufficient
SpecializedScorer::Other(Box::new(AllScorer::new(max_doc)))
}
} else {
should_scorer
}
}
enum ShouldScorersCombinationMethod {
// Should scorers are irrelevant.
Ignored,
// Only contributes to final score.
Optional(SpecializedScorer),
// Regardless of score, the should scorers may impact whether a document is matching or not.
Required(SpecializedScorer),
}
/// Weight associated to the `BoolQuery`.
pub struct BooleanWeight<TScoreCombiner: ScoreCombiner> {
weights: Vec<(Occur, Box<dyn Weight>)>,
@@ -159,27 +227,50 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
) -> crate::Result<SpecializedScorer> {
let num_docs = reader.num_docs();
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
// Indicate how should clauses are combined with other clauses.
enum CombinationMethod {
Ignored,
// Only contributes to final score.
Optional(SpecializedScorer),
Required(SpecializedScorer),
// Indicate how should clauses are combined with must clauses.
let mut must_scorers: Vec<Box<dyn Scorer>> =
per_occur_scorers.remove(&Occur::Must).unwrap_or_default();
let must_special_scorer_counts = remove_and_count_all_and_empty_scorers(&mut must_scorers);
if must_special_scorer_counts.num_empty_scorers > 0 {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
let mut must_scorers = per_occur_scorers.remove(&Occur::Must);
let should_opt = if let Some(mut should_scorers) = per_occur_scorers.remove(&Occur::Should)
{
let mut should_scorers = per_occur_scorers.remove(&Occur::Should).unwrap_or_default();
let should_special_scorer_counts =
remove_and_count_all_and_empty_scorers(&mut should_scorers);
let mut exclude_scorers: Vec<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::MustNot)
.unwrap_or_default();
let exclude_special_scorer_counts =
remove_and_count_all_and_empty_scorers(&mut exclude_scorers);
if exclude_special_scorer_counts.num_all_scorers > 0 {
// We exclude all documents at one point.
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
let effective_minimum_number_should_match = self
.minimum_number_should_match
.saturating_sub(should_special_scorer_counts.num_all_scorers);
let should_scorers: ShouldScorersCombinationMethod = {
let num_of_should_scorers = should_scorers.len();
if self.minimum_number_should_match > num_of_should_scorers {
if effective_minimum_number_should_match > num_of_should_scorers {
// We don't have enough scorers to satisfy the minimum number of should matches.
// The request will match no documents.
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
match self.minimum_number_should_match {
0 => CombinationMethod::Optional(scorer_union(
match effective_minimum_number_should_match {
0 if num_of_should_scorers == 0 => ShouldScorersCombinationMethod::Ignored,
0 => ShouldScorersCombinationMethod::Optional(scorer_union(
should_scorers,
&score_combiner_fn,
num_docs,
)),
1 => CombinationMethod::Required(scorer_union(
1 => ShouldScorersCombinationMethod::Required(scorer_union(
should_scorers,
&score_combiner_fn,
num_docs,
@@ -187,76 +278,145 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
n if num_of_should_scorers == n => {
// When num_of_should_scorers equals the number of should clauses,
// they are no different from must clauses.
must_scorers = match must_scorers.take() {
Some(mut must_scorers) => {
must_scorers.append(&mut should_scorers);
Some(must_scorers)
}
None => Some(should_scorers),
};
CombinationMethod::Ignored
must_scorers.append(&mut should_scorers);
ShouldScorersCombinationMethod::Ignored
}
_ => CombinationMethod::Required(SpecializedScorer::Other(scorer_disjunction(
should_scorers,
score_combiner_fn(),
self.minimum_number_should_match,
))),
}
} else {
// None of should clauses are provided.
if self.minimum_number_should_match > 0 {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
} else {
CombinationMethod::Ignored
_ => ShouldScorersCombinationMethod::Required(SpecializedScorer::Other(
scorer_disjunction(
should_scorers,
score_combiner_fn(),
effective_minimum_number_should_match,
),
)),
}
};
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::MustNot)
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default, num_docs))
.map(|specialized_scorer: SpecializedScorer| {
into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs)
});
let positive_scorer = match (should_opt, must_scorers) {
(CombinationMethod::Ignored, Some(must_scorers)) => {
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
let exclude_scorer_opt: Option<Box<dyn Scorer>> = if exclude_scorers.is_empty() {
None
} else {
let exclude_specialized_scorer: SpecializedScorer =
scorer_union(exclude_scorers, DoNothingCombiner::default, num_docs);
Some(into_box_scorer(
exclude_specialized_scorer,
DoNothingCombiner::default,
num_docs,
))
};
let include_scorer = match (should_scorers, must_scorers) {
(ShouldScorersCombinationMethod::Ignored, must_scorers) => {
// No SHOULD clauses (or they were absorbed into MUST).
// Result depends entirely on MUST + any removed AllScorers.
let combined_all_scorer_count = must_special_scorer_counts.num_all_scorers
+ should_special_scorer_counts.num_all_scorers;
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
must_scorers,
combined_all_scorer_count,
reader.max_doc(),
num_docs,
)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
(CombinationMethod::Optional(should_scorer), Some(must_scorers)) => {
let must_scorer = intersect_scorers(must_scorers, num_docs);
if self.scoring_enabled {
SpecializedScorer::Other(Box::new(
RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
),
))
} else {
SpecializedScorer::Other(must_scorer)
(ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => {
// Optional SHOULD: contributes to scoring but not required for matching.
match effective_must_scorer(
must_scorers,
must_special_scorer_counts.num_all_scorers,
reader.max_doc(),
num_docs,
) {
None => {
// No MUST constraint: promote SHOULD to required.
// Must preserve any removed AllScorers from SHOULD via union.
effective_should_scorer_for_union(
should_scorer,
should_special_scorer_counts.num_all_scorers,
reader.max_doc(),
num_docs,
&score_combiner_fn,
self.scoring_enabled,
)
}
Some(must_scorer) => {
// Has MUST constraint: SHOULD only affects scoring.
if self.scoring_enabled {
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
_,
_,
TScoreCombiner,
>::new(
must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
)))
} else {
SpecializedScorer::Other(must_scorer)
}
}
}
}
(CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => {
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs));
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
(ShouldScorersCombinationMethod::Required(should_scorer), must_scorers) => {
// Required SHOULD: at least `minimum_number_should_match` must match.
// Semantics: (MUST constraint) AND (SHOULD constraint)
match effective_must_scorer(
must_scorers,
must_special_scorer_counts.num_all_scorers,
reader.max_doc(),
num_docs,
) {
None => {
// No MUST constraint: SHOULD alone determines matching.
should_scorer
}
Some(must_scorer) => {
// Has MUST constraint: intersect MUST with SHOULD.
let should_boxed =
into_box_scorer(should_scorer, &score_combiner_fn, num_docs);
SpecializedScorer::Other(intersect_scorers(
vec![must_scorer, should_boxed],
num_docs,
))
}
}
}
(CombinationMethod::Ignored, None) => {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)))
}
(CombinationMethod::Required(should_scorer), None) => should_scorer,
// Optional options are promoted to required if no must scorers exists.
(CombinationMethod::Optional(should_scorer), None) => should_scorer,
};
if let Some(exclude_scorer) = exclude_scorer_opt {
let positive_scorer_boxed =
into_box_scorer(positive_scorer, &score_combiner_fn, num_docs);
let include_scorer_boxed =
into_box_scorer(include_scorer, &score_combiner_fn, num_docs);
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
positive_scorer_boxed,
include_scorer_boxed,
exclude_scorer,
))))
} else {
Ok(positive_scorer)
Ok(include_scorer)
}
}
}
#[derive(Default, Copy, Clone, Debug)]
struct AllAndEmptyScorerCounts {
num_all_scorers: usize,
num_empty_scorers: usize,
}
fn remove_and_count_all_and_empty_scorers(
scorers: &mut Vec<Box<dyn Scorer>>,
) -> AllAndEmptyScorerCounts {
let mut counts = AllAndEmptyScorerCounts::default();
scorers.retain(|scorer| {
if scorer.is::<AllScorer>() {
counts.num_all_scorers += 1;
false
} else if scorer.is::<EmptyScorer>() {
counts.num_empty_scorers += 1;
false
} else {
true
}
});
counts
}
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let num_docs = reader.num_docs();
@@ -293,7 +453,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
let mut explanation = Explanation::new("BooleanClause. sum of ...", scorer.score());
for (occur, subweight) in &self.weights {
if is_positive_occur(*occur) {
if is_include_occur(*occur) {
if let Ok(child_explanation) = subweight.explain(reader, doc) {
explanation.add_detail(child_explanation);
}
@@ -377,7 +537,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
}
}
fn is_positive_occur(occur: Occur) -> bool {
fn is_include_occur(occur: Occur) -> bool {
match occur {
Occur::Must | Occur::Should => true,
Occur::MustNot => false,

View File

@@ -9,13 +9,15 @@ pub use self::boolean_weight::BooleanWeight;
#[cfg(test)]
mod tests {
use std::ops::Bound;
use super::*;
use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE;
use crate::collector::TopDocs;
use crate::collector::{Count, TopDocs};
use crate::query::term_query::TermScorer;
use crate::query::{
EnableScoring, Intersection, Occur, Query, QueryParser, RequiredOptionalScorer, Scorer,
SumCombiner, TermQuery,
AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser, RangeQuery,
RequiredOptionalScorer, Scorer, SumCombiner, TermQuery,
};
use crate::schema::*;
use crate::{assert_nearly_equals, DocAddress, DocId, Index, IndexWriter, Score};
@@ -182,7 +184,7 @@ mod tests {
let matching_topdocs = |query: &dyn Query| {
reader
.searcher()
.search(query, &TopDocs::with_limit(3))
.search(query, &TopDocs::with_limit(3).order_by_score())
.unwrap()
};
@@ -311,4 +313,530 @@ mod tests {
assert_nearly_equals!(explanation.value(), std::f32::consts::LN_2);
Ok(())
}
#[test]
pub fn test_boolean_weight_optimization() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
index_writer.add_document(doc!(text_field=>"hello"))?;
index_writer.add_document(doc!(text_field=>"hello happy"))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let term_match_all: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "hello"),
IndexRecordOption::Basic,
));
let term_match_some: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "happy"),
IndexRecordOption::Basic,
));
let term_match_none: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "tax"),
IndexRecordOption::Basic,
));
{
let query = BooleanQuery::from(vec![
(Occur::Must, term_match_all.box_clone()),
(Occur::Must, term_match_some.box_clone()),
]);
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert!(scorer.is::<TermScorer>());
}
{
let query = BooleanQuery::from(vec![
(Occur::Must, term_match_all.box_clone()),
(Occur::Must, term_match_some.box_clone()),
(Occur::Must, term_match_none.box_clone()),
]);
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert!(scorer.is::<EmptyScorer>());
}
{
let query = BooleanQuery::from(vec![
(Occur::Should, term_match_all.box_clone()),
(Occur::Should, term_match_none.box_clone()),
]);
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert!(scorer.is::<AllScorer>());
}
{
let query = BooleanQuery::from(vec![
(Occur::Should, term_match_some.box_clone()),
(Occur::Should, term_match_none.box_clone()),
]);
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert!(scorer.is::<TermScorer>());
}
Ok(())
}
#[test]
pub fn test_min_should_match_with_all_query() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let num_field =
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
index_writer.add_document(doc!(text_field => "apple", num_field => 10i64))?;
index_writer.add_document(doc!(text_field => "banana", num_field => 20i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let effective_all_match_query: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 0)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "apple"),
IndexRecordOption::Basic,
));
// in some previous version, we would remove the 2 all_match, but then say we need *4*
// matches out of the 3 term queries, which matches nothing.
let mut bool_query = BooleanQuery::new(vec![
(Occur::Should, effective_all_match_query.box_clone()),
(Occur::Should, effective_all_match_query.box_clone()),
(Occur::Should, term_query.box_clone()),
(Occur::Should, term_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
bool_query.set_minimum_number_should_match(4);
let count = searcher.search(&bool_query, &Count)?;
assert_eq!(count, 1);
Ok(())
}
// =========================================================================
// AllScorer Preservation Regression Tests
// =========================================================================
//
// These tests verify the fix for a bug where AllScorer instances (produced by
// queries matching all documents, such as range queries covering all values)
// were incorrectly removed from Boolean query processing, causing documents
// to be unexpectedly excluded from results.
//
// The bug manifested in several scenarios:
// 1. SHOULD + SHOULD where one clause is AllScorer
// 2. MUST (AllScorer) + SHOULD
// 3. Range queries in Boolean clauses when all documents match the range
/// Regression test: SHOULD clause with AllScorer combined with other SHOULD clauses.
///
/// When a SHOULD clause produces an AllScorer (e.g., from a range query matching
/// all documents), the Boolean query should still match all documents.
///
/// Bug before fix: AllScorer was removed during optimization, leaving only the
/// other SHOULD clauses, which incorrectly excluded documents.
#[test]
pub fn test_should_with_all_scorer_regression() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let num_field =
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// All docs have num > 0, so range query will return AllScorer
index_writer.add_document(doc!(text_field => "hello", num_field => 10i64))?;
index_writer.add_document(doc!(text_field => "world", num_field => 20i64))?;
index_writer.add_document(doc!(text_field => "hello world", num_field => 30i64))?;
index_writer.add_document(doc!(text_field => "foo", num_field => 40i64))?;
index_writer.add_document(doc!(text_field => "bar", num_field => 50i64))?;
index_writer.add_document(doc!(text_field => "baz", num_field => 60i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
// Range query matching all docs (returns AllScorer)
let all_match_query: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 0)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "hello"),
IndexRecordOption::Basic,
));
// Verify range matches all 6 docs
assert_eq!(searcher.search(all_match_query.as_ref(), &Count)?, 6);
// RangeQuery(all) OR TermQuery should match all 6 docs
let bool_query = BooleanQuery::new(vec![
(Occur::Should, all_match_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
let count = searcher.search(&bool_query, &Count)?;
assert_eq!(count, 6, "SHOULD with AllScorer should match all docs");
// Order should not matter
let bool_query_reversed = BooleanQuery::new(vec![
(Occur::Should, term_query.box_clone()),
(Occur::Should, all_match_query.box_clone()),
]);
let count_reversed = searcher.search(&bool_query_reversed, &Count)?;
assert_eq!(
count_reversed, 6,
"Order of SHOULD clauses should not matter"
);
Ok(())
}
/// Regression test: MUST clause with AllScorer combined with SHOULD clause.
///
/// When MUST contains an AllScorer, all documents satisfy the MUST constraint.
/// The SHOULD clause should only affect scoring, not filtering.
///
/// Bug before fix: AllScorer was removed, leaving an empty must_scorers vector.
/// intersect_scorers([]) incorrectly returned EmptyScorer, matching 0 documents.
#[test]
pub fn test_must_all_with_should_regression() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let num_field =
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// All docs have num > 0, so range query will return AllScorer
index_writer.add_document(doc!(text_field => "apple", num_field => 10i64))?;
index_writer.add_document(doc!(text_field => "banana", num_field => 20i64))?;
index_writer.add_document(doc!(text_field => "cherry", num_field => 30i64))?;
index_writer.add_document(doc!(text_field => "date", num_field => 40i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
// Range query matching all docs (returns AllScorer)
let all_match_query: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 0)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "apple"),
IndexRecordOption::Basic,
));
// Verify range matches all 4 docs
assert_eq!(searcher.search(all_match_query.as_ref(), &Count)?, 4);
// MUST(range matching all) AND SHOULD(term) should match all 4 docs
let bool_query = BooleanQuery::new(vec![
(Occur::Must, all_match_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
let count = searcher.search(&bool_query, &Count)?;
assert_eq!(count, 4, "MUST AllScorer + SHOULD should match all docs");
Ok(())
}
/// Regression test: Range queries in Boolean clauses when all documents match.
///
/// Range queries can return AllScorer as an optimization when all indexed values
/// fall within the range. This test ensures such queries work correctly in
/// Boolean combinations.
///
/// This is the most common real-world manifestation of the bug, occurring in
/// queries like: (age > 50 OR name = 'Alice') AND status = 'active'
/// when all documents have age > 50.
#[test]
pub fn test_range_query_all_match_in_boolean() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let name_field = schema_builder.add_text_field("name", TEXT);
let age_field =
schema_builder.add_i64_field("age", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// All documents have age > 50, so range query will return AllScorer
index_writer.add_document(doc!(name_field => "alice", age_field => 55_i64))?;
index_writer.add_document(doc!(name_field => "bob", age_field => 60_i64))?;
index_writer.add_document(doc!(name_field => "charlie", age_field => 70_i64))?;
index_writer.add_document(doc!(name_field => "diana", age_field => 80_i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let range_query: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(age_field, 50)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(name_field, "alice"),
IndexRecordOption::Basic,
));
// Verify preconditions
assert_eq!(searcher.search(range_query.as_ref(), &Count)?, 4);
assert_eq!(searcher.search(term_query.as_ref(), &Count)?, 1);
// SHOULD(range) OR SHOULD(term): range matches all, so result is 4
let should_query = BooleanQuery::new(vec![
(Occur::Should, range_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
assert_eq!(
searcher.search(&should_query, &Count)?,
4,
"SHOULD range OR term should match all"
);
// MUST(range) AND SHOULD(term): range matches all, term is optional
let must_should_query = BooleanQuery::new(vec![
(Occur::Must, range_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
assert_eq!(
searcher.search(&must_should_query, &Count)?,
4,
"MUST range + SHOULD term should match all"
);
Ok(())
}
/// Test multiple AllScorer instances in different clause types.
///
/// Verifies correct behavior when AllScorers appear in multiple positions.
#[test]
pub fn test_multiple_all_scorers() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let num_field =
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// All docs have num > 0, so range queries will return AllScorer
index_writer.add_document(doc!(text_field => "doc1", num_field => 10i64))?;
index_writer.add_document(doc!(text_field => "doc2", num_field => 20i64))?;
index_writer.add_document(doc!(text_field => "doc3", num_field => 30i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
// Two different range queries that both match all docs (return AllScorer)
let all_query1: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 0)),
Bound::Unbounded,
));
let all_query2: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 5)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "doc1"),
IndexRecordOption::Basic,
));
// Multiple AllScorers in SHOULD
let multi_all_should = BooleanQuery::new(vec![
(Occur::Should, all_query1.box_clone()),
(Occur::Should, all_query2.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
assert_eq!(
searcher.search(&multi_all_should, &Count)?,
3,
"Multiple AllScorers in SHOULD"
);
// AllScorer in both MUST and SHOULD
let all_must_and_should = BooleanQuery::new(vec![
(Occur::Must, all_query1.box_clone()),
(Occur::Should, all_query2.box_clone()),
]);
assert_eq!(
searcher.search(&all_must_and_should, &Count)?,
3,
"AllScorer in both MUST and SHOULD"
);
Ok(())
}
}
/// A proptest which generates arbitrary permutations of a simple boolean AST, and then matches
/// the result against an index which contains all permutations of documents with N fields.
#[cfg(test)]
mod proptest_boolean_query {
use std::collections::{BTreeMap, HashSet};
use std::ops::Bound;
use proptest::collection::vec;
use proptest::prelude::*;
use crate::collector::DocSetCollector;
use crate::query::{AllQuery, BooleanQuery, Occur, Query, RangeQuery, TermQuery};
use crate::schema::{Field, NumericOptions, OwnedValue, Schema, TEXT};
use crate::{DocId, Index, Term};
#[derive(Debug, Clone)]
enum BooleanQueryAST {
/// Matches all documents via AllQuery (wraps AllScorer in BoostScorer)
All,
/// Matches all documents via RangeQuery (returns bare AllScorer)
/// This is the actual trigger for the AllScorer preservation bug
RangeAll,
/// Matches documents where the field has value "true"
Leaf {
field_idx: usize,
},
Union(Vec<BooleanQueryAST>),
Intersection(Vec<BooleanQueryAST>),
}
impl BooleanQueryAST {
fn matches(&self, doc_id: DocId) -> bool {
match self {
BooleanQueryAST::All => true,
BooleanQueryAST::RangeAll => true,
BooleanQueryAST::Leaf { field_idx } => Self::matches_field(doc_id, *field_idx),
BooleanQueryAST::Union(children) => {
children.iter().any(|child| child.matches(doc_id))
}
BooleanQueryAST::Intersection(children) => {
children.iter().all(|child| child.matches(doc_id))
}
}
}
fn matches_field(doc_id: DocId, field_idx: usize) -> bool {
((doc_id as usize) >> field_idx) & 1 == 1
}
fn to_query(&self, fields: &[Field], range_field: Field) -> Box<dyn Query> {
match self {
BooleanQueryAST::All => Box::new(AllQuery),
BooleanQueryAST::RangeAll => {
// Range query that matches all docs (all have value >= 0)
// This returns bare AllScorer, triggering the bug we fixed
Box::new(RangeQuery::new(
Bound::Included(Term::from_field_i64(range_field, 0)),
Bound::Unbounded,
))
}
BooleanQueryAST::Leaf { field_idx } => Box::new(TermQuery::new(
Term::from_field_text(fields[*field_idx], "true"),
crate::schema::IndexRecordOption::Basic,
)),
BooleanQueryAST::Union(children) => {
let sub_queries = children
.iter()
.map(|child| (Occur::Should, child.to_query(fields, range_field)))
.collect();
Box::new(BooleanQuery::new(sub_queries))
}
BooleanQueryAST::Intersection(children) => {
let sub_queries = children
.iter()
.map(|child| (Occur::Must, child.to_query(fields, range_field)))
.collect();
Box::new(BooleanQuery::new(sub_queries))
}
}
}
}
fn doc_ids(num_docs: usize, num_fields: usize) -> impl Iterator<Item = DocId> {
let permutations = 1 << num_fields;
let copies = (num_docs as f32 / permutations as f32).ceil() as u32;
(0..(permutations * copies)).into_iter()
}
fn create_index_with_boolean_permutations(
num_docs: usize,
num_fields: usize,
) -> (Index, Vec<Field>, Field) {
let mut schema_builder = Schema::builder();
let fields: Vec<Field> = (0..num_fields)
.map(|i| schema_builder.add_text_field(&format!("field_{}", i), TEXT))
.collect();
// Add a numeric field for RangeQuery tests - all docs have value = doc_id
let range_field = schema_builder.add_i64_field(
"range_field",
NumericOptions::default().set_fast().set_indexed(),
);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
for doc_id in doc_ids(num_docs, num_fields) {
let mut doc: BTreeMap<_, OwnedValue> = BTreeMap::default();
for (field_idx, &field) in fields.iter().enumerate() {
if (doc_id >> field_idx) & 1 == 1 {
doc.insert(field, "true".into());
}
}
// All docs have non-negative values, so RangeQuery(>=0) matches all
doc.insert(range_field, (doc_id as i64).into());
writer.add_document(doc).unwrap();
}
writer.commit().unwrap();
(index, fields, range_field)
}
fn arb_boolean_query_ast(num_fields: usize) -> impl Strategy<Value = BooleanQueryAST> {
// Leaf strategies: term queries, AllQuery, and RangeQuery matching all docs
let leaf = prop_oneof![
(0..num_fields).prop_map(|field_idx| BooleanQueryAST::Leaf { field_idx }),
Just(BooleanQueryAST::All),
Just(BooleanQueryAST::RangeAll),
];
leaf.prop_recursive(
8, // 8 levels of recursion
256, // 256 nodes max
10, // 10 items per collection
|inner| {
prop_oneof![
vec(inner.clone(), 1..10).prop_map(BooleanQueryAST::Union),
vec(inner, 1..10).prop_map(BooleanQueryAST::Intersection),
]
},
)
}
#[test]
fn proptest_boolean_query() {
// In the presence of optimizations around buffering, it can take large numbers of
// documents to uncover some issues.
let num_docs = 10000;
let num_fields = 8;
let num_docs = 1 << num_fields;
let (index, fields, range_field) =
create_index_with_boolean_permutations(num_docs, num_fields);
let searcher = index.reader().unwrap().searcher();
proptest!(|(ast in arb_boolean_query_ast(num_fields))| {
let query = ast.to_query(&fields, range_field);
let mut matching_docs = HashSet::new();
for doc_id in doc_ids(num_docs, num_fields) {
if ast.matches(doc_id as DocId) {
matching_docs.insert(doc_id as DocId);
}
}
let doc_addresses = searcher.search(&*query, &DocSetCollector).unwrap();
let result_docs: HashSet<DocId> =
doc_addresses.into_iter().map(|doc_address| doc_address.doc_id).collect();
prop_assert_eq!(result_docs, matching_docs);
});
}
}

View File

@@ -53,7 +53,7 @@ use crate::{Score, Term};
/// // TermQuery "diary" and "girl" should be present and only one should be accounted in score
/// let queries1 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()];
/// let diary_and_girl = DisjunctionMaxQuery::new(queries1);
/// let documents = searcher.search(&diary_and_girl, &TopDocs::with_limit(3))?;
/// let documents = searcher.search(&diary_and_girl, &TopDocs::with_limit(3).order_by_score())?;
/// assert_eq!(documents[0].0, documents[1].0);
/// assert_eq!(documents[1].0, documents[2].0);
///
@@ -62,7 +62,7 @@ use crate::{Score, Term};
/// let queries2 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()];
/// let tie_breaker = 0.7;
/// let diary_and_girl_with_tie_breaker = DisjunctionMaxQuery::with_tie_breaker(queries2, tie_breaker);
/// let documents = searcher.search(&diary_and_girl_with_tie_breaker, &TopDocs::with_limit(3))?;
/// let documents = searcher.search(&diary_and_girl_with_tie_breaker, &TopDocs::with_limit(3).order_by_score())?;
/// assert_eq!(documents[1].0, documents[2].0);
/// // For this test all terms brings the same score. So we can do easy math and assume that
/// // `DisjunctionMaxQuery` with tie breakers score should be equal

View File

@@ -127,7 +127,11 @@ impl Weight for ExistsWeight {
.any(|col| matches!(col.column_index(), ColumnIndex::Full))
{
let all_scorer = AllScorer::new(max_doc);
return Ok(Box::new(BoostScorer::new(all_scorer, boost)));
if boost != 1.0f32 {
return Ok(Box::new(BoostScorer::new(all_scorer, boost)));
} else {
return Ok(Box::new(all_scorer));
}
}
// If we have a single dynamic column, use ExistsDocSet

View File

@@ -67,7 +67,7 @@ impl Automaton for DfaWrapper {
/// {
/// let term = Term::from_field_text(title, "Diary");
/// let query = FuzzyTermQuery::new(term, 1, true);
/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count)).unwrap();
/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count)).unwrap();
/// assert_eq!(count, 2);
/// assert_eq!(top_docs.len(), 2);
/// }
@@ -241,7 +241,8 @@ mod test {
{
let term = get_json_path_term("attributes.aa:japan")?;
let fuzzy_query = FuzzyTermQuery::new(term, 2, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
assert_eq!(top_docs[0].1.doc_id, 1, "Expected the second document");
}
@@ -252,7 +253,8 @@ mod test {
let term = get_json_path_term("attributes.a:japon")?;
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
assert_eq!(top_docs[0].1.doc_id, 0, "Expected the first document");
}
@@ -262,7 +264,8 @@ mod test {
let term = get_json_path_term("attributes.a:jap")?;
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 0, "Expected no document");
}
@@ -292,7 +295,8 @@ mod test {
{
let term = Term::from_field_text(country_field, "japon");
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals!(1.0, score);
@@ -303,7 +307,8 @@ mod test {
let term = Term::from_field_text(country_field, "jap");
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 0, "Expected no document");
}
@@ -311,7 +316,8 @@ mod test {
{
let term = Term::from_field_text(country_field, "jap");
let fuzzy_query = FuzzyTermQuery::new_prefix(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals!(1.0, score);

View File

@@ -267,7 +267,7 @@ mod tests {
.with_boost_factor(1.0)
.with_stop_words(vec!["old".to_string()])
.with_document(DocAddress::new(0, 0));
let top_docs = searcher.search(&query, &TopDocs::with_limit(5))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(5).order_by_score())?;
let mut doc_ids: Vec<_> = top_docs.iter().map(|item| item.1.doc_id).collect();
doc_ids.sort_unstable();
@@ -283,7 +283,7 @@ mod tests {
.with_max_word_length(5)
.with_boost_factor(1.0)
.with_document(DocAddress::new(0, 4));
let top_docs = searcher.search(&query, &TopDocs::with_limit(5))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(5).order_by_score())?;
let mut doc_ids: Vec<_> = top_docs.iter().map(|item| item.1.doc_id).collect();
doc_ids.sort_unstable();

View File

@@ -266,8 +266,9 @@ mod tests {
use super::RangeQuery;
use crate::collector::{Count, TopDocs};
use crate::indexer::NoMergePolicy;
use crate::query::range_query::fast_field_range_doc_set::RangeDocSet;
use crate::query::range_query::range_query::InvertedIndexRangeQuery;
use crate::query::QueryParser;
use crate::query::{AllScorer, ConstScorer, EmptyScorer, EnableScoring, Query, QueryParser};
use crate::schema::{
Field, IntoIpv6Addr, Schema, TantivyDocument, FAST, INDEXED, STORED, TEXT,
};
@@ -495,7 +496,7 @@ mod tests {
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![title]);
let query = query_parser.parse_query("hemoglobin AND year:[1970 TO 1990]")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
assert_eq!(top_docs.len(), 1);
Ok(())
}
@@ -549,7 +550,7 @@ mod tests {
let get_num_hits = |query| {
let (_top_docs, count) = searcher
.search(&query, &(TopDocs::with_limit(10), Count))
.search(&query, &(TopDocs::with_limit(10).order_by_score(), Count))
.unwrap();
count
};
@@ -660,4 +661,46 @@ mod tests {
0
);
}
#[test]
fn test_range_query_simplified() {
// This test checks that if the targeted column values are entirely
// within the range, and the column is full, we end up with a AllScorer.
let mut schema_builder = Schema::builder();
let u64_field = schema_builder.add_u64_field("u64_field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut index_writer = index.writer_for_tests().unwrap();
index_writer.add_document(doc!(u64_field=> 2u64)).unwrap();
index_writer.add_document(doc!(u64_field=> 4u64)).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let make_term = |value: u64| Term::from_field_u64(u64_field, value);
let make_scorer = move |lower_bound: Bound<u64>, upper_bound: Bound<u64>| {
let lower_bound_term = lower_bound.map(make_term);
let upper_bound_term = upper_bound.map(make_term);
let range_query = RangeQuery::new(lower_bound_term, upper_bound_term);
let range_weight = range_query
.weight(EnableScoring::disabled_from_schema(&schema))
.unwrap();
let range_scorer = range_weight
.scorer(&searcher.segment_readers()[0], 1.0f32)
.unwrap();
range_scorer
};
let range_scorer = make_scorer(Bound::Included(1), Bound::Included(4));
assert!(range_scorer.is::<AllScorer>());
let range_scorer = make_scorer(Bound::Included(0), Bound::Included(2));
assert!(range_scorer.is::<ConstScorer<RangeDocSet<u64>>>());
let range_scorer = make_scorer(Bound::Included(3), Bound::Included(10));
assert!(range_scorer.is::<ConstScorer<RangeDocSet<u64>>>());
let range_scorer = make_scorer(Bound::Included(10), Bound::Included(12));
assert!(range_scorer.is::<ConstScorer<RangeDocSet<u64>>>());
let range_scorer = make_scorer(Bound::Included(0), Bound::Included(1));
assert!(range_scorer.is::<EmptyScorer>());
let range_scorer = make_scorer(Bound::Included(0), Bound::Excluded(2));
assert!(range_scorer.is::<EmptyScorer>());
}
}

View File

@@ -6,8 +6,8 @@ use std::net::Ipv6Addr;
use std::ops::{Bound, RangeInclusive};
use columnar::{
Column, ColumnType, MonotonicallyMappableToU128, MonotonicallyMappableToU64, NumericalType,
StrColumn,
Cardinality, Column, ColumnType, MonotonicallyMappableToU128, MonotonicallyMappableToU64,
NumericalType, StrColumn,
};
use common::bounds::{BoundsRange, TransformBound};
@@ -397,6 +397,8 @@ fn search_on_u64_ff(
boost: Score,
bounds: BoundsRange<u64>,
) -> crate::Result<Box<dyn Scorer>> {
let col_min_value = column.min_value();
let col_max_value = column.max_value();
#[expect(clippy::reversed_empty_ranges)]
let value_range = bound_to_value_range(
&bounds.lower_bound,
@@ -408,6 +410,22 @@ fn search_on_u64_ff(
if value_range.is_empty() {
return Ok(Box::new(EmptyScorer));
}
if col_min_value >= *value_range.start() && col_max_value <= *value_range.end() {
// all values in the column are within the range.
if column.index.get_cardinality() == Cardinality::Full {
if boost != 1.0f32 {
return Ok(Box::new(ConstScorer::new(
AllScorer::new(column.num_docs()),
boost,
)));
} else {
return Ok(Box::new(AllScorer::new(column.num_docs())));
}
} else {
// TODO Make it a field presence request for that specific column
}
}
let docset = RangeDocSet::new(value_range, column);
Ok(Box::new(ConstScorer::new(docset, boost)))
}
@@ -509,7 +527,9 @@ mod tests {
let test_query = |query, num_hits| {
let query = query_parser.parse_query(query).unwrap();
let top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
let top_docs = searcher
.search(&query, &TopDocs::with_limit(10).order_by_score())
.unwrap();
assert_eq!(top_docs.len(), num_hits);
};
@@ -595,7 +615,9 @@ mod tests {
let query_parser = QueryParser::for_index(&index, vec![date_field]);
let test_query = |query, num_hits| {
let query = query_parser.parse_query(query).unwrap();
let top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
let top_docs = searcher
.search(&query, &TopDocs::with_limit(10).order_by_score())
.unwrap();
assert_eq!(top_docs.len(), num_hits);
};
@@ -975,7 +997,9 @@ mod tests {
let query_parser = QueryParser::for_index(&index, vec![json_field]);
let test_query = |query, num_hits| {
let query = query_parser.parse_query(query).unwrap();
let top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
let top_docs = searcher
.search(&query, &TopDocs::with_limit(10).order_by_score())
.unwrap();
assert_eq!(top_docs.len(), num_hits);
};

View File

@@ -125,14 +125,20 @@ mod test {
let searcher = reader.searcher();
{
let scored_docs = searcher
.search(&query_matching_one, &TopDocs::with_limit(2))
.search(
&query_matching_one,
&TopDocs::with_limit(2).order_by_score(),
)
.unwrap();
assert_eq!(scored_docs.len(), 1, "Expected only 1 document");
let (score, _) = scored_docs[0];
assert_nearly_equals!(1.0, score);
}
let top_docs = searcher
.search(&query_matching_zero, &TopDocs::with_limit(2))
.search(
&query_matching_zero,
&TopDocs::with_limit(2).order_by_score(),
)
.unwrap();
assert!(top_docs.is_empty(), "Expected ZERO document");
}

View File

@@ -153,7 +153,8 @@ mod tests {
let terms = vec![Term::from_field_text(field1, "doc1")];
let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&term_set_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals!(1.0, score);
@@ -164,7 +165,8 @@ mod tests {
let terms = vec![Term::from_field_text(field1, "doc4")];
let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(1))?;
let top_docs =
searcher.search(&term_set_query, &TopDocs::with_limit(1).order_by_score())?;
assert!(top_docs.is_empty(), "Expected 0 document");
}
@@ -176,7 +178,8 @@ mod tests {
];
let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&term_set_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 2, "Expected 2 documents");
for (score, _) in top_docs {
assert_nearly_equals!(1.0, score);
@@ -192,7 +195,8 @@ mod tests {
];
let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(3))?;
let top_docs =
searcher.search(&term_set_query, &TopDocs::with_limit(3).order_by_score())?;
assert_eq!(top_docs.len(), 2, "Expected 2 document");
for (score, _) in top_docs {
@@ -205,13 +209,15 @@ mod tests {
let terms = vec![Term::from_field_text(field1, "doc3")];
let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(3))?;
let top_docs =
searcher.search(&term_set_query, &TopDocs::with_limit(3).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected 1 document");
let terms = vec![Term::from_field_text(field2, "doc3")];
let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(3))?;
let top_docs =
searcher.search(&term_set_query, &TopDocs::with_limit(3).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected 1 document");
let terms = vec![
@@ -220,7 +226,8 @@ mod tests {
];
let term_set_query = TermSetQuery::new(terms);
let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(3))?;
let top_docs =
searcher.search(&term_set_query, &TopDocs::with_limit(3).order_by_score())?;
assert_eq!(top_docs.len(), 2, "Expected 2 document");
}
@@ -249,7 +256,7 @@ mod tests {
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![]);
let query = query_parser.parse_query("field: IN [val1 val2]")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(3))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(3).order_by_score())?;
assert_eq!(top_docs.len(), 2);
Ok(())
}

View File

@@ -10,7 +10,10 @@ mod tests {
use crate::collector::TopDocs;
use crate::docset::DocSet;
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::query::{EnableScoring, Query, QueryParser, Scorer, TermQuery};
use crate::query::term_query::TermScorer;
use crate::query::{
AllScorer, EmptyScorer, EnableScoring, Query, QueryParser, Scorer, TermQuery,
};
use crate::schema::{Field, IndexRecordOption, Schema, FAST, STRING, TEXT};
use crate::{assert_nearly_equals, DocAddress, Index, IndexWriter, Term, TERMINATED};
@@ -97,7 +100,7 @@ mod tests {
{
let term = Term::from_field_text(left_field, "left2");
let term_query = TermQuery::new(term, IndexRecordOption::WithFreqs);
let topdocs = searcher.search(&term_query, &TopDocs::with_limit(2))?;
let topdocs = searcher.search(&term_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(topdocs.len(), 1);
let (score, _) = topdocs[0];
assert_nearly_equals!(0.77802235, score);
@@ -105,7 +108,8 @@ mod tests {
{
let term = Term::from_field_text(left_field, "left1");
let term_query = TermQuery::new(term, IndexRecordOption::WithFreqs);
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&term_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 2);
let (score1, _) = top_docs[0];
assert_nearly_equals!(0.27101856, score1);
@@ -115,7 +119,7 @@ mod tests {
{
let query_parser = QueryParser::for_index(&index, Vec::new());
let query = query_parser.parse_query("left:left2 left:left1")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(2))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 2);
let (score1, _) = top_docs[0];
assert_nearly_equals!(0.9153879, score1);
@@ -435,9 +439,87 @@ mod tests {
// Using TopDocs requires scoring; since the field is not indexed,
// TermQuery cannot score and should return a SchemaError.
let res = searcher.search(&tq, &TopDocs::with_limit(1));
let res = searcher.search(&tq, &TopDocs::with_limit(1).order_by_score());
assert!(matches!(res, Err(crate::TantivyError::SchemaError(_))));
Ok(())
}
#[test]
fn test_term_weight_all_query_optimization() {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", crate::schema::TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(doc!(text_field=>"hello"))
.unwrap();
index_writer
.add_document(doc!(text_field=>"hello happy"))
.unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let get_scorer_for_term = |term: &str| {
let term_query = TermQuery::new(
Term::from_field_text(text_field, term),
IndexRecordOption::Basic,
);
let term_weight = term_query
.weight(EnableScoring::disabled_from_schema(&schema))
.unwrap();
term_weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap()
};
// Should be an allscorer
let match_all_scorer = get_scorer_for_term("hello");
// Should be a term scorer
let match_some_scorer = get_scorer_for_term("happy");
// Should be an empty scorer
let empty_scorer = get_scorer_for_term("tax");
assert!(match_all_scorer.is::<AllScorer>());
assert!(match_some_scorer.is::<TermScorer>());
assert!(empty_scorer.is::<EmptyScorer>());
}
#[test]
fn test_term_weight_all_query_optimization_disable_when_scoring_enabled() {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", crate::schema::TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(doc!(text_field=>"hello"))
.unwrap();
index_writer
.add_document(doc!(text_field=>"hello happy"))
.unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let get_scorer_for_term = |term: &str| {
let term_query = TermQuery::new(
Term::from_field_text(text_field, term),
IndexRecordOption::Basic,
);
let term_weight = term_query
.weight(EnableScoring::enabled_from_searcher(&searcher))
.unwrap();
term_weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap()
};
// Should be an allscorer
let match_all_scorer = get_scorer_for_term("hello");
// Should be a term scorer
let one_scorer = get_scorer_for_term("happy");
// Should be an empty scorer
let empty_scorer = get_scorer_for_term("tax");
assert!(match_all_scorer.is::<TermScorer>());
assert!(one_scorer.is::<TermScorer>());
assert!(empty_scorer.is::<EmptyScorer>());
}
}

View File

@@ -50,7 +50,7 @@ use crate::Term;
/// Term::from_field_text(title, "diary"),
/// IndexRecordOption::Basic,
/// );
/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
/// assert_eq!(count, 2);
/// Ok(())
/// # }
@@ -101,7 +101,7 @@ impl TermQuery {
EnableScoring::Enabled {
statistics_provider,
..
} => Bm25Weight::for_terms(statistics_provider, &[self.term.clone()])?,
} => Bm25Weight::for_terms(statistics_provider, std::slice::from_ref(&self.term))?,
EnableScoring::Disabled { .. } => {
Bm25Weight::new(Explanation::new("<no score>", 1.0f32), 1.0f32)
}
@@ -190,7 +190,7 @@ mod tests {
let assert_single_hit = |query| {
let (_top_docs, count) = searcher
.search(&query, &(TopDocs::with_limit(2), Count))
.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))
.unwrap();
assert_eq!(count, 1);
};

View File

@@ -259,7 +259,7 @@ mod tests {
let mut block_max_scores_b = vec![];
let mut docs = vec![];
{
let mut term_scorer = term_weight.specialized_scorer(reader, 1.0)?;
let mut term_scorer = term_weight.term_scorer_for_test(reader, 1.0)?.unwrap();
while term_scorer.doc() != TERMINATED {
let mut score = term_scorer.score();
docs.push(term_scorer.doc());
@@ -273,7 +273,7 @@ mod tests {
}
}
{
let mut term_scorer = term_weight.specialized_scorer(reader, 1.0)?;
let mut term_scorer = term_weight.term_scorer_for_test(reader, 1.0)?.unwrap();
for d in docs {
term_scorer.seek_block(d);
block_max_scores_b.push(term_scorer.block_max_score());

Some files were not shown because too many files have changed in this diff Show More