Compare commits

..

31 Commits

Author SHA1 Message Date
Stu Hood
db9e35e7ee Property test for Comparator/ValueRange consistency, and fixes. 2026-01-04 19:19:08 -08:00
Stu Hood
7f39d5eab9 test_order_by_u64_prop 2026-01-04 15:23:30 -08:00
Stu Hood
af53ffe5df Use a Buffer generic scratch buffer parameter on TopNComputer and push directly from ColumnValues into a TopNComputer buffer in some cases. 2026-01-04 15:23:28 -08:00
Stu Hood
041c6f01a3 Convert test_order_by_compound_filtering_with_none to a proptest. 2026-01-04 15:16:05 -08:00
Stu Hood
9615eb73b8 Implement collect_block for lazy scorers using SegmentSortKeyComputer::segment_sort_keys. 2026-01-04 15:16:00 -08:00
Paul Masurel
77505c3d03 Making stemming optional. (#2791)
Fixed code and CI to run on no default features.

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2026-01-02 12:40:42 +01:00
PSeitz
735c588f4f fix union performance regression (#2790)
* add inlines

* fix union performance regression

Remove unwrap from hotpath generates better assembly.

closes #2788
2026-01-02 12:06:51 +01:00
PSeitz
242a1531bf fix flaky test (#2784)
Signed-off-by: Pascal Seitz <pascal.seitz@gmail.com>
2026-01-02 11:30:51 +01:00
trinity-1686a
6443b63177 document 1bit hole and some queries supporting running with just fastfield (#2779)
* add small doc on some queries using fast field when not indexed

* document 1 unused bit in skiplist
2026-01-02 10:32:37 +01:00
Stu Hood
4987495ee4 Add an erased SortKeyComputer to sort on types which are not known until runtime (#2770)
* Remove PartialOrd bound on compared values.

* Fix declared `SortKey` type of `impl<..> SortKeyComputer for (HeadSortKeyComputer, TailSortKeyComputer)`

* Add a SortByOwnedValue implementation to provide a type-erased column.

* Add support for comparing mismatched `OwnedValue` types.

* Support JSON columns.

* Refer to https://github.com/quickwit-oss/tantivy/issues/2776

* Rename to `SortByErasedType`.

* Comment on transitivity.

Co-authored-by: Paul Masurel <paul@quickwit.io>

* Fix clippy warnings in new code.

---------

Co-authored-by: Paul Masurel <paul@quickwit.io>
2026-01-02 10:28:47 +01:00
Paul Masurel
b11605f045 Addressing clippy comments (#2789)
Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-12-31 18:02:00 +01:00
ChangRui-Ryan
75d7989cc6 add benchmark for boolean query with range sub query (#2787) 2025-12-31 12:00:53 +01:00
PSeitz
923f0508f2 seek_exact + cost based intersection (#2538)
* seek_exact + cost based intersection

Adds `seek_exact` and `cost` to `DocSet` for a more efficient intersection.
Unlike `seek`, `seek_exact` does not require the DocSet to advance to the next hit, if the target does not exist.

`cost` allows to address the different DocSet types and their cost
model and is used to determine the DocSet that drives the intersection.
E.g. fast field range queries may do a full scan. Phrase queries load the positions to check if a we have a hit.
They both have a higher cost than their size_hint would suggest.

Improves `size_hint` estimation for intersection and union, by having a
estimation based on random distribution with a co-location factor.

Refactor range query benchmark.

Closes #2531

*Future Work*

Implement `seek_exact` for BufferedUnionScorer and RangeDocSet (fast field range queries)
Evaluate replacing `seek` with `seek_exact` to reduce code complexity

* Apply suggestions from code review

Co-authored-by: Paul Masurel <paul@quickwit.io>

* add API contract verfication

* impl seek_exact on union

* rename seek_exact

* add mixed AND OR test, fix buffered_union

* Add a proptest of BooleanQuery. (#2690)

* fix build

* Increase the document count.

* fix merge conflict

* fix debug assert

* Fix compilation errors after rebase

- Remove duplicate proptest_boolean_query module
- Remove duplicate cost() method implementations
- Fix TopDocs API usage (add .order_by_score())
- Remove duplicate imports
- Remove unused variable assignments

---------

Co-authored-by: Paul Masurel <paul@quickwit.io>
Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
Co-authored-by: Stu Hood <stuhood@gmail.com>
2025-12-30 14:43:25 +01:00
ChangRui-Ryan
e0b62e00ac optimize RangeDocSet for non-overlapping query ranges (#2783) 2025-12-29 16:55:28 +01:00
Stu Hood
ce97beb86f Add support for natural-order-with-none-highest in TopDocs::order_by (#2780)
* Add `ComparatorEnum::NaturalNoneHigher`.

* Fix comments.
2025-12-23 09:22:20 +01:00
Stu Hood
c0f21a45ae Use a strict comparison in TopNComputer (#2777)
* Remove `(Partial)Ord` from `ComparableDoc`, and unify comparison between `TopNComputer` and `Comparator`.

* Doc cleanups.

* Require Ord for `ComparableDoc`.

* Semantics are actually _ascending_ DocId order.

* Adjust docs again for ascending DocId order.

* minor change

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-12-18 12:13:23 +01:00
Moe
73657dff77 fix: fixed integer overflow in ExpUnrolledLinkedList for large datasets (#2735)
* Fixed the overflow issue.

* Fixed lint issues.

* Applied PR fixes.

* Fixed a lint issue.
2025-12-16 22:57:12 +01:00
Moe
e3c9be1f92 fix: boolean query incorrectly dropping documents when AllScorer is present (#2760)
* Fixed the range issue.

* Fixed the second all scorer issue

* Improved docs + tests

* Improved code.

* Fixed lint issues.

* Improved tests + logic based on PR comments.

* Fixed lint issues.

* Increase the document count.

* Improved the prop-tests

* Expand the index size, and remove unused parameter.

---------

Co-authored-by: Stu Hood <stuhood@gmail.com>
2025-12-16 22:52:02 +01:00
Ming
ba61ed6ef3 fix: vint buffer can overflow (#2778)
* fix vint overflow

* comment
2025-12-16 22:50:41 +01:00
trinity-1686a
d0e1600135 fix bug with minimum_should_match and AllScorer (#2774) 2025-12-14 10:10:45 +01:00
PSeitz-dd
e9020d17d4 fix coverage (#2769) 2025-12-11 11:35:58 +01:00
PSeitz-dd
5ba0031f7d move rand_distr to dev_dep (#2772) 2025-12-11 18:23:50 +08:00
Philippe Noël
22dde8f9ae chore: Make some delete-related functions public (#46) (#2766)
Co-authored-by: Ming <ming.ying.nyc@gmail.com>
2025-12-11 01:22:15 +01:00
Philippe Noël
14cc24614e Make DeleteMeta pub (#2765)
Co-authored-by: Ming Ying <ming.ying.nyc@gmail.com>
2025-12-11 00:11:03 +01:00
Philippe Noël
8a1079b2dc expose AddOperation and with_max_doc (#7) (#2762)
Co-authored-by: Ming <ming.ying.nyc@gmail.com>
2025-12-11 00:10:42 +01:00
Philippe Noël
794ff1ffc9 chore: Make Language hashable (#79) (#2763)
Co-authored-by: Ming <ming.ying.nyc@gmail.com>
2025-12-10 15:38:43 +01:00
PSeitz-dd
c6912ce89a Handle JSON fields and columnar in space_usage (#2761)
return field names in space_usage instead of `Field`
more detailed info for columns
2025-12-10 20:33:33 +08:00
PSeitz
618e3bd11b Term and IndexingTerm cleanup (#2750)
* refactor term

* add deprecated functions

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-12-05 09:48:40 +08:00
PSeitz
b2f99c6217 add term->histogram benchmark (#2758)
* add term->histogram benchmark

* add more term aggs

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-12-04 02:29:37 +01:00
PSeitz
76de5bab6f fix unsafe warnings (#2757) 2025-12-03 20:15:21 +08:00
rustmailer
b7eb31162b docs: add usage example to README (#2743) 2025-12-02 21:56:57 +01:00
122 changed files with 6003 additions and 4654 deletions

View File

@@ -15,11 +15,11 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
continue-on-error: true

View File

@@ -39,11 +39,11 @@ jobs:
- name: Check Formatting
run: cargo +nightly fmt --all -- --check
- name: Check Stable Compilation
run: cargo build --all-features
- name: Check Bench Compilation
run: cargo +nightly bench --no-run --profile=dev --all-features
@@ -59,10 +59,10 @@ jobs:
strategy:
matrix:
features: [
{ label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints" },
{ label: "quickwit", flags: "mmap,quickwit,failpoints" }
]
features:
- { label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints,stemmer" }
- { label: "quickwit", flags: "mmap,quickwit,failpoints" }
- { label: "none", flags: "" }
name: test-${{ matrix.features.label}}
@@ -80,7 +80,21 @@ jobs:
- uses: Swatinem/rust-cache@v2
- name: Run tests
run: cargo +stable nextest run --features ${{ matrix.features.flags }} --verbose --workspace
run: |
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
# (as most of them rely on mmap) otherwise run all
if [ -z "${{ matrix.features.flags }}" ]; then
cargo +stable nextest run --lib --no-default-features --verbose --workspace
else
cargo +stable nextest run --features ${{ matrix.features.flags }} --no-default-features --verbose --workspace
fi
- name: Run doctests
run: cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace
run: |
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
# (as most of them rely on mmap) otherwise run all
if [ -z "${{ matrix.features.flags }}" ]; then
echo "no doctest for no feature flag"
else
cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace
fi

View File

@@ -37,7 +37,7 @@ fs4 = { version = "0.13.1", optional = true }
levenshtein_automata = "0.2.1"
uuid = { version = "1.0.0", features = ["v4", "serde"] }
crossbeam-channel = "0.5.4"
rust-stemmers = "1.2.0"
rust-stemmers = { version = "1.2.0", optional = true }
downcast-rs = "2.0.1"
bitpacking = { version = "0.9.2", default-features = false, features = [
"bitpacker4x",
@@ -56,7 +56,6 @@ itertools = "0.14.0"
measure_time = "0.9.0"
arc-swap = "1.5.0"
bon = "3.3.1"
i_triangle = "0.38.0"
columnar = { version = "0.6", path = "./columnar", package = "tantivy-columnar" }
sstable = { version = "0.6", path = "./sstable", package = "tantivy-sstable", optional = true }
@@ -71,18 +70,17 @@ futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
typetag = "0.2.21"
geo-types = "0.7.17"
[target.'cfg(windows)'.dependencies]
winapi = "0.3.9"
[dev-dependencies]
binggan = "0.14.0"
binggan = "0.14.2"
rand = "0.8.5"
maplit = "1.0.2"
matches = "0.1.9"
pretty_assertions = "1.2.1"
proptest = "1.0.0"
proptest = "1.7.0"
test-log = "0.2.10"
futures = "0.3.21"
paste = "1.0.11"
@@ -115,7 +113,8 @@ debug-assertions = true
overflow-checks = true
[features]
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression"]
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression", "stemmer"]
stemmer = ["rust-stemmers"]
mmap = ["fs4", "tempfile", "memmap2"]
stopwords = []
@@ -175,6 +174,18 @@ harness = false
name = "exists_json"
harness = false
[[bench]]
name = "range_query"
harness = false
[[bench]]
name = "and_or_queries"
harness = false
[[bench]]
name = "range_queries"
harness = false
[[bench]]
name = "bool_queries_with_range"
harness = false

View File

@@ -123,6 +123,7 @@ You can also find other bindings on [GitHub](https://github.com/search?q=tantivy
- [seshat](https://github.com/matrix-org/seshat/): A matrix message database/indexer
- [tantiny](https://github.com/baygeldin/tantiny): Tiny full-text search for Ruby
- [lnx](https://github.com/lnx-search/lnx): adaptable, typo tolerant search engine with a REST API
- [Bichon](https://github.com/rustmailer/bichon): A lightweight, high-performance Rust email archiver with WebUI
- and [more](https://github.com/search?q=tantivy)!
### On average, how much faster is Tantivy compared to Lucene?

View File

@@ -1,5 +1,6 @@
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use rand::distributions::WeightedIndex;
use rand::prelude::SliceRandom;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
@@ -54,12 +55,18 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, extendedstats_f64);
register!(group, percentiles_f64);
register!(group, terms_few);
register!(group, terms_all_unique);
register!(group, terms_many);
register!(group, terms_many_top_1000);
register!(group, terms_many_order_by_term);
register!(group, terms_many_with_top_hits);
register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_few_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status);
register!(group, terms_few_with_histogram);
register!(group, terms_status_with_histogram);
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
@@ -132,12 +139,12 @@ fn extendedstats_f64(index: &Index) {
}
fn percentiles_f64(index: &Index) {
let agg_req = json!({
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
}
}
}
});
execute_agg(index, agg_req);
}
@@ -174,6 +181,19 @@ fn terms_few(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_status(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms_status" } },
});
execute_agg(index, agg_req);
}
fn terms_all_unique(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_all_unique_terms" } },
});
execute_agg(index, agg_req);
}
fn terms_many(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms" } },
@@ -222,6 +242,39 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_all_unique_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_all_unique_terms" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_few_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_few_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
@@ -234,6 +287,17 @@ fn terms_few_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_status_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
@@ -419,14 +483,21 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let json_field = schema_builder.add_json_field("json", FAST);
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
let text_field_few_terms_status =
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let index = Index::create_from_tempdir(schema_builder.build())?;
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
// Approximate production log proportions: INFO dominant, WARN and DEBUG occasional, ERROR rare.
let log_level_distribution = WeightedIndex::new([80u32, 3, 12, 5]).unwrap();
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
@@ -442,15 +513,21 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!())?;
}
if cardinality == Cardinality::Multivalued {
let log_level_sample_a = few_terms_data[log_level_distribution.sample(&mut rng)];
let log_level_sample_b = few_terms_data[log_level_distribution.sample(&mut rng)];
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
text_field => "cool",
text_field => "cool",
text_field_all_unique_terms => "cool",
text_field_all_unique_terms => "coolo",
text_field_many_terms => "cool",
text_field_many_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms_status => log_level_sample_a,
text_field_few_terms_status => log_level_sample_b,
score_field => 1u64,
score_field => 1u64,
score_field_f64 => lg_norm.sample(&mut rng),
@@ -475,8 +552,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!(
text_field => "cool",
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()),
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms_status => few_terms_data[log_level_distribution.sample(&mut rng)],
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,

View File

@@ -0,0 +1,288 @@
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Collector, Count, DocSetCollector, TopDocs};
use tantivy::query::{Query, QueryParser};
use tantivy::schema::{Schema, FAST, INDEXED, TEXT};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
#[derive(Clone)]
struct BenchIndex {
#[allow(dead_code)]
index: Index,
searcher: Searcher,
query_parser: QueryParser,
}
fn build_shared_indices(num_docs: usize, p_title_a: f32, distribution: &str) -> BenchIndex {
// Unified schema
let mut schema_builder = Schema::builder();
let f_title = schema_builder.add_text_field("title", TEXT);
let f_num_rand = schema_builder.add_u64_field("num_rand", INDEXED);
let f_num_asc = schema_builder.add_u64_field("num_asc", INDEXED);
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// Populate index with stable RNG for reproducibility.
let mut rng = StdRng::from_seed([7u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
match distribution {
"dense" => {
for doc_id in 0..num_docs {
// Always add title to avoid empty documents
let title_token = if rng.gen_bool(p_title_a as f64) {
"a"
} else {
"b"
};
let num_rand = rng.gen_range(0u64..1000u64);
let num_asc = (doc_id / 10000) as u64;
writer
.add_document(doc!(
f_title=>title_token,
f_num_rand=>num_rand,
f_num_asc=>num_asc,
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
"sparse" => {
for doc_id in 0..num_docs {
// Always add title to avoid empty documents
let title_token = if rng.gen_bool(p_title_a as f64) {
"a"
} else {
"b"
};
let num_rand = rng.gen_range(0u64..10000000u64);
let num_asc = doc_id as u64;
writer
.add_document(doc!(
f_title=>title_token,
f_num_rand=>num_rand,
f_num_asc=>num_asc,
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
_ => {
panic!("Unsupported distribution type");
}
}
writer.commit().unwrap();
}
// Prepare reader/searcher once.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
// Build query parser for title field
let qp_title = QueryParser::for_index(&index, vec![f_title]);
BenchIndex {
index,
searcher,
query_parser: qp_title,
}
}
fn main() {
// Prepare corpora with varying scenarios
let scenarios = vec![
(
"dense and 99% a".to_string(),
10_000_000,
0.99,
"dense",
0,
9,
),
(
"dense and 99% a".to_string(),
10_000_000,
0.99,
"dense",
990,
999,
),
(
"sparse and 99% a".to_string(),
10_000_000,
0.99,
"sparse",
0,
9,
),
(
"sparse and 99% a".to_string(),
10_000_000,
0.99,
"sparse",
9_999_990,
9_999_999,
),
];
let mut runner = BenchRunner::new();
for (scenario_id, n, p_title_a, num_rand_distribution, range_low, range_high) in scenarios {
// Build index for this scenario
let bench_index = build_shared_indices(n, p_title_a, num_rand_distribution);
// Create benchmark group
let mut group = runner.new_group();
// Now set the name (this moves scenario_id)
group.set_name(scenario_id);
// Define all four field types
let field_names = ["num_rand", "num_asc", "num_rand_fast", "num_asc_fast"];
// Define the three terms we want to test with
let terms = ["a", "b", "z"];
// Generate all combinations of terms and field names
let mut queries = Vec::new();
for &term in &terms {
for &field_name in &field_names {
let query_str = format!(
"{} AND {}:[{} TO {}]",
term, field_name, range_low, range_high
);
queries.push((query_str, field_name.to_string()));
}
}
let query_str = format!(
"{}:[{} TO {}] AND {}:[{} TO {}]",
"num_rand_fast", range_low, range_high, "num_asc_fast", range_low, range_high
);
queries.push((query_str, "num_asc_fast".to_string()));
// Run all benchmark tasks for each query and its corresponding field name
for (query_str, field_name) in queries {
run_benchmark_tasks(&mut group, &bench_index, &query_str, &field_name);
}
group.run();
}
}
/// Run all benchmark tasks for a given query string and field name
fn run_benchmark_tasks(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
field_name: &str,
) {
// Test count
add_bench_task(bench_group, bench_index, query_str, Count, "count");
// Test all results
add_bench_task(
bench_group,
bench_index,
query_str,
DocSetCollector,
"all results",
);
// Test top 100 by the field (if it's a FAST field)
if field_name.ends_with("_fast") {
// Ascending order
{
let collector_name = format!("top100_by_{}_asc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task(
bench_group,
bench_index,
query_str,
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Asc),
&collector_name,
);
}
// Descending order
{
let collector_name = format!("top100_by_{}_desc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task(
bench_group,
bench_index,
query_str,
TopDocs::with_limit(100).order_by_fast_field::<u64>(field_name_owned, Order::Desc),
&collector_name,
);
}
}
}
fn add_bench_task<C: Collector + 'static>(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
collector: C,
collector_name: &str,
) {
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let search_task = SearchTask {
searcher: bench_index.searcher.clone(),
collector,
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct SearchTask<C: Collector> {
searcher: Searcher,
collector: C,
query: Box<dyn Query>,
}
impl<C: Collector> SearchTask<C> {
#[inline(never)]
pub fn run(&self) -> usize {
let result = self.searcher.search(&self.query, &self.collector).unwrap();
if let Some(count) = (&result as &dyn std::any::Any).downcast_ref::<usize>() {
*count
} else if let Some(top_docs) = (&result as &dyn std::any::Any)
.downcast_ref::<Vec<(Option<u64>, tantivy::DocAddress)>>()
{
top_docs.len()
} else if let Some(top_docs) =
(&result as &dyn std::any::Any).downcast_ref::<Vec<(u64, tantivy::DocAddress)>>()
{
top_docs.len()
} else if let Some(doc_set) = (&result as &dyn std::any::Any)
.downcast_ref::<std::collections::HashSet<tantivy::DocAddress>>()
{
doc_set.len()
} else {
eprintln!(
"Unknown collector result type: {:?}",
std::any::type_name::<C::Fruit>()
);
0
}
}
}

365
benches/range_queries.rs Normal file
View File

@@ -0,0 +1,365 @@
use std::ops::Bound;
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Count, DocSetCollector, TopDocs};
use tantivy::query::RangeQuery;
use tantivy::schema::{Schema, FAST, INDEXED};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher, Term};
#[derive(Clone)]
struct BenchIndex {
#[allow(dead_code)]
index: Index,
searcher: Searcher,
}
fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex {
// Schema with fast fields only
let mut schema_builder = Schema::builder();
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// Populate index with stable RNG for reproducibility.
let mut rng = StdRng::from_seed([7u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
match distribution {
"dense" => {
for doc_id in 0..num_docs {
let num_rand = rng.gen_range(0u64..1000u64);
let num_asc = (doc_id / 10000) as u64;
writer
.add_document(doc!(
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
"sparse" => {
for doc_id in 0..num_docs {
let num_rand = rng.gen_range(0u64..10000000u64);
let num_asc = doc_id as u64;
writer
.add_document(doc!(
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
_ => {
panic!("Unsupported distribution type");
}
}
writer.commit().unwrap();
}
// Prepare reader/searcher once.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
BenchIndex { index, searcher }
}
fn main() {
// Prepare corpora with varying scenarios
let scenarios = vec![
// Dense distribution - random values in small range (0-999)
(
"dense_values_search_low_value_range".to_string(),
10_000_000,
"dense",
0,
9,
),
(
"dense_values_search_high_value_range".to_string(),
10_000_000,
"dense",
990,
999,
),
(
"dense_values_search_out_of_range".to_string(),
10_000_000,
"dense",
1000,
1002,
),
(
"sparse_values_search_low_value_range".to_string(),
10_000_000,
"sparse",
0,
9,
),
(
"sparse_values_search_high_value_range".to_string(),
10_000_000,
"sparse",
9_999_990,
9_999_999,
),
(
"sparse_values_search_out_of_range".to_string(),
10_000_000,
"sparse",
10_000_000,
10_000_002,
),
];
let mut runner = BenchRunner::new();
for (scenario_id, n, num_rand_distribution, range_low, range_high) in scenarios {
// Build index for this scenario
let bench_index = build_shared_indices(n, num_rand_distribution);
// Create benchmark group
let mut group = runner.new_group();
// Now set the name (this moves scenario_id)
group.set_name(scenario_id);
// Define fast field types
let field_names = ["num_rand_fast", "num_asc_fast"];
// Generate range queries for fast fields
for &field_name in &field_names {
// Create the range query
let field = bench_index.searcher.schema().get_field(field_name).unwrap();
let lower_term = Term::from_field_u64(field, range_low);
let upper_term = Term::from_field_u64(field, range_high);
let query = RangeQuery::new(Bound::Included(lower_term), Bound::Included(upper_term));
run_benchmark_tasks(
&mut group,
&bench_index,
query,
field_name,
range_low,
range_high,
);
}
group.run();
}
}
/// Run all benchmark tasks for a given range query and field name
fn run_benchmark_tasks(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
field_name: &str,
range_low: u64,
range_high: u64,
) {
// Test count
add_bench_task_count(
bench_group,
bench_index,
query.clone(),
"count",
field_name,
range_low,
range_high,
);
// Test top 100 by the field (ascending order)
{
let collector_name = format!("top100_by_{}_asc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task_top100_asc(
bench_group,
bench_index,
query.clone(),
&collector_name,
field_name,
range_low,
range_high,
field_name_owned,
);
}
// Test top 100 by the field (descending order)
{
let collector_name = format!("top100_by_{}_desc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task_top100_desc(
bench_group,
bench_index,
query,
&collector_name,
field_name,
range_low,
range_high,
field_name_owned,
);
}
}
fn add_bench_task_count(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = CountSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_docset(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = DocSetSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_top100_asc(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
field_name_owned: String,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = Top100AscSearchTask {
searcher: bench_index.searcher.clone(),
query,
field_name: field_name_owned,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_top100_desc(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
field_name_owned: String,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = Top100DescSearchTask {
searcher: bench_index.searcher.clone(),
query,
field_name: field_name_owned,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct CountSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl CountSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
self.searcher.search(&self.query, &Count).unwrap()
}
}
struct DocSetSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl DocSetSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let result = self.searcher.search(&self.query, &DocSetCollector).unwrap();
result.len()
}
}
struct Top100AscSearchTask {
searcher: Searcher,
query: RangeQuery,
field_name: String,
}
impl Top100AscSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let collector =
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Asc);
let result = self.searcher.search(&self.query, &collector).unwrap();
for (_score, doc_address) in &result {
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
}
result.len()
}
}
struct Top100DescSearchTask {
searcher: Searcher,
query: RangeQuery,
field_name: String,
}
impl Top100DescSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let collector =
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Desc);
let result = self.searcher.search(&self.query, &collector).unwrap();
for (_score, doc_address) in &result {
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
}
result.len()
}
}

260
benches/range_query.rs Normal file
View File

@@ -0,0 +1,260 @@
use std::fmt::Display;
use std::net::Ipv6Addr;
use std::ops::RangeInclusive;
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, BenchRunner, OutputValue, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use columnar::MonotonicallyMappableToU128;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use tantivy::collector::{Count, TopDocs};
use tantivy::query::QueryParser;
use tantivy::schema::*;
use tantivy::{doc, Index};
#[global_allocator]
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
fn main() {
bench_range_query();
}
fn bench_range_query() {
let index = get_index_0_to_100();
let mut runner = BenchRunner::new();
runner.add_plugin(PeakMemAllocPlugin::new(GLOBAL));
runner.set_name("range_query on u64");
let field_name_and_descr: Vec<_> = vec![
("id", "Single Valued Range Field"),
("ids", "Multi Valued Range Field"),
];
let range_num_hits = vec![
("90_percent", get_90_percent()),
("10_percent", get_10_percent()),
("1_percent", get_1_percent()),
];
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
runner.set_name("range_query on ip");
let field_name_and_descr: Vec<_> = vec![
("ip", "Single Valued Range Field"),
("ips", "Multi Valued Range Field"),
];
let range_num_hits = vec![
("90_percent", get_90_percent_ip()),
("10_percent", get_10_percent_ip()),
("1_percent", get_1_percent_ip()),
];
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
}
fn test_range<T: Display>(
runner: &mut BenchRunner,
index: &Index,
field_name_and_descr: &[(&str, &str)],
range_num_hits: Vec<(&str, RangeInclusive<T>)>,
) {
for (field, suffix) in field_name_and_descr {
let term_num_hits = vec![
("", ""),
("1_percent", "veryfew"),
("10_percent", "few"),
("90_percent", "most"),
];
let mut group = runner.new_group();
group.set_name(suffix);
// all intersect combinations
for (range_name, range) in &range_num_hits {
for (term_name, term) in &term_num_hits {
let index = &index;
let test_name = if term_name.is_empty() {
format!("id_range_hit_{}", range_name)
} else {
format!(
"id_range_hit_{}_intersect_with_term_{}",
range_name, term_name
)
};
group.register(test_name, move |_| {
let query = if term_name.is_empty() {
"".to_string()
} else {
format!("AND id_name:{}", term)
};
black_box(execute_query(field, range, &query, index));
});
}
}
group.run();
}
}
fn get_index_0_to_100() -> Index {
let mut rng = StdRng::from_seed([1u8; 32]);
let num_vals = 100_000;
let docs: Vec<_> = (0..num_vals)
.map(|_i| {
let id_name = if rng.gen_bool(0.01) {
"veryfew".to_string() // 1%
} else if rng.gen_bool(0.1) {
"few".to_string() // 9%
} else {
"most".to_string() // 90%
};
Doc {
id_name,
id: rng.gen_range(0..100),
// Multiply by 1000, so that we create most buckets in the compact space
// The benches depend on this range to select n-percent of elements with the
// methods below.
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
}
})
.collect();
create_index_from_docs(&docs)
}
#[derive(Clone, Debug)]
pub struct Doc {
pub id_name: String,
pub id: u64,
pub ip: Ipv6Addr,
}
pub fn create_index_from_docs(docs: &[Doc]) -> Index {
let mut schema_builder = Schema::builder();
let id_u64_field = schema_builder.add_u64_field("id", INDEXED | STORED | FAST);
let ids_u64_field =
schema_builder.add_u64_field("ids", NumericOptions::default().set_fast().set_indexed());
let id_f64_field = schema_builder.add_f64_field("id_f64", INDEXED | STORED | FAST);
let ids_f64_field = schema_builder.add_f64_field(
"ids_f64",
NumericOptions::default().set_fast().set_indexed(),
);
let id_i64_field = schema_builder.add_i64_field("id_i64", INDEXED | STORED | FAST);
let ids_i64_field = schema_builder.add_i64_field(
"ids_i64",
NumericOptions::default().set_fast().set_indexed(),
);
let text_field = schema_builder.add_text_field("id_name", STRING | STORED);
let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST);
let ip_field = schema_builder.add_ip_addr_field("ip", FAST);
let ips_field = schema_builder.add_ip_addr_field("ips", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 50_000_000).unwrap();
for doc in docs.iter() {
index_writer
.add_document(doc!(
ids_i64_field => doc.id as i64,
ids_i64_field => doc.id as i64,
ids_f64_field => doc.id as f64,
ids_f64_field => doc.id as f64,
ids_u64_field => doc.id,
ids_u64_field => doc.id,
id_u64_field => doc.id,
id_f64_field => doc.id as f64,
id_i64_field => doc.id as i64,
text_field => doc.id_name.to_string(),
text_field2 => doc.id_name.to_string(),
ips_field => doc.ip,
ips_field => doc.ip,
ip_field => doc.ip,
))
.unwrap();
}
index_writer.commit().unwrap();
}
index
}
fn get_90_percent() -> RangeInclusive<u64> {
0..=90
}
fn get_10_percent() -> RangeInclusive<u64> {
0..=10
}
fn get_1_percent() -> RangeInclusive<u64> {
10..=10
}
fn get_90_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(90 * 1000);
start..=end
}
fn get_10_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn get_1_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(10 * 1000);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
struct NumHits {
count: usize,
}
impl OutputValue for NumHits {
fn column_title() -> &'static str {
"NumHits"
}
fn format(&self) -> Option<String> {
Some(self.count.to_string())
}
}
fn execute_query<T: Display>(
field: &str,
id_range: &RangeInclusive<T>,
suffix: &str,
index: &Index,
) -> NumHits {
let gen_query_inclusive = |from: &T, to: &T| {
format!(
"{}:[{} TO {}] {}",
field,
&from.to_string(),
&to.to_string(),
suffix
)
};
let query = gen_query_inclusive(id_range.start(), id_range.end());
execute_query_(&query, index)
}
fn execute_query_(query: &str, index: &Index) -> NumHits {
let query_from_text = |text: &str| {
QueryParser::for_index(index, vec![])
.parse_query(text)
.unwrap()
};
let query = query_from_text(query);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let num_hits = searcher
.search(&query, &(TopDocs::with_limit(10).order_by_score(), Count))
.unwrap()
.1;
NumHits { count: num_hits }
}

View File

@@ -19,7 +19,7 @@ fn u32_to_i32(val: u32) -> i32 {
#[inline]
unsafe fn u32_to_i32_avx2(vals_u32x8s: DataType) -> DataType {
const HIGHEST_BIT_MASK: DataType = from_u32x8([HIGHEST_BIT; NUM_LANES]);
op_xor(vals_u32x8s, HIGHEST_BIT_MASK)
unsafe { op_xor(vals_u32x8s, HIGHEST_BIT_MASK) }
}
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
@@ -66,17 +66,19 @@ unsafe fn filter_vec_avx2_aux(
]);
const SHIFT: __m256i = from_u32x8([NUM_LANES as u32; NUM_LANES]);
for _ in 0..num_words {
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
unsafe {
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
}
}
output_tail.offset_from(output) as usize
unsafe { output_tail.offset_from(output) as usize }
}
#[inline]
@@ -92,8 +94,7 @@ unsafe fn compute_filter_bitset(val: __m256i, range: std::ops::RangeInclusive<__
let too_low = op_greater(*range.start(), val);
let too_high = op_greater(val, *range.end());
let inside = op_or(too_low, too_high);
255 - std::arch::x86_64::_mm256_movemask_ps(std::mem::transmute::<DataType, __m256>(inside))
as u8
255 - std::arch::x86_64::_mm256_movemask_ps(_mm256_castsi256_ps(inside)) as u8
}
union U8x32 {

View File

@@ -16,7 +16,7 @@ stacker = { version= "0.6", path = "../stacker", package="tantivy-stacker"}
sstable = { version= "0.6", path = "../sstable", package = "tantivy-sstable" }
common = { version= "0.10", path = "../common", package = "tantivy-common" }
tantivy-bitpacker = { version= "0.9", path = "../bitpacker/" }
serde = "1.0.152"
serde = { version = "1.0.152", features = ["derive"] }
downcast-rs = "2.0.1"
[dev-dependencies]

View File

@@ -1,6 +1,6 @@
use binggan::{InputGroup, black_box};
use common::*;
use tantivy_columnar::Column;
use tantivy_columnar::{Column, ValueRange};
pub mod common;
@@ -46,16 +46,16 @@ fn bench_group(mut runner: InputGroup<Column>) {
runner.register("access_first_vals", |column| {
let mut sum = 0;
const BLOCK_SIZE: usize = 32;
let mut docs = vec![0; BLOCK_SIZE];
let mut buffer = vec![None; BLOCK_SIZE];
let mut docs = Vec::with_capacity(BLOCK_SIZE);
let mut buffer = Vec::with_capacity(BLOCK_SIZE);
for i in (0..NUM_DOCS).step_by(BLOCK_SIZE) {
// fill docs
#[allow(clippy::needless_range_loop)]
docs.clear();
for idx in 0..BLOCK_SIZE {
docs[idx] = idx as u32 + i;
docs.push(idx as u32 + i);
}
column.first_vals(&docs, &mut buffer);
buffer.clear();
column.first_vals_in_value_range(&mut docs, &mut buffer, ValueRange::All);
for val in buffer.iter() {
let Some(val) = val else { continue };
sum += *val;

View File

@@ -1,6 +1,7 @@
mod dictionary_encoded;
mod serialize;
use std::cell::RefCell;
use std::fmt::{self, Debug};
use std::io::Write;
use std::ops::{Range, RangeInclusive};
@@ -19,6 +20,11 @@ use crate::column_values::monotonic_mapping::StrictlyMonotonicMappingToInternal;
use crate::column_values::{ColumnValues, monotonic_map_column};
use crate::{Cardinality, DocId, EmptyColumnValues, MonotonicallyMappableToU64, RowId};
thread_local! {
static ROWS: RefCell<Vec<RowId>> = const { RefCell::new(Vec::new()) };
static DOCS: RefCell<Vec<DocId>> = const { RefCell::new(Vec::new()) };
}
#[derive(Clone)]
pub struct Column<T = u64> {
pub index: ColumnIndex,
@@ -89,31 +95,6 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
self.values_for_doc(row_id).next()
}
/// Load the first value for each docid in the provided slice.
#[inline]
pub fn first_vals(&self, docids: &[DocId], output: &mut [Option<T>]) {
match &self.index {
ColumnIndex::Empty { .. } => {}
ColumnIndex::Full => self.values.get_vals_opt(docids, output),
ColumnIndex::Optional(optional_index) => {
for (i, docid) in docids.iter().enumerate() {
output[i] = optional_index
.rank_if_exists(*docid)
.map(|rowid| self.values.get_val(rowid));
}
}
ColumnIndex::Multivalued(multivalued_index) => {
for (i, docid) in docids.iter().enumerate() {
let range = multivalued_index.range(*docid);
let is_empty = range.start == range.end;
if !is_empty {
output[i] = Some(self.values.get_val(range.start));
}
}
}
}
}
/// Translates a block of docids to row_ids.
///
/// returns the row_ids and the matching docids on the same index
@@ -143,7 +124,7 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
#[inline]
pub fn get_docids_for_value_range(
&self,
value_range: RangeInclusive<T>,
value_range: ValueRange<T>,
selected_docid_range: Range<u32>,
doc_ids: &mut Vec<u32>,
) {
@@ -168,6 +149,194 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
}
}
// Separate impl block for methods requiring `Default` for `T`.
impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static + Default> Column<T> {
/// Load the first value for each docid in the provided slice.
///
/// The `docids` vector is mutated: documents that do not match the `value_range` are removed.
/// The `values` vector is populated with the values of the remaining documents.
#[inline]
pub fn first_vals_in_value_range(
&self,
input_docs: &[DocId],
output: &mut Vec<crate::ComparableDoc<Option<T>, DocId>>,
value_range: ValueRange<T>,
) {
match (&self.index, value_range) {
(ColumnIndex::Empty { .. }, value_range) => {
let nulls_match = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(_) => false,
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
ValueRange::LessThan(_, nulls_match) => *nulls_match,
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
};
if nulls_match {
for &doc in input_docs {
output.push(crate::ComparableDoc {
doc,
sort_key: None,
});
}
}
}
(ColumnIndex::Full, value_range) => {
self.values
.get_vals_in_value_range(input_docs, input_docs, output, value_range);
}
(ColumnIndex::Optional(optional_index), value_range) => {
let nulls_match = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(_) => false,
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
ValueRange::LessThan(_, nulls_match) => *nulls_match,
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
};
let fallback_needed = ROWS.with(|rows_cell| {
DOCS.with(|docs_cell| {
let mut rows = rows_cell.borrow_mut();
let mut docs = docs_cell.borrow_mut();
rows.clear();
docs.clear();
let mut has_nulls = false;
for &doc_id in input_docs {
if let Some(row_id) = optional_index.rank_if_exists(doc_id) {
rows.push(row_id);
docs.push(doc_id);
} else {
has_nulls = true;
if nulls_match {
break;
}
}
}
if !has_nulls || !nulls_match {
self.values.get_vals_in_value_range(
&rows,
&docs,
output,
value_range.clone(),
);
return false;
}
true
})
});
if fallback_needed {
for &doc_id in input_docs {
if let Some(row_id) = optional_index.rank_if_exists(doc_id) {
let val = self.values.get_val(row_id);
let value_matches = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(&val),
ValueRange::GreaterThan(t, _) => val > *t,
ValueRange::GreaterThanOrEqual(t, _) => val >= *t,
ValueRange::LessThan(t, _) => val < *t,
ValueRange::LessThanOrEqual(t, _) => val <= *t,
};
if value_matches {
output.push(crate::ComparableDoc {
doc: doc_id,
sort_key: Some(val),
});
}
} else if nulls_match {
output.push(crate::ComparableDoc {
doc: doc_id,
sort_key: None,
});
}
}
}
}
(ColumnIndex::Multivalued(multivalued_index), value_range) => {
let nulls_match = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(_) => false,
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
ValueRange::LessThan(_, nulls_match) => *nulls_match,
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
};
for i in 0..input_docs.len() {
let docid = input_docs[i];
let row_range = multivalued_index.range(docid);
let is_empty = row_range.start == row_range.end;
if !is_empty {
let val = self.values.get_val(row_range.start);
let matches = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(&val),
ValueRange::GreaterThan(t, _) => val > *t,
ValueRange::GreaterThanOrEqual(t, _) => val >= *t,
ValueRange::LessThan(t, _) => val < *t,
ValueRange::LessThanOrEqual(t, _) => val <= *t,
};
if matches {
output.push(crate::ComparableDoc {
doc: docid,
sort_key: Some(val),
});
}
} else if nulls_match {
output.push(crate::ComparableDoc {
doc: docid,
sort_key: None,
});
}
}
}
}
}
}
/// A range of values.
///
/// This type is intended to be used in batch APIs, where the cost of unpacking the enum
/// is outweighed by the time spent processing a batch.
///
/// Implementers should pattern match on the variants to use optimized loops for each case.
#[derive(Clone, Debug)]
pub enum ValueRange<T> {
/// A range that includes both start and end.
Inclusive(RangeInclusive<T>),
/// A range that matches all values.
All,
/// A range that matches all values greater than the threshold.
/// The boolean flag indicates if null values should be included.
GreaterThan(T, bool),
/// A range that matches all values greater than or equal to the threshold.
/// The boolean flag indicates if null values should be included.
GreaterThanOrEqual(T, bool),
/// A range that matches all values less than the threshold.
/// The boolean flag indicates if null values should be included.
LessThan(T, bool),
/// A range that matches all values less than or equal to the threshold.
/// The boolean flag indicates if null values should be included.
LessThanOrEqual(T, bool),
}
impl<T: PartialOrd> ValueRange<T> {
pub fn intersects(&self, min: T, max: T) -> bool {
match self {
ValueRange::Inclusive(range) => *range.start() <= max && *range.end() >= min,
ValueRange::All => true,
ValueRange::GreaterThan(val, _) => max > *val,
ValueRange::GreaterThanOrEqual(val, _) => max >= *val,
ValueRange::LessThan(val, _) => min < *val,
ValueRange::LessThanOrEqual(val, _) => min <= *val,
}
}
}
impl BinarySerializable for Cardinality {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> std::io::Result<()> {
self.to_code().serialize(writer)

View File

@@ -333,7 +333,7 @@ mod tests {
use std::ops::Range;
use super::MultiValueIndex;
use crate::{ColumnarReader, DynamicColumn};
use crate::{ColumnarReader, DynamicColumn, ValueRange};
fn index_to_pos_helper(
index: &MultiValueIndex,
@@ -413,7 +413,7 @@ mod tests {
assert_eq!(row_id_range, 0..4);
let check = |range, expected| {
let full_range = 0..=u64::MAX;
let full_range = ValueRange::All;
let mut docids = Vec::new();
column.get_docids_for_value_range(full_range, range, &mut docids);
assert_eq!(docids, expected);

View File

@@ -7,13 +7,15 @@
//! - Monotonically map values to u64/u128
use std::fmt::Debug;
use std::ops::{Range, RangeInclusive};
use std::ops::Range;
use std::sync::Arc;
use downcast_rs::DowncastSync;
pub use monotonic_mapping::{MonotonicallyMappableToU64, StrictlyMonotonicFn};
pub use monotonic_mapping_u128::MonotonicallyMappableToU128;
use crate::column::ValueRange;
mod merge;
pub(crate) mod monotonic_mapping;
pub(crate) mod monotonic_mapping_u128;
@@ -109,6 +111,307 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
}
}
/// Load the values for the provided docids.
///
/// The values are filtered by the provided value range.
fn get_vals_in_value_range(
&self,
input_indexes: &[u32],
input_doc_ids: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
value_range: ValueRange<T>,
) {
let len = input_indexes.len();
let mut read_head = 0;
match value_range {
ValueRange::All => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
read_head += 4;
}
}
ValueRange::Inclusive(ref range) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if range.contains(&val0) {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if range.contains(&val1) {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if range.contains(&val2) {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if range.contains(&val3) {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
ValueRange::GreaterThan(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 > *threshold {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if val1 > *threshold {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if val2 > *threshold {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if val3 > *threshold {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
ValueRange::GreaterThanOrEqual(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 >= *threshold {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if val1 >= *threshold {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if val2 >= *threshold {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if val3 >= *threshold {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
ValueRange::LessThan(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 < *threshold {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if val1 < *threshold {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if val2 < *threshold {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if val3 < *threshold {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
ValueRange::LessThanOrEqual(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 <= *threshold {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if val1 <= *threshold {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if val2 <= *threshold {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if val3 <= *threshold {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
}
// Process remaining elements (0 to 3)
while read_head < len {
let idx = input_indexes[read_head];
let doc = input_doc_ids[read_head];
let val = self.get_val(idx);
let matches = match value_range {
// 'value_range' is still moved here. This is the outer `value_range`
ValueRange::All => true,
ValueRange::Inclusive(ref r) => r.contains(&val),
ValueRange::GreaterThan(ref t, _) => val > *t,
ValueRange::GreaterThanOrEqual(ref t, _) => val >= *t,
ValueRange::LessThan(ref t, _) => val < *t,
ValueRange::LessThanOrEqual(ref t, _) => val <= *t,
};
if matches {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(val),
});
}
read_head += 1;
}
}
/// Fills an output buffer with the fast field values
/// associated with the `DocId` going from
/// `start` to `start + output.len()`.
@@ -129,15 +432,54 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
/// Note that position == docid for single value fast fields
fn get_row_ids_for_value_range(
&self,
value_range: RangeInclusive<T>,
value_range: ValueRange<T>,
row_id_range: Range<RowId>,
row_id_hits: &mut Vec<RowId>,
) {
let row_id_range = row_id_range.start..row_id_range.end.min(self.num_vals());
for idx in row_id_range {
let val = self.get_val(idx);
if value_range.contains(&val) {
row_id_hits.push(idx);
match value_range {
ValueRange::Inclusive(range) => {
for idx in row_id_range {
let val = self.get_val(idx);
if range.contains(&val) {
row_id_hits.push(idx);
}
}
}
ValueRange::GreaterThan(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val > threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val >= threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::LessThan(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val < threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::LessThanOrEqual(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val <= threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::All => {
row_id_hits.extend(row_id_range);
}
}
}
@@ -193,6 +535,17 @@ impl<T: PartialOrd + Default> ColumnValues<T> for EmptyColumnValues {
fn num_vals(&self) -> u32 {
0
}
fn get_vals_in_value_range(
&self,
input_indexes: &[u32],
input_doc_ids: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
value_range: ValueRange<T>,
) {
let _ = (input_indexes, input_doc_ids, output, value_range);
panic!("Internal Error: Called get_vals_in_value_range of empty column.")
}
}
impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnValues<T>> {
@@ -206,6 +559,18 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
self.as_ref().get_vals_opt(indexes, output)
}
#[inline(always)]
fn get_vals_in_value_range(
&self,
input_indexes: &[u32],
input_doc_ids: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
value_range: ValueRange<T>,
) {
self.as_ref()
.get_vals_in_value_range(input_indexes, input_doc_ids, output, value_range)
}
#[inline(always)]
fn min_value(&self) -> T {
self.as_ref().min_value()
@@ -234,7 +599,7 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
#[inline(always)]
fn get_row_ids_for_value_range(
&self,
range: RangeInclusive<T>,
range: ValueRange<T>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {

View File

@@ -1,8 +1,9 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use std::ops::{Range, RangeInclusive};
use std::ops::Range;
use crate::ColumnValues;
use crate::column::ValueRange;
use crate::column_values::monotonic_mapping::StrictlyMonotonicFn;
struct MonotonicMappingColumn<C, T, Input> {
@@ -80,16 +81,52 @@ where
fn get_row_ids_for_value_range(
&self,
range: RangeInclusive<Output>,
range: ValueRange<Output>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {
self.from_column.get_row_ids_for_value_range(
self.monotonic_mapping.inverse(range.start().clone())
..=self.monotonic_mapping.inverse(range.end().clone()),
doc_id_range,
positions,
)
match range {
ValueRange::Inclusive(range) => self.from_column.get_row_ids_for_value_range(
ValueRange::Inclusive(
self.monotonic_mapping.inverse(range.start().clone())
..=self.monotonic_mapping.inverse(range.end().clone()),
),
doc_id_range,
positions,
),
ValueRange::All => self.from_column.get_row_ids_for_value_range(
ValueRange::All,
doc_id_range,
positions,
),
ValueRange::GreaterThan(threshold, _) => self.from_column.get_row_ids_for_value_range(
ValueRange::GreaterThan(self.monotonic_mapping.inverse(threshold), false),
doc_id_range,
positions,
),
ValueRange::GreaterThanOrEqual(threshold, _) => {
self.from_column.get_row_ids_for_value_range(
ValueRange::GreaterThanOrEqual(
self.monotonic_mapping.inverse(threshold),
false,
),
doc_id_range,
positions,
)
}
ValueRange::LessThan(threshold, _) => self.from_column.get_row_ids_for_value_range(
ValueRange::LessThan(self.monotonic_mapping.inverse(threshold), false),
doc_id_range,
positions,
),
ValueRange::LessThanOrEqual(threshold, _) => {
self.from_column.get_row_ids_for_value_range(
ValueRange::LessThanOrEqual(self.monotonic_mapping.inverse(threshold), false),
doc_id_range,
positions,
)
}
}
}
// We voluntarily do not implement get_range as it yields a regression,

View File

@@ -25,6 +25,7 @@ use common::{BinarySerializable, CountingWriter, OwnedBytes, VInt, VIntU128};
use tantivy_bitpacker::{BitPacker, BitUnpacker};
use crate::RowId;
use crate::column::ValueRange;
use crate::column_values::ColumnValues;
/// The cost per blank is quite hard actually, since blanks are delta encoded, the actual cost of
@@ -338,14 +339,48 @@ impl ColumnValues<u64> for CompactSpaceU64Accessor {
#[inline]
fn get_row_ids_for_value_range(
&self,
value_range: RangeInclusive<u64>,
value_range: ValueRange<u64>,
position_range: Range<u32>,
positions: &mut Vec<u32>,
) {
let value_range = self.0.compact_to_u128(*value_range.start() as u32)
..=self.0.compact_to_u128(*value_range.end() as u32);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
match value_range {
ValueRange::Inclusive(value_range) => {
let value_range = ValueRange::Inclusive(
self.0.compact_to_u128(*value_range.start() as u32)
..=self.0.compact_to_u128(*value_range.end() as u32),
);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::All => {
let position_range = position_range.start..position_range.end.min(self.num_vals());
positions.extend(position_range);
}
ValueRange::GreaterThan(threshold, _) => {
let value_range =
ValueRange::GreaterThan(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
let value_range =
ValueRange::GreaterThanOrEqual(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::LessThan(threshold, _) => {
let value_range =
ValueRange::LessThan(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::LessThanOrEqual(threshold, _) => {
let value_range =
ValueRange::LessThanOrEqual(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
}
}
}
@@ -375,10 +410,47 @@ impl ColumnValues<u128> for CompactSpaceDecompressor {
#[inline]
fn get_row_ids_for_value_range(
&self,
value_range: RangeInclusive<u128>,
value_range: ValueRange<u128>,
position_range: Range<u32>,
positions: &mut Vec<u32>,
) {
let value_range = match value_range {
ValueRange::Inclusive(value_range) => value_range,
ValueRange::All => {
let position_range = position_range.start..position_range.end.min(self.num_vals());
positions.extend(position_range);
return;
}
ValueRange::GreaterThan(threshold, _) => {
let max = self.max_value();
if threshold >= max {
return;
}
(threshold + 1)..=max
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
let max = self.max_value();
if threshold > max {
return;
}
threshold..=max
}
ValueRange::LessThan(threshold, _) => {
let min = self.min_value();
if threshold <= min {
return;
}
min..=(threshold - 1)
}
ValueRange::LessThanOrEqual(threshold, _) => {
let min = self.min_value();
if threshold < min {
return;
}
min..=threshold
}
};
if value_range.start() > value_range.end() {
return;
}
@@ -560,7 +632,7 @@ mod tests {
.collect::<Vec<_>>();
let mut positions = Vec::new();
decompressor.get_row_ids_for_value_range(
range,
ValueRange::Inclusive(range),
0..decompressor.num_vals(),
&mut positions,
);
@@ -604,7 +676,11 @@ mod tests {
let val = *val;
let pos = pos as u32;
let mut positions = Vec::new();
decomp.get_row_ids_for_value_range(val..=val, pos..pos + 1, &mut positions);
decomp.get_row_ids_for_value_range(
ValueRange::Inclusive(val..=val),
pos..pos + 1,
&mut positions,
);
assert_eq!(positions, vec![pos]);
}
@@ -746,7 +822,11 @@ mod tests {
doc_id_range: Range<u32>,
) -> Vec<u32> {
let mut positions = Vec::new();
column.get_row_ids_for_value_range(value_range, doc_id_range, &mut positions);
column.get_row_ids_for_value_range(
ValueRange::Inclusive(value_range),
doc_id_range,
&mut positions,
);
positions
}

View File

@@ -6,6 +6,7 @@ use common::{BinarySerializable, OwnedBytes};
use fastdivide::DividerU64;
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
use crate::column::ValueRange;
use crate::column_values::u64_based::{ColumnCodec, ColumnCodecEstimator, ColumnStats};
use crate::{ColumnValues, RowId};
@@ -41,12 +42,6 @@ fn transform_range_before_linear_transformation(
if range.is_empty() {
return None;
}
if stats.min_value > *range.end() {
return None;
}
if stats.max_value < *range.start() {
return None;
}
let shifted_range =
range.start().saturating_sub(stats.min_value)..=range.end().saturating_sub(stats.min_value);
let start_before_gcd_multiplication: u64 = div_ceil(*shifted_range.start(), stats.gcd);
@@ -72,24 +67,273 @@ impl ColumnValues for BitpackedReader {
self.stats.num_rows
}
fn get_vals_in_value_range(
&self,
input_indexes: &[u32],
input_doc_ids: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<u64>, crate::DocId>>,
value_range: ValueRange<u64>,
) {
match value_range {
ValueRange::All => {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
}
ValueRange::Inclusive(range) => {
if let Some(transformed_range) =
transform_range_before_linear_transformation(&self.stats, range)
{
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.get_val(idx);
if transformed_range.contains(&raw_val) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::GreaterThan(threshold, _) => {
if threshold < self.stats.min_value {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold >= self.stats.max_value {
// All filtered out
} else {
let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get();
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.get_val(idx);
if raw_val > raw_threshold {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
if threshold <= self.stats.min_value {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold > self.stats.max_value {
// All filtered out
} else {
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
let raw_threshold = (diff + gcd - 1) / gcd;
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.get_val(idx);
if raw_val >= raw_threshold {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::LessThan(threshold, _) => {
if threshold > self.stats.max_value {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold <= self.stats.min_value {
// All filtered out
} else {
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
let raw_threshold = if diff % gcd == 0 {
diff / gcd
} else {
diff / gcd + 1
};
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.get_val(idx);
if raw_val < raw_threshold {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::LessThanOrEqual(threshold, _) => {
if threshold >= self.stats.max_value {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold < self.stats.min_value {
// All filtered out
} else {
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
let raw_threshold = diff / gcd;
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.get_val(idx);
if raw_val <= raw_threshold {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
}
}
fn get_row_ids_for_value_range(
&self,
range: RangeInclusive<u64>,
range: ValueRange<u64>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {
let Some(transformed_range) =
transform_range_before_linear_transformation(&self.stats, range)
else {
positions.clear();
return;
};
self.bit_unpacker.get_ids_for_value_range(
transformed_range,
doc_id_range,
&self.data,
positions,
);
match range {
ValueRange::All => {
positions.extend(doc_id_range);
return;
}
ValueRange::Inclusive(range) => {
let Some(transformed_range) =
transform_range_before_linear_transformation(&self.stats, range)
else {
positions.clear();
return;
};
self.bit_unpacker.get_ids_for_value_range(
transformed_range,
doc_id_range,
&self.data,
positions,
);
}
ValueRange::GreaterThan(threshold, _) => {
if threshold < self.stats.min_value {
positions.extend(doc_id_range);
return;
}
if threshold >= self.stats.max_value {
return;
}
let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get();
let max_raw = (self.stats.max_value - self.stats.min_value) / self.stats.gcd.get();
let transformed_range = (raw_threshold + 1)..=max_raw;
self.bit_unpacker.get_ids_for_value_range(
transformed_range,
doc_id_range,
&self.data,
positions,
);
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
if threshold <= self.stats.min_value {
positions.extend(doc_id_range);
return;
}
if threshold > self.stats.max_value {
return;
}
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
let raw_threshold = (diff + gcd - 1) / gcd;
// We want raw >= raw_threshold.
let max_raw = (self.stats.max_value - self.stats.min_value) / self.stats.gcd.get();
let transformed_range = raw_threshold..=max_raw;
self.bit_unpacker.get_ids_for_value_range(
transformed_range,
doc_id_range,
&self.data,
positions,
);
}
ValueRange::LessThan(threshold, _) => {
if threshold > self.stats.max_value {
positions.extend(doc_id_range);
return;
}
if threshold <= self.stats.min_value {
return;
}
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
// We want raw < raw_threshold_limit
// raw <= raw_threshold_limit - 1
let raw_threshold_limit = if diff % gcd == 0 {
diff / gcd
} else {
diff / gcd + 1
};
if raw_threshold_limit == 0 {
return;
}
let transformed_range = 0..=(raw_threshold_limit - 1);
self.bit_unpacker.get_ids_for_value_range(
transformed_range,
doc_id_range,
&self.data,
positions,
);
}
ValueRange::LessThanOrEqual(threshold, _) => {
if threshold >= self.stats.max_value {
positions.extend(doc_id_range);
return;
}
if threshold < self.stats.min_value {
return;
}
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
// We want raw <= raw_threshold.
let raw_threshold = diff / gcd;
let transformed_range = 0..=raw_threshold;
self.bit_unpacker.get_ids_for_value_range(
transformed_range,
doc_id_range,
&self.data,
positions,
);
}
}
}
}

View File

@@ -131,7 +131,7 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
.collect();
let mut positions = Vec::new();
reader.get_row_ids_for_value_range(
vals[test_rand_idx]..=vals[test_rand_idx],
crate::column::ValueRange::Inclusive(vals[test_rand_idx]..=vals[test_rand_idx]),
0..vals.len() as u32,
&mut positions,
);

View File

@@ -0,0 +1,22 @@
use serde::{Deserialize, Serialize};
/// Contains a feature (field, score, etc.) of a document along with the document address.
///
/// Used only by TopNComputer, which implements the actual comparison via a `Comparator`.
#[derive(Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
pub struct ComparableDoc<T, D> {
/// The feature of the document. In practice, this is
/// is a type which can be compared with a `Comparator<T>`.
pub sort_key: T,
/// The document address. In practice, this is either a `DocId` or `DocAddress`.
pub doc: D,
}
impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ComparableDoc")
.field("feature", &self.sort_key)
.field("doc", &self.doc)
.finish()
}
}

View File

@@ -3,7 +3,8 @@ use std::sync::Arc;
use std::{fmt, io};
use common::file_slice::FileSlice;
use common::{ByteCount, DateTime, HasLen, OwnedBytes};
use common::{ByteCount, DateTime, OwnedBytes};
use serde::{Deserialize, Serialize};
use crate::column::{BytesColumn, Column, StrColumn};
use crate::column_values::{StrictlyMonotonicFn, monotonic_map_column};
@@ -317,10 +318,89 @@ impl DynamicColumnHandle {
}
pub fn num_bytes(&self) -> ByteCount {
self.file_slice.len().into()
self.file_slice.num_bytes()
}
/// Legacy helper returning the column space usage.
pub fn column_and_dictionary_num_bytes(&self) -> io::Result<ColumnSpaceUsage> {
self.space_usage()
}
/// Return the space usage of the column, optionally broken down by dictionary and column
/// values.
///
/// For dictionary encoded columns (strings and bytes), this splits the total footprint into
/// the dictionary and the remaining column data (including index and values).
/// For all other column types, the dictionary size is `None` and the column size
/// equals the total bytes.
pub fn space_usage(&self) -> io::Result<ColumnSpaceUsage> {
let total_num_bytes = self.num_bytes();
let dynamic_column = self.open()?;
let dictionary_num_bytes = match &dynamic_column {
DynamicColumn::Bytes(bytes_column) => bytes_column.dictionary().num_bytes(),
DynamicColumn::Str(str_column) => str_column.dictionary().num_bytes(),
_ => {
return Ok(ColumnSpaceUsage::new(self.num_bytes(), None));
}
};
assert!(dictionary_num_bytes <= total_num_bytes);
let column_num_bytes =
ByteCount::from(total_num_bytes.get_bytes() - dictionary_num_bytes.get_bytes());
Ok(ColumnSpaceUsage::new(
column_num_bytes,
Some(dictionary_num_bytes),
))
}
pub fn column_type(&self) -> ColumnType {
self.column_type
}
}
/// Represents space usage of a column.
///
/// `column_num_bytes` tracks the column payload (index, values and footer).
/// For dictionary encoded columns, `dictionary_num_bytes` captures the dictionary footprint.
/// [`ColumnSpaceUsage::total_num_bytes`] returns the sum of both parts.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ColumnSpaceUsage {
column_num_bytes: ByteCount,
dictionary_num_bytes: Option<ByteCount>,
}
impl ColumnSpaceUsage {
pub(crate) fn new(
column_num_bytes: ByteCount,
dictionary_num_bytes: Option<ByteCount>,
) -> Self {
ColumnSpaceUsage {
column_num_bytes,
dictionary_num_bytes,
}
}
pub fn column_num_bytes(&self) -> ByteCount {
self.column_num_bytes
}
pub fn dictionary_num_bytes(&self) -> Option<ByteCount> {
self.dictionary_num_bytes
}
pub fn total_num_bytes(&self) -> ByteCount {
self.column_num_bytes + self.dictionary_num_bytes.unwrap_or_default()
}
/// Merge two space usage values by summing their components.
pub fn merge(&self, other: &ColumnSpaceUsage) -> ColumnSpaceUsage {
let dictionary_num_bytes = match (self.dictionary_num_bytes, other.dictionary_num_bytes) {
(Some(lhs), Some(rhs)) => Some(lhs + rhs),
(Some(val), None) | (None, Some(val)) => Some(val),
(None, None) => None,
};
ColumnSpaceUsage {
column_num_bytes: self.column_num_bytes + other.column_num_bytes,
dictionary_num_bytes,
}
}
}

View File

@@ -29,6 +29,7 @@ mod column;
pub mod column_index;
pub mod column_values;
mod columnar;
mod comparable_doc;
mod dictionary;
mod dynamic_column;
mod iterable;
@@ -36,7 +37,7 @@ pub(crate) mod utils;
mod value;
pub use block_accessor::ColumnBlockAccessor;
pub use column::{BytesColumn, Column, StrColumn};
pub use column::{BytesColumn, Column, StrColumn, ValueRange};
pub use column_index::ColumnIndex;
pub use column_values::{
ColumnValues, EmptyColumnValues, MonotonicallyMappableToU64, MonotonicallyMappableToU128,
@@ -45,10 +46,11 @@ pub use columnar::{
CURRENT_VERSION, ColumnType, ColumnarReader, ColumnarWriter, HasAssociatedColumnType,
MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, Version, merge_columnar,
};
pub use comparable_doc::ComparableDoc;
use sstable::VoidSSTable;
pub use value::{NumericalType, NumericalValue};
pub use self::dynamic_column::{DynamicColumn, DynamicColumnHandle};
pub use self::dynamic_column::{ColumnSpaceUsage, DynamicColumn, DynamicColumnHandle};
pub type RowId = u32;
pub type DocId = u32;

View File

@@ -1,66 +0,0 @@
use geo_types::Point;
use tantivy::collector::TopDocs;
use tantivy::query::SpatialQuery;
use tantivy::schema::{Schema, Value, SPATIAL, STORED, TEXT};
use tantivy::spatial::point::GeoPoint;
use tantivy::{Index, IndexWriter, TantivyDocument};
fn main() -> tantivy::Result<()> {
let mut schema_builder = Schema::builder();
schema_builder.add_json_field("properties", STORED | TEXT);
schema_builder.add_spatial_field("geometry", STORED | SPATIAL);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut index_writer: IndexWriter = index.writer(50_000_000)?;
let doc = TantivyDocument::parse_json(
&schema,
r#"{
"type":"Feature",
"geometry":{
"type":"Polygon",
"coordinates":[[[-99.483911,45.577697],[-99.483869,45.571457],[-99.481739,45.571461],[-99.474881,45.571584],[-99.473167,45.571615],[-99.463394,45.57168],[-99.463391,45.57883],[-99.463368,45.586076],[-99.48177,45.585926],[-99.48384,45.585953],[-99.483885,45.57873],[-99.483911,45.577697]]]
},
"properties":{
"admin_level":"8",
"border_type":"city",
"boundary":"administrative",
"gnis:feature_id":"1267426",
"name":"Hosmer",
"place":"city",
"source":"TIGER/Line® 2008 Place Shapefiles (http://www.census.gov/geo/www/tiger/)",
"wikidata":"Q2442118",
"wikipedia":"en:Hosmer, South Dakota"
}
}"#,
)?;
index_writer.add_document(doc)?;
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let field = schema.get_field("geometry").unwrap();
let query = SpatialQuery::new(
field,
[
GeoPoint {
lon: -99.49,
lat: 45.56,
},
GeoPoint {
lon: -99.45,
lat: 45.59,
},
],
tantivy::query::SpatialQueryType::Intersects,
);
let hits = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
for (_score, doc_address) in &hits {
let retrieved_doc: TantivyDocument = searcher.doc(*doc_address)?;
if let Some(field_value) = retrieved_doc.get_first(field) {
if let Some(geometry_box) = field_value.as_value().into_geometry() {
println!("Retrieved geometry: {:?}", geometry_box);
}
}
}
assert_eq!(hits.len(), 1);
Ok(())
}

View File

@@ -1,7 +1,8 @@
use std::cmp::Ordering;
use std::collections::HashMap;
use std::net::Ipv6Addr;
use columnar::{Column, ColumnType, ColumnarReader, DynamicColumn};
use columnar::{Column, ColumnType, ColumnarReader, DynamicColumn, ValueRange};
use common::json_path_writer::JSON_PATH_SEGMENT_SEP_STR;
use common::DateTime;
use regex::Regex;
@@ -16,7 +17,7 @@ use crate::aggregation::intermediate_agg_result::{
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::AggregationError;
use crate::collector::sort_key::ReverseComparator;
use crate::collector::sort_key::{Comparator, ReverseComparator};
use crate::collector::TopNComputer;
use crate::schema::OwnedValue;
use crate::{DocAddress, DocId, SegmentOrdinal};
@@ -383,7 +384,7 @@ impl From<FastFieldValue> for OwnedValue {
/// Holds a fast field value in its u64 representation, and the order in which it should be sorted.
#[derive(Clone, Serialize, Deserialize, Debug)]
struct DocValueAndOrder {
pub(crate) struct DocValueAndOrder {
/// A fast field value in its u64 representation.
value: Option<u64>,
/// Sort order for the value
@@ -455,6 +456,37 @@ impl PartialEq for DocSortValuesAndFields {
impl Eq for DocSortValuesAndFields {}
impl Comparator<DocSortValuesAndFields> for ReverseComparator {
#[inline(always)]
fn compare(&self, lhs: &DocSortValuesAndFields, rhs: &DocSortValuesAndFields) -> Ordering {
rhs.cmp(lhs)
}
fn threshold_to_valuerange(
&self,
threshold: DocSortValuesAndFields,
) -> ValueRange<DocSortValuesAndFields> {
ValueRange::LessThan(threshold, true)
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct TopHitsSegmentSortKey(pub Vec<DocValueAndOrder>);
impl Comparator<TopHitsSegmentSortKey> for ReverseComparator {
#[inline(always)]
fn compare(&self, lhs: &TopHitsSegmentSortKey, rhs: &TopHitsSegmentSortKey) -> Ordering {
rhs.cmp(lhs)
}
fn threshold_to_valuerange(
&self,
threshold: TopHitsSegmentSortKey,
) -> ValueRange<TopHitsSegmentSortKey> {
ValueRange::LessThan(threshold, true)
}
}
/// The TopHitsCollector used for collecting over segments and merging results.
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct TopHitsTopNComputer {
@@ -518,7 +550,7 @@ impl TopHitsTopNComputer {
pub(crate) struct TopHitsSegmentCollector {
segment_ordinal: SegmentOrdinal,
accessor_idx: usize,
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>,
top_n: TopNComputer<TopHitsSegmentSortKey, DocAddress, ReverseComparator>,
}
impl TopHitsSegmentCollector {
@@ -539,13 +571,15 @@ impl TopHitsSegmentCollector {
req: &TopHitsAggregationReq,
) -> TopHitsTopNComputer {
let mut top_hits_computer = TopHitsTopNComputer::new(req);
// Map TopHitsSegmentSortKey back to Vec<DocValueAndOrder> if needed or use directly
// The TopNComputer here stores TopHitsSegmentSortKey.
let top_results = self.top_n.into_vec();
for res in top_results {
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
top_hits_computer.collect(
DocSortValuesAndFields {
sorts: res.sort_key,
sorts: res.sort_key.0,
doc_value_fields,
},
res.doc,
@@ -579,7 +613,7 @@ impl TopHitsSegmentCollector {
.collect();
self.top_n.push(
sorts,
TopHitsSegmentSortKey(sorts),
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,

View File

@@ -96,10 +96,9 @@ mod histogram_collector;
pub use histogram_collector::HistogramCollector;
mod multi_collector;
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
pub use columnar::ComparableDoc;
mod top_collector;
pub use self::top_collector::ComparableDoc;
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
mod top_score_collector;
pub use self::top_score_collector::{TopDocs, TopNComputer};

View File

@@ -1,25 +1,49 @@
mod order;
mod sort_by_erased_type;
mod sort_by_score;
mod sort_by_static_fast_value;
mod sort_by_string;
mod sort_key_computer;
pub use order::*;
pub use sort_by_erased_type::SortByErasedType;
pub use sort_by_score::SortBySimilarityScore;
pub use sort_by_static_fast_value::SortByStaticFastValue;
pub use sort_by_string::SortByString;
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
#[cfg(test)]
mod tests {
pub(crate) mod tests {
// By spec, regardless of whether ascending or descending order was requested, in presence of a
// tie, we sort by ascending doc id/doc address.
pub(crate) fn sort_hits<TSortKey: Ord, D: Ord>(
hits: &mut [ComparableDoc<TSortKey, D>],
order: Order,
) {
if order.is_asc() {
hits.sort_by(|l, r| l.sort_key.cmp(&r.sort_key).then(l.doc.cmp(&r.doc)));
} else {
hits.sort_by(|l, r| {
l.sort_key
.cmp(&r.sort_key)
.reverse() // This is descending
.then(l.doc.cmp(&r.doc))
});
}
}
use std::collections::HashMap;
use std::ops::Range;
use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString};
use crate::collector::sort_key::{
SortByErasedType, SortBySimilarityScore, SortByStaticFastValue, SortByString,
};
use crate::collector::top_score_collector::compare_for_top_k;
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
use crate::indexer::NoMergePolicy;
use crate::query::{AllQuery, QueryParser};
use crate::schema::{Schema, FAST, TEXT};
use crate::schema::{OwnedValue, Schema, FAST, TEXT};
use crate::{DocAddress, Document, Index, Order, Score, Searcher};
fn make_index() -> crate::Result<Index> {
@@ -294,11 +318,9 @@ mod tests {
(SortBySimilarityScore, score_order),
(SortByString::for_field("city"), city_order),
));
Ok(searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(f, doc)| (f, ids[&doc]))
.collect())
let results: Vec<((Score, Option<String>), DocAddress)> =
searcher.search(&AllQuery, &top_collector)?;
Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
}
assert_eq!(
@@ -323,6 +345,97 @@ mod tests {
Ok(())
}
#[test]
fn test_order_by_score_then_owned_value() -> crate::Result<()> {
let index = make_index()?;
type SortKey = (Score, OwnedValue);
fn query(
index: &Index,
score_order: Order,
city_order: Order,
) -> crate::Result<Vec<(SortKey, u64)>> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by::<(Score, OwnedValue)>((
(SortBySimilarityScore, score_order),
(SortByErasedType::for_field("city"), city_order),
));
let results: Vec<((Score, OwnedValue), DocAddress)> =
searcher.search(&AllQuery, &top_collector)?;
Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
}
assert_eq!(
&query(&index, Order::Asc, Order::Asc)?,
&[
((1.0, OwnedValue::Str("austin".to_owned())), 0),
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
((1.0, OwnedValue::Null), 3),
]
);
assert_eq!(
&query(&index, Order::Asc, Order::Desc)?,
&[
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
((1.0, OwnedValue::Str("austin".to_owned())), 0),
((1.0, OwnedValue::Null), 3),
]
);
Ok(())
}
#[test]
fn test_order_by_compound_fast_fields() -> crate::Result<()> {
let index = make_index()?;
type CompoundSortKey = (Option<String>, Option<f64>);
fn assert_query(
index: &Index,
city_order: Order,
altitude_order: Order,
expected: Vec<(CompoundSortKey, u64)>,
) -> crate::Result<()> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by((
(SortByString::for_field("city"), city_order),
(
SortByStaticFastValue::<f64>::for_field("altitude"),
altitude_order,
),
));
let actual = searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(key, doc)| (key, ids[&doc]))
.collect::<Vec<_>>();
assert_eq!(actual, expected);
Ok(())
}
assert_query(
&index,
Order::Asc,
Order::Desc,
vec![
((Some("austin".to_owned()), Some(149.0)), 0),
((Some("greenville".to_owned()), Some(27.0)), 1),
((Some("tokyo".to_owned()), Some(40.0)), 2),
((None, Some(0.0)), 3),
],
)?;
Ok(())
}
use proptest::prelude::*;
proptest! {
@@ -372,15 +485,10 @@ mod tests {
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
sort_hits(&mut comparable_docs, order);
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
@@ -390,4 +498,197 @@ mod tests {
);
}
}
proptest! {
#[test]
fn test_order_by_compound_prop(
city_order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
altitude_order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..20_usize,
offset in 0..20_usize,
segments_data in proptest::collection::vec(
proptest::collection::vec(
(proptest::option::of("[a-c]"), proptest::option::of(0..50u64)),
1..10_usize // segment size
),
1..4_usize // num segments
)
) {
use crate::collector::sort_key::ComparatorEnum;
use crate::TantivyDocument;
let mut schema_builder = Schema::builder();
let city = schema_builder.add_text_field("city", TEXT | FAST);
let altitude = schema_builder.add_u64_field("altitude", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
for segment_data in segments_data.into_iter() {
for (city_val, altitude_val) in segment_data.into_iter() {
let mut doc = TantivyDocument::default();
if let Some(c) = city_val {
doc.add_text(city, c);
}
if let Some(a) = altitude_val {
doc.add_u64(altitude, a);
}
index_writer.add_document(doc).unwrap();
}
index_writer.commit().unwrap();
}
let searcher = index.reader().unwrap().searcher();
let top_collector = TopDocs::with_limit(limit)
.and_offset(offset)
.order_by((
(SortByString::for_field("city"), city_order),
(
SortByStaticFastValue::<u64>::for_field("altitude"),
altitude_order,
),
));
let actual_results = searcher.search(&AllQuery, &top_collector).unwrap();
let actual_doc_ids: Vec<DocAddress> =
actual_results.into_iter().map(|(_, doc)| doc).collect();
// Verification logic
let all_docs_collector = DocSetCollector;
let all_docs = searcher.search(&AllQuery, &all_docs_collector).unwrap();
let docs_with_keys: Vec<((Option<String>, Option<u64>), DocAddress)> = all_docs
.into_iter()
.map(|doc_addr| {
let reader = searcher.segment_reader(doc_addr.segment_ord);
let city_val = if let Some(col) = reader.fast_fields().str("city").unwrap() {
let ord = col.ords().first(doc_addr.doc_id);
if let Some(ord) = ord {
let mut out = Vec::new();
col.dictionary().ord_to_term(ord, &mut out).unwrap();
String::from_utf8(out).ok()
} else {
None
}
} else {
None
};
let alt_val = if let Some((col, _)) = reader.fast_fields().u64_lenient("altitude").unwrap() {
col.first(doc_addr.doc_id)
} else {
None
};
((city_val, alt_val), doc_addr)
})
.collect();
let city_comparator = ComparatorEnum::from(city_order);
let alt_comparator = ComparatorEnum::from(altitude_order);
let comparator = (city_comparator, alt_comparator);
let mut comparable_docs: Vec<ComparableDoc<_, _>> = docs_with_keys
.into_iter()
.map(|(sort_key, doc)| ComparableDoc { sort_key, doc })
.collect();
comparable_docs.sort_by(|l, r| compare_for_top_k(&comparator, l, r));
let expected_results = comparable_docs
.into_iter()
.skip(offset)
.take(limit)
.collect::<Vec<_>>();
let expected_doc_ids: Vec<DocAddress> =
expected_results.into_iter().map(|cd| cd.doc).collect();
prop_assert_eq!(actual_doc_ids, expected_doc_ids);
}
}
proptest! {
#[test]
fn test_order_by_u64_prop(
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..20_usize,
offset in 0..20_usize,
segments_data in proptest::collection::vec(
proptest::collection::vec(
proptest::option::of(0..100u64),
1..1000_usize // segment size
),
1..4_usize // num segments
)
) {
use crate::collector::sort_key::ComparatorEnum;
use crate::TantivyDocument;
let mut schema_builder = Schema::builder();
let field = schema_builder.add_u64_field("field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
for segment_data in segments_data.into_iter() {
for val in segment_data.into_iter() {
let mut doc = TantivyDocument::default();
if let Some(v) = val {
doc.add_u64(field, v);
}
index_writer.add_document(doc).unwrap();
}
index_writer.commit().unwrap();
}
let searcher = index.reader().unwrap().searcher();
let top_collector = TopDocs::with_limit(limit)
.and_offset(offset)
.order_by((SortByStaticFastValue::<u64>::for_field("field"), order));
let actual_results = searcher.search(&AllQuery, &top_collector).unwrap();
let actual_doc_ids: Vec<DocAddress> =
actual_results.into_iter().map(|(_, doc)| doc).collect();
// Verification logic
let all_docs_collector = DocSetCollector;
let all_docs = searcher.search(&AllQuery, &all_docs_collector).unwrap();
let docs_with_keys: Vec<(Option<u64>, DocAddress)> = all_docs
.into_iter()
.map(|doc_addr| {
let reader = searcher.segment_reader(doc_addr.segment_ord);
let val = if let Some((col, _)) = reader.fast_fields().u64_lenient("field").unwrap() {
col.first(doc_addr.doc_id)
} else {
None
};
(val, doc_addr)
})
.collect();
let comparator = ComparatorEnum::from(order);
let mut comparable_docs: Vec<ComparableDoc<_, _>> = docs_with_keys
.into_iter()
.map(|(sort_key, doc)| ComparableDoc { sort_key, doc })
.collect();
comparable_docs.sort_by(|l, r| compare_for_top_k(&comparator, l, r));
let expected_results = comparable_docs
.into_iter()
.skip(offset)
.take(limit)
.collect::<Vec<_>>();
let expected_doc_ids: Vec<DocAddress> =
expected_results.into_iter().map(|cd| cd.doc).collect();
prop_assert_eq!(actual_doc_ids, expected_doc_ids);
}
}
}

View File

@@ -1,53 +1,204 @@
use std::cmp::Ordering;
use columnar::{ComparableDoc, MonotonicallyMappableToU64, ValueRange};
use serde::{Deserialize, Serialize};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::Schema;
use crate::schema::{OwnedValue, Schema};
use crate::{DocId, Order, Score};
fn compare_owned_value<const NULLS_FIRST: bool>(lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
match (lhs, rhs) {
(OwnedValue::Null, OwnedValue::Null) => Ordering::Equal,
(OwnedValue::Null, _) => {
if NULLS_FIRST {
Ordering::Less
} else {
Ordering::Greater
}
}
(_, OwnedValue::Null) => {
if NULLS_FIRST {
Ordering::Greater
} else {
Ordering::Less
}
}
(OwnedValue::Str(a), OwnedValue::Str(b)) => a.cmp(b),
(OwnedValue::PreTokStr(a), OwnedValue::PreTokStr(b)) => a.cmp(b),
(OwnedValue::U64(a), OwnedValue::U64(b)) => a.cmp(b),
(OwnedValue::I64(a), OwnedValue::I64(b)) => a.cmp(b),
(OwnedValue::F64(a), OwnedValue::F64(b)) => a.to_u64().cmp(&b.to_u64()),
(OwnedValue::Bool(a), OwnedValue::Bool(b)) => a.cmp(b),
(OwnedValue::Date(a), OwnedValue::Date(b)) => a.cmp(b),
(OwnedValue::Facet(a), OwnedValue::Facet(b)) => a.cmp(b),
(OwnedValue::Bytes(a), OwnedValue::Bytes(b)) => a.cmp(b),
(OwnedValue::IpAddr(a), OwnedValue::IpAddr(b)) => a.cmp(b),
(OwnedValue::U64(a), OwnedValue::I64(b)) => {
if *b < 0 {
Ordering::Greater
} else {
a.cmp(&(*b as u64))
}
}
(OwnedValue::I64(a), OwnedValue::U64(b)) => {
if *a < 0 {
Ordering::Less
} else {
(*a as u64).cmp(b)
}
}
(OwnedValue::U64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
(OwnedValue::F64(a), OwnedValue::U64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
(OwnedValue::I64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
(OwnedValue::F64(a), OwnedValue::I64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
(a, b) => {
let ord = a.discriminant_value().cmp(&b.discriminant_value());
// If the discriminant is equal, it's because a new type was added, but hasn't been
// included in this `match` statement.
assert!(
ord != Ordering::Equal,
"Unimplemented comparison for type of {a:?}, {b:?}"
);
ord
}
}
}
/// Comparator trait defining the order in which documents should be ordered.
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
/// Return the order between two values.
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
/// Return a `ValueRange` that matches all values that are greater than the provided threshold.
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T>;
}
/// With the natural comparator, the top k collector will return
/// the top documents in decreasing order.
/// Compare values naturally (e.g. 1 < 2).
///
/// When used with `TopDocs`, which reverses the order, this results in a
/// "Descending" sort (Greatest values first).
///
/// `None` (or Null for `OwnedValue`) values are considered to be smaller than any other value,
/// and will therefore appear last in a descending sort (e.g. `[Some(20), Some(10), None]`).
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct NaturalComparator;
impl<T: PartialOrd> Comparator<T> for NaturalComparator {
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
lhs.partial_cmp(rhs).unwrap()
lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal)
}
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T> {
ValueRange::GreaterThan(threshold, false)
}
}
/// Sorts document in reverse order.
/// A (partial) implementation of comparison for OwnedValue.
///
/// If the sort key is None, it will considered as the lowest value, and will therefore appear
/// first.
/// Intended for use within columns of homogenous types, and so will panic for OwnedValues with
/// mismatched types. The one exception is Null, for which we do define all comparisons.
impl Comparator<OwnedValue> for NaturalComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ true>(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
ValueRange::GreaterThan(threshold, false)
}
}
/// Compare values in reverse (e.g. 2 < 1).
///
/// When used with `TopDocs`, which reverses the order, this results in an
/// "Ascending" sort (Smallest values first).
///
/// `None` is considered smaller than `Some` in the underlying comparator, but because the
/// comparison is reversed, `None` is effectively treated as the lowest value in the resulting
/// Ascending sort (e.g. `[None, Some(10), Some(20)]`).
///
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids.
/// to the NaturalComparator. In presence of a tie on the sort key, documents will always be
/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the sort key's order.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct ReverseComparator;
impl<T> Comparator<T> for ReverseComparator
where NaturalComparator: Comparator<T>
{
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
NaturalComparator.compare(rhs, lhs)
macro_rules! impl_reverse_comparator_primitive {
($($t:ty),*) => {
$(
impl Comparator<$t> for ReverseComparator {
#[inline(always)]
fn compare(&self, lhs: &$t, rhs: &$t) -> Ordering {
NaturalComparator.compare(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: $t) -> ValueRange<$t> {
ValueRange::LessThan(threshold, true)
}
}
)*
}
}
/// Sorts document in reverse order, but considers None as having the lowest value.
impl_reverse_comparator_primitive!(
bool,
u8,
u16,
u32,
u64,
u128,
usize,
i8,
i16,
i32,
i64,
i128,
isize,
f32,
f64,
String,
crate::DateTime,
Vec<u8>,
crate::schema::Facet
);
impl<T: PartialOrd + Send + Sync + std::fmt::Debug + Clone + 'static> Comparator<Option<T>>
for ReverseComparator
{
#[inline(always)]
fn compare(&self, lhs: &Option<T>, rhs: &Option<T>) -> Ordering {
NaturalComparator.compare(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
let is_some = threshold.is_some();
ValueRange::LessThan(threshold, is_some)
}
}
impl Comparator<OwnedValue> for ReverseComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
NaturalComparator.compare(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
let is_not_null = !matches!(threshold, OwnedValue::Null);
ValueRange::LessThan(threshold, is_not_null)
}
}
/// Compare values in reverse, but treating `None` as lower than `Some`.
///
/// When used with `TopDocs`, which reverses the order, this results in an
/// "Ascending" sort (Smallest values first), but with `None` values appearing last
/// (e.g. `[Some(10), Some(20), None]`).
///
/// This is usually what is wanted when sorting by a field in an ascending order.
/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the
/// cheapest items first, and the items without a price at last.
/// For instance, in an e-commerce website, if sorting by price ascending,
/// the cheapest items would appear first, and items without a price would appear last.
#[derive(Debug, Copy, Clone, Default)]
pub struct ReverseNoneIsLowerComparator;
@@ -63,6 +214,14 @@ where ReverseComparator: Comparator<T>
(Some(lhs), Some(rhs)) => ReverseComparator.compare(lhs, rhs),
}
}
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
if threshold.is_some() {
ValueRange::LessThan(threshold, false)
} else {
ValueRange::GreaterThan(threshold, false)
}
}
}
impl Comparator<u32> for ReverseNoneIsLowerComparator {
@@ -70,6 +229,10 @@ impl Comparator<u32> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange<u32> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<u64> for ReverseNoneIsLowerComparator {
@@ -77,6 +240,10 @@ impl Comparator<u64> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange<u64> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<f64> for ReverseNoneIsLowerComparator {
@@ -84,6 +251,10 @@ impl Comparator<f64> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange<f64> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<f32> for ReverseNoneIsLowerComparator {
@@ -91,6 +262,10 @@ impl Comparator<f32> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange<f32> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<i64> for ReverseNoneIsLowerComparator {
@@ -98,6 +273,10 @@ impl Comparator<i64> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange<i64> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<String> for ReverseNoneIsLowerComparator {
@@ -105,6 +284,129 @@ impl Comparator<String> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: String) -> ValueRange<String> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<OwnedValue> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ false>(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
ValueRange::LessThan(threshold, false)
}
}
/// Compare values naturally, but treating `None` as higher than `Some`.
///
/// When used with `TopDocs`, which reverses the order, this results in a
/// "Descending" sort (Greatest values first), but with `None` values appearing first
/// (e.g. `[None, Some(20), Some(10)]`).
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct NaturalNoneIsHigherComparator;
impl<T> Comparator<Option<T>> for NaturalNoneIsHigherComparator
where NaturalComparator: Comparator<T>
{
#[inline(always)]
fn compare(&self, lhs_opt: &Option<T>, rhs_opt: &Option<T>) -> Ordering {
match (lhs_opt, rhs_opt) {
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Greater,
(Some(_), None) => Ordering::Less,
(Some(lhs), Some(rhs)) => NaturalComparator.compare(lhs, rhs),
}
}
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
if threshold.is_some() {
let is_some = threshold.is_some();
ValueRange::GreaterThan(threshold, is_some)
} else {
ValueRange::LessThan(threshold, false)
}
}
}
impl Comparator<u32> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange<u32> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<u64> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange<u64> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<f64> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange<f64> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<f32> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange<f32> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<i64> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange<i64> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<String> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: String) -> ValueRange<String> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<OwnedValue> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ false>(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
ValueRange::GreaterThan(threshold, true)
}
}
/// An enum representing the different sort orders.
@@ -115,8 +417,10 @@ pub enum ComparatorEnum {
Natural,
/// Reverse order (See [ReverseComparator])
Reverse,
/// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator])
/// Reverse order by treating None as the lowest value. (See [ReverseNoneLowerComparator])
ReverseNoneLower,
/// Natural order but treating None as the highest value. (See [NaturalNoneIsHigherComparator])
NaturalNoneHigher,
}
impl From<Order> for ComparatorEnum {
@@ -133,6 +437,7 @@ where
ReverseNoneIsLowerComparator: Comparator<T>,
NaturalComparator: Comparator<T>,
ReverseComparator: Comparator<T>,
NaturalNoneIsHigherComparator: Comparator<T>,
{
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
@@ -140,6 +445,20 @@ where
ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs),
ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs),
ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs),
ComparatorEnum::NaturalNoneHigher => NaturalNoneIsHigherComparator.compare(lhs, rhs),
}
}
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T> {
match self {
ComparatorEnum::Natural => NaturalComparator.threshold_to_valuerange(threshold),
ComparatorEnum::Reverse => ReverseComparator.threshold_to_valuerange(threshold),
ComparatorEnum::ReverseNoneLower => {
ReverseNoneIsLowerComparator.threshold_to_valuerange(threshold)
}
ComparatorEnum::NaturalNoneHigher => {
NaturalNoneIsHigherComparator.threshold_to_valuerange(threshold)
}
}
}
}
@@ -156,6 +475,10 @@ where
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
}
fn threshold_to_valuerange(&self, threshold: (Head, Tail)) -> ValueRange<(Head, Tail)> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, (Type2, Type3))>
@@ -172,6 +495,13 @@ where
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
.then_with(|| self.2.compare(&lhs.1 .1, &rhs.1 .1))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, (Type2, Type3)),
) -> ValueRange<(Type1, (Type2, Type3))> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, Type2, Type3)>
@@ -188,6 +518,13 @@ where
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, Type2, Type3),
) -> ValueRange<(Type1, Type2, Type3)> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
@@ -211,6 +548,13 @@ where
.then_with(|| self.2.compare(&lhs.1 .1 .0, &rhs.1 .1 .0))
.then_with(|| self.3.compare(&lhs.1 .1 .1, &rhs.1 .1 .1))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, (Type2, (Type3, Type4))),
) -> ValueRange<(Type1, (Type2, (Type3, Type4)))> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
@@ -234,6 +578,13 @@ where
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
.then_with(|| self.3.compare(&lhs.3, &rhs.3))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, Type2, Type3, Type4),
) -> ValueRange<(Type1, Type2, Type3, Type4)> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, ComparatorEnum)
@@ -322,16 +673,33 @@ impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComput
for SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator>
where
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send,
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
TSegmentSortKey: Clone + 'static + Sync + Send,
TComparator: Comparator<TSegmentSortKey> + Clone + 'static + Sync + Send,
{
type SortKey = TSegmentSortKeyComputer::SortKey;
type SegmentSortKey = TSegmentSortKey;
type SegmentComparator = TComparator;
type Buffer = TSegmentSortKeyComputer::Buffer;
fn segment_comparator(&self) -> Self::SegmentComparator {
self.comparator.clone()
}
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.segment_sort_key_computer.segment_sort_key(doc, score)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
self.segment_sort_key_computer
.segment_sort_keys(input_docs, output, buffer, filter)
}
#[inline(always)]
fn compare_segment_sort_key(
&self,
@@ -346,3 +714,55 @@ where
.convert_segment_sort_key(sort_key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::OwnedValue;
#[test]
fn test_natural_none_is_higher() {
let comp = NaturalNoneIsHigherComparator;
let null = None;
let v1 = Some(1_u64);
let v2 = Some(2_u64);
// NaturalNoneIsGreaterComparator logic:
// 1. Delegates to NaturalComparator for non-nulls.
// NaturalComparator compare(2, 1) -> 2.cmp(1) -> Greater.
assert_eq!(comp.compare(&v2, &v1), Ordering::Greater);
// 2. Treats None (Null) as Greater than any value.
// compare(None, Some(2)) should be Greater.
assert_eq!(comp.compare(&null, &v2), Ordering::Greater);
// compare(Some(1), None) should be Less.
assert_eq!(comp.compare(&v1, &null), Ordering::Less);
// compare(None, None) should be Equal.
assert_eq!(comp.compare(&null, &null), Ordering::Equal);
}
#[test]
fn test_mixed_ownedvalue_compare() {
let u = OwnedValue::U64(10);
let i = OwnedValue::I64(10);
let f = OwnedValue::F64(10.0);
let nc = NaturalComparator;
assert_eq!(nc.compare(&u, &i), Ordering::Equal);
assert_eq!(nc.compare(&u, &f), Ordering::Equal);
assert_eq!(nc.compare(&i, &f), Ordering::Equal);
let u2 = OwnedValue::U64(11);
assert_eq!(nc.compare(&u2, &f), Ordering::Greater);
let s = OwnedValue::Str("a".to_string());
// Str < U64
assert_eq!(nc.compare(&s, &u), Ordering::Less);
// Str < I64
assert_eq!(nc.compare(&s, &i), Ordering::Less);
// Str < F64
assert_eq!(nc.compare(&s, &f), Ordering::Less);
}
}

View File

@@ -0,0 +1,410 @@
use columnar::{ColumnType, MonotonicallyMappableToU64, ValueRange};
use crate::collector::sort_key::sort_by_score::SortBySimilarityScoreSegmentComputer;
use crate::collector::sort_key::{
NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString,
};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::FastFieldNotAvailableError;
use crate::schema::OwnedValue;
use crate::{DateTime, DocId, Score};
/// Sort by the boxed / OwnedValue representation of either a fast field, or of the score.
///
/// Using the OwnedValue representation allows for type erasure, and can be useful when sort orders
/// are not known until runtime. But it comes with a performance cost: wherever possible, prefer to
/// use a SortKeyComputer implementation with a known-type at compile time.
#[derive(Debug, Clone)]
pub enum SortByErasedType {
/// Sort by a fast field
Field(String),
/// Sort by score
Score,
}
impl SortByErasedType {
/// Creates a new sort key computer which will sort by the given fast field column, with type
/// erasure.
pub fn for_field(column_name: impl ToString) -> Self {
Self::Field(column_name.to_string())
}
/// Creates a new sort key computer which will sort by score, with type erasure.
pub fn for_score() -> Self {
Self::Score
}
}
trait ErasedSegmentSortKeyComputer: Send + Sync {
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64>;
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
filter: ValueRange<Option<u64>>,
);
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue;
}
struct ErasedSegmentSortKeyComputerWrapper<C, F>
where
C: SegmentSortKeyComputer<SegmentSortKey = Option<u64>> + Send + Sync,
F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static,
{
inner: C,
converter: F,
buffer: C::Buffer,
}
impl<C, F> ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper<C, F>
where
C: SegmentSortKeyComputer<SegmentSortKey = Option<u64>> + Send + Sync,
F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static,
{
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
self.inner.segment_sort_key(doc, score)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
filter: ValueRange<Option<u64>>,
) {
self.inner
.segment_sort_keys(input_docs, output, &mut self.buffer, filter)
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
let val = self.inner.convert_segment_sort_key(sort_key);
(self.converter)(val)
}
}
struct ScoreSegmentSortKeyComputer {
segment_computer: SortBySimilarityScoreSegmentComputer,
}
impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
let score_value: f64 = self.segment_computer.segment_sort_key(doc, score).into();
Some(score_value.to_u64())
}
fn segment_sort_keys(
&mut self,
_input_docs: &[DocId],
_output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
_filter: ValueRange<Option<u64>>,
) {
unimplemented!("Batch computation not supported for score sorting")
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
let score_value: u64 = sort_key.expect("This implementation always produces a score.");
OwnedValue::F64(f64::from_u64(score_value))
}
}
impl SortKeyComputer for SortByErasedType {
type SortKey = OwnedValue;
type Child = ErasedColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator;
fn requires_scoring(&self) -> bool {
matches!(self, Self::Score)
}
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let inner: Box<dyn ErasedSegmentSortKeyComputer> = match self {
Self::Field(column_name) => {
let fast_fields = segment_reader.fast_fields();
// TODO: We currently double-open the column to avoid relying on the implementation
// details of `SortByString` or `SortByStaticFastValue`. Once
// https://github.com/quickwit-oss/tantivy/issues/2776 is resolved, we should
// consider directly constructing the appropriate `SegmentSortKeyComputer` type for
// the column that we open here.
let (_column, column_type) =
fast_fields.u64_lenient(column_name)?.ok_or_else(|| {
FastFieldNotAvailableError {
field_name: column_name.to_owned(),
}
})?;
match column_type {
ColumnType::Str => {
let computer = SortByString::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<String>| {
val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::U64 => {
let computer = SortByStaticFastValue::<u64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<u64>| {
val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::I64 => {
let computer = SortByStaticFastValue::<i64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<i64>| {
val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::F64 => {
let computer = SortByStaticFastValue::<f64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<f64>| {
val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::Bool => {
let computer = SortByStaticFastValue::<bool>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<bool>| {
val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::DateTime => {
let computer = SortByStaticFastValue::<DateTime>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<DateTime>| {
val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
column_type => {
return Err(crate::TantivyError::SchemaError(format!(
"Field `{}` is of type {column_type:?}, which is not supported for \
sorting by owned value yet.",
column_name
)))
}
}
}
Self::Score => Box::new(ScoreSegmentSortKeyComputer {
segment_computer: SortBySimilarityScore
.segment_sort_key_computer(segment_reader)?,
}),
};
Ok(ErasedColumnSegmentSortKeyComputer { inner })
}
}
pub struct ErasedColumnSegmentSortKeyComputer {
inner: Box<dyn ErasedSegmentSortKeyComputer>,
}
impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer {
type SortKey = OwnedValue;
type SegmentSortKey = Option<u64>;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
self.inner.segment_sort_key(doc, score)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
self.inner.segment_sort_keys(input_docs, output, filter)
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue {
self.inner.convert_segment_sort_key(segment_sort_key)
}
}
#[cfg(test)]
mod tests {
use crate::collector::sort_key::{ComparatorEnum, SortByErasedType};
use crate::collector::TopDocs;
use crate::query::AllQuery;
use crate::schema::{OwnedValue, Schema, FAST, TEXT};
use crate::Index;
#[test]
fn test_sort_by_owned_u64() {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(id_field => 10u64)).unwrap();
writer.add_document(doc!(id_field => 2u64)).unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Natural));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![OwnedValue::U64(10), OwnedValue::U64(2), OwnedValue::Null]
);
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_field("id"),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![OwnedValue::U64(2), OwnedValue::U64(10), OwnedValue::Null]
);
}
#[test]
fn test_sort_by_owned_string() {
let mut schema_builder = Schema::builder();
let city_field = schema_builder.add_text_field("city", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(city_field => "tokyo")).unwrap();
writer.add_document(doc!(city_field => "austin")).unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_field("city"),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![
OwnedValue::Str("austin".to_string()),
OwnedValue::Str("tokyo".to_string()),
OwnedValue::Null
]
);
}
#[test]
fn test_sort_by_owned_reverse() {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(id_field => 10u64)).unwrap();
writer.add_document(doc!(id_field => 2u64)).unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Reverse));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![OwnedValue::Null, OwnedValue::U64(2), OwnedValue::U64(10)]
);
}
#[test]
fn test_sort_by_owned_score() {
let mut schema_builder = Schema::builder();
let body_field = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(body_field => "a a")).unwrap();
writer.add_document(doc!(body_field => "a")).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let query_parser = crate::query::QueryParser::for_index(&index, vec![body_field]);
let query = query_parser.parse_query("a").unwrap();
// Sort by score descending (Natural)
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_score(), ComparatorEnum::Natural));
let top_docs = searcher.search(&query, &collector).unwrap();
let values: Vec<f64> = top_docs
.into_iter()
.map(|(key, _)| match key {
OwnedValue::F64(val) => val,
_ => panic!("Wrong type {key:?}"),
})
.collect();
assert_eq!(values.len(), 2);
assert!(values[0] > values[1]);
// Sort by score ascending (ReverseNoneLower)
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_score(),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&query, &collector).unwrap();
let values: Vec<f64> = top_docs
.into_iter()
.map(|(key, _)| match key {
OwnedValue::F64(val) => val,
_ => panic!("Wrong type {key:?}"),
})
.collect();
assert_eq!(values.len(), 2);
assert!(values[0] < values[1]);
}
}

View File

@@ -1,5 +1,7 @@
use columnar::ValueRange;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::{DocAddress, DocId, Score};
/// Sort by similarity score.
@@ -9,7 +11,7 @@ pub struct SortBySimilarityScore;
impl SortKeyComputer for SortBySimilarityScore {
type SortKey = Score;
type Child = SortBySimilarityScore;
type Child = SortBySimilarityScoreSegmentComputer;
type Comparator = NaturalComparator;
@@ -21,7 +23,7 @@ impl SortKeyComputer for SortBySimilarityScore {
&self,
_segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
Ok(SortBySimilarityScore)
Ok(SortBySimilarityScoreSegmentComputer)
}
// Sorting by score is special in that it allows for the Block-Wand optimization.
@@ -61,16 +63,29 @@ impl SortKeyComputer for SortBySimilarityScore {
}
}
impl SegmentSortKeyComputer for SortBySimilarityScore {
type SortKey = Score;
pub struct SortBySimilarityScoreSegmentComputer;
impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer {
type SortKey = Score;
type SegmentSortKey = Score;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {
score
}
fn segment_sort_keys(
&mut self,
_input_docs: &[DocId],
_output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
_filter: ValueRange<Self::SegmentSortKey>,
) {
unimplemented!("Batch computation not supported for score sorting")
}
fn convert_segment_sort_key(&self, score: Score) -> Score {
score
}

View File

@@ -1,9 +1,10 @@
use std::marker::PhantomData;
use columnar::Column;
use columnar::{Column, ValueRange};
use crate::collector::sort_key::sort_key_computer::convert_optional_u64_range_to_u64_range;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
use crate::{DocId, Score, SegmentReader};
@@ -34,9 +35,7 @@ impl<T: FastValue> SortByStaticFastValue<T> {
impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
type Child = SortByFastValueSegmentSortKeyComputer<T>;
type SortKey = Option<T>;
type Comparator = NaturalComparator;
fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> {
@@ -84,15 +83,112 @@ pub struct SortByFastValueSegmentSortKeyComputer<T> {
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
type SortKey = Option<T>;
type SegmentSortKey = Option<u64>;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {
self.sort_column.first(doc)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
self.sort_column
.first_vals_in_value_range(input_docs, output, u64_filter);
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key.map(T::from_u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{Schema, FAST};
use crate::Index;
#[test]
fn test_sort_by_fast_value_batch() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_u64_field("field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => 10u64))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => 20u64))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByStaticFastValue::<u64>::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
let mut buffer = ();
computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(10), Some(20), None]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[0, 1, 2]);
}
#[test]
fn test_sort_by_fast_value_batch_with_filter() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_u64_field("field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => 10u64))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => 20u64))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByStaticFastValue::<u64>::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
let mut buffer = ();
computer.segment_sort_keys(
&mut docs,
&mut output,
&mut buffer,
ValueRange::GreaterThan(Some(15u64), false /* inclusive */),
);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(20)]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[1]);
}
}

View File

@@ -1,7 +1,10 @@
use columnar::StrColumn;
use columnar::{StrColumn, ValueRange};
use crate::collector::sort_key::sort_key_computer::{
convert_optional_u64_range_to_u64_range, range_contains_none,
};
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
use crate::{DocId, Score};
@@ -30,9 +33,7 @@ impl SortByString {
impl SortKeyComputer for SortByString {
type SortKey = Option<String>;
type Child = ByStringColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(
@@ -50,8 +51,9 @@ pub struct ByStringColumnSegmentSortKeyComputer {
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
type SortKey = Option<String>;
type SegmentSortKey = Option<TermOrdinal>;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
@@ -59,7 +61,31 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
str_column.ords().first(doc)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
if let Some(str_column) = &self.str_column_opt {
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
str_column
.ords()
.first_vals_in_value_range(input_docs, output, u64_filter);
} else if range_contains_none(&filter) {
for &doc in input_docs {
output.push(ComparableDoc {
doc,
sort_key: None,
});
}
}
}
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
let term_ord = term_ord_opt?;
let str_column = self.str_column_opt.as_ref()?;
let mut bytes = Vec::new();
@@ -70,3 +96,90 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
String::try_from(bytes).ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{Schema, FAST, TEXT};
use crate::Index;
#[test]
fn test_sort_by_string_batch() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => "a"))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => "c"))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByString::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
let mut buffer = ();
computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(0), Some(1), None]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[0, 1, 2]);
}
#[test]
fn test_sort_by_string_batch_with_filter() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => "a"))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => "c"))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByString::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
// Filter: > "b". "a" is 0, "c" is 1.
// We want > "a" (ord 0). So we filter > ord 0.
// 0 is "a", 1 is "c".
let mut buffer = ();
computer.segment_sort_keys(
&mut docs,
&mut output,
&mut buffer,
ValueRange::GreaterThan(Some(0), false /* inclusive */),
);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(1)]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[1]);
}
}

View File

@@ -1,8 +1,12 @@
use std::cmp::Ordering;
use columnar::ValueRange;
use crate::collector::sort_key::{Comparator, NaturalComparator};
use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector;
use crate::collector::{default_collect_segment_impl, SegmentCollector as _, TopNComputer};
use crate::collector::{
default_collect_segment_impl, ComparableDoc, SegmentCollector as _, TopNComputer,
};
use crate::schema::Schema;
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
@@ -12,17 +16,40 @@ use crate::{DocAddress, DocId, Result, Score, SegmentReader};
/// It is the segment local version of the [`SortKeyComputer`].
pub trait SegmentSortKeyComputer: 'static {
/// The final score being emitted.
type SortKey: 'static + PartialOrd + Send + Sync + Clone;
type SortKey: 'static + Send + Sync + Clone;
/// Sort key used by at the segment level by the `SegmentSortKeyComputer`.
///
/// It is typically small like a `u64`, and is meant to be converted
/// to the final score at the end of the collection of the segment.
type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone;
type SegmentSortKey: 'static + Clone + Send + Sync + Clone;
/// Comparator type.
type SegmentComparator: Comparator<Self::SegmentSortKey> + Clone + 'static;
/// Buffer type used for scratch space.
type Buffer: Default + Send + Sync + 'static;
/// Returns the segment sort key comparator.
fn segment_comparator(&self) -> Self::SegmentComparator {
Self::SegmentComparator::default()
}
/// Computes the sort key for the given document and score.
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
/// Computes the sort keys for a batch of documents.
///
/// The computed sort keys and document IDs are pushed into the `output` vector.
/// The `buffer` is used for scratch space.
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
);
/// Computes the sort key and pushes the document in a TopN Computer.
///
/// When using a tuple as the sorting key, the sort key is evaluated in a lazy manner.
@@ -31,12 +58,32 @@ pub trait SegmentSortKeyComputer: 'static {
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
let sort_key = self.segment_sort_key(doc, score);
top_n_computer.push(sort_key, doc);
}
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
docs: &[DocId],
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
// The capacity of a TopNComputer is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so we
// should always be able to `reserve` space for the entire block.
top_n_computer.reserve(docs.len());
let comparator = self.segment_comparator();
let value_range = if let Some(threshold) = &top_n_computer.threshold {
comparator.threshold_to_valuerange(threshold.clone())
} else {
ValueRange::All
};
let (buffer, scratch) = top_n_computer.buffer_and_scratch();
self.segment_sort_keys(docs, buffer, scratch, value_range);
}
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
///
@@ -47,27 +94,7 @@ pub trait SegmentSortKeyComputer: 'static {
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
NaturalComparator.compare(left, right)
}
/// Implementing this method makes it possible to avoid computing
/// a sort_key entirely if we can assess that it won't pass a threshold
/// with a partial computation.
///
/// This is currently used for lexicographic sorting.
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let sort_key = self.segment_sort_key(doc_id, score);
let cmp = self.compare_segment_sort_key(&sort_key, threshold);
if cmp == Ordering::Less {
None
} else {
Some((cmp, sort_key))
}
self.segment_comparator().compare(left, right)
}
/// Convert a segment level sort key into the global sort key.
@@ -81,7 +108,7 @@ pub trait SegmentSortKeyComputer: 'static {
/// the sort key at a segment scale.
pub trait SortKeyComputer: Sync {
/// The sort key type.
type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug;
type SortKey: 'static + Send + Sync + Clone + std::fmt::Debug;
/// Type of the associated [`SegmentSortKeyComputer`].
type Child: SegmentSortKeyComputer<SortKey = Self::SortKey>;
/// Comparator type.
@@ -136,11 +163,9 @@ where
HeadSortKeyComputer: SortKeyComputer,
TailSortKeyComputer: SortKeyComputer,
{
type SortKey = (
<HeadSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
<TailSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
);
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey);
type Child =
ChainSegmentSortKeyComputer<HeadSortKeyComputer::Child, TailSortKeyComputer::Child>;
type Comparator = (
HeadSortKeyComputer::Comparator,
@@ -152,10 +177,10 @@ where
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((
self.0.segment_sort_key_computer(segment_reader)?,
self.1.segment_sort_key_computer(segment_reader)?,
))
Ok(ChainSegmentSortKeyComputer {
head: self.0.segment_sort_key_computer(segment_reader)?,
tail: self.1.segment_sort_key_computer(segment_reader)?,
})
}
/// Checks whether the schema is compatible with the sort key computer.
@@ -173,20 +198,91 @@ where
}
}
impl<HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer> SegmentSortKeyComputer
for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer)
pub struct ChainSegmentSortKeyComputer<Head, Tail>
where
HeadSegmentSortKeyComputer: SegmentSortKeyComputer,
TailSegmentSortKeyComputer: SegmentSortKeyComputer,
Head: SegmentSortKeyComputer,
Tail: SegmentSortKeyComputer,
{
type SortKey = (
HeadSegmentSortKeyComputer::SortKey,
TailSegmentSortKeyComputer::SortKey,
);
type SegmentSortKey = (
HeadSegmentSortKeyComputer::SegmentSortKey,
TailSegmentSortKeyComputer::SegmentSortKey,
);
head: Head,
tail: Tail,
}
pub struct ChainBuffer<HeadBuffer, TailBuffer, HeadKey, TailKey> {
pub head: HeadBuffer,
pub tail: TailBuffer,
pub head_output: Vec<ComparableDoc<HeadKey, DocId>>,
pub tail_output: Vec<ComparableDoc<TailKey, DocId>>,
pub tail_input_docs: Vec<DocId>,
}
impl<HeadBuffer: Default, TailBuffer: Default, HeadKey, TailKey> Default
for ChainBuffer<HeadBuffer, TailBuffer, HeadKey, TailKey>
{
fn default() -> Self {
ChainBuffer {
head: HeadBuffer::default(),
tail: TailBuffer::default(),
head_output: Vec::new(),
tail_output: Vec::new(),
tail_input_docs: Vec::new(),
}
}
}
impl<Head, Tail> ChainSegmentSortKeyComputer<Head, Tail>
where
Head: SegmentSortKeyComputer,
Tail: SegmentSortKeyComputer,
{
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &<Self as SegmentSortKeyComputer>::SegmentSortKey,
) -> Option<(Ordering, <Self as SegmentSortKeyComputer>::SegmentSortKey)> {
let (head_threshold, tail_threshold) = threshold;
let head_sort_key = self.head.segment_sort_key(doc_id, score);
let head_cmp = self
.head
.compare_segment_sort_key(&head_sort_key, head_threshold);
if head_cmp == Ordering::Less {
None
} else if head_cmp == Ordering::Equal {
let tail_sort_key = self.tail.segment_sort_key(doc_id, score);
let tail_cmp = self
.tail
.compare_segment_sort_key(&tail_sort_key, tail_threshold);
if tail_cmp == Ordering::Less {
None
} else {
Some((tail_cmp, (head_sort_key, tail_sort_key)))
}
} else {
let tail_sort_key = self.tail.segment_sort_key(doc_id, score);
Some((head_cmp, (head_sort_key, tail_sort_key)))
}
}
}
impl<Head, Tail> SegmentSortKeyComputer for ChainSegmentSortKeyComputer<Head, Tail>
where
Head: SegmentSortKeyComputer,
Tail: SegmentSortKeyComputer,
{
type SortKey = (Head::SortKey, Tail::SortKey);
type SegmentSortKey = (Head::SegmentSortKey, Tail::SegmentSortKey);
type SegmentComparator = (Head::SegmentComparator, Tail::SegmentComparator);
type Buffer =
ChainBuffer<Head::Buffer, Tail::Buffer, Head::SegmentSortKey, Tail::SegmentSortKey>;
fn segment_comparator(&self) -> Self::SegmentComparator {
(
self.head.segment_comparator(),
self.tail.segment_comparator(),
)
}
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
@@ -198,9 +294,90 @@ where
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
self.0
self.head
.compare_segment_sort_key(&left.0, &right.0)
.then_with(|| self.1.compare_segment_sort_key(&left.1, &right.1))
.then_with(|| self.tail.compare_segment_sort_key(&left.1, &right.1))
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
let (head_filter, threshold) = match filter {
ValueRange::GreaterThan((head_threshold, tail_threshold), _)
| ValueRange::LessThan((head_threshold, tail_threshold), _) => {
let head_cmp = self.head.segment_comparator();
let strict_head_filter = head_cmp.threshold_to_valuerange(head_threshold.clone());
let head_filter = match strict_head_filter {
ValueRange::GreaterThan(t, m) => ValueRange::GreaterThanOrEqual(t, m),
ValueRange::LessThan(t, m) => ValueRange::LessThanOrEqual(t, m),
other => other,
};
(head_filter, Some((head_threshold, tail_threshold)))
}
_ => (ValueRange::All, None),
};
buffer.head_output.clear();
self.head.segment_sort_keys(
input_docs,
&mut buffer.head_output,
&mut buffer.head,
head_filter,
);
if buffer.head_output.is_empty() {
return;
}
buffer.tail_output.clear();
buffer.tail_input_docs.clear();
for cd in &buffer.head_output {
buffer.tail_input_docs.push(cd.doc);
}
self.tail.segment_sort_keys(
&buffer.tail_input_docs,
&mut buffer.tail_output,
&mut buffer.tail,
ValueRange::All,
);
let head_cmp = self.head.segment_comparator();
let tail_cmp = self.tail.segment_comparator();
for (head_doc, tail_doc) in buffer
.head_output
.drain(..)
.zip(buffer.tail_output.drain(..))
{
debug_assert_eq!(head_doc.doc, tail_doc.doc);
let doc = head_doc.doc;
let head_key = head_doc.sort_key;
let tail_key = tail_doc.sort_key;
let accept = if let Some((head_threshold, tail_threshold)) = &threshold {
let head_ord = head_cmp.compare(&head_key, head_threshold);
let ord = if head_ord == Ordering::Equal {
tail_cmp.compare(&tail_key, tail_threshold)
} else {
head_ord
};
ord == Ordering::Greater
} else {
true
};
if accept {
output.push(ComparableDoc {
sort_key: (head_key, tail_key),
doc,
});
}
}
}
#[inline(always)]
@@ -208,7 +385,7 @@ where
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
let sort_key: Self::SegmentSortKey;
if let Some(threshold) = &top_n_computer.threshold {
@@ -225,68 +402,56 @@ where
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
let head_sort_key = self.0.segment_sort_key(doc, score);
let tail_sort_key = self.1.segment_sort_key(doc, score);
let head_sort_key = self.head.segment_sort_key(doc, score);
let tail_sort_key = self.tail.segment_sort_key(doc, score);
(head_sort_key, tail_sort_key)
}
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let (head_threshold, tail_threshold) = threshold;
let (head_cmp, head_sort_key) =
self.0.accept_sort_key_lazy(doc_id, score, head_threshold)?;
if head_cmp == Ordering::Equal {
let (tail_cmp, tail_sort_key) =
self.1.accept_sort_key_lazy(doc_id, score, tail_threshold)?;
Some((tail_cmp, (head_sort_key, tail_sort_key)))
} else {
let tail_sort_key = self.1.segment_sort_key(doc_id, score);
Some((head_cmp, (head_sort_key, tail_sort_key)))
}
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
let (head_sort_key, tail_sort_key) = sort_key;
(
self.0.convert_segment_sort_key(head_sort_key),
self.1.convert_segment_sort_key(tail_sort_key),
self.head.convert_segment_sort_key(head_sort_key),
self.tail.convert_segment_sort_key(tail_sort_key),
)
}
}
/// This struct is used as an adapter to take a sort key computer and map its score to another
/// new sort key.
pub struct MappedSegmentSortKeyComputer<T, PreviousSortKey, NewSortKey> {
pub struct MappedSegmentSortKeyComputer<T: SegmentSortKeyComputer, NewSortKey> {
sort_key_computer: T,
map: fn(PreviousSortKey) -> NewSortKey,
map: fn(T::SortKey) -> NewSortKey,
}
impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
for MappedSegmentSortKeyComputer<T, NewScore>
where
T: SegmentSortKeyComputer<SortKey = PreviousScore>,
PreviousScore: 'static + Clone + Send + Sync + PartialOrd,
NewScore: 'static + Clone + Send + Sync + PartialOrd,
PreviousScore: 'static + Clone + Send + Sync,
NewScore: 'static + Clone + Send + Sync,
{
type SortKey = NewScore;
type SegmentSortKey = T::SegmentSortKey;
type SegmentComparator = T::SegmentComparator;
type Buffer = T::Buffer;
fn segment_comparator(&self) -> Self::SegmentComparator {
self.sort_key_computer.segment_comparator()
}
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.sort_key_computer.segment_sort_key(doc, score)
}
fn accept_sort_key_lazy(
fn segment_sort_keys(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
self.sort_key_computer
.accept_sort_key_lazy(doc_id, score, threshold)
.segment_sort_keys(input_docs, output, buffer, filter)
}
#[inline(always)]
@@ -294,12 +459,21 @@ where
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
self.sort_key_computer
.compute_sort_key_and_collect(doc, score, top_n_computer);
}
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
docs: &[DocId],
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
self.sort_key_computer
.compute_sort_keys_and_collect(docs, top_n_computer);
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey {
(self.map)(
self.sort_key_computer
@@ -325,10 +499,6 @@ where
);
type Child = MappedSegmentSortKeyComputer<
<(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(SortKeyComputer2::SortKey, SortKeyComputer3::SortKey),
),
Self::SortKey,
>;
@@ -352,7 +522,13 @@ where
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let map = |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3);
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)),
sort_key_computer: ChainSegmentSortKeyComputer {
head: sort_key_computer1,
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer2,
tail: sort_key_computer3,
},
},
map,
})
}
@@ -387,13 +563,6 @@ where
SortKeyComputer1,
(SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)),
) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(
SortKeyComputer2::SortKey,
(SortKeyComputer3::SortKey, SortKeyComputer4::SortKey),
),
),
Self::SortKey,
>;
type SortKey = (
@@ -415,10 +584,16 @@ where
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?;
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: (
sort_key_computer1,
(sort_key_computer2, (sort_key_computer3, sort_key_computer4)),
),
sort_key_computer: ChainSegmentSortKeyComputer {
head: sort_key_computer1,
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer2,
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer3,
tail: sort_key_computer4,
},
},
},
map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| {
(sort_key1, sort_key2, sort_key3, sort_key4)
},
@@ -441,6 +616,13 @@ where
}
}
use std::marker::PhantomData;
pub struct FuncSegmentSortKeyComputer<F, TSortKey> {
func: F,
_phantom: PhantomData<TSortKey>,
}
impl<F, SegmentF, TSortKey> SortKeyComputer for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF,
@@ -448,24 +630,44 @@ where
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
{
type SortKey = TSortKey;
type Child = SegmentF;
type Child = FuncSegmentSortKeyComputer<SegmentF, TSortKey>;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((self)(segment_reader))
Ok(FuncSegmentSortKeyComputer {
func: (self)(segment_reader),
_phantom: PhantomData,
})
}
}
impl<F, TSortKey> SegmentSortKeyComputer for F
impl<F, TSortKey> SegmentSortKeyComputer for FuncSegmentSortKeyComputer<F, TSortKey>
where
F: 'static + FnMut(DocId) -> TSortKey,
TSortKey: 'static + PartialOrd + Clone + Send + Sync,
{
type SortKey = TSortKey;
type SegmentSortKey = TSortKey;
type SegmentComparator = NaturalComparator;
type Buffer = ();
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
(self)(doc)
(self.func)(doc)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
_filter: ValueRange<Self::SegmentSortKey>,
) {
for &doc in input_docs {
output.push(ComparableDoc {
sort_key: (self.func)(doc),
doc,
});
}
}
/// Convert a segment level score into the global level score.
@@ -474,13 +676,75 @@ where
}
}
pub(crate) fn range_contains_none(range: &ValueRange<Option<u64>>) -> bool {
match range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(&None),
ValueRange::GreaterThan(_threshold, match_nulls) => *match_nulls,
ValueRange::GreaterThanOrEqual(_threshold, match_nulls) => *match_nulls,
ValueRange::LessThan(_threshold, match_nulls) => *match_nulls,
ValueRange::LessThanOrEqual(_threshold, match_nulls) => *match_nulls,
}
}
pub(crate) fn convert_optional_u64_range_to_u64_range(
range: ValueRange<Option<u64>>,
) -> ValueRange<u64> {
match range {
ValueRange::Inclusive(r) => {
let start = r.start().unwrap_or(0);
let end = r.end().unwrap_or(u64::MAX);
ValueRange::Inclusive(start..=end)
}
ValueRange::GreaterThan(Some(val), match_nulls) => {
ValueRange::GreaterThan(val, match_nulls)
}
ValueRange::GreaterThan(None, match_nulls) => {
if match_nulls {
ValueRange::All
} else {
ValueRange::Inclusive(u64::MIN..=u64::MAX)
}
}
ValueRange::GreaterThanOrEqual(Some(val), match_nulls) => {
ValueRange::GreaterThanOrEqual(val, match_nulls)
}
ValueRange::GreaterThanOrEqual(None, match_nulls) => {
if match_nulls {
ValueRange::All
} else {
ValueRange::Inclusive(u64::MIN..=u64::MAX)
}
}
ValueRange::LessThan(None, match_nulls) => {
if match_nulls {
ValueRange::LessThan(u64::MIN, true)
} else {
ValueRange::Inclusive(1..=0)
}
}
ValueRange::LessThan(Some(val), match_nulls) => ValueRange::LessThan(val, match_nulls),
ValueRange::LessThanOrEqual(None, match_nulls) => {
if match_nulls {
ValueRange::LessThan(u64::MIN, true)
} else {
ValueRange::Inclusive(1..=0)
}
}
ValueRange::LessThanOrEqual(Some(val), match_nulls) => {
ValueRange::LessThanOrEqual(val, match_nulls)
}
ValueRange::All => ValueRange::All,
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::sync::Arc;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::schema::Schema;
use crate::{DocId, Index, Order, SegmentReader};
@@ -628,4 +892,178 @@ mod tests {
(200u32, 2u32)
);
}
#[test]
fn test_batch_score_computer_edge_case() {
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
let score_computer_secondary = |_segment_reader: &SegmentReader| |_doc: DocId| "b";
let lazy_score_computer = (score_computer_primary, score_computer_secondary);
let index = build_test_index();
let searcher = index.reader().unwrap().searcher();
let mut segment_sort_key_computer = lazy_score_computer
.segment_sort_key_computer(searcher.segment_reader(0))
.unwrap();
let mut top_n_computer =
TopNComputer::new_with_comparator(10, lazy_score_computer.comparator());
// Threshold (200, "a"). Doc is (200, "b"). 200 == 200, "b" > "a". Should be accepted.
top_n_computer.threshold = Some((200, "a"));
let docs = vec![0];
segment_sort_key_computer.compute_sort_keys_and_collect(&docs, &mut top_n_computer);
let results = top_n_computer.into_sorted_vec();
assert_eq!(results.len(), 1);
let result = &results[0];
assert_eq!(result.doc, 0);
assert_eq!(result.sort_key, (200, "b"));
}
}
#[cfg(test)]
mod proptest_tests {
use proptest::prelude::*;
use super::*;
use crate::collector::sort_key::order::*;
// Re-implement logic to interpret ValueRange<Option<u64>> manually to verify expectations
fn range_contains_opt(range: &ValueRange<Option<u64>>, val: &Option<u64>) -> bool {
match range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(val),
ValueRange::GreaterThan(t, match_nulls) => {
if val.is_none() {
*match_nulls
} else {
val > t
}
}
ValueRange::GreaterThanOrEqual(t, match_nulls) => {
if val.is_none() {
*match_nulls
} else {
val >= t
}
}
ValueRange::LessThan(t, match_nulls) => {
if val.is_none() {
*match_nulls
} else {
val < t
}
}
ValueRange::LessThanOrEqual(t, match_nulls) => {
if val.is_none() {
*match_nulls
} else {
val <= t
}
}
}
}
fn range_contains_u64(range: &ValueRange<u64>, val: &u64) -> bool {
match range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(val),
ValueRange::GreaterThan(t, _) => val > t,
ValueRange::GreaterThanOrEqual(t, _) => val >= t,
ValueRange::LessThan(t, _) => val < t,
ValueRange::LessThanOrEqual(t, _) => val <= t,
}
}
proptest! {
#[test]
fn test_comparator_consistency_natural_none_is_lower(
threshold in any::<Option<u64>>(),
val in any::<Option<u64>>()
) {
check_comparator::<NaturalComparator>(threshold, val)?;
}
#[test]
fn test_comparator_consistency_reverse(
threshold in any::<Option<u64>>(),
val in any::<Option<u64>>()
) {
check_comparator::<ReverseComparator>(threshold, val)?;
}
#[test]
fn test_comparator_consistency_reverse_none_is_lower(
threshold in any::<Option<u64>>(),
val in any::<Option<u64>>()
) {
check_comparator::<ReverseNoneIsLowerComparator>(threshold, val)?;
}
#[test]
fn test_comparator_consistency_natural_none_is_higher(
threshold in any::<Option<u64>>(),
val in any::<Option<u64>>()
) {
check_comparator::<NaturalNoneIsHigherComparator>(threshold, val)?;
}
}
fn check_comparator<C: Comparator<Option<u64>>>(
threshold: Option<u64>,
val: Option<u64>,
) -> std::result::Result<(), proptest::test_runner::TestCaseError> {
let comparator = C::default();
let range = comparator.threshold_to_valuerange(threshold);
let ordering = comparator.compare(&val, &threshold);
let should_be_in_range = ordering == Ordering::Greater;
let in_range_opt = range_contains_opt(&range, &val);
prop_assert_eq!(
in_range_opt,
should_be_in_range,
"Comparator consistency failed for {:?}. Threshold: {:?}, Val: {:?}, Range: {:?}, \
Ordering: {:?}. range_contains_opt says {}, but compare says {}",
std::any::type_name::<C>(),
threshold,
val,
range,
ordering,
in_range_opt,
should_be_in_range
);
// Check range_contains_none
let expected_none_in_range = range_contains_opt(&range, &None);
let actual_none_in_range = range_contains_none(&range);
prop_assert_eq!(
actual_none_in_range,
expected_none_in_range,
"range_contains_none failed for {:?}. Range: {:?}. Expected (from \
range_contains_opt): {}, Actual: {}",
std::any::type_name::<C>(),
range,
expected_none_in_range,
actual_none_in_range
);
// Check convert_optional_u64_range_to_u64_range
let u64_range = convert_optional_u64_range_to_u64_range(range.clone());
if let Some(v) = val {
let in_u64_range = range_contains_u64(&u64_range, &v);
let in_opt_range = range_contains_opt(&range, &Some(v));
prop_assert_eq!(
in_u64_range,
in_opt_range,
"convert_optional_u64_range_to_u64_range failed for {:?}. Val: {:?}, OptRange: \
{:?}, U64Range: {:?}. Opt says {}, U64 says {}",
std::any::type_name::<C>(),
v,
range,
u64_range,
in_opt_range,
in_u64_range
);
}
Ok(())
}
}

View File

@@ -99,7 +99,12 @@ where
TSegmentSortKeyComputer: SegmentSortKeyComputer,
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey>,
{
pub(crate) topn_computer: TopNComputer<TSegmentSortKeyComputer::SegmentSortKey, DocId, C>,
pub(crate) topn_computer: TopNComputer<
TSegmentSortKeyComputer::SegmentSortKey,
DocId,
C,
TSegmentSortKeyComputer::Buffer,
>,
pub(crate) segment_ord: u32,
pub(crate) segment_sort_key_computer: TSegmentSortKeyComputer,
}
@@ -120,6 +125,11 @@ where
);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.segment_sort_key_computer
.compute_sort_keys_and_collect(docs, &mut self.topn_computer);
}
fn harvest(self) -> Self::Fruit {
let segment_ord = self.segment_ord;
let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self

View File

@@ -1,64 +0,0 @@
use std::cmp::Ordering;
use serde::{Deserialize, Serialize};
/// Contains a feature (field, score, etc.) of a document along with the document address.
///
/// It guarantees stable sorting: in case of a tie on the feature, the document
/// address is used.
///
/// The REVERSE_ORDER generic parameter controls whether the by-feature order
/// should be reversed, which is useful for achieving for example largest-first
/// semantics without having to wrap the feature in a `Reverse`.
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
/// The feature of the document. In practice, this is
/// is any type that implements `PartialOrd`.
pub sort_key: T,
/// The document address. In practice, this is any
/// type that implements `PartialOrd`, and is guaranteed
/// to be unique for each document.
pub doc: D,
}
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
for ComparableDoc<T, D, R>
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
.field("feature", &self.sort_key)
.field("doc", &self.doc)
.finish()
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialOrd for ComparableDoc<T, D, R> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R> {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
let by_feature = self
.sort_key
.partial_cmp(&other.sort_key)
.map(|ord| if R { ord.reverse() } else { ord })
.unwrap_or(Ordering::Equal);
let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal);
// In case of a tie on the feature, we sort by ascending
// `DocAddress` in order to ensure a stable sorting of the
// documents.
by_feature.then_with(lazy_by_doc_address)
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T, D, R> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}

View File

@@ -2,6 +2,7 @@ use std::cmp::Ordering;
use std::fmt;
use std::ops::Range;
use columnar::ValueRange;
use serde::{Deserialize, Serialize};
use super::Collector;
@@ -10,8 +11,7 @@ use crate::collector::sort_key::{
SortByStaticFastValue, SortByString,
};
use crate::collector::sort_key_top_collector::TopBySortKeyCollector;
use crate::collector::top_collector::ComparableDoc;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::FastValue;
use crate::{DocAddress, DocId, Order, Score, SegmentReader};
@@ -23,10 +23,9 @@ use crate::{DocAddress, DocId, Order, Score, SegmentReader};
/// The theoretical complexity for collecting the top `K` out of `N` documents
/// is `O(N + K)`.
///
/// This collector does not guarantee a stable sorting in case of a tie on the
/// document score, for stable sorting `PartialOrd` needs to resolve on other fields
/// like docid in case of score equality.
/// Only then, it is suitable for pagination.
/// This collector guarantees a stable sorting in case of a tie on the
/// document score/sort key: The document address (`DocAddress`) is used as a tie breaker.
/// In case of a tie on the sort key, documents are always sorted by ascending `DocAddress`.
///
/// ```rust
/// use tantivy::collector::TopDocs;
@@ -325,7 +324,7 @@ impl TopDocs {
sort_key_computer: impl SortKeyComputer<SortKey = TSortKey> + Send + 'static,
) -> impl Collector<Fruit = Vec<(TSortKey, DocAddress)>>
where
TSortKey: 'static + Clone + Send + Sync + PartialOrd + std::fmt::Debug,
TSortKey: 'static + Clone + Send + Sync + std::fmt::Debug,
{
TopBySortKeyCollector::new(sort_key_computer, self.doc_range())
}
@@ -446,7 +445,7 @@ where
F: 'static + Send + Sync + Fn(&SegmentReader) -> TTweakScoreSortKeyFn,
TTweakScoreSortKeyFn: 'static + Fn(DocId, Score) -> TSortKey,
TweakScoreSegmentSortKeyComputer<TTweakScoreSortKeyFn>:
SegmentSortKeyComputer<SortKey = TSortKey>,
SegmentSortKeyComputer<SortKey = TSortKey, SegmentSortKey = TSortKey>,
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
{
type SortKey = TSortKey;
@@ -481,11 +480,23 @@ where
{
type SortKey = TSortKey;
type SegmentSortKey = TSortKey;
type SegmentComparator = NaturalComparator;
type Buffer = ();
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey {
(self.sort_key_fn)(doc, score)
}
fn segment_sort_keys(
&mut self,
_input_docs: &[DocId],
_output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
_filter: ValueRange<Self::SegmentSortKey>,
) {
unimplemented!("Batch computation is not supported for tweak score.")
}
/// Convert a segment level score into the global level score.
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key
@@ -500,16 +511,23 @@ where
///
/// For TopN == 0, it will be relative expensive.
///
/// When using the natural comparator, the top N computer returns the top N elements in
/// descending order, as expected for a top N.
/// The TopNComputer will tiebreak by using ascending `D` (DocId or DocAddress):
/// i.e., in case of a tie on the sort key, the `DocId|DocAddress` are always sorted in
/// ascending order, regardless of the `Comparator` used for the `Score` type.
///
/// NOTE: Items must be `push`ed to the TopNComputer in ascending `DocId|DocAddress` order, as the
/// threshold used to eliminate docs does not include the `DocId` or `DocAddress`: this provides
/// the ascending `DocId|DocAddress` tie-breaking behavior without additional comparisons.
#[derive(Serialize, Deserialize)]
#[serde(from = "TopNComputerDeser<Score, D, C>")]
pub struct TopNComputer<Score, D, C> {
pub struct TopNComputer<Score, D, C, Buffer = ()> {
/// The buffer reverses sort order to get top-semantics instead of bottom-semantics
buffer: Vec<ComparableDoc<Score, D>>,
top_n: usize,
pub(crate) threshold: Option<Score>,
comparator: C,
#[serde(skip)]
scratch: Buffer,
}
// Intermediate struct for TopNComputer for deserialization, to keep vec capacity
@@ -521,7 +539,9 @@ struct TopNComputerDeser<Score, D, C> {
comparator: C,
}
impl<Score, D, C> From<TopNComputerDeser<Score, D, C>> for TopNComputer<Score, D, C> {
impl<Score, D, C, Buffer> From<TopNComputerDeser<Score, D, C>> for TopNComputer<Score, D, C, Buffer>
where Buffer: Default
{
fn from(mut value: TopNComputerDeser<Score, D, C>) -> Self {
let expected_cap = value.top_n.max(1) * 2;
let current_cap = value.buffer.capacity();
@@ -536,12 +556,15 @@ impl<Score, D, C> From<TopNComputerDeser<Score, D, C>> for TopNComputer<Score, D
top_n: value.top_n,
threshold: value.threshold,
comparator: value.comparator,
scratch: Buffer::default(),
}
}
}
impl<Score: std::fmt::Debug, D, C> std::fmt::Debug for TopNComputer<Score, D, C>
where C: Comparator<Score>
impl<Score: std::fmt::Debug, D, C, Buffer> std::fmt::Debug for TopNComputer<Score, D, C, Buffer>
where
C: Comparator<Score>,
Buffer: std::fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> std::fmt::Result {
f.debug_struct("TopNComputer")
@@ -549,12 +572,13 @@ where C: Comparator<Score>
.field("top_n", &self.top_n)
.field("current_threshold", &self.threshold)
.field("comparator", &self.comparator)
.field("scratch", &self.scratch)
.finish()
}
}
// Custom clone to keep capacity
impl<Score: Clone, D: Clone, C: Clone> Clone for TopNComputer<Score, D, C> {
impl<Score: Clone, D: Clone, C: Clone, Buffer: Clone> Clone for TopNComputer<Score, D, C, Buffer> {
fn clone(&self) -> Self {
let mut buffer_clone = Vec::with_capacity(self.buffer.capacity());
buffer_clone.extend(self.buffer.iter().cloned());
@@ -563,15 +587,17 @@ impl<Score: Clone, D: Clone, C: Clone> Clone for TopNComputer<Score, D, C> {
top_n: self.top_n,
threshold: self.threshold.clone(),
comparator: self.comparator.clone(),
scratch: self.scratch.clone(),
}
}
}
impl<TSortKey, D> TopNComputer<TSortKey, D, ReverseComparator>
impl<TSortKey, D> TopNComputer<TSortKey, D, ReverseComparator, ()>
where
D: Ord,
TSortKey: Clone,
NaturalComparator: Comparator<TSortKey>,
ReverseComparator: Comparator<TSortKey>,
{
/// Create a new `TopNComputer`.
/// Internally it will allocate a buffer of size `2 * top_n`.
@@ -580,30 +606,50 @@ where
}
}
impl<TSortKey, D, C> TopNComputer<TSortKey, D, C>
#[inline(always)]
pub fn compare_for_top_k<TSortKey, D: Ord, C: Comparator<TSortKey>>(
c: &C,
lhs: &ComparableDoc<TSortKey, D>,
rhs: &ComparableDoc<TSortKey, D>,
) -> std::cmp::Ordering {
c.compare(&lhs.sort_key, &rhs.sort_key)
.reverse() // Reverse here because we want top K.
.then_with(|| lhs.doc.cmp(&rhs.doc)) // Regardless of asc/desc, in presence of a tie, we
// sort by doc id
}
impl<TSortKey, D, C, Buffer> TopNComputer<TSortKey, D, C, Buffer>
where
D: Ord,
TSortKey: Clone,
C: Comparator<TSortKey>,
Buffer: Default,
{
/// Create a new `TopNComputer`.
/// Internally it will allocate a buffer of size `2 * top_n`.
/// Internally it will allocate a buffer of size `(top_n.max(1) * 2) +
/// COLLECT_BLOCK_BUFFER_LEN`.
pub fn new_with_comparator(top_n: usize, comparator: C) -> Self {
let vec_cap = top_n.max(1) * 2;
// We ensure that there is always enough space to include an entire block in the buffer if
// need be, so that `push_block_lazy` can avoid checking capacity inside its loop.
let vec_cap = (top_n.max(1) * 2) + crate::COLLECT_BLOCK_BUFFER_LEN;
TopNComputer {
buffer: Vec::with_capacity(vec_cap),
top_n,
threshold: None,
comparator,
scratch: Buffer::default(),
}
}
/// Push a new document to the top n.
/// If the document is below the current threshold, it will be ignored.
///
/// NOTE: `push` must be called in ascending `DocId`/`DocAddress` order.
#[inline]
pub fn push(&mut self, sort_key: TSortKey, doc: D) {
if let Some(last_median) = &self.threshold {
if self.comparator.compare(&sort_key, last_median) == Ordering::Less {
// See the struct docs for an explanation of why this comparison is strict.
if self.comparator.compare(&sort_key, last_median) != Ordering::Greater {
return;
}
}
@@ -615,23 +661,33 @@ where
// At this point, we need to have established that the doc is above the threshold.
#[inline(always)]
pub(crate) fn append_doc(&mut self, doc: D, sort_key: TSortKey) {
if self.buffer.len() == self.buffer.capacity() {
let median = self.truncate_top_n();
self.threshold = Some(median);
}
// This cannot panic, because we truncate_median will at least remove one element, since
// the min capacity is 2.
self.reserve(1);
// This cannot panic, because we've reserved room for one element.
let comparable_doc = ComparableDoc { doc, sort_key };
push_assuming_capacity(comparable_doc, &mut self.buffer);
}
// Ensure that there is capacity to push `additional` more elements without resizing.
#[inline(always)]
pub(crate) fn reserve(&mut self, additional: usize) {
if self.buffer.len() + additional > self.buffer.capacity() {
let median = self.truncate_top_n();
debug_assert!(self.buffer.len() + additional <= self.buffer.capacity());
self.threshold = Some(median);
}
}
pub(crate) fn buffer_and_scratch(
&mut self,
) -> (&mut Vec<ComparableDoc<TSortKey, D>>, &mut Buffer) {
(&mut self.buffer, &mut self.scratch)
}
#[inline(never)]
fn truncate_top_n(&mut self) -> TSortKey {
// Use select_nth_unstable to find the top nth score
let (_, median_el, _) = self.buffer.select_nth_unstable_by(self.top_n, |lhs, rhs| {
self.comparator
.compare(&rhs.sort_key, &lhs.sort_key)
.then_with(|| lhs.doc.cmp(&rhs.doc))
compare_for_top_k(&self.comparator, lhs, rhs)
});
let median_score = median_el.sort_key.clone();
@@ -646,11 +702,8 @@ where
if self.buffer.len() > self.top_n {
self.truncate_top_n();
}
self.buffer.sort_unstable_by(|left, right| {
self.comparator
.compare(&right.sort_key, &left.sort_key)
.then_with(|| left.doc.cmp(&right.doc))
});
self.buffer
.sort_unstable_by(|lhs, rhs| compare_for_top_k(&self.comparator, lhs, rhs));
self.buffer
}
@@ -669,7 +722,7 @@ where
//
// Panics if there is not enough capacity to add an element.
#[inline(always)]
fn push_assuming_capacity<T>(el: T, buf: &mut Vec<T>) {
pub fn push_assuming_capacity<T>(el: T, buf: &mut Vec<T>) {
let prev_len = buf.len();
assert!(prev_len < buf.capacity());
// This is mimicking the current (non-stabilized) implementation in std.
@@ -687,8 +740,7 @@ mod tests {
use super::{TopDocs, TopNComputer};
use crate::collector::sort_key::{ComparatorEnum, NaturalComparator, ReverseComparator};
use crate::collector::top_collector::ComparableDoc;
use crate::collector::{Collector, DocSetCollector};
use crate::collector::{Collector, ComparableDoc, DocSetCollector};
use crate::query::{AllQuery, Query, QueryParser};
use crate::schema::{Field, Schema, FAST, STORED, TEXT};
use crate::time::format_description::well_known::Rfc3339;
@@ -755,6 +807,33 @@ mod tests {
);
}
#[test]
fn test_topn_computer_duplicates() {
let mut computer: TopNComputer<u32, u32, NaturalComparator> =
TopNComputer::new_with_comparator(2, NaturalComparator);
computer.push(1u32, 1u32);
computer.push(1u32, 2u32);
computer.push(1u32, 3u32);
computer.push(1u32, 4u32);
computer.push(1u32, 5u32);
// In the presence of duplicates, DocIds are always ascending order.
assert_eq!(
computer.into_sorted_vec(),
&[
ComparableDoc {
sort_key: 1u32,
doc: 1u32,
},
ComparableDoc {
sort_key: 1u32,
doc: 2u32,
}
]
);
}
#[test]
fn test_topn_computer_no_panic() {
for top_n in 0..10 {
@@ -772,14 +851,17 @@ mod tests {
#[test]
fn test_topn_computer_asc_prop(
limit in 0..10_usize,
docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
mut docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
) {
// NB: TopNComputer must receive inputs in ascending DocId order.
docs.sort_by_key(|(_, doc_id)| *doc_id);
let mut computer: TopNComputer<_, _, ReverseComparator> = TopNComputer::new_with_comparator(limit, ReverseComparator);
for (feature, doc) in &docs {
computer.push(*feature, *doc);
}
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> = docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::<Vec<_>>();
comparable_docs.sort();
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> =
docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect();
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, Order::Asc);
comparable_docs.truncate(limit);
prop_assert_eq!(
computer.into_sorted_vec(),
@@ -1363,11 +1445,11 @@ mod tests {
#[test]
fn test_top_field_collect_string_prop(
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..256_usize,
offset in 0..256_usize,
limit in 1..32_usize,
offset in 0..32_usize,
segments_terms in
proptest::collection::vec(
proptest::collection::vec(0..32_u8, 1..32_usize),
proptest::collection::vec(0..64_u8, 1..256_usize),
0..8_usize,
)
) {
@@ -1406,15 +1488,10 @@ mod tests {
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, order);
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
@@ -1693,7 +1770,8 @@ mod tests {
#[test]
fn test_top_n_computer_not_at_capacity() {
let mut top_n_computer = TopNComputer::new_with_comparator(4, NaturalComparator);
let mut top_n_computer: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(4, NaturalComparator);
top_n_computer.append_doc(1, 0.8);
top_n_computer.append_doc(3, 0.2);
top_n_computer.append_doc(5, 0.3);
@@ -1718,7 +1796,8 @@ mod tests {
#[test]
fn test_top_n_computer_at_capacity() {
let mut top_collector = TopNComputer::new_with_comparator(4, NaturalComparator);
let mut top_collector: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(4, NaturalComparator);
top_collector.append_doc(1, 0.8);
top_collector.append_doc(3, 0.2);
top_collector.append_doc(5, 0.3);
@@ -1755,12 +1834,14 @@ mod tests {
let doc_ids_collection = [4, 5, 6];
let score = 3.3f32;
let mut top_collector_limit_2 = TopNComputer::new_with_comparator(2, NaturalComparator);
let mut top_collector_limit_2: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(2, NaturalComparator);
for id in &doc_ids_collection {
top_collector_limit_2.append_doc(*id, score);
}
let mut top_collector_limit_3 = TopNComputer::new_with_comparator(3, NaturalComparator);
let mut top_collector_limit_3: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(3, NaturalComparator);
for id in &doc_ids_collection {
top_collector_limit_3.append_doc(*id, score);
}
@@ -1781,15 +1862,16 @@ mod bench {
#[bench]
fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) {
let mut top_collector = TopNComputer::new_with_comparator(100, NaturalComparator);
let mut top_collector: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(100, NaturalComparator);
for i in 0..100 {
top_collector.append_doc(i, 0.8);
top_collector.append_doc(i as u32, 0.8);
}
b.iter(|| {
for i in 0..100 {
top_collector.append_doc(i, 0.8);
top_collector.append_doc(i as u32, 0.8);
}
});
}

View File

@@ -227,9 +227,6 @@ pub(crate) fn index_json_value<'a, V: Value<'a>>(
ReferenceValueLeaf::IpAddr(_) => {
unimplemented!("IP address support in dynamic fields is not yet implemented")
}
ReferenceValueLeaf::Geometry(_) => {
unimplemented!("Geometry support in dynamic fields is not implemented")
}
},
ReferenceValue::Array(elements) => {
for val in elements {
@@ -409,7 +406,7 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_str("red");
assert_eq!(term.serialized_term(), b"\x00\x00\x00\x01jcolor\x00sred")
assert_eq!(term.serialized_value_bytes(), b"color\x00sred".to_vec())
}
#[test]
@@ -419,8 +416,8 @@ mod tests {
term.append_type_and_fast_value(-4i64);
assert_eq!(
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc"
term.serialized_value_bytes(),
b"color\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc".to_vec()
)
}
@@ -431,8 +428,8 @@ mod tests {
term.append_type_and_fast_value(4u64);
assert_eq!(
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00u\x00\x00\x00\x00\x00\x00\x00\x04"
term.serialized_value_bytes(),
b"color\x00u\x00\x00\x00\x00\x00\x00\x00\x04".to_vec()
)
}
@@ -442,8 +439,8 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_fast_value(4.0f64);
assert_eq!(
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00f\xc0\x10\x00\x00\x00\x00\x00\x00"
term.serialized_value_bytes(),
b"color\x00f\xc0\x10\x00\x00\x00\x00\x00\x00".to_vec()
)
}
@@ -453,8 +450,8 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_fast_value(true);
assert_eq!(
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00o\x00\x00\x00\x00\x00\x00\x00\x01"
term.serialized_value_bytes(),
b"color\x00o\x00\x00\x00\x00\x00\x00\x00\x01".to_vec()
)
}

View File

@@ -5,7 +5,7 @@ use std::ops::Range;
use common::{BinarySerializable, CountingWriter, HasLen, VInt};
use crate::directory::{FileSlice, TerminatingWrite, WritePtr};
use crate::schema::Field;
use crate::schema::{Field, Schema};
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
#[derive(Eq, PartialEq, Hash, Copy, Ord, PartialOrd, Clone, Debug)]
@@ -167,10 +167,11 @@ impl CompositeFile {
.map(|byte_range| self.data.slice(byte_range.clone()))
}
pub fn space_usage(&self) -> PerFieldSpaceUsage {
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
let mut fields = Vec::new();
for (&field_addr, byte_range) in &self.offsets_index {
let mut field_usage = FieldUsage::empty(field_addr.field);
let field_name = schema.get_field_name(field_addr.field).to_string();
let mut field_usage = FieldUsage::empty(field_name);
field_usage.add_field_idx(field_addr.idx, byte_range.len().into());
fields.push(field_usage);
}

View File

@@ -1,3 +1,5 @@
mod file_watcher;
use std::collections::HashMap;
use std::fmt;
use std::fs::{self, File, OpenOptions};
@@ -7,6 +9,7 @@ use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock, Weak};
use common::StableDeref;
use file_watcher::FileWatcher;
use fs4::fs_std::FileExt;
#[cfg(all(feature = "mmap", unix))]
pub use memmap2::Advice;
@@ -18,7 +21,6 @@ use crate::core::META_FILEPATH;
use crate::directory::error::{
DeleteError, LockError, OpenDirectoryError, OpenReadError, OpenWriteError,
};
use crate::directory::file_watcher::FileWatcher;
use crate::directory::{
AntiCallToken, Directory, DirectoryLock, FileHandle, Lock, OwnedBytes, TerminatingWrite,
WatchCallback, WatchHandle, WritePtr,

View File

@@ -5,7 +5,6 @@ mod mmap_directory;
mod directory;
mod directory_lock;
mod file_watcher;
pub mod footer;
mod managed_directory;
mod ram_directory;

View File

@@ -40,6 +40,8 @@ pub trait DocSet: Send {
/// of `DocSet` should support it.
///
/// Calling `seek(TERMINATED)` is also legal and is the normal way to consume a `DocSet`.
///
/// `target` has to be larger or equal to `.doc()` when calling `seek`.
fn seek(&mut self, target: DocId) -> DocId {
let mut doc = self.doc();
debug_assert!(doc <= target);
@@ -49,6 +51,33 @@ pub trait DocSet: Send {
doc
}
/// Seeks to the target if possible and returns true if the target is in the DocSet.
///
/// DocSets that already have an efficient `seek` method don't need to implement
/// `seek_into_the_danger_zone`. All wrapper DocSets should forward
/// `seek_into_the_danger_zone` to the underlying DocSet.
///
/// ## API Behaviour
/// If `seek_into_the_danger_zone` is returning true, a call to `doc()` has to return target.
/// If `seek_into_the_danger_zone` is returning false, a call to `doc()` may return any doc
/// between the last doc that matched and target or a doc that is a valid next hit after
/// target. The DocSet is considered to be in an invalid state until
/// `seek_into_the_danger_zone` returns true again.
///
/// `target` needs to be equal or larger than `doc` when in a valid state.
///
/// Consecutive calls are not allowed to have decreasing `target` values.
///
/// # Warning
/// This is an advanced API used by intersection. The API contract is tricky, avoid using it.
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
let current_doc = self.doc();
if current_doc < target {
self.seek(target);
}
self.doc() == target
}
/// Fills a given mutable buffer with the next doc ids from the
/// `DocSet`
///
@@ -94,6 +123,15 @@ pub trait DocSet: Send {
/// which would be the number of documents in the DocSet.
///
/// By default this returns `size_hint()`.
///
/// DocSets may have vastly different cost depending on their type,
/// e.g. an intersection with 10 hits is much cheaper than
/// a phrase search with 10 hits, since it needs to load positions.
///
/// ### Future Work
/// We may want to differentiate `DocSet` costs more more granular, e.g.
/// creation_cost, advance_cost, seek_cost on to get a good estimation
/// what query types to choose.
fn cost(&self) -> u64 {
self.size_hint() as u64
}
@@ -137,6 +175,10 @@ impl DocSet for &mut dyn DocSet {
(**self).seek(target)
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
(**self).seek_into_the_danger_zone(target)
}
fn doc(&self) -> u32 {
(**self).doc()
}
@@ -169,6 +211,11 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
unboxed.seek(target)
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.seek_into_the_danger_zone(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.fill_buffer(buffer)

View File

@@ -79,7 +79,7 @@ mod tests {
use std::ops::{Range, RangeInclusive};
use std::path::Path;
use columnar::StrColumn;
use columnar::{StrColumn, ValueRange};
use common::{ByteCount, DateTimePrecision, HasLen, TerminatingWrite};
use once_cell::sync::Lazy;
use rand::prelude::SliceRandom;
@@ -683,7 +683,7 @@ mod tests {
}
#[test]
fn test_datefastfield() {
fn test_datefastfield() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let date_field = schema_builder.add_date_field(
"date",
@@ -697,28 +697,22 @@ mod tests {
);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
let mut index_writer = index.writer_for_tests()?;
index_writer.set_merge_policy(Box::new(NoMergePolicy));
index_writer
.add_document(doc!(
date_field => DateTime::from_u64(1i64.to_u64()),
multi_date_field => DateTime::from_u64(2i64.to_u64()),
multi_date_field => DateTime::from_u64(3i64.to_u64())
))
.unwrap();
index_writer
.add_document(doc!(
date_field => DateTime::from_u64(4i64.to_u64())
))
.unwrap();
index_writer
.add_document(doc!(
multi_date_field => DateTime::from_u64(5i64.to_u64()),
multi_date_field => DateTime::from_u64(6i64.to_u64())
))
.unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
index_writer.add_document(doc!(
date_field => DateTime::from_u64(1i64.to_u64()),
multi_date_field => DateTime::from_u64(2i64.to_u64()),
multi_date_field => DateTime::from_u64(3i64.to_u64())
))?;
index_writer.add_document(doc!(
date_field => DateTime::from_u64(4i64.to_u64())
))?;
index_writer.add_document(doc!(
multi_date_field => DateTime::from_u64(5i64.to_u64()),
multi_date_field => DateTime::from_u64(6i64.to_u64())
))?;
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_reader(0);
@@ -752,6 +746,7 @@ mod tests {
assert_eq!(dates[0].into_timestamp_nanos(), 5i64);
assert_eq!(dates[1].into_timestamp_nanos(), 6i64);
}
Ok(())
}
#[test]
@@ -949,7 +944,7 @@ mod tests {
let test_range = |range: RangeInclusive<u64>| {
let expected_count = numbers.iter().filter(|num| range.contains(*num)).count();
let mut vec = vec![];
field.get_row_ids_for_value_range(range, 0..u32::MAX, &mut vec);
field.get_row_ids_for_value_range(ValueRange::Inclusive(range), 0..u32::MAX, &mut vec);
assert_eq!(vec.len(), expected_count);
};
test_range(50..=50);
@@ -1027,7 +1022,7 @@ mod tests {
let test_range = |range: RangeInclusive<u64>| {
let expected_count = numbers.iter().filter(|num| range.contains(*num)).count();
let mut vec = vec![];
field.get_row_ids_for_value_range(range, 0..u32::MAX, &mut vec);
field.get_row_ids_for_value_range(ValueRange::Inclusive(range), 0..u32::MAX, &mut vec);
assert_eq!(vec.len(), expected_count);
};
let test_range_variant = |start, stop| {

View File

@@ -8,7 +8,7 @@ use columnar::{
};
use common::ByteCount;
use crate::core::json_utils::encode_column_name;
use crate::core::json_utils::{encode_column_name, json_path_sep_to_dot};
use crate::directory::FileSlice;
use crate::schema::{Field, FieldEntry, FieldType, Schema};
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
@@ -39,19 +39,15 @@ impl FastFieldReaders {
self.resolve_column_name_given_default_field(column_name, default_field_opt)
}
pub(crate) fn space_usage(&self, schema: &Schema) -> io::Result<PerFieldSpaceUsage> {
pub(crate) fn space_usage(&self) -> io::Result<PerFieldSpaceUsage> {
let mut per_field_usages: Vec<FieldUsage> = Default::default();
for (field, field_entry) in schema.fields() {
let column_handles = self.columnar.read_columns(field_entry.name())?;
let num_bytes: ByteCount = column_handles
.iter()
.map(|column_handle| column_handle.num_bytes())
.sum();
let mut field_usage = FieldUsage::empty(field);
field_usage.add_field_idx(0, num_bytes);
for (mut field_name, column_handle) in self.columnar.iter_columns()? {
json_path_sep_to_dot(&mut field_name);
let space_usage = column_handle.space_usage()?;
let mut field_usage = FieldUsage::empty(field_name);
field_usage.set_column_usage(space_usage);
per_field_usages.push(field_usage);
}
// TODO fix space usage for JSON fields.
Ok(PerFieldSpaceUsage::new(per_field_usages))
}

View File

@@ -189,9 +189,6 @@ impl FastFieldsWriter {
.record_str(doc_id, field_name, &token.text);
}
}
ReferenceValueLeaf::Geometry(_) => {
panic!("Geometry fields should not be routed to fast field writer")
}
},
ReferenceValue::Array(val) => {
// TODO: Check this is the correct behaviour we want.
@@ -323,9 +320,6 @@ fn record_json_value_to_columnar_writer<'a, V: Value<'a>>(
"Pre-tokenized string support in dynamic fields is not yet implemented"
)
}
ReferenceValueLeaf::Geometry(_) => {
unimplemented!("Geometry support in dynamic fields is not yet implemented")
}
},
ReferenceValue::Array(elements) => {
for el in elements {

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use super::{fieldnorm_to_id, id_to_fieldnorm};
use crate::directory::{CompositeFile, FileSlice, OwnedBytes};
use crate::schema::Field;
use crate::schema::{Field, Schema};
use crate::space_usage::PerFieldSpaceUsage;
use crate::DocId;
@@ -37,8 +37,8 @@ impl FieldNormReaders {
}
/// Return a break down of the space usage per field.
pub fn space_usage(&self) -> PerFieldSpaceUsage {
self.data.space_usage()
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
self.data.space_usage(schema)
}
/// Returns a handle to inner file

View File

@@ -13,9 +13,9 @@ use crate::store::Compressor;
use crate::{Inventory, Opstamp, TrackedObject};
#[derive(Clone, Debug, Serialize, Deserialize)]
struct DeleteMeta {
pub struct DeleteMeta {
num_deleted_docs: u32,
opstamp: Opstamp,
pub opstamp: Opstamp,
}
#[derive(Clone, Default)]
@@ -142,7 +142,6 @@ impl SegmentMeta {
SegmentComponent::FastFields => ".fast".to_string(),
SegmentComponent::FieldNorms => ".fieldnorm".to_string(),
SegmentComponent::Delete => format!(".{}.del", self.delete_opstamp().unwrap_or(0)),
SegmentComponent::Spatial => ".spatial".to_string(),
});
PathBuf::from(path)
}
@@ -214,7 +213,7 @@ impl SegmentMeta {
struct InnerSegmentMeta {
segment_id: SegmentId,
max_doc: u32,
deletes: Option<DeleteMeta>,
pub deletes: Option<DeleteMeta>,
/// If you want to avoid the SegmentComponent::TempStore file to be covered by
/// garbage collection and deleted, set this to true. This is used during merge.
#[serde(skip)]
@@ -405,7 +404,10 @@ mod tests {
schema_builder.build()
};
let index_metas = IndexMeta {
index_settings: IndexSettings::default(),
index_settings: IndexSettings {
docstore_compression: Compressor::None,
..Default::default()
},
segments: Vec::new(),
schema,
opstamp: 0u64,
@@ -414,7 +416,7 @@ mod tests {
let json = serde_json::ser::to_string(&index_metas).expect("serialization failed");
assert_eq!(
json,
r#"{"index_settings":{"docstore_compression":"lz4","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
r#"{"index_settings":{"docstore_compression":"none","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
);
let deser_meta: UntrackedIndexMeta = serde_json::from_str(&json).unwrap();
@@ -495,6 +497,8 @@ mod tests {
#[test]
#[cfg(feature = "lz4-compression")]
fn test_index_settings_default() {
use crate::store::Compressor;
let mut index_settings = IndexSettings::default();
assert_eq!(
index_settings,

View File

@@ -46,7 +46,7 @@ impl Segment {
///
/// This method is only used when updating `max_doc` from 0
/// as we finalize a fresh new segment.
pub(crate) fn with_max_doc(self, max_doc: u32) -> Segment {
pub fn with_max_doc(self, max_doc: u32) -> Segment {
Segment {
index: self.index,
meta: self.meta.with_max_doc(max_doc),

View File

@@ -28,14 +28,12 @@ pub enum SegmentComponent {
/// Bitset describing which document of the segment is alive.
/// (It was representing deleted docs but changed to represent alive docs from v0.17)
Delete,
/// HUSH
Spatial,
}
impl SegmentComponent {
/// Iterates through the components.
pub fn iterator() -> slice::Iter<'static, SegmentComponent> {
static SEGMENT_COMPONENTS: [SegmentComponent; 9] = [
static SEGMENT_COMPONENTS: [SegmentComponent; 8] = [
SegmentComponent::Postings,
SegmentComponent::Positions,
SegmentComponent::FastFields,
@@ -44,7 +42,6 @@ impl SegmentComponent {
SegmentComponent::Store,
SegmentComponent::TempStore,
SegmentComponent::Delete,
SegmentComponent::Spatial,
];
SEGMENT_COMPONENTS.iter()
}

View File

@@ -14,7 +14,6 @@ use crate::index::{InvertedIndexReader, Segment, SegmentComponent, SegmentId};
use crate::json_utils::json_path_sep_to_dot;
use crate::schema::{Field, IndexRecordOption, Schema, Type};
use crate::space_usage::SegmentSpaceUsage;
use crate::spatial::reader::SpatialReaders;
use crate::store::StoreReader;
use crate::termdict::TermDictionary;
use crate::{DocId, Opstamp};
@@ -44,7 +43,6 @@ pub struct SegmentReader {
positions_composite: CompositeFile,
fast_fields_readers: FastFieldReaders,
fieldnorm_readers: FieldNormReaders,
spatial_readers: SpatialReaders,
store_file: FileSlice,
alive_bitset_opt: Option<AliveBitSet>,
@@ -94,11 +92,6 @@ impl SegmentReader {
&self.fast_fields_readers
}
/// HUSH
pub fn spatial_fields(&self) -> &SpatialReaders {
&self.spatial_readers
}
/// Accessor to the `FacetReader` associated with a given `Field`.
pub fn facet_reader(&self, field_name: &str) -> crate::Result<FacetReader> {
let schema = self.schema();
@@ -180,12 +173,6 @@ impl SegmentReader {
let fast_fields_readers = FastFieldReaders::open(fast_fields_data, schema.clone())?;
let fieldnorm_data = segment.open_read(SegmentComponent::FieldNorms)?;
let fieldnorm_readers = FieldNormReaders::open(fieldnorm_data)?;
let spatial_readers = if schema.contains_spatial_field() {
let spatial_data = segment.open_read(SegmentComponent::Spatial)?;
SpatialReaders::open(spatial_data)?
} else {
SpatialReaders::empty()
};
let original_bitset = if segment.meta().has_deletes() {
let alive_doc_file_slice = segment.open_read(SegmentComponent::Delete)?;
@@ -211,7 +198,6 @@ impl SegmentReader {
postings_composite,
fast_fields_readers,
fieldnorm_readers,
spatial_readers,
segment_id: segment.id(),
delete_opstamp: segment.meta().delete_opstamp(),
store_file,
@@ -469,12 +455,11 @@ impl SegmentReader {
pub fn space_usage(&self) -> io::Result<SegmentSpaceUsage> {
Ok(SegmentSpaceUsage::new(
self.num_docs(),
self.termdict_composite.space_usage(),
self.postings_composite.space_usage(),
self.positions_composite.space_usage(),
self.fast_fields_readers.space_usage(self.schema())?,
self.fieldnorm_readers.space_usage(),
self.spatial_readers.space_usage(),
self.termdict_composite.space_usage(self.schema()),
self.postings_composite.space_usage(self.schema()),
self.positions_composite.space_usage(self.schema()),
self.fast_fields_readers.space_usage()?,
self.fieldnorm_readers.space_usage(self.schema()),
self.get_store_reader(0)?.space_usage(),
self.alive_bitset_opt
.as_ref()

View File

@@ -4,38 +4,37 @@ use std::sync::{Arc, RwLock, Weak};
use super::operation::DeleteOperation;
use crate::Opstamp;
// The DeleteQueue is similar in conceptually to a multiple
// consumer single producer broadcast channel.
//
// All consumer will receive all messages.
//
// Consumer of the delete queue are holding a `DeleteCursor`,
// which points to a specific place of the `DeleteQueue`.
//
// New consumer can be created in two ways
// - calling `delete_queue.cursor()` returns a cursor, that will include all future delete operation
// (and some or none of the past operations... The client is in charge of checking the opstamps.).
// - cloning an existing cursor returns a new cursor, that is at the exact same position, and can
// now advance independently from the original cursor.
/// The DeleteQueue is similar in conceptually to a multiple
/// consumer single producer broadcast channel.
///
/// All consumer will receive all messages.
///
/// Consumer of the delete queue are holding a `DeleteCursor`,
/// which points to a specific place of the `DeleteQueue`.
///
/// New consumer can be created in two ways
/// - calling `delete_queue.cursor()` returns a cursor, that will include all future delete
/// operation (and some or none of the past operations... The client is in charge of checking the
/// opstamps.).
/// - cloning an existing cursor returns a new cursor, that is at the exact same position, and can
/// now advance independently from the original cursor.
#[derive(Default)]
struct InnerDeleteQueue {
writer: Vec<DeleteOperation>,
last_block: Weak<Block>,
}
#[derive(Clone)]
/// The delete queue is a linked list storing delete operations.
///
/// Several consumers can hold a reference to it. Delete operations
/// get dropped/gc'ed when no more consumers are holding a reference
/// to them.
#[derive(Clone, Default)]
pub struct DeleteQueue {
inner: Arc<RwLock<InnerDeleteQueue>>,
}
impl DeleteQueue {
// Creates a new delete queue.
pub fn new() -> DeleteQueue {
DeleteQueue {
inner: Arc::default(),
}
}
fn get_last_block(&self) -> Arc<Block> {
{
// try get the last block with simply acquiring the read lock.
@@ -58,10 +57,10 @@ impl DeleteQueue {
block
}
// Creates a new cursor that makes it possible to
// consume future delete operations.
//
// Past delete operations are not accessible.
/// Creates a new cursor that makes it possible to
/// consume future delete operations.
///
/// Past delete operations are not accessible.
pub fn cursor(&self) -> DeleteCursor {
let last_block = self.get_last_block();
let operations_len = last_block.operations.len();
@@ -71,7 +70,7 @@ impl DeleteQueue {
}
}
// Appends a new delete operations.
/// Appends a new delete operations.
pub fn push(&self, delete_operation: DeleteOperation) {
self.inner
.write()
@@ -169,6 +168,7 @@ struct Block {
next: NextBlock,
}
/// As we process delete operations, keeps track of our position.
#[derive(Clone)]
pub struct DeleteCursor {
block: Arc<Block>,
@@ -261,7 +261,7 @@ mod tests {
#[test]
fn test_deletequeue() {
let delete_queue = DeleteQueue::new();
let delete_queue = DeleteQueue::default();
let make_op = |i: usize| DeleteOperation {
opstamp: i as u64,

View File

@@ -128,7 +128,7 @@ fn compute_deleted_bitset(
/// is `==` target_opstamp.
/// For instance, there was no delete operation between the state of the `segment_entry` and
/// the `target_opstamp`, `segment_entry` is not updated.
pub(crate) fn advance_deletes(
pub fn advance_deletes(
mut segment: Segment,
segment_entry: &mut SegmentEntry,
target_opstamp: Opstamp,
@@ -303,7 +303,7 @@ impl<D: Document> IndexWriter<D> {
let (document_sender, document_receiver) =
crossbeam_channel::bounded(PIPELINE_MAX_SIZE_IN_DOCS);
let delete_queue = DeleteQueue::new();
let delete_queue = DeleteQueue::default();
let current_opstamp = index.load_metas()?.opstamp;

View File

@@ -3,21 +3,21 @@ use std::net::Ipv6Addr;
use columnar::MonotonicallyMappableToU128;
use crate::fastfield::FastValue;
use crate::schema::{Field, Type};
use crate::schema::Field;
/// Term represents the value that the token can take.
/// It's a serialized representation over different types.
/// IndexingTerm is used to represent a term during indexing.
/// It's a serialized representation over field and value.
///
/// It actually wraps a `Vec<u8>`. The first 5 bytes are metadata.
/// 4 bytes are the field id, and the last byte is the type.
/// It actually wraps a `Vec<u8>`. The first 4 bytes are the field.
///
/// The serialized value `ValueBytes` is considered everything after the 4 first bytes (term id).
/// We serialize the field, because we index everything in a single
/// global term dictionary during indexing.
#[derive(Clone)]
pub(crate) struct IndexingTerm<B = Vec<u8>>(B)
where B: AsRef<[u8]>;
/// The number of bytes used as metadata by `Term`.
const TERM_METADATA_LENGTH: usize = 5;
const TERM_METADATA_LENGTH: usize = 4;
impl IndexingTerm {
/// Create a new Term with a buffer with a given capacity.
@@ -31,10 +31,9 @@ impl IndexingTerm {
/// Use `clear_with_field_and_type` in that case.
///
/// Sets field and the type.
pub(crate) fn set_field_and_type(&mut self, field: Field, typ: Type) {
pub(crate) fn set_field(&mut self, field: Field) {
assert!(self.is_empty());
self.0[0..4].clone_from_slice(field.field_id().to_be_bytes().as_ref());
self.0[4] = typ.to_code();
}
/// Is empty if there are no value bytes.
@@ -42,10 +41,10 @@ impl IndexingTerm {
self.0.len() == TERM_METADATA_LENGTH
}
/// Removes the value_bytes and set the field and type code.
pub(crate) fn clear_with_field_and_type(&mut self, typ: Type, field: Field) {
/// Removes the value_bytes and set the field
pub(crate) fn clear_with_field(&mut self, field: Field) {
self.truncate_value_bytes(0);
self.set_field_and_type(field, typ);
self.set_field(field);
}
/// Sets a u64 value in the term.
@@ -122,6 +121,23 @@ impl IndexingTerm {
impl<B> IndexingTerm<B>
where B: AsRef<[u8]>
{
/// Wraps serialized term bytes.
///
/// The input buffer is expected to be the concatenation of the big endian encoded field id
/// followed by the serialized value bytes (type tag + payload).
#[inline]
pub fn wrap(serialized_term: B) -> IndexingTerm<B> {
debug_assert!(serialized_term.as_ref().len() >= TERM_METADATA_LENGTH);
IndexingTerm(serialized_term)
}
/// Returns the field this term belongs to.
#[inline]
pub fn field(&self) -> Field {
let field_id_bytes: [u8; 4] = self.0.as_ref()[..4].try_into().unwrap();
Field::from_field_id(u32::from_be_bytes(field_id_bytes))
}
/// Returns the serialized representation of Term.
/// This includes field_id, value type and value.
///
@@ -136,6 +152,7 @@ where B: AsRef<[u8]>
#[cfg(test)]
mod tests {
use super::IndexingTerm;
use crate::schema::*;
#[test]
@@ -143,42 +160,55 @@ mod tests {
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("text", STRING);
let title_field = schema_builder.add_text_field("title", STRING);
let term = Term::from_field_text(title_field, "test");
let mut term = IndexingTerm::with_capacity(0);
term.set_field(title_field);
term.set_bytes(b"test");
assert_eq!(term.field(), title_field);
assert_eq!(term.typ(), Type::Str);
assert_eq!(term.value().as_str(), Some("test"))
assert_eq!(term.serialized_term(), b"\x00\x00\x00\x01test".to_vec())
}
/// Size (in bytes) of the buffer of a fast value (u64, i64, f64, or date) term.
/// <field> + <type byte> + <value len>
///
/// - <field> is a big endian encoded u32 field id
/// - <type_byte>'s most significant bit expresses whether the term is a json term or not The
/// remaining 7 bits are used to encode the type of the value. If this is a JSON term, the
/// type is the type of the leaf of the json.
/// - <value> is, if this is not the json term, a binary representation specific to the type.
/// If it is a JSON Term, then it is prepended with the path that leads to this leaf value.
const FAST_VALUE_TERM_LEN: usize = 4 + 1 + 8;
const FAST_VALUE_TERM_LEN: usize = 4 + 8;
#[test]
pub fn test_term_u64() {
let mut schema_builder = Schema::builder();
let count_field = schema_builder.add_u64_field("count", INDEXED);
let term = Term::from_field_u64(count_field, 983u64);
let mut term = IndexingTerm::with_capacity(0);
term.set_field(count_field);
term.set_u64(983u64);
assert_eq!(term.field(), count_field);
assert_eq!(term.typ(), Type::U64);
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
assert_eq!(term.value().as_u64(), Some(983u64))
}
#[test]
pub fn test_term_bool() {
let mut schema_builder = Schema::builder();
let bool_field = schema_builder.add_bool_field("bool", INDEXED);
let term = Term::from_field_bool(bool_field, true);
let term = {
let mut term = IndexingTerm::with_capacity(0);
term.set_field(bool_field);
term.set_bool(true);
term
};
assert_eq!(term.field(), bool_field);
assert_eq!(term.typ(), Type::Bool);
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
assert_eq!(term.value().as_bool(), Some(true))
}
#[test]
pub fn indexing_term_wrap_extracts_field() {
let field = Field::from_field_id(7u32);
let mut term = IndexingTerm::with_capacity(0);
term.set_field(field);
term.append_bytes(b"abc");
let wrapped = IndexingTerm::wrap(term.serialized_term());
assert_eq!(wrapped.field(), field);
assert_eq!(wrapped.serialized_term(), term.serialized_term());
}
}

View File

@@ -1,5 +1,3 @@
use std::collections::HashMap;
use std::io::{BufWriter, Write};
use std::sync::Arc;
use columnar::{
@@ -8,7 +6,6 @@ use columnar::{
use common::ReadOnlyBitSet;
use itertools::Itertools;
use measure_time::debug_time;
use tempfile::NamedTempFile;
use crate::directory::WritePtr;
use crate::docset::{DocSet, TERMINATED};
@@ -20,8 +17,6 @@ use crate::indexer::doc_id_mapping::{MappingType, SegmentDocIdMapping};
use crate::indexer::SegmentSerializer;
use crate::postings::{InvertedIndexSerializer, Postings, SegmentPostings};
use crate::schema::{value_type_to_column_type, Field, FieldType, Schema};
use crate::spatial::bkd::LeafPageIterator;
use crate::spatial::triangle::Triangle;
use crate::store::StoreWriter;
use crate::termdict::{TermMerger, TermOrdinal};
use crate::{DocAddress, DocId, InvertedIndexReader};
@@ -175,7 +170,6 @@ impl IndexMerger {
let mut readers = vec![];
for (segment, new_alive_bitset_opt) in segments.iter().zip(alive_bitset_opt) {
if segment.meta().num_docs() > 0 {
dbg!("segment");
let reader =
SegmentReader::open_with_custom_alive_set(segment, new_alive_bitset_opt)?;
readers.push(reader);
@@ -526,89 +520,6 @@ impl IndexMerger {
Ok(())
}
fn write_spatial_fields(
&self,
serializer: &mut SegmentSerializer,
doc_id_mapping: &SegmentDocIdMapping,
) -> crate::Result<()> {
/// We need to rebuild a BKD-tree based off the list of triangles.
///
/// Because the data can be large, we do this by writing the sequence of triangles to
/// disk, and mmapping it as mutable slice, and calling the same code as what
/// is done for the segment serialization.
///
/// The OS is in charge of deciding how to handle its page cache.
/// This is the same as what would have happened with swapping,
/// except by explicitly mapping the file, the OS is more likely to
/// swap, the memory will not be accounted as anonymous memory,
/// swap space is reserved etc.
use crate::spatial::bkd::Segment;
let Some(mut spatial_serializer) = serializer.extract_spatial_serializer() else {
// The schema does not contain any spatial field.
return Ok(());
};
let mut segment_mappings: Vec<Vec<Option<DocId>>> = Vec::new();
for reader in &self.readers {
let max_doc = reader.max_doc();
segment_mappings.push(vec![None; max_doc as usize]);
}
for (new_doc_id, old_doc_addr) in doc_id_mapping.iter_old_doc_addrs().enumerate() {
segment_mappings[old_doc_addr.segment_ord as usize][old_doc_addr.doc_id as usize] =
Some(new_doc_id as DocId);
}
let mut temp_files: HashMap<Field, NamedTempFile> = HashMap::new();
for (field, field_entry) in self.schema.fields() {
if matches!(field_entry.field_type(), FieldType::Spatial(_)) {
temp_files.insert(field, NamedTempFile::new()?);
}
}
for (segment_ord, reader) in self.readers.iter().enumerate() {
for (field, temp_file) in &mut temp_files {
let mut buf_temp_file = BufWriter::new(temp_file);
let spatial_readers = reader.spatial_fields();
let Some(spatial_reader) = spatial_readers.get_field(*field)? else {
continue;
};
let segment = Segment::new(spatial_reader.get_bytes());
for triangle_result in LeafPageIterator::new(&segment) {
let triangles = triangle_result?;
for triangle in triangles {
if let Some(new_doc_id) =
segment_mappings[segment_ord][triangle.doc_id as usize]
{
// This is really just a temporary file, not meant to be portable, so we
// use native endianness here.
for &word in &triangle.words {
buf_temp_file.write_all(&word.to_ne_bytes())?;
}
buf_temp_file.write_all(&new_doc_id.to_ne_bytes())?;
}
}
}
buf_temp_file.flush()?;
// No need to fsync here. This file is not here for persistency.
}
}
for (field, temp_file) in temp_files {
// Memory map the triangle file.
use memmap2::MmapOptions;
let mmap = unsafe { MmapOptions::new().map_mut(temp_file.as_file())? };
// Cast to &[Triangle] slice
let triangle_count = mmap.len() / std::mem::size_of::<Triangle>();
let triangles = unsafe {
std::slice::from_raw_parts_mut(mmap.as_ptr() as *mut Triangle, triangle_count)
};
// Get spatial writer and rebuild block kd-tree.
spatial_serializer.serialize_field(field, triangles)?;
}
spatial_serializer.close()?;
Ok(())
}
/// Writes the merged segment by pushing information
/// to the `SegmentSerializer`.
///
@@ -633,10 +544,9 @@ impl IndexMerger {
debug!("write-storagefields");
self.write_storable_fields(serializer.get_store_writer())?;
debug!("write-spatialfields");
self.write_spatial_fields(&mut serializer, &doc_id_mapping)?;
debug!("write-fastfields");
self.write_fast_fields(serializer.get_fast_field_write(), doc_id_mapping)?;
debug!("close-serializer");
serializer.close()?;
Ok(self.max_doc)

View File

@@ -4,6 +4,7 @@
//! `IndexWriter` is the main entry point for that, which created from
//! [`Index::writer`](crate::Index::writer).
/// Delete queue implementation for broadcasting delete operations to consumers.
pub(crate) mod delete_queue;
pub(crate) mod path_to_unordered_id;
@@ -32,12 +33,11 @@ mod stamper;
use crossbeam_channel as channel;
use smallvec::SmallVec;
pub use self::index_writer::{IndexWriter, IndexWriterOptions};
pub use self::index_writer::{advance_deletes, IndexWriter, IndexWriterOptions};
pub use self::log_merge_policy::LogMergePolicy;
pub use self::merge_operation::MergeOperation;
pub use self::merge_policy::{MergeCandidate, MergePolicy, NoMergePolicy};
use self::operation::AddOperation;
pub use self::operation::UserOperation;
pub use self::operation::{AddOperation, DeleteOperation, UserOperation};
pub use self::prepared_commit::PreparedCommit;
pub use self::segment_entry::SegmentEntry;
pub(crate) use self::segment_serializer::SegmentSerializer;

View File

@@ -5,14 +5,20 @@ use crate::Opstamp;
/// Timestamped Delete operation.
pub struct DeleteOperation {
/// Operation stamp.
/// It is used to check whether the delete operation
/// applies to an added document operation.
pub opstamp: Opstamp,
/// Weight is used to define the set of documents to be deleted.
pub target: Box<dyn Weight>,
}
/// Timestamped Add operation.
#[derive(Eq, PartialEq, Debug)]
pub struct AddOperation<D: Document = TantivyDocument> {
/// Operation stamp.
pub opstamp: Opstamp,
/// Document to be added.
pub document: D,
}

View File

@@ -117,7 +117,7 @@ mod tests {
#[test]
fn test_segment_register() {
let inventory = SegmentMetaInventory::default();
let delete_queue = DeleteQueue::new();
let delete_queue = DeleteQueue::default();
let mut segment_register = SegmentRegister::default();
let segment_id_a = SegmentId::generate_random();

View File

@@ -4,7 +4,6 @@ use crate::directory::WritePtr;
use crate::fieldnorm::FieldNormsSerializer;
use crate::index::{Segment, SegmentComponent};
use crate::postings::InvertedIndexSerializer;
use crate::spatial::serializer::SpatialSerializer;
use crate::store::StoreWriter;
/// Segment serializer is in charge of laying out on disk
@@ -13,7 +12,6 @@ pub struct SegmentSerializer {
segment: Segment,
pub(crate) store_writer: StoreWriter,
fast_field_write: WritePtr,
spatial_serializer: Option<SpatialSerializer>,
fieldnorms_serializer: Option<FieldNormsSerializer>,
postings_serializer: InvertedIndexSerializer,
}
@@ -37,20 +35,11 @@ impl SegmentSerializer {
let fieldnorms_write = segment.open_write(SegmentComponent::FieldNorms)?;
let fieldnorms_serializer = FieldNormsSerializer::from_write(fieldnorms_write)?;
let spatial_serializer: Option<SpatialSerializer> =
if segment.schema().contains_spatial_field() {
let spatial_write = segment.open_write(SegmentComponent::Spatial)?;
Some(SpatialSerializer::from_write(spatial_write)?)
} else {
None
};
let postings_serializer = InvertedIndexSerializer::open(&mut segment)?;
Ok(SegmentSerializer {
segment,
store_writer,
fast_field_write,
spatial_serializer,
fieldnorms_serializer: Some(fieldnorms_serializer),
postings_serializer,
})
@@ -75,11 +64,6 @@ impl SegmentSerializer {
&mut self.fast_field_write
}
/// Accessor to the `SpatialSerializer`
pub fn extract_spatial_serializer(&mut self) -> Option<SpatialSerializer> {
self.spatial_serializer.take()
}
/// Extract the field norm serializer.
///
/// Note the fieldnorms serializer can only be extracted once.
@@ -97,9 +81,6 @@ impl SegmentSerializer {
if let Some(fieldnorms_serializer) = self.extract_fieldnorms_serializer() {
fieldnorms_serializer.close()?;
}
if let Some(spatial_serializer) = self.extract_spatial_serializer() {
spatial_serializer.close()?;
}
self.fast_field_write.terminate()?;
self.postings_serializer.close()?;
self.store_writer.close()?;

View File

@@ -16,7 +16,6 @@ use crate::postings::{
};
use crate::schema::document::{Document, Value};
use crate::schema::{FieldEntry, FieldType, Schema, DATE_TIME_PRECISION_INDEXED};
use crate::spatial::writer::SpatialWriter;
use crate::tokenizer::{FacetTokenizer, PreTokenizedStream, TextAnalyzer, Tokenizer};
use crate::{DocId, Opstamp, TantivyError};
@@ -53,7 +52,6 @@ pub struct SegmentWriter {
pub(crate) segment_serializer: SegmentSerializer,
pub(crate) fast_field_writers: FastFieldsWriter,
pub(crate) fieldnorms_writer: FieldNormsWriter,
pub(crate) spatial_writer: SpatialWriter,
pub(crate) json_path_writer: JsonPathWriter,
pub(crate) json_positions_per_path: IndexingPositionsPerPath,
pub(crate) doc_opstamps: Vec<Opstamp>,
@@ -106,7 +104,6 @@ impl SegmentWriter {
ctx: IndexingContext::new(table_size),
per_field_postings_writers,
fieldnorms_writer: FieldNormsWriter::for_schema(&schema),
spatial_writer: SpatialWriter::default(),
json_path_writer: JsonPathWriter::default(),
json_positions_per_path: IndexingPositionsPerPath::default(),
segment_serializer,
@@ -133,7 +130,6 @@ impl SegmentWriter {
self.ctx,
self.fast_field_writers,
&self.fieldnorms_writer,
&mut self.spatial_writer,
self.segment_serializer,
)?;
Ok(self.doc_opstamps)
@@ -146,7 +142,6 @@ impl SegmentWriter {
+ self.fieldnorms_writer.mem_usage()
+ self.fast_field_writers.mem_usage()
+ self.segment_serializer.mem_usage()
+ self.spatial_writer.mem_usage()
}
fn index_document<D: Document>(&mut self, doc: &D) -> crate::Result<()> {
@@ -176,7 +171,7 @@ impl SegmentWriter {
let (term_buffer, ctx) = (&mut self.term_buffer, &mut self.ctx);
let postings_writer: &mut dyn PostingsWriter =
self.per_field_postings_writers.get_for_field_mut(field);
term_buffer.clear_with_field_and_type(field_entry.field_type().value_type(), field);
term_buffer.clear_with_field(field);
match field_entry.field_type() {
FieldType::Facet(_) => {
@@ -343,13 +338,6 @@ impl SegmentWriter {
self.fieldnorms_writer.record(doc_id, field, num_vals);
}
}
FieldType::Spatial(_) => {
for value in values {
if let Some(geometry) = value.as_geometry() {
self.spatial_writer.add_geometry(doc_id, field, *geometry);
}
}
}
}
}
Ok(())
@@ -404,16 +392,12 @@ fn remap_and_write(
ctx: IndexingContext,
fast_field_writers: FastFieldsWriter,
fieldnorms_writer: &FieldNormsWriter,
spatial_writer: &mut SpatialWriter,
mut serializer: SegmentSerializer,
) -> crate::Result<()> {
debug!("remap-and-write");
if let Some(fieldnorms_serializer) = serializer.extract_fieldnorms_serializer() {
fieldnorms_writer.serialize(fieldnorms_serializer)?;
}
if let Some(spatial_serializer) = serializer.extract_spatial_serializer() {
spatial_writer.serialize(spatial_serializer)?;
}
let fieldnorm_data = serializer
.segment()
.open_read(SegmentComponent::FieldNorms)?;
@@ -437,10 +421,9 @@ fn remap_and_write(
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use std::path::{Path, PathBuf};
use std::path::Path;
use columnar::ColumnType;
use tempfile::TempDir;
use crate::collector::{Count, TopDocs};
use crate::directory::RamDirectory;
@@ -1083,10 +1066,7 @@ mod tests {
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("title", text_options);
let schema = schema_builder.build();
let tempdir = TempDir::new().unwrap();
let tempdir_path = PathBuf::from(tempdir.path());
Index::create_in_dir(&tempdir_path, schema).unwrap();
let index = Index::open_in_dir(tempdir_path).unwrap();
let index = Index::create_in_ram(schema);
let schema = index.schema();
let mut index_writer = index.writer(50_000_000).unwrap();
let title = schema.get_field("title").unwrap();

View File

@@ -17,6 +17,7 @@
//!
//! ```rust
//! # use std::path::Path;
//! # use std::fs;
//! # use tempfile::TempDir;
//! # use tantivy::collector::TopDocs;
//! # use tantivy::query::QueryParser;
@@ -27,8 +28,11 @@
//! # // Let's create a temporary directory for the
//! # // sake of this example
//! # if let Ok(dir) = TempDir::new() {
//! # run_example(dir.path()).unwrap();
//! # dir.close().unwrap();
//! # let index_path = dir.path().join("index");
//! # // In case the directory already exists, we remove it
//! # let _ = fs::remove_dir_all(&index_path);
//! # fs::create_dir_all(&index_path).unwrap();
//! # run_example(&index_path).unwrap();
//! # }
//! # }
//! #
@@ -191,7 +195,6 @@ pub mod fieldnorm;
pub mod index;
pub mod positions;
pub mod postings;
pub mod spatial;
/// Module containing the different query implementations.
pub mod query;
@@ -204,6 +207,7 @@ mod docset;
mod reader;
#[cfg(test)]
#[cfg(feature = "mmap")]
mod compat_tests;
pub use self::reader::{IndexReader, IndexReaderBuilder, ReloadPolicy, Warmer};
@@ -217,9 +221,7 @@ use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
pub use self::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN, TERMINATED};
#[doc(hidden)]
pub use crate::core::json_utils;
pub use crate::core::{Executor, Searcher, SearcherGeneration};
pub use crate::core::{json_utils, Executor, Searcher, SearcherGeneration};
pub use crate::directory::Directory;
pub use crate::index::{
Index, IndexBuilder, IndexMeta, IndexSettings, InvertedIndexReader, Order, Segment,
@@ -1173,12 +1175,11 @@ pub mod tests {
#[test]
fn test_validate_checksum() -> crate::Result<()> {
let index_path = tempfile::tempdir().expect("dir");
let mut builder = Schema::builder();
let body = builder.add_text_field("body", TEXT | STORED);
let schema = builder.build();
let index = Index::create_in_dir(&index_path, schema)?;
let mut writer: IndexWriter = index.writer(50_000_000)?;
let index = Index::create_in_ram(schema);
let mut writer: IndexWriter = index.writer_for_tests()?;
writer.set_merge_policy(Box::new(NoMergePolicy));
for _ in 0..5000 {
writer.add_document(doc!(body => "foo"))?;

View File

@@ -1,12 +1,15 @@
use bitpacking::{BitPacker, BitPacker4x};
use common::FixedSize;
pub const COMPRESSION_BLOCK_SIZE: usize = BitPacker4x::BLOCK_LEN;
const COMPRESSED_BLOCK_MAX_SIZE: usize = COMPRESSION_BLOCK_SIZE * u32::SIZE_IN_BYTES;
// in vint encoding, each byte stores 7 bits of data, so we need at most 32 / 7 = 4.57 bytes to
// store a u32 in the worst case, rounding up to 5 bytes total
const MAX_VINT_SIZE: usize = 5;
const COMPRESSED_BLOCK_MAX_SIZE: usize = COMPRESSION_BLOCK_SIZE * MAX_VINT_SIZE;
mod vint;
/// Returns the size in bytes of a compressed block, given `num_bits`.
#[inline]
pub fn compressed_block_size(num_bits: u8) -> usize {
(num_bits as usize) * COMPRESSION_BLOCK_SIZE / 8
}
@@ -267,7 +270,6 @@ impl VIntDecoder for BlockDecoder {
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::TERMINATED;
@@ -372,6 +374,13 @@ pub(crate) mod tests {
}
}
}
#[test]
fn test_compress_vint_unsorted_does_not_overflow() {
let mut encoder = BlockEncoder::new();
let input: Vec<u32> = vec![u32::MAX; COMPRESSION_BLOCK_SIZE];
encoder.compress_vint_unsorted(&input);
}
}
#[cfg(all(test, feature = "unstable"))]

View File

@@ -8,7 +8,7 @@ use crate::indexer::path_to_unordered_id::OrderedPathId;
use crate::postings::postings_writer::SpecializedPostingsWriter;
use crate::postings::recorder::{BufferLender, DocIdRecorder, Recorder};
use crate::postings::{FieldSerializer, IndexingContext, IndexingPosition, PostingsWriter};
use crate::schema::{Field, Type, ValueBytes};
use crate::schema::{Field, Type};
use crate::tokenizer::TokenStream;
use crate::DocId;
@@ -79,8 +79,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
term_buffer.truncate(term_path_len);
term_buffer.append_bytes(term);
let json_value = ValueBytes::wrap(term);
let typ = json_value.typ();
let typ = Type::from_code(term[0]).expect("Invalid type code in JSON term");
if typ == Type::Str {
SpecializedPostingsWriter::<Rec>::serialize_one_term(
term_buffer.as_bytes(),
@@ -107,6 +106,8 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
}
}
/// Helper to build the JSON term bytes that land in the term dictionary.
/// Format: `[json path utf8][JSON_END_OF_PATH][type tag][payload]`
struct JsonTermSerializer(Vec<u8>);
impl JsonTermSerializer {
/// Appends a JSON path to the Term.

View File

@@ -527,6 +527,7 @@ pub(crate) mod tests {
}
impl<TScorer: Scorer> Scorer for UnoptimizedDocSet<TScorer> {
#[inline]
fn score(&mut self) -> Score {
self.0.score()
}

View File

@@ -51,7 +51,6 @@ fn posting_writer_from_field_entry(field_entry: &FieldEntry) -> Box<dyn Postings
| FieldType::Date(_)
| FieldType::Bytes(_)
| FieldType::IpAddr(_)
| FieldType::Spatial(_)
| FieldType::Facet(_) => Box::<SpecializedPostingsWriter<DocIdRecorder>>::default(),
FieldType::JsonObject(ref json_object_options) => {
if let Some(text_indexing_option) = json_object_options.get_text_indexing_options() {

View File

@@ -11,7 +11,7 @@ use crate::postings::recorder::{BufferLender, Recorder};
use crate::postings::{
FieldSerializer, IndexingContext, InvertedIndexSerializer, PerFieldPostingsWriter,
};
use crate::schema::{Field, Schema, Term, Type};
use crate::schema::{Field, Schema, Type};
use crate::tokenizer::{Token, TokenStream, MAX_TOKEN_LEN};
use crate::DocId;
@@ -59,14 +59,14 @@ pub(crate) fn serialize_postings(
let mut term_offsets: Vec<(Field, OrderedPathId, &[u8], Addr)> =
Vec::with_capacity(ctx.term_index.len());
term_offsets.extend(ctx.term_index.iter().map(|(key, addr)| {
let field = Term::wrap(key).field();
let field = IndexingTerm::wrap(key).field();
if schema.get_field_entry(field).field_type().value_type() == Type::Json {
let byte_range_path = 5..5 + 4;
let byte_range_path = 4..4 + 4;
let unordered_id = u32::from_be_bytes(key[byte_range_path.clone()].try_into().unwrap());
let path_id = unordered_id_to_ordered_id[unordered_id as usize];
(field, path_id, &key[byte_range_path.end..], addr)
} else {
(field, 0.into(), &key[5..], addr)
(field, 0.into(), &key[4..], addr)
}
}));
// Sort by field, path, and term

View File

@@ -6,17 +6,21 @@ use crate::{DocId, Score, TERMINATED};
// doc num bits uses the following encoding:
// given 0b a b cdefgh
// |1|2| 3 |
// |1|2|3| 4 |
// - 1: unused
// - 2: is delta-1 encoded. 0 if not, 1, if yes
// - 3: a 6 bit number in 0..=32, the actual bitwidth
// - 3: unused
// - 4: a 5 bit number in 0..32, the actual bitwidth. Bitpacking could in theory say this is 32
// (requiring a 6th bit), but the biggest doc_id we can want to encode is TERMINATED-1, which can
// be represented on 31b without delta encoding.
fn encode_bitwidth(bitwidth: u8, delta_1: bool) -> u8 {
assert!(bitwidth < 32);
bitwidth | ((delta_1 as u8) << 6)
}
fn decode_bitwidth(raw_bitwidth: u8) -> (u8, bool) {
let delta_1 = ((raw_bitwidth >> 6) & 1) != 0;
let bitwidth = raw_bitwidth & 0x3f;
let bitwidth = raw_bitwidth & 0x1f;
(bitwidth, delta_1)
}
@@ -430,7 +434,7 @@ mod tests {
#[test]
fn test_encode_decode_bitwidth() {
for bitwidth in 0..=32 {
for bitwidth in 0..32 {
for delta_1 in [false, true] {
assert_eq!(
(bitwidth, delta_1),

View File

@@ -23,7 +23,11 @@ pub struct AllWeight;
impl Weight for AllWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let all_scorer = AllScorer::new(reader.max_doc());
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
if boost != 1.0 {
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
} else {
Ok(Box::new(all_scorer))
}
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
@@ -58,6 +62,15 @@ impl DocSet for AllScorer {
self.doc
}
fn seek(&mut self, target: DocId) -> DocId {
debug_assert!(target >= self.doc);
self.doc = target;
if self.doc >= self.max_doc {
self.doc = TERMINATED;
}
self.doc
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
if self.doc() == TERMINATED {
return 0;
@@ -92,6 +105,7 @@ impl DocSet for AllScorer {
}
impl Scorer for AllScorer {
#[inline]
fn score(&mut self) -> Score {
1.0
}

View File

@@ -483,7 +483,7 @@ mod tests {
let checkpoints_for_each_pruning =
compute_checkpoints_for_each_pruning(term_scorers.clone(), top_k);
let checkpoints_manual =
compute_checkpoints_manual(term_scorers.clone(), top_k, 100_000);
compute_checkpoints_manual(term_scorers.clone(), top_k, max_doc as u32);
assert_eq!(checkpoints_for_each_pruning.len(), checkpoints_manual.len());
for (&(left_doc, left_score), &(right_doc, right_score)) in checkpoints_for_each_pruning
.iter()

View File

@@ -97,6 +97,65 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
}
}
/// Returns the effective MUST scorer, accounting for removed AllScorers.
///
/// When AllScorer instances are removed from must_scorers as an optimization,
/// we must restore the "match all" semantics if the list becomes empty.
fn effective_must_scorer(
must_scorers: Vec<Box<dyn Scorer>>,
removed_all_scorer_count: usize,
max_doc: DocId,
num_docs: u32,
) -> Option<Box<dyn Scorer>> {
if must_scorers.is_empty() {
if removed_all_scorer_count > 0 {
// Had AllScorer(s) only - all docs match
Some(Box::new(AllScorer::new(max_doc)))
} else {
// No MUST constraint at all
None
}
} else {
Some(intersect_scorers(must_scorers, num_docs))
}
}
/// Returns a SHOULD scorer with AllScorer union if any were removed.
///
/// For union semantics (OR): if any SHOULD clause was an AllScorer, the result
/// should include all documents. We restore this by unioning with AllScorer.
///
/// When `scoring_enabled` is false, we can just return AllScorer alone since
/// we don't need score contributions from the should_scorer.
fn effective_should_scorer_for_union<TScoreCombiner: ScoreCombiner>(
should_scorer: SpecializedScorer,
removed_all_scorer_count: usize,
max_doc: DocId,
num_docs: u32,
score_combiner_fn: impl Fn() -> TScoreCombiner,
scoring_enabled: bool,
) -> SpecializedScorer {
if removed_all_scorer_count > 0 {
if scoring_enabled {
// Need to union to get score contributions from both
let all_scorers: Vec<Box<dyn Scorer>> = vec![
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
Box::new(AllScorer::new(max_doc)),
];
SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
all_scorers,
score_combiner_fn,
num_docs,
)))
} else {
// Scoring disabled - AllScorer alone is sufficient
SpecializedScorer::Other(Box::new(AllScorer::new(max_doc)))
}
} else {
should_scorer
}
}
enum ShouldScorersCombinationMethod {
// Should scorers are irrelevant.
Ignored,
@@ -193,18 +252,18 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
let minimum_number_should_match = self
let effective_minimum_number_should_match = self
.minimum_number_should_match
.saturating_sub(should_special_scorer_counts.num_all_scorers);
let should_scorers: ShouldScorersCombinationMethod = {
let num_of_should_scorers = should_scorers.len();
if minimum_number_should_match > num_of_should_scorers {
if effective_minimum_number_should_match > num_of_should_scorers {
// We don't have enough scorers to satisfy the minimum number of should matches.
// The request will match no documents.
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
match minimum_number_should_match {
match effective_minimum_number_should_match {
0 if num_of_should_scorers == 0 => ShouldScorersCombinationMethod::Ignored,
0 => ShouldScorersCombinationMethod::Optional(scorer_union(
should_scorers,
@@ -226,7 +285,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
scorer_disjunction(
should_scorers,
score_combiner_fn(),
self.minimum_number_should_match,
effective_minimum_number_should_match,
),
)),
}
@@ -246,53 +305,78 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
let include_scorer = match (should_scorers, must_scorers) {
(ShouldScorersCombinationMethod::Ignored, must_scorers) => {
let boxed_scorer: Box<dyn Scorer> = if must_scorers.is_empty() {
// We do not have any should scorers, nor all scorers.
// There are still two cases here.
//
// If this follows the removal of some AllScorers in the should/must clauses,
// then we match all documents.
//
// Otherwise, it is really just an EmptyScorer.
if must_special_scorer_counts.num_all_scorers
+ should_special_scorer_counts.num_all_scorers
> 0
{
Box::new(AllScorer::new(reader.max_doc()))
} else {
Box::new(EmptyScorer)
}
} else {
intersect_scorers(must_scorers, num_docs)
};
// No SHOULD clauses (or they were absorbed into MUST).
// Result depends entirely on MUST + any removed AllScorers.
let combined_all_scorer_count = must_special_scorer_counts.num_all_scorers
+ should_special_scorer_counts.num_all_scorers;
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
must_scorers,
combined_all_scorer_count,
reader.max_doc(),
num_docs,
)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
(ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => {
if must_scorers.is_empty() && must_special_scorer_counts.num_all_scorers == 0 {
// Optional options are promoted to required if no must scorers exists.
should_scorer
} else {
let must_scorer = intersect_scorers(must_scorers, num_docs);
if self.scoring_enabled {
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
_,
_,
TScoreCombiner,
>::new(
must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
)))
} else {
SpecializedScorer::Other(must_scorer)
// Optional SHOULD: contributes to scoring but not required for matching.
match effective_must_scorer(
must_scorers,
must_special_scorer_counts.num_all_scorers,
reader.max_doc(),
num_docs,
) {
None => {
// No MUST constraint: promote SHOULD to required.
// Must preserve any removed AllScorers from SHOULD via union.
effective_should_scorer_for_union(
should_scorer,
should_special_scorer_counts.num_all_scorers,
reader.max_doc(),
num_docs,
&score_combiner_fn,
self.scoring_enabled,
)
}
Some(must_scorer) => {
// Has MUST constraint: SHOULD only affects scoring.
if self.scoring_enabled {
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
_,
_,
TScoreCombiner,
>::new(
must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
)))
} else {
SpecializedScorer::Other(must_scorer)
}
}
}
}
(ShouldScorersCombinationMethod::Required(should_scorer), mut must_scorers) => {
if must_scorers.is_empty() {
should_scorer
} else {
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs));
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
(ShouldScorersCombinationMethod::Required(should_scorer), must_scorers) => {
// Required SHOULD: at least `minimum_number_should_match` must match.
// Semantics: (MUST constraint) AND (SHOULD constraint)
match effective_must_scorer(
must_scorers,
must_special_scorer_counts.num_all_scorers,
reader.max_doc(),
num_docs,
) {
None => {
// No MUST constraint: SHOULD alone determines matching.
should_scorer
}
Some(must_scorer) => {
// Has MUST constraint: intersect MUST with SHOULD.
let should_boxed =
into_box_scorer(should_scorer, &score_combiner_fn, num_docs);
SpecializedScorer::Other(intersect_scorers(
vec![must_scorer, should_boxed],
num_docs,
))
}
}
}
};

View File

@@ -9,12 +9,14 @@ pub use self::boolean_weight::BooleanWeight;
#[cfg(test)]
mod tests {
use std::ops::Bound;
use super::*;
use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE;
use crate::collector::TopDocs;
use crate::collector::{Count, TopDocs};
use crate::query::term_query::TermScorer;
use crate::query::{
AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser,
AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser, RangeQuery,
RequiredOptionalScorer, Scorer, SumCombiner, TermQuery,
};
use crate::schema::*;
@@ -374,4 +376,466 @@ mod tests {
}
Ok(())
}
#[test]
pub fn test_min_should_match_with_all_query() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let num_field =
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
index_writer.add_document(doc!(text_field => "apple", num_field => 10i64))?;
index_writer.add_document(doc!(text_field => "banana", num_field => 20i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let effective_all_match_query: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 0)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "apple"),
IndexRecordOption::Basic,
));
// in some previous version, we would remove the 2 all_match, but then say we need *4*
// matches out of the 3 term queries, which matches nothing.
let mut bool_query = BooleanQuery::new(vec![
(Occur::Should, effective_all_match_query.box_clone()),
(Occur::Should, effective_all_match_query.box_clone()),
(Occur::Should, term_query.box_clone()),
(Occur::Should, term_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
bool_query.set_minimum_number_should_match(4);
let count = searcher.search(&bool_query, &Count)?;
assert_eq!(count, 1);
Ok(())
}
// =========================================================================
// AllScorer Preservation Regression Tests
// =========================================================================
//
// These tests verify the fix for a bug where AllScorer instances (produced by
// queries matching all documents, such as range queries covering all values)
// were incorrectly removed from Boolean query processing, causing documents
// to be unexpectedly excluded from results.
//
// The bug manifested in several scenarios:
// 1. SHOULD + SHOULD where one clause is AllScorer
// 2. MUST (AllScorer) + SHOULD
// 3. Range queries in Boolean clauses when all documents match the range
/// Regression test: SHOULD clause with AllScorer combined with other SHOULD clauses.
///
/// When a SHOULD clause produces an AllScorer (e.g., from a range query matching
/// all documents), the Boolean query should still match all documents.
///
/// Bug before fix: AllScorer was removed during optimization, leaving only the
/// other SHOULD clauses, which incorrectly excluded documents.
#[test]
pub fn test_should_with_all_scorer_regression() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let num_field =
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// All docs have num > 0, so range query will return AllScorer
index_writer.add_document(doc!(text_field => "hello", num_field => 10i64))?;
index_writer.add_document(doc!(text_field => "world", num_field => 20i64))?;
index_writer.add_document(doc!(text_field => "hello world", num_field => 30i64))?;
index_writer.add_document(doc!(text_field => "foo", num_field => 40i64))?;
index_writer.add_document(doc!(text_field => "bar", num_field => 50i64))?;
index_writer.add_document(doc!(text_field => "baz", num_field => 60i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
// Range query matching all docs (returns AllScorer)
let all_match_query: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 0)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "hello"),
IndexRecordOption::Basic,
));
// Verify range matches all 6 docs
assert_eq!(searcher.search(all_match_query.as_ref(), &Count)?, 6);
// RangeQuery(all) OR TermQuery should match all 6 docs
let bool_query = BooleanQuery::new(vec![
(Occur::Should, all_match_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
let count = searcher.search(&bool_query, &Count)?;
assert_eq!(count, 6, "SHOULD with AllScorer should match all docs");
// Order should not matter
let bool_query_reversed = BooleanQuery::new(vec![
(Occur::Should, term_query.box_clone()),
(Occur::Should, all_match_query.box_clone()),
]);
let count_reversed = searcher.search(&bool_query_reversed, &Count)?;
assert_eq!(
count_reversed, 6,
"Order of SHOULD clauses should not matter"
);
Ok(())
}
/// Regression test: MUST clause with AllScorer combined with SHOULD clause.
///
/// When MUST contains an AllScorer, all documents satisfy the MUST constraint.
/// The SHOULD clause should only affect scoring, not filtering.
///
/// Bug before fix: AllScorer was removed, leaving an empty must_scorers vector.
/// intersect_scorers([]) incorrectly returned EmptyScorer, matching 0 documents.
#[test]
pub fn test_must_all_with_should_regression() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let num_field =
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// All docs have num > 0, so range query will return AllScorer
index_writer.add_document(doc!(text_field => "apple", num_field => 10i64))?;
index_writer.add_document(doc!(text_field => "banana", num_field => 20i64))?;
index_writer.add_document(doc!(text_field => "cherry", num_field => 30i64))?;
index_writer.add_document(doc!(text_field => "date", num_field => 40i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
// Range query matching all docs (returns AllScorer)
let all_match_query: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 0)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "apple"),
IndexRecordOption::Basic,
));
// Verify range matches all 4 docs
assert_eq!(searcher.search(all_match_query.as_ref(), &Count)?, 4);
// MUST(range matching all) AND SHOULD(term) should match all 4 docs
let bool_query = BooleanQuery::new(vec![
(Occur::Must, all_match_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
let count = searcher.search(&bool_query, &Count)?;
assert_eq!(count, 4, "MUST AllScorer + SHOULD should match all docs");
Ok(())
}
/// Regression test: Range queries in Boolean clauses when all documents match.
///
/// Range queries can return AllScorer as an optimization when all indexed values
/// fall within the range. This test ensures such queries work correctly in
/// Boolean combinations.
///
/// This is the most common real-world manifestation of the bug, occurring in
/// queries like: (age > 50 OR name = 'Alice') AND status = 'active'
/// when all documents have age > 50.
#[test]
pub fn test_range_query_all_match_in_boolean() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let name_field = schema_builder.add_text_field("name", TEXT);
let age_field =
schema_builder.add_i64_field("age", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// All documents have age > 50, so range query will return AllScorer
index_writer.add_document(doc!(name_field => "alice", age_field => 55_i64))?;
index_writer.add_document(doc!(name_field => "bob", age_field => 60_i64))?;
index_writer.add_document(doc!(name_field => "charlie", age_field => 70_i64))?;
index_writer.add_document(doc!(name_field => "diana", age_field => 80_i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let range_query: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(age_field, 50)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(name_field, "alice"),
IndexRecordOption::Basic,
));
// Verify preconditions
assert_eq!(searcher.search(range_query.as_ref(), &Count)?, 4);
assert_eq!(searcher.search(term_query.as_ref(), &Count)?, 1);
// SHOULD(range) OR SHOULD(term): range matches all, so result is 4
let should_query = BooleanQuery::new(vec![
(Occur::Should, range_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
assert_eq!(
searcher.search(&should_query, &Count)?,
4,
"SHOULD range OR term should match all"
);
// MUST(range) AND SHOULD(term): range matches all, term is optional
let must_should_query = BooleanQuery::new(vec![
(Occur::Must, range_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
assert_eq!(
searcher.search(&must_should_query, &Count)?,
4,
"MUST range + SHOULD term should match all"
);
Ok(())
}
/// Test multiple AllScorer instances in different clause types.
///
/// Verifies correct behavior when AllScorers appear in multiple positions.
#[test]
pub fn test_multiple_all_scorers() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let num_field =
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// All docs have num > 0, so range queries will return AllScorer
index_writer.add_document(doc!(text_field => "doc1", num_field => 10i64))?;
index_writer.add_document(doc!(text_field => "doc2", num_field => 20i64))?;
index_writer.add_document(doc!(text_field => "doc3", num_field => 30i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
// Two different range queries that both match all docs (return AllScorer)
let all_query1: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 0)),
Bound::Unbounded,
));
let all_query2: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 5)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "doc1"),
IndexRecordOption::Basic,
));
// Multiple AllScorers in SHOULD
let multi_all_should = BooleanQuery::new(vec![
(Occur::Should, all_query1.box_clone()),
(Occur::Should, all_query2.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
assert_eq!(
searcher.search(&multi_all_should, &Count)?,
3,
"Multiple AllScorers in SHOULD"
);
// AllScorer in both MUST and SHOULD
let all_must_and_should = BooleanQuery::new(vec![
(Occur::Must, all_query1.box_clone()),
(Occur::Should, all_query2.box_clone()),
]);
assert_eq!(
searcher.search(&all_must_and_should, &Count)?,
3,
"AllScorer in both MUST and SHOULD"
);
Ok(())
}
}
/// A proptest which generates arbitrary permutations of a simple boolean AST, and then matches
/// the result against an index which contains all permutations of documents with N fields.
#[cfg(test)]
mod proptest_boolean_query {
use std::collections::{BTreeMap, HashSet};
use std::ops::{Bound, Range};
use proptest::collection::vec;
use proptest::prelude::*;
use crate::collector::DocSetCollector;
use crate::query::{AllQuery, BooleanQuery, Occur, Query, RangeQuery, TermQuery};
use crate::schema::{Field, NumericOptions, OwnedValue, Schema, TEXT};
use crate::{DocId, Index, Term};
#[derive(Debug, Clone)]
enum BooleanQueryAST {
/// Matches all documents via AllQuery (wraps AllScorer in BoostScorer)
All,
/// Matches all documents via RangeQuery (returns bare AllScorer)
/// This is the actual trigger for the AllScorer preservation bug
RangeAll,
/// Matches documents where the field has value "true"
Leaf {
field_idx: usize,
},
Union(Vec<BooleanQueryAST>),
Intersection(Vec<BooleanQueryAST>),
}
impl BooleanQueryAST {
fn matches(&self, doc_id: DocId) -> bool {
match self {
BooleanQueryAST::All => true,
BooleanQueryAST::RangeAll => true,
BooleanQueryAST::Leaf { field_idx } => Self::matches_field(doc_id, *field_idx),
BooleanQueryAST::Union(children) => {
children.iter().any(|child| child.matches(doc_id))
}
BooleanQueryAST::Intersection(children) => {
children.iter().all(|child| child.matches(doc_id))
}
}
}
fn matches_field(doc_id: DocId, field_idx: usize) -> bool {
((doc_id as usize) >> field_idx) & 1 == 1
}
fn to_query(&self, fields: &[Field], range_field: Field) -> Box<dyn Query> {
match self {
BooleanQueryAST::All => Box::new(AllQuery),
BooleanQueryAST::RangeAll => {
// Range query that matches all docs (all have value >= 0)
// This returns bare AllScorer, triggering the bug we fixed
Box::new(RangeQuery::new(
Bound::Included(Term::from_field_i64(range_field, 0)),
Bound::Unbounded,
))
}
BooleanQueryAST::Leaf { field_idx } => Box::new(TermQuery::new(
Term::from_field_text(fields[*field_idx], "true"),
crate::schema::IndexRecordOption::Basic,
)),
BooleanQueryAST::Union(children) => {
let sub_queries = children
.iter()
.map(|child| (Occur::Should, child.to_query(fields, range_field)))
.collect();
Box::new(BooleanQuery::new(sub_queries))
}
BooleanQueryAST::Intersection(children) => {
let sub_queries = children
.iter()
.map(|child| (Occur::Must, child.to_query(fields, range_field)))
.collect();
Box::new(BooleanQuery::new(sub_queries))
}
}
}
}
fn doc_ids(num_docs: usize, num_fields: usize) -> Range<DocId> {
let permutations = 1 << num_fields;
let copies = (num_docs as f32 / permutations as f32).ceil() as u32;
0..(permutations * copies)
}
fn create_index_with_boolean_permutations(
num_docs: usize,
num_fields: usize,
) -> (Index, Vec<Field>, Field) {
let mut schema_builder = Schema::builder();
let fields: Vec<Field> = (0..num_fields)
.map(|i| schema_builder.add_text_field(&format!("field_{}", i), TEXT))
.collect();
// Add a numeric field for RangeQuery tests - all docs have value = doc_id
let range_field = schema_builder.add_i64_field(
"range_field",
NumericOptions::default().set_fast().set_indexed(),
);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
for doc_id in doc_ids(num_docs, num_fields) {
let mut doc: BTreeMap<_, OwnedValue> = BTreeMap::default();
for (field_idx, &field) in fields.iter().enumerate() {
if (doc_id >> field_idx) & 1 == 1 {
doc.insert(field, "true".into());
}
}
// All docs have non-negative values, so RangeQuery(>=0) matches all
doc.insert(range_field, (doc_id as i64).into());
writer.add_document(doc).unwrap();
}
writer.commit().unwrap();
(index, fields, range_field)
}
fn arb_boolean_query_ast(num_fields: usize) -> impl Strategy<Value = BooleanQueryAST> {
// Leaf strategies: term queries, AllQuery, and RangeQuery matching all docs
let leaf = prop_oneof![
(0..num_fields).prop_map(|field_idx| BooleanQueryAST::Leaf { field_idx }),
Just(BooleanQueryAST::All),
Just(BooleanQueryAST::RangeAll),
];
leaf.prop_recursive(
8, // 8 levels of recursion
256, // 256 nodes max
10, // 10 items per collection
|inner| {
prop_oneof![
vec(inner.clone(), 1..10).prop_map(BooleanQueryAST::Union),
vec(inner, 1..10).prop_map(BooleanQueryAST::Intersection),
]
},
)
}
#[test]
fn proptest_boolean_query() {
// In the presence of optimizations around buffering, it can take large numbers of
// documents to uncover some issues.
let num_fields = 8;
let num_docs = 1 << num_fields;
let (index, fields, range_field) =
create_index_with_boolean_permutations(num_docs, num_fields);
let searcher = index.reader().unwrap().searcher();
proptest!(|(ast in arb_boolean_query_ast(num_fields))| {
let query = ast.to_query(&fields, range_field);
let mut matching_docs = HashSet::new();
for doc_id in doc_ids(num_docs, num_fields) {
if ast.matches(doc_id as DocId) {
matching_docs.insert(doc_id as DocId);
}
}
let doc_addresses = searcher.search(&*query, &DocSetCollector).unwrap();
let result_docs: HashSet<DocId> =
doc_addresses.into_iter().map(|doc_address| doc_address.doc_id).collect();
prop_assert_eq!(result_docs, matching_docs);
});
}
}

View File

@@ -104,6 +104,9 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
fn seek(&mut self, target: DocId) -> DocId {
self.underlying.seek(target)
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
self.underlying.seek_into_the_danger_zone(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
self.underlying.fill_buffer(buffer)
@@ -131,6 +134,7 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
}
impl<S: Scorer> Scorer for BoostScorer<S> {
#[inline]
fn score(&mut self) -> Score {
self.underlying.score() * self.boost
}

View File

@@ -137,6 +137,7 @@ impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
}
impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
#[inline]
fn score(&mut self) -> Score {
self.score
}

View File

@@ -62,6 +62,16 @@ impl<T: Scorer> DocSet for ScorerWrapper<T> {
self.current_doc = doc_id;
doc_id
}
fn seek(&mut self, target: DocId) -> DocId {
let doc_id = self.scorer.seek(target);
self.current_doc = doc_id;
doc_id
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
let found = self.scorer.seek_into_the_danger_zone(target);
self.current_doc = self.scorer.doc();
found
}
fn doc(&self) -> DocId {
self.current_doc
@@ -163,6 +173,7 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> DocSet
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer
for Disjunction<TScorer, TScoreCombiner>
{
#[inline]
fn score(&mut self) -> Score {
self.current_score
}
@@ -297,6 +308,7 @@ mod tests {
}
impl Scorer for DummyScorer {
#[inline]
fn score(&mut self) -> Score {
self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0)
}

View File

@@ -55,6 +55,7 @@ impl DocSet for EmptyScorer {
}
impl Scorer for EmptyScorer {
#[inline]
fn score(&mut self) -> Score {
0.0
}

View File

@@ -84,6 +84,7 @@ where
TScorer: Scorer,
TDocSetExclude: DocSet + 'static,
{
#[inline]
fn score(&mut self) -> Score {
self.underlying_docset.score()
}

View File

@@ -1,5 +1,5 @@
use super::size_hint::estimate_intersection;
use crate::docset::{DocSet, TERMINATED};
use crate::query::size_hint::estimate_intersection;
use crate::query::term_query::TermScorer;
use crate::query::{EmptyScorer, Scorer};
use crate::{DocId, Score};
@@ -12,6 +12,9 @@ use crate::{DocId, Score};
/// For better performance, the function uses a
/// specialized implementation if the two
/// shortest scorers are `TermScorer`s.
///
/// num_docs_segment is the number of documents in the segment. It is used for estimating the
/// `size_hint` of the intersection.
pub fn intersect_scorers(
mut scorers: Vec<Box<dyn Scorer>>,
num_docs_segment: u32,
@@ -102,35 +105,48 @@ impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
}
impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOtherDocSet> {
#[inline]
fn advance(&mut self) -> DocId {
let (left, right) = (&mut self.left, &mut self.right);
let mut candidate = left.advance();
if candidate == TERMINATED {
return TERMINATED;
}
'outer: loop {
loop {
// In the first part we look for a document in the intersection
// of the two rarest `DocSet` in the intersection.
loop {
let right_doc = right.seek(candidate);
candidate = left.seek(right_doc);
if candidate == right_doc {
if right.seek_into_the_danger_zone(candidate) {
break;
}
let right_doc = right.doc();
// TODO: Think about which value would make sense here
// It depends on the DocSet implementation, when a seek would outweigh an advance.
if right_doc > candidate.wrapping_add(100) {
candidate = left.seek(right_doc);
} else {
candidate = left.advance();
}
if candidate == TERMINATED {
return TERMINATED;
}
}
debug_assert_eq!(left.doc(), right.doc());
// test the remaining scorers;
for docset in self.others.iter_mut() {
let seek_doc = docset.seek(candidate);
if seek_doc > candidate {
candidate = left.seek(seek_doc);
continue 'outer;
}
// test the remaining scorers
if self
.others
.iter_mut()
.all(|docset| docset.seek_into_the_danger_zone(candidate))
{
debug_assert_eq!(candidate, self.left.doc());
debug_assert_eq!(candidate, self.right.doc());
debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate));
return candidate;
}
debug_assert_eq!(candidate, self.left.doc());
debug_assert_eq!(candidate, self.right.doc());
debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate));
return candidate;
candidate = left.advance();
}
}
@@ -146,6 +162,20 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
doc
}
/// Seeks to the target if necessary and checks if the target is an exact match.
///
/// Some implementations may choose to advance past the target if beneficial for performance.
/// The return value is `true` if the target is in the docset, and `false` otherwise.
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
self.left.seek_into_the_danger_zone(target)
&& self.right.seek_into_the_danger_zone(target)
&& self
.others
.iter_mut()
.all(|docset| docset.seek_into_the_danger_zone(target))
}
#[inline]
fn doc(&self) -> DocId {
self.left.doc()
}
@@ -172,6 +202,7 @@ where
TScorer: Scorer,
TOtherScorer: Scorer,
{
#[inline]
fn score(&mut self) -> Score {
self.left.score()
+ self.right.score()
@@ -181,6 +212,8 @@ where
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::Intersection;
use crate::docset::{DocSet, TERMINATED};
use crate::postings::tests::test_skip_against_unoptimized;
@@ -270,4 +303,38 @@ mod tests {
let intersection = Intersection::new(vec![a, b, c], 10);
assert_eq!(intersection.doc(), TERMINATED);
}
// Strategy to generate sorted and deduplicated vectors of u32 document IDs
fn sorted_deduped_vec(max_val: u32, max_size: usize) -> impl Strategy<Value = Vec<u32>> {
prop::collection::vec(0..max_val, 0..max_size).prop_map(|mut vec| {
vec.sort();
vec.dedup();
vec
})
}
proptest! {
#[test]
fn prop_test_intersection_consistency(
a in sorted_deduped_vec(100, 10),
b in sorted_deduped_vec(100, 10),
num_docs in 100u32..500u32
) {
let left = VecDocSet::from(a.clone());
let right = VecDocSet::from(b.clone());
let mut intersection = Intersection::new(vec![left, right], num_docs);
let expected: Vec<u32> = a.iter()
.cloned()
.filter(|doc| b.contains(doc))
.collect();
for expected_doc in expected {
assert_eq!(intersection.doc(), expected_doc);
intersection.advance();
}
assert_eq!(intersection.doc(), TERMINATED);
}
}
}

View File

@@ -24,7 +24,6 @@ mod reqopt_scorer;
mod scorer;
mod set_query;
mod size_hint;
mod spatial_query;
mod term_query;
mod union;
mod weight;
@@ -63,7 +62,6 @@ pub use self::reqopt_scorer::RequiredOptionalScorer;
pub use self::score_combiner::{DisjunctionMaxCombiner, ScoreCombiner, SumCombiner};
pub use self::scorer::Scorer;
pub use self::set_query::TermSetQuery;
pub use self::spatial_query::{SpatialQuery, SpatialQueryType};
pub use self::term_query::TermQuery;
pub use self::union::BufferedUnionScorer;
#[cfg(test)]
@@ -72,9 +70,83 @@ pub use self::weight::Weight;
#[cfg(test)]
mod tests {
use crate::collector::TopDocs;
use crate::query::phrase_query::tests::create_index;
use crate::query::QueryParser;
use crate::schema::{Schema, TEXT};
use crate::{Index, Term};
use crate::{DocAddress, Index, Term};
#[test]
pub fn test_mixed_intersection_and_union() -> crate::Result<()> {
let index = create_index(&["a b", "a c", "a b c", "b"])?;
let schema = index.schema();
let text_field = schema.get_field("text").unwrap();
let searcher = index.reader()?.searcher();
let do_search = |term: &str| {
let query = QueryParser::for_index(&index, vec![text_field])
.parse_query(term)
.unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(10).order_by_score())
.unwrap();
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>()
};
assert_eq!(do_search("a AND b"), vec![0, 2]);
assert_eq!(do_search("(a OR b) AND C"), vec![2, 1]);
// The intersection code has special code for more than 2 intersections
// left, right + others
// The will place the union in the "others" insersection to that seek_into_the_danger_zone
// is called
assert_eq!(
do_search("(a OR b) AND (c OR a) AND (b OR c)"),
vec![2, 1, 0]
);
Ok(())
}
#[test]
pub fn test_mixed_intersection_and_union_with_skip() -> crate::Result<()> {
// Test 4096 skip in BufferedUnionScorer
let mut data: Vec<&str> = Vec::new();
data.push("a b");
let zz_data = vec!["z z"; 5000];
data.extend_from_slice(&zz_data);
data.extend_from_slice(&["a c"]);
data.extend_from_slice(&zz_data);
data.extend_from_slice(&["a b c", "b"]);
let index = create_index(&data)?;
let schema = index.schema();
let text_field = schema.get_field("text").unwrap();
let searcher = index.reader()?.searcher();
let do_search = |term: &str| {
let query = QueryParser::for_index(&index, vec![text_field])
.parse_query(term)
.unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(10).order_by_score())
.unwrap();
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>()
};
assert_eq!(do_search("a AND b"), vec![0, 10002]);
assert_eq!(do_search("(a OR b) AND C"), vec![10002, 5001]);
// The intersection code has special code for more than 2 intersections
// left, right + others
// The will place the union in the "others" insersection to that seek_into_the_danger_zone
// is called
assert_eq!(
do_search("(a OR b) AND (c OR a) AND (b OR c)"),
vec![10002, 5001, 0]
);
Ok(())
}
#[test]
fn test_query_terms() {

View File

@@ -81,6 +81,7 @@ impl<TPostings: Postings> DocSet for PhraseKind<TPostings> {
}
impl<TPostings: Postings> Scorer for PhraseKind<TPostings> {
#[inline]
fn score(&mut self) -> Score {
match self {
PhraseKind::SinglePrefix { positions, .. } => {
@@ -193,6 +194,14 @@ impl<TPostings: Postings> DocSet for PhrasePrefixScorer<TPostings> {
self.advance()
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
if self.phrase_scorer.seek_into_the_danger_zone(target) {
self.matches_prefix()
} else {
false
}
}
fn doc(&self) -> DocId {
self.phrase_scorer.doc()
}
@@ -207,6 +216,7 @@ impl<TPostings: Postings> DocSet for PhrasePrefixScorer<TPostings> {
}
impl<TPostings: Postings> Scorer for PhrasePrefixScorer<TPostings> {
#[inline]
fn score(&mut self) -> Score {
// TODO modify score??
self.phrase_scorer.score()

View File

@@ -382,8 +382,9 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
PostingsWithOffset::new(postings, (max_offset - offset) as u32)
})
.collect::<Vec<_>>();
let intersection_docset = Intersection::new(postings_with_offsets, num_docs);
let mut scorer = PhraseScorer {
intersection_docset: Intersection::new(postings_with_offsets, num_docs),
intersection_docset,
num_terms: num_docsets,
left_positions: Vec::with_capacity(100),
right_positions: Vec::with_capacity(100),
@@ -529,25 +530,40 @@ impl<TPostings: Postings> DocSet for PhraseScorer<TPostings> {
self.advance()
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
debug_assert!(target >= self.doc());
if self.intersection_docset.seek_into_the_danger_zone(target) && self.phrase_match() {
return true;
}
false
}
fn doc(&self) -> DocId {
self.intersection_docset.doc()
}
fn size_hint(&self) -> u32 {
self.intersection_docset.size_hint()
// We adjust the intersection estimate, since actual phrase hits are much lower than where
// the all appear.
// The estimate should depend on average field length, e.g. if the field is really short
// a phrase hit is more likely
self.intersection_docset.size_hint() / (10 * self.num_terms as u32)
}
/// Returns a best-effort hint of the
/// cost to drive the docset.
fn cost(&self) -> u64 {
// Evaluating phrase matches is generally more expensive than simple term matches,
// as it requires loading and comparing positions. Use a conservative multiplier
// based on the number of terms.
// While determing a potential hit is cheap for phrases, evaluating an actual hit is
// expensive since it requires to load positions for a doc and check if they are next to
// each other.
// So the cost estimation would be the number of times we need to check if a doc is a hit *
// 10 * self.num_terms.
self.intersection_docset.size_hint() as u64 * 10 * self.num_terms as u64
}
}
impl<TPostings: Postings> Scorer for PhraseScorer<TPostings> {
#[inline]
fn score(&mut self) -> Score {
let doc = self.doc();
let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc);

View File

@@ -524,9 +524,6 @@ impl QueryParser {
let ip_v6 = IpAddr::from_str(phrase)?.into_ipv6_addr();
Ok(Term::from_field_ip_addr(field, ip_v6))
}
FieldType::Spatial(_) => Err(QueryParserError::UnsupportedQuery(
"Spatial queries are not yet supported in text query parser".to_string(),
)),
}
}
@@ -627,10 +624,6 @@ impl QueryParser {
let term = Term::from_field_ip_addr(field, ip_v6);
Ok(vec![LogicalLiteral::Term(term)])
}
FieldType::Spatial(_) => Err(QueryParserError::UnsupportedQuery(format!(
"Spatial queries are not yet supported for field '{}'",
field_name
))),
}
}

View File

@@ -1,7 +1,6 @@
use core::fmt::Debug;
use std::ops::RangeInclusive;
use columnar::Column;
use columnar::{Column, ValueRange};
use crate::{DocId, DocSet, TERMINATED};
@@ -41,7 +40,7 @@ impl VecCursor {
pub(crate) struct RangeDocSet<T> {
/// The range filter on the values.
value_range: RangeInclusive<T>,
value_range: ValueRange<T>,
column: Column<T>,
/// The next docid start range to fetch (inclusive).
next_fetch_start: u32,
@@ -61,7 +60,18 @@ pub(crate) struct RangeDocSet<T> {
const DEFAULT_FETCH_HORIZON: u32 = 128;
impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
pub(crate) fn new(value_range: RangeInclusive<T>, column: Column<T>) -> Self {
pub(crate) fn new(value_range: ValueRange<T>, column: Column<T>) -> Self {
if !value_range.intersects(column.min_value(), column.max_value()) {
return Self {
value_range,
column,
loaded_docs: VecCursor::new(),
next_fetch_start: TERMINATED,
fetch_horizon: DEFAULT_FETCH_HORIZON,
last_seek_pos_opt: None,
};
}
let mut range_docset = Self {
value_range,
column,
@@ -81,6 +91,9 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
/// Returns true if more data could be fetched
fn fetch_block(&mut self) {
if self.next_fetch_start >= self.column.num_docs() {
return;
}
const MAX_HORIZON: u32 = 100_000;
while self.loaded_docs.is_empty() {
let finished_to_end = self.fetch_horizon(self.fetch_horizon);
@@ -105,10 +118,10 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
fn fetch_horizon(&mut self, horizon: u32) -> bool {
let mut finished_to_end = false;
let limit = self.column.num_docs();
let mut end = self.next_fetch_start + horizon;
if end >= limit {
end = limit;
let num_docs = self.column.num_docs();
let mut fetch_end = self.next_fetch_start + horizon;
if fetch_end >= num_docs {
fetch_end = num_docs;
finished_to_end = true;
}
@@ -116,7 +129,7 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
let doc_buffer: &mut Vec<DocId> = self.loaded_docs.get_cleared_data();
self.column.get_docids_for_value_range(
self.value_range.clone(),
self.next_fetch_start..end,
self.next_fetch_start..fetch_end,
doc_buffer,
);
if let Some(last_doc) = last_doc {
@@ -124,7 +137,7 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
self.loaded_docs.next();
}
}
self.next_fetch_start = end;
self.next_fetch_start = fetch_end;
finished_to_end
}
@@ -136,9 +149,6 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
if let Some(docid) = self.loaded_docs.next() {
return docid;
}
if self.next_fetch_start >= self.column.num_docs() {
return TERMINATED;
}
self.fetch_block();
self.loaded_docs.current().unwrap_or(TERMINATED)
}
@@ -174,15 +184,25 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
}
fn size_hint(&self) -> u32 {
self.column.num_docs()
// TODO: Implement a better size hint
self.column.num_docs() / 10
}
/// Returns a best-effort hint of the
/// cost to drive the docset.
fn cost(&self) -> u64 {
// Advancing the docset is relatively expensive since it scans the column.
// Keep cost relative to a term query driver; use num_docs as baseline.
self.column.num_docs() as u64
// Advancing the docset is pretty expensive since it scans the whole column, there is no
// index currently (will change with an kd-tree)
// Since we use SIMD to scan the fast field range query we lower the cost a little bit,
// assuming that we hit 10% of the docs like in size_hint.
//
// If we would return a cost higher than num_docs, we would never choose ff range query as
// the driver in a DocSet, when intersecting a term query with a fast field. But
// it's the faster choice when the term query has a lot of docids and the range
// query has not.
//
// Ideally this would take the fast field codec into account
(self.column.num_docs() as f64 * 0.8) as u64
}
}
@@ -236,4 +256,52 @@ mod tests {
let count = searcher.search(&query, &Count).unwrap();
assert_eq!(count, 500);
}
#[test]
fn range_query_no_overlap_optimization() {
let mut schema_builder = schema::SchemaBuilder::new();
let id_field = schema_builder.add_text_field("id", schema::STRING);
let value_field = schema_builder.add_u64_field("value", schema::FAST | schema::INDEXED);
let dir = RamDirectory::default();
let index = IndexBuilder::new()
.schema(schema_builder.build())
.open_or_create(dir)
.unwrap();
{
let mut writer = index.writer(15_000_000).unwrap();
// Add documents with values in the range [10, 20]
for i in 0..100 {
let mut doc = TantivyDocument::new();
doc.add_text(id_field, format!("doc{i}"));
doc.add_u64(value_field, 10 + (i % 11) as u64); // values in range 10-20
writer.add_document(doc).unwrap();
}
writer.commit().unwrap();
}
let reader = index.reader().unwrap();
let searcher = reader.searcher();
// Test a range query [100, 200] that has no overlap with data range [10, 20]
let query = RangeQuery::new(
Bound::Included(Term::from_field_u64(value_field, 100)),
Bound::Included(Term::from_field_u64(value_field, 200)),
);
let count = searcher.search(&query, &Count).unwrap();
assert_eq!(count, 0); // should return 0 results since there's no overlap
// Test another non-overlapping range: [0, 5] while data range is [10, 20]
let query2 = RangeQuery::new(
Bound::Included(Term::from_field_u64(value_field, 0)),
Bound::Included(Term::from_field_u64(value_field, 5)),
);
let count2 = searcher.search(&query2, &Count).unwrap();
assert_eq!(count2, 0); // should return 0 results since there's no overlap
}
}

View File

@@ -20,6 +20,6 @@ pub(crate) fn is_type_valid_for_fastfield_range_query(typ: Type) -> bool {
| Type::Date
| Type::Json
| Type::IpAddr => true,
Type::Facet | Type::Bytes | Type::Spatial => false,
Type::Facet | Type::Bytes => false,
}
}

View File

@@ -7,7 +7,7 @@ use std::ops::{Bound, RangeInclusive};
use columnar::{
Cardinality, Column, ColumnType, MonotonicallyMappableToU128, MonotonicallyMappableToU64,
NumericalType, StrColumn,
NumericalType, StrColumn, ValueRange,
};
use common::bounds::{BoundsRange, TransformBound};
@@ -128,15 +128,12 @@ impl Weight for FastFieldRangeWeight {
BoundsRange::new(bounds.lower_bound, bounds.upper_bound),
)
}
Type::Bool
| Type::Facet
| Type::Bytes
| Type::Json
| Type::IpAddr
| Type::Spatial => Err(crate::TantivyError::InvalidArgument(format!(
"unsupported value bytes type in json term value_bytes {:?}",
term_value.typ()
))),
Type::Bool | Type::Facet | Type::Bytes | Type::Json | Type::IpAddr => {
Err(crate::TantivyError::InvalidArgument(format!(
"unsupported value bytes type in json term value_bytes {:?}",
term_value.typ()
)))
}
}
} else if field_type.is_ip_addr() {
let parse_ip_from_bytes = |term: &Term| {
@@ -157,7 +154,7 @@ impl Weight for FastFieldRangeWeight {
ip_addr_column.min_value(),
ip_addr_column.max_value(),
);
let docset = RangeDocSet::new(value_range, ip_addr_column);
let docset = RangeDocSet::new(ValueRange::Inclusive(value_range), ip_addr_column);
Ok(Box::new(ConstScorer::new(docset, boost)))
} else if field_type.is_str() {
let Some(str_dict_column): Option<StrColumn> = reader.fast_fields().str(&field_name)?
@@ -429,7 +426,7 @@ fn search_on_u64_ff(
}
}
let docset = RangeDocSet::new(value_range, column);
let docset = RangeDocSet::new(ValueRange::Inclusive(value_range), column);
Ok(Box::new(ConstScorer::new(docset, boost)))
}
@@ -438,7 +435,7 @@ pub(crate) fn maps_to_u64_fastfield(typ: Type) -> bool {
match typ {
Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true,
Type::IpAddr => false,
Type::Str | Type::Facet | Type::Bytes | Type::Json | Type::Spatial => false,
Type::Str | Type::Facet | Type::Bytes | Type::Json => false,
}
}
@@ -1601,449 +1598,3 @@ pub(crate) mod ip_range_tests {
Ok(())
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use test::Bencher;
use super::tests::*;
use super::*;
use crate::collector::Count;
use crate::query::QueryParser;
use crate::Index;
fn get_index_0_to_100() -> Index {
let mut rng = StdRng::from_seed([1u8; 32]);
let num_vals = 100_000;
let docs: Vec<_> = (0..num_vals)
.map(|_i| {
let id_name = if rng.gen_bool(0.01) {
"veryfew".to_string() // 1%
} else if rng.gen_bool(0.1) {
"few".to_string() // 9%
} else {
"many".to_string() // 90%
};
Doc {
id_name,
id: rng.gen_range(0..100),
}
})
.collect();
create_index_from_docs(&docs, false)
}
fn get_90_percent() -> RangeInclusive<u64> {
0..=90
}
fn get_10_percent() -> RangeInclusive<u64> {
0..=10
}
fn get_1_percent() -> RangeInclusive<u64> {
10..=10
}
fn execute_query(
field: &str,
id_range: RangeInclusive<u64>,
suffix: &str,
index: &Index,
) -> usize {
let gen_query_inclusive = |from: &u64, to: &u64| {
format!(
"{}:[{} TO {}] {}",
field,
&from.to_string(),
&to.to_string(),
suffix
)
};
let query = gen_query_inclusive(id_range.start(), id_range.end());
let query_from_text = |text: &str| {
QueryParser::for_index(index, vec![])
.parse_query(text)
.unwrap()
};
let query = query_from_text(&query);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
searcher.search(&query, &(Count)).unwrap()
}
#[bench]
fn bench_id_range_hit_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_90_percent(), "", &index));
}
#[bench]
fn bench_id_range_hit_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_10_percent(), "", &index));
}
#[bench]
fn bench_id_range_hit_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_1_percent(), "", &index));
}
#[bench]
fn bench_id_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_10_percent(), "AND id_name:few", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:few", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:many", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:veryfew", &index));
}
#[bench]
fn bench_id_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_10_percent(), "AND id_name:many", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:many", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:few", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:veryfew", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_90_percent(), "", &index));
}
#[bench]
fn bench_id_range_hit_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_10_percent(), "", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_1_percent(), "", &index));
}
#[bench]
fn bench_id_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_10_percent(), "AND id_name:few", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:few", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:many", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:veryfew", &index));
}
#[bench]
fn bench_id_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_10_percent(), "AND id_name:many", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:many", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:few", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:veryfew", &index));
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench_ip {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use test::Bencher;
use super::ip_range_tests::*;
use super::*;
use crate::collector::Count;
use crate::query::QueryParser;
use crate::Index;
fn get_index_0_to_100() -> Index {
let mut rng = StdRng::from_seed([1u8; 32]);
let num_vals = 100_000;
let docs: Vec<_> = (0..num_vals)
.map(|_i| {
let id = if rng.gen_bool(0.01) {
"veryfew".to_string() // 1%
} else if rng.gen_bool(0.1) {
"few".to_string() // 9%
} else {
"many".to_string() // 90%
};
Doc {
id,
// Multiply by 1000, so that we create many buckets in the compact space
// The benches depend on this range to select n-percent of elements with the
// methods below.
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
}
})
.collect();
create_index_from_ip_docs(&docs)
}
fn get_90_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(90 * 1000);
start..=end
}
fn get_10_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn get_1_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(10 * 1000);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn execute_query(
field: &str,
ip_range: RangeInclusive<Ipv6Addr>,
suffix: &str,
index: &Index,
) -> usize {
let gen_query_inclusive = |from: &Ipv6Addr, to: &Ipv6Addr| {
format!(
"{}:[{} TO {}] {}",
field,
&from.to_string(),
&to.to_string(),
suffix
)
};
let query = gen_query_inclusive(ip_range.start(), ip_range.end());
let query_from_text = |text: &str| {
QueryParser::for_index(index, vec![])
.parse_query(text)
.unwrap()
};
let query = query_from_text(&query);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
searcher.search(&query, &(Count)).unwrap()
}
#[bench]
fn bench_ip_range_hit_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_90_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_10_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_1_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_10_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_1_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_1_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_1_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_10_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_90_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_90_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_90_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_90_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_10_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_1_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_10_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_1_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_1_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_1_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_10_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_90_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_90_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_90_percent(), "AND id:veryfew", &index));
}
}

View File

@@ -56,6 +56,11 @@ where
self.req_scorer.seek(target)
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
self.score_cache = None;
self.req_scorer.seek_into_the_danger_zone(target)
}
fn doc(&self) -> DocId {
self.req_scorer.doc()
}
@@ -76,6 +81,7 @@ where
TOptScorer: Scorer,
TScoreCombiner: ScoreCombiner,
{
#[inline]
fn score(&mut self) -> Score {
if let Some(score) = self.score_cache {
return score;

View File

@@ -29,6 +29,7 @@ impl ScoreCombiner for DoNothingCombiner {
fn clear(&mut self) {}
#[inline]
fn score(&self) -> Score {
1.0
}
@@ -49,6 +50,7 @@ impl ScoreCombiner for SumCombiner {
self.score = 0.0;
}
#[inline]
fn score(&self) -> Score {
self.score
}
@@ -86,6 +88,7 @@ impl ScoreCombiner for DisjunctionMaxCombiner {
self.sum = 0.0;
}
#[inline]
fn score(&self) -> Score {
self.max + (self.sum - self.max) * self.tie_breaker
}

View File

@@ -18,6 +18,7 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static {
impl_downcast!(Scorer);
impl Scorer for Box<dyn Scorer> {
#[inline]
fn score(&mut self) -> Score {
self.deref_mut().score()
}

View File

@@ -1,186 +0,0 @@
//! HUSH
use common::BitSet;
use crate::query::explanation::does_not_match;
use crate::query::{BitSetDocSet, Explanation, Query, Scorer, Weight};
use crate::schema::Field;
use crate::spatial::bkd::{search_intersects, Segment};
use crate::spatial::point::GeoPoint;
use crate::spatial::writer::as_point_i32;
use crate::{DocId, DocSet, Score, TantivyError, TERMINATED};
#[derive(Clone, Copy, Debug)]
/// HUSH
pub enum SpatialQueryType {
/// HUSH
Intersects,
// Within,
// Contains,
}
#[derive(Clone, Copy, Debug)]
/// HUSH
pub struct SpatialQuery {
field: Field,
bounds: [(i32, i32); 2],
query_type: SpatialQueryType,
}
impl SpatialQuery {
/// HUSH
pub fn new(field: Field, bounds: [GeoPoint; 2], query_type: SpatialQueryType) -> Self {
SpatialQuery {
field,
bounds: [as_point_i32(bounds[0]), as_point_i32(bounds[1])],
query_type,
}
}
}
impl Query for SpatialQuery {
fn weight(
&self,
_enable_scoring: super::EnableScoring<'_>,
) -> crate::Result<Box<dyn super::Weight>> {
Ok(Box::new(SpatialWeight::new(
self.field,
self.bounds,
self.query_type,
)))
}
}
pub struct SpatialWeight {
field: Field,
bounds: [(i32, i32); 2],
query_type: SpatialQueryType,
}
impl SpatialWeight {
fn new(field: Field, bounds: [(i32, i32); 2], query_type: SpatialQueryType) -> Self {
SpatialWeight {
field,
bounds,
query_type,
}
}
}
impl Weight for SpatialWeight {
fn scorer(
&self,
reader: &crate::SegmentReader,
boost: crate::Score,
) -> crate::Result<Box<dyn super::Scorer>> {
let spatial_reader = reader
.spatial_fields()
.get_field(self.field)?
.ok_or_else(|| TantivyError::SchemaError(format!("No spatial data for field")))?;
let block_kd_tree = Segment::new(spatial_reader.get_bytes());
match self.query_type {
SpatialQueryType::Intersects => {
let mut include = BitSet::with_max_value(reader.max_doc());
search_intersects(
&block_kd_tree,
block_kd_tree.root_offset,
&[
self.bounds[0].1,
self.bounds[0].0,
self.bounds[1].1,
self.bounds[1].0,
],
&mut include,
)?;
Ok(Box::new(SpatialScorer::new(boost, include, None)))
}
}
}
fn explain(
&self,
reader: &crate::SegmentReader,
doc: DocId,
) -> crate::Result<super::Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc {
return Err(does_not_match(doc));
}
let query_type_desc = match self.query_type {
SpatialQueryType::Intersects => "SpatialQuery::Intersects",
};
let score = scorer.score();
let mut explanation = Explanation::new(query_type_desc, score);
explanation.add_context(format!(
"bounds: [({}, {}), ({}, {})]",
self.bounds[0].0, self.bounds[0].1, self.bounds[1].0, self.bounds[1].1,
));
explanation.add_context(format!("field: {:?}", self.field));
Ok(explanation)
}
}
struct SpatialScorer {
include: BitSetDocSet,
exclude: Option<BitSet>,
doc_id: DocId,
score: Score,
}
impl SpatialScorer {
pub fn new(score: Score, include: BitSet, exclude: Option<BitSet>) -> Self {
let mut scorer = SpatialScorer {
include: BitSetDocSet::from(include),
exclude,
doc_id: 0,
score,
};
scorer.prime();
scorer
}
fn prime(&mut self) {
self.doc_id = self.include.doc();
while self.exclude() {
self.doc_id = self.include.advance();
}
}
fn exclude(&self) -> bool {
if self.doc_id == TERMINATED {
return false;
}
match &self.exclude {
Some(exclude) => exclude.contains(self.doc_id),
None => false,
}
}
}
impl Scorer for SpatialScorer {
fn score(&mut self) -> Score {
self.score
}
}
impl DocSet for SpatialScorer {
fn advance(&mut self) -> DocId {
if self.doc_id == TERMINATED {
return TERMINATED;
}
self.doc_id = self.include.advance();
while self.exclude() {
self.doc_id = self.include.advance();
}
self.doc_id
}
fn size_hint(&self) -> u32 {
match &self.exclude {
Some(exclude) => self.include.size_hint() - exclude.len() as u32,
None => self.include.size_hint(),
}
}
fn doc(&self) -> DocId {
self.doc_id
}
}

View File

@@ -98,14 +98,17 @@ impl TermScorer {
}
impl DocSet for TermScorer {
#[inline]
fn advance(&mut self) -> DocId {
self.postings.advance()
}
#[inline]
fn seek(&mut self, target: DocId) -> DocId {
self.postings.seek(target)
}
#[inline]
fn doc(&self) -> DocId {
self.postings.doc()
}
@@ -116,6 +119,7 @@ impl DocSet for TermScorer {
}
impl Scorer for TermScorer {
#[inline]
fn score(&mut self) -> Score {
let fieldnorm_id = self.fieldnorm_id();
let term_freq = self.term_freq();

View File

@@ -15,7 +15,7 @@ const HORIZON: u32 = 64u32 * 64u32;
// This function is similar except that it does is not unstable, and
// it does not keep the original vector ordering.
//
// Also, it does not "yield" any elements.
// Elements are dropped and not yielded.
fn unordered_drain_filter<T, P>(v: &mut Vec<T>, mut predicate: P)
where P: FnMut(&mut T) -> bool {
let mut i = 0;
@@ -128,6 +128,7 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
}
}
#[inline]
fn advance_buffered(&mut self) -> bool {
while self.bucket_idx < HORIZON_NUM_TINYBITSETS {
if let Some(val) = self.bitsets[self.bucket_idx].pop_lowest() {
@@ -143,6 +144,12 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
}
false
}
fn is_in_horizon(&self, target: DocId) -> bool {
// wrapping_sub, because target may be < window_start_doc
let gap = target.wrapping_sub(self.window_start_doc);
gap < HORIZON
}
}
impl<TScorer, TScoreCombiner> DocSet for BufferedUnionScorer<TScorer, TScoreCombiner>
@@ -150,6 +157,7 @@ where
TScorer: Scorer,
TScoreCombiner: ScoreCombiner,
{
#[inline]
fn advance(&mut self) -> DocId {
if self.advance_buffered() {
return self.doc;
@@ -217,8 +225,29 @@ where
}
}
// TODO Also implement `count` with deletes efficiently.
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
if self.is_in_horizon(target) {
// Our value is within the buffered horizon and the docset may already have been
// processed and removed, so we need to use seek, which uses the regular advance.
self.seek(target) == target
} else {
// The docsets are not in the buffered range, so we can use seek_into_the_danger_zone
// of the underlying docsets
let is_hit = self
.docsets
.iter_mut()
.any(|docset| docset.seek_into_the_danger_zone(target));
// The API requires the DocSet to be in a valid state when `seek_into_the_danger_zone`
// returns true.
if is_hit {
self.seek(target);
}
is_hit
}
}
#[inline]
fn doc(&self) -> DocId {
self.doc
}
@@ -231,6 +260,7 @@ where
self.docsets.iter().map(|docset| docset.cost()).sum()
}
// TODO Also implement `count` with deletes efficiently.
fn count_including_deleted(&mut self) -> u32 {
if self.doc == TERMINATED {
return 0;
@@ -259,6 +289,7 @@ where
TScoreCombiner: ScoreCombiner,
TScorer: Scorer,
{
#[inline]
fn score(&mut self) -> Score {
self.score
}

View File

@@ -92,6 +92,7 @@ impl<TDocSet: DocSet> DocSet for SimpleUnion<TDocSet> {
}
fn size_hint(&self) -> u32 {
// TODO: use estimate_union
self.docsets
.iter()
.map(|docset| docset.size_hint())

View File

@@ -22,7 +22,6 @@ use super::se::BinaryObjectSerializer;
use super::{OwnedValue, Value};
use crate::schema::document::type_codes;
use crate::schema::{Facet, Field};
use crate::spatial::geometry::Geometry;
use crate::store::DocStoreVersion;
use crate::tokenizer::PreTokenizedString;
@@ -130,9 +129,6 @@ pub trait ValueDeserializer<'de> {
/// Attempts to deserialize a pre-tokenized string value from the deserializer.
fn deserialize_pre_tokenized_string(self) -> Result<PreTokenizedString, DeserializeError>;
/// HUSH
fn deserialize_geometry(self) -> Result<Geometry, DeserializeError>;
/// Attempts to deserialize the value using a given visitor.
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, DeserializeError>
where V: ValueVisitor;
@@ -170,8 +166,6 @@ pub enum ValueType {
/// A JSON object value. Deprecated.
#[deprecated(note = "We keep this for backwards compatibility, use Object instead")]
JSONObject,
/// HUSH
Geometry,
}
/// A value visitor for deserializing a document value.
@@ -252,12 +246,6 @@ pub trait ValueVisitor {
Err(DeserializeError::UnsupportedType(ValueType::PreTokStr))
}
#[inline]
/// Called when the deserializer visits a geometry value.
fn visit_geometry(&self, _val: Geometry) -> Result<Self::Value, DeserializeError> {
Err(DeserializeError::UnsupportedType(ValueType::Geometry))
}
#[inline]
/// Called when the deserializer visits an array.
fn visit_array<'de, A>(&self, _access: A) -> Result<Self::Value, DeserializeError>
@@ -392,7 +380,6 @@ where R: Read
match ext_type_code {
type_codes::TOK_STR_EXT_CODE => ValueType::PreTokStr,
type_codes::GEO_EXT_CODE => ValueType::Geometry,
_ => {
return Err(DeserializeError::from(io::Error::new(
io::ErrorKind::InvalidData,
@@ -508,11 +495,6 @@ where R: Read
.map_err(DeserializeError::from)
}
fn deserialize_geometry(self) -> Result<Geometry, DeserializeError> {
self.validate_type(ValueType::Geometry)?;
<Geometry as BinarySerializable>::deserialize(self.reader).map_err(DeserializeError::from)
}
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, DeserializeError>
where V: ValueVisitor {
match self.value_type {
@@ -557,10 +539,6 @@ where R: Read
let val = self.deserialize_pre_tokenized_string()?;
visitor.visit_pre_tokenized_string(val)
}
ValueType::Geometry => {
let val = self.deserialize_geometry()?;
visitor.visit_geometry(val)
}
ValueType::Array => {
let access =
BinaryArrayDeserializer::from_reader(self.reader, self.doc_store_version)?;

View File

@@ -13,7 +13,6 @@ use crate::schema::document::{
};
use crate::schema::field_type::ValueParsingError;
use crate::schema::{Facet, Field, NamedFieldDocument, OwnedValue, Schema};
use crate::spatial::geometry::Geometry;
use crate::tokenizer::PreTokenizedString;
#[repr(C, packed)]
@@ -255,7 +254,6 @@ impl CompactDoc {
}
ReferenceValueLeaf::IpAddr(num) => write_into(&mut self.node_data, num.to_u128()),
ReferenceValueLeaf::PreTokStr(pre_tok) => write_into(&mut self.node_data, *pre_tok),
ReferenceValueLeaf::Geometry(geometry) => write_into(&mut self.node_data, *geometry),
};
ValueAddr { type_id, val_addr }
}
@@ -466,12 +464,6 @@ impl<'a> CompactDocValue<'a> {
.map(Into::into)
.map(ReferenceValueLeaf::PreTokStr)
.map(Into::into),
ValueType::Geometry => self
.container
.read_from::<Geometry>(addr)
.map(Into::into)
.map(ReferenceValueLeaf::Geometry)
.map(Into::into),
ValueType::Object => Ok(ReferenceValue::Object(CompactDocObjectIter::new(
self.container,
addr,
@@ -550,8 +542,6 @@ pub enum ValueType {
Object = 11,
/// Pre-tokenized str type,
Array = 12,
/// HUSH
Geometry = 13,
}
impl BinarySerializable for ValueType {
@@ -597,7 +587,6 @@ impl<'a> From<&ReferenceValueLeaf<'a>> for ValueType {
ReferenceValueLeaf::PreTokStr(_) => ValueType::PreTokStr,
ReferenceValueLeaf::Facet(_) => ValueType::Facet,
ReferenceValueLeaf::Bytes(_) => ValueType::Bytes,
ReferenceValueLeaf::Geometry(_) => ValueType::Geometry,
}
}
}

View File

@@ -273,5 +273,4 @@ pub(crate) mod type_codes {
// Extended type codes
pub const TOK_STR_EXT_CODE: u8 = 0;
pub const GEO_EXT_CODE: u8 = 1;
}

View File

@@ -15,7 +15,6 @@ use crate::schema::document::{
ValueDeserializer, ValueVisitor,
};
use crate::schema::Facet;
use crate::spatial::geometry::Geometry;
use crate::tokenizer::PreTokenizedString;
use crate::DateTime;
@@ -50,8 +49,6 @@ pub enum OwnedValue {
Object(Vec<(String, Self)>),
/// IpV6 Address. Internally there is no IpV4, it needs to be converted to `Ipv6Addr`.
IpAddr(Ipv6Addr),
/// A GeoRust multi-polygon.
Geometry(Geometry),
}
impl AsRef<OwnedValue> for OwnedValue {
@@ -61,6 +58,31 @@ impl AsRef<OwnedValue> for OwnedValue {
}
}
impl OwnedValue {
/// Returns a u8 discriminant value for the `OwnedValue` variant.
///
/// This can be used to sort `OwnedValue` instances by their type.
pub fn discriminant_value(&self) -> u8 {
match self {
OwnedValue::Null => 0,
OwnedValue::Str(_) => 1,
OwnedValue::PreTokStr(_) => 2,
// It is key to make sure U64, I64, F64 are grouped together in there, otherwise we
// might be breaking transivity.
OwnedValue::U64(_) => 3,
OwnedValue::I64(_) => 4,
OwnedValue::F64(_) => 5,
OwnedValue::Bool(_) => 6,
OwnedValue::Date(_) => 7,
OwnedValue::Facet(_) => 8,
OwnedValue::Bytes(_) => 9,
OwnedValue::Array(_) => 10,
OwnedValue::Object(_) => 11,
OwnedValue::IpAddr(_) => 12,
}
}
}
impl<'a> Value<'a> for &'a OwnedValue {
type ArrayIter = std::slice::Iter<'a, OwnedValue>;
type ObjectIter = ObjectMapIter<'a>;
@@ -80,9 +102,6 @@ impl<'a> Value<'a> for &'a OwnedValue {
OwnedValue::IpAddr(val) => ReferenceValueLeaf::IpAddr(*val).into(),
OwnedValue::Array(array) => ReferenceValue::Array(array.iter()),
OwnedValue::Object(object) => ReferenceValue::Object(ObjectMapIter(object.iter())),
OwnedValue::Geometry(geometry) => {
ReferenceValueLeaf::Geometry(Box::new(geometry.clone())).into()
}
}
}
}
@@ -142,10 +161,6 @@ impl ValueDeserialize for OwnedValue {
Ok(OwnedValue::PreTokStr(val))
}
fn visit_geometry(&self, val: Geometry) -> Result<Self::Value, DeserializeError> {
Ok(OwnedValue::Geometry(val))
}
fn visit_array<'de, A>(&self, mut access: A) -> Result<Self::Value, DeserializeError>
where A: ArrayAccess<'de> {
let mut elements = Vec::with_capacity(access.size_hint());
@@ -208,7 +223,6 @@ impl serde::Serialize for OwnedValue {
}
}
OwnedValue::Array(ref array) => array.serialize(serializer),
OwnedValue::Geometry(ref geometry) => geometry.to_geojson().serialize(serializer),
}
}
}
@@ -296,7 +310,6 @@ impl<'a, V: Value<'a>> From<ReferenceValue<'a, V>> for OwnedValue {
ReferenceValueLeaf::IpAddr(val) => OwnedValue::IpAddr(val),
ReferenceValueLeaf::Bool(val) => OwnedValue::Bool(val),
ReferenceValueLeaf::PreTokStr(val) => OwnedValue::PreTokStr(*val.clone()),
ReferenceValueLeaf::Geometry(val) => OwnedValue::Geometry(*val.clone()),
},
ReferenceValue::Array(val) => {
OwnedValue::Array(val.map(|v| v.as_value().into()).collect())

View File

@@ -133,10 +133,6 @@ where W: Write
self.write_type_code(type_codes::EXT_CODE)?;
self.serialize_with_type_code(type_codes::TOK_STR_EXT_CODE, &*val)
}
ReferenceValueLeaf::Geometry(val) => {
self.write_type_code(type_codes::EXT_CODE)?;
self.serialize_with_type_code(type_codes::GEO_EXT_CODE, &*val)
}
},
ReferenceValue::Array(elements) => {
self.write_type_code(type_codes::ARRAY_CODE)?;

View File

@@ -3,7 +3,6 @@ use std::net::Ipv6Addr;
use common::DateTime;
use crate::spatial::geometry::Geometry;
use crate::tokenizer::PreTokenizedString;
/// A single field value.
@@ -109,12 +108,6 @@ pub trait Value<'a>: Send + Sync + Debug {
None
}
}
#[inline]
/// HUSH
fn as_geometry(&self) -> Option<Box<Geometry>> {
self.as_leaf().and_then(|leaf| leaf.into_geometry())
}
}
/// A enum representing a leaf value for tantivy to index.
@@ -143,8 +136,6 @@ pub enum ReferenceValueLeaf<'a> {
Bool(bool),
/// Pre-tokenized str type,
PreTokStr(Box<PreTokenizedString>),
/// HUSH
Geometry(Box<Geometry>),
}
impl From<u64> for ReferenceValueLeaf<'_> {
@@ -229,9 +220,6 @@ impl<'a, T: Value<'a> + ?Sized> From<ReferenceValueLeaf<'a>> for ReferenceValue<
ReferenceValueLeaf::PreTokStr(val) => {
ReferenceValue::Leaf(ReferenceValueLeaf::PreTokStr(val))
}
ReferenceValueLeaf::Geometry(val) => {
ReferenceValue::Leaf(ReferenceValueLeaf::Geometry(val))
}
}
}
}
@@ -343,16 +331,6 @@ impl<'a> ReferenceValueLeaf<'a> {
None
}
}
#[inline]
/// HUSH
pub fn into_geometry(self) -> Option<Box<Geometry>> {
if let Self::Geometry(val) = self {
Some(val)
} else {
None
}
}
}
/// A enum representing a value for tantivy to index.
@@ -470,10 +448,4 @@ where V: Value<'a>
pub fn is_object(&self) -> bool {
matches!(self, Self::Object(_))
}
#[inline]
/// HUSH
pub fn into_geometry(self) -> Option<Box<Geometry>> {
self.into_leaf().and_then(|leaf| leaf.into_geometry())
}
}

View File

@@ -1,7 +1,6 @@
use serde::{Deserialize, Serialize};
use super::ip_options::IpAddrOptions;
use super::spatial_options::SpatialOptions;
use crate::schema::bytes_options::BytesOptions;
use crate::schema::{
is_valid_field_name, DateOptions, FacetOptions, FieldType, JsonObjectOptions, NumericOptions,
@@ -81,11 +80,6 @@ impl FieldEntry {
Self::new(field_name, FieldType::JsonObject(json_object_options))
}
/// Creates a field entry for a spatial field
pub fn new_spatial(field_name: String, spatial_options: SpatialOptions) -> FieldEntry {
Self::new(field_name, FieldType::Spatial(spatial_options))
}
/// Returns the name of the field
pub fn name(&self) -> &str {
&self.name
@@ -135,7 +129,6 @@ impl FieldEntry {
FieldType::Bytes(ref options) => options.is_stored(),
FieldType::JsonObject(ref options) => options.is_stored(),
FieldType::IpAddr(ref options) => options.is_stored(),
FieldType::Spatial(ref options) => options.is_stored(),
}
}
}

View File

@@ -9,7 +9,6 @@ use serde_json::Value as JsonValue;
use thiserror::Error;
use super::ip_options::IpAddrOptions;
use super::spatial_options::SpatialOptions;
use super::IntoIpv6Addr;
use crate::schema::bytes_options::BytesOptions;
use crate::schema::facet_options::FacetOptions;
@@ -17,7 +16,6 @@ use crate::schema::{
DateOptions, Facet, IndexRecordOption, JsonObjectOptions, NumericOptions, OwnedValue,
TextFieldIndexing, TextOptions,
};
use crate::spatial::geometry::Geometry;
use crate::time::format_description::well_known::Rfc3339;
use crate::time::OffsetDateTime;
use crate::tokenizer::PreTokenizedString;
@@ -73,8 +71,6 @@ pub enum Type {
Json = b'j',
/// IpAddr
IpAddr = b'p',
/// Spatial
Spatial = b't',
}
impl From<ColumnType> for Type {
@@ -143,7 +139,6 @@ impl Type {
Type::Bytes => "Bytes",
Type::Json => "Json",
Type::IpAddr => "IpAddr",
Type::Spatial => "Spatial",
}
}
@@ -194,8 +189,6 @@ pub enum FieldType {
JsonObject(JsonObjectOptions),
/// IpAddr field
IpAddr(IpAddrOptions),
/// Spatial field
Spatial(SpatialOptions),
}
impl FieldType {
@@ -212,7 +205,6 @@ impl FieldType {
FieldType::Bytes(_) => Type::Bytes,
FieldType::JsonObject(_) => Type::Json,
FieldType::IpAddr(_) => Type::IpAddr,
FieldType::Spatial(_) => Type::Spatial,
}
}
@@ -249,7 +241,6 @@ impl FieldType {
FieldType::Bytes(ref bytes_options) => bytes_options.is_indexed(),
FieldType::JsonObject(ref json_object_options) => json_object_options.is_indexed(),
FieldType::IpAddr(ref ip_addr_options) => ip_addr_options.is_indexed(),
FieldType::Spatial(ref _spatial_options) => true,
}
}
@@ -287,7 +278,6 @@ impl FieldType {
FieldType::IpAddr(ref ip_addr_options) => ip_addr_options.is_fast(),
FieldType::Facet(_) => true,
FieldType::JsonObject(ref json_object_options) => json_object_options.is_fast(),
FieldType::Spatial(_) => false,
}
}
@@ -307,7 +297,6 @@ impl FieldType {
FieldType::Bytes(ref bytes_options) => bytes_options.fieldnorms(),
FieldType::JsonObject(ref _json_object_options) => false,
FieldType::IpAddr(ref ip_addr_options) => ip_addr_options.fieldnorms(),
FieldType::Spatial(_) => false,
}
}
@@ -359,8 +348,6 @@ impl FieldType {
None
}
}
FieldType::Spatial(_) => None, /* Geometry types cannot be indexed in the inverted
* index. */
}
}
@@ -462,10 +449,6 @@ impl FieldType {
Ok(OwnedValue::IpAddr(ip_addr.into_ipv6_addr()))
}
FieldType::Spatial(_) => Err(ValueParsingError::TypeError {
expected: "spatial field parsing not implemented",
json: JsonValue::String(field_text),
}),
}
}
JsonValue::Number(field_val_num) => match self {
@@ -525,10 +508,6 @@ impl FieldType {
expected: "a string with an ip addr",
json: JsonValue::Number(field_val_num),
}),
FieldType::Spatial(_) => Err(ValueParsingError::TypeError {
expected: "spatial field parsing not implemented",
json: JsonValue::Number(field_val_num),
}),
},
JsonValue::Object(json_map) => match self {
FieldType::Str(_) => {
@@ -544,14 +523,6 @@ impl FieldType {
}
}
FieldType::JsonObject(_) => Ok(OwnedValue::from(json_map)),
FieldType::Spatial(_) => Ok(OwnedValue::Geometry(
Geometry::from_geojson(&json_map).map_err(|e| {
ValueParsingError::ParseError {
error: format!("{:?}", e),
json: JsonValue::Object(json_map),
}
})?,
)),
_ => Err(ValueParsingError::TypeError {
expected: self.value_type().name(),
json: JsonValue::Object(json_map),

View File

@@ -1,6 +1,6 @@
use std::ops::BitOr;
use crate::schema::{DateOptions, NumericOptions, SpatialOptions, TextOptions};
use crate::schema::{DateOptions, NumericOptions, TextOptions};
#[derive(Clone)]
pub struct StoredFlag;
@@ -95,14 +95,6 @@ impl<T: Clone + Into<TextOptions>> BitOr<TextOptions> for SchemaFlagList<T, ()>
}
}
impl<T: Clone + Into<SpatialOptions>> BitOr<SpatialOptions> for SchemaFlagList<T, ()> {
type Output = SpatialOptions;
fn bitor(self, rhs: SpatialOptions) -> Self::Output {
self.head.into() | rhs
}
}
#[derive(Clone)]
pub struct SchemaFlagList<Head: Clone, Tail: Clone> {
pub head: Head,

View File

@@ -98,6 +98,10 @@
//! make it possible to access the value given the doc id rapidly. This is useful if the value
//! of the field is required during scoring or collection for instance.
//!
//! Some queries may leverage Fast fields when run on a field that is not indexed. This can be
//! handy if that kind of request is infrequent, however note that searching on a Fast field is
//! generally much slower than searching in an index.
//!
//! ```
//! use tantivy::schema::*;
//! let mut schema_builder = Schema::builder();
@@ -124,7 +128,6 @@ mod ip_options;
mod json_object_options;
mod named_field_document;
mod numeric_options;
mod spatial_options;
mod text_options;
use columnar::ColumnType;
@@ -145,7 +148,6 @@ pub use self::json_object_options::JsonObjectOptions;
pub use self::named_field_document::NamedFieldDocument;
pub use self::numeric_options::NumericOptions;
pub use self::schema::{Schema, SchemaBuilder};
pub use self::spatial_options::{SpatialOptions, SPATIAL};
pub use self::term::{Term, ValueBytes};
pub use self::text_options::{TextFieldIndexing, TextOptions, STRING, TEXT};
@@ -170,7 +172,6 @@ pub(crate) fn value_type_to_column_type(typ: Type) -> Option<ColumnType> {
Type::Bytes => Some(ColumnType::Bytes),
Type::IpAddr => Some(ColumnType::IpAddr),
Type::Json => None,
Type::Spatial => None,
}
}

View File

@@ -194,16 +194,6 @@ impl SchemaBuilder {
self.add_field(field_entry)
}
/// Adds a spatial entry to the schema in build.
pub fn add_spatial_field<T: Into<SpatialOptions>>(
&mut self,
field_name: &str,
field_options: T,
) -> Field {
let field_entry = FieldEntry::new_spatial(field_name.to_string(), field_options.into());
self.add_field(field_entry)
}
/// Adds a field entry to the schema in build.
pub fn add_field(&mut self, field_entry: FieldEntry) -> Field {
let field = Field::from_field_id(self.fields.len() as u32);
@@ -218,14 +208,9 @@ impl SchemaBuilder {
/// Finalize the creation of a `Schema`
/// This will consume your `SchemaBuilder`
pub fn build(self) -> Schema {
let contains_spatial_field = self
.fields
.iter()
.any(|field_entry| field_entry.field_type().value_type() == Type::Spatial);
Schema(Arc::new(InnerSchema {
fields: self.fields,
fields_map: self.fields_map,
contains_spatial_field,
}))
}
}
@@ -233,7 +218,6 @@ impl SchemaBuilder {
struct InnerSchema {
fields: Vec<FieldEntry>,
fields_map: HashMap<String, Field>, // transient
contains_spatial_field: bool,
}
impl PartialEq for InnerSchema {
@@ -384,11 +368,6 @@ impl Schema {
}
Some((field, json_path))
}
/// Returns true if the schema contains a spatial field.
pub(crate) fn contains_spatial_field(&self) -> bool {
self.0.contains_spatial_field
}
}
impl Serialize for Schema {
@@ -416,16 +395,16 @@ impl<'de> Deserialize<'de> for Schema {
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where A: SeqAccess<'de> {
let mut schema_builder = SchemaBuilder {
let mut schema = SchemaBuilder {
fields: Vec::with_capacity(seq.size_hint().unwrap_or(0)),
fields_map: HashMap::with_capacity(seq.size_hint().unwrap_or(0)),
};
while let Some(value) = seq.next_element()? {
schema_builder.add_field(value);
schema.add_field(value);
}
Ok(schema_builder.build())
Ok(schema.build())
}
}
@@ -1041,33 +1020,4 @@ mod tests {
Some((default, "foobar"))
);
}
#[test]
fn test_contains_spatial_field() {
// No spatial field
{
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("title", TEXT);
let schema = schema_builder.build();
assert!(!schema.contains_spatial_field());
// Serialization check
let schema_json = serde_json::to_string(&schema).unwrap();
let schema_deserialized: Schema = serde_json::from_str(&schema_json).unwrap();
assert!(!schema_deserialized.contains_spatial_field());
}
// With spatial field
{
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("title", TEXT);
schema_builder.add_spatial_field("location", SPATIAL);
let schema = schema_builder.build();
assert!(schema.contains_spatial_field());
// Serialization check
let schema_json = serde_json::to_string(&schema).unwrap();
let schema_deserialized: Schema = serde_json::from_str(&schema_json).unwrap();
assert!(schema_deserialized.contains_spatial_field());
}
}
}

View File

@@ -1,53 +0,0 @@
use std::ops::BitOr;
use serde::{Deserialize, Serialize};
use crate::schema::flags::StoredFlag;
/// Define how a spatial field should be handled by tantivy.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Default)]
pub struct SpatialOptions {
#[serde(default)]
stored: bool,
}
/// The field will be untokenized and indexed.
pub const SPATIAL: SpatialOptions = SpatialOptions { stored: false };
impl SpatialOptions {
/// Returns true if the geometry is to be stored.
#[inline]
pub fn is_stored(&self) -> bool {
self.stored
}
}
impl<T: Into<SpatialOptions>> BitOr<T> for SpatialOptions {
type Output = SpatialOptions;
fn bitor(self, other: T) -> SpatialOptions {
let other = other.into();
SpatialOptions {
stored: self.stored | other.stored,
}
}
}
impl From<StoredFlag> for SpatialOptions {
fn from(_: StoredFlag) -> SpatialOptions {
SpatialOptions { stored: true }
}
}
// #[cfg(test)]
// mod tests {
// use crate::schema::*;
//
// #[test]
// fn test_field_options() {
// let field_options = STORED | SPATIAL;
// assert!(field_options.is_stored());
// let mut schema_builder = Schema::builder();
// schema_builder.add_spatial_index("where", SPATIAL | STORED);
// }
// }

Some files were not shown because too many files have changed in this diff Show More