Compare commits

..

11 Commits

Author SHA1 Message Date
Pascal Seitz
b345c11786 add key_as_string for numbers in term agg 2024-07-25 13:20:56 +08:00
PSeitz
7ebcc15b17 add support for str fast field range query (#2453)
* add support for str fast field range query

Add support for range queries on fast fields, by converting term bounds to
term ordinals bounds.

closes https://github.com/quickwit-oss/tantivy/issues/2023

* extend tests, rename

* update comment

* update comment
2024-07-17 09:31:42 +08:00
PSeitz
1b4076691f refactor fast field query (#2452)
As preparation of #2023 and #1709

* Use Term to pass parameters
* merge u64 and ip fast field range query

Side note: I did not rename range_query_u64_fastfield, because then git can't track the changes.
2024-07-15 18:08:05 +08:00
Robert Caulk
eab660873a doc: fix typo in readme (#2450) 2024-07-09 15:12:22 +08:00
PSeitz
232f37126e fix coverage (#2448) 2024-07-05 12:04:18 +08:00
PSeitz
13e9885dfd faster term aggregation fetch terms (#2447)
big impact for term aggregations with large `size` parameter (e.g. 1000)
add top 1000 term agg bench

full
terms_few                                      Memory: 27.3 KB (+79.09%)    Avg: 3.8058ms (+2.40%)      Median: 3.7192ms (+3.47%)       [3.6224ms .. 4.3721ms]
terms_many                                     Memory: 6.9 MB               Avg: 12.6102ms (-4.70%)     Median: 12.1389ms (-6.58%)      [10.2847ms .. 15.4857ms]
terms_many_top_1000                            Memory: 6.9 MB               Avg: 15.8216ms (-83.19%)    Median: 15.4899ms (-83.46%)     [13.4250ms .. 20.6897ms]
terms_many_order_by_term                       Memory: 6.9 MB               Avg: 14.7820ms (-3.95%)     Median: 14.2236ms (-4.28%)      [12.6669ms .. 21.0968ms]
terms_many_with_top_hits                       Memory: 58.2 MB              Avg: 551.6218ms (+7.18%)    Median: 549.8826ms (+11.01%)    [496.7371ms .. 592.1299ms]
terms_many_with_avg_sub_agg                    Memory: 27.8 MB              Avg: 197.7029ms (+2.66%)    Median: 190.1564ms (+0.64%)     [167.9226ms .. 245.6651ms]
terms_many_json_mixed_type_with_avg_sub_agg    Memory: 42.0 MB (+0.00%)     Avg: 242.0121ms (+0.92%)    Median: 237.7084ms (-2.85%)     [201.9959ms .. 302.2136ms]
terms_few_with_cardinality_agg                 Memory: 10.6 MB              Avg: 122.6036ms (+1.21%)    Median: 119.0033ms (+2.60%)     [109.2859ms .. 161.5858ms]
range_agg_with_term_agg_few                    Memory: 45.4 KB (+39.75%)    Avg: 24.5454ms (+2.14%)     Median: 24.2861ms (+2.44%)      [23.5109ms .. 27.8406ms]
range_agg_with_term_agg_many                   Memory: 6.9 MB               Avg: 56.8049ms (+3.01%)     Median: 50.9706ms (+1.52%)      [41.4517ms .. 90.3934ms]
dense
terms_few                                      Memory: 28.8 KB (+81.74%)    Avg: 8.9092ms (-2.24%)      Median: 8.7143ms (-1.31%)      [8.6148ms .. 10.3868ms]
terms_many                                     Memory: 6.9 MB (-0.00%)      Avg: 17.9604ms (-10.18%)    Median: 17.1552ms (-11.93%)    [14.8979ms .. 26.2779ms]
terms_many_top_1000                            Memory: 6.9 MB               Avg: 21.4963ms (-78.90%)    Median: 21.2924ms (-78.98%)    [18.2033ms .. 28.0087ms]
terms_many_order_by_term                       Memory: 6.9 MB               Avg: 20.4167ms (-9.13%)     Median: 19.5596ms (-11.37%)    [17.5153ms .. 29.5987ms]
terms_many_with_top_hits                       Memory: 58.2 MB              Avg: 518.4474ms (-6.41%)    Median: 514.9180ms (-9.44%)    [471.5550ms .. 579.0220ms]
terms_many_with_avg_sub_agg                    Memory: 27.8 MB              Avg: 263.6702ms (-2.78%)    Median: 260.8775ms (-2.55%)    [239.5754ms .. 304.6669ms]
terms_many_json_mixed_type_with_avg_sub_agg    Memory: 42.0 MB              Avg: 299.9791ms (-2.01%)    Median: 302.2180ms (-3.08%)    [239.2080ms .. 346.3649ms]
terms_few_with_cardinality_agg                 Memory: 10.6 MB              Avg: 136.3303ms (-3.12%)    Median: 132.3831ms (-2.88%)    [123.7564ms .. 164.7914ms]
range_agg_with_term_agg_few                    Memory: 47.1 KB (+37.81%)    Avg: 35.4538ms (+0.66%)     Median: 34.8754ms (-0.56%)     [34.2287ms .. 40.0884ms]
range_agg_with_term_agg_many                   Memory: 6.9 MB               Avg: 72.2269ms (-4.38%)     Median: 66.1174ms (-4.98%)     [55.5125ms .. 124.1622ms]
sparse
terms_few                                      Memory: 27.3 KB (+69.68%)    Avg: 19.6053ms (-1.15%)     Median: 19.4543ms (-0.38%)     [19.3056ms .. 24.0547ms]
terms_many                                     Memory: 1.8 MB               Avg: 21.2886ms (-6.28%)     Median: 21.1287ms (-6.65%)     [20.6640ms .. 24.6144ms]
terms_many_top_1000                            Memory: 2.6 MB               Avg: 23.4869ms (-85.53%)    Median: 23.3393ms (-85.61%)    [22.7789ms .. 25.0896ms]
terms_many_order_by_term                       Memory: 1.8 MB               Avg: 21.7437ms (-7.78%)     Median: 21.6272ms (-7.66%)     [21.0409ms .. 23.6517ms]
terms_many_with_top_hits                       Memory: 13.1 MB              Avg: 43.7926ms (-2.76%)     Median: 44.3602ms (+0.01%)     [37.8039ms .. 51.0451ms]
terms_many_with_avg_sub_agg                    Memory: 7.5 MB               Avg: 34.6307ms (+3.72%)     Median: 33.4522ms (+1.16%)     [32.4418ms .. 41.4196ms]
terms_many_json_mixed_type_with_avg_sub_agg    Memory: 7.4 MB               Avg: 46.4318ms (+1.16%)     Median: 46.4050ms (+2.03%)     [44.5986ms .. 48.5142ms]
terms_few_with_cardinality_agg                 Memory: 680.0 KB (-0.04%)    Avg: 35.4410ms (+2.05%)     Median: 35.1384ms (+1.19%)     [34.4402ms .. 39.1082ms]
range_agg_with_term_agg_few                    Memory: 45.7 KB (+39.44%)    Avg: 22.7760ms (+0.44%)     Median: 22.5152ms (-0.35%)     [22.3078ms .. 26.1567ms]
range_agg_with_term_agg_many                   Memory: 1.8 MB               Avg: 25.7696ms (-4.45%)     Median: 25.4009ms (-5.61%)     [24.7874ms .. 29.6434ms]
multivalue
terms_few                                      Memory: 244.4 KB            Avg: 15.1253ms (-2.85%)     Median: 15.0988ms (-0.54%)     [14.8790ms .. 15.8193ms]
terms_many                                     Memory: 6.9 MB (-0.00%)     Avg: 26.3019ms (-6.24%)     Median: 26.3662ms (-4.94%)     [21.3553ms .. 31.0564ms]
terms_many_top_1000                            Memory: 6.9 MB              Avg: 29.5212ms (-72.90%)    Median: 29.4257ms (-72.84%)    [24.2645ms .. 35.1607ms]
terms_many_order_by_term                       Memory: 6.9 MB              Avg: 28.6076ms (-4.93%)     Median: 28.1059ms (-6.64%)     [24.0845ms .. 34.1493ms]
terms_many_with_top_hits                       Memory: 58.3 MB             Avg: 570.1548ms (+1.52%)    Median: 572.7759ms (+0.53%)    [525.9567ms .. 617.0862ms]
terms_many_with_avg_sub_agg                    Memory: 27.8 MB             Avg: 305.5207ms (+0.24%)    Median: 296.0101ms (-0.22%)    [277.8579ms .. 373.5914ms]
terms_many_json_mixed_type_with_avg_sub_agg    Memory: 42.0 MB (-0.00%)    Avg: 324.7342ms (-2.51%)    Median: 319.0025ms (-2.58%)    [298.7122ms .. 368.6144ms]
terms_few_with_cardinality_agg                 Memory: 10.8 MB             Avg: 151.6126ms (-2.54%)    Median: 149.0616ms (-0.32%)    [136.5592ms .. 181.8942ms]
range_agg_with_term_agg_few                    Memory: 248.2 KB            Avg: 49.5225ms (+3.11%)     Median: 48.3994ms (+3.18%)     [46.4134ms .. 60.5989ms]
range_agg_with_term_agg_many                   Memory: 6.9 MB              Avg: 85.9824ms (-3.66%)     Median: 78.4266ms (-3.85%)     [64.1231ms .. 128.5279ms]
2024-07-03 12:42:59 +08:00
PSeitz
56d79cb203 fix cardinality aggregation performance (#2446)
* fix cardinality aggregation performance

fix cardinality performance by fetching multiple terms at once. This
avoids decompressing the same block and keeps the buffer state between
terms.

add cardinality aggregation benchmark

bump rust version to 1.66

Performance comparison to before (AllQuery)
```
full
cardinality_agg                   Memory: 3.5 MB (-0.00%)    Avg: 21.2256ms (-97.78%)    Median: 21.0042ms (-97.82%)    [20.4717ms .. 23.6206ms]
terms_few_with_cardinality_agg    Memory: 10.6 MB            Avg: 81.9293ms (-97.37%)    Median: 81.5526ms (-97.38%)    [79.7564ms .. 88.0374ms]
dense
cardinality_agg                   Memory: 3.6 MB (-0.00%)    Avg: 25.9372ms (-97.24%)    Median: 25.7744ms (-97.25%)    [24.7241ms .. 27.8793ms]
terms_few_with_cardinality_agg    Memory: 10.6 MB            Avg: 93.9897ms (-96.91%)    Median: 92.7821ms (-96.94%)    [90.3312ms .. 117.4076ms]
sparse
cardinality_agg                   Memory: 895.4 KB (-0.00%)    Avg: 22.5113ms (-95.01%)    Median: 22.5629ms (-94.99%)    [22.1628ms .. 22.9436ms]
terms_few_with_cardinality_agg    Memory: 680.2 KB             Avg: 26.4250ms (-94.85%)    Median: 26.4135ms (-94.86%)    [26.3210ms .. 26.6774ms]
```

* clippy

* assert for sorted ordinals
2024-07-02 15:29:00 +08:00
Paul Masurel
0f4c2e27cf Fixes bug that causes out-of-order sstable key. (#2445)
The previous way to address the problem was to replace \u{0000}
with 0 in different places.

This logic had several flaws:
Done on the serializer side (like it was for the columnar), there was
a collision problem.

If a document in the segment contained a json field with a \0 and
antoher doc contained the same json field but `0` then we were sending
the same field path twice to the serializer.

Another option would have been to normalizes all values on the writer
side.

This PR simplifies the logic and simply ignore json path containing a
\0, both in the columnar and the inverted index.

Closes #2442
2024-07-01 15:40:07 +08:00
落叶乌龟
f9ae295507 feat(query): Make BooleanQuery supports minimum_number_should_match (#2405)
* feat(query): Make `BooleanQuery` supports `minimum_number_should_match`. see issue #2398

In this commit, a novel scorer named DisjunctionScorer is introduced, which performs the union of inverted chains with the minimal required elements. BTW, it's implemented via a min-heap. Necessary modifications on `BooleanQuery` and `BooleanWeight` are performed as well.

* fixup! fix test

* fixup!: refactor code.

1. More meaningful names.
2. Add Cache for `Disjunction`'s scorers, and fix bug.
3. Optimize `BooleanWeight::complex_scorer`

Thanks
 Paul Masurel <paul@quickwit.io>

* squash!: come up with better variable naming.

* squash!: fix naming issues.

* squash!: fix typo.

* squash!: Remove CombinationMethod::FullIntersection
2024-07-01 15:39:41 +08:00
Raphael Coeffic
d9db5302d9 feat: cardinality aggregation (#2337)
* WiP: cardinality aggregation

* Collect unique entries first, then insert into HyperLogLog

* Handle `missing`

* Hybrid approach

* Review changes

- insert `missing` value at most once
- `term_id` -> `term_ord`
- iterate directly over entries without collecting first

* Use salted hasher to include column type

* fix: formatting

* More review fixes

* Add cardinality to test_aggregation_flushing

* Formatting
2024-07-01 07:49:42 +08:00
Paul Masurel
e453848134 Recycling buffer in PrefixPhraseScorer (#2443) 2024-06-24 17:11:53 +09:00
54 changed files with 2513 additions and 1221 deletions

View File

@@ -15,11 +15,11 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: rustup toolchain install nightly-2024-04-10 --profile minimal --component llvm-tools-preview
run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
run: cargo +nightly-2024-04-10 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
continue-on-error: true

View File

@@ -11,7 +11,7 @@ repository = "https://github.com/quickwit-oss/tantivy"
readme = "README.md"
keywords = ["search", "information", "retrieval"]
edition = "2021"
rust-version = "1.63"
rust-version = "1.66"
exclude = ["benches/*.json", "benches/*.txt"]
[dependencies]
@@ -38,7 +38,7 @@ levenshtein_automata = "0.2.1"
uuid = { version = "1.0.0", features = ["v4", "serde"] }
crossbeam-channel = "0.5.4"
rust-stemmers = "1.2.0"
downcast-rs = "1.2.0"
downcast-rs = "1.2.1"
bitpacking = { version = "0.9.2", default-features = false, features = [
"bitpacker4x",
] }
@@ -64,6 +64,7 @@ tantivy-bitpacker = { version = "0.6", path = "./bitpacker" }
common = { version = "0.7", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version = "0.3", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
sketches-ddsketch = { version = "0.3.0", features = ["use_serde"] }
hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
futures-util = { version = "0.3.28", optional = true }
fnv = "1.0.7"

View File

@@ -18,7 +18,7 @@ Tantivy is, in fact, strongly inspired by Lucene's design.
## Benchmark
The following [benchmark](https://tantivy-search.github.io/bench/) breakdowns
The following [benchmark](https://tantivy-search.github.io/bench/) breaks down the
performance for different types of queries/collections.
Your mileage WILL vary depending on the nature of queries and their load.

View File

@@ -51,10 +51,15 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, percentiles_f64);
register!(group, terms_few);
register!(group, terms_many);
register!(group, terms_many_top_1000);
register!(group, terms_many_order_by_term);
register!(group, terms_many_with_top_hits);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_many_json_mixed_type_with_sub_agg_card);
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, cardinality_agg);
register!(group, terms_few_with_cardinality_agg);
register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg);
register!(group, range_agg_with_term_agg_few);
@@ -123,6 +128,33 @@ fn percentiles_f64(index: &Index) {
});
execute_agg(index, agg_req);
}
fn cardinality_agg(index: &Index) {
let agg_req = json!({
"cardinality": {
"cardinality": {
"field": "text_many_terms"
},
}
});
execute_agg(index, agg_req);
}
fn terms_few_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"aggs": {
"cardinality": {
"cardinality": {
"field": "text_many_terms"
},
}
}
},
});
execute_agg(index, agg_req);
}
fn terms_few(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms" } },
@@ -135,6 +167,12 @@ fn terms_many(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_many_top_1000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms", "size": 1000 } },
});
execute_agg(index, agg_req);
}
fn terms_many_order_by_term(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms", "order": { "_key": "desc" } } },
@@ -171,7 +209,7 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_many_json_mixed_type_with_sub_agg_card(index: &Index) {
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "json.mixed_type" },
@@ -268,6 +306,7 @@ fn range_agg_with_term_agg_many(index: &Index) {
});
execute_agg(index, agg_req);
}
fn histogram(index: &Index) {
let agg_req = json!({
"rangef64": {

View File

@@ -34,6 +34,7 @@ fn compute_stats(vals: impl Iterator<Item = u64>) -> ColumnStats {
fn value_iter() -> impl Iterator<Item = u64> {
0..20_000
}
fn get_reader_for_bench<Codec: ColumnCodec>(data: &[u64]) -> Codec::ColumnValues {
let mut bytes = Vec::new();
let stats = compute_stats(data.iter().cloned());
@@ -41,10 +42,13 @@ fn get_reader_for_bench<Codec: ColumnCodec>(data: &[u64]) -> Codec::ColumnValues
for val in data {
codec_serializer.collect(*val);
}
codec_serializer.serialize(&stats, Box::new(data.iter().copied()).as_mut(), &mut bytes);
codec_serializer
.serialize(&stats, Box::new(data.iter().copied()).as_mut(), &mut bytes)
.unwrap();
Codec::load(OwnedBytes::new(bytes)).unwrap()
}
fn bench_get<Codec: ColumnCodec>(b: &mut Bencher, data: &[u64]) {
let col = get_reader_for_bench::<Codec>(data);
b.iter(|| {

View File

@@ -8,6 +8,7 @@ use std::net::Ipv6Addr;
use column_operation::ColumnOperation;
pub(crate) use column_writers::CompatibleNumericalTypes;
use common::json_path_writer::JSON_END_OF_PATH;
use common::CountingWriter;
pub(crate) use serializer::ColumnarSerializer;
use stacker::{Addr, ArenaHashMap, MemoryArena};
@@ -283,12 +284,17 @@ impl ColumnarWriter {
.iter()
.map(|(column_name, addr)| (column_name, ColumnType::DateTime, addr)),
);
// TODO: replace JSON_END_OF_PATH with b'0' in columns
columns.sort_unstable_by_key(|(column_name, col_type, _)| (*column_name, *col_type));
let (arena, buffers, dictionaries) = (&self.arena, &mut self.buffers, &self.dictionaries);
let mut symbol_byte_buffer: Vec<u8> = Vec::new();
for (column_name, column_type, addr) in columns {
if column_name.contains(&JSON_END_OF_PATH) {
// Tantivy uses b'0' as a separator for nested fields in JSON.
// Column names with a b'0' are not simply ignored by the columnar (and the inverted
// index).
continue;
}
match column_type {
ColumnType::Bool => {
let column_writer: ColumnWriter = self.bool_field_hash_map.read(addr);

View File

@@ -93,18 +93,3 @@ impl<'a, W: io::Write> io::Write for ColumnSerializer<'a, W> {
self.columnar_serializer.wrt.write_all(buf)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prepare_key_bytes() {
let mut buffer: Vec<u8> = b"somegarbage".to_vec();
prepare_key(b"root\0child", ColumnType::Str, &mut buffer);
assert_eq!(buffer.len(), 12);
assert_eq!(&buffer[..10], b"root0child");
assert_eq!(buffer[10], 0u8);
assert_eq!(buffer[11], ColumnType::Str.to_code());
}
}

View File

@@ -9,7 +9,6 @@ documentation = "https://docs.rs/tantivy_common/"
homepage = "https://github.com/quickwit-oss/tantivy"
repository = "https://github.com/quickwit-oss/tantivy"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
@@ -20,8 +19,7 @@ time = { version = "0.3.10", features = ["serde-well-known"] }
serde = { version = "1.0.136", features = ["derive"] }
[dev-dependencies]
binggan = "0.8.1"
proptest = "1.0.0"
rand = "0.8.4"
[features]
unstable = [] # useful for benches.

View File

@@ -1,39 +1,64 @@
#![feature(test)]
use binggan::{black_box, BenchRunner};
use rand::seq::IteratorRandom;
use rand::thread_rng;
use tantivy_common::{serialize_vint_u32, BitSet, TinySet};
extern crate test;
fn bench_vint() {
let mut runner = BenchRunner::new();
#[cfg(test)]
mod tests {
use rand::seq::IteratorRandom;
use rand::thread_rng;
use tantivy_common::serialize_vint_u32;
use test::Bencher;
let vals: Vec<u32> = (0..20_000).collect();
runner.bench_function("bench_vint", move |_| {
let mut out = 0u64;
for val in vals.iter().cloned() {
let mut buf = [0u8; 8];
serialize_vint_u32(val, &mut buf);
out += u64::from(buf[0]);
}
black_box(out);
});
#[bench]
fn bench_vint(b: &mut Bencher) {
let vals: Vec<u32> = (0..20_000).collect();
b.iter(|| {
let mut out = 0u64;
for val in vals.iter().cloned() {
let mut buf = [0u8; 8];
serialize_vint_u32(val, &mut buf);
out += u64::from(buf[0]);
}
out
});
}
#[bench]
fn bench_vint_rand(b: &mut Bencher) {
let vals: Vec<u32> = (0..20_000).choose_multiple(&mut thread_rng(), 100_000);
b.iter(|| {
let mut out = 0u64;
for val in vals.iter().cloned() {
let mut buf = [0u8; 8];
serialize_vint_u32(val, &mut buf);
out += u64::from(buf[0]);
}
out
});
}
let vals: Vec<u32> = (0..20_000).choose_multiple(&mut thread_rng(), 100_000);
runner.bench_function("bench_vint_rand", move |_| {
let mut out = 0u64;
for val in vals.iter().cloned() {
let mut buf = [0u8; 8];
serialize_vint_u32(val, &mut buf);
out += u64::from(buf[0]);
}
black_box(out);
});
}
fn bench_bitset() {
let mut runner = BenchRunner::new();
runner.bench_function("bench_tinyset_pop", move |_| {
let mut tinyset = TinySet::singleton(black_box(31u32));
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
black_box(tinyset);
});
let tiny_set = TinySet::empty().insert(10u32).insert(14u32).insert(21u32);
runner.bench_function("bench_tinyset_sum", move |_| {
assert_eq!(black_box(tiny_set).into_iter().sum::<u32>(), 45u32);
});
let v = [10u32, 14u32, 21u32];
runner.bench_function("bench_tinyarr_sum", move |_| {
black_box(v.iter().cloned().sum::<u32>());
});
runner.bench_function("bench_bitset_initialize", move |_| {
black_box(BitSet::with_max_value(1_000_000));
});
}
fn main() {
bench_vint();
bench_bitset();
}

View File

@@ -696,43 +696,3 @@ mod tests {
}
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use test;
use super::{BitSet, TinySet};
#[bench]
fn bench_tinyset_pop(b: &mut test::Bencher) {
b.iter(|| {
let mut tinyset = TinySet::singleton(test::black_box(31u32));
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
});
}
#[bench]
fn bench_tinyset_sum(b: &mut test::Bencher) {
let tiny_set = TinySet::empty().insert(10u32).insert(14u32).insert(21u32);
b.iter(|| {
assert_eq!(test::black_box(tiny_set).into_iter().sum::<u32>(), 45u32);
});
}
#[bench]
fn bench_tinyarr_sum(b: &mut test::Bencher) {
let v = [10u32, 14u32, 21u32];
b.iter(|| test::black_box(v).iter().cloned().sum::<u32>());
}
#[bench]
fn bench_bitset_initialize(b: &mut test::Bencher) {
b.iter(|| BitSet::with_max_value(1_000_000));
}
}

View File

@@ -1,3 +1,5 @@
use std::ops::Bound;
// # Searching a range on an indexed int field.
//
// Below is an example of creating an indexed integer field in your schema
@@ -5,7 +7,7 @@
use tantivy::collector::Count;
use tantivy::query::RangeQuery;
use tantivy::schema::{Schema, INDEXED};
use tantivy::{doc, Index, IndexWriter, Result};
use tantivy::{doc, Index, IndexWriter, Result, Term};
fn main() -> Result<()> {
// For the sake of simplicity, this schema will only have 1 field
@@ -27,7 +29,10 @@ fn main() -> Result<()> {
reader.reload()?;
let searcher = reader.searcher();
// The end is excluded i.e. here we are searching up to 1969
let docs_in_the_sixties = RangeQuery::new_u64("year".to_string(), 1960..1970);
let docs_in_the_sixties = RangeQuery::new(
Bound::Included(Term::from_field_u64(year_field, 1960)),
Bound::Excluded(Term::from_field_u64(year_field, 1970)),
);
// Uses a Count collector to sum the total number of docs in the range
let num_60s_books = searcher.search(&docs_in_the_sixties, &Count)?;
assert_eq!(num_60s_books, 10);

View File

@@ -34,8 +34,9 @@ use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
};
use super::metric::{
AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation,
PercentilesAggregationReq, StatsAggregation, SumAggregation, TopHitsAggregation,
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
MaxAggregation, MinAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation,
TopHitsAggregationReq,
};
/// The top-level aggregation request structure, which contains [`Aggregation`] and their user
@@ -159,7 +160,10 @@ pub enum AggregationVariants {
Percentiles(PercentilesAggregationReq),
/// Finds the top k values matching some order
#[serde(rename = "top_hits")]
TopHits(TopHitsAggregation),
TopHits(TopHitsAggregationReq),
/// Computes an estimate of the number of unique values
#[serde(rename = "cardinality")]
Cardinality(CardinalityAggregationReq),
}
impl AggregationVariants {
@@ -179,6 +183,7 @@ impl AggregationVariants {
AggregationVariants::Sum(sum) => vec![sum.field_name()],
AggregationVariants::Percentiles(per) => vec![per.field_name()],
AggregationVariants::TopHits(top_hits) => top_hits.field_names(),
AggregationVariants::Cardinality(per) => vec![per.field_name()],
}
}
@@ -203,7 +208,7 @@ impl AggregationVariants {
_ => None,
}
}
pub(crate) fn as_top_hits(&self) -> Option<&TopHitsAggregation> {
pub(crate) fn as_top_hits(&self) -> Option<&TopHitsAggregationReq> {
match &self {
AggregationVariants::TopHits(top_hits) => Some(top_hits),
_ => None,

View File

@@ -11,8 +11,8 @@ use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
};
use super::metric::{
AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation,
StatsAggregation, SumAggregation,
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
MaxAggregation, MinAggregation, StatsAggregation, SumAggregation,
};
use super::segment_agg_result::AggregationLimits;
use super::VecWithNames;
@@ -162,6 +162,11 @@ impl AggregationWithAccessor {
field: ref field_name,
ref missing,
..
})
| Cardinality(CardinalityAggregationReq {
field: ref field_name,
ref missing,
..
}) => {
let str_dict_column = reader.fast_fields().str(field_name)?;
let allowed_column_types = [

View File

@@ -98,6 +98,8 @@ pub enum MetricResult {
Percentiles(PercentilesMetricResult),
/// Top hits metric result
TopHits(TopHitsMetricResult),
/// Cardinality metric result
Cardinality(SingleMetricResult),
}
impl MetricResult {
@@ -116,6 +118,7 @@ impl MetricResult {
MetricResult::TopHits(_) => Err(TantivyError::AggregationError(
AggregationError::InvalidRequest("top_hits can't be used to order".to_string()),
)),
MetricResult::Cardinality(card) => Ok(card.value),
}
}
}

View File

@@ -110,6 +110,16 @@ fn test_aggregation_flushing(
}
}
}
},
"cardinality_string_id":{
"cardinality": {
"field": "string_id"
}
},
"cardinality_score":{
"cardinality": {
"field": "score"
}
}
});
@@ -212,6 +222,9 @@ fn test_aggregation_flushing(
)
);
assert_eq!(res["cardinality_string_id"]["value"], 2.0);
assert_eq!(res["cardinality_score"]["value"], 80.0);
Ok(())
}
@@ -926,10 +939,10 @@ fn test_aggregation_on_json_object_mixed_types() {
},
"termagg": {
"buckets": [
{ "doc_count": 1, "key": 10.0, "min_price": { "value": 10.0 } },
{ "doc_count": 1, "key": 10.0, "key_as_string": "10", "min_price": { "value": 10.0 } },
{ "doc_count": 3, "key": "blue", "min_price": { "value": 5.0 } },
{ "doc_count": 2, "key": "red", "min_price": { "value": 1.0 } },
{ "doc_count": 1, "key": -20.5, "min_price": { "value": -20.5 } },
{ "doc_count": 1, "key": -20.5, "key_as_string": "-20.5", "min_price": { "value": -20.5 } },
{ "doc_count": 2, "key": 1.0, "key_as_string": "true", "min_price": { "value": null } },
],
"sum_other_doc_count": 0

View File

@@ -1,10 +1,9 @@
use std::fmt::Debug;
use std::io;
use std::net::Ipv6Addr;
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{
BytesColumn, ColumnType, MonotonicallyMappableToU128, MonotonicallyMappableToU64, StrColumn,
};
use columnar::{ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
@@ -466,49 +465,66 @@ impl SegmentTermCollector {
};
if self.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let term_dict = agg_with_accessor
.str_dict_column
.as_ref()
.cloned()
.unwrap_or_else(|| {
StrColumn::wrap(BytesColumn::empty(agg_with_accessor.accessor.num_docs()))
});
let mut buffer = String::new();
for (term_id, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(term_id, doc_count)?;
// Special case for missing key
if term_id == u64::MAX {
let missing_key = self
.req
.missing
.as_ref()
.expect("Found placeholder term_id but `missing` is None");
match missing_key {
Key::Str(missing) => {
buffer.clear();
buffer.push_str(missing);
dict.insert(
IntermediateKey::Str(buffer.to_string()),
intermediate_entry,
);
}
Key::F64(val) => {
buffer.push_str(&val.to_string());
dict.insert(IntermediateKey::F64(*val), intermediate_entry);
}
.map(|el| el.dictionary())
.unwrap_or_else(|| &fallback_dict);
let mut buffer = Vec::new();
// special case for missing key
if let Some(index) = entries.iter().position(|value| value.0 == u64::MAX) {
let entry = entries[index];
let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1)?;
let missing_key = self
.req
.missing
.as_ref()
.expect("Found placeholder term_id but `missing` is None");
match missing_key {
Key::Str(missing) => {
buffer.clear();
buffer.extend_from_slice(missing.as_bytes());
dict.insert(
IntermediateKey::Str(
String::from_utf8(buffer.to_vec())
.expect("could not convert to String"),
),
intermediate_entry,
);
}
} else {
if !term_dict.ord_to_str(term_id, &mut buffer)? {
return Err(TantivyError::InternalError(format!(
"Couldn't find term_id {term_id} in dict"
)));
Key::F64(val) => {
dict.insert(IntermediateKey::F64(*val), intermediate_entry);
}
dict.insert(IntermediateKey::Str(buffer.to_string()), intermediate_entry);
}
entries.swap_remove(index);
}
// Sort by term ord
entries.sort_unstable_by_key(|bucket| bucket.0);
let mut idx = 0;
term_dict.sorted_ords_to_term_cb(
entries.iter().map(|(term_id, _)| *term_id),
|term| {
let entry = entries[idx];
let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
dict.insert(
IntermediateKey::Str(
String::from_utf8(term.to_vec()).expect("could not convert to String"),
),
intermediate_entry,
);
idx += 1;
Ok(())
},
)?;
if self.req.min_doc_count == 0 {
// TODO: Handle rev streaming for descending sorting by keys
let mut stream = term_dict.dictionary().stream()?;
let mut stream = term_dict.stream()?;
let empty_sub_aggregation = IntermediateAggregationResults::empty_from_req(
agg_with_accessor.agg.sub_aggregation(),
);

View File

@@ -26,6 +26,7 @@ use super::segment_agg_result::AggregationLimits;
use super::{format_date, AggregationError, Key, SerializedKey};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
use crate::aggregation::bucket::TermsAggregationInternal;
use crate::aggregation::metric::CardinalityCollector;
use crate::TantivyError;
/// Contains the intermediate aggregation result, which is optimized to be merged with other
@@ -227,6 +228,9 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
TopHits(ref req) => IntermediateAggregationResult::Metric(
IntermediateMetricResult::TopHits(TopHitsTopNComputer::new(req)),
),
Cardinality(_) => IntermediateAggregationResult::Metric(
IntermediateMetricResult::Cardinality(CardinalityCollector::default()),
),
}
}
@@ -291,6 +295,8 @@ pub enum IntermediateMetricResult {
Sum(IntermediateSum),
/// Intermediate top_hits result
TopHits(TopHitsTopNComputer),
/// Intermediate cardinality result
Cardinality(CardinalityCollector),
}
impl IntermediateMetricResult {
@@ -324,6 +330,9 @@ impl IntermediateMetricResult {
IntermediateMetricResult::TopHits(top_hits) => {
MetricResult::TopHits(top_hits.into_final_result())
}
IntermediateMetricResult::Cardinality(cardinality) => {
MetricResult::Cardinality(cardinality.finalize().into())
}
}
}
@@ -372,6 +381,12 @@ impl IntermediateMetricResult {
(IntermediateMetricResult::TopHits(left), IntermediateMetricResult::TopHits(right)) => {
left.merge_fruits(right)?;
}
(
IntermediateMetricResult::Cardinality(left),
IntermediateMetricResult::Cardinality(right),
) => {
left.merge_fruits(right)?;
}
_ => {
panic!("incompatible fruit types in tree or missing merge_fruits handler");
}
@@ -584,6 +599,7 @@ impl IntermediateTermBucketResult {
let val = if key { "true" } else { "false" };
Some(val.to_string())
}
IntermediateKey::F64(val) => Some(val.to_string()),
_ => None,
};
Ok(BucketEntry {

View File

@@ -0,0 +1,466 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher};
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::Dictionary;
use common::f64_to_u64;
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use rustc_hash::FxHashSet;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
#[derive(Clone, Debug, Serialize, Deserialize)]
struct BuildSaltedHasher {
salt: u8,
}
impl BuildHasher for BuildSaltedHasher {
type Hasher = DefaultHasher;
fn build_hasher(&self) -> Self::Hasher {
let mut hasher = DefaultHasher::new();
hasher.write_u8(self.salt);
hasher
}
}
/// # Cardinality
///
/// The cardinality aggregation allows for computing an estimate
/// of the number of different values in a data set based on the
/// HyperLogLog++ algorithm. This is particularly useful for understanding the
/// uniqueness of values in a large dataset where counting each unique value
/// individually would be computationally expensive.
///
/// For example, you might use a cardinality aggregation to estimate the number
/// of unique visitors to a website by aggregating on a field that contains
/// user IDs or session IDs.
///
/// To use the cardinality aggregation, you'll need to provide a field to
/// aggregate on. The following example demonstrates a request for the cardinality
/// of the "user_id" field:
///
/// ```JSON
/// {
/// "cardinality": {
/// "field": "user_id"
/// }
/// }
/// ```
///
/// This request will return an estimate of the number of unique values in the
/// "user_id" field.
///
/// ## Missing Values
///
/// The `missing` parameter defines how documents that are missing a value should be treated.
/// By default, documents without a value for the specified field are ignored. However, you can
/// specify a default value for these documents using the `missing` parameter. This can be useful
/// when you want to include documents with missing values in the aggregation.
///
/// For example, the following request treats documents with missing values in the "user_id"
/// field as if they had a value of "unknown":
///
/// ```JSON
/// {
/// "cardinality": {
/// "field": "user_id",
/// "missing": "unknown"
/// }
/// }
/// ```
///
/// # Estimation Accuracy
///
/// The cardinality aggregation provides an approximate count, which is usually
/// accurate within a small error range. This trade-off allows for efficient
/// computation even on very large datasets.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CardinalityAggregationReq {
/// The field name to compute the percentiles on.
pub field: String,
/// The missing parameter defines how documents that are missing a value should be treated.
/// By default they will be ignored but it is also possible to treat them as if they had a
/// value. Examples in JSON format:
/// { "field": "my_numbers", "missing": "10.0" }
#[serde(skip_serializing_if = "Option::is_none", default)]
pub missing: Option<Key>,
}
impl CardinalityAggregationReq {
/// Creates a new [`CardinalityAggregationReq`] instance from a field name.
pub fn from_field_name(field_name: String) -> Self {
Self {
field: field_name,
missing: None,
}
}
/// Returns the field name the aggregation is computed on.
pub fn field_name(&self) -> &str {
&self.field
}
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentCardinalityCollector {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
column_type: ColumnType,
accessor_idx: usize,
missing: Option<Key>,
}
impl SegmentCardinalityCollector {
pub fn from_req(column_type: ColumnType, accessor_idx: usize, missing: &Option<Key>) -> Self {
Self {
cardinality: CardinalityCollector::new(column_type as u8),
entries: Default::default(),
column_type,
accessor_idx,
missing: missing.clone(),
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_accessor: &mut AggregationWithAccessor,
) {
if let Some(missing) = agg_accessor.missing_value_for_accessor {
agg_accessor.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
missing,
);
} else {
agg_accessor
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
}
}
fn into_intermediate_metric_result(
mut self,
agg_with_accessor: &AggregationWithAccessor,
) -> crate::Result<IntermediateMetricResult> {
if self.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let dict = agg_with_accessor
.str_dict_column
.as_ref()
.map(|el| el.dictionary())
.unwrap_or_else(|| &fallback_dict);
let mut has_missing = false;
// TODO: replace FxHashSet with something that allows iterating in order
// (e.g. sparse bitvec)
let mut term_ids = Vec::new();
for term_ord in self.entries.into_iter() {
if term_ord == u64::MAX {
has_missing = true;
} else {
// we can reasonably exclude values above u32::MAX
term_ids.push(term_ord as u32);
}
}
term_ids.sort_unstable();
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
self.cardinality.sketch.insert_any(&term);
Ok(())
})?;
if has_missing {
let missing_key = self
.missing
.as_ref()
.expect("Found placeholder term_ord but `missing` is None");
match missing_key {
Key::Str(missing) => {
self.cardinality.sketch.insert_any(&missing);
}
Key::F64(val) => {
let val = f64_to_u64(*val);
self.cardinality.sketch.insert_any(&val);
}
}
}
}
Ok(IntermediateMetricResult::Cardinality(self.cardinality))
}
}
impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
let intermediate_result = self.into_intermediate_metric_result(agg_with_accessor)?;
results.push(
name,
IntermediateAggregationResult::Metric(intermediate_result),
)?;
Ok(())
}
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_with_accessor)
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.fetch_block_with_field(docs, bucket_agg_accessor);
let col_block_accessor = &bucket_agg_accessor.column_block_accessor;
if self.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() {
self.entries.insert(term_ord);
}
} else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = bucket_agg_accessor
.accessor
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(
crate::aggregation::AggregationError::InternalError(
"Type mismatch: Could not downcast to CompactSpaceU64Accessor"
.to_string(),
),
)
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
self.cardinality.sketch.insert_any(&val);
}
} else {
for val in col_block_accessor.iter_vals() {
self.cardinality.sketch.insert_any(&val);
}
}
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
/// The percentiles collector used during segment collection and for merging results.
pub struct CardinalityCollector {
sketch: HyperLogLogPlus<u64, BuildSaltedHasher>,
}
impl Default for CardinalityCollector {
fn default() -> Self {
Self::new(0)
}
}
impl PartialEq for CardinalityCollector {
fn eq(&self, _other: &Self) -> bool {
false
}
}
impl CardinalityCollector {
/// Compute the final cardinality estimate.
pub fn finalize(self) -> Option<f64> {
Some(self.sketch.clone().count().trunc())
}
fn new(salt: u8) -> Self {
Self {
sketch: HyperLogLogPlus::new(16, BuildSaltedHasher { salt }).unwrap(),
}
}
pub(crate) fn merge_fruits(&mut self, right: CardinalityCollector) -> crate::Result<()> {
self.sketch.merge(&right.sketch).map_err(|err| {
TantivyError::AggregationError(AggregationError::InternalError(format!(
"Error while merging cardinality {err:?}"
)))
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::net::IpAddr;
use std::str::FromStr;
use columnar::MonotonicallyMappableToU64;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::tests::{exec_request, get_test_index_from_terms};
use crate::schema::{IntoIpv6Addr, Schema, FAST};
use crate::Index;
#[test]
fn cardinality_aggregation_test_empty_index() -> crate::Result<()> {
let values = vec![];
let index = get_test_index_from_terms(false, &values)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "string_id",
}
},
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 0.0);
Ok(())
}
#[test]
fn cardinality_aggregation_test_single_segment() -> crate::Result<()> {
cardinality_aggregation_test_merge_segment(true)
}
#[test]
fn cardinality_aggregation_test() -> crate::Result<()> {
cardinality_aggregation_test_merge_segment(false)
}
fn cardinality_aggregation_test_merge_segment(merge_segments: bool) -> crate::Result<()> {
let segment_and_terms = vec![
vec!["terma"],
vec!["termb"],
vec!["termc"],
vec!["terma"],
vec!["terma"],
vec!["terma"],
vec!["termb"],
vec!["terma"],
];
let index = get_test_index_from_terms(merge_segments, &segment_and_terms)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "string_id",
}
},
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 3.0);
Ok(())
}
#[test]
fn cardinality_aggregation_u64() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", FAST);
let index = Index::create_in_ram(schema_builder.build());
{
let mut writer = index.writer_for_tests()?;
writer.add_document(doc!(id_field => 1u64))?;
writer.add_document(doc!(id_field => 2u64))?;
writer.add_document(doc!(id_field => 3u64))?;
writer.add_document(doc!())?;
writer.commit()?;
}
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "id",
"missing": 0u64
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 4.0);
Ok(())
}
#[test]
fn cardinality_aggregation_ip_addr() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let field = schema_builder.add_ip_addr_field("ip_field", FAST);
let index = Index::create_in_ram(schema_builder.build());
{
let mut writer = index.writer_for_tests()?;
// IpV6 loopback
writer.add_document(doc!(field=>IpAddr::from_str("::1").unwrap().into_ipv6_addr()))?;
writer.add_document(doc!(field=>IpAddr::from_str("::1").unwrap().into_ipv6_addr()))?;
// IpV4
writer.add_document(
doc!(field=>IpAddr::from_str("127.0.0.1").unwrap().into_ipv6_addr()),
)?;
writer.commit()?;
}
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "ip_field"
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 2.0);
Ok(())
}
#[test]
fn cardinality_aggregation_json() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let field = schema_builder.add_json_field("json", FAST);
let index = Index::create_in_ram(schema_builder.build());
{
let mut writer = index.writer_for_tests()?;
writer.add_document(doc!(field => json!({"value": false})))?;
writer.add_document(doc!(field => json!({"value": true})))?;
writer.add_document(doc!(field => json!({"value": i64::from_u64(0u64)})))?;
writer.add_document(doc!(field => json!({"value": i64::from_u64(1u64)})))?;
writer.commit()?;
}
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "json.value"
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 4.0);
Ok(())
}
}

View File

@@ -17,6 +17,7 @@
//! - [Percentiles](PercentilesAggregationReq)
mod average;
mod cardinality;
mod count;
mod extended_stats;
mod max;
@@ -29,6 +30,7 @@ mod top_hits;
use std::collections::HashMap;
pub use average::*;
pub use cardinality::*;
pub use count::*;
pub use extended_stats::*;
pub use max::*;

View File

@@ -89,7 +89,7 @@ use crate::{DocAddress, DocId, SegmentOrdinal};
/// }
/// ```
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct TopHitsAggregation {
pub struct TopHitsAggregationReq {
sort: Vec<KeyOrder>,
size: usize,
from: Option<usize>,
@@ -164,7 +164,7 @@ fn unsupported_err(parameter: &str) -> crate::Result<()> {
))
}
impl TopHitsAggregation {
impl TopHitsAggregationReq {
/// Validate and resolve field retrieval parameters
pub fn validate_and_resolve_field_names(
&mut self,
@@ -431,7 +431,7 @@ impl Eq for DocSortValuesAndFields {}
/// The TopHitsCollector used for collecting over segments and merging results.
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct TopHitsTopNComputer {
req: TopHitsAggregation,
req: TopHitsAggregationReq,
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, false>,
}
@@ -443,7 +443,7 @@ impl std::cmp::PartialEq for TopHitsTopNComputer {
impl TopHitsTopNComputer {
/// Create a new TopHitsCollector
pub fn new(req: &TopHitsAggregation) -> Self {
pub fn new(req: &TopHitsAggregationReq) -> Self {
Self {
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
req: req.clone(),
@@ -496,7 +496,7 @@ pub(crate) struct TopHitsSegmentCollector {
impl TopHitsSegmentCollector {
pub fn from_req(
req: &TopHitsAggregation,
req: &TopHitsAggregationReq,
accessor_idx: usize,
segment_ordinal: SegmentOrdinal,
) -> Self {
@@ -509,7 +509,7 @@ impl TopHitsSegmentCollector {
fn into_top_hits_collector(
self,
value_accessors: &HashMap<String, Vec<DynamicColumn>>,
req: &TopHitsAggregation,
req: &TopHitsAggregationReq,
) -> TopHitsTopNComputer {
let mut top_hits_computer = TopHitsTopNComputer::new(req);
let top_results = self.top_n.into_vec();
@@ -532,7 +532,7 @@ impl TopHitsSegmentCollector {
fn collect_with(
&mut self,
doc_id: crate::DocId,
req: &TopHitsAggregation,
req: &TopHitsAggregationReq,
accessors: &[(Column<u64>, ColumnType)],
) -> crate::Result<()> {
let sorts: Vec<DocValueAndOrder> = req

View File

@@ -44,11 +44,14 @@
//! - [Metric](metric)
//! - [Average](metric::AverageAggregation)
//! - [Stats](metric::StatsAggregation)
//! - [ExtendedStats](metric::ExtendedStatsAggregation)
//! - [Min](metric::MinAggregation)
//! - [Max](metric::MaxAggregation)
//! - [Sum](metric::SumAggregation)
//! - [Count](metric::CountAggregation)
//! - [Percentiles](metric::PercentilesAggregationReq)
//! - [Cardinality](metric::CardinalityAggregationReq)
//! - [TopHits](metric::TopHitsAggregationReq)
//!
//! # Example
//! Compute the average metric, by building [`agg_req::Aggregations`], which is built from an

View File

@@ -16,7 +16,10 @@ use super::metric::{
SumAggregation,
};
use crate::aggregation::bucket::TermMissingAgg;
use crate::aggregation::metric::{SegmentExtendedStatsCollector, TopHitsSegmentCollector};
use crate::aggregation::metric::{
CardinalityAggregationReq, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
TopHitsSegmentCollector,
};
pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug {
fn add_intermediate_aggregation_result(
@@ -169,6 +172,9 @@ pub(crate) fn build_single_agg_segment_collector(
accessor_idx,
req.segment_ordinal,
))),
Cardinality(CardinalityAggregationReq { missing, .. }) => Ok(Box::new(
SegmentCardinalityCollector::from_req(req.field_type, accessor_idx, missing),
)),
}
}

View File

@@ -1,4 +1,4 @@
use common::json_path_writer::JSON_PATH_SEGMENT_SEP;
use common::json_path_writer::{JSON_END_OF_PATH, JSON_PATH_SEGMENT_SEP};
use common::{replace_in_place, JsonPathWriter};
use rustc_hash::FxHashMap;
@@ -83,6 +83,9 @@ fn index_json_object<'a, V: Value<'a>>(
positions_per_path: &mut IndexingPositionsPerPath,
) {
for (json_path_segment, json_value_visitor) in json_visitor {
if json_path_segment.as_bytes().contains(&JSON_END_OF_PATH) {
continue;
}
json_path_writer.push(json_path_segment);
index_json_value(
doc,

View File

@@ -815,8 +815,9 @@ mod tests {
use crate::indexer::NoMergePolicy;
use crate::query::{QueryParser, TermQuery};
use crate::schema::{
self, Facet, FacetOptions, IndexRecordOption, IpAddrOptions, NumericOptions,
TextFieldIndexing, TextOptions, Value, FAST, INDEXED, STORED, STRING, TEXT,
self, Facet, FacetOptions, IndexRecordOption, IpAddrOptions, JsonObjectOptions,
NumericOptions, Schema, TextFieldIndexing, TextOptions, Value, FAST, INDEXED, STORED,
STRING, TEXT,
};
use crate::store::DOCSTORE_CACHE_CAPACITY;
use crate::{
@@ -1573,11 +1574,11 @@ mod tests {
deleted_ids.remove(id);
}
IndexingOp::DeleteDoc { id } => {
existing_ids.remove(&id);
existing_ids.remove(id);
deleted_ids.insert(*id);
}
IndexingOp::DeleteDocQuery { id } => {
existing_ids.remove(&id);
existing_ids.remove(id);
deleted_ids.insert(*id);
}
_ => {}
@@ -2378,11 +2379,11 @@ mod tests {
#[test]
fn test_bug_1617_2() {
assert!(test_operation_strategy(
test_operation_strategy(
&[
IndexingOp::AddDoc {
id: 13,
value: Default::default()
value: Default::default(),
},
IndexingOp::DeleteDoc { id: 13 },
IndexingOp::Commit,
@@ -2390,9 +2391,9 @@ mod tests {
IndexingOp::Commit,
IndexingOp::Merge,
],
true
true,
)
.is_ok());
.unwrap();
}
#[test]
@@ -2492,9 +2493,9 @@ mod tests {
}
#[test]
fn test_bug_2442() -> crate::Result<()> {
fn test_bug_2442_reserved_character_fast_field() -> crate::Result<()> {
let mut schema_builder = schema::Schema::builder();
let json_field = schema_builder.add_json_field("json", TEXT | FAST);
let json_field = schema_builder.add_json_field("json", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::builder().schema(schema).create_in_ram()?;
@@ -2515,4 +2516,21 @@ mod tests {
Ok(())
}
#[test]
fn test_bug_2442_reserved_character_columnar() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let options = JsonObjectOptions::from(FAST).set_expand_dots_enabled();
let field = schema_builder.add_json_field("json", options);
let index = Index::create_in_ram(schema_builder.build());
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(doc!(field=>json!({"\u{0000}": "A"})))
.unwrap();
index_writer
.add_document(doc!(field=>json!({format!("\u{0000}\u{0000}"): "A"})))
.unwrap();
index_writer.commit().unwrap();
Ok(())
}
}

View File

@@ -145,15 +145,27 @@ mod tests_mmap {
}
}
#[test]
fn test_json_field_null_byte() {
// Test when field name contains a zero byte, which has special meaning in tantivy.
// As a workaround, we convert the zero byte to the ASCII character '0'.
// https://github.com/quickwit-oss/tantivy/issues/2340
// https://github.com/quickwit-oss/tantivy/issues/2193
let field_name_in = "\u{0000}";
let field_name_out = "0";
test_json_field_name(field_name_in, field_name_out);
fn test_json_field_null_byte_is_ignored() {
let mut schema_builder = Schema::builder();
let options = JsonObjectOptions::from(TEXT | FAST).set_expand_dots_enabled();
let field = schema_builder.add_json_field("json", options);
let index = Index::create_in_ram(schema_builder.build());
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(doc!(field=>json!({"key": "test1", "invalidkey\u{0000}": "test2"})))
.unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let inv_indexer = segment_reader.inverted_index(field).unwrap();
let term_dict = inv_indexer.terms();
assert_eq!(term_dict.num_terms(), 1);
let mut term_bytes = Vec::new();
term_dict.ord_to_term(0, &mut term_bytes).unwrap();
assert_eq!(term_bytes, b"key\0stest1");
}
#[test]
fn test_json_field_1byte() {
// Test when field name contains a '1' byte, which has special meaning in tantivy.
@@ -291,7 +303,7 @@ mod tests_mmap {
Type::Str,
),
(format!("{field_name_out_internal}a"), Type::Str),
(format!("{field_name_out_internal}"), Type::Str),
(field_name_out_internal.to_string(), Type::Str),
(format!("num{field_name_out_internal}"), Type::I64),
];
expected_fields.sort();

View File

@@ -1,5 +1,3 @@
use common::json_path_writer::JSON_END_OF_PATH;
use common::replace_in_place;
use fnv::FnvHashMap;
/// `Field` is represented by an unsigned 32-bit integer type.
@@ -40,13 +38,7 @@ impl PathToUnorderedId {
#[cold]
fn insert_new_path(&mut self, path: &str) -> u32 {
let next_id = self.map.len() as u32;
let mut new_path = path.to_string();
// The unsafe below is safe as long as b'.' and JSON_PATH_SEGMENT_SEP are
// valid single byte ut8 strings.
// By utf-8 design, they cannot be part of another codepoint.
unsafe { replace_in_place(JSON_END_OF_PATH, b'0', new_path.as_bytes_mut()) };
let new_path = path.to_string();
self.map.insert(new_path, next_id);
next_id
}

View File

@@ -3,7 +3,7 @@
//! In "The beauty and the beast", the term "the" appears in position 0 and position 3.
//! This information is useful to run phrase queries.
//!
//! The [position](crate::SegmentComponent::Positions) file contains all of the
//! The [position](crate::index::SegmentComponent::Positions) file contains all of the
//! bitpacked positions delta, for all terms of a given field, one term after the other.
//!
//! Each term is encoded independently.

View File

@@ -59,7 +59,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
/// The actual serialization format is handled by the `PostingsSerializer`.
fn serialize(
&self,
term_addrs: &[(Field, OrderedPathId, &[u8], Addr)],
ordered_term_addrs: &[(Field, OrderedPathId, &[u8], Addr)],
ordered_id_to_path: &[&str],
ctx: &IndexingContext,
serializer: &mut FieldSerializer,
@@ -69,7 +69,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
term_buffer.clear_with_field_and_type(Type::Json, Field::from_field_id(0));
let mut prev_term_id = u32::MAX;
let mut term_path_len = 0; // this will be set in the first iteration
for (_field, path_id, term, addr) in term_addrs {
for (_field, path_id, term, addr) in ordered_term_addrs {
if prev_term_id != path_id.path_id() {
term_buffer.truncate_value_bytes(0);
term_buffer.append_path(ordered_id_to_path[path_id.path_id() as usize].as_bytes());

View File

@@ -15,6 +15,7 @@ pub trait Postings: DocSet + 'static {
fn term_freq(&self) -> u32;
/// Returns the positions offsetted with a given value.
/// It is not necessary to clear the `output` before calling this method.
/// The output vector will be resized to the `term_freq`.
fn positions_with_offset(&mut self, offset: u32, output: &mut Vec<u32>);

View File

@@ -22,10 +22,7 @@ pub struct AllWeight;
impl Weight for AllWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let all_scorer = AllScorer {
doc: 0u32,
max_doc: reader.max_doc(),
};
let all_scorer = AllScorer::new(reader.max_doc());
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
}
@@ -43,6 +40,13 @@ pub struct AllScorer {
max_doc: DocId,
}
impl AllScorer {
/// Creates a new AllScorer with `max_doc` docs.
pub fn new(max_doc: DocId) -> AllScorer {
AllScorer { doc: 0u32, max_doc }
}
}
impl DocSet for AllScorer {
#[inline(always)]
fn advance(&mut self) -> DocId {

View File

@@ -66,6 +66,10 @@ use crate::schema::{IndexRecordOption, Term};
/// Term::from_field_text(title, "diary"),
/// IndexRecordOption::Basic,
/// ));
/// let cow_term_query: Box<dyn Query> = Box::new(TermQuery::new(
/// Term::from_field_text(title, "cow"),
/// IndexRecordOption::Basic
/// ));
/// // A TermQuery with "found" in the body
/// let body_term_query: Box<dyn Query> = Box::new(TermQuery::new(
/// Term::from_field_text(body, "found"),
@@ -74,7 +78,7 @@ use crate::schema::{IndexRecordOption, Term};
/// // TermQuery "diary" must and "girl" must not be present
/// let queries_with_occurs1 = vec![
/// (Occur::Must, diary_term_query.box_clone()),
/// (Occur::MustNot, girl_term_query),
/// (Occur::MustNot, girl_term_query.box_clone()),
/// ];
/// // Make a BooleanQuery equivalent to
/// // title:+diary title:-girl
@@ -82,15 +86,10 @@ use crate::schema::{IndexRecordOption, Term};
/// let count1 = searcher.search(&diary_must_and_girl_mustnot, &Count)?;
/// assert_eq!(count1, 1);
///
/// // TermQuery for "cow" in the title
/// let cow_term_query: Box<dyn Query> = Box::new(TermQuery::new(
/// Term::from_field_text(title, "cow"),
/// IndexRecordOption::Basic,
/// ));
/// // "title:diary OR title:cow"
/// let title_diary_or_cow = BooleanQuery::new(vec![
/// (Occur::Should, diary_term_query.box_clone()),
/// (Occur::Should, cow_term_query),
/// (Occur::Should, cow_term_query.box_clone()),
/// ]);
/// let count2 = searcher.search(&title_diary_or_cow, &Count)?;
/// assert_eq!(count2, 4);
@@ -118,21 +117,38 @@ use crate::schema::{IndexRecordOption, Term};
/// ]);
/// let count4 = searcher.search(&nested_query, &Count)?;
/// assert_eq!(count4, 1);
///
/// // You may call `with_minimum_required_clauses` to
/// // specify the number of should clauses the returned documents must match.
/// let minimum_required_query = BooleanQuery::with_minimum_required_clauses(vec![
/// (Occur::Should, cow_term_query.box_clone()),
/// (Occur::Should, girl_term_query.box_clone()),
/// (Occur::Should, diary_term_query.box_clone()),
/// ], 2);
/// // Return documents contains "Diary Cow", "Diary Girl" or "Cow Girl"
/// // Notice: "Diary" isn't "Dairy". ;-)
/// let count5 = searcher.search(&minimum_required_query, &Count)?;
/// assert_eq!(count5, 1);
/// Ok(())
/// }
/// ```
#[derive(Debug)]
pub struct BooleanQuery {
subqueries: Vec<(Occur, Box<dyn Query>)>,
minimum_number_should_match: usize,
}
impl Clone for BooleanQuery {
fn clone(&self) -> Self {
self.subqueries
let subqueries = self
.subqueries
.iter()
.map(|(occur, subquery)| (*occur, subquery.box_clone()))
.collect::<Vec<_>>()
.into()
.collect::<Vec<_>>();
Self {
subqueries,
minimum_number_should_match: self.minimum_number_should_match,
}
}
}
@@ -149,8 +165,9 @@ impl Query for BooleanQuery {
.iter()
.map(|(occur, subquery)| Ok((*occur, subquery.weight(enable_scoring)?)))
.collect::<crate::Result<_>>()?;
Ok(Box::new(BooleanWeight::new(
Ok(Box::new(BooleanWeight::with_minimum_number_should_match(
sub_weights,
self.minimum_number_should_match,
enable_scoring.is_scoring_enabled(),
Box::new(SumWithCoordsCombiner::default),
)))
@@ -166,7 +183,41 @@ impl Query for BooleanQuery {
impl BooleanQuery {
/// Creates a new boolean query.
pub fn new(subqueries: Vec<(Occur, Box<dyn Query>)>) -> BooleanQuery {
BooleanQuery { subqueries }
// If the bool query includes at least one should clause
// and no Must or MustNot clauses, the default value is 1. Otherwise, the default value is
// 0. Keep pace with Elasticsearch.
let mut minimum_required = 0;
for (occur, _) in &subqueries {
match occur {
Occur::Should => minimum_required = 1,
Occur::Must | Occur::MustNot => {
minimum_required = 0;
break;
}
}
}
Self::with_minimum_required_clauses(subqueries, minimum_required)
}
/// Create a new boolean query with minimum number of required should clauses specified.
pub fn with_minimum_required_clauses(
subqueries: Vec<(Occur, Box<dyn Query>)>,
minimum_number_should_match: usize,
) -> BooleanQuery {
BooleanQuery {
subqueries,
minimum_number_should_match,
}
}
/// Getter for `minimum_number_should_match`
pub fn get_minimum_number_should_match(&self) -> usize {
self.minimum_number_should_match
}
/// Setter for `minimum_number_should_match`
pub fn set_minimum_number_should_match(&mut self, minimum_number_should_match: usize) {
self.minimum_number_should_match = minimum_number_should_match;
}
/// Returns the intersection of the queries.
@@ -181,6 +232,18 @@ impl BooleanQuery {
BooleanQuery::new(subqueries)
}
/// Returns the union of the queries with minimum required clause.
pub fn union_with_minimum_required_clauses(
queries: Vec<Box<dyn Query>>,
minimum_required_clauses: usize,
) -> BooleanQuery {
let subqueries = queries
.into_iter()
.map(|sub_query| (Occur::Should, sub_query))
.collect();
BooleanQuery::with_minimum_required_clauses(subqueries, minimum_required_clauses)
}
/// Helper method to create a boolean query matching a given list of terms.
/// The resulting query is a disjunction of the terms.
pub fn new_multiterms_query(terms: Vec<Term>) -> BooleanQuery {
@@ -203,11 +266,13 @@ impl BooleanQuery {
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::BooleanQuery;
use crate::collector::{Count, DocSetCollector};
use crate::query::{QueryClone, QueryParser, TermQuery};
use crate::schema::{IndexRecordOption, Schema, TEXT};
use crate::{DocAddress, Index, Term};
use crate::query::{Query, QueryClone, QueryParser, TermQuery};
use crate::schema::{Field, IndexRecordOption, Schema, TEXT};
use crate::{DocAddress, DocId, Index, Term};
fn create_test_index() -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
@@ -223,6 +288,73 @@ mod tests {
Ok(index)
}
#[test]
fn test_minimum_required() -> crate::Result<()> {
fn create_test_index_with<T: IntoIterator<Item = &'static str>>(
docs: T,
) -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
let text = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests()?;
for doc in docs {
writer.add_document(doc!(text => doc))?;
}
writer.commit()?;
Ok(index)
}
fn create_boolean_query_with_mr<T: IntoIterator<Item = &'static str>>(
queries: T,
field: Field,
mr: usize,
) -> BooleanQuery {
let terms = queries
.into_iter()
.map(|t| Term::from_field_text(field, t))
.map(|t| TermQuery::new(t, IndexRecordOption::Basic))
.map(|q| -> Box<dyn Query> { Box::new(q) })
.collect();
BooleanQuery::union_with_minimum_required_clauses(terms, mr)
}
fn check_doc_id<T: IntoIterator<Item = DocId>>(
expected: T,
actually: HashSet<DocAddress>,
seg: u32,
) {
assert_eq!(
actually,
expected
.into_iter()
.map(|id| DocAddress::new(seg, id))
.collect()
);
}
let index = create_test_index_with(["a b c", "a c e", "d f g", "z z z", "c i b"])?;
let searcher = index.reader()?.searcher();
let text = index.schema().get_field("text").unwrap();
// Documents contains 'a c' 'a z' 'a i' 'c z' 'c i' or 'z i' shall be return.
let q1 = create_boolean_query_with_mr(["a", "c", "z", "i"], text, 2);
let docs = searcher.search(&q1, &DocSetCollector)?;
check_doc_id([0, 1, 4], docs, 0);
// Documents contains 'a b c', 'a b e', 'a c e' or 'b c e' shall be return.
let q2 = create_boolean_query_with_mr(["a", "b", "c", "e"], text, 3);
let docs = searcher.search(&q2, &DocSetCollector)?;
check_doc_id([0, 1], docs, 0);
// Nothing queried since minimum_required is too large.
let q3 = create_boolean_query_with_mr(["a", "b"], text, 3);
let docs = searcher.search(&q3, &DocSetCollector)?;
assert!(docs.is_empty());
// When mr is set to zero or one, there are no difference with `Boolean::Union`.
let q4 = create_boolean_query_with_mr(["a", "z"], text, 1);
let docs = searcher.search(&q4, &DocSetCollector)?;
check_doc_id([0, 1, 3], docs, 0);
let q5 = create_boolean_query_with_mr(["a", "b"], text, 0);
let docs = searcher.search(&q5, &DocSetCollector)?;
check_doc_id([0, 1, 4], docs, 0);
Ok(())
}
#[test]
fn test_union() -> crate::Result<()> {
let index = create_test_index()?;

View File

@@ -3,6 +3,7 @@ use std::collections::HashMap;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::index::SegmentReader;
use crate::postings::FreqReadingOption;
use crate::query::disjunction::Disjunction;
use crate::query::explanation::does_not_match;
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::term_query::TermScorer;
@@ -18,6 +19,26 @@ enum SpecializedScorer {
Other(Box<dyn Scorer>),
}
fn scorer_disjunction<TScoreCombiner>(
scorers: Vec<Box<dyn Scorer>>,
score_combiner: TScoreCombiner,
minimum_match_required: usize,
) -> Box<dyn Scorer>
where
TScoreCombiner: ScoreCombiner,
{
debug_assert!(!scorers.is_empty());
debug_assert!(minimum_match_required > 1);
if scorers.len() == 1 {
return scorers.into_iter().next().unwrap(); // Safe unwrap.
}
Box::new(Disjunction::new(
scorers,
score_combiner,
minimum_match_required,
))
}
fn scorer_union<TScoreCombiner>(
scorers: Vec<Box<dyn Scorer>>,
score_combiner_fn: impl Fn() -> TScoreCombiner,
@@ -70,6 +91,7 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
/// Weight associated to the `BoolQuery`.
pub struct BooleanWeight<TScoreCombiner: ScoreCombiner> {
weights: Vec<(Occur, Box<dyn Weight>)>,
minimum_number_should_match: usize,
scoring_enabled: bool,
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send>,
}
@@ -85,6 +107,22 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
weights,
scoring_enabled,
score_combiner_fn,
minimum_number_should_match: 1,
}
}
/// Create a new boolean weight with minimum number of required should clauses specified.
pub fn with_minimum_number_should_match(
weights: Vec<(Occur, Box<dyn Weight>)>,
minimum_number_should_match: usize,
scoring_enabled: bool,
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send + 'static>,
) -> BooleanWeight<TScoreCombiner> {
BooleanWeight {
weights,
minimum_number_should_match,
scoring_enabled,
score_combiner_fn,
}
}
@@ -111,43 +149,89 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
) -> crate::Result<SpecializedScorer> {
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
let should_scorer_opt: Option<SpecializedScorer> = per_occur_scorers
.remove(&Occur::Should)
.map(|scorers| scorer_union(scorers, &score_combiner_fn));
// Indicate how should clauses are combined with other clauses.
enum CombinationMethod {
Ignored,
// Only contributes to final score.
Optional(SpecializedScorer),
// Must be fitted.
Required(Box<dyn Scorer>),
}
let mut must_scorers = per_occur_scorers.remove(&Occur::Must);
let should_opt = if let Some(mut should_scorers) = per_occur_scorers.remove(&Occur::Should)
{
let num_of_should_scorers = should_scorers.len();
if self.minimum_number_should_match > num_of_should_scorers {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
match self.minimum_number_should_match {
0 => CombinationMethod::Optional(scorer_union(should_scorers, &score_combiner_fn)),
1 => CombinationMethod::Required(into_box_scorer(
scorer_union(should_scorers, &score_combiner_fn),
&score_combiner_fn,
)),
n if num_of_should_scorers == n => {
// When num_of_should_scorers equals the number of should clauses,
// they are no different from must clauses.
must_scorers = match must_scorers.take() {
Some(mut must_scorers) => {
must_scorers.append(&mut should_scorers);
Some(must_scorers)
}
None => Some(should_scorers),
};
CombinationMethod::Ignored
}
_ => CombinationMethod::Required(scorer_disjunction(
should_scorers,
score_combiner_fn(),
self.minimum_number_should_match,
)),
}
} else {
// None of should clauses are provided.
if self.minimum_number_should_match > 0 {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
} else {
CombinationMethod::Ignored
}
};
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::MustNot)
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default))
.map(|specialized_scorer| {
.map(|specialized_scorer: SpecializedScorer| {
into_box_scorer(specialized_scorer, DoNothingCombiner::default)
});
let must_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::Must)
.map(intersect_scorers);
let positive_scorer: SpecializedScorer = match (should_scorer_opt, must_scorer_opt) {
(Some(should_scorer), Some(must_scorer)) => {
let positive_scorer = match (should_opt, must_scorers) {
(CombinationMethod::Ignored, Some(must_scorers)) => {
SpecializedScorer::Other(intersect_scorers(must_scorers))
}
(CombinationMethod::Optional(should_scorer), Some(must_scorers)) => {
let must_scorer = intersect_scorers(must_scorers);
if self.scoring_enabled {
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
Box<dyn Scorer>,
Box<dyn Scorer>,
TComplexScoreCombiner,
>::new(
must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn),
)))
SpecializedScorer::Other(Box::new(
RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn),
),
))
} else {
SpecializedScorer::Other(must_scorer)
}
}
(None, Some(must_scorer)) => SpecializedScorer::Other(must_scorer),
(Some(should_scorer), None) => should_scorer,
(None, None) => {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
(CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => {
must_scorers.push(should_scorer);
SpecializedScorer::Other(intersect_scorers(must_scorers))
}
(CombinationMethod::Ignored, None) => {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)))
}
(CombinationMethod::Required(should_scorer), None) => {
SpecializedScorer::Other(should_scorer)
}
// Optional options are promoted to required if no must scorers exists.
(CombinationMethod::Optional(should_scorer), None) => should_scorer,
};
if let Some(exclude_scorer) = exclude_scorer_opt {
let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn);
Ok(SpecializedScorer::Other(Box::new(Exclude::new(

327
src/query/disjunction.rs Normal file
View File

@@ -0,0 +1,327 @@
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::query::score_combiner::DoNothingCombiner;
use crate::query::{ScoreCombiner, Scorer};
use crate::{DocId, DocSet, Score, TERMINATED};
/// `Disjunction` is responsible for merging `DocSet` from multiple
/// source. Specifically, It takes the union of two or more `DocSet`s
/// then filtering out elements that appear fewer times than a
/// specified threshold.
pub struct Disjunction<TScorer, TScoreCombiner = DoNothingCombiner> {
chains: BinaryHeap<ScorerWrapper<TScorer>>,
minimum_matches_required: usize,
score_combiner: TScoreCombiner,
current_doc: DocId,
current_score: Score,
}
/// A wrapper around a `Scorer` that caches the current `doc_id` and implements the `DocSet` trait.
/// Also, the `Ord` trait and it's family are implemented reversely. So that we can combine
/// `std::BinaryHeap<ScorerWrapper<T>>` to gain a min-heap with current doc id as key.
struct ScorerWrapper<T> {
scorer: T,
current_doc: DocId,
}
impl<T: Scorer> ScorerWrapper<T> {
fn new(scorer: T) -> Self {
let current_doc = scorer.doc();
Self {
scorer,
current_doc,
}
}
}
impl<T: Scorer> PartialEq for ScorerWrapper<T> {
fn eq(&self, other: &Self) -> bool {
self.doc() == other.doc()
}
}
impl<T: Scorer> Eq for ScorerWrapper<T> {}
impl<T: Scorer> PartialOrd for ScorerWrapper<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: Scorer> Ord for ScorerWrapper<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.doc().cmp(&other.doc()).reverse()
}
}
impl<T: Scorer> DocSet for ScorerWrapper<T> {
fn advance(&mut self) -> DocId {
let doc_id = self.scorer.advance();
self.current_doc = doc_id;
doc_id
}
fn doc(&self) -> DocId {
self.current_doc
}
fn size_hint(&self) -> u32 {
self.scorer.size_hint()
}
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Disjunction<TScorer, TScoreCombiner> {
pub fn new<T: IntoIterator<Item = TScorer>>(
docsets: T,
score_combiner: TScoreCombiner,
minimum_matches_required: usize,
) -> Self {
debug_assert!(
minimum_matches_required > 1,
"union scorer works better if just one matches required"
);
let chains = docsets
.into_iter()
.map(|doc| ScorerWrapper::new(doc))
.collect();
let mut disjunction = Self {
chains,
score_combiner,
current_doc: TERMINATED,
minimum_matches_required,
current_score: 0.0,
};
if minimum_matches_required > disjunction.chains.len() {
return disjunction;
}
disjunction.advance();
disjunction
}
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> DocSet
for Disjunction<TScorer, TScoreCombiner>
{
fn advance(&mut self) -> DocId {
let mut current_num_matches = 0;
while let Some(mut candidate) = self.chains.pop() {
let next = candidate.doc();
if next != TERMINATED {
// Peek next doc.
if self.current_doc != next {
if current_num_matches >= self.minimum_matches_required {
self.chains.push(candidate);
self.current_score = self.score_combiner.score();
return self.current_doc;
}
// Reset current_num_matches and scores.
current_num_matches = 0;
self.current_doc = next;
self.score_combiner.clear();
}
current_num_matches += 1;
self.score_combiner.update(&mut candidate.scorer);
candidate.advance();
self.chains.push(candidate);
}
}
if current_num_matches < self.minimum_matches_required {
self.current_doc = TERMINATED;
}
self.current_score = self.score_combiner.score();
self.current_doc
}
#[inline]
fn doc(&self) -> DocId {
self.current_doc
}
fn size_hint(&self) -> u32 {
self.chains
.iter()
.map(|docset| docset.size_hint())
.max()
.unwrap_or(0u32)
}
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer
for Disjunction<TScorer, TScoreCombiner>
{
fn score(&mut self) -> Score {
self.current_score
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use super::Disjunction;
use crate::query::score_combiner::DoNothingCombiner;
use crate::query::{ConstScorer, Scorer, SumCombiner, VecDocSet};
use crate::{DocId, DocSet, Score, TERMINATED};
fn conjunct<T: Ord + Copy>(arrays: &[Vec<T>], pass_line: usize) -> Vec<T> {
let mut counts = BTreeMap::new();
for array in arrays {
for &element in array {
*counts.entry(element).or_insert(0) += 1;
}
}
counts
.iter()
.filter_map(|(&element, &count)| {
if count >= pass_line {
Some(element)
} else {
None
}
})
.collect()
}
fn aux_test_conjunction(vals: Vec<Vec<u32>>, min_match: usize) {
let mut union_expected = VecDocSet::from(conjunct(&vals, min_match));
let make_scorer = || {
Disjunction::new(
vals.iter()
.cloned()
.map(VecDocSet::from)
.map(|d| ConstScorer::new(d, 1.0)),
DoNothingCombiner,
min_match,
)
};
let mut scorer: Disjunction<_, DoNothingCombiner> = make_scorer();
let mut count = 0;
while scorer.doc() != TERMINATED {
assert_eq!(union_expected.doc(), scorer.doc());
assert_eq!(union_expected.advance(), scorer.advance());
count += 1;
}
assert_eq!(union_expected.advance(), TERMINATED);
assert_eq!(count, make_scorer().count_including_deleted());
}
#[should_panic]
#[test]
fn test_arg_check1() {
aux_test_conjunction(vec![], 0);
}
#[should_panic]
#[test]
fn test_arg_check2() {
aux_test_conjunction(vec![], 1);
}
#[test]
fn test_corner_case() {
aux_test_conjunction(vec![], 2);
aux_test_conjunction(vec![vec![]; 1000], 2);
aux_test_conjunction(vec![vec![]; 100], usize::MAX);
aux_test_conjunction(vec![vec![0xC0FFEE]; 10000], usize::MAX);
aux_test_conjunction((1..10000u32).map(|i| vec![i]).collect::<Vec<_>>(), 2);
}
#[test]
fn test_conjunction() {
aux_test_conjunction(
vec![
vec![1, 3333, 100000000u32],
vec![1, 2, 100000000u32],
vec![1, 2, 100000000u32],
],
2,
);
aux_test_conjunction(
vec![vec![8], vec![3, 4, 0xC0FFEEu32], vec![1, 2, 100000000u32]],
2,
);
aux_test_conjunction(
vec![
vec![1, 3333, 100000000u32],
vec![1, 2, 100000000u32],
vec![1, 2, 100000000u32],
],
3,
)
}
// This dummy scorer does nothing but yield doc id increasingly.
// with constant score 1.0
#[derive(Clone)]
struct DummyScorer {
cursor: usize,
foo: Vec<(DocId, f32)>,
}
impl DummyScorer {
fn new(doc_score: Vec<(DocId, f32)>) -> Self {
Self {
cursor: 0,
foo: doc_score,
}
}
}
impl DocSet for DummyScorer {
fn advance(&mut self) -> DocId {
self.cursor += 1;
self.doc()
}
fn doc(&self) -> DocId {
self.foo.get(self.cursor).map(|x| x.0).unwrap_or(TERMINATED)
}
fn size_hint(&self) -> u32 {
self.foo.len() as u32
}
}
impl Scorer for DummyScorer {
fn score(&mut self) -> Score {
self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0)
}
}
#[test]
fn test_score_calculate() {
let mut scorer = Disjunction::new(
vec![
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (4, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
],
SumCombiner::default(),
3,
);
assert_eq!(scorer.score(), 5.0);
assert_eq!(scorer.advance(), 2);
assert_eq!(scorer.score(), 3.0);
}
#[test]
fn test_score_calculate_corner_case() {
let mut scorer = Disjunction::new(
vec![
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
],
SumCombiner::default(),
2,
);
assert_eq!(scorer.doc(), 1);
assert_eq!(scorer.score(), 3.0);
assert_eq!(scorer.advance(), 3);
assert_eq!(scorer.score(), 2.0);
}
}

View File

@@ -149,7 +149,7 @@ mod tests {
use crate::query::exist_query::ExistsQuery;
use crate::query::{BooleanQuery, RangeQuery};
use crate::schema::{Facet, FacetOptions, Schema, FAST, INDEXED, STRING, TEXT};
use crate::{Index, Searcher};
use crate::{Index, Searcher, Term};
#[test]
fn test_exists_query_simple() -> crate::Result<()> {
@@ -188,9 +188,8 @@ mod tests {
// exercise seek
let query = BooleanQuery::intersection(vec![
Box::new(RangeQuery::new_u64_bounds(
"all".to_string(),
Bound::Included(50),
Box::new(RangeQuery::new(
Bound::Included(Term::from_field_u64(all_field, 50)),
Bound::Unbounded,
)),
Box::new(ExistsQuery::new_exists_query("even".to_string())),
@@ -198,10 +197,9 @@ mod tests {
assert_eq!(searcher.search(&query, &Count)?, 25);
let query = BooleanQuery::intersection(vec![
Box::new(RangeQuery::new_u64_bounds(
"all".to_string(),
Bound::Included(0),
Bound::Excluded(50),
Box::new(RangeQuery::new(
Bound::Included(Term::from_field_u64(all_field, 0)),
Bound::Included(Term::from_field_u64(all_field, 50)),
)),
Box::new(ExistsQuery::new_exists_query("odd".to_string())),
]);

View File

@@ -5,6 +5,7 @@ mod bm25;
mod boolean_query;
mod boost_query;
mod const_score_query;
mod disjunction;
mod disjunction_max_query;
mod empty_query;
mod exclude;
@@ -53,7 +54,7 @@ pub use self::phrase_prefix_query::PhrasePrefixQuery;
pub use self::phrase_query::PhraseQuery;
pub use self::query::{EnableScoring, Query, QueryClone};
pub use self::query_parser::{QueryParser, QueryParserError};
pub use self::range_query::{FastFieldRangeWeight, IPFastFieldRangeWeight, RangeQuery};
pub use self::range_query::{FastFieldRangeWeight, RangeQuery};
pub use self::regex_query::RegexQuery;
pub use self::reqopt_scorer::RequiredOptionalScorer;
pub use self::score_combiner::{

View File

@@ -145,15 +145,7 @@ impl Query for PhrasePrefixQuery {
Bound::Unbounded
};
let mut range_query = RangeQuery::new_term_bounds(
enable_scoring
.schema()
.get_field_name(self.field)
.to_owned(),
self.prefix.1.typ(),
&Bound::Included(self.prefix.1.clone()),
&end_term,
);
let mut range_query = RangeQuery::new(Bound::Included(self.prefix.1.clone()), end_term);
range_query.limit(self.max_expansions as u64);
range_query.weight(enable_scoring)
}

View File

@@ -97,6 +97,7 @@ pub struct PhrasePrefixScorer<TPostings: Postings> {
suffixes: Vec<TPostings>,
suffix_offset: u32,
phrase_count: u32,
suffix_position_buffer: Vec<u32>,
}
impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
@@ -140,6 +141,7 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
suffixes,
suffix_offset: (max_offset - suffix_pos) as u32,
phrase_count: 0,
suffix_position_buffer: Vec::with_capacity(100),
};
if phrase_prefix_scorer.doc() != TERMINATED && !phrase_prefix_scorer.matches_prefix() {
phrase_prefix_scorer.advance();
@@ -153,7 +155,6 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
fn matches_prefix(&mut self) -> bool {
let mut count = 0;
let mut positions = Vec::new();
let current_doc = self.doc();
let pos_matching = self.phrase_scorer.get_intersection();
for suffix in &mut self.suffixes {
@@ -162,8 +163,8 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
}
let doc = suffix.seek(current_doc);
if doc == current_doc {
suffix.positions_with_offset(self.suffix_offset, &mut positions);
count += intersection_count(pos_matching, &positions);
suffix.positions_with_offset(self.suffix_offset, &mut self.suffix_position_buffer);
count += intersection_count(pos_matching, &self.suffix_position_buffer);
}
}
self.phrase_count = count as u32;

View File

@@ -2,7 +2,7 @@ use std::fmt;
use std::ops::Bound;
use crate::query::Occur;
use crate::schema::{Term, Type};
use crate::schema::Term;
use crate::Score;
#[derive(Clone)]
@@ -14,8 +14,6 @@ pub enum LogicalLiteral {
prefix: bool,
},
Range {
field: String,
value_type: Type,
lower: Bound<Term>,
upper: Bound<Term>,
},

View File

@@ -790,8 +790,6 @@ impl QueryParser {
let (field, json_path) = try_tuple!(self
.split_full_path(&full_path)
.ok_or_else(|| QueryParserError::FieldDoesNotExist(full_path.clone())));
let field_entry = self.schema.get_field_entry(field);
let value_type = field_entry.field_type().value_type();
let mut errors = Vec::new();
let lower = match self.resolve_bound(field, json_path, &lower) {
Ok(bound) => bound,
@@ -812,12 +810,8 @@ impl QueryParser {
// we failed to parse something. Either way, there is no point emiting it
return (None, errors);
}
let logical_ast = LogicalAst::Leaf(Box::new(LogicalLiteral::Range {
field: self.schema.get_field_name(field).to_string(),
value_type,
lower,
upper,
}));
let logical_ast =
LogicalAst::Leaf(Box::new(LogicalLiteral::Range { lower, upper }));
(Some(logical_ast), errors)
}
UserInputLeaf::Set {
@@ -884,14 +878,7 @@ fn convert_literal_to_query(
Box::new(PhraseQuery::new_with_offset_and_slop(terms, slop))
}
}
LogicalLiteral::Range {
field,
value_type,
lower,
upper,
} => Box::new(RangeQuery::new_term_bounds(
field, value_type, &lower, &upper,
)),
LogicalLiteral::Range { lower, upper } => Box::new(RangeQuery::new(lower, upper)),
LogicalLiteral::Set { elements, .. } => Box::new(TermSetQuery::new(elements)),
LogicalLiteral::All => Box::new(AllQuery),
}
@@ -1136,8 +1123,8 @@ mod test {
let query = make_query_parser().parse_query("title:[A TO B]").unwrap();
assert_eq!(
format!("{query:?}"),
"RangeQuery { field: \"title\", value_type: Str, lower_bound: Included([97]), \
upper_bound: Included([98]), limit: None }"
"RangeQuery { lower_bound: Included(Term(field=0, type=Str, \"a\")), upper_bound: \
Included(Term(field=0, type=Str, \"b\")), limit: None }"
);
}
@@ -1815,7 +1802,8 @@ mod test {
\"bad\"))], prefix: (2, Term(field=0, type=Str, \"wo\")), max_expansions: 50 }), \
(Should, PhrasePrefixQuery { field: Field(1), phrase_terms: [(0, Term(field=1, \
type=Str, \"big\")), (1, Term(field=1, type=Str, \"bad\"))], prefix: (2, \
Term(field=1, type=Str, \"wo\")), max_expansions: 50 })] }"
Term(field=1, type=Str, \"wo\")), max_expansions: 50 })], \
minimum_number_should_match: 1 }"
);
}
@@ -1880,7 +1868,8 @@ mod test {
format!("{query:?}"),
"BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, \
type=Str, \"abc\"), distance: 1, transposition_cost_one: true, prefix: false }), \
(Should, TermQuery(Term(field=1, type=Str, \"abc\")))] }"
(Should, TermQuery(Term(field=1, type=Str, \"abc\")))], \
minimum_number_should_match: 1 }"
);
}
@@ -1897,7 +1886,8 @@ mod test {
format!("{query:?}"),
"BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, \
\"abc\"))), (Should, FuzzyTermQuery { term: Term(field=1, type=Str, \"abc\"), \
distance: 2, transposition_cost_one: false, prefix: true })] }"
distance: 2, transposition_cost_one: false, prefix: true })], \
minimum_number_should_match: 1 }"
);
}
}

View File

@@ -180,10 +180,12 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
#[cfg(test)]
mod tests {
use std::ops::Bound;
use crate::collector::Count;
use crate::directory::RamDirectory;
use crate::query::RangeQuery;
use crate::{schema, IndexBuilder, TantivyDocument};
use crate::{schema, IndexBuilder, TantivyDocument, Term};
#[test]
fn range_query_fast_optional_field_minimum() {
@@ -218,10 +220,9 @@ mod tests {
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let query = RangeQuery::new_u64_bounds(
"score".to_string(),
std::ops::Bound::Included(70),
std::ops::Bound::Unbounded,
let query = RangeQuery::new(
Bound::Included(Term::from_field_u64(score_field, 70)),
Bound::Unbounded,
);
let count = searcher.search(&query, &Count).unwrap();

View File

@@ -2,21 +2,19 @@ use std::ops::Bound;
use crate::schema::Type;
mod fast_field_range_query;
mod fast_field_range_doc_set;
mod range_query;
mod range_query_ip_fastfield;
mod range_query_u64_fastfield;
pub use self::range_query::RangeQuery;
pub use self::range_query_ip_fastfield::IPFastFieldRangeWeight;
pub use self::range_query_u64_fastfield::FastFieldRangeWeight;
// TODO is this correct?
pub(crate) fn is_type_valid_for_fastfield_range_query(typ: Type) -> bool {
match typ {
Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true,
Type::Str | Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true,
Type::IpAddr => true,
Type::Str | Type::Facet | Type::Bytes | Type::Json => false,
Type::Facet | Type::Bytes | Type::Json => false,
}
}

View File

@@ -1,21 +1,17 @@
use std::io;
use std::net::Ipv6Addr;
use std::ops::{Bound, Range};
use std::ops::Bound;
use columnar::MonotonicallyMappableToU128;
use common::{BinarySerializable, BitSet};
use common::BitSet;
use super::map_bound;
use super::range_query_u64_fastfield::FastFieldRangeWeight;
use crate::error::TantivyError;
use crate::index::SegmentReader;
use crate::query::explanation::does_not_match;
use crate::query::range_query::range_query_ip_fastfield::IPFastFieldRangeWeight;
use crate::query::range_query::{is_type_valid_for_fastfield_range_query, map_bound_res};
use crate::query::range_query::is_type_valid_for_fastfield_range_query;
use crate::query::{BitSetDocSet, ConstScorer, EnableScoring, Explanation, Query, Scorer, Weight};
use crate::schema::{Field, IndexRecordOption, Term, Type};
use crate::termdict::{TermDictionary, TermStreamer};
use crate::{DateTime, DocId, Score};
use crate::{DocId, Score};
/// `RangeQuery` matches all documents that have at least one term within a defined range.
///
@@ -40,8 +36,10 @@ use crate::{DateTime, DocId, Score};
/// ```rust
/// use tantivy::collector::Count;
/// use tantivy::query::RangeQuery;
/// use tantivy::Term;
/// use tantivy::schema::{Schema, INDEXED};
/// use tantivy::{doc, Index, IndexWriter};
/// use std::ops::Bound;
/// # fn test() -> tantivy::Result<()> {
/// let mut schema_builder = Schema::builder();
/// let year_field = schema_builder.add_u64_field("year", INDEXED);
@@ -59,7 +57,10 @@ use crate::{DateTime, DocId, Score};
///
/// let reader = index.reader()?;
/// let searcher = reader.searcher();
/// let docs_in_the_sixties = RangeQuery::new_u64("year".to_string(), 1960..1970);
/// let docs_in_the_sixties = RangeQuery::new(
/// Bound::Included(Term::from_field_u64(year_field, 1960)),
/// Bound::Excluded(Term::from_field_u64(year_field, 1970)),
/// );
/// let num_60s_books = searcher.search(&docs_in_the_sixties, &Count)?;
/// assert_eq!(num_60s_books, 2285);
/// Ok(())
@@ -68,246 +69,46 @@ use crate::{DateTime, DocId, Score};
/// ```
#[derive(Clone, Debug)]
pub struct RangeQuery {
field: String,
value_type: Type,
lower_bound: Bound<Vec<u8>>,
upper_bound: Bound<Vec<u8>>,
lower_bound: Bound<Term>,
upper_bound: Bound<Term>,
limit: Option<u64>,
}
/// Returns the inner value of a `Bound`
pub(crate) fn inner_bound(val: &Bound<Term>) -> Option<&Term> {
match val {
Bound::Included(term) | Bound::Excluded(term) => Some(term),
Bound::Unbounded => None,
}
}
impl RangeQuery {
/// Creates a new `RangeQuery` from bounded start and end terms.
///
/// If the value type is not correct, something may go terribly wrong when
/// the `Weight` object is created.
pub fn new_term_bounds(
field: String,
value_type: Type,
lower_bound: &Bound<Term>,
upper_bound: &Bound<Term>,
) -> RangeQuery {
let verify_and_unwrap_term = |val: &Term| val.serialized_value_bytes().to_owned();
pub fn new(lower_bound: Bound<Term>, upper_bound: Bound<Term>) -> RangeQuery {
RangeQuery {
field,
value_type,
lower_bound: map_bound(lower_bound, verify_and_unwrap_term),
upper_bound: map_bound(upper_bound, verify_and_unwrap_term),
lower_bound,
upper_bound,
limit: None,
}
}
/// Creates a new `RangeQuery` over a `i64` field.
///
/// If the field is not of the type `i64`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_i64(field: String, range: Range<i64>) -> RangeQuery {
RangeQuery::new_i64_bounds(
field,
Bound::Included(range.start),
Bound::Excluded(range.end),
)
}
/// Create a new `RangeQuery` over a `i64` field.
///
/// The two `Bound` arguments make it possible to create more complex
/// ranges than semi-inclusive range.
///
/// If the field is not of the type `i64`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_i64_bounds(
field: String,
lower_bound: Bound<i64>,
upper_bound: Bound<i64>,
) -> RangeQuery {
let make_term_val = |val: &i64| {
Term::from_field_i64(Field::from_field_id(0), *val)
.serialized_value_bytes()
.to_owned()
};
RangeQuery {
field,
value_type: Type::I64,
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
/// Creates a new `RangeQuery` over a `f64` field.
///
/// If the field is not of the type `f64`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_f64(field: String, range: Range<f64>) -> RangeQuery {
RangeQuery::new_f64_bounds(
field,
Bound::Included(range.start),
Bound::Excluded(range.end),
)
}
/// Create a new `RangeQuery` over a `f64` field.
///
/// The two `Bound` arguments make it possible to create more complex
/// ranges than semi-inclusive range.
///
/// If the field is not of the type `f64`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_f64_bounds(
field: String,
lower_bound: Bound<f64>,
upper_bound: Bound<f64>,
) -> RangeQuery {
let make_term_val = |val: &f64| {
Term::from_field_f64(Field::from_field_id(0), *val)
.serialized_value_bytes()
.to_owned()
};
RangeQuery {
field,
value_type: Type::F64,
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
/// Create a new `RangeQuery` over a `u64` field.
///
/// The two `Bound` arguments make it possible to create more complex
/// ranges than semi-inclusive range.
///
/// If the field is not of the type `u64`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_u64_bounds(
field: String,
lower_bound: Bound<u64>,
upper_bound: Bound<u64>,
) -> RangeQuery {
let make_term_val = |val: &u64| {
Term::from_field_u64(Field::from_field_id(0), *val)
.serialized_value_bytes()
.to_owned()
};
RangeQuery {
field,
value_type: Type::U64,
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
/// Create a new `RangeQuery` over a `ip` field.
///
/// If the field is not of the type `ip`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_ip_bounds(
field: String,
lower_bound: Bound<Ipv6Addr>,
upper_bound: Bound<Ipv6Addr>,
) -> RangeQuery {
let make_term_val = |val: &Ipv6Addr| {
Term::from_field_ip_addr(Field::from_field_id(0), *val)
.serialized_value_bytes()
.to_owned()
};
RangeQuery {
field,
value_type: Type::IpAddr,
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
/// Create a new `RangeQuery` over a `u64` field.
///
/// If the field is not of the type `u64`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_u64(field: String, range: Range<u64>) -> RangeQuery {
RangeQuery::new_u64_bounds(
field,
Bound::Included(range.start),
Bound::Excluded(range.end),
)
}
/// Create a new `RangeQuery` over a `date` field.
///
/// The two `Bound` arguments make it possible to create more complex
/// ranges than semi-inclusive range.
///
/// If the field is not of the type `date`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_date_bounds(
field: String,
lower_bound: Bound<DateTime>,
upper_bound: Bound<DateTime>,
) -> RangeQuery {
let make_term_val = |val: &DateTime| {
Term::from_field_date(Field::from_field_id(0), *val)
.serialized_value_bytes()
.to_owned()
};
RangeQuery {
field,
value_type: Type::Date,
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
/// Create a new `RangeQuery` over a `date` field.
///
/// If the field is not of the type `date`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_date(field: String, range: Range<DateTime>) -> RangeQuery {
RangeQuery::new_date_bounds(
field,
Bound::Included(range.start),
Bound::Excluded(range.end),
)
}
/// Create a new `RangeQuery` over a `Str` field.
///
/// The two `Bound` arguments make it possible to create more complex
/// ranges than semi-inclusive range.
///
/// If the field is not of the type `Str`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_str_bounds(
field: String,
lower_bound: Bound<&str>,
upper_bound: Bound<&str>,
) -> RangeQuery {
let make_term_val = |val: &&str| val.as_bytes().to_vec();
RangeQuery {
field,
value_type: Type::Str,
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
/// Create a new `RangeQuery` over a `Str` field.
///
/// If the field is not of the type `Str`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_str(field: String, range: Range<&str>) -> RangeQuery {
RangeQuery::new_str_bounds(
field,
Bound::Included(range.start),
Bound::Excluded(range.end),
)
}
/// Field to search over
pub fn field(&self) -> &str {
&self.field
pub fn field(&self) -> Field {
self.get_term().field()
}
/// The value type of the field
pub fn value_type(&self) -> Type {
self.get_term().typ()
}
pub(crate) fn get_term(&self) -> &Term {
inner_bound(&self.lower_bound)
.or(inner_bound(&self.upper_bound))
.expect("At least one bound must be set")
}
/// Limit the number of term the `RangeQuery` will go through.
@@ -319,70 +120,23 @@ impl RangeQuery {
}
}
/// Returns true if the type maps to a u64 fast field
pub(crate) fn maps_to_u64_fastfield(typ: Type) -> bool {
match typ {
Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true,
Type::IpAddr => false,
Type::Str | Type::Facet | Type::Bytes | Type::Json => false,
}
}
impl Query for RangeQuery {
fn weight(&self, enable_scoring: EnableScoring<'_>) -> crate::Result<Box<dyn Weight>> {
let schema = enable_scoring.schema();
let field_type = schema
.get_field_entry(schema.get_field(&self.field)?)
.field_type();
let value_type = field_type.value_type();
if value_type != self.value_type {
let err_msg = format!(
"Create a range query of the type {:?}, when the field given was of type \
{value_type:?}",
self.value_type
);
return Err(TantivyError::SchemaError(err_msg));
}
let field_type = schema.get_field_entry(self.field()).field_type();
if field_type.is_fast() && is_type_valid_for_fastfield_range_query(self.value_type) {
if field_type.is_ip_addr() {
let parse_ip_from_bytes = |data: &Vec<u8>| {
let ip_u128_bytes: [u8; 16] = data.as_slice().try_into().map_err(|_| {
crate::TantivyError::InvalidArgument(
"Expected 8 bytes for ip address".to_string(),
)
})?;
let ip_u128 = u128::from_be_bytes(ip_u128_bytes);
crate::Result::<Ipv6Addr>::Ok(Ipv6Addr::from_u128(ip_u128))
};
let lower_bound = map_bound_res(&self.lower_bound, parse_ip_from_bytes)?;
let upper_bound = map_bound_res(&self.upper_bound, parse_ip_from_bytes)?;
Ok(Box::new(IPFastFieldRangeWeight::new(
self.field.to_string(),
lower_bound,
upper_bound,
)))
} else {
// We run the range query on u64 value space for performance reasons and simpicity
// assert the type maps to u64
assert!(maps_to_u64_fastfield(self.value_type));
let parse_from_bytes = |data: &Vec<u8>| {
u64::from_be(BinarySerializable::deserialize(&mut &data[..]).unwrap())
};
let lower_bound = map_bound(&self.lower_bound, parse_from_bytes);
let upper_bound = map_bound(&self.upper_bound, parse_from_bytes);
Ok(Box::new(FastFieldRangeWeight::new_u64_lenient(
self.field.to_string(),
lower_bound,
upper_bound,
)))
}
if field_type.is_fast() && is_type_valid_for_fastfield_range_query(self.value_type()) {
Ok(Box::new(FastFieldRangeWeight::new(
self.field(),
self.lower_bound.clone(),
self.upper_bound.clone(),
)))
} else {
let verify_and_unwrap_term = |val: &Term| val.serialized_value_bytes().to_owned();
Ok(Box::new(RangeWeight {
field: self.field.to_string(),
lower_bound: self.lower_bound.clone(),
upper_bound: self.upper_bound.clone(),
field: self.field(),
lower_bound: map_bound(&self.lower_bound, verify_and_unwrap_term),
upper_bound: map_bound(&self.upper_bound, verify_and_unwrap_term),
limit: self.limit,
}))
}
@@ -390,7 +144,7 @@ impl Query for RangeQuery {
}
pub struct RangeWeight {
field: String,
field: Field,
lower_bound: Bound<Vec<u8>>,
upper_bound: Bound<Vec<u8>>,
limit: Option<u64>,
@@ -423,7 +177,7 @@ impl Weight for RangeWeight {
let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc);
let inverted_index = reader.inverted_index(reader.schema().get_field(&self.field)?)?;
let inverted_index = reader.inverted_index(self.field)?;
let term_dict = inverted_index.terms();
let mut term_range = self.term_range(term_dict)?;
let mut processed_count = 0;
@@ -477,7 +231,7 @@ mod tests {
use crate::schema::{
Field, IntoIpv6Addr, Schema, TantivyDocument, FAST, INDEXED, STORED, TEXT,
};
use crate::{Index, IndexWriter};
use crate::{Index, IndexWriter, Term};
#[test]
fn test_range_query_simple() -> crate::Result<()> {
@@ -499,7 +253,10 @@ mod tests {
let reader = index.reader()?;
let searcher = reader.searcher();
let docs_in_the_sixties = RangeQuery::new_u64("year".to_string(), 1960u64..1970u64);
let docs_in_the_sixties = RangeQuery::new(
Bound::Included(Term::from_field_u64(year_field, 1960)),
Bound::Excluded(Term::from_field_u64(year_field, 1970)),
);
// ... or `1960..=1969` if inclusive range is enabled.
let count = searcher.search(&docs_in_the_sixties, &Count)?;
@@ -530,7 +287,10 @@ mod tests {
let reader = index.reader()?;
let searcher = reader.searcher();
let mut docs_in_the_sixties = RangeQuery::new_u64("year".to_string(), 1960u64..1970u64);
let mut docs_in_the_sixties = RangeQuery::new(
Bound::Included(Term::from_field_u64(year_field, 1960)),
Bound::Excluded(Term::from_field_u64(year_field, 1970)),
);
docs_in_the_sixties.limit(5);
// due to the limit and no docs in 1963, it's really only 1960..=1965
@@ -575,29 +335,29 @@ mod tests {
|range_query: RangeQuery| searcher.search(&range_query, &Count).unwrap();
assert_eq!(
count_multiples(RangeQuery::new_i64("intfield".to_string(), 10..11)),
count_multiples(RangeQuery::new(
Bound::Included(Term::from_field_i64(int_field, 10)),
Bound::Excluded(Term::from_field_i64(int_field, 11)),
)),
9
);
assert_eq!(
count_multiples(RangeQuery::new_i64_bounds(
"intfield".to_string(),
Bound::Included(10),
Bound::Included(11)
count_multiples(RangeQuery::new(
Bound::Included(Term::from_field_i64(int_field, 10)),
Bound::Included(Term::from_field_i64(int_field, 11)),
)),
18
);
assert_eq!(
count_multiples(RangeQuery::new_i64_bounds(
"intfield".to_string(),
Bound::Excluded(9),
Bound::Included(10)
count_multiples(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(int_field, 9)),
Bound::Included(Term::from_field_i64(int_field, 10)),
)),
9
);
assert_eq!(
count_multiples(RangeQuery::new_i64_bounds(
"intfield".to_string(),
Bound::Included(9),
count_multiples(RangeQuery::new(
Bound::Included(Term::from_field_i64(int_field, 9)),
Bound::Unbounded
)),
91
@@ -646,29 +406,29 @@ mod tests {
|range_query: RangeQuery| searcher.search(&range_query, &Count).unwrap();
assert_eq!(
count_multiples(RangeQuery::new_f64("floatfield".to_string(), 10.0..11.0)),
count_multiples(RangeQuery::new(
Bound::Included(Term::from_field_f64(float_field, 10.0)),
Bound::Excluded(Term::from_field_f64(float_field, 11.0)),
)),
9
);
assert_eq!(
count_multiples(RangeQuery::new_f64_bounds(
"floatfield".to_string(),
Bound::Included(10.0),
Bound::Included(11.0)
count_multiples(RangeQuery::new(
Bound::Included(Term::from_field_f64(float_field, 10.0)),
Bound::Included(Term::from_field_f64(float_field, 11.0)),
)),
18
);
assert_eq!(
count_multiples(RangeQuery::new_f64_bounds(
"floatfield".to_string(),
Bound::Excluded(9.0),
Bound::Included(10.0)
count_multiples(RangeQuery::new(
Bound::Excluded(Term::from_field_f64(float_field, 9.0)),
Bound::Included(Term::from_field_f64(float_field, 10.0)),
)),
9
);
assert_eq!(
count_multiples(RangeQuery::new_f64_bounds(
"floatfield".to_string(),
Bound::Included(9.0),
count_multiples(RangeQuery::new(
Bound::Included(Term::from_field_f64(float_field, 9.0)),
Bound::Unbounded
)),
91

View File

@@ -1,512 +0,0 @@
//! IP Fastfields support efficient scanning for range queries.
//! We use this variant only if the fastfield exists, otherwise the default in `range_query` is
//! used, which uses the term dictionary + postings.
use std::net::Ipv6Addr;
use std::ops::{Bound, RangeInclusive};
use columnar::{Column, MonotonicallyMappableToU128};
use crate::query::range_query::fast_field_range_query::RangeDocSet;
use crate::query::{ConstScorer, EmptyScorer, Explanation, Scorer, Weight};
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError};
/// `IPFastFieldRangeWeight` uses the ip address fast field to execute range queries.
pub struct IPFastFieldRangeWeight {
field: String,
lower_bound: Bound<Ipv6Addr>,
upper_bound: Bound<Ipv6Addr>,
}
impl IPFastFieldRangeWeight {
/// Creates a new IPFastFieldRangeWeight.
pub fn new(field: String, lower_bound: Bound<Ipv6Addr>, upper_bound: Bound<Ipv6Addr>) -> Self {
Self {
field,
lower_bound,
upper_bound,
}
}
}
impl Weight for IPFastFieldRangeWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let Some(ip_addr_column): Option<Column<Ipv6Addr>> =
reader.fast_fields().column_opt(&self.field)?
else {
return Ok(Box::new(EmptyScorer));
};
let value_range = bound_to_value_range(
&self.lower_bound,
&self.upper_bound,
ip_addr_column.min_value(),
ip_addr_column.max_value(),
);
let docset = RangeDocSet::new(value_range, ip_addr_column);
Ok(Box::new(ConstScorer::new(docset, boost)))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc {
return Err(TantivyError::InvalidArgument(format!(
"Document #({doc}) does not match"
)));
}
let explanation = Explanation::new("Const", scorer.score());
Ok(explanation)
}
}
fn bound_to_value_range(
lower_bound: &Bound<Ipv6Addr>,
upper_bound: &Bound<Ipv6Addr>,
min_value: Ipv6Addr,
max_value: Ipv6Addr,
) -> RangeInclusive<Ipv6Addr> {
let start_value = match lower_bound {
Bound::Included(ip_addr) => *ip_addr,
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() + 1),
Bound::Unbounded => min_value,
};
let end_value = match upper_bound {
Bound::Included(ip_addr) => *ip_addr,
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() - 1),
Bound::Unbounded => max_value,
};
start_value..=end_value
}
#[cfg(test)]
pub mod tests {
use proptest::prelude::ProptestConfig;
use proptest::strategy::Strategy;
use proptest::{prop_oneof, proptest};
use super::*;
use crate::collector::Count;
use crate::query::QueryParser;
use crate::schema::{Schema, FAST, INDEXED, STORED, STRING};
use crate::{Index, IndexWriter};
#[derive(Clone, Debug)]
pub struct Doc {
pub id: String,
pub ip: Ipv6Addr,
}
fn operation_strategy() -> impl Strategy<Value = Doc> {
prop_oneof![
(0u64..10_000u64).prop_map(doc_from_id_1),
(1u64..10_000u64).prop_map(doc_from_id_2),
]
}
pub fn doc_from_id_1(id: u64) -> Doc {
let id = id * 1000;
Doc {
// ip != id
id: id.to_string(),
ip: Ipv6Addr::from_u128(id as u128),
}
}
fn doc_from_id_2(id: u64) -> Doc {
let id = id * 1000;
Doc {
// ip != id
id: (id - 1).to_string(),
ip: Ipv6Addr::from_u128(id as u128),
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_ip_range_for_docs_prop(ops in proptest::collection::vec(operation_strategy(), 1..1000)) {
assert!(test_ip_range_for_docs(&ops).is_ok());
}
}
#[test]
fn test_ip_range_regression1() {
let ops = &[doc_from_id_1(0)];
assert!(test_ip_range_for_docs(ops).is_ok());
}
#[test]
fn test_ip_range_regression2() {
let ops = &[
doc_from_id_1(52),
doc_from_id_1(63),
doc_from_id_1(12),
doc_from_id_2(91),
doc_from_id_2(33),
];
assert!(test_ip_range_for_docs(ops).is_ok());
}
#[test]
fn test_ip_range_regression3() {
let ops = &[doc_from_id_1(1), doc_from_id_1(2), doc_from_id_1(3)];
assert!(test_ip_range_for_docs(ops).is_ok());
}
#[test]
fn test_ip_range_regression3_simple() {
let mut schema_builder = Schema::builder();
let ips_field = schema_builder.add_ip_addr_field("ips", FAST | INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer: IndexWriter = index.writer_for_tests().unwrap();
let ip_addrs: Vec<Ipv6Addr> = [1000, 2000, 3000]
.into_iter()
.map(Ipv6Addr::from_u128)
.collect();
for &ip_addr in &ip_addrs {
writer
.add_document(doc!(ips_field=>ip_addr, ips_field=>ip_addr))
.unwrap();
}
writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let range_weight = IPFastFieldRangeWeight {
field: "ips".to_string(),
lower_bound: Bound::Included(ip_addrs[1]),
upper_bound: Bound::Included(ip_addrs[2]),
};
let count = range_weight.count(searcher.segment_reader(0)).unwrap();
assert_eq!(count, 2);
}
pub fn create_index_from_docs(docs: &[Doc]) -> Index {
let mut schema_builder = Schema::builder();
let ip_field = schema_builder.add_ip_addr_field("ip", STORED | FAST);
let ips_field = schema_builder.add_ip_addr_field("ips", FAST | INDEXED);
let text_field = schema_builder.add_text_field("id", STRING | STORED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(2, 60_000_000).unwrap();
for doc in docs.iter() {
index_writer
.add_document(doc!(
ips_field => doc.ip,
ips_field => doc.ip,
ip_field => doc.ip,
text_field => doc.id.to_string(),
))
.unwrap();
}
index_writer.commit().unwrap();
}
index
}
fn test_ip_range_for_docs(docs: &[Doc]) -> crate::Result<()> {
let index = create_index_from_docs(docs);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let get_num_hits = |query| searcher.search(&query, &Count).unwrap();
let query_from_text = |text: &str| {
QueryParser::for_index(&index, vec![])
.parse_query(text)
.unwrap()
};
let gen_query_inclusive = |field: &str, ip_range: &RangeInclusive<Ipv6Addr>| {
format!("{field}:[{} TO {}]", ip_range.start(), ip_range.end())
};
let test_sample = |sample_docs: &[Doc]| {
let mut ips: Vec<Ipv6Addr> = sample_docs.iter().map(|doc| doc.ip).collect();
ips.sort();
let ip_range = ips[0]..=ips[1];
let expected_num_hits = docs
.iter()
.filter(|doc| (ips[0]..=ips[1]).contains(&doc.ip))
.count();
let query = gen_query_inclusive("ip", &ip_range);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
let query = gen_query_inclusive("ips", &ip_range);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
// Intersection search
let id_filter = sample_docs[0].id.to_string();
let expected_num_hits = docs
.iter()
.filter(|doc| ip_range.contains(&doc.ip) && doc.id == id_filter)
.count();
let query = format!(
"{} AND id:{}",
gen_query_inclusive("ip", &ip_range),
&id_filter
);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
// Intersection search on multivalue ip field
let id_filter = sample_docs[0].id.to_string();
let query = format!(
"{} AND id:{}",
gen_query_inclusive("ips", &ip_range),
&id_filter
);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
};
test_sample(&[docs[0].clone(), docs[0].clone()]);
if docs.len() > 1 {
test_sample(&[docs[0].clone(), docs[1].clone()]);
test_sample(&[docs[1].clone(), docs[1].clone()]);
}
if docs.len() > 2 {
test_sample(&[docs[1].clone(), docs[2].clone()]);
}
Ok(())
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use test::Bencher;
use super::tests::*;
use super::*;
use crate::collector::Count;
use crate::query::QueryParser;
use crate::Index;
fn get_index_0_to_100() -> Index {
let mut rng = StdRng::from_seed([1u8; 32]);
let num_vals = 100_000;
let docs: Vec<_> = (0..num_vals)
.map(|_i| {
let id = if rng.gen_bool(0.01) {
"veryfew".to_string() // 1%
} else if rng.gen_bool(0.1) {
"few".to_string() // 9%
} else {
"many".to_string() // 90%
};
Doc {
id,
// Multiply by 1000, so that we create many buckets in the compact space
// The benches depend on this range to select n-percent of elements with the
// methods below.
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
}
})
.collect();
create_index_from_docs(&docs)
}
fn get_90_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(90 * 1000);
start..=end
}
fn get_10_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn get_1_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(10 * 1000);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn excute_query(
field: &str,
ip_range: RangeInclusive<Ipv6Addr>,
suffix: &str,
index: &Index,
) -> usize {
let gen_query_inclusive = |from: &Ipv6Addr, to: &Ipv6Addr| {
format!(
"{}:[{} TO {}] {}",
field,
&from.to_string(),
&to.to_string(),
suffix
)
};
let query = gen_query_inclusive(ip_range.start(), ip_range.end());
let query_from_text = |text: &str| {
QueryParser::for_index(index, vec![])
.parse_query(text)
.unwrap()
};
let query = query_from_text(&query);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
searcher.search(&query, &(Count)).unwrap()
}
#[bench]
fn bench_ip_range_hit_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_10_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_10_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_10_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_10_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_10_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_10_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:veryfew", &index));
}
}

View File

@@ -2,54 +2,34 @@
//! We use this variant only if the fastfield exists, otherwise the default in `range_query` is
//! used, which uses the term dictionary + postings.
use std::net::Ipv6Addr;
use std::ops::{Bound, RangeInclusive};
use columnar::{ColumnType, HasAssociatedColumnType, MonotonicallyMappableToU64};
use columnar::{Column, MonotonicallyMappableToU128, MonotonicallyMappableToU64, StrColumn};
use common::BinarySerializable;
use super::fast_field_range_query::RangeDocSet;
use super::map_bound;
use crate::query::{ConstScorer, EmptyScorer, Explanation, Query, Scorer, Weight};
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError};
use super::fast_field_range_doc_set::RangeDocSet;
use super::{map_bound, map_bound_res};
use crate::query::range_query::range_query::inner_bound;
use crate::query::{AllScorer, ConstScorer, EmptyScorer, Explanation, Query, Scorer, Weight};
use crate::schema::{Field, Type};
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError, Term};
/// `FastFieldRangeWeight` uses the fast field to execute range queries.
#[derive(Clone, Debug)]
pub struct FastFieldRangeWeight {
field: String,
lower_bound: Bound<u64>,
upper_bound: Bound<u64>,
column_type_opt: Option<ColumnType>,
lower_bound: Bound<Term>,
upper_bound: Bound<Term>,
field: Field,
}
impl FastFieldRangeWeight {
/// Create a new FastFieldRangeWeight, using the u64 representation of any fast field.
pub(crate) fn new_u64_lenient(
field: String,
lower_bound: Bound<u64>,
upper_bound: Bound<u64>,
) -> Self {
let lower_bound = map_bound(&lower_bound, |val| *val);
let upper_bound = map_bound(&upper_bound, |val| *val);
/// Create a new FastFieldRangeWeight
pub(crate) fn new(field: Field, lower_bound: Bound<Term>, upper_bound: Bound<Term>) -> Self {
Self {
field,
lower_bound,
upper_bound,
column_type_opt: None,
}
}
/// Create a new `FastFieldRangeWeight` for a range of a u64-mappable type .
pub fn new<T: HasAssociatedColumnType + MonotonicallyMappableToU64>(
field: String,
lower_bound: Bound<T>,
upper_bound: Bound<T>,
) -> Self {
let lower_bound = map_bound(&lower_bound, |val| val.to_u64());
let upper_bound = map_bound(&upper_bound, |val| val.to_u64());
Self {
field,
lower_bound,
upper_bound,
column_type_opt: Some(T::column_type()),
}
}
}
@@ -65,30 +45,101 @@ impl Query for FastFieldRangeWeight {
impl Weight for FastFieldRangeWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let fast_field_reader = reader.fast_fields();
let column_type_opt: Option<[ColumnType; 1]> =
self.column_type_opt.map(|column_type| [column_type]);
let column_type_opt_ref: Option<&[ColumnType]> = column_type_opt
.as_ref()
.map(|column_types| column_types.as_slice());
let Some((column, _)) =
fast_field_reader.u64_lenient_for_type(column_type_opt_ref, &self.field)?
else {
return Ok(Box::new(EmptyScorer));
};
#[allow(clippy::reversed_empty_ranges)]
let value_range = bound_to_value_range(
&self.lower_bound,
&self.upper_bound,
column.min_value(),
column.max_value(),
)
.unwrap_or(1..=0); // empty range
if value_range.is_empty() {
return Ok(Box::new(EmptyScorer));
// Check if both bounds are Bound::Unbounded
if self.lower_bound == Bound::Unbounded && self.upper_bound == Bound::Unbounded {
return Ok(Box::new(AllScorer::new(reader.max_doc())));
}
let field_name = reader.schema().get_field_name(self.field);
let field_type = reader.schema().get_field_entry(self.field).field_type();
let term = inner_bound(&self.lower_bound)
.or(inner_bound(&self.upper_bound))
.expect("At least one bound must be set");
assert_eq!(
term.typ(),
field_type.value_type(),
"Field is of type {:?}, but got term of type {:?}",
field_type,
term.typ()
);
if field_type.is_ip_addr() {
let parse_ip_from_bytes = |term: &Term| {
term.value().as_ip_addr().ok_or_else(|| {
crate::TantivyError::InvalidArgument("Expected ip address".to_string())
})
};
let lower_bound = map_bound_res(&self.lower_bound, parse_ip_from_bytes)?;
let upper_bound = map_bound_res(&self.upper_bound, parse_ip_from_bytes)?;
let Some(ip_addr_column): Option<Column<Ipv6Addr>> =
reader.fast_fields().column_opt(field_name)?
else {
return Ok(Box::new(EmptyScorer));
};
let value_range = bound_to_value_range_ip(
&lower_bound,
&upper_bound,
ip_addr_column.min_value(),
ip_addr_column.max_value(),
);
let docset = RangeDocSet::new(value_range, ip_addr_column);
Ok(Box::new(ConstScorer::new(docset, boost)))
} else {
let (lower_bound, upper_bound) = if field_type.is_str() {
let Some(str_dict_column): Option<StrColumn> =
reader.fast_fields().str(field_name)?
else {
return Ok(Box::new(EmptyScorer));
};
let dict = str_dict_column.dictionary();
let lower_bound = map_bound(&self.lower_bound, |term| {
term.serialized_value_bytes().to_vec()
});
let upper_bound = map_bound(&self.upper_bound, |term| {
term.serialized_value_bytes().to_vec()
});
// Get term ids for terms
let (lower_bound, upper_bound) =
dict.term_bounds_to_ord(lower_bound, upper_bound)?;
(lower_bound, upper_bound)
} else {
assert!(
maps_to_u64_fastfield(field_type.value_type()),
"{:?}",
field_type
);
let parse_from_bytes = |term: &Term| {
u64::from_be(
BinarySerializable::deserialize(&mut &term.serialized_value_bytes()[..])
.unwrap(),
)
};
let lower_bound = map_bound(&self.lower_bound, parse_from_bytes);
let upper_bound = map_bound(&self.upper_bound, parse_from_bytes);
(lower_bound, upper_bound)
};
let fast_field_reader = reader.fast_fields();
let Some((column, _)) = fast_field_reader.u64_lenient_for_type(None, field_name)?
else {
return Ok(Box::new(EmptyScorer));
};
#[allow(clippy::reversed_empty_ranges)]
let value_range = bound_to_value_range(
&lower_bound,
&upper_bound,
column.min_value(),
column.max_value(),
)
.unwrap_or(1..=0); // empty range
if value_range.is_empty() {
return Ok(Box::new(EmptyScorer));
}
let docset = RangeDocSet::new(value_range, column);
Ok(Box::new(ConstScorer::new(docset, boost)))
}
let docset = RangeDocSet::new(value_range, column);
Ok(Box::new(ConstScorer::new(docset, boost)))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
@@ -104,6 +155,35 @@ impl Weight for FastFieldRangeWeight {
}
}
/// Returns true if the type maps to a u64 fast field
pub(crate) fn maps_to_u64_fastfield(typ: Type) -> bool {
match typ {
Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true,
Type::IpAddr => false,
Type::Str | Type::Facet | Type::Bytes | Type::Json => false,
}
}
fn bound_to_value_range_ip(
lower_bound: &Bound<Ipv6Addr>,
upper_bound: &Bound<Ipv6Addr>,
min_value: Ipv6Addr,
max_value: Ipv6Addr,
) -> RangeInclusive<Ipv6Addr> {
let start_value = match lower_bound {
Bound::Included(ip_addr) => *ip_addr,
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() + 1),
Bound::Unbounded => min_value,
};
let end_value = match upper_bound {
Bound::Included(ip_addr) => *ip_addr,
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() - 1),
Bound::Unbounded => max_value,
};
start_value..=end_value
}
// Returns None, if the range cannot be converted to a inclusive range (which equals to a empty
// range).
fn bound_to_value_range<T: MonotonicallyMappableToU64>(
@@ -137,11 +217,72 @@ pub mod tests {
use rand::seq::SliceRandom;
use rand::SeedableRng;
use crate::collector::Count;
use crate::collector::{Count, TopDocs};
use crate::query::range_query::range_query_u64_fastfield::FastFieldRangeWeight;
use crate::query::{QueryParser, Weight};
use crate::schema::{NumericOptions, Schema, SchemaBuilder, FAST, INDEXED, STORED, STRING};
use crate::{Index, IndexWriter, TERMINATED};
use crate::schema::{
NumericOptions, Schema, SchemaBuilder, FAST, INDEXED, STORED, STRING, TEXT,
};
use crate::{Index, IndexWriter, Term, TERMINATED};
#[test]
fn test_text_field_ff_range_query() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("title", TEXT | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut index_writer = index.writer_for_tests()?;
let title = schema.get_field("title").unwrap();
index_writer.add_document(doc!(
title => "bbb"
))?;
index_writer.add_document(doc!(
title => "ddd"
))?;
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![title]);
let test_query = |query, num_hits| {
let query = query_parser.parse_query(query).unwrap();
let top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
assert_eq!(top_docs.len(), num_hits);
};
test_query("title:[aaa TO ccc]", 1);
test_query("title:[aaa TO bbb]", 1);
test_query("title:[bbb TO bbb]", 1);
test_query("title:[bbb TO ddd]", 2);
test_query("title:[bbb TO eee]", 2);
test_query("title:[bb TO eee]", 2);
test_query("title:[ccc TO ccc]", 0);
test_query("title:[ccc TO ddd]", 1);
test_query("title:[ccc TO eee]", 1);
test_query("title:[aaa TO *}", 2);
test_query("title:[bbb TO *]", 2);
test_query("title:[bb TO *]", 2);
test_query("title:[ccc TO *]", 1);
test_query("title:[ddd TO *]", 1);
test_query("title:[dddd TO *]", 0);
test_query("title:{aaa TO *}", 2);
test_query("title:{bbb TO *]", 1);
test_query("title:{bb TO *]", 2);
test_query("title:{ccc TO *]", 1);
test_query("title:{ddd TO *]", 0);
test_query("title:{dddd TO *]", 0);
test_query("title:[* TO bb]", 0);
test_query("title:[* TO bbb]", 1);
test_query("title:[* TO ccc]", 1);
test_query("title:[* TO ddd]", 2);
test_query("title:[* TO ddd}", 1);
test_query("title:[* TO eee]", 2);
Ok(())
}
#[derive(Clone, Debug)]
pub struct Doc {
@@ -159,14 +300,14 @@ pub mod tests {
fn doc_from_id_1(id: u64) -> Doc {
let id = id * 1000;
Doc {
id_name: id.to_string(),
id_name: format!("id_name{:010}", id),
id,
}
}
fn doc_from_id_2(id: u64) -> Doc {
let id = id * 1000;
Doc {
id_name: (id - 1).to_string(),
id_name: format!("id_name{:010}", id - 1),
id,
}
}
@@ -213,10 +354,10 @@ pub mod tests {
writer.add_document(doc!(field=>52_000u64)).unwrap();
writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let range_query = FastFieldRangeWeight::new_u64_lenient(
"test_field".to_string(),
Bound::Included(50_000),
Bound::Included(50_002),
let range_query = FastFieldRangeWeight::new(
field,
Bound::Included(Term::from_field_u64(field, 50_000)),
Bound::Included(Term::from_field_u64(field, 50_002)),
);
let scorer = range_query
.scorer(searcher.segment_reader(0), 1.0f32)
@@ -254,7 +395,8 @@ pub mod tests {
NumericOptions::default().set_fast().set_indexed(),
);
let text_field = schema_builder.add_text_field("id_name", STRING | STORED);
let text_field = schema_builder.add_text_field("id_name", STRING | STORED | FAST);
let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
@@ -273,6 +415,7 @@ pub mod tests {
id_f64_field => doc.id as f64,
id_i64_field => doc.id as i64,
text_field => doc.id_name.to_string(),
text_field2 => doc.id_name.to_string(),
))
.unwrap();
}
@@ -317,6 +460,24 @@ pub mod tests {
let query = gen_query_inclusive("ids", ids[0]..=ids[1]);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
// Text query
{
let test_text_query = |field_name: &str| {
let mut id_names: Vec<&str> =
sample_docs.iter().map(|doc| doc.id_name.as_str()).collect();
id_names.sort();
let expected_num_hits = docs
.iter()
.filter(|doc| (id_names[0]..=id_names[1]).contains(&doc.id_name.as_str()))
.count();
let query = format!("{}:[{} TO {}]", field_name, id_names[0], id_names[1]);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
};
test_text_query("id_name");
test_text_query("id_name_fast");
}
// Exclusive range
let expected_num_hits = docs
.iter()
@@ -394,6 +555,202 @@ pub mod tests {
}
}
#[cfg(test)]
pub mod ip_range_tests {
use proptest::prelude::ProptestConfig;
use proptest::strategy::Strategy;
use proptest::{prop_oneof, proptest};
use super::*;
use crate::collector::Count;
use crate::query::QueryParser;
use crate::schema::{Schema, FAST, INDEXED, STORED, STRING};
use crate::{Index, IndexWriter};
#[derive(Clone, Debug)]
pub struct Doc {
pub id: String,
pub ip: Ipv6Addr,
}
fn operation_strategy() -> impl Strategy<Value = Doc> {
prop_oneof![
(0u64..10_000u64).prop_map(doc_from_id_1),
(1u64..10_000u64).prop_map(doc_from_id_2),
]
}
pub fn doc_from_id_1(id: u64) -> Doc {
let id = id * 1000;
Doc {
// ip != id
id: id.to_string(),
ip: Ipv6Addr::from_u128(id as u128),
}
}
fn doc_from_id_2(id: u64) -> Doc {
let id = id * 1000;
Doc {
// ip != id
id: (id - 1).to_string(),
ip: Ipv6Addr::from_u128(id as u128),
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_ip_range_for_docs_prop(ops in proptest::collection::vec(operation_strategy(), 1..1000)) {
assert!(test_ip_range_for_docs(&ops).is_ok());
}
}
#[test]
fn test_ip_range_regression1() {
let ops = &[doc_from_id_1(0)];
assert!(test_ip_range_for_docs(ops).is_ok());
}
#[test]
fn test_ip_range_regression2() {
let ops = &[
doc_from_id_1(52),
doc_from_id_1(63),
doc_from_id_1(12),
doc_from_id_2(91),
doc_from_id_2(33),
];
assert!(test_ip_range_for_docs(ops).is_ok());
}
#[test]
fn test_ip_range_regression3() {
let ops = &[doc_from_id_1(1), doc_from_id_1(2), doc_from_id_1(3)];
assert!(test_ip_range_for_docs(ops).is_ok());
}
#[test]
fn test_ip_range_regression3_simple() {
let mut schema_builder = Schema::builder();
let ips_field = schema_builder.add_ip_addr_field("ips", FAST | INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer: IndexWriter = index.writer_for_tests().unwrap();
let ip_addrs: Vec<Ipv6Addr> = [1000, 2000, 3000]
.into_iter()
.map(Ipv6Addr::from_u128)
.collect();
for &ip_addr in &ip_addrs {
writer
.add_document(doc!(ips_field=>ip_addr, ips_field=>ip_addr))
.unwrap();
}
writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let range_weight = FastFieldRangeWeight::new(
ips_field,
Bound::Included(Term::from_field_ip_addr(ips_field, ip_addrs[1])),
Bound::Included(Term::from_field_ip_addr(ips_field, ip_addrs[2])),
);
let count =
crate::query::weight::Weight::count(&range_weight, searcher.segment_reader(0)).unwrap();
assert_eq!(count, 2);
}
pub fn create_index_from_ip_docs(docs: &[Doc]) -> Index {
let mut schema_builder = Schema::builder();
let ip_field = schema_builder.add_ip_addr_field("ip", STORED | FAST);
let ips_field = schema_builder.add_ip_addr_field("ips", FAST | INDEXED);
let text_field = schema_builder.add_text_field("id", STRING | STORED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(2, 60_000_000).unwrap();
for doc in docs.iter() {
index_writer
.add_document(doc!(
ips_field => doc.ip,
ips_field => doc.ip,
ip_field => doc.ip,
text_field => doc.id.to_string(),
))
.unwrap();
}
index_writer.commit().unwrap();
}
index
}
fn test_ip_range_for_docs(docs: &[Doc]) -> crate::Result<()> {
let index = create_index_from_ip_docs(docs);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let get_num_hits = |query| searcher.search(&query, &Count).unwrap();
let query_from_text = |text: &str| {
QueryParser::for_index(&index, vec![])
.parse_query(text)
.unwrap()
};
let gen_query_inclusive = |field: &str, ip_range: &RangeInclusive<Ipv6Addr>| {
format!("{field}:[{} TO {}]", ip_range.start(), ip_range.end())
};
let test_sample = |sample_docs: &[Doc]| {
let mut ips: Vec<Ipv6Addr> = sample_docs.iter().map(|doc| doc.ip).collect();
ips.sort();
let ip_range = ips[0]..=ips[1];
let expected_num_hits = docs
.iter()
.filter(|doc| (ips[0]..=ips[1]).contains(&doc.ip))
.count();
let query = gen_query_inclusive("ip", &ip_range);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
let query = gen_query_inclusive("ips", &ip_range);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
// Intersection search
let id_filter = sample_docs[0].id.to_string();
let expected_num_hits = docs
.iter()
.filter(|doc| ip_range.contains(&doc.ip) && doc.id == id_filter)
.count();
let query = format!(
"{} AND id:{}",
gen_query_inclusive("ip", &ip_range),
&id_filter
);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
// Intersection search on multivalue ip field
let id_filter = sample_docs[0].id.to_string();
let query = format!(
"{} AND id:{}",
gen_query_inclusive("ips", &ip_range),
&id_filter
);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
};
test_sample(&[docs[0].clone(), docs[0].clone()]);
if docs.len() > 1 {
test_sample(&[docs[0].clone(), docs[1].clone()]);
test_sample(&[docs[1].clone(), docs[1].clone()]);
}
if docs.len() > 2 {
test_sample(&[docs[1].clone(), docs[2].clone()]);
}
Ok(())
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
@@ -601,3 +958,242 @@ mod bench {
bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:veryfew", &index));
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench_ip {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use test::Bencher;
use super::ip_range_tests::*;
use super::*;
use crate::collector::Count;
use crate::query::QueryParser;
use crate::Index;
fn get_index_0_to_100() -> Index {
let mut rng = StdRng::from_seed([1u8; 32]);
let num_vals = 100_000;
let docs: Vec<_> = (0..num_vals)
.map(|_i| {
let id = if rng.gen_bool(0.01) {
"veryfew".to_string() // 1%
} else if rng.gen_bool(0.1) {
"few".to_string() // 9%
} else {
"many".to_string() // 90%
};
Doc {
id,
// Multiply by 1000, so that we create many buckets in the compact space
// The benches depend on this range to select n-percent of elements with the
// methods below.
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
}
})
.collect();
create_index_from_ip_docs(&docs)
}
fn get_90_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(90 * 1000);
start..=end
}
fn get_10_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn get_1_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(10 * 1000);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn excute_query(
field: &str,
ip_range: RangeInclusive<Ipv6Addr>,
suffix: &str,
index: &Index,
) -> usize {
let gen_query_inclusive = |from: &Ipv6Addr, to: &Ipv6Addr| {
format!(
"{}:[{} TO {}] {}",
field,
&from.to_string(),
&to.to_string(),
suffix
)
};
let query = gen_query_inclusive(ip_range.start(), ip_range.end());
let query_from_text = |text: &str| {
QueryParser::for_index(index, vec![])
.parse_query(text)
.unwrap()
};
let query = query_from_text(&query);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
searcher.search(&query, &(Count)).unwrap()
}
#[bench]
fn bench_ip_range_hit_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_10_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_10_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_10_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_10_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_10_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_10_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:veryfew", &index));
}
}

View File

@@ -508,7 +508,7 @@ impl std::fmt::Debug for ValueAddr {
/// A enum representing a value for tantivy to index.
///
/// Any changes need to be reflected in `BinarySerializable` for `ValueType`
/// ** Any changes need to be reflected in `BinarySerializable` for `ValueType` **
///
/// We can't use [schema::Type] or [columnar::ColumnType] here, because they are missing
/// some items like Array and PreTokStr.
@@ -553,7 +553,7 @@ impl BinarySerializable for ValueType {
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let num = u8::deserialize(reader)?;
let type_id = if (0..=12).contains(&num) {
unsafe { std::mem::transmute(num) }
unsafe { std::mem::transmute::<u8, ValueType>(num) }
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidData,

View File

@@ -201,6 +201,11 @@ impl FieldType {
matches!(self, FieldType::IpAddr(_))
}
/// returns true if this is an str field
pub fn is_str(&self) -> bool {
matches!(self, FieldType::Str(_))
}
/// returns true if this is an date field
pub fn is_date(&self) -> bool {
matches!(self, FieldType::Date(_))

View File

@@ -249,15 +249,8 @@ impl Term {
#[inline]
pub fn append_path(&mut self, bytes: &[u8]) -> &mut [u8] {
let len_before = self.0.len();
if bytes.contains(&JSON_END_OF_PATH) {
self.0.extend(
bytes
.iter()
.map(|&b| if b == JSON_END_OF_PATH { b'0' } else { b }),
);
} else {
self.0.extend_from_slice(bytes);
}
assert!(!bytes.contains(&JSON_END_OF_PATH));
self.0.extend_from_slice(bytes);
&mut self.0[len_before..]
}
}

View File

@@ -16,9 +16,7 @@ fn make_test_sstable(suffix: &str) -> FileSlice {
let table = builder.finish().unwrap();
let table = Arc::new(OwnedBytes::new(table));
let slice = common::file_slice::FileSlice::new(table.clone());
slice
common::file_slice::FileSlice::new(table.clone())
}
pub fn criterion_benchmark(c: &mut Criterion) {

View File

@@ -7,7 +7,7 @@ use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use tantivy_sstable::{Dictionary, MonotonicU64SSTable};
const CHARSET: &'static [u8] = b"abcdefghij";
const CHARSET: &[u8] = b"abcdefghij";
fn generate_key(rng: &mut impl Rng) -> String {
let len = rng.gen_range(3..12);

View File

@@ -56,6 +56,53 @@ impl Dictionary<VoidSSTable> {
}
}
fn map_bound<TFrom, TTo>(bound: &Bound<TFrom>, transform: impl Fn(&TFrom) -> TTo) -> Bound<TTo> {
use self::Bound::*;
match bound {
Excluded(ref from_val) => Bound::Excluded(transform(from_val)),
Included(ref from_val) => Bound::Included(transform(from_val)),
Unbounded => Unbounded,
}
}
/// Takes a bound and transforms the inner value into a new bound via a closure.
/// The bound variant may change by the value returned value from the closure.
fn transform_bound_inner<TFrom, TTo>(
bound: &Bound<TFrom>,
transform: impl Fn(&TFrom) -> io::Result<Bound<TTo>>,
) -> io::Result<Bound<TTo>> {
use self::Bound::*;
Ok(match bound {
Excluded(ref from_val) => transform(from_val)?,
Included(ref from_val) => transform(from_val)?,
Unbounded => Unbounded,
})
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TermOrdHit {
/// Exact term ord hit
Exact(TermOrdinal),
/// Next best term ordinal
Next(TermOrdinal),
}
impl TermOrdHit {
fn into_exact(self) -> Option<TermOrdinal> {
match self {
TermOrdHit::Exact(ord) => Some(ord),
TermOrdHit::Next(_) => None,
}
}
fn map<F: FnOnce(TermOrdinal) -> TermOrdinal>(self, f: F) -> Self {
match self {
TermOrdHit::Exact(ord) => TermOrdHit::Exact(f(ord)),
TermOrdHit::Next(ord) => TermOrdHit::Next(f(ord)),
}
}
}
impl<TSSTable: SSTable> Dictionary<TSSTable> {
pub fn builder<W: io::Write>(wrt: W) -> io::Result<crate::Writer<W, TSSTable::ValueWriter>> {
Ok(TSSTable::writer(wrt))
@@ -257,6 +304,17 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
key: K,
sstable_delta_reader: &mut DeltaReader<TSSTable::ValueReader>,
) -> io::Result<Option<TermOrdinal>> {
self.decode_up_to_or_next(key, sstable_delta_reader)
.map(|hit| hit.into_exact())
}
/// Decode a DeltaReader up to key, returning the number of terms traversed
///
/// If the key was not found, it returns the next term id.
fn decode_up_to_or_next<K: AsRef<[u8]>>(
&self,
key: K,
sstable_delta_reader: &mut DeltaReader<TSSTable::ValueReader>,
) -> io::Result<TermOrdHit> {
let mut term_ord = 0;
let key_bytes = key.as_ref();
let mut ok_bytes = 0;
@@ -265,7 +323,7 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
let suffix = sstable_delta_reader.suffix();
match prefix_len.cmp(&ok_bytes) {
Ordering::Less => return Ok(None), // popped bytes already matched => too far
Ordering::Less => return Ok(TermOrdHit::Next(term_ord)), /* popped bytes already matched => too far */
Ordering::Equal => (),
Ordering::Greater => {
// the ok prefix is less than current entry prefix => continue to next elem
@@ -277,25 +335,26 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
// we have ok_bytes byte of common prefix, check if this key adds more
for (key_byte, suffix_byte) in key_bytes[ok_bytes..].iter().zip(suffix) {
match suffix_byte.cmp(key_byte) {
Ordering::Less => break, // byte too small
Ordering::Equal => ok_bytes += 1, // new matching byte
Ordering::Greater => return Ok(None), // too far
Ordering::Less => break, // byte too small
Ordering::Equal => ok_bytes += 1, // new matching
// byte
Ordering::Greater => return Ok(TermOrdHit::Next(term_ord)), // too far
}
}
if ok_bytes == key_bytes.len() {
if prefix_len + suffix.len() == ok_bytes {
return Ok(Some(term_ord));
return Ok(TermOrdHit::Exact(term_ord));
} else {
// current key is a prefix of current element, not a match
return Ok(None);
return Ok(TermOrdHit::Next(term_ord));
}
}
term_ord += 1;
}
Ok(None)
Ok(TermOrdHit::Next(term_ord))
}
/// Returns the ordinal associated with a given term.
@@ -312,6 +371,61 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
.map(|opt| opt.map(|ord| ord + first_ordinal))
}
/// Returns the ordinal associated with a given term or its closest next term_id
/// The closest next term_id may not exist.
pub fn term_ord_or_next<K: AsRef<[u8]>>(&self, key: K) -> io::Result<TermOrdHit> {
let key_bytes = key.as_ref();
let Some(block_addr) = self.sstable_index.get_block_with_key(key_bytes) else {
// TODO: Would be more consistent to return last_term id + 1
return Ok(TermOrdHit::Next(u64::MAX));
};
let first_ordinal = block_addr.first_ordinal;
let mut sstable_delta_reader = self.sstable_delta_reader_block(block_addr)?;
self.decode_up_to_or_next(key_bytes, &mut sstable_delta_reader)
.map(|opt| opt.map(|ord| ord + first_ordinal))
}
/// Converts strings into a Bound range.
/// This does handle several special cases if the term is not exactly in the dictionary.
/// e.g. [bbb, ddd]
/// lower_bound: Bound::Included(aaa) => Included(0) // "Next" term id
/// lower_bound: Bound::Excluded(aaa) => Included(0) // "Next" term id + Change the Bounds
/// lower_bound: Bound::Included(ccc) => Included(1) // "Next" term id
/// lower_bound: Bound::Excluded(ccc) => Included(1) // "Next" term id + Change the Bounds
/// lower_bound: Bound::Included(zzz) => Included(2) // "Next" term id
/// lower_bound: Bound::Excluded(zzz) => Included(2) // "Next" term id + Change the Bounds
/// For zzz we should have some post processing to return an empty query`
///
/// upper_bound: Bound::Included(aaa) => Excluded(0) // "Next" term id + Change the bounds
/// upper_bound: Bound::Excluded(aaa) => Excluded(0) // "Next" term id
/// upper_bound: Bound::Included(ccc) => Excluded(1) // Next term id + Change the bounds
/// upper_bound: Bound::Excluded(ccc) => Excluded(1) // Next term id
/// upper_bound: Bound::Included(zzz) => Excluded(2) // Next term id + Change the bounds
/// upper_bound: Bound::Excluded(zzz) => Excluded(2) // Next term id
pub fn term_bounds_to_ord<K: AsRef<[u8]>>(
&self,
lower_bound: Bound<K>,
upper_bound: Bound<K>,
) -> io::Result<(Bound<TermOrdinal>, Bound<TermOrdinal>)> {
let lower_bound = transform_bound_inner(&lower_bound, |start_bound_bytes| {
let ord = self.term_ord_or_next(start_bound_bytes)?;
match ord {
TermOrdHit::Exact(ord) => Ok(map_bound(&lower_bound, |_| ord)),
TermOrdHit::Next(ord) => Ok(Bound::Included(ord)), // Change bounds to included
}
})?;
let upper_bound = transform_bound_inner(&upper_bound, |end_bound_bytes| {
let ord = self.term_ord_or_next(end_bound_bytes)?;
match ord {
TermOrdHit::Exact(ord) => Ok(map_bound(&upper_bound, |_| ord)),
TermOrdHit::Next(ord) => Ok(Bound::Excluded(ord)), // Change bounds to excluded
}
})?;
Ok((lower_bound, upper_bound))
}
/// Returns the term associated with a given term ordinal.
///
/// Term ordinals are defined as the position of the term in
@@ -338,6 +452,45 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
Ok(true)
}
/// Returns the terms for a _sorted_ list of term ordinals.
///
/// Returns true if and only if all terms have been found.
pub fn sorted_ords_to_term_cb<F: FnMut(&[u8]) -> io::Result<()>>(
&self,
ord: impl Iterator<Item = TermOrdinal>,
mut cb: F,
) -> io::Result<bool> {
let mut bytes = Vec::new();
let mut current_block_addr = self.sstable_index.get_block_with_ord(0);
let mut current_sstable_delta_reader =
self.sstable_delta_reader_block(current_block_addr.clone())?;
let mut current_ordinal = 0;
for ord in ord {
assert!(ord >= current_ordinal);
// check if block changed for new term_ord
let new_block_addr = self.sstable_index.get_block_with_ord(ord);
if new_block_addr != current_block_addr {
current_block_addr = new_block_addr;
current_ordinal = current_block_addr.first_ordinal;
current_sstable_delta_reader =
self.sstable_delta_reader_block(current_block_addr.clone())?;
bytes.clear();
}
// move to ord inside that block
for _ in current_ordinal..=ord {
if !current_sstable_delta_reader.advance()? {
return Ok(false);
}
bytes.truncate(current_sstable_delta_reader.common_prefix_len());
bytes.extend_from_slice(current_sstable_delta_reader.suffix());
}
current_ordinal = ord + 1;
cb(&bytes)?;
}
Ok(true)
}
/// Returns the number of terms in the dictionary.
pub fn term_info_from_ord(&self, term_ord: TermOrdinal) -> io::Result<Option<TSSTable::Value>> {
// find block in which the term would be
@@ -416,12 +569,13 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
#[cfg(test)]
mod tests {
use std::ops::Range;
use std::ops::{Bound, Range};
use std::sync::{Arc, Mutex};
use common::OwnedBytes;
use super::Dictionary;
use crate::dictionary::TermOrdHit;
use crate::MonotonicU64SSTable;
#[derive(Debug)]
@@ -485,6 +639,140 @@ mod tests {
(dictionary, table)
}
#[test]
fn test_term_to_ord_or_next() {
let dict = {
let mut builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new()).unwrap();
builder.insert(b"bbb", &1).unwrap();
builder.insert(b"ddd", &2).unwrap();
let table = builder.finish().unwrap();
let table = Arc::new(PermissionedHandle::new(table));
let slice = common::file_slice::FileSlice::new(table.clone());
Dictionary::<MonotonicU64SSTable>::open(slice).unwrap()
};
assert_eq!(dict.term_ord_or_next(b"aaa").unwrap(), TermOrdHit::Next(0));
assert_eq!(dict.term_ord_or_next(b"bbb").unwrap(), TermOrdHit::Exact(0));
assert_eq!(dict.term_ord_or_next(b"bb").unwrap(), TermOrdHit::Next(0));
assert_eq!(dict.term_ord_or_next(b"bbbb").unwrap(), TermOrdHit::Next(1));
assert_eq!(dict.term_ord_or_next(b"dd").unwrap(), TermOrdHit::Next(1));
assert_eq!(dict.term_ord_or_next(b"ddd").unwrap(), TermOrdHit::Exact(1));
assert_eq!(dict.term_ord_or_next(b"dddd").unwrap(), TermOrdHit::Next(2));
// This is not u64::MAX because for very small sstables (only one block),
// we don't store an index, and the pseudo-index always reply that the
// answer lies in block number 0
assert_eq!(
dict.term_ord_or_next(b"zzzzzzz").unwrap(),
TermOrdHit::Next(2)
);
}
#[test]
fn test_term_to_ord_or_next_2() {
let dict = {
let mut builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new()).unwrap();
let mut term_ord = 0;
builder.insert(b"bbb", &term_ord).unwrap();
// Fill blocks in between
for elem in 0..50_000 {
term_ord += 1;
let key = format!("ccccc{elem:05X}").into_bytes();
builder.insert(&key, &term_ord).unwrap();
}
term_ord += 1;
builder.insert(b"eee", &term_ord).unwrap();
let table = builder.finish().unwrap();
let table = Arc::new(PermissionedHandle::new(table));
let slice = common::file_slice::FileSlice::new(table.clone());
Dictionary::<MonotonicU64SSTable>::open(slice).unwrap()
};
assert_eq!(dict.term_ord(b"bbb").unwrap(), Some(0));
assert_eq!(dict.term_ord_or_next(b"bbb").unwrap(), TermOrdHit::Exact(0));
assert_eq!(dict.term_ord_or_next(b"aaa").unwrap(), TermOrdHit::Next(0));
assert_eq!(dict.term_ord_or_next(b"bb").unwrap(), TermOrdHit::Next(0));
assert_eq!(dict.term_ord_or_next(b"bbbb").unwrap(), TermOrdHit::Next(1));
assert_eq!(
dict.term_ord_or_next(b"ee").unwrap(),
TermOrdHit::Next(50001)
);
assert_eq!(
dict.term_ord_or_next(b"eee").unwrap(),
TermOrdHit::Exact(50001)
);
assert_eq!(
dict.term_ord_or_next(b"eeee").unwrap(),
TermOrdHit::Next(u64::MAX)
);
assert_eq!(
dict.term_ord_or_next(b"zzzzzzz").unwrap(),
TermOrdHit::Next(u64::MAX)
);
}
#[test]
fn test_term_bounds_to_ord() {
let dict = {
let mut builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new()).unwrap();
builder.insert(b"bbb", &1).unwrap();
builder.insert(b"ddd", &2).unwrap();
let table = builder.finish().unwrap();
let table = Arc::new(PermissionedHandle::new(table));
let slice = common::file_slice::FileSlice::new(table.clone());
Dictionary::<MonotonicU64SSTable>::open(slice).unwrap()
};
// Test cases for lower_bound
let test_lower_bound = |bound, expected| {
assert_eq!(
dict.term_bounds_to_ord::<&[u8]>(bound, Bound::Included(b"ignored"))
.unwrap()
.0,
expected
);
};
test_lower_bound(Bound::Included(b"aaa".as_slice()), Bound::Included(0));
test_lower_bound(Bound::Excluded(b"aaa".as_slice()), Bound::Included(0));
test_lower_bound(Bound::Included(b"bbb".as_slice()), Bound::Included(0));
test_lower_bound(Bound::Excluded(b"bbb".as_slice()), Bound::Excluded(0));
test_lower_bound(Bound::Included(b"ccc".as_slice()), Bound::Included(1));
test_lower_bound(Bound::Excluded(b"ccc".as_slice()), Bound::Included(1));
test_lower_bound(Bound::Included(b"zzz".as_slice()), Bound::Included(2));
test_lower_bound(Bound::Excluded(b"zzz".as_slice()), Bound::Included(2));
// Test cases for upper_bound
let test_upper_bound = |bound, expected| {
assert_eq!(
dict.term_bounds_to_ord::<&[u8]>(Bound::Included(b"ignored"), bound,)
.unwrap()
.1,
expected
);
};
test_upper_bound(Bound::Included(b"ccc".as_slice()), Bound::Excluded(1));
test_upper_bound(Bound::Excluded(b"ccc".as_slice()), Bound::Excluded(1));
test_upper_bound(Bound::Included(b"zzz".as_slice()), Bound::Excluded(2));
test_upper_bound(Bound::Excluded(b"zzz".as_slice()), Bound::Excluded(2));
test_upper_bound(Bound::Included(b"ddd".as_slice()), Bound::Included(1));
test_upper_bound(Bound::Excluded(b"ddd".as_slice()), Bound::Excluded(1));
}
#[test]
fn test_ord_term_conversion() {
let (dic, slice) = make_test_sstable();
@@ -551,6 +839,61 @@ mod tests {
assert!(dic.term_ord(b"1000").unwrap().is_none());
}
#[test]
fn test_ords_term() {
let (dic, _slice) = make_test_sstable();
// Single term
let mut terms = Vec::new();
assert!(dic
.sorted_ords_to_term_cb(100_000..100_001, |term| {
terms.push(term.to_vec());
Ok(())
})
.unwrap());
assert_eq!(terms, vec![format!("{:05X}", 100_000).into_bytes(),]);
// Single term
let mut terms = Vec::new();
assert!(dic
.sorted_ords_to_term_cb(100_001..100_002, |term| {
terms.push(term.to_vec());
Ok(())
})
.unwrap());
assert_eq!(terms, vec![format!("{:05X}", 100_001).into_bytes(),]);
// both terms
let mut terms = Vec::new();
assert!(dic
.sorted_ords_to_term_cb(100_000..100_002, |term| {
terms.push(term.to_vec());
Ok(())
})
.unwrap());
assert_eq!(
terms,
vec![
format!("{:05X}", 100_000).into_bytes(),
format!("{:05X}", 100_001).into_bytes(),
]
);
// Test cross block
let mut terms = Vec::new();
assert!(dic
.sorted_ords_to_term_cb(98653..=98655, |term| {
terms.push(term.to_vec());
Ok(())
})
.unwrap());
assert_eq!(
terms,
vec![
format!("{:05X}", 98653).into_bytes(),
format!("{:05X}", 98654).into_bytes(),
format!("{:05X}", 98655).into_bytes(),
]
);
}
#[test]
fn test_range() {
let (dic, slice) = make_test_sstable();

View File

@@ -78,6 +78,7 @@ impl ValueWriter for RangeValueWriter {
}
#[cfg(test)]
#[allow(clippy::single_range_in_vec_init)]
mod tests {
use super::*;

View File

@@ -39,8 +39,6 @@ pub fn deserialize_read(buf: &[u8]) -> (usize, u64) {
#[cfg(test)]
mod tests {
use std::u64;
use super::{deserialize_read, serialize};
fn aux_test_int(val: u64, expect_len: usize) {

View File

@@ -54,7 +54,7 @@ fn bench_hashmap_throughput(c: &mut Criterion) {
);
// numbers
let input_bytes = 1_000_000 * 8 as u64;
let input_bytes = 1_000_000 * 8;
group.throughput(Throughput::Bytes(input_bytes));
let numbers: Vec<[u8; 8]> = (0..1_000_000u64).map(|el| el.to_le_bytes()).collect();
@@ -82,7 +82,7 @@ fn bench_hashmap_throughput(c: &mut Criterion) {
let mut rng = StdRng::from_seed([3u8; 32]);
let zipf = zipf::ZipfDistribution::new(10_000, 1.03).unwrap();
let input_bytes = 1_000_000 * 8 as u64;
let input_bytes = 1_000_000 * 8;
group.throughput(Throughput::Bytes(input_bytes));
let zipf_numbers: Vec<[u8; 8]> = (0..1_000_000u64)
.map(|_| zipf.sample(&mut rng).to_le_bytes())
@@ -110,7 +110,7 @@ impl DocIdRecorder {
}
}
fn create_hash_map<'a, T: AsRef<[u8]>>(terms: impl Iterator<Item = T>) -> ArenaHashMap {
fn create_hash_map<T: AsRef<[u8]>>(terms: impl Iterator<Item = T>) -> ArenaHashMap {
let mut map = ArenaHashMap::with_capacity(HASHMAP_SIZE);
for term in terms {
map.mutate_or_create(term.as_ref(), |val| {
@@ -126,7 +126,7 @@ fn create_hash_map<'a, T: AsRef<[u8]>>(terms: impl Iterator<Item = T>) -> ArenaH
map
}
fn create_hash_map_with_expull<'a, T: AsRef<[u8]>>(
fn create_hash_map_with_expull<T: AsRef<[u8]>>(
terms: impl Iterator<Item = (u32, T)>,
) -> ArenaHashMap {
let mut memory_arena = MemoryArena::default();
@@ -145,7 +145,7 @@ fn create_hash_map_with_expull<'a, T: AsRef<[u8]>>(
map
}
fn create_fx_hash_ref_map_with_expull<'a>(
fn create_fx_hash_ref_map_with_expull(
terms: impl Iterator<Item = &'static [u8]>,
) -> FxHashMap<&'static [u8], Vec<u32>> {
let terms = terms.enumerate();
@@ -158,7 +158,7 @@ fn create_fx_hash_ref_map_with_expull<'a>(
map
}
fn create_fx_hash_owned_map_with_expull<'a>(
fn create_fx_hash_owned_map_with_expull(
terms: impl Iterator<Item = &'static [u8]>,
) -> FxHashMap<Vec<u8>, Vec<u32>> {
let terms = terms.enumerate();

View File

@@ -10,7 +10,7 @@ fn main() {
}
}
fn create_hash_map<'a, T: AsRef<str>>(terms: impl Iterator<Item = T>) -> ArenaHashMap {
fn create_hash_map<T: AsRef<str>>(terms: impl Iterator<Item = T>) -> ArenaHashMap {
let mut map = ArenaHashMap::with_capacity(4);
for term in terms {
map.mutate_or_create(term.as_ref().as_bytes(), |val| {