mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-05 16:52:55 +00:00
Compare commits
108 Commits
low_card_o
...
stuhood.la
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
147214b0eb | ||
|
|
865a12f4bb | ||
|
|
00110312c9 | ||
|
|
b2e980b450 | ||
|
|
1a701b86bd | ||
|
|
ee4538d6c2 | ||
|
|
25f1e9aa9f | ||
|
|
6b03b28bac | ||
|
|
7a5241cb83 | ||
|
|
0f5e0f6f87 | ||
|
|
a654115d9a | ||
|
|
1a17515ead | ||
|
|
0f1b0ce527 | ||
|
|
0c920dfc61 | ||
|
|
996fc936f6 | ||
|
|
5ff38e1605 | ||
|
|
e8a4adeedd | ||
|
|
efc9e585a9 | ||
|
|
f4252fc184 | ||
|
|
53c067d1f3 | ||
|
|
259c1ed965 | ||
|
|
1afc432df8 | ||
|
|
b8acd3ac94 | ||
|
|
b5321d2125 | ||
|
|
ad3e2363fe | ||
|
|
9ec5750c25 | ||
|
|
03f09a2b5b | ||
|
|
9ffe4af096 | ||
|
|
c56ddcb6d7 | ||
|
|
5b8fff154b | ||
|
|
ff6ee3a5db | ||
|
|
eda9aa437f | ||
|
|
538da08eb5 | ||
|
|
7bd5cc5417 | ||
|
|
5d46137556 | ||
|
|
92c784f697 | ||
|
|
b3541d10e1 | ||
|
|
7183ac6cbc | ||
|
|
e0476d2eb2 | ||
|
|
9fe0899934 | ||
|
|
aaa5abb7d6 | ||
|
|
f8b8fd0321 | ||
|
|
cd878a5c90 | ||
|
|
30c237e895 | ||
|
|
b6cd39872b | ||
|
|
c96d801c68 | ||
|
|
7a13e0294d | ||
|
|
20d00701ee | ||
|
|
526afc6111 | ||
|
|
f9e4a8413b | ||
|
|
58124bb164 | ||
|
|
176f7e852a | ||
|
|
cfa5f94114 | ||
|
|
5e449e7dda | ||
|
|
1617459b01 | ||
|
|
0e1a7e213e | ||
|
|
b0660ba196 | ||
|
|
936d6af471 | ||
|
|
2560de3a01 | ||
|
|
75a8384c2b | ||
|
|
5b6da9123c | ||
|
|
8b7db36c99 | ||
|
|
eabe589814 | ||
|
|
65d3574dfd | ||
|
|
26d623c411 | ||
|
|
0552dddeb9 | ||
|
|
1b88bb61f9 | ||
|
|
16da31cf06 | ||
|
|
658b9b22e0 | ||
|
|
95661fba30 | ||
|
|
ddd169b77c | ||
|
|
bb4c4b8522 | ||
|
|
ffa558e3a9 | ||
|
|
a35e3dcb5a | ||
|
|
1e3998fbad | ||
|
|
f3df079d6b | ||
|
|
f7c0335857 | ||
|
|
2584325e0d | ||
|
|
1f2c2d0c8a | ||
|
|
91db6909d1 | ||
|
|
7639b47615 | ||
|
|
8b55f0f355 | ||
|
|
8d29f19110 | ||
|
|
d742d3277a | ||
|
|
3afe3714a2 | ||
|
|
67ea8e53a8 | ||
|
|
3adc85c017 | ||
|
|
6bb3a22c98 | ||
|
|
5503cfb8ef | ||
|
|
ea0e88ae4b | ||
|
|
dee2dd3f21 | ||
|
|
794ff1ffc9 | ||
|
|
c6912ce89a | ||
|
|
618e3bd11b | ||
|
|
b2f99c6217 | ||
|
|
76de5bab6f | ||
|
|
b7eb31162b | ||
|
|
63c66005db | ||
|
|
7d513a44c5 | ||
|
|
ca87fcd454 | ||
|
|
08a92675dc | ||
|
|
f7f4b354d6 | ||
|
|
25d44fcec8 | ||
|
|
842fe9295f | ||
|
|
f88b7200b2 | ||
|
|
8725594d47 | ||
|
|
43a784671a | ||
|
|
c363bbd23d |
29
.github/workflows/coverage.yml
vendored
29
.github/workflows/coverage.yml
vendored
@@ -1,29 +0,0 @@
|
||||
name: Coverage
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
# Ensures that we cancel running jobs for the same PR / same workflow.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
coverage:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install Rust
|
||||
run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- uses: taiki-e/install-action@cargo-llvm-cov
|
||||
- name: Generate code coverage
|
||||
run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
continue-on-error: true
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
|
||||
files: lcov.info
|
||||
fail_ci_if_error: true
|
||||
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -76,7 +76,9 @@ jobs:
|
||||
profile: minimal
|
||||
override: true
|
||||
|
||||
- uses: taiki-e/install-action@nextest
|
||||
- uses: taiki-e/install-action@v2
|
||||
with:
|
||||
tool: 'nextest'
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Run tests
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -6,7 +6,6 @@ target
|
||||
target/debug
|
||||
.vscode
|
||||
target/release
|
||||
Cargo.lock
|
||||
benchmark
|
||||
.DS_Store
|
||||
*.bk
|
||||
@@ -15,3 +14,7 @@ trace.dat
|
||||
cargo-timing*
|
||||
control
|
||||
variable
|
||||
|
||||
# for `sample record -p`
|
||||
profile.json
|
||||
profile.json.gz
|
||||
|
||||
@@ -78,7 +78,7 @@ This will slightly increase space and access time. [#2439](https://github.com/qu
|
||||
|
||||
- **Store DateTime as nanoseconds in doc store** DateTime in the doc store was truncated to microseconds previously. This removes this truncation, while still keeping backwards compatibility. [#2486](https://github.com/quickwit-oss/tantivy/pull/2486)(@PSeitz)
|
||||
|
||||
- **Performace/Memory**
|
||||
- **Performance/Memory**
|
||||
- lift clauses in LogicalAst for optimized ast during execution [#2449](https://github.com/quickwit-oss/tantivy/pull/2449)(@PSeitz)
|
||||
- Use Vec instead of BTreeMap to back OwnedValue object [#2364](https://github.com/quickwit-oss/tantivy/pull/2364)(@fulmicoton)
|
||||
- Replace TantivyDocument with CompactDoc. CompactDoc is much smaller and provides similar performance. [#2402](https://github.com/quickwit-oss/tantivy/pull/2402)(@PSeitz)
|
||||
|
||||
2361
Cargo.lock
generated
Normal file
2361
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
34
Cargo.toml
34
Cargo.toml
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tantivy"
|
||||
version = "0.25.0"
|
||||
version = "0.26.0"
|
||||
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
|
||||
license = "MIT"
|
||||
categories = ["database-implementations", "data-structures"]
|
||||
@@ -21,11 +21,11 @@ byteorder = "1.4.3"
|
||||
crc32fast = "1.3.2"
|
||||
once_cell = "1.10.0"
|
||||
regex = { version = "1.5.5", default-features = false, features = [
|
||||
"std",
|
||||
"unicode",
|
||||
"std",
|
||||
"unicode",
|
||||
] }
|
||||
aho-corasick = "1.0"
|
||||
tantivy-fst = "0.5"
|
||||
tantivy-fst = { git = "https://github.com/paradedb/fst.git" }
|
||||
memmap2 = { version = "0.9.0", optional = true }
|
||||
lz4_flex = { version = "0.11", default-features = false, optional = true }
|
||||
zstd = { version = "0.13", optional = true, default-features = false }
|
||||
@@ -38,9 +38,10 @@ levenshtein_automata = "0.2.1"
|
||||
uuid = { version = "1.0.0", features = ["v4", "serde"] }
|
||||
crossbeam-channel = "0.5.4"
|
||||
rust-stemmers = "1.2.0"
|
||||
tantivy-stemmers = { version = "0.4.0", default-features = false, features = ["polish_yarovoy"] }
|
||||
downcast-rs = "2.0.1"
|
||||
bitpacking = { version = "0.9.2", default-features = false, features = [
|
||||
"bitpacker4x",
|
||||
"bitpacker4x",
|
||||
] }
|
||||
census = "0.4.2"
|
||||
rustc-hash = "2.0.0"
|
||||
@@ -48,6 +49,10 @@ thiserror = "2.0.1"
|
||||
htmlescape = "0.3.1"
|
||||
fail = { version = "0.5.0", optional = true }
|
||||
time = { version = "0.3.35", features = ["serde-well-known"] }
|
||||
# TODO: We have integer wrappers with PartialOrd, and a misfeature of
|
||||
# `deranged` causes inference to fail in a bunch of cases. See
|
||||
# https://github.com/jhpratt/deranged/issues/18#issuecomment-2746844093
|
||||
deranged = "=0.4.0"
|
||||
smallvec = "1.8.0"
|
||||
rayon = "1.5.2"
|
||||
lru = "0.12.0"
|
||||
@@ -69,6 +74,7 @@ hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
|
||||
futures-util = { version = "0.3.28", optional = true }
|
||||
futures-channel = { version = "0.3.28", optional = true }
|
||||
fnv = "1.0.7"
|
||||
parking_lot = "0.12.4"
|
||||
typetag = "0.2.21"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
@@ -88,7 +94,7 @@ more-asserts = "0.3.1"
|
||||
rand_distr = "0.4.3"
|
||||
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
|
||||
postcard = { version = "1.0.4", features = [
|
||||
"use-std",
|
||||
"use-std",
|
||||
], default-features = false }
|
||||
|
||||
[target.'cfg(not(windows))'.dev-dependencies]
|
||||
@@ -135,14 +141,14 @@ compare_hash_only = ["stacker/compare_hash_only"]
|
||||
|
||||
[workspace]
|
||||
members = [
|
||||
"query-grammar",
|
||||
"bitpacker",
|
||||
"common",
|
||||
"ownedbytes",
|
||||
"stacker",
|
||||
"sstable",
|
||||
"tokenizer-api",
|
||||
"columnar",
|
||||
"query-grammar",
|
||||
"bitpacker",
|
||||
"common",
|
||||
"ownedbytes",
|
||||
"stacker",
|
||||
"sstable",
|
||||
"tokenizer-api",
|
||||
"columnar",
|
||||
]
|
||||
|
||||
# Following the "fail" crate best practises, we isolate
|
||||
|
||||
@@ -123,6 +123,7 @@ You can also find other bindings on [GitHub](https://github.com/search?q=tantivy
|
||||
- [seshat](https://github.com/matrix-org/seshat/): A matrix message database/indexer
|
||||
- [tantiny](https://github.com/baygeldin/tantiny): Tiny full-text search for Ruby
|
||||
- [lnx](https://github.com/lnx-search/lnx): adaptable, typo tolerant search engine with a REST API
|
||||
- [Bichon](https://github.com/rustmailer/bichon): A lightweight, high-performance Rust email archiver with WebUI
|
||||
- and [more](https://github.com/search?q=tantivy)!
|
||||
|
||||
### On average, how much faster is Tantivy compared to Lucene?
|
||||
|
||||
2
TODO.txt
2
TODO.txt
@@ -10,7 +10,7 @@ rename FastFieldReaders::open to load
|
||||
remove fast field reader
|
||||
|
||||
find a way to unify the two DateTime.
|
||||
readd type check in the filter wrapper
|
||||
re-add type check in the filter wrapper
|
||||
|
||||
add unit test on columnar list columns.
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use binggan::plugins::PeakMemAllocPlugin;
|
||||
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
|
||||
use rand::distributions::WeightedIndex;
|
||||
use rand::prelude::SliceRandom;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
@@ -54,11 +55,19 @@ fn bench_agg(mut group: InputGroup<Index>) {
|
||||
register!(group, extendedstats_f64);
|
||||
register!(group, percentiles_f64);
|
||||
register!(group, terms_few);
|
||||
register!(group, terms_all_unique);
|
||||
register!(group, terms_many);
|
||||
register!(group, terms_many_top_1000);
|
||||
register!(group, terms_many_order_by_term);
|
||||
register!(group, terms_many_with_top_hits);
|
||||
register!(group, terms_all_unique_with_avg_sub_agg);
|
||||
register!(group, terms_many_with_avg_sub_agg);
|
||||
register!(group, terms_few_with_avg_sub_agg);
|
||||
register!(group, terms_status_with_avg_sub_agg);
|
||||
register!(group, terms_status);
|
||||
register!(group, terms_few_with_histogram);
|
||||
register!(group, terms_status_with_histogram);
|
||||
|
||||
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
|
||||
|
||||
register!(group, cardinality_agg);
|
||||
@@ -130,12 +139,12 @@ fn extendedstats_f64(index: &Index) {
|
||||
}
|
||||
fn percentiles_f64(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"mypercentiles": {
|
||||
"percentiles": {
|
||||
"field": "score_f64",
|
||||
"percents": [ 95, 99, 99.9 ]
|
||||
"mypercentiles": {
|
||||
"percentiles": {
|
||||
"field": "score_f64",
|
||||
"percents": [ 95, 99, 99.9 ]
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
@@ -172,6 +181,19 @@ fn terms_few(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_status(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } },
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_all_unique(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_all_unique_terms" } },
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_many(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_many_terms" } },
|
||||
@@ -220,6 +242,63 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_all_unique_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_all_unique_terms" },
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_few_with_histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms" },
|
||||
"aggs": {
|
||||
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_status_with_histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms_status" },
|
||||
"aggs": {
|
||||
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_few_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms" },
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_status_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms_status" },
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
@@ -404,14 +483,21 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
.set_stored();
|
||||
let text_field = schema_builder.add_text_field("text", text_fieldtype);
|
||||
let json_field = schema_builder.add_json_field("json", FAST);
|
||||
let text_field_all_unique_terms =
|
||||
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
|
||||
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
|
||||
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
|
||||
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
|
||||
let text_field_few_terms_status =
|
||||
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
|
||||
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
|
||||
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
|
||||
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
|
||||
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
|
||||
let index = Index::create_from_tempdir(schema_builder.build())?;
|
||||
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
|
||||
// Approximate production log proportions: INFO dominant, WARN and DEBUG occasional, ERROR rare.
|
||||
let log_level_distribution = WeightedIndex::new([80u32, 3, 12, 5]).unwrap();
|
||||
|
||||
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
|
||||
|
||||
@@ -427,15 +513,21 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
index_writer.add_document(doc!())?;
|
||||
}
|
||||
if cardinality == Cardinality::Multivalued {
|
||||
let log_level_sample_a = few_terms_data[log_level_distribution.sample(&mut rng)];
|
||||
let log_level_sample_b = few_terms_data[log_level_distribution.sample(&mut rng)];
|
||||
index_writer.add_document(doc!(
|
||||
json_field => json!({"mixed_type": 10.0}),
|
||||
json_field => json!({"mixed_type": 10.0}),
|
||||
text_field => "cool",
|
||||
text_field => "cool",
|
||||
text_field_all_unique_terms => "cool",
|
||||
text_field_all_unique_terms => "coolo",
|
||||
text_field_many_terms => "cool",
|
||||
text_field_many_terms => "cool",
|
||||
text_field_few_terms => "cool",
|
||||
text_field_few_terms => "cool",
|
||||
text_field_few_terms_status => log_level_sample_a,
|
||||
text_field_few_terms_status => log_level_sample_b,
|
||||
score_field => 1u64,
|
||||
score_field => 1u64,
|
||||
score_field_f64 => lg_norm.sample(&mut rng),
|
||||
@@ -460,8 +552,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
index_writer.add_document(doc!(
|
||||
text_field => "cool",
|
||||
json_field => json,
|
||||
text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()),
|
||||
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
|
||||
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
|
||||
text_field_few_terms_status => few_terms_data[log_level_distribution.sample(&mut rng)],
|
||||
score_field => val as u64,
|
||||
score_field_f64 => lg_norm.sample(&mut rng),
|
||||
score_field_i64 => val as i64,
|
||||
|
||||
@@ -16,14 +16,15 @@
|
||||
// - This bench isolates boolean iteration speed and intersection/union cost.
|
||||
// - Use `cargo bench --bench boolean_conjunction` to run.
|
||||
|
||||
use binggan::{black_box, BenchRunner};
|
||||
use binggan::{black_box, BenchGroup, BenchRunner};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use tantivy::collector::{Count, TopDocs};
|
||||
use tantivy::query::QueryParser;
|
||||
use tantivy::schema::{Schema, TEXT};
|
||||
use tantivy::{doc, Index, ReloadPolicy, Searcher};
|
||||
use tantivy::collector::sort_key::SortByStaticFastValue;
|
||||
use tantivy::collector::{Collector, Count, TopDocs};
|
||||
use tantivy::query::{Query, QueryParser};
|
||||
use tantivy::schema::{Schema, FAST, TEXT};
|
||||
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BenchIndex {
|
||||
@@ -33,23 +34,6 @@ struct BenchIndex {
|
||||
query_parser: QueryParser,
|
||||
}
|
||||
|
||||
impl BenchIndex {
|
||||
#[inline(always)]
|
||||
fn count_query(&self, query_str: &str) -> usize {
|
||||
let query = self.query_parser.parse_query(query_str).unwrap();
|
||||
self.searcher.search(&query, &Count).unwrap()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn topk_len(&self, query_str: &str, k: usize) -> usize {
|
||||
let query = self.query_parser.parse_query(query_str).unwrap();
|
||||
self.searcher
|
||||
.search(&query, &TopDocs::with_limit(k))
|
||||
.unwrap()
|
||||
.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a single index containing both fields (title, body) and
|
||||
/// return two BenchIndex views:
|
||||
/// - single_field: QueryParser defaults to only "body"
|
||||
@@ -59,6 +43,8 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
|
||||
let mut schema_builder = Schema::builder();
|
||||
let f_title = schema_builder.add_text_field("title", TEXT);
|
||||
let f_body = schema_builder.add_text_field("body", TEXT);
|
||||
let f_score = schema_builder.add_u64_field("score", FAST);
|
||||
let f_score2 = schema_builder.add_u64_field("score2", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
|
||||
@@ -67,11 +53,13 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
|
||||
|
||||
// Populate: spread each present token 90/10 to body/title
|
||||
{
|
||||
let mut writer = index.writer(500_000_000).unwrap();
|
||||
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
|
||||
for _ in 0..num_docs {
|
||||
let has_a = rng.gen_bool(p_a as f64);
|
||||
let has_b = rng.gen_bool(p_b as f64);
|
||||
let has_c = rng.gen_bool(p_c as f64);
|
||||
let score = rng.gen_range(0u64..100u64);
|
||||
let score2 = rng.gen_range(0u64..100_000u64);
|
||||
let mut title_tokens: Vec<&str> = Vec::new();
|
||||
let mut body_tokens: Vec<&str> = Vec::new();
|
||||
if has_a {
|
||||
@@ -101,7 +89,9 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_title=>title_tokens.join(" "),
|
||||
f_body=>body_tokens.join(" ")
|
||||
f_body=>body_tokens.join(" "),
|
||||
f_score=>score,
|
||||
f_score2=>score2,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
@@ -153,72 +143,76 @@ fn main() {
|
||||
),
|
||||
];
|
||||
|
||||
let queries = &["a", "+a +b", "+a +b +c", "a OR b", "a OR b OR c"];
|
||||
|
||||
let mut runner = BenchRunner::new();
|
||||
for (label, n, pa, pb, pc) in scenarios {
|
||||
let (single_view, multi_view) = build_shared_indices(n, pa, pb, pc);
|
||||
|
||||
// Single-field group: default field is body only
|
||||
for (view_name, bench_index) in [("single_field", single_view), ("multi_field", multi_view)]
|
||||
{
|
||||
// Single-field group: default field is body only
|
||||
let mut group = runner.new_group();
|
||||
group.set_name(format!("single_field — {}", label));
|
||||
group.register_with_input("+a_+b_count", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("+a +b"))
|
||||
});
|
||||
group.register_with_input("+a_+b_+c_count", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("+a +b +c"))
|
||||
});
|
||||
group.register_with_input("+a_+b_top10", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("+a +b", 10))
|
||||
});
|
||||
group.register_with_input("+a_+b_+c_top10", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("+a +b +c", 10))
|
||||
});
|
||||
// OR queries
|
||||
group.register_with_input("a_OR_b_count", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("a OR b"))
|
||||
});
|
||||
group.register_with_input("a_OR_b_OR_c_count", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("a OR b OR c"))
|
||||
});
|
||||
group.register_with_input("a_OR_b_top10", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("a OR b", 10))
|
||||
});
|
||||
group.register_with_input("a_OR_b_OR_c_top10", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("a OR b OR c", 10))
|
||||
});
|
||||
group.run();
|
||||
}
|
||||
|
||||
// Multi-field group: default fields are [title, body]
|
||||
{
|
||||
let mut group = runner.new_group();
|
||||
group.set_name(format!("multi_field — {}", label));
|
||||
group.register_with_input("+a_+b_count", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("+a +b"))
|
||||
});
|
||||
group.register_with_input("+a_+b_+c_count", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("+a +b +c"))
|
||||
});
|
||||
group.register_with_input("+a_+b_top10", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("+a +b", 10))
|
||||
});
|
||||
group.register_with_input("+a_+b_+c_top10", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("+a +b +c", 10))
|
||||
});
|
||||
// OR queries
|
||||
group.register_with_input("a_OR_b_count", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("a OR b"))
|
||||
});
|
||||
group.register_with_input("a_OR_b_OR_c_count", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("a OR b OR c"))
|
||||
});
|
||||
group.register_with_input("a_OR_b_top10", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("a OR b", 10))
|
||||
});
|
||||
group.register_with_input("a_OR_b_OR_c_top10", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("a OR b OR c", 10))
|
||||
});
|
||||
group.set_name(format!("{} — {}", view_name, label));
|
||||
for query_str in queries {
|
||||
add_bench_task(&mut group, &bench_index, query_str, Count, "count");
|
||||
add_bench_task(
|
||||
&mut group,
|
||||
&bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(10).order_by_score(),
|
||||
"top10",
|
||||
);
|
||||
add_bench_task(
|
||||
&mut group,
|
||||
&bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(10).order_by_fast_field::<u64>("score", Order::Asc),
|
||||
"top10_by_ff",
|
||||
);
|
||||
add_bench_task(
|
||||
&mut group,
|
||||
&bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(10).order_by((
|
||||
SortByStaticFastValue::<u64>::for_field("score"),
|
||||
SortByStaticFastValue::<u64>::for_field("score2"),
|
||||
)),
|
||||
"top10_by_2ff",
|
||||
);
|
||||
}
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_bench_task<C: Collector + 'static>(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query_str: &str,
|
||||
collector: C,
|
||||
collector_name: &str,
|
||||
) {
|
||||
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
|
||||
let query = bench_index.query_parser.parse_query(query_str).unwrap();
|
||||
let search_task = SearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
collector,
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
struct SearchTask<C: Collector> {
|
||||
searcher: Searcher,
|
||||
collector: C,
|
||||
query: Box<dyn Query>,
|
||||
}
|
||||
|
||||
impl<C: Collector> SearchTask<C> {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
self.searcher.search(&self.query, &self.collector).unwrap();
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,9 +11,6 @@ keywords = []
|
||||
documentation = "https://docs.rs/tantivy-bitpacker/latest/tantivy_bitpacker"
|
||||
homepage = "https://github.com/quickwit-oss/tantivy"
|
||||
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
bitpacking = { version = "0.9.2", default-features = false, features = ["bitpacker1x"] }
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ impl BitPacker {
|
||||
|
||||
pub fn flush<TWrite: io::Write + ?Sized>(&mut self, output: &mut TWrite) -> io::Result<()> {
|
||||
if self.mini_buffer_written > 0 {
|
||||
let num_bytes = self.mini_buffer_written.div_ceil(8);
|
||||
let num_bytes = (self.mini_buffer_written + 7) / 8;
|
||||
let bytes = self.mini_buffer.to_le_bytes();
|
||||
output.write_all(&bytes[..num_bytes])?;
|
||||
self.mini_buffer_written = 0;
|
||||
@@ -65,10 +65,16 @@ impl BitPacker {
|
||||
|
||||
#[derive(Clone, Debug, Default, Copy)]
|
||||
pub struct BitUnpacker {
|
||||
num_bits: usize,
|
||||
num_bits: u32,
|
||||
mask: u64,
|
||||
}
|
||||
|
||||
pub type BlockNumber = usize;
|
||||
|
||||
// 16k
|
||||
const BLOCK_SIZE_MIN_POW: u8 = 14;
|
||||
const BLOCK_SIZE_MIN: usize = 2 << BLOCK_SIZE_MIN_POW;
|
||||
|
||||
impl BitUnpacker {
|
||||
/// Creates a bit unpacker, that assumes the same bitwidth for all values.
|
||||
///
|
||||
@@ -82,8 +88,9 @@ impl BitUnpacker {
|
||||
} else {
|
||||
(1u64 << num_bits) - 1u64
|
||||
};
|
||||
|
||||
BitUnpacker {
|
||||
num_bits: usize::from(num_bits),
|
||||
num_bits: u32::from(num_bits),
|
||||
mask,
|
||||
}
|
||||
}
|
||||
@@ -92,16 +99,69 @@ impl BitUnpacker {
|
||||
self.num_bits as u8
|
||||
}
|
||||
|
||||
/// Calculates a block number for the given `idx`.
|
||||
#[inline]
|
||||
pub fn block_num(&self, idx: u32) -> BlockNumber {
|
||||
// Find the address in bits of the index.
|
||||
let addr_in_bits = (idx * self.num_bits) as usize;
|
||||
|
||||
// Then round down to the nearest byte.
|
||||
let addr_in_bytes = addr_in_bits >> 3;
|
||||
|
||||
// And compute the containing BlockNumber.
|
||||
addr_in_bytes >> (BLOCK_SIZE_MIN_POW + 1)
|
||||
}
|
||||
|
||||
/// Given a block number and dataset length, calculates a data Range for the block.
|
||||
pub fn block(&self, block: BlockNumber, data_len: usize) -> Range<usize> {
|
||||
let block_addr = block << (BLOCK_SIZE_MIN_POW + 1);
|
||||
// We extend the end of the block by a constant factor, so that it overlaps the next
|
||||
// block. That ensures that we never need to read on a block boundary.
|
||||
block_addr..(std::cmp::min(block_addr + BLOCK_SIZE_MIN + 8, data_len))
|
||||
}
|
||||
|
||||
/// Calculates the number of blocks for the given data_len.
|
||||
///
|
||||
/// Usually only called at startup to pre-allocate structures.
|
||||
pub fn block_count(&self, data_len: usize) -> usize {
|
||||
let block_count = data_len / (BLOCK_SIZE_MIN as usize);
|
||||
if data_len % (BLOCK_SIZE_MIN as usize) == 0 {
|
||||
block_count
|
||||
} else {
|
||||
block_count + 1
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a range within the data which covers the given id_range.
|
||||
///
|
||||
/// NOTE: This method is used for batch reads which bypass blocks to avoid dealing with block
|
||||
/// boundaries.
|
||||
#[inline]
|
||||
pub fn block_oblivious_range(&self, id_range: Range<u32>, data_len: usize) -> Range<usize> {
|
||||
let start_in_bits = id_range.start * self.num_bits;
|
||||
let start = (start_in_bits >> 3) as usize;
|
||||
let end_in_bits = id_range.end * self.num_bits;
|
||||
let end = (end_in_bits >> 3) as usize;
|
||||
// TODO: We fetch more than we need and then truncate.
|
||||
start..(std::cmp::min(end + 8, data_len))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get(&self, idx: u32, data: &[u8]) -> u64 {
|
||||
let addr_in_bits = idx as usize * self.num_bits;
|
||||
let addr = addr_in_bits >> 3;
|
||||
self.get_from_subset(idx, 0, data)
|
||||
}
|
||||
|
||||
/// Get the value at the given idx, which must exist within the given subset of the data.
|
||||
#[inline]
|
||||
pub fn get_from_subset(&self, idx: u32, data_offset: usize, data: &[u8]) -> u64 {
|
||||
let addr_in_bits = idx * self.num_bits;
|
||||
let addr = (addr_in_bits >> 3) as usize - data_offset;
|
||||
if addr + 8 > data.len() {
|
||||
if self.num_bits == 0 {
|
||||
return 0;
|
||||
}
|
||||
let bit_shift = addr_in_bits & 7;
|
||||
return self.get_slow_path(addr, bit_shift as u32, data);
|
||||
return self.get_slow_path(addr, bit_shift, data);
|
||||
}
|
||||
let bit_shift = addr_in_bits & 7;
|
||||
let bytes: [u8; 8] = (&data[addr..addr + 8]).try_into().unwrap();
|
||||
@@ -113,6 +173,7 @@ impl BitUnpacker {
|
||||
#[inline(never)]
|
||||
fn get_slow_path(&self, addr: usize, bit_shift: u32, data: &[u8]) -> u64 {
|
||||
let mut bytes: [u8; 8] = [0u8; 8];
|
||||
|
||||
let available_bytes = data.len() - addr;
|
||||
// This function is meant to only be called if we did not have 8 bytes to load.
|
||||
debug_assert!(available_bytes < 8);
|
||||
@@ -128,26 +189,25 @@ impl BitUnpacker {
|
||||
// #Panics
|
||||
//
|
||||
// This methods panics if `num_bits` is > 32.
|
||||
fn get_batch_u32s(&self, start_idx: u32, data: &[u8], output: &mut [u32]) {
|
||||
fn get_batch_u32s(&self, start_idx: u32, data_offset: usize, data: &[u8], output: &mut [u32]) {
|
||||
assert!(
|
||||
self.bit_width() <= 32,
|
||||
"Bitwidth must be <= 32 to use this method."
|
||||
);
|
||||
|
||||
let end_idx: u32 = start_idx + output.len() as u32;
|
||||
let end_idx = start_idx + output.len() as u32;
|
||||
|
||||
// We use `usize` here to avoid overflow issues.
|
||||
let end_bit_read = (end_idx as usize) * self.num_bits;
|
||||
let end_byte_read = end_bit_read.div_ceil(8);
|
||||
let end_bit_read = end_idx * self.num_bits;
|
||||
let end_byte_read = (end_bit_read + 7) / 8;
|
||||
assert!(
|
||||
end_byte_read <= data.len(),
|
||||
end_byte_read as usize <= data_offset + data.len(),
|
||||
"Requested index is out of bounds."
|
||||
);
|
||||
|
||||
// Simple slow implementation of get_batch_u32s, to deal with our ramps.
|
||||
let get_batch_ramp = |start_idx: u32, output: &mut [u32]| {
|
||||
for (out, idx) in output.iter_mut().zip(start_idx..) {
|
||||
*out = self.get(idx, data) as u32;
|
||||
*out = self.get_from_subset(idx, data_offset, data) as u32;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -160,24 +220,24 @@ impl BitUnpacker {
|
||||
// We want the start of the fast track to start align with bytes.
|
||||
// A sufficient condition is to start with an idx that is a multiple of 8,
|
||||
// so highway start is the closest multiple of 8 that is >= start_idx.
|
||||
let entrance_ramp_len: u32 = 8 - (start_idx % 8) % 8;
|
||||
let entrance_ramp_len = 8 - (start_idx % 8) % 8;
|
||||
|
||||
let highway_start: u32 = start_idx + entrance_ramp_len;
|
||||
|
||||
if highway_start + (BitPacker1x::BLOCK_LEN as u32) > end_idx {
|
||||
if highway_start + BitPacker1x::BLOCK_LEN as u32 > end_idx {
|
||||
// We don't have enough values to have even a single block of highway.
|
||||
// Let's just supply the values the simple way.
|
||||
get_batch_ramp(start_idx, output);
|
||||
return;
|
||||
}
|
||||
|
||||
let num_blocks: usize = (end_idx - highway_start) as usize / BitPacker1x::BLOCK_LEN;
|
||||
let num_blocks: u32 = (end_idx - highway_start) / BitPacker1x::BLOCK_LEN as u32;
|
||||
|
||||
// Entrance ramp
|
||||
get_batch_ramp(start_idx, &mut output[..entrance_ramp_len as usize]);
|
||||
|
||||
// Highway
|
||||
let mut offset = (highway_start as usize * self.num_bits) / 8;
|
||||
let mut offset = ((highway_start * self.num_bits) as usize / 8) - data_offset;
|
||||
let mut output_cursor = (highway_start - start_idx) as usize;
|
||||
for _ in 0..num_blocks {
|
||||
offset += BitPacker1x.decompress(
|
||||
@@ -189,7 +249,7 @@ impl BitUnpacker {
|
||||
}
|
||||
|
||||
// Exit ramp
|
||||
let highway_end: u32 = highway_start + (num_blocks * BitPacker1x::BLOCK_LEN) as u32;
|
||||
let highway_end = highway_start + num_blocks * BitPacker1x::BLOCK_LEN as u32;
|
||||
get_batch_ramp(highway_end, &mut output[output_cursor..]);
|
||||
}
|
||||
|
||||
@@ -199,16 +259,27 @@ impl BitUnpacker {
|
||||
id_range: Range<u32>,
|
||||
data: &[u8],
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
self.get_ids_for_value_range_from_subset(range, id_range, 0, data, positions)
|
||||
}
|
||||
|
||||
pub fn get_ids_for_value_range_from_subset(
|
||||
&self,
|
||||
range: RangeInclusive<u64>,
|
||||
id_range: Range<u32>,
|
||||
data_offset: usize,
|
||||
data: &[u8],
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
if self.bit_width() > 32 {
|
||||
self.get_ids_for_value_range_slow(range, id_range, data, positions)
|
||||
self.get_ids_for_value_range_slow(range, id_range, data_offset, data, positions)
|
||||
} else {
|
||||
if *range.start() > u32::MAX as u64 {
|
||||
positions.clear();
|
||||
return;
|
||||
}
|
||||
let range_u32 = (*range.start() as u32)..=(*range.end()).min(u32::MAX as u64) as u32;
|
||||
self.get_ids_for_value_range_fast(range_u32, id_range, data, positions)
|
||||
self.get_ids_for_value_range_fast(range_u32, id_range, data_offset, data, positions)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,6 +287,7 @@ impl BitUnpacker {
|
||||
&self,
|
||||
range: RangeInclusive<u64>,
|
||||
id_range: Range<u32>,
|
||||
data_offset: usize,
|
||||
data: &[u8],
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
@@ -223,7 +295,7 @@ impl BitUnpacker {
|
||||
for i in id_range {
|
||||
// If we cared we could make this branchless, but the slow implementation should rarely
|
||||
// kick in.
|
||||
let val = self.get(i, data);
|
||||
let val = self.get_from_subset(i, data_offset, data);
|
||||
if range.contains(&val) {
|
||||
positions.push(i);
|
||||
}
|
||||
@@ -234,11 +306,12 @@ impl BitUnpacker {
|
||||
&self,
|
||||
value_range: RangeInclusive<u32>,
|
||||
id_range: Range<u32>,
|
||||
data_offset: usize,
|
||||
data: &[u8],
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
positions.resize(id_range.len(), 0u32);
|
||||
self.get_batch_u32s(id_range.start, data, positions);
|
||||
self.get_batch_u32s(id_range.start, data_offset, data, positions);
|
||||
crate::filter_vec::filter_vec_in_place(value_range, id_range.start, positions)
|
||||
}
|
||||
}
|
||||
@@ -258,7 +331,7 @@ mod test {
|
||||
bitpacker.write(val, num_bits, &mut data).unwrap();
|
||||
}
|
||||
bitpacker.close(&mut data).unwrap();
|
||||
assert_eq!(data.len(), ((num_bits as usize) * len + 7) / 8);
|
||||
assert_eq!(data.len(), ((num_bits as usize) * len).div_ceil(8));
|
||||
let bitunpacker = BitUnpacker::new(num_bits);
|
||||
(bitunpacker, vals, data)
|
||||
}
|
||||
@@ -304,7 +377,7 @@ mod test {
|
||||
bitpacker.write(val, num_bits, &mut buffer).unwrap();
|
||||
}
|
||||
bitpacker.flush(&mut buffer).unwrap();
|
||||
assert_eq!(buffer.len(), (vals.len() * num_bits as usize + 7) / 8);
|
||||
assert_eq!(buffer.len(), (vals.len() * num_bits as usize).div_ceil(8));
|
||||
let bitunpacker = BitUnpacker::new(num_bits);
|
||||
let max_val = if num_bits == 64 {
|
||||
u64::MAX
|
||||
@@ -329,14 +402,14 @@ mod test {
|
||||
fn test_get_batch_panics_over_32_bits() {
|
||||
let bitunpacker = BitUnpacker::new(33);
|
||||
let mut output: [u32; 1] = [0u32];
|
||||
bitunpacker.get_batch_u32s(0, &[0, 0, 0, 0, 0, 0, 0, 0], &mut output[..]);
|
||||
bitunpacker.get_batch_u32s(0, 0, &[0, 0, 0, 0, 0, 0, 0, 0], &mut output[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_batch_limit() {
|
||||
let bitunpacker = BitUnpacker::new(1);
|
||||
let mut output: [u32; 3] = [0u32, 0u32, 0u32];
|
||||
bitunpacker.get_batch_u32s(8 * 4 - 3, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
|
||||
bitunpacker.get_batch_u32s(8 * 4 - 3, 0, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -345,7 +418,7 @@ mod test {
|
||||
let bitunpacker = BitUnpacker::new(1);
|
||||
let mut output: [u32; 3] = [0u32, 0u32, 0u32];
|
||||
// We are missing exactly one bit.
|
||||
bitunpacker.get_batch_u32s(8 * 4 - 2, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
|
||||
bitunpacker.get_batch_u32s(8 * 4 - 2, 0, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
|
||||
}
|
||||
|
||||
proptest::proptest! {
|
||||
@@ -368,7 +441,7 @@ mod test {
|
||||
for len in [0, 1, 2, 32, 33, 34, 64] {
|
||||
for start_idx in 0u32..32u32 {
|
||||
output.resize(len, 0);
|
||||
bitunpacker.get_batch_u32s(start_idx, &buffer, &mut output);
|
||||
bitunpacker.get_batch_u32s(start_idx, 0, &buffer, &mut output);
|
||||
for (i, output_byte) in output.iter().enumerate() {
|
||||
let expected = (start_idx + i as u32) & mask;
|
||||
assert_eq!(*output_byte, expected);
|
||||
|
||||
@@ -19,7 +19,7 @@ fn u32_to_i32(val: u32) -> i32 {
|
||||
#[inline]
|
||||
unsafe fn u32_to_i32_avx2(vals_u32x8s: DataType) -> DataType {
|
||||
const HIGHEST_BIT_MASK: DataType = from_u32x8([HIGHEST_BIT; NUM_LANES]);
|
||||
op_xor(vals_u32x8s, HIGHEST_BIT_MASK)
|
||||
unsafe { op_xor(vals_u32x8s, HIGHEST_BIT_MASK) }
|
||||
}
|
||||
|
||||
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
|
||||
@@ -66,17 +66,19 @@ unsafe fn filter_vec_avx2_aux(
|
||||
]);
|
||||
const SHIFT: __m256i = from_u32x8([NUM_LANES as u32; NUM_LANES]);
|
||||
for _ in 0..num_words {
|
||||
let word = load_unaligned(input);
|
||||
let word = u32_to_i32_avx2(word);
|
||||
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
|
||||
let added_len = keeper_bitset.count_ones();
|
||||
let filtered_doc_ids = compact(ids, keeper_bitset);
|
||||
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
|
||||
output_tail = output_tail.offset(added_len as isize);
|
||||
ids = op_add(ids, SHIFT);
|
||||
input = input.offset(1);
|
||||
unsafe {
|
||||
let word = load_unaligned(input);
|
||||
let word = u32_to_i32_avx2(word);
|
||||
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
|
||||
let added_len = keeper_bitset.count_ones();
|
||||
let filtered_doc_ids = compact(ids, keeper_bitset);
|
||||
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
|
||||
output_tail = output_tail.offset(added_len as isize);
|
||||
ids = op_add(ids, SHIFT);
|
||||
input = input.offset(1);
|
||||
}
|
||||
}
|
||||
output_tail.offset_from(output) as usize
|
||||
unsafe { output_tail.offset_from(output) as usize }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@@ -92,8 +94,7 @@ unsafe fn compute_filter_bitset(val: __m256i, range: std::ops::RangeInclusive<__
|
||||
let too_low = op_greater(*range.start(), val);
|
||||
let too_high = op_greater(val, *range.end());
|
||||
let inside = op_or(too_low, too_high);
|
||||
255 - std::arch::x86_64::_mm256_movemask_ps(std::mem::transmute::<DataType, __m256>(inside))
|
||||
as u8
|
||||
255 - std::arch::x86_64::_mm256_movemask_ps(_mm256_castsi256_ps(inside)) as u8
|
||||
}
|
||||
|
||||
union U8x32 {
|
||||
|
||||
@@ -16,7 +16,7 @@ stacker = { version= "0.6", path = "../stacker", package="tantivy-stacker"}
|
||||
sstable = { version= "0.6", path = "../sstable", package = "tantivy-sstable" }
|
||||
common = { version= "0.10", path = "../common", package = "tantivy-common" }
|
||||
tantivy-bitpacker = { version= "0.9", path = "../bitpacker/" }
|
||||
serde = "1.0.152"
|
||||
serde = { version = "1.0.152", features = ["derive"] }
|
||||
downcast-rs = "2.0.1"
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -73,7 +73,7 @@ The crate introduces the following concepts.
|
||||
`Columnar` is an equivalent of a dataframe.
|
||||
It maps `column_key` to `Column`.
|
||||
|
||||
A `Column<T>` asssociates a `RowId` (u32) to any
|
||||
A `Column<T>` associates a `RowId` (u32) to any
|
||||
number of values.
|
||||
|
||||
This is made possible by wrapping a `ColumnIndex` and a `ColumnValue` object.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use binggan::{InputGroup, black_box};
|
||||
use common::*;
|
||||
use tantivy_columnar::Column;
|
||||
use tantivy_columnar::{Column, ValueRange};
|
||||
|
||||
pub mod common;
|
||||
|
||||
@@ -46,16 +46,16 @@ fn bench_group(mut runner: InputGroup<Column>) {
|
||||
runner.register("access_first_vals", |column| {
|
||||
let mut sum = 0;
|
||||
const BLOCK_SIZE: usize = 32;
|
||||
let mut docs = vec![0; BLOCK_SIZE];
|
||||
let mut buffer = vec![None; BLOCK_SIZE];
|
||||
let mut docs = Vec::with_capacity(BLOCK_SIZE);
|
||||
let mut buffer = Vec::with_capacity(BLOCK_SIZE);
|
||||
for i in (0..NUM_DOCS).step_by(BLOCK_SIZE) {
|
||||
// fill docs
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
docs.clear();
|
||||
for idx in 0..BLOCK_SIZE {
|
||||
docs[idx] = idx as u32 + i;
|
||||
docs.push(idx as u32 + i);
|
||||
}
|
||||
|
||||
column.first_vals(&docs, &mut buffer);
|
||||
buffer.clear();
|
||||
column.first_vals_in_value_range(&mut docs, &mut buffer, ValueRange::All);
|
||||
for val in buffer.iter() {
|
||||
let Some(val) = val else { continue };
|
||||
sum += *val;
|
||||
|
||||
@@ -89,13 +89,6 @@ fn main() {
|
||||
black_box(sum);
|
||||
});
|
||||
|
||||
group.register("first_block_fetch", |column| {
|
||||
let mut block: Vec<Option<u64>> = vec![None; 64];
|
||||
let fetch_docids = (0..64).collect::<Vec<_>>();
|
||||
column.first_vals(&fetch_docids, &mut block);
|
||||
black_box(block[0]);
|
||||
});
|
||||
|
||||
group.register("first_block_single_calls", |column| {
|
||||
let mut block: Vec<Option<u64>> = vec![None; 64];
|
||||
let fetch_docids = (0..64).collect::<Vec<_>>();
|
||||
|
||||
@@ -40,7 +40,14 @@ fn main() {
|
||||
let columnar_readers = columnar_readers.iter().collect::<Vec<_>>();
|
||||
let merge_row_order = StackMergeOrder::stack(&columnar_readers[..]);
|
||||
|
||||
merge_columnar(&columnar_readers, &[], merge_row_order.into(), &mut out).unwrap();
|
||||
merge_columnar(
|
||||
&columnar_readers,
|
||||
&[],
|
||||
merge_row_order.into(),
|
||||
&mut out,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
Some(out.len() as u64)
|
||||
},
|
||||
);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod dictionary_encoded;
|
||||
mod serialize;
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::fmt::{self, Debug};
|
||||
use std::io::Write;
|
||||
use std::ops::{Range, RangeInclusive};
|
||||
@@ -19,6 +20,11 @@ use crate::column_values::monotonic_mapping::StrictlyMonotonicMappingToInternal;
|
||||
use crate::column_values::{ColumnValues, monotonic_map_column};
|
||||
use crate::{Cardinality, DocId, EmptyColumnValues, MonotonicallyMappableToU64, RowId};
|
||||
|
||||
thread_local! {
|
||||
static ROWS: RefCell<Vec<RowId>> = const { RefCell::new(Vec::new()) };
|
||||
static DOCS: RefCell<Vec<DocId>> = const { RefCell::new(Vec::new()) };
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Column<T = u64> {
|
||||
pub index: ColumnIndex,
|
||||
@@ -89,31 +95,6 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
|
||||
self.values_for_doc(row_id).next()
|
||||
}
|
||||
|
||||
/// Load the first value for each docid in the provided slice.
|
||||
#[inline]
|
||||
pub fn first_vals(&self, docids: &[DocId], output: &mut [Option<T>]) {
|
||||
match &self.index {
|
||||
ColumnIndex::Empty { .. } => {}
|
||||
ColumnIndex::Full => self.values.get_vals_opt(docids, output),
|
||||
ColumnIndex::Optional(optional_index) => {
|
||||
for (i, docid) in docids.iter().enumerate() {
|
||||
output[i] = optional_index
|
||||
.rank_if_exists(*docid)
|
||||
.map(|rowid| self.values.get_val(rowid));
|
||||
}
|
||||
}
|
||||
ColumnIndex::Multivalued(multivalued_index) => {
|
||||
for (i, docid) in docids.iter().enumerate() {
|
||||
let range = multivalued_index.range(*docid);
|
||||
let is_empty = range.start == range.end;
|
||||
if !is_empty {
|
||||
output[i] = Some(self.values.get_val(range.start));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Translates a block of docids to row_ids.
|
||||
///
|
||||
/// returns the row_ids and the matching docids on the same index
|
||||
@@ -131,6 +112,8 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
|
||||
self.index.docids_to_rowids(doc_ids, doc_ids_out, row_ids)
|
||||
}
|
||||
|
||||
/// Get an iterator over the values for the provided docid.
|
||||
#[inline]
|
||||
pub fn values_for_doc(&self, doc_id: DocId) -> impl Iterator<Item = T> + '_ {
|
||||
self.index
|
||||
.value_row_ids(doc_id)
|
||||
@@ -141,7 +124,7 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
|
||||
#[inline]
|
||||
pub fn get_docids_for_value_range(
|
||||
&self,
|
||||
value_range: RangeInclusive<T>,
|
||||
value_range: ValueRange<T>,
|
||||
selected_docid_range: Range<u32>,
|
||||
doc_ids: &mut Vec<u32>,
|
||||
) {
|
||||
@@ -158,15 +141,6 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
|
||||
.select_batch_in_place(selected_docid_range.start, doc_ids);
|
||||
}
|
||||
|
||||
/// Fills the output vector with the (possibly multiple values that are associated_with
|
||||
/// `row_id`.
|
||||
///
|
||||
/// This method clears the `output` vector.
|
||||
pub fn fill_vals(&self, row_id: RowId, output: &mut Vec<T>) {
|
||||
output.clear();
|
||||
output.extend(self.values_for_doc(row_id));
|
||||
}
|
||||
|
||||
pub fn first_or_default_col(self, default_value: T) -> Arc<dyn ColumnValues<T>> {
|
||||
Arc::new(FirstValueWithDefault {
|
||||
column: self,
|
||||
@@ -175,6 +149,181 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
|
||||
}
|
||||
}
|
||||
|
||||
// Separate impl block for methods requiring `Default` for `T`.
|
||||
impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static + Default> Column<T> {
|
||||
/// Load the first value for each docid in the provided slice.
|
||||
///
|
||||
/// The `docids` vector is mutated: documents that do not match the `value_range` are removed.
|
||||
/// The `values` vector is populated with the values of the remaining documents.
|
||||
#[inline]
|
||||
pub fn first_vals_in_value_range(
|
||||
&self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<crate::ComparableDoc<Option<T>, DocId>>,
|
||||
value_range: ValueRange<T>,
|
||||
) {
|
||||
match (&self.index, value_range) {
|
||||
(ColumnIndex::Empty { .. }, value_range) => {
|
||||
let nulls_match = match &value_range {
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(_) => false,
|
||||
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
|
||||
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
|
||||
ValueRange::LessThan(_, nulls_match) => *nulls_match,
|
||||
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
|
||||
};
|
||||
if nulls_match {
|
||||
for &doc in input_docs {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
(ColumnIndex::Full, value_range) => {
|
||||
self.values
|
||||
.get_vals_in_value_range(input_docs, input_docs, output, value_range);
|
||||
}
|
||||
(ColumnIndex::Optional(optional_index), value_range) => {
|
||||
let nulls_match = match &value_range {
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(_) => false,
|
||||
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
|
||||
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
|
||||
ValueRange::LessThan(_, nulls_match) => *nulls_match,
|
||||
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
|
||||
};
|
||||
|
||||
let fallback_needed = ROWS.with(|rows_cell| {
|
||||
DOCS.with(|docs_cell| {
|
||||
let mut rows = rows_cell.borrow_mut();
|
||||
let mut docs = docs_cell.borrow_mut();
|
||||
rows.clear();
|
||||
docs.clear();
|
||||
|
||||
let mut has_nulls = false;
|
||||
|
||||
for &doc_id in input_docs {
|
||||
if let Some(row_id) = optional_index.rank_if_exists(doc_id) {
|
||||
rows.push(row_id);
|
||||
docs.push(doc_id);
|
||||
} else {
|
||||
has_nulls = true;
|
||||
if nulls_match {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !has_nulls || !nulls_match {
|
||||
self.values.get_vals_in_value_range(
|
||||
&rows,
|
||||
&docs,
|
||||
output,
|
||||
value_range.clone(),
|
||||
);
|
||||
return false;
|
||||
}
|
||||
true
|
||||
})
|
||||
});
|
||||
|
||||
if fallback_needed {
|
||||
for &doc_id in input_docs {
|
||||
if let Some(row_id) = optional_index.rank_if_exists(doc_id) {
|
||||
let val = self.values.get_val(row_id);
|
||||
let value_matches = match &value_range {
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(r) => r.contains(&val),
|
||||
ValueRange::GreaterThan(t, _) => val > *t,
|
||||
ValueRange::GreaterThanOrEqual(t, _) => val >= *t,
|
||||
ValueRange::LessThan(t, _) => val < *t,
|
||||
ValueRange::LessThanOrEqual(t, _) => val <= *t,
|
||||
};
|
||||
|
||||
if value_matches {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc_id,
|
||||
sort_key: Some(val),
|
||||
});
|
||||
}
|
||||
} else if nulls_match {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc_id,
|
||||
sort_key: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
(ColumnIndex::Multivalued(multivalued_index), value_range) => {
|
||||
let nulls_match = match &value_range {
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(_) => false,
|
||||
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
|
||||
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
|
||||
ValueRange::LessThan(_, nulls_match) => *nulls_match,
|
||||
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
|
||||
};
|
||||
for i in 0..input_docs.len() {
|
||||
let docid = input_docs[i];
|
||||
let row_range = multivalued_index.range(docid);
|
||||
let is_empty = row_range.start == row_range.end;
|
||||
if !is_empty {
|
||||
let val = self.values.get_val(row_range.start);
|
||||
let matches = match &value_range {
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(r) => r.contains(&val),
|
||||
ValueRange::GreaterThan(t, _) => val > *t,
|
||||
ValueRange::GreaterThanOrEqual(t, _) => val >= *t,
|
||||
ValueRange::LessThan(t, _) => val < *t,
|
||||
ValueRange::LessThanOrEqual(t, _) => val <= *t,
|
||||
};
|
||||
if matches {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: docid,
|
||||
sort_key: Some(val),
|
||||
});
|
||||
}
|
||||
} else if nulls_match {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: docid,
|
||||
sort_key: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A range of values.
|
||||
///
|
||||
/// This type is intended to be used in batch APIs, where the cost of unpacking the enum
|
||||
/// is outweighed by the time spent processing a batch.
|
||||
///
|
||||
/// Implementers should pattern match on the variants to use optimized loops for each case.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ValueRange<T> {
|
||||
/// A range that includes both start and end.
|
||||
Inclusive(RangeInclusive<T>),
|
||||
/// A range that matches all values.
|
||||
All,
|
||||
/// A range that matches all values greater than the threshold.
|
||||
/// The boolean flag indicates if null values should be included.
|
||||
GreaterThan(T, bool),
|
||||
/// A range that matches all values greater than or equal to the threshold.
|
||||
/// The boolean flag indicates if null values should be included.
|
||||
GreaterThanOrEqual(T, bool),
|
||||
/// A range that matches all values less than the threshold.
|
||||
/// The boolean flag indicates if null values should be included.
|
||||
LessThan(T, bool),
|
||||
/// A range that matches all values less than or equal to the threshold.
|
||||
/// The boolean flag indicates if null values should be included.
|
||||
LessThanOrEqual(T, bool),
|
||||
}
|
||||
|
||||
impl BinarySerializable for Cardinality {
|
||||
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
self.to_code().serialize(writer)
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::io;
|
||||
use std::io::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::OwnedBytes;
|
||||
use common::file_slice::FileSlice;
|
||||
use sstable::Dictionary;
|
||||
|
||||
use crate::column::{BytesColumn, Column};
|
||||
@@ -41,12 +41,13 @@ pub fn serialize_column_mappable_to_u64<T: MonotonicallyMappableToU64>(
|
||||
}
|
||||
|
||||
pub fn open_column_u64<T: MonotonicallyMappableToU64>(
|
||||
bytes: OwnedBytes,
|
||||
file_slice: FileSlice,
|
||||
format_version: Version,
|
||||
) -> io::Result<Column<T>> {
|
||||
let (body, column_index_num_bytes_payload) = bytes.rsplit(4);
|
||||
let (body, column_index_num_bytes_payload) = file_slice.split_from_end(4);
|
||||
let column_index_num_bytes = u32::from_le_bytes(
|
||||
column_index_num_bytes_payload
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
@@ -61,12 +62,13 @@ pub fn open_column_u64<T: MonotonicallyMappableToU64>(
|
||||
}
|
||||
|
||||
pub fn open_column_u128<T: MonotonicallyMappableToU128>(
|
||||
bytes: OwnedBytes,
|
||||
file_slice: FileSlice,
|
||||
format_version: Version,
|
||||
) -> io::Result<Column<T>> {
|
||||
let (body, column_index_num_bytes_payload) = bytes.rsplit(4);
|
||||
let (body, column_index_num_bytes_payload) = file_slice.split_from_end(4);
|
||||
let column_index_num_bytes = u32::from_le_bytes(
|
||||
column_index_num_bytes_payload
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
@@ -84,12 +86,13 @@ pub fn open_column_u128<T: MonotonicallyMappableToU128>(
|
||||
///
|
||||
/// See [`open_u128_as_compact_u64`] for more details.
|
||||
pub fn open_column_u128_as_compact_u64(
|
||||
bytes: OwnedBytes,
|
||||
file_slice: FileSlice,
|
||||
format_version: Version,
|
||||
) -> io::Result<Column<u64>> {
|
||||
let (body, column_index_num_bytes_payload) = bytes.rsplit(4);
|
||||
let (body, column_index_num_bytes_payload) = file_slice.split_from_end(4);
|
||||
let column_index_num_bytes = u32::from_le_bytes(
|
||||
column_index_num_bytes_payload
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
@@ -103,11 +106,21 @@ pub fn open_column_u128_as_compact_u64(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn open_column_bytes(data: OwnedBytes, format_version: Version) -> io::Result<BytesColumn> {
|
||||
let (body, dictionary_len_bytes) = data.rsplit(4);
|
||||
let dictionary_len = u32::from_le_bytes(dictionary_len_bytes.as_slice().try_into().unwrap());
|
||||
pub fn open_column_bytes(
|
||||
file_slice: FileSlice,
|
||||
format_version: Version,
|
||||
) -> io::Result<BytesColumn> {
|
||||
let (body, dictionary_len_bytes) = file_slice.split_from_end(4);
|
||||
let dictionary_len = u32::from_le_bytes(
|
||||
dictionary_len_bytes
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
);
|
||||
let (dictionary_bytes, column_bytes) = body.split(dictionary_len as usize);
|
||||
let dictionary = Arc::new(Dictionary::from_bytes(dictionary_bytes)?);
|
||||
|
||||
let dictionary = Arc::new(Dictionary::open(dictionary_bytes)?);
|
||||
let term_ord_column = crate::column::open_column_u64::<u64>(column_bytes, format_version)?;
|
||||
Ok(BytesColumn {
|
||||
dictionary,
|
||||
@@ -115,7 +128,7 @@ pub fn open_column_bytes(data: OwnedBytes, format_version: Version) -> io::Resul
|
||||
})
|
||||
}
|
||||
|
||||
pub fn open_column_str(data: OwnedBytes, format_version: Version) -> io::Result<StrColumn> {
|
||||
let bytes_column = open_column_bytes(data, format_version)?;
|
||||
pub fn open_column_str(file_slice: FileSlice, format_version: Version) -> io::Result<StrColumn> {
|
||||
let bytes_column = open_column_bytes(file_slice, format_version)?;
|
||||
Ok(StrColumn::wrap(bytes_column))
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ pub fn merge_column_index<'a>(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use common::OwnedBytes;
|
||||
use common::file_slice::FileSlice;
|
||||
|
||||
use crate::column_index::merge::detect_cardinality;
|
||||
use crate::column_index::multivalued_index::{
|
||||
@@ -178,7 +178,7 @@ mod tests {
|
||||
let mut output = Vec::new();
|
||||
serialize_multivalued_index(&start_index_iterable, &mut output).unwrap();
|
||||
let multivalue =
|
||||
open_multivalued_index(OwnedBytes::new(output), crate::Version::V2).unwrap();
|
||||
open_multivalued_index(FileSlice::from(output), crate::Version::V2).unwrap();
|
||||
let start_indexes: Vec<RowId> = multivalue.get_start_index_column().iter().collect();
|
||||
assert_eq!(&start_indexes, &[0, 3, 5]);
|
||||
}
|
||||
@@ -216,7 +216,7 @@ mod tests {
|
||||
let mut output = Vec::new();
|
||||
serialize_multivalued_index(&start_index_iterable, &mut output).unwrap();
|
||||
let multivalue =
|
||||
open_multivalued_index(OwnedBytes::new(output), crate::Version::V2).unwrap();
|
||||
open_multivalued_index(FileSlice::from(output), crate::Version::V2).unwrap();
|
||||
let start_indexes: Vec<RowId> = multivalue.get_start_index_column().iter().collect();
|
||||
assert_eq!(&start_indexes, &[0, 3, 5, 6]);
|
||||
}
|
||||
|
||||
@@ -3,7 +3,8 @@ use std::io::Write;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::{CountingWriter, OwnedBytes};
|
||||
use common::CountingWriter;
|
||||
use common::file_slice::FileSlice;
|
||||
|
||||
use super::optional_index::{open_optional_index, serialize_optional_index};
|
||||
use super::{OptionalIndex, SerializableOptionalIndex, Set};
|
||||
@@ -44,21 +45,26 @@ pub fn serialize_multivalued_index(
|
||||
}
|
||||
|
||||
pub fn open_multivalued_index(
|
||||
bytes: OwnedBytes,
|
||||
file_slice: FileSlice,
|
||||
format_version: Version,
|
||||
) -> io::Result<MultiValueIndex> {
|
||||
match format_version {
|
||||
Version::V1 => {
|
||||
let start_index_column: Arc<dyn ColumnValues<RowId>> =
|
||||
load_u64_based_column_values(bytes)?;
|
||||
load_u64_based_column_values(file_slice)?;
|
||||
Ok(MultiValueIndex::MultiValueIndexV1(MultiValueIndexV1 {
|
||||
start_index_column,
|
||||
}))
|
||||
}
|
||||
Version::V2 => {
|
||||
let (body_bytes, optional_index_len) = bytes.rsplit(4);
|
||||
let optional_index_len =
|
||||
u32::from_le_bytes(optional_index_len.as_slice().try_into().unwrap());
|
||||
let (body_bytes, optional_index_len) = file_slice.split_from_end(4);
|
||||
let optional_index_len = u32::from_le_bytes(
|
||||
optional_index_len
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
);
|
||||
let (optional_index_bytes, start_index_bytes) =
|
||||
body_bytes.split(optional_index_len as usize);
|
||||
let optional_index = open_optional_index(optional_index_bytes)?;
|
||||
@@ -185,8 +191,8 @@ impl MultiValueIndex {
|
||||
};
|
||||
let mut buffer = Vec::new();
|
||||
serialize_multivalued_index(&serializable_multivalued_index, &mut buffer).unwrap();
|
||||
let bytes = OwnedBytes::new(buffer);
|
||||
open_multivalued_index(bytes, Version::V2).unwrap()
|
||||
let file_slice = FileSlice::from(buffer);
|
||||
open_multivalued_index(file_slice, Version::V2).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_start_index_column(&self) -> &Arc<dyn crate::ColumnValues<RowId>> {
|
||||
@@ -333,7 +339,7 @@ mod tests {
|
||||
use std::ops::Range;
|
||||
|
||||
use super::MultiValueIndex;
|
||||
use crate::{ColumnarReader, DynamicColumn};
|
||||
use crate::{ColumnarReader, DynamicColumn, ValueRange};
|
||||
|
||||
fn index_to_pos_helper(
|
||||
index: &MultiValueIndex,
|
||||
@@ -413,7 +419,7 @@ mod tests {
|
||||
assert_eq!(row_id_range, 0..4);
|
||||
|
||||
let check = |range, expected| {
|
||||
let full_range = 0..=u64::MAX;
|
||||
let full_range = ValueRange::All;
|
||||
let mut docids = Vec::new();
|
||||
column.get_docids_for_value_range(full_range, range, &mut docids);
|
||||
assert_eq!(docids, expected);
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::sync::Arc;
|
||||
mod set;
|
||||
mod set_block;
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{BinarySerializable, OwnedBytes, VInt};
|
||||
pub use set::{SelectCursor, Set, SetCodec};
|
||||
use set_block::{
|
||||
@@ -268,8 +269,8 @@ impl OptionalIndex {
|
||||
);
|
||||
let mut buffer = Vec::new();
|
||||
serialize_optional_index(&row_ids, num_rows, &mut buffer).unwrap();
|
||||
let bytes = OwnedBytes::new(buffer);
|
||||
open_optional_index(bytes).unwrap()
|
||||
let file_slice = FileSlice::from(buffer);
|
||||
open_optional_index(file_slice).unwrap()
|
||||
}
|
||||
|
||||
pub fn num_docs(&self) -> RowId {
|
||||
@@ -486,10 +487,17 @@ fn deserialize_optional_index_block_metadatas(
|
||||
(block_metas.into_boxed_slice(), non_null_rows_before_block)
|
||||
}
|
||||
|
||||
pub fn open_optional_index(bytes: OwnedBytes) -> io::Result<OptionalIndex> {
|
||||
let (mut bytes, num_non_empty_blocks_bytes) = bytes.rsplit(2);
|
||||
let num_non_empty_block_bytes =
|
||||
u16::from_le_bytes(num_non_empty_blocks_bytes.as_slice().try_into().unwrap());
|
||||
pub fn open_optional_index(file_slice: FileSlice) -> io::Result<OptionalIndex> {
|
||||
let (bytes, num_non_empty_blocks_bytes) = file_slice.split_from_end(2);
|
||||
let num_non_empty_block_bytes = u16::from_le_bytes(
|
||||
num_non_empty_blocks_bytes
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let mut bytes = bytes.read_bytes()?;
|
||||
let num_docs = VInt::deserialize_u64(&mut bytes)? as u32;
|
||||
let block_metas_num_bytes =
|
||||
num_non_empty_block_bytes as usize * SERIALIZED_BLOCK_META_NUM_BYTES;
|
||||
|
||||
@@ -59,7 +59,7 @@ fn test_with_random_sets_simple() {
|
||||
let vals = 10..ELEMENTS_PER_BLOCK * 2;
|
||||
let mut out: Vec<u8> = Vec::new();
|
||||
serialize_optional_index(&vals, 100, &mut out).unwrap();
|
||||
let null_index = open_optional_index(OwnedBytes::new(out)).unwrap();
|
||||
let null_index = open_optional_index(FileSlice::from(out)).unwrap();
|
||||
let ranks: Vec<u32> = (65_472u32..65_473u32).collect();
|
||||
let els: Vec<u32> = ranks.iter().copied().map(|rank| rank + 10).collect();
|
||||
let mut select_cursor = null_index.select_cursor();
|
||||
@@ -102,7 +102,7 @@ impl<'a> Iterable<RowId> for &'a [bool] {
|
||||
fn test_null_index(data: &[bool]) {
|
||||
let mut out: Vec<u8> = Vec::new();
|
||||
serialize_optional_index(&data, data.len() as RowId, &mut out).unwrap();
|
||||
let null_index = open_optional_index(OwnedBytes::new(out)).unwrap();
|
||||
let null_index = open_optional_index(FileSlice::from(out)).unwrap();
|
||||
let orig_idx_with_value: Vec<u32> = data
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -223,3 +223,170 @@ fn test_optional_index_for_tests() {
|
||||
assert!(!optional_index.contains(3));
|
||||
assert_eq!(optional_index.num_docs(), 4);
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use test::Bencher;
|
||||
|
||||
use super::*;
|
||||
|
||||
const TOTAL_NUM_VALUES: u32 = 1_000_000;
|
||||
fn gen_bools(fill_ratio: f64) -> OptionalIndex {
|
||||
let mut out = Vec::new();
|
||||
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
|
||||
let vals: Vec<RowId> = (0..TOTAL_NUM_VALUES)
|
||||
.map(|_| rng.gen_bool(fill_ratio))
|
||||
.enumerate()
|
||||
.filter(|(_pos, val)| *val)
|
||||
.map(|(pos, _)| pos as RowId)
|
||||
.collect();
|
||||
serialize_optional_index(&&vals[..], TOTAL_NUM_VALUES, &mut out).unwrap();
|
||||
|
||||
open_optional_index(FileSlice::from(out)).unwrap()
|
||||
}
|
||||
|
||||
fn random_range_iterator(
|
||||
start: u32,
|
||||
end: u32,
|
||||
avg_step_size: u32,
|
||||
avg_deviation: u32,
|
||||
) -> impl Iterator<Item = u32> {
|
||||
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
|
||||
let mut current = start;
|
||||
std::iter::from_fn(move || {
|
||||
current += rng.gen_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation);
|
||||
if current >= end { None } else { Some(current) }
|
||||
})
|
||||
}
|
||||
|
||||
fn n_percent_step_iterator(percent: f32, num_values: u32) -> impl Iterator<Item = u32> {
|
||||
let ratio = percent / 100.0;
|
||||
let step_size = (1f32 / ratio) as u32;
|
||||
let deviation = step_size - 1;
|
||||
random_range_iterator(0, num_values, step_size, deviation)
|
||||
}
|
||||
|
||||
fn walk_over_data(codec: &OptionalIndex, avg_step_size: u32) -> Option<u32> {
|
||||
walk_over_data_from_positions(
|
||||
codec,
|
||||
random_range_iterator(0, TOTAL_NUM_VALUES, avg_step_size, 0),
|
||||
)
|
||||
}
|
||||
|
||||
fn walk_over_data_from_positions(
|
||||
codec: &OptionalIndex,
|
||||
positions: impl Iterator<Item = u32>,
|
||||
) -> Option<u32> {
|
||||
let mut dense_idx: Option<u32> = None;
|
||||
for idx in positions {
|
||||
dense_idx = dense_idx.or(codec.rank_if_exists(idx));
|
||||
}
|
||||
dense_idx
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_1percent_filled_10percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.01f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_5percent_filled_10percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.05f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_5percent_filled_1percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.05f64);
|
||||
bench.iter(|| walk_over_data(&codec, 1000));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_full_scan_1percent_filled(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.01f64);
|
||||
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_full_scan_10percent_filled(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.1f64);
|
||||
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_full_scan_90percent_filled(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.9f64);
|
||||
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_10percent_filled_1percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.1f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_50percent_filled_1percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.5f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_orig_to_codec_90percent_filled_1percent_hit(bench: &mut Bencher) {
|
||||
let codec = gen_bools(0.9f64);
|
||||
bench.iter(|| walk_over_data(&codec, 100));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_1percent_filled_0comma005percent_hit(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.01f64, 0.005f32, bench);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_10percent_filled_0comma005percent_hit(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.1f64, 0.005f32, bench);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_1percent_filled_10percent_hit(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.01f64, 10f32, bench);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_1percent_filled_full_scan(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.01f64, 100f32, bench);
|
||||
}
|
||||
|
||||
fn bench_translate_codec_to_orig_util(
|
||||
percent_filled: f64,
|
||||
percent_hit: f32,
|
||||
bench: &mut Bencher,
|
||||
) {
|
||||
let codec = gen_bools(percent_filled);
|
||||
let num_non_nulls = codec.num_non_nulls();
|
||||
let idxs: Vec<u32> = if percent_hit == 100.0f32 {
|
||||
(0..num_non_nulls).collect()
|
||||
} else {
|
||||
n_percent_step_iterator(percent_hit, num_non_nulls).collect()
|
||||
};
|
||||
let mut output = vec![0u32; idxs.len()];
|
||||
bench.iter(|| {
|
||||
output.copy_from_slice(&idxs[..]);
|
||||
codec.select_batch(&mut output);
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_90percent_filled_0comma005percent_hit(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.9f64, 0.005, bench);
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_translate_codec_to_orig_90percent_filled_full_scan(bench: &mut Bencher) {
|
||||
bench_translate_codec_to_orig_util(0.9f64, 100.0f32, bench);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
|
||||
use common::{CountingWriter, OwnedBytes};
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{CountingWriter, HasLen};
|
||||
|
||||
use super::OptionalIndex;
|
||||
use super::multivalued_index::SerializableMultivalueIndex;
|
||||
@@ -65,27 +66,28 @@ pub fn serialize_column_index(
|
||||
|
||||
/// Open a serialized column index.
|
||||
pub fn open_column_index(
|
||||
mut bytes: OwnedBytes,
|
||||
file_slice: FileSlice,
|
||||
format_version: Version,
|
||||
) -> io::Result<ColumnIndex> {
|
||||
if bytes.is_empty() {
|
||||
if file_slice.len() == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"Failed to deserialize column index. Empty buffer.",
|
||||
));
|
||||
}
|
||||
let cardinality_code = bytes[0];
|
||||
let (header, body) = file_slice.split(1);
|
||||
let cardinality_code = header.read_bytes()?.as_slice()[0];
|
||||
let cardinality = Cardinality::try_from_code(cardinality_code)?;
|
||||
bytes.advance(1);
|
||||
|
||||
match cardinality {
|
||||
Cardinality::Full => Ok(ColumnIndex::Full),
|
||||
Cardinality::Optional => {
|
||||
let optional_index = super::optional_index::open_optional_index(bytes)?;
|
||||
let optional_index = super::optional_index::open_optional_index(body)?;
|
||||
Ok(ColumnIndex::Optional(optional_index))
|
||||
}
|
||||
Cardinality::Multivalued => {
|
||||
let multivalue_index =
|
||||
super::multivalued_index::open_multivalued_index(bytes, format_version)?;
|
||||
super::multivalued_index::open_multivalued_index(body, format_version)?;
|
||||
Ok(ColumnIndex::Multivalued(multivalue_index))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,13 +7,15 @@
|
||||
//! - Monotonically map values to u64/u128
|
||||
|
||||
use std::fmt::Debug;
|
||||
use std::ops::{Range, RangeInclusive};
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
use downcast_rs::DowncastSync;
|
||||
pub use monotonic_mapping::{MonotonicallyMappableToU64, StrictlyMonotonicFn};
|
||||
pub use monotonic_mapping_u128::MonotonicallyMappableToU128;
|
||||
|
||||
use crate::column::ValueRange;
|
||||
|
||||
mod merge;
|
||||
pub(crate) mod monotonic_mapping;
|
||||
pub(crate) mod monotonic_mapping_u128;
|
||||
@@ -27,8 +29,7 @@ mod monotonic_column;
|
||||
pub(crate) use merge::MergedColumnValues;
|
||||
pub use stats::ColumnStats;
|
||||
pub use u64_based::{
|
||||
ALL_U64_CODEC_TYPES, CodecType, load_u64_based_column_values,
|
||||
serialize_and_load_u64_based_column_values, serialize_u64_based_column_values,
|
||||
ALL_U64_CODEC_TYPES, CodecType, load_u64_based_column_values, serialize_u64_based_column_values,
|
||||
};
|
||||
pub use u128_based::{
|
||||
CompactSpaceU64Accessor, open_u128_as_compact_u64, open_u128_mapped,
|
||||
@@ -109,6 +110,307 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the values for the provided docids.
|
||||
///
|
||||
/// The values are filtered by the provided value range.
|
||||
fn get_vals_in_value_range(
|
||||
&self,
|
||||
input_indexes: &[u32],
|
||||
input_doc_ids: &[u32],
|
||||
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
|
||||
value_range: ValueRange<T>,
|
||||
) {
|
||||
let len = input_indexes.len();
|
||||
let mut read_head = 0;
|
||||
|
||||
match value_range {
|
||||
ValueRange::All => {
|
||||
while read_head + 3 < len {
|
||||
let idx0 = input_indexes[read_head];
|
||||
let idx1 = input_indexes[read_head + 1];
|
||||
let idx2 = input_indexes[read_head + 2];
|
||||
let idx3 = input_indexes[read_head + 3];
|
||||
|
||||
let doc0 = input_doc_ids[read_head];
|
||||
let doc1 = input_doc_ids[read_head + 1];
|
||||
let doc2 = input_doc_ids[read_head + 2];
|
||||
let doc3 = input_doc_ids[read_head + 3];
|
||||
|
||||
let val0 = self.get_val(idx0);
|
||||
let val1 = self.get_val(idx1);
|
||||
let val2 = self.get_val(idx2);
|
||||
let val3 = self.get_val(idx3);
|
||||
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc0,
|
||||
sort_key: Some(val0),
|
||||
});
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc1,
|
||||
sort_key: Some(val1),
|
||||
});
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc2,
|
||||
sort_key: Some(val2),
|
||||
});
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc3,
|
||||
sort_key: Some(val3),
|
||||
});
|
||||
|
||||
read_head += 4;
|
||||
}
|
||||
}
|
||||
ValueRange::Inclusive(ref range) => {
|
||||
while read_head + 3 < len {
|
||||
let idx0 = input_indexes[read_head];
|
||||
let idx1 = input_indexes[read_head + 1];
|
||||
let idx2 = input_indexes[read_head + 2];
|
||||
let idx3 = input_indexes[read_head + 3];
|
||||
|
||||
let doc0 = input_doc_ids[read_head];
|
||||
let doc1 = input_doc_ids[read_head + 1];
|
||||
let doc2 = input_doc_ids[read_head + 2];
|
||||
let doc3 = input_doc_ids[read_head + 3];
|
||||
|
||||
let val0 = self.get_val(idx0);
|
||||
let val1 = self.get_val(idx1);
|
||||
let val2 = self.get_val(idx2);
|
||||
let val3 = self.get_val(idx3);
|
||||
|
||||
if range.contains(&val0) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc0,
|
||||
sort_key: Some(val0),
|
||||
});
|
||||
}
|
||||
if range.contains(&val1) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc1,
|
||||
sort_key: Some(val1),
|
||||
});
|
||||
}
|
||||
if range.contains(&val2) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc2,
|
||||
sort_key: Some(val2),
|
||||
});
|
||||
}
|
||||
if range.contains(&val3) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc3,
|
||||
sort_key: Some(val3),
|
||||
});
|
||||
}
|
||||
|
||||
read_head += 4;
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThan(ref threshold, _) => {
|
||||
while read_head + 3 < len {
|
||||
let idx0 = input_indexes[read_head];
|
||||
let idx1 = input_indexes[read_head + 1];
|
||||
let idx2 = input_indexes[read_head + 2];
|
||||
let idx3 = input_indexes[read_head + 3];
|
||||
|
||||
let doc0 = input_doc_ids[read_head];
|
||||
let doc1 = input_doc_ids[read_head + 1];
|
||||
let doc2 = input_doc_ids[read_head + 2];
|
||||
let doc3 = input_doc_ids[read_head + 3];
|
||||
|
||||
let val0 = self.get_val(idx0);
|
||||
let val1 = self.get_val(idx1);
|
||||
let val2 = self.get_val(idx2);
|
||||
let val3 = self.get_val(idx3);
|
||||
|
||||
if val0 > *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc0,
|
||||
sort_key: Some(val0),
|
||||
});
|
||||
}
|
||||
if val1 > *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc1,
|
||||
sort_key: Some(val1),
|
||||
});
|
||||
}
|
||||
if val2 > *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc2,
|
||||
sort_key: Some(val2),
|
||||
});
|
||||
}
|
||||
if val3 > *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc3,
|
||||
sort_key: Some(val3),
|
||||
});
|
||||
}
|
||||
|
||||
read_head += 4;
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(ref threshold, _) => {
|
||||
while read_head + 3 < len {
|
||||
let idx0 = input_indexes[read_head];
|
||||
let idx1 = input_indexes[read_head + 1];
|
||||
let idx2 = input_indexes[read_head + 2];
|
||||
let idx3 = input_indexes[read_head + 3];
|
||||
|
||||
let doc0 = input_doc_ids[read_head];
|
||||
let doc1 = input_doc_ids[read_head + 1];
|
||||
let doc2 = input_doc_ids[read_head + 2];
|
||||
let doc3 = input_doc_ids[read_head + 3];
|
||||
|
||||
let val0 = self.get_val(idx0);
|
||||
let val1 = self.get_val(idx1);
|
||||
let val2 = self.get_val(idx2);
|
||||
let val3 = self.get_val(idx3);
|
||||
|
||||
if val0 >= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc0,
|
||||
sort_key: Some(val0),
|
||||
});
|
||||
}
|
||||
if val1 >= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc1,
|
||||
sort_key: Some(val1),
|
||||
});
|
||||
}
|
||||
if val2 >= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc2,
|
||||
sort_key: Some(val2),
|
||||
});
|
||||
}
|
||||
if val3 >= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc3,
|
||||
sort_key: Some(val3),
|
||||
});
|
||||
}
|
||||
|
||||
read_head += 4;
|
||||
}
|
||||
}
|
||||
ValueRange::LessThan(ref threshold, _) => {
|
||||
while read_head + 3 < len {
|
||||
let idx0 = input_indexes[read_head];
|
||||
let idx1 = input_indexes[read_head + 1];
|
||||
let idx2 = input_indexes[read_head + 2];
|
||||
let idx3 = input_indexes[read_head + 3];
|
||||
|
||||
let doc0 = input_doc_ids[read_head];
|
||||
let doc1 = input_doc_ids[read_head + 1];
|
||||
let doc2 = input_doc_ids[read_head + 2];
|
||||
let doc3 = input_doc_ids[read_head + 3];
|
||||
|
||||
let val0 = self.get_val(idx0);
|
||||
let val1 = self.get_val(idx1);
|
||||
let val2 = self.get_val(idx2);
|
||||
let val3 = self.get_val(idx3);
|
||||
|
||||
if val0 < *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc0,
|
||||
sort_key: Some(val0),
|
||||
});
|
||||
}
|
||||
if val1 < *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc1,
|
||||
sort_key: Some(val1),
|
||||
});
|
||||
}
|
||||
if val2 < *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc2,
|
||||
sort_key: Some(val2),
|
||||
});
|
||||
}
|
||||
if val3 < *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc3,
|
||||
sort_key: Some(val3),
|
||||
});
|
||||
}
|
||||
|
||||
read_head += 4;
|
||||
}
|
||||
}
|
||||
ValueRange::LessThanOrEqual(ref threshold, _) => {
|
||||
while read_head + 3 < len {
|
||||
let idx0 = input_indexes[read_head];
|
||||
let idx1 = input_indexes[read_head + 1];
|
||||
let idx2 = input_indexes[read_head + 2];
|
||||
let idx3 = input_indexes[read_head + 3];
|
||||
|
||||
let doc0 = input_doc_ids[read_head];
|
||||
let doc1 = input_doc_ids[read_head + 1];
|
||||
let doc2 = input_doc_ids[read_head + 2];
|
||||
let doc3 = input_doc_ids[read_head + 3];
|
||||
|
||||
let val0 = self.get_val(idx0);
|
||||
let val1 = self.get_val(idx1);
|
||||
let val2 = self.get_val(idx2);
|
||||
let val3 = self.get_val(idx3);
|
||||
|
||||
if val0 <= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc0,
|
||||
sort_key: Some(val0),
|
||||
});
|
||||
}
|
||||
if val1 <= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc1,
|
||||
sort_key: Some(val1),
|
||||
});
|
||||
}
|
||||
if val2 <= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc2,
|
||||
sort_key: Some(val2),
|
||||
});
|
||||
}
|
||||
if val3 <= *threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc: doc3,
|
||||
sort_key: Some(val3),
|
||||
});
|
||||
}
|
||||
|
||||
read_head += 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Process remaining elements (0 to 3)
|
||||
while read_head < len {
|
||||
let idx = input_indexes[read_head];
|
||||
let doc = input_doc_ids[read_head];
|
||||
let val = self.get_val(idx);
|
||||
let matches = match value_range {
|
||||
// 'value_range' is still moved here. This is the outer `value_range`
|
||||
ValueRange::All => true,
|
||||
ValueRange::Inclusive(ref r) => r.contains(&val),
|
||||
ValueRange::GreaterThan(ref t, _) => val > *t,
|
||||
ValueRange::GreaterThanOrEqual(ref t, _) => val >= *t,
|
||||
ValueRange::LessThan(ref t, _) => val < *t,
|
||||
ValueRange::LessThanOrEqual(ref t, _) => val <= *t,
|
||||
};
|
||||
if matches {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(val),
|
||||
});
|
||||
}
|
||||
read_head += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Fills an output buffer with the fast field values
|
||||
/// associated with the `DocId` going from
|
||||
/// `start` to `start + output.len()`.
|
||||
@@ -129,15 +431,54 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
|
||||
/// Note that position == docid for single value fast fields
|
||||
fn get_row_ids_for_value_range(
|
||||
&self,
|
||||
value_range: RangeInclusive<T>,
|
||||
value_range: ValueRange<T>,
|
||||
row_id_range: Range<RowId>,
|
||||
row_id_hits: &mut Vec<RowId>,
|
||||
) {
|
||||
let row_id_range = row_id_range.start..row_id_range.end.min(self.num_vals());
|
||||
for idx in row_id_range {
|
||||
let val = self.get_val(idx);
|
||||
if value_range.contains(&val) {
|
||||
row_id_hits.push(idx);
|
||||
match value_range {
|
||||
ValueRange::Inclusive(range) => {
|
||||
for idx in row_id_range {
|
||||
let val = self.get_val(idx);
|
||||
if range.contains(&val) {
|
||||
row_id_hits.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThan(threshold, _) => {
|
||||
for idx in row_id_range {
|
||||
let val = self.get_val(idx);
|
||||
if val > threshold {
|
||||
row_id_hits.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(threshold, _) => {
|
||||
for idx in row_id_range {
|
||||
let val = self.get_val(idx);
|
||||
if val >= threshold {
|
||||
row_id_hits.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::LessThan(threshold, _) => {
|
||||
for idx in row_id_range {
|
||||
let val = self.get_val(idx);
|
||||
if val < threshold {
|
||||
row_id_hits.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::LessThanOrEqual(threshold, _) => {
|
||||
for idx in row_id_range {
|
||||
let val = self.get_val(idx);
|
||||
if val <= threshold {
|
||||
row_id_hits.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::All => {
|
||||
row_id_hits.extend(row_id_range);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -193,6 +534,17 @@ impl<T: PartialOrd + Default> ColumnValues<T> for EmptyColumnValues {
|
||||
fn num_vals(&self) -> u32 {
|
||||
0
|
||||
}
|
||||
|
||||
fn get_vals_in_value_range(
|
||||
&self,
|
||||
input_indexes: &[u32],
|
||||
input_doc_ids: &[u32],
|
||||
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
|
||||
value_range: ValueRange<T>,
|
||||
) {
|
||||
let _ = (input_indexes, input_doc_ids, output, value_range);
|
||||
panic!("Internal Error: Called get_vals_in_value_range of empty column.")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnValues<T>> {
|
||||
@@ -206,6 +558,18 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
|
||||
self.as_ref().get_vals_opt(indexes, output)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_vals_in_value_range(
|
||||
&self,
|
||||
input_indexes: &[u32],
|
||||
input_doc_ids: &[u32],
|
||||
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
|
||||
value_range: ValueRange<T>,
|
||||
) {
|
||||
self.as_ref()
|
||||
.get_vals_in_value_range(input_indexes, input_doc_ids, output, value_range)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn min_value(&self) -> T {
|
||||
self.as_ref().min_value()
|
||||
@@ -234,7 +598,7 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
|
||||
#[inline(always)]
|
||||
fn get_row_ids_for_value_range(
|
||||
&self,
|
||||
range: RangeInclusive<T>,
|
||||
range: ValueRange<T>,
|
||||
doc_id_range: Range<u32>,
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::{Range, RangeInclusive};
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::ColumnValues;
|
||||
use crate::column::ValueRange;
|
||||
use crate::column_values::monotonic_mapping::StrictlyMonotonicFn;
|
||||
|
||||
struct MonotonicMappingColumn<C, T, Input> {
|
||||
@@ -80,16 +81,52 @@ where
|
||||
|
||||
fn get_row_ids_for_value_range(
|
||||
&self,
|
||||
range: RangeInclusive<Output>,
|
||||
range: ValueRange<Output>,
|
||||
doc_id_range: Range<u32>,
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
self.from_column.get_row_ids_for_value_range(
|
||||
self.monotonic_mapping.inverse(range.start().clone())
|
||||
..=self.monotonic_mapping.inverse(range.end().clone()),
|
||||
doc_id_range,
|
||||
positions,
|
||||
)
|
||||
match range {
|
||||
ValueRange::Inclusive(range) => self.from_column.get_row_ids_for_value_range(
|
||||
ValueRange::Inclusive(
|
||||
self.monotonic_mapping.inverse(range.start().clone())
|
||||
..=self.monotonic_mapping.inverse(range.end().clone()),
|
||||
),
|
||||
doc_id_range,
|
||||
positions,
|
||||
),
|
||||
ValueRange::All => self.from_column.get_row_ids_for_value_range(
|
||||
ValueRange::All,
|
||||
doc_id_range,
|
||||
positions,
|
||||
),
|
||||
ValueRange::GreaterThan(threshold, _) => self.from_column.get_row_ids_for_value_range(
|
||||
ValueRange::GreaterThan(self.monotonic_mapping.inverse(threshold), false),
|
||||
doc_id_range,
|
||||
positions,
|
||||
),
|
||||
ValueRange::GreaterThanOrEqual(threshold, _) => {
|
||||
self.from_column.get_row_ids_for_value_range(
|
||||
ValueRange::GreaterThanOrEqual(
|
||||
self.monotonic_mapping.inverse(threshold),
|
||||
false,
|
||||
),
|
||||
doc_id_range,
|
||||
positions,
|
||||
)
|
||||
}
|
||||
ValueRange::LessThan(threshold, _) => self.from_column.get_row_ids_for_value_range(
|
||||
ValueRange::LessThan(self.monotonic_mapping.inverse(threshold), false),
|
||||
doc_id_range,
|
||||
positions,
|
||||
),
|
||||
ValueRange::LessThanOrEqual(threshold, _) => {
|
||||
self.from_column.get_row_ids_for_value_range(
|
||||
ValueRange::LessThanOrEqual(self.monotonic_mapping.inverse(threshold), false),
|
||||
doc_id_range,
|
||||
positions,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We voluntarily do not implement get_range as it yields a regression,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::fmt::Debug;
|
||||
use std::net::Ipv6Addr;
|
||||
|
||||
/// Montonic maps a value to u128 value space
|
||||
/// Monotonic maps a value to u128 value space
|
||||
/// Monotonic mapping enables `PartialOrd` on u128 space without conversion to original space.
|
||||
pub trait MonotonicallyMappableToU128: 'static + PartialOrd + Copy + Debug + Send + Sync {
|
||||
/// Converts a value to u128.
|
||||
|
||||
@@ -2,7 +2,8 @@ use std::io;
|
||||
use std::io::Write;
|
||||
use std::num::NonZeroU64;
|
||||
|
||||
use common::{BinarySerializable, VInt};
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{BinarySerializable, HasLen, VInt};
|
||||
|
||||
use crate::RowId;
|
||||
|
||||
@@ -27,6 +28,55 @@ impl ColumnStats {
|
||||
}
|
||||
}
|
||||
|
||||
impl ColumnStats {
|
||||
/// Deserialize from the tail of the given FileSlice, and return the stats and remaining prefix
|
||||
/// FileSlice.
|
||||
pub fn deserialize_from_tail(file_slice: FileSlice) -> io::Result<(Self, FileSlice)> {
|
||||
// [`deserialize_with_size`] deserializes 4 variable-width encoded u64s, which
|
||||
// could end up being, in the worst case, 9 bytes each. this is where the 36 comes from
|
||||
let (stats, _) = file_slice.clone().split(36.min(file_slice.len())); // hope that's enough bytes
|
||||
let mut stats = stats.read_bytes()?;
|
||||
let (stats, stats_nbytes) = ColumnStats::deserialize_with_size(&mut stats)?;
|
||||
let (_, remainder) = file_slice.split(stats_nbytes);
|
||||
Ok((stats, remainder))
|
||||
}
|
||||
|
||||
/// Same as [`BinarySeerializable::deserialize`] but also returns the number of bytes
|
||||
/// consumed from the reader `R`
|
||||
fn deserialize_with_size<R: io::Read>(reader: &mut R) -> io::Result<(Self, usize)> {
|
||||
let mut nbytes = 0;
|
||||
|
||||
let (min_value, len) = VInt::deserialize_with_size(reader)?;
|
||||
let min_value = min_value.0;
|
||||
nbytes += len;
|
||||
|
||||
let (gcd, len) = VInt::deserialize_with_size(reader)?;
|
||||
let gcd = gcd.0;
|
||||
let gcd = NonZeroU64::new(gcd)
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "GCD of 0 is forbidden"))?;
|
||||
nbytes += len;
|
||||
|
||||
let (amplitude, len) = VInt::deserialize_with_size(reader)?;
|
||||
let amplitude = amplitude.0 * gcd.get();
|
||||
let max_value = min_value + amplitude;
|
||||
nbytes += len;
|
||||
|
||||
let (num_rows, len) = VInt::deserialize_with_size(reader)?;
|
||||
let num_rows = num_rows.0 as RowId;
|
||||
nbytes += len;
|
||||
|
||||
Ok((
|
||||
ColumnStats {
|
||||
min_value,
|
||||
max_value,
|
||||
num_rows,
|
||||
gcd,
|
||||
},
|
||||
nbytes,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl BinarySerializable for ColumnStats {
|
||||
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
|
||||
VInt(self.min_value).serialize(writer)?;
|
||||
|
||||
@@ -25,6 +25,7 @@ use common::{BinarySerializable, CountingWriter, OwnedBytes, VInt, VIntU128};
|
||||
use tantivy_bitpacker::{BitPacker, BitUnpacker};
|
||||
|
||||
use crate::RowId;
|
||||
use crate::column::ValueRange;
|
||||
use crate::column_values::ColumnValues;
|
||||
|
||||
/// The cost per blank is quite hard actually, since blanks are delta encoded, the actual cost of
|
||||
@@ -338,14 +339,48 @@ impl ColumnValues<u64> for CompactSpaceU64Accessor {
|
||||
#[inline]
|
||||
fn get_row_ids_for_value_range(
|
||||
&self,
|
||||
value_range: RangeInclusive<u64>,
|
||||
value_range: ValueRange<u64>,
|
||||
position_range: Range<u32>,
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
let value_range = self.0.compact_to_u128(*value_range.start() as u32)
|
||||
..=self.0.compact_to_u128(*value_range.end() as u32);
|
||||
self.0
|
||||
.get_row_ids_for_value_range(value_range, position_range, positions)
|
||||
match value_range {
|
||||
ValueRange::Inclusive(value_range) => {
|
||||
let value_range = ValueRange::Inclusive(
|
||||
self.0.compact_to_u128(*value_range.start() as u32)
|
||||
..=self.0.compact_to_u128(*value_range.end() as u32),
|
||||
);
|
||||
self.0
|
||||
.get_row_ids_for_value_range(value_range, position_range, positions)
|
||||
}
|
||||
ValueRange::All => {
|
||||
let position_range = position_range.start..position_range.end.min(self.num_vals());
|
||||
positions.extend(position_range);
|
||||
}
|
||||
ValueRange::GreaterThan(threshold, _) => {
|
||||
let value_range =
|
||||
ValueRange::GreaterThan(self.0.compact_to_u128(threshold as u32), false);
|
||||
self.0
|
||||
.get_row_ids_for_value_range(value_range, position_range, positions)
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(threshold, _) => {
|
||||
let value_range =
|
||||
ValueRange::GreaterThanOrEqual(self.0.compact_to_u128(threshold as u32), false);
|
||||
self.0
|
||||
.get_row_ids_for_value_range(value_range, position_range, positions)
|
||||
}
|
||||
ValueRange::LessThan(threshold, _) => {
|
||||
let value_range =
|
||||
ValueRange::LessThan(self.0.compact_to_u128(threshold as u32), false);
|
||||
self.0
|
||||
.get_row_ids_for_value_range(value_range, position_range, positions)
|
||||
}
|
||||
ValueRange::LessThanOrEqual(threshold, _) => {
|
||||
let value_range =
|
||||
ValueRange::LessThanOrEqual(self.0.compact_to_u128(threshold as u32), false);
|
||||
self.0
|
||||
.get_row_ids_for_value_range(value_range, position_range, positions)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -375,10 +410,47 @@ impl ColumnValues<u128> for CompactSpaceDecompressor {
|
||||
#[inline]
|
||||
fn get_row_ids_for_value_range(
|
||||
&self,
|
||||
value_range: RangeInclusive<u128>,
|
||||
value_range: ValueRange<u128>,
|
||||
position_range: Range<u32>,
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
let value_range = match value_range {
|
||||
ValueRange::Inclusive(value_range) => value_range,
|
||||
ValueRange::All => {
|
||||
let position_range = position_range.start..position_range.end.min(self.num_vals());
|
||||
positions.extend(position_range);
|
||||
return;
|
||||
}
|
||||
ValueRange::GreaterThan(threshold, _) => {
|
||||
let max = self.max_value();
|
||||
if threshold >= max {
|
||||
return;
|
||||
}
|
||||
(threshold + 1)..=max
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(threshold, _) => {
|
||||
let max = self.max_value();
|
||||
if threshold > max {
|
||||
return;
|
||||
}
|
||||
threshold..=max
|
||||
}
|
||||
ValueRange::LessThan(threshold, _) => {
|
||||
let min = self.min_value();
|
||||
if threshold <= min {
|
||||
return;
|
||||
}
|
||||
min..=(threshold - 1)
|
||||
}
|
||||
ValueRange::LessThanOrEqual(threshold, _) => {
|
||||
let min = self.min_value();
|
||||
if threshold < min {
|
||||
return;
|
||||
}
|
||||
min..=threshold
|
||||
}
|
||||
};
|
||||
|
||||
if value_range.start() > value_range.end() {
|
||||
return;
|
||||
}
|
||||
@@ -560,7 +632,7 @@ mod tests {
|
||||
.collect::<Vec<_>>();
|
||||
let mut positions = Vec::new();
|
||||
decompressor.get_row_ids_for_value_range(
|
||||
range,
|
||||
ValueRange::Inclusive(range),
|
||||
0..decompressor.num_vals(),
|
||||
&mut positions,
|
||||
);
|
||||
@@ -604,7 +676,11 @@ mod tests {
|
||||
let val = *val;
|
||||
let pos = pos as u32;
|
||||
let mut positions = Vec::new();
|
||||
decomp.get_row_ids_for_value_range(val..=val, pos..pos + 1, &mut positions);
|
||||
decomp.get_row_ids_for_value_range(
|
||||
ValueRange::Inclusive(val..=val),
|
||||
pos..pos + 1,
|
||||
&mut positions,
|
||||
);
|
||||
assert_eq!(positions, vec![pos]);
|
||||
}
|
||||
|
||||
@@ -746,7 +822,11 @@ mod tests {
|
||||
doc_id_range: Range<u32>,
|
||||
) -> Vec<u32> {
|
||||
let mut positions = Vec::new();
|
||||
column.get_row_ids_for_value_range(value_range, doc_id_range, &mut positions);
|
||||
column.get_row_ids_for_value_range(
|
||||
ValueRange::Inclusive(value_range),
|
||||
doc_id_range,
|
||||
&mut positions,
|
||||
);
|
||||
positions
|
||||
}
|
||||
|
||||
@@ -769,7 +849,7 @@ mod tests {
|
||||
];
|
||||
let mut out = Vec::new();
|
||||
serialize_column_values_u128(&&vals[..], &mut out).unwrap();
|
||||
let decomp = open_u128_mapped(OwnedBytes::new(out)).unwrap();
|
||||
let decomp = open_u128_mapped(FileSlice::from(out)).unwrap();
|
||||
let complete_range = 0..vals.len() as u32;
|
||||
|
||||
assert_eq!(
|
||||
@@ -823,6 +903,7 @@ mod tests {
|
||||
let _data = test_aux_vals(vals);
|
||||
}
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use proptest::prelude::*;
|
||||
|
||||
fn num_strategy() -> impl Strategy<Value = u128> {
|
||||
|
||||
@@ -5,7 +5,8 @@ use std::sync::Arc;
|
||||
|
||||
mod compact_space;
|
||||
|
||||
use common::{BinarySerializable, OwnedBytes, VInt};
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{BinarySerializable, VInt};
|
||||
pub use compact_space::{
|
||||
CompactSpaceCompressor, CompactSpaceDecompressor, CompactSpaceU64Accessor,
|
||||
};
|
||||
@@ -101,8 +102,9 @@ impl U128FastFieldCodecType {
|
||||
|
||||
/// Returns the correct codec reader wrapped in the `Arc` for the data.
|
||||
pub fn open_u128_mapped<T: MonotonicallyMappableToU128 + Debug>(
|
||||
mut bytes: OwnedBytes,
|
||||
file_slice: FileSlice,
|
||||
) -> io::Result<Arc<dyn ColumnValues<T>>> {
|
||||
let mut bytes = file_slice.read_bytes()?;
|
||||
let header = U128Header::deserialize(&mut bytes)?;
|
||||
assert_eq!(header.codec_type, U128FastFieldCodecType::CompactSpace);
|
||||
let reader = CompactSpaceDecompressor::open(bytes)?;
|
||||
@@ -120,7 +122,8 @@ pub fn open_u128_mapped<T: MonotonicallyMappableToU128 + Debug>(
|
||||
/// # Notice
|
||||
/// In case there are new codecs added, check for usages of `CompactSpaceDecompressorU64` and
|
||||
/// also handle the new codecs.
|
||||
pub fn open_u128_as_compact_u64(mut bytes: OwnedBytes) -> io::Result<Arc<dyn ColumnValues<u64>>> {
|
||||
pub fn open_u128_as_compact_u64(file_slice: FileSlice) -> io::Result<Arc<dyn ColumnValues<u64>>> {
|
||||
let mut bytes = file_slice.read_bytes()?;
|
||||
let header = U128Header::deserialize(&mut bytes)?;
|
||||
assert_eq!(header.codec_type, U128FastFieldCodecType::CompactSpace);
|
||||
let reader = CompactSpaceU64Accessor::open(bytes)?;
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
use std::io::{self, Write};
|
||||
use std::num::NonZeroU64;
|
||||
use std::ops::{Range, RangeInclusive};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use common::{BinarySerializable, OwnedBytes};
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{BinarySerializable, HasLen, OwnedBytes};
|
||||
use fastdivide::DividerU64;
|
||||
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
|
||||
|
||||
use crate::column::ValueRange;
|
||||
use crate::column_values::u64_based::{ColumnCodec, ColumnCodecEstimator, ColumnStats};
|
||||
use crate::{ColumnValues, RowId};
|
||||
|
||||
@@ -13,9 +16,40 @@ use crate::{ColumnValues, RowId};
|
||||
/// fast field is required.
|
||||
#[derive(Clone)]
|
||||
pub struct BitpackedReader {
|
||||
data: OwnedBytes,
|
||||
data: FileSlice,
|
||||
bit_unpacker: BitUnpacker,
|
||||
stats: ColumnStats,
|
||||
blocks: Arc<[OnceLock<Block>]>,
|
||||
}
|
||||
|
||||
impl BitpackedReader {
|
||||
#[inline(always)]
|
||||
fn unpack_val(&self, doc: u32) -> u64 {
|
||||
let block_num = self.bit_unpacker.block_num(doc);
|
||||
|
||||
if block_num == 0 && self.blocks.len() == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let block = self.blocks[block_num].get_or_init(|| {
|
||||
let block_range = self.bit_unpacker.block(block_num, self.data.len());
|
||||
let offset = block_range.start;
|
||||
let data = self
|
||||
.data
|
||||
.slice(block_range)
|
||||
.read_bytes()
|
||||
.expect("Failed to read column values.");
|
||||
Block { offset, data }
|
||||
});
|
||||
|
||||
self.bit_unpacker
|
||||
.get_from_subset(doc, block.offset, &block.data)
|
||||
}
|
||||
}
|
||||
|
||||
struct Block {
|
||||
offset: usize,
|
||||
data: OwnedBytes,
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
@@ -57,8 +91,9 @@ fn transform_range_before_linear_transformation(
|
||||
impl ColumnValues for BitpackedReader {
|
||||
#[inline(always)]
|
||||
fn get_val(&self, doc: u32) -> u64 {
|
||||
self.stats.min_value + self.stats.gcd.get() * self.bit_unpacker.get(doc, &self.data)
|
||||
self.stats.min_value + self.stats.gcd.get() * self.unpack_val(doc)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn min_value(&self) -> u64 {
|
||||
self.stats.min_value
|
||||
@@ -72,24 +107,329 @@ impl ColumnValues for BitpackedReader {
|
||||
self.stats.num_rows
|
||||
}
|
||||
|
||||
fn get_vals_in_value_range(
|
||||
&self,
|
||||
input_indexes: &[u32],
|
||||
input_doc_ids: &[u32],
|
||||
output: &mut Vec<crate::ComparableDoc<Option<u64>, crate::DocId>>,
|
||||
value_range: ValueRange<u64>,
|
||||
) {
|
||||
match value_range {
|
||||
ValueRange::All => {
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(self.get_val(idx)),
|
||||
});
|
||||
}
|
||||
}
|
||||
ValueRange::Inclusive(range) => {
|
||||
if let Some(transformed_range) =
|
||||
transform_range_before_linear_transformation(&self.stats, range)
|
||||
{
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
let raw_val = self.unpack_val(idx);
|
||||
if transformed_range.contains(&raw_val) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(
|
||||
self.stats.min_value + self.stats.gcd.get() * raw_val,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThan(threshold, _) => {
|
||||
if threshold < self.stats.min_value {
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(self.get_val(idx)),
|
||||
});
|
||||
}
|
||||
} else if threshold >= self.stats.max_value {
|
||||
// All filtered out
|
||||
} else {
|
||||
let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get();
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
let raw_val = self.unpack_val(idx);
|
||||
if raw_val > raw_threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(
|
||||
self.stats.min_value + self.stats.gcd.get() * raw_val,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(threshold, _) => {
|
||||
if threshold <= self.stats.min_value {
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(self.get_val(idx)),
|
||||
});
|
||||
}
|
||||
} else if threshold > self.stats.max_value {
|
||||
// All filtered out
|
||||
} else {
|
||||
let diff = threshold - self.stats.min_value;
|
||||
let gcd = self.stats.gcd.get();
|
||||
let raw_threshold = (diff + gcd - 1) / gcd;
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
let raw_val = self.unpack_val(idx);
|
||||
if raw_val >= raw_threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(
|
||||
self.stats.min_value + self.stats.gcd.get() * raw_val,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::LessThan(threshold, _) => {
|
||||
if threshold > self.stats.max_value {
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(self.get_val(idx)),
|
||||
});
|
||||
}
|
||||
} else if threshold <= self.stats.min_value {
|
||||
// All filtered out
|
||||
} else {
|
||||
let diff = threshold - self.stats.min_value;
|
||||
let gcd = self.stats.gcd.get();
|
||||
let raw_threshold = if diff % gcd == 0 {
|
||||
diff / gcd
|
||||
} else {
|
||||
diff / gcd + 1
|
||||
};
|
||||
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
let raw_val = self.unpack_val(idx);
|
||||
if raw_val < raw_threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(
|
||||
self.stats.min_value + self.stats.gcd.get() * raw_val,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ValueRange::LessThanOrEqual(threshold, _) => {
|
||||
if threshold >= self.stats.max_value {
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(self.get_val(idx)),
|
||||
});
|
||||
}
|
||||
} else if threshold < self.stats.min_value {
|
||||
// All filtered out
|
||||
} else {
|
||||
let diff = threshold - self.stats.min_value;
|
||||
let gcd = self.stats.gcd.get();
|
||||
let raw_threshold = diff / gcd;
|
||||
|
||||
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
|
||||
let raw_val = self.unpack_val(idx);
|
||||
if raw_val <= raw_threshold {
|
||||
output.push(crate::ComparableDoc {
|
||||
doc,
|
||||
sort_key: Some(
|
||||
self.stats.min_value + self.stats.gcd.get() * raw_val,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fn get_row_ids_for_value_range(
|
||||
&self,
|
||||
range: RangeInclusive<u64>,
|
||||
range: ValueRange<u64>,
|
||||
doc_id_range: Range<u32>,
|
||||
positions: &mut Vec<u32>,
|
||||
) {
|
||||
let Some(transformed_range) =
|
||||
transform_range_before_linear_transformation(&self.stats, range)
|
||||
else {
|
||||
positions.clear();
|
||||
return;
|
||||
};
|
||||
self.bit_unpacker.get_ids_for_value_range(
|
||||
transformed_range,
|
||||
doc_id_range,
|
||||
&self.data,
|
||||
positions,
|
||||
);
|
||||
match range {
|
||||
ValueRange::All => {
|
||||
positions.extend(doc_id_range);
|
||||
return;
|
||||
}
|
||||
ValueRange::Inclusive(range) => {
|
||||
let Some(transformed_range) =
|
||||
transform_range_before_linear_transformation(&self.stats, range)
|
||||
else {
|
||||
positions.clear();
|
||||
return;
|
||||
};
|
||||
// TODO: This does not use the `self.blocks` cache, because callers are usually
|
||||
// already doing sequential, and fairly dense reads. Fix it to
|
||||
// iterate over blocks if that assumption turns out to be incorrect!
|
||||
let data_range = self
|
||||
.bit_unpacker
|
||||
.block_oblivious_range(doc_id_range.clone(), self.data.len());
|
||||
let data_offset = data_range.start;
|
||||
let data_subset = self
|
||||
.data
|
||||
.slice(data_range)
|
||||
.read_bytes()
|
||||
.expect("Failed to read column values.");
|
||||
self.bit_unpacker.get_ids_for_value_range_from_subset(
|
||||
transformed_range,
|
||||
doc_id_range,
|
||||
data_offset,
|
||||
&data_subset,
|
||||
positions,
|
||||
);
|
||||
}
|
||||
ValueRange::GreaterThan(threshold, _) => {
|
||||
if threshold < self.stats.min_value {
|
||||
positions.extend(doc_id_range);
|
||||
return;
|
||||
}
|
||||
if threshold >= self.stats.max_value {
|
||||
return;
|
||||
}
|
||||
let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get();
|
||||
// We want raw > raw_threshold.
|
||||
// bit_unpacker.get_ids_for_value_range_from_subset takes a RangeInclusive.
|
||||
// We can construct a RangeInclusive: (raw_threshold + 1) ..= u64::MAX
|
||||
// But max raw value is known? (max_value - min_value) / gcd.
|
||||
let max_raw = (self.stats.max_value - self.stats.min_value) / self.stats.gcd.get();
|
||||
let transformed_range = (raw_threshold + 1)..=max_raw;
|
||||
|
||||
let data_range = self
|
||||
.bit_unpacker
|
||||
.block_oblivious_range(doc_id_range.clone(), self.data.len());
|
||||
let data_offset = data_range.start;
|
||||
let data_subset = self
|
||||
.data
|
||||
.slice(data_range)
|
||||
.read_bytes()
|
||||
.expect("Failed to read column values.");
|
||||
self.bit_unpacker.get_ids_for_value_range_from_subset(
|
||||
transformed_range,
|
||||
doc_id_range,
|
||||
data_offset,
|
||||
&data_subset,
|
||||
positions,
|
||||
);
|
||||
}
|
||||
ValueRange::GreaterThanOrEqual(threshold, _) => {
|
||||
if threshold <= self.stats.min_value {
|
||||
positions.extend(doc_id_range);
|
||||
return;
|
||||
}
|
||||
if threshold > self.stats.max_value {
|
||||
return;
|
||||
}
|
||||
let diff = threshold - self.stats.min_value;
|
||||
let gcd = self.stats.gcd.get();
|
||||
let raw_threshold = (diff + gcd - 1) / gcd;
|
||||
// We want raw >= raw_threshold.
|
||||
let max_raw = (self.stats.max_value - self.stats.min_value) / self.stats.gcd.get();
|
||||
let transformed_range = raw_threshold..=max_raw;
|
||||
|
||||
let data_range = self
|
||||
.bit_unpacker
|
||||
.block_oblivious_range(doc_id_range.clone(), self.data.len());
|
||||
let data_offset = data_range.start;
|
||||
let data_subset = self
|
||||
.data
|
||||
.slice(data_range)
|
||||
.read_bytes()
|
||||
.expect("Failed to read column values.");
|
||||
self.bit_unpacker.get_ids_for_value_range_from_subset(
|
||||
transformed_range,
|
||||
doc_id_range,
|
||||
data_offset,
|
||||
&data_subset,
|
||||
positions,
|
||||
);
|
||||
}
|
||||
ValueRange::LessThan(threshold, _) => {
|
||||
if threshold > self.stats.max_value {
|
||||
positions.extend(doc_id_range);
|
||||
return;
|
||||
}
|
||||
if threshold <= self.stats.min_value {
|
||||
return;
|
||||
}
|
||||
|
||||
let diff = threshold - self.stats.min_value;
|
||||
let gcd = self.stats.gcd.get();
|
||||
// We want raw < raw_threshold_limit
|
||||
// raw <= raw_threshold_limit - 1
|
||||
let raw_threshold_limit = if diff % gcd == 0 {
|
||||
diff / gcd
|
||||
} else {
|
||||
diff / gcd + 1
|
||||
};
|
||||
|
||||
if raw_threshold_limit == 0 {
|
||||
return;
|
||||
}
|
||||
let transformed_range = 0..=(raw_threshold_limit - 1);
|
||||
|
||||
let data_range = self
|
||||
.bit_unpacker
|
||||
.block_oblivious_range(doc_id_range.clone(), self.data.len());
|
||||
let data_offset = data_range.start;
|
||||
let data_subset = self
|
||||
.data
|
||||
.slice(data_range)
|
||||
.read_bytes()
|
||||
.expect("Failed to read column values.");
|
||||
self.bit_unpacker.get_ids_for_value_range_from_subset(
|
||||
transformed_range,
|
||||
doc_id_range,
|
||||
data_offset,
|
||||
&data_subset,
|
||||
positions,
|
||||
);
|
||||
}
|
||||
ValueRange::LessThanOrEqual(threshold, _) => {
|
||||
if threshold >= self.stats.max_value {
|
||||
positions.extend(doc_id_range);
|
||||
return;
|
||||
}
|
||||
if threshold < self.stats.min_value {
|
||||
return;
|
||||
}
|
||||
let diff = threshold - self.stats.min_value;
|
||||
let gcd = self.stats.gcd.get();
|
||||
// We want raw <= raw_threshold.
|
||||
let raw_threshold = diff / gcd;
|
||||
let transformed_range = 0..=raw_threshold;
|
||||
|
||||
let data_range = self
|
||||
.bit_unpacker
|
||||
.block_oblivious_range(doc_id_range.clone(), self.data.len());
|
||||
let data_offset = data_range.start;
|
||||
let data_subset = self
|
||||
.data
|
||||
.slice(data_range)
|
||||
.read_bytes()
|
||||
.expect("Failed to read column values.");
|
||||
self.bit_unpacker.get_ids_for_value_range_from_subset(
|
||||
transformed_range,
|
||||
doc_id_range,
|
||||
data_offset,
|
||||
&data_subset,
|
||||
positions,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,14 +473,20 @@ impl ColumnCodec for BitpackedCodec {
|
||||
type Estimator = BitpackedCodecEstimator;
|
||||
|
||||
/// Opens a fast field given a file.
|
||||
fn load(mut data: OwnedBytes) -> io::Result<Self::ColumnValues> {
|
||||
let stats = ColumnStats::deserialize(&mut data)?;
|
||||
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues> {
|
||||
let (stats, data) = ColumnStats::deserialize_from_tail(file_slice)?;
|
||||
|
||||
let num_bits = num_bits(&stats);
|
||||
let bit_unpacker = BitUnpacker::new(num_bits);
|
||||
let block_count = bit_unpacker.block_count(data.len());
|
||||
Ok(BitpackedReader {
|
||||
data,
|
||||
bit_unpacker,
|
||||
stats,
|
||||
blocks: (0..block_count)
|
||||
.into_iter()
|
||||
.map(|_| OnceLock::new())
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
use std::sync::Arc;
|
||||
use std::{io, iter};
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use common::{BinarySerializable, CountingWriter, DeserializeFrom, OwnedBytes};
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{BinarySerializable, CountingWriter, DeserializeFrom, HasLen, OwnedBytes};
|
||||
use fastdivide::DividerU64;
|
||||
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
|
||||
|
||||
@@ -172,32 +174,63 @@ impl ColumnCodec<u64> for BlockwiseLinearCodec {
|
||||
|
||||
type Estimator = BlockwiseLinearEstimator;
|
||||
|
||||
fn load(mut bytes: OwnedBytes) -> io::Result<Self::ColumnValues> {
|
||||
let stats = ColumnStats::deserialize(&mut bytes)?;
|
||||
let footer_len: u32 = (&bytes[bytes.len() - 4..]).deserialize()?;
|
||||
let footer_offset = bytes.len() - 4 - footer_len as usize;
|
||||
let (data, mut footer) = bytes.split(footer_offset);
|
||||
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues> {
|
||||
let (stats, body) = ColumnStats::deserialize_from_tail(file_slice)?;
|
||||
|
||||
let (_, footer) = body.clone().split_from_end(4);
|
||||
|
||||
let footer_len: u32 = footer.read_bytes()?.as_slice().deserialize()?;
|
||||
let (data, footer) = body.split_from_end(footer_len as usize + 4);
|
||||
|
||||
let mut footer = footer.read_bytes()?;
|
||||
let num_blocks = compute_num_blocks(stats.num_rows);
|
||||
let mut blocks: Vec<Block> = iter::repeat_with(|| Block::deserialize(&mut footer))
|
||||
.take(num_blocks as usize)
|
||||
.collect::<io::Result<_>>()?;
|
||||
|
||||
let mut start_offset = 0;
|
||||
for block in &mut blocks {
|
||||
let mut blocks = Vec::with_capacity(num_blocks as usize);
|
||||
|
||||
for _ in 0..num_blocks {
|
||||
let mut block = Block::deserialize(&mut footer)?;
|
||||
let len = (block.bit_unpacker.bit_width() as usize) * BLOCK_SIZE as usize / 8;
|
||||
|
||||
block.data_start_offset = start_offset;
|
||||
start_offset += (block.bit_unpacker.bit_width() as usize) * BLOCK_SIZE as usize / 8;
|
||||
blocks.push(BlockWithData {
|
||||
block,
|
||||
file_slice: data.slice(start_offset..(start_offset + len).min(data.len())),
|
||||
data: Default::default(),
|
||||
});
|
||||
|
||||
start_offset += len;
|
||||
}
|
||||
Ok(BlockwiseLinearReader {
|
||||
blocks: blocks.into_boxed_slice().into(),
|
||||
data,
|
||||
stats,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct BlockWithData {
|
||||
block: Block,
|
||||
file_slice: FileSlice,
|
||||
data: OnceLock<OwnedBytes>,
|
||||
}
|
||||
|
||||
impl Deref for BlockWithData {
|
||||
type Target = Block;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.block
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for BlockWithData {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.block
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct BlockwiseLinearReader {
|
||||
blocks: Arc<[Block]>,
|
||||
data: OwnedBytes,
|
||||
blocks: Arc<[BlockWithData]>,
|
||||
stats: ColumnStats,
|
||||
}
|
||||
|
||||
@@ -208,7 +241,9 @@ impl ColumnValues for BlockwiseLinearReader {
|
||||
let idx_within_block = idx % BLOCK_SIZE;
|
||||
let block = &self.blocks[block_id];
|
||||
let interpoled_val: u64 = block.line.eval(idx_within_block);
|
||||
let block_bytes = &self.data[block.data_start_offset..];
|
||||
let block_bytes = block
|
||||
.data
|
||||
.get_or_init(|| block.file_slice.read_bytes().unwrap());
|
||||
let bitpacked_diff = block.bit_unpacker.get(idx_within_block, block_bytes);
|
||||
// TODO optimize me! the line parameters could be tweaked to include the multiplication and
|
||||
// remove the dependency.
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::column_values::ColumnValues;
|
||||
const MID_POINT: u64 = (1u64 << 32) - 1u64;
|
||||
|
||||
/// `Line` describes a line function `y: ax + b` using integer
|
||||
/// arithmetics.
|
||||
/// arithmetic.
|
||||
///
|
||||
/// The slope is in fact a decimal split into a 32 bit integer value,
|
||||
/// and a 32-bit decimal value.
|
||||
@@ -94,7 +94,7 @@ impl Line {
|
||||
// `(i, ys[])`.
|
||||
//
|
||||
// The best intercept therefore has the form
|
||||
// `y[i] - line.eval(i)` (using wrapping arithmetics).
|
||||
// `y[i] - line.eval(i)` (using wrapping arithmetic).
|
||||
// In other words, the best intercept is one of the `y - Line::eval(ys[i])`
|
||||
// and our task is just to pick the one that minimizes our error.
|
||||
//
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::io;
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{BinarySerializable, OwnedBytes};
|
||||
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
|
||||
|
||||
@@ -190,7 +191,8 @@ impl ColumnCodec for LinearCodec {
|
||||
|
||||
type Estimator = LinearCodecEstimator;
|
||||
|
||||
fn load(mut data: OwnedBytes) -> io::Result<Self::ColumnValues> {
|
||||
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues> {
|
||||
let mut data = file_slice.read_bytes()?;
|
||||
let stats = ColumnStats::deserialize(&mut data)?;
|
||||
let linear_params = LinearParams::deserialize(&mut data)?;
|
||||
Ok(LinearReader {
|
||||
|
||||
@@ -8,7 +8,8 @@ use std::io;
|
||||
use std::io::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::{BinarySerializable, OwnedBytes};
|
||||
use common::BinarySerializable;
|
||||
use common::file_slice::FileSlice;
|
||||
|
||||
use crate::column_values::monotonic_mapping::{
|
||||
StrictlyMonotonicMappingInverter, StrictlyMonotonicMappingToInternal,
|
||||
@@ -52,7 +53,7 @@ pub trait ColumnCodecEstimator<T = u64>: 'static {
|
||||
) -> io::Result<()>;
|
||||
}
|
||||
|
||||
/// A column codec describes a colunm serialization format.
|
||||
/// A column codec describes a column serialization format.
|
||||
pub trait ColumnCodec<T: PartialOrd = u64> {
|
||||
/// Specialized `ColumnValues` type.
|
||||
type ColumnValues: ColumnValues<T> + 'static;
|
||||
@@ -60,7 +61,7 @@ pub trait ColumnCodec<T: PartialOrd = u64> {
|
||||
type Estimator: ColumnCodecEstimator + Default;
|
||||
|
||||
/// Loads a column that has been serialized using this codec.
|
||||
fn load(bytes: OwnedBytes) -> io::Result<Self::ColumnValues>;
|
||||
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues>;
|
||||
|
||||
/// Returns an estimator.
|
||||
fn estimator() -> Self::Estimator {
|
||||
@@ -111,20 +112,22 @@ impl CodecType {
|
||||
|
||||
fn load<T: MonotonicallyMappableToU64>(
|
||||
&self,
|
||||
bytes: OwnedBytes,
|
||||
file_slice: FileSlice,
|
||||
) -> io::Result<Arc<dyn ColumnValues<T>>> {
|
||||
match self {
|
||||
CodecType::Bitpacked => load_specific_codec::<BitpackedCodec, T>(bytes),
|
||||
CodecType::Linear => load_specific_codec::<LinearCodec, T>(bytes),
|
||||
CodecType::BlockwiseLinear => load_specific_codec::<BlockwiseLinearCodec, T>(bytes),
|
||||
CodecType::Bitpacked => load_specific_codec::<BitpackedCodec, T>(file_slice),
|
||||
CodecType::Linear => load_specific_codec::<LinearCodec, T>(file_slice),
|
||||
CodecType::BlockwiseLinear => {
|
||||
load_specific_codec::<BlockwiseLinearCodec, T>(file_slice)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_specific_codec<C: ColumnCodec, T: MonotonicallyMappableToU64>(
|
||||
bytes: OwnedBytes,
|
||||
file_slice: FileSlice,
|
||||
) -> io::Result<Arc<dyn ColumnValues<T>>> {
|
||||
let reader = C::load(bytes)?;
|
||||
let reader = C::load(file_slice)?;
|
||||
let reader_typed = monotonic_map_column(
|
||||
reader,
|
||||
StrictlyMonotonicMappingInverter::from(StrictlyMonotonicMappingToInternal::<T>::new()),
|
||||
@@ -189,25 +192,28 @@ pub fn serialize_u64_based_column_values<T: MonotonicallyMappableToU64>(
|
||||
///
|
||||
/// This method first identifies the codec off the first byte.
|
||||
pub fn load_u64_based_column_values<T: MonotonicallyMappableToU64>(
|
||||
mut bytes: OwnedBytes,
|
||||
file_slice: FileSlice,
|
||||
) -> io::Result<Arc<dyn ColumnValues<T>>> {
|
||||
let codec_type: CodecType = bytes
|
||||
.first()
|
||||
.copied()
|
||||
let (header, body) = file_slice.split(1);
|
||||
let codec_type: CodecType = header
|
||||
.read_bytes()?
|
||||
.as_slice()
|
||||
.get(0)
|
||||
.cloned()
|
||||
.and_then(CodecType::try_from_code)
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Failed to read codec type"))?;
|
||||
bytes.advance(1);
|
||||
codec_type.load(bytes)
|
||||
codec_type.load(body)
|
||||
}
|
||||
|
||||
/// Helper function to serialize a column (autodetect from all codecs) and then open it
|
||||
#[cfg(test)]
|
||||
pub fn serialize_and_load_u64_based_column_values<T: MonotonicallyMappableToU64>(
|
||||
vals: &dyn Iterable,
|
||||
codec_types: &[CodecType],
|
||||
) -> Arc<dyn ColumnValues<T>> {
|
||||
let mut buffer = Vec::new();
|
||||
serialize_u64_based_column_values(vals, codec_types, &mut buffer).unwrap();
|
||||
load_u64_based_column_values::<T>(OwnedBytes::new(buffer)).unwrap()
|
||||
load_u64_based_column_values::<T>(FileSlice::from(buffer)).unwrap()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use common::HasLen;
|
||||
use proptest::prelude::*;
|
||||
use proptest::{prop_oneof, proptest};
|
||||
use rand::Rng;
|
||||
@@ -13,7 +14,7 @@ fn test_serialize_and_load_simple() {
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(buffer.len(), 7);
|
||||
let col = load_u64_based_column_values::<u64>(OwnedBytes::new(buffer)).unwrap();
|
||||
let col = load_u64_based_column_values::<u64>(FileSlice::from(buffer)).unwrap();
|
||||
assert_eq!(col.num_vals(), 3);
|
||||
assert_eq!(col.get_val(0), 1);
|
||||
assert_eq!(col.get_val(1), 2);
|
||||
@@ -30,7 +31,7 @@ fn test_empty_column_i64() {
|
||||
continue;
|
||||
}
|
||||
num_acceptable_codecs += 1;
|
||||
let col = load_u64_based_column_values::<i64>(OwnedBytes::new(buffer)).unwrap();
|
||||
let col = load_u64_based_column_values::<i64>(FileSlice::from(buffer)).unwrap();
|
||||
assert_eq!(col.num_vals(), 0);
|
||||
assert_eq!(col.min_value(), i64::MIN);
|
||||
assert_eq!(col.max_value(), i64::MIN);
|
||||
@@ -48,7 +49,7 @@ fn test_empty_column_u64() {
|
||||
continue;
|
||||
}
|
||||
num_acceptable_codecs += 1;
|
||||
let col = load_u64_based_column_values::<u64>(OwnedBytes::new(buffer)).unwrap();
|
||||
let col = load_u64_based_column_values::<u64>(FileSlice::from(buffer)).unwrap();
|
||||
assert_eq!(col.num_vals(), 0);
|
||||
assert_eq!(col.min_value(), u64::MIN);
|
||||
assert_eq!(col.max_value(), u64::MIN);
|
||||
@@ -66,7 +67,7 @@ fn test_empty_column_f64() {
|
||||
continue;
|
||||
}
|
||||
num_acceptable_codecs += 1;
|
||||
let col = load_u64_based_column_values::<f64>(OwnedBytes::new(buffer)).unwrap();
|
||||
let col = load_u64_based_column_values::<f64>(FileSlice::from(buffer)).unwrap();
|
||||
assert_eq!(col.num_vals(), 0);
|
||||
// FIXME. f64::MIN would be better!
|
||||
assert!(col.min_value().is_nan());
|
||||
@@ -97,7 +98,7 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
|
||||
|
||||
let actual_compression = buffer.len() as u64;
|
||||
|
||||
let reader = TColumnCodec::load(OwnedBytes::new(buffer)).unwrap();
|
||||
let reader = TColumnCodec::load(FileSlice::from(buffer)).unwrap();
|
||||
assert_eq!(reader.num_vals(), vals.len() as u32);
|
||||
let mut buffer = Vec::new();
|
||||
for (doc, orig_val) in vals.iter().copied().enumerate() {
|
||||
@@ -131,7 +132,7 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
|
||||
.collect();
|
||||
let mut positions = Vec::new();
|
||||
reader.get_row_ids_for_value_range(
|
||||
vals[test_rand_idx]..=vals[test_rand_idx],
|
||||
crate::column::ValueRange::Inclusive(vals[test_rand_idx]..=vals[test_rand_idx]),
|
||||
0..vals.len() as u32,
|
||||
&mut positions,
|
||||
);
|
||||
@@ -326,7 +327,7 @@ fn test_fastfield_gcd_i64_with_codec(codec_type: CodecType, num_vals: usize) ->
|
||||
&[codec_type],
|
||||
&mut buffer,
|
||||
)?;
|
||||
let buffer = OwnedBytes::new(buffer);
|
||||
let buffer = FileSlice::from(buffer);
|
||||
let column = crate::column_values::load_u64_based_column_values::<i64>(buffer.clone())?;
|
||||
assert_eq!(column.get_val(0), -4000i64);
|
||||
assert_eq!(column.get_val(1), -3000i64);
|
||||
@@ -343,7 +344,7 @@ fn test_fastfield_gcd_i64_with_codec(codec_type: CodecType, num_vals: usize) ->
|
||||
&[codec_type],
|
||||
&mut buffer_without_gcd,
|
||||
)?;
|
||||
let buffer_without_gcd = OwnedBytes::new(buffer_without_gcd);
|
||||
let buffer_without_gcd = FileSlice::from(buffer_without_gcd);
|
||||
assert!(buffer_without_gcd.len() > buffer.len());
|
||||
|
||||
Ok(())
|
||||
@@ -369,7 +370,7 @@ fn test_fastfield_gcd_u64_with_codec(codec_type: CodecType, num_vals: usize) ->
|
||||
&[codec_type],
|
||||
&mut buffer,
|
||||
)?;
|
||||
let buffer = OwnedBytes::new(buffer);
|
||||
let buffer = FileSlice::from(buffer);
|
||||
let column = crate::column_values::load_u64_based_column_values::<u64>(buffer.clone())?;
|
||||
assert_eq!(column.get_val(0), 1000u64);
|
||||
assert_eq!(column.get_val(1), 2000u64);
|
||||
@@ -386,7 +387,7 @@ fn test_fastfield_gcd_u64_with_codec(codec_type: CodecType, num_vals: usize) ->
|
||||
&[codec_type],
|
||||
&mut buffer_without_gcd,
|
||||
)?;
|
||||
let buffer_without_gcd = OwnedBytes::new(buffer_without_gcd);
|
||||
let buffer_without_gcd = FileSlice::from(buffer_without_gcd);
|
||||
assert!(buffer_without_gcd.len() > buffer.len());
|
||||
Ok(())
|
||||
}
|
||||
@@ -405,7 +406,7 @@ fn test_fastfield_gcd_u64() -> io::Result<()> {
|
||||
|
||||
#[test]
|
||||
pub fn test_fastfield2() {
|
||||
let test_fastfield = crate::column_values::serialize_and_load_u64_based_column_values::<u64>(
|
||||
let test_fastfield = serialize_and_load_u64_based_column_values::<u64>(
|
||||
&&[100u64, 200u64, 300u64][..],
|
||||
&ALL_U64_CODEC_TYPES,
|
||||
);
|
||||
|
||||
@@ -4,6 +4,7 @@ mod term_merger;
|
||||
|
||||
use std::collections::{BTreeMap, HashSet};
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use std::net::Ipv6Addr;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -78,6 +79,7 @@ pub fn merge_columnar(
|
||||
required_columns: &[(String, ColumnType)],
|
||||
merge_row_order: MergeRowOrder,
|
||||
output: &mut impl io::Write,
|
||||
cancel: impl Fn() -> bool,
|
||||
) -> io::Result<()> {
|
||||
let mut serializer = ColumnarSerializer::new(output);
|
||||
let num_docs_per_columnar = columnar_readers
|
||||
@@ -87,6 +89,9 @@ pub fn merge_columnar(
|
||||
|
||||
let columns_to_merge = group_columns_for_merge(columnar_readers, required_columns)?;
|
||||
for res in columns_to_merge {
|
||||
if cancel() {
|
||||
return Err(io::Error::new(ErrorKind::Interrupted, "Merge cancelled"));
|
||||
}
|
||||
let ((column_name, _column_type_category), grouped_columns) = res;
|
||||
let grouped_columns = grouped_columns.open(&merge_row_order)?;
|
||||
if grouped_columns.is_empty() {
|
||||
|
||||
@@ -205,6 +205,7 @@ fn test_merge_columnar_numbers() {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut buffer,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let columnar_reader = ColumnarReader::open(buffer).unwrap();
|
||||
@@ -233,6 +234,7 @@ fn test_merge_columnar_texts() {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut buffer,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let columnar_reader = ColumnarReader::open(buffer).unwrap();
|
||||
@@ -282,6 +284,7 @@ fn test_merge_columnar_byte() {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut buffer,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let columnar_reader = ColumnarReader::open(buffer).unwrap();
|
||||
@@ -338,6 +341,7 @@ fn test_merge_columnar_byte_with_missing() {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut buffer,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let columnar_reader = ColumnarReader::open(buffer).unwrap();
|
||||
@@ -390,6 +394,7 @@ fn test_merge_columnar_different_types() {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut buffer,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let columnar_reader = ColumnarReader::open(buffer).unwrap();
|
||||
@@ -455,6 +460,7 @@ fn test_merge_columnar_different_empty_cardinality() {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut buffer,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let columnar_reader = ColumnarReader::open(buffer).unwrap();
|
||||
@@ -565,6 +571,7 @@ proptest! {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut out,
|
||||
|| false,
|
||||
).unwrap();
|
||||
|
||||
let merged_reader = ColumnarReader::open(out).unwrap();
|
||||
@@ -582,6 +589,7 @@ proptest! {
|
||||
&[],
|
||||
MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut out,
|
||||
|| false,
|
||||
).unwrap();
|
||||
|
||||
}
|
||||
|
||||
22
columnar/src/comparable_doc.rs
Normal file
22
columnar/src/comparable_doc.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Contains a feature (field, score, etc.) of a document along with the document address.
|
||||
///
|
||||
/// Used only by TopNComputer, which implements the actual comparison via a `Comparator`.
|
||||
#[derive(Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ComparableDoc<T, D> {
|
||||
/// The feature of the document. In practice, this is
|
||||
/// is a type which can be compared with a `Comparator<T>`.
|
||||
pub sort_key: T,
|
||||
/// The document address. In practice, this is either a `DocId` or `DocAddress`.
|
||||
pub doc: D,
|
||||
}
|
||||
|
||||
impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
f.debug_struct("ComparableDoc")
|
||||
.field("feature", &self.sort_key)
|
||||
.field("doc", &self.doc)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
@@ -71,7 +71,14 @@ fn test_format(path: &str) {
|
||||
let columnar_readers = vec![&reader, &reader2];
|
||||
let merge_row_order = StackMergeOrder::stack(&columnar_readers[..]);
|
||||
let mut out = Vec::new();
|
||||
merge_columnar(&columnar_readers, &[], merge_row_order.into(), &mut out).unwrap();
|
||||
merge_columnar(
|
||||
&columnar_readers,
|
||||
&[],
|
||||
merge_row_order.into(),
|
||||
&mut out,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let reader = ColumnarReader::open(out).unwrap();
|
||||
check_columns(&reader);
|
||||
}
|
||||
|
||||
@@ -3,7 +3,8 @@ use std::sync::Arc;
|
||||
use std::{fmt, io};
|
||||
|
||||
use common::file_slice::FileSlice;
|
||||
use common::{ByteCount, DateTime, HasLen, OwnedBytes};
|
||||
use common::{ByteCount, DateTime};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::column::{BytesColumn, Column, StrColumn};
|
||||
use crate::column_values::{StrictlyMonotonicFn, monotonic_map_column};
|
||||
@@ -238,8 +239,7 @@ pub struct DynamicColumnHandle {
|
||||
impl DynamicColumnHandle {
|
||||
// TODO rename load
|
||||
pub fn open(&self) -> io::Result<DynamicColumn> {
|
||||
let column_bytes: OwnedBytes = self.file_slice.read_bytes()?;
|
||||
self.open_internal(column_bytes)
|
||||
self.open_internal(self.file_slice.clone())
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
@@ -258,16 +258,15 @@ impl DynamicColumnHandle {
|
||||
/// If not, the fastfield reader will returns the u64-value associated with the original
|
||||
/// FastValue.
|
||||
pub fn open_u64_lenient(&self) -> io::Result<Option<Column<u64>>> {
|
||||
let column_bytes = self.file_slice.read_bytes()?;
|
||||
match self.column_type {
|
||||
ColumnType::Str | ColumnType::Bytes => {
|
||||
let column: BytesColumn =
|
||||
crate::column::open_column_bytes(column_bytes, self.format_version)?;
|
||||
crate::column::open_column_bytes(self.file_slice.clone(), self.format_version)?;
|
||||
Ok(Some(column.term_ord_column))
|
||||
}
|
||||
ColumnType::IpAddr => {
|
||||
let column = crate::column::open_column_u128_as_compact_u64(
|
||||
column_bytes,
|
||||
self.file_slice.clone(),
|
||||
self.format_version,
|
||||
)?;
|
||||
Ok(Some(column))
|
||||
@@ -277,50 +276,129 @@ impl DynamicColumnHandle {
|
||||
| ColumnType::U64
|
||||
| ColumnType::F64
|
||||
| ColumnType::DateTime => {
|
||||
let column =
|
||||
crate::column::open_column_u64::<u64>(column_bytes, self.format_version)?;
|
||||
let column = crate::column::open_column_u64::<u64>(
|
||||
self.file_slice.clone(),
|
||||
self.format_version,
|
||||
)?;
|
||||
Ok(Some(column))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn open_internal(&self, column_bytes: OwnedBytes) -> io::Result<DynamicColumn> {
|
||||
fn open_internal(&self, file_slice: FileSlice) -> io::Result<DynamicColumn> {
|
||||
let dynamic_column: DynamicColumn = match self.column_type {
|
||||
ColumnType::Bytes => {
|
||||
crate::column::open_column_bytes(column_bytes, self.format_version)?.into()
|
||||
crate::column::open_column_bytes(file_slice, self.format_version)?.into()
|
||||
}
|
||||
ColumnType::Str => {
|
||||
crate::column::open_column_str(column_bytes, self.format_version)?.into()
|
||||
crate::column::open_column_str(file_slice, self.format_version)?.into()
|
||||
}
|
||||
ColumnType::I64 => {
|
||||
crate::column::open_column_u64::<i64>(column_bytes, self.format_version)?.into()
|
||||
crate::column::open_column_u64::<i64>(file_slice, self.format_version)?.into()
|
||||
}
|
||||
ColumnType::U64 => {
|
||||
crate::column::open_column_u64::<u64>(column_bytes, self.format_version)?.into()
|
||||
crate::column::open_column_u64::<u64>(file_slice, self.format_version)?.into()
|
||||
}
|
||||
ColumnType::F64 => {
|
||||
crate::column::open_column_u64::<f64>(column_bytes, self.format_version)?.into()
|
||||
crate::column::open_column_u64::<f64>(file_slice, self.format_version)?.into()
|
||||
}
|
||||
ColumnType::Bool => {
|
||||
crate::column::open_column_u64::<bool>(column_bytes, self.format_version)?.into()
|
||||
crate::column::open_column_u64::<bool>(file_slice, self.format_version)?.into()
|
||||
}
|
||||
ColumnType::IpAddr => {
|
||||
crate::column::open_column_u128::<Ipv6Addr>(column_bytes, self.format_version)?
|
||||
.into()
|
||||
crate::column::open_column_u128::<Ipv6Addr>(file_slice, self.format_version)?.into()
|
||||
}
|
||||
ColumnType::DateTime => {
|
||||
crate::column::open_column_u64::<DateTime>(column_bytes, self.format_version)?
|
||||
.into()
|
||||
crate::column::open_column_u64::<DateTime>(file_slice, self.format_version)?.into()
|
||||
}
|
||||
};
|
||||
Ok(dynamic_column)
|
||||
}
|
||||
|
||||
pub fn num_bytes(&self) -> ByteCount {
|
||||
self.file_slice.len().into()
|
||||
self.file_slice.num_bytes()
|
||||
}
|
||||
|
||||
/// Legacy helper returning the column space usage.
|
||||
pub fn column_and_dictionary_num_bytes(&self) -> io::Result<ColumnSpaceUsage> {
|
||||
self.space_usage()
|
||||
}
|
||||
|
||||
/// Return the space usage of the column, optionally broken down by dictionary and column
|
||||
/// values.
|
||||
///
|
||||
/// For dictionary encoded columns (strings and bytes), this splits the total footprint into
|
||||
/// the dictionary and the remaining column data (including index and values).
|
||||
/// For all other column types, the dictionary size is `None` and the column size
|
||||
/// equals the total bytes.
|
||||
pub fn space_usage(&self) -> io::Result<ColumnSpaceUsage> {
|
||||
let total_num_bytes = self.num_bytes();
|
||||
let dynamic_column = self.open()?;
|
||||
let dictionary_num_bytes = match &dynamic_column {
|
||||
DynamicColumn::Bytes(bytes_column) => bytes_column.dictionary().num_bytes(),
|
||||
DynamicColumn::Str(str_column) => str_column.dictionary().num_bytes(),
|
||||
_ => {
|
||||
return Ok(ColumnSpaceUsage::new(self.num_bytes(), None));
|
||||
}
|
||||
};
|
||||
assert!(dictionary_num_bytes <= total_num_bytes);
|
||||
let column_num_bytes =
|
||||
ByteCount::from(total_num_bytes.get_bytes() - dictionary_num_bytes.get_bytes());
|
||||
Ok(ColumnSpaceUsage::new(
|
||||
column_num_bytes,
|
||||
Some(dictionary_num_bytes),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn column_type(&self) -> ColumnType {
|
||||
self.column_type
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents space usage of a column.
|
||||
///
|
||||
/// `column_num_bytes` tracks the column payload (index, values and footer).
|
||||
/// For dictionary encoded columns, `dictionary_num_bytes` captures the dictionary footprint.
|
||||
/// [`ColumnSpaceUsage::total_num_bytes`] returns the sum of both parts.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ColumnSpaceUsage {
|
||||
column_num_bytes: ByteCount,
|
||||
dictionary_num_bytes: Option<ByteCount>,
|
||||
}
|
||||
|
||||
impl ColumnSpaceUsage {
|
||||
pub(crate) fn new(
|
||||
column_num_bytes: ByteCount,
|
||||
dictionary_num_bytes: Option<ByteCount>,
|
||||
) -> Self {
|
||||
ColumnSpaceUsage {
|
||||
column_num_bytes,
|
||||
dictionary_num_bytes,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn column_num_bytes(&self) -> ByteCount {
|
||||
self.column_num_bytes
|
||||
}
|
||||
|
||||
pub fn dictionary_num_bytes(&self) -> Option<ByteCount> {
|
||||
self.dictionary_num_bytes
|
||||
}
|
||||
|
||||
pub fn total_num_bytes(&self) -> ByteCount {
|
||||
self.column_num_bytes + self.dictionary_num_bytes.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Merge two space usage values by summing their components.
|
||||
pub fn merge(&self, other: &ColumnSpaceUsage) -> ColumnSpaceUsage {
|
||||
let dictionary_num_bytes = match (self.dictionary_num_bytes, other.dictionary_num_bytes) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs + rhs),
|
||||
(Some(val), None) | (None, Some(val)) => Some(val),
|
||||
(None, None) => None,
|
||||
};
|
||||
ColumnSpaceUsage {
|
||||
column_num_bytes: self.column_num_bytes + other.column_num_bytes,
|
||||
dictionary_num_bytes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ mod column;
|
||||
pub mod column_index;
|
||||
pub mod column_values;
|
||||
mod columnar;
|
||||
mod comparable_doc;
|
||||
mod dictionary;
|
||||
mod dynamic_column;
|
||||
mod iterable;
|
||||
@@ -36,7 +37,7 @@ pub(crate) mod utils;
|
||||
mod value;
|
||||
|
||||
pub use block_accessor::ColumnBlockAccessor;
|
||||
pub use column::{BytesColumn, Column, StrColumn};
|
||||
pub use column::{BytesColumn, Column, StrColumn, ValueRange};
|
||||
pub use column_index::ColumnIndex;
|
||||
pub use column_values::{
|
||||
ColumnValues, EmptyColumnValues, MonotonicallyMappableToU64, MonotonicallyMappableToU128,
|
||||
@@ -45,10 +46,11 @@ pub use columnar::{
|
||||
CURRENT_VERSION, ColumnType, ColumnarReader, ColumnarWriter, HasAssociatedColumnType,
|
||||
MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, Version, merge_columnar,
|
||||
};
|
||||
pub use comparable_doc::ComparableDoc;
|
||||
use sstable::VoidSSTable;
|
||||
pub use value::{NumericalType, NumericalValue};
|
||||
|
||||
pub use self::dynamic_column::{DynamicColumn, DynamicColumnHandle};
|
||||
pub use self::dynamic_column::{ColumnSpaceUsage, DynamicColumn, DynamicColumnHandle};
|
||||
|
||||
pub type RowId = u32;
|
||||
pub type DocId = u32;
|
||||
|
||||
@@ -641,7 +641,7 @@ proptest! {
|
||||
let columnar_readers_arr: Vec<&ColumnarReader> = columnar_readers.iter().collect();
|
||||
let mut output: Vec<u8> = Vec::new();
|
||||
let stack_merge_order = StackMergeOrder::stack(&columnar_readers_arr[..]).into();
|
||||
crate::merge_columnar(&columnar_readers_arr[..], &[], stack_merge_order, &mut output).unwrap();
|
||||
crate::merge_columnar(&columnar_readers_arr[..], &[], stack_merge_order, &mut output, || false,).unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
let concat_rows: Vec<Vec<(&'static str, ColumnValue)>> = columnar_docs.iter().flatten().cloned().collect();
|
||||
let expected_merged_columnar = build_columnar(&concat_rows[..]);
|
||||
@@ -665,6 +665,7 @@ fn test_columnar_merging_empty_columnar() {
|
||||
&[],
|
||||
crate::MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut output,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
@@ -702,6 +703,7 @@ fn test_columnar_merging_number_columns() {
|
||||
&[],
|
||||
crate::MergeRowOrder::Stack(stack_merge_order),
|
||||
&mut output,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
@@ -775,6 +777,7 @@ fn test_columnar_merge_and_remap(
|
||||
&[],
|
||||
shuffle_merge_order.into(),
|
||||
&mut output,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
@@ -817,6 +820,7 @@ fn test_columnar_merge_empty() {
|
||||
&[],
|
||||
shuffle_merge_order.into(),
|
||||
&mut output,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
@@ -843,6 +847,7 @@ fn test_columnar_merge_single_str_column() {
|
||||
&[],
|
||||
shuffle_merge_order.into(),
|
||||
&mut output,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
@@ -875,6 +880,7 @@ fn test_delete_decrease_cardinality() {
|
||||
&[],
|
||||
shuffle_merge_order.into(),
|
||||
&mut output,
|
||||
|| false,
|
||||
)
|
||||
.unwrap();
|
||||
let merged_columnar = ColumnarReader::open(output).unwrap();
|
||||
|
||||
106
common/src/buffered_file_slice.rs
Normal file
106
common/src/buffered_file_slice.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
use std::cell::RefCell;
|
||||
use std::cmp::min;
|
||||
use std::io;
|
||||
use std::ops::Range;
|
||||
|
||||
use super::file_slice::FileSlice;
|
||||
use super::{HasLen, OwnedBytes};
|
||||
|
||||
const DEFAULT_BUFFER_MAX_SIZE: usize = 512 * 1024; // 512K
|
||||
|
||||
/// A buffered reader for a FileSlice.
|
||||
///
|
||||
/// Reads the underlying `FileSlice` in large, sequential chunks to amortize
|
||||
/// the cost of `read_bytes` calls, while keeping peak memory usage under control.
|
||||
///
|
||||
/// TODO: Rather than wrapping a `FileSlice` in buffering, it will usually be better to adjust a
|
||||
/// `FileHandle` to directly handle buffering itself.
|
||||
/// TODO: See: https://github.com/paradedb/paradedb/issues/3374
|
||||
pub struct BufferedFileSlice {
|
||||
file_slice: FileSlice,
|
||||
buffer: RefCell<OwnedBytes>,
|
||||
buffer_range: RefCell<Range<u64>>,
|
||||
buffer_max_size: usize,
|
||||
}
|
||||
|
||||
impl BufferedFileSlice {
|
||||
/// Creates a new `BufferedFileSlice`.
|
||||
///
|
||||
/// The `buffer_max_size` is the amount of data that will be read from the
|
||||
/// `FileSlice` on a buffer miss.
|
||||
pub fn new(file_slice: FileSlice, buffer_max_size: usize) -> Self {
|
||||
Self {
|
||||
file_slice,
|
||||
buffer: RefCell::new(OwnedBytes::empty()),
|
||||
buffer_range: RefCell::new(0..0),
|
||||
buffer_max_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new `BufferedFileSlice` with a default buffer max size.
|
||||
pub fn new_with_default_buffer_size(file_slice: FileSlice) -> Self {
|
||||
Self::new(file_slice, DEFAULT_BUFFER_MAX_SIZE)
|
||||
}
|
||||
|
||||
/// Creates an empty `BufferedFileSlice`.
|
||||
pub fn empty() -> Self {
|
||||
Self::new(FileSlice::empty(), 0)
|
||||
}
|
||||
|
||||
/// Returns an `OwnedBytes` corresponding to the given `required_range`.
|
||||
///
|
||||
/// If the requested range is not in the buffer, this will trigger a read
|
||||
/// from the underlying `FileSlice`.
|
||||
///
|
||||
/// If the requested range is larger than the buffer_max_size, it will be read directly from the
|
||||
/// source without buffering.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an `io::Error` if the underlying read fails or the range is
|
||||
/// out of bounds.
|
||||
pub fn get_bytes(&self, required_range: Range<u64>) -> io::Result<OwnedBytes> {
|
||||
let buffer_range = self.buffer_range.borrow();
|
||||
|
||||
// Cache miss condition: the required range is not fully contained in the current buffer.
|
||||
if required_range.start < buffer_range.start || required_range.end > buffer_range.end {
|
||||
drop(buffer_range); // release borrow before mutating
|
||||
|
||||
if required_range.end > self.file_slice.len() as u64 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"Requested range extends beyond the end of the file slice.",
|
||||
));
|
||||
}
|
||||
|
||||
if (required_range.end - required_range.start) as usize > self.buffer_max_size {
|
||||
// This read is larger than our buffer max size.
|
||||
// Read it directly and bypass the buffer to avoid churning.
|
||||
return self
|
||||
.file_slice
|
||||
.read_bytes_slice(required_range.start as usize..required_range.end as usize);
|
||||
}
|
||||
|
||||
let new_buffer_start = required_range.start;
|
||||
let new_buffer_end = min(
|
||||
new_buffer_start + self.buffer_max_size as u64,
|
||||
self.file_slice.len() as u64,
|
||||
);
|
||||
let read_range = new_buffer_start..new_buffer_end;
|
||||
|
||||
let new_buffer = self
|
||||
.file_slice
|
||||
.read_bytes_slice(read_range.start as usize..read_range.end as usize)?;
|
||||
|
||||
self.buffer.replace(new_buffer);
|
||||
self.buffer_range.replace(read_range);
|
||||
}
|
||||
|
||||
// Now the data is guaranteed to be in the buffer.
|
||||
let buffer = self.buffer.borrow();
|
||||
let buffer_range = self.buffer_range.borrow();
|
||||
let local_start = (required_range.start - buffer_range.start) as usize;
|
||||
let local_end = (required_range.end - buffer_range.start) as usize;
|
||||
Ok(buffer.slice(local_start..local_end))
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::fs::File;
|
||||
use std::ops::{Deref, Range, RangeBounds};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::{fmt, io};
|
||||
|
||||
use async_trait::async_trait;
|
||||
@@ -339,6 +339,27 @@ impl FileHandle for OwnedBytes {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DeferredFileSlice {
|
||||
opener: Arc<dyn Fn() -> io::Result<FileSlice> + Send + Sync + 'static>,
|
||||
file_slice: OnceLock<std::io::Result<FileSlice>>,
|
||||
}
|
||||
|
||||
impl DeferredFileSlice {
|
||||
pub fn new(opener: impl Fn() -> io::Result<FileSlice> + Send + Sync + 'static) -> Self {
|
||||
DeferredFileSlice {
|
||||
opener: Arc::new(opener),
|
||||
file_slice: OnceLock::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open(&self) -> io::Result<&FileSlice> {
|
||||
match self.file_slice.get_or_init(|| (self.opener)()) {
|
||||
Ok(file_slice) => Ok(file_slice),
|
||||
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::io;
|
||||
|
||||
@@ -6,6 +6,7 @@ pub use byteorder::LittleEndian as Endianness;
|
||||
|
||||
mod bitset;
|
||||
pub mod bounds;
|
||||
pub mod buffered_file_slice;
|
||||
mod byte_count;
|
||||
mod datetime;
|
||||
pub mod file_slice;
|
||||
|
||||
@@ -28,7 +28,9 @@ impl BinarySerializable for VIntU128 {
|
||||
writer.write_all(&buffer)
|
||||
}
|
||||
|
||||
#[allow(clippy::unbuffered_bytes)]
|
||||
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
|
||||
#[allow(clippy::unbuffered_bytes)]
|
||||
let mut bytes = reader.bytes();
|
||||
let mut result = 0u128;
|
||||
let mut shift = 0u64;
|
||||
@@ -56,6 +58,33 @@ impl BinarySerializable for VIntU128 {
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub struct VInt(pub u64);
|
||||
|
||||
impl VInt {
|
||||
pub fn deserialize_with_size<R: Read>(reader: &mut R) -> io::Result<(Self, usize)> {
|
||||
let mut nbytes = 0;
|
||||
let mut bytes = reader.bytes();
|
||||
let mut result = 0u64;
|
||||
let mut shift = 0u64;
|
||||
loop {
|
||||
match bytes.next() {
|
||||
Some(Ok(b)) => {
|
||||
nbytes += 1;
|
||||
result |= u64::from(b % 128u8) << shift;
|
||||
if b >= STOP_BIT {
|
||||
return Ok((VInt(result), nbytes));
|
||||
}
|
||||
shift += 7;
|
||||
}
|
||||
_ => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Reach end of buffer while reading VInt",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const STOP_BIT: u8 = 128;
|
||||
|
||||
#[inline]
|
||||
@@ -195,7 +224,9 @@ impl BinarySerializable for VInt {
|
||||
writer.write_all(&buffer[0..num_bytes])
|
||||
}
|
||||
|
||||
#[allow(clippy::unbuffered_bytes)]
|
||||
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
|
||||
#[allow(clippy::unbuffered_bytes)]
|
||||
let mut bytes = reader.bytes();
|
||||
let mut result = 0u64;
|
||||
let mut shift = 0u64;
|
||||
@@ -221,7 +252,6 @@ impl BinarySerializable for VInt {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::{BinarySerializable, VInt, serialize_vint_u32};
|
||||
|
||||
fn aux_test_vint(val: u64) {
|
||||
|
||||
@@ -208,7 +208,7 @@ fn main() -> tantivy::Result<()> {
|
||||
// is the role of the `TopDocs` collector.
|
||||
|
||||
// We can now perform our query.
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
// The actual documents still need to be
|
||||
// retrieved from Tantivy's store.
|
||||
@@ -226,7 +226,7 @@ fn main() -> tantivy::Result<()> {
|
||||
let query = query_parser.parse_query("title:sea^20 body:whale^70")?;
|
||||
|
||||
let (_score, doc_address) = searcher
|
||||
.search(&query, &TopDocs::with_limit(1))?
|
||||
.search(&query, &TopDocs::with_limit(1).order_by_score())?
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
@@ -100,7 +100,7 @@ fn main() -> tantivy::Result<()> {
|
||||
// here we want to get a hit on the 'ken' in Frankenstein
|
||||
let query = query_parser.parse_query("ken")?;
|
||||
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
for (_, doc_address) in top_docs {
|
||||
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;
|
||||
|
||||
@@ -50,14 +50,14 @@ fn main() -> tantivy::Result<()> {
|
||||
{
|
||||
// Simple exact search on the date
|
||||
let query = query_parser.parse_query("occurred_at:\"2022-06-22T12:53:50.53Z\"")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 1);
|
||||
}
|
||||
{
|
||||
// Range query on the date field
|
||||
let query = query_parser
|
||||
.parse_query(r#"occurred_at:[2022-06-22T12:58:00Z TO 2022-06-23T00:00:00Z}"#)?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 1);
|
||||
for (_score, doc_address) in count_docs {
|
||||
let retrieved_doc = searcher.doc::<TantivyDocument>(doc_address)?;
|
||||
|
||||
@@ -28,7 +28,7 @@ fn extract_doc_given_isbn(
|
||||
// The second argument is here to tell we don't care about decoding positions,
|
||||
// or term frequencies.
|
||||
let term_query = TermQuery::new(isbn_term.clone(), IndexRecordOption::Basic);
|
||||
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1))?;
|
||||
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1).order_by_score())?;
|
||||
|
||||
if let Some((_score, doc_address)) = top_docs.first() {
|
||||
let doc = searcher.doc(*doc_address)?;
|
||||
|
||||
@@ -145,7 +145,7 @@ fn main() -> tantivy::Result<()> {
|
||||
let query = FuzzyTermQuery::new(term, 2, true);
|
||||
|
||||
let (top_docs, count) = searcher
|
||||
.search(&query, &(TopDocs::with_limit(5), Count))
|
||||
.search(&query, &(TopDocs::with_limit(5).order_by_score(), Count))
|
||||
.unwrap();
|
||||
assert_eq!(count, 3);
|
||||
assert_eq!(top_docs.len(), 3);
|
||||
|
||||
@@ -69,25 +69,25 @@ fn main() -> tantivy::Result<()> {
|
||||
{
|
||||
// Inclusive range queries
|
||||
let query = query_parser.parse_query("ip:[192.168.0.80 TO 192.168.0.100]")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 1);
|
||||
}
|
||||
{
|
||||
// Exclusive range queries
|
||||
let query = query_parser.parse_query("ip:{192.168.0.80 TO 192.168.1.100]")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 0);
|
||||
}
|
||||
{
|
||||
// Find docs with IP addresses smaller equal 192.168.1.100
|
||||
let query = query_parser.parse_query("ip:[* TO 192.168.1.100]")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 2);
|
||||
}
|
||||
{
|
||||
// Find docs with IP addresses smaller than 192.168.1.100
|
||||
let query = query_parser.parse_query("ip:[* TO 192.168.1.100}")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 2);
|
||||
}
|
||||
|
||||
|
||||
@@ -59,12 +59,12 @@ fn main() -> tantivy::Result<()> {
|
||||
let query_parser = QueryParser::for_index(&index, vec![event_type, attributes]);
|
||||
{
|
||||
let query = query_parser.parse_query("target:submit-button")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 2);
|
||||
}
|
||||
{
|
||||
let query = query_parser.parse_query("target:submit")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 2);
|
||||
}
|
||||
{
|
||||
@@ -74,33 +74,33 @@ fn main() -> tantivy::Result<()> {
|
||||
}
|
||||
{
|
||||
let query = query_parser.parse_query("click AND cart.product_id:133")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 1);
|
||||
}
|
||||
{
|
||||
// The sub-fields in the json field marked as default field still need to be explicitly
|
||||
// addressed
|
||||
let query = query_parser.parse_query("click AND 133")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 0);
|
||||
}
|
||||
{
|
||||
// Default json fields are ignored if they collide with the schema
|
||||
let query = query_parser.parse_query("event_type:holiday-sale")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 0);
|
||||
}
|
||||
// # Query via full attribute path
|
||||
{
|
||||
// This only searches in our schema's `event_type` field
|
||||
let query = query_parser.parse_query("event_type:click")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 2);
|
||||
}
|
||||
{
|
||||
// Default json fields can still be accessed by full path
|
||||
let query = query_parser.parse_query("attributes.event_type:holiday-sale")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 1);
|
||||
}
|
||||
Ok(())
|
||||
|
||||
86
examples/multiple_snippets.rs
Normal file
86
examples/multiple_snippets.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
// # Multiple Snippets Example
|
||||
//
|
||||
// This example demonstrates how to return multiple text fragments
|
||||
// from a document, useful for long documents with matches in different locations.
|
||||
|
||||
use tantivy::collector::TopDocs;
|
||||
use tantivy::query::QueryParser;
|
||||
use tantivy::schema::*;
|
||||
use tantivy::snippet::SnippetGenerator;
|
||||
use tantivy::{doc, Index, IndexWriter};
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn main() -> tantivy::Result<()> {
|
||||
let index_path = TempDir::new()?;
|
||||
|
||||
// Define the schema
|
||||
let mut schema_builder = Schema::builder();
|
||||
let title = schema_builder.add_text_field("title", TEXT | STORED);
|
||||
let body = schema_builder.add_text_field("body", TEXT | STORED);
|
||||
let schema = schema_builder.build();
|
||||
|
||||
// Create the index
|
||||
let index = Index::create_in_dir(&index_path, schema)?;
|
||||
let mut index_writer: IndexWriter = index.writer(50_000_000)?;
|
||||
|
||||
// Index a long document with multiple occurrences of "rust"
|
||||
index_writer.add_document(doc!(
|
||||
title => "The Rust Programming Language",
|
||||
body => "Rust is a systems programming language that runs blazingly fast, prevents \
|
||||
segfaults, and guarantees thread safety. Lorem ipsum dolor sit amet, \
|
||||
consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore. \
|
||||
Rust empowers everyone to build reliable and efficient software. More filler \
|
||||
text to create distance between matches. Ut enim ad minim veniam, quis nostrud \
|
||||
exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. \
|
||||
The Rust compiler is known for its helpful error messages. Duis aute irure \
|
||||
dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla \
|
||||
pariatur. Rust has a strong type system and ownership model."
|
||||
))?;
|
||||
|
||||
index_writer.commit()?;
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
let query_parser = QueryParser::for_index(&index, vec![title, body]);
|
||||
let query = query_parser.parse_query("rust")?;
|
||||
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
// Create snippet generator
|
||||
let mut snippet_generator = SnippetGenerator::create(&searcher, &*query, body)?;
|
||||
|
||||
println!("=== Single Snippet (Default Behavior) ===\n");
|
||||
for (score, doc_address) in &top_docs {
|
||||
let doc = searcher.doc::<TantivyDocument>(*doc_address)?;
|
||||
let snippet = snippet_generator.snippet_from_doc(&doc);
|
||||
println!("Document score: {}", score);
|
||||
println!("Title: {}", doc.get_first(title).unwrap().as_str().unwrap());
|
||||
println!("Single snippet: {}\n", snippet.to_html());
|
||||
}
|
||||
|
||||
println!("\n=== Multiple Snippets (New Feature) ===\n");
|
||||
|
||||
// Configure to return multiple snippets
|
||||
// Get up to 3 snippets
|
||||
snippet_generator.set_snippets_limit(3);
|
||||
// Smaller fragments
|
||||
snippet_generator.set_max_num_chars(80);
|
||||
// By default, multiple snippets are sorted by score. You can change this to sort by position.
|
||||
// snippet_generator.set_sort_order(SnippetSortOrder::Position);
|
||||
|
||||
for (score, doc_address) in top_docs {
|
||||
let doc = searcher.doc::<TantivyDocument>(doc_address)?;
|
||||
let snippets = snippet_generator.snippets_from_doc(&doc);
|
||||
|
||||
println!("Document score: {}", score);
|
||||
println!("Title: {}", doc.get_first(title).unwrap().as_str().unwrap());
|
||||
println!("Found {} snippets:", snippets.len());
|
||||
|
||||
for (i, snippet) in snippets.iter().enumerate() {
|
||||
println!(" Snippet {}: {}", i + 1, snippet.to_html());
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -63,7 +63,7 @@ fn main() -> Result<()> {
|
||||
// but not "in the Gulf Stream".
|
||||
let query = query_parser.parse_query("\"in the su\"*")?;
|
||||
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
let mut titles = top_docs
|
||||
.into_iter()
|
||||
.map(|(_score, doc_address)| {
|
||||
|
||||
@@ -107,7 +107,8 @@ fn main() -> tantivy::Result<()> {
|
||||
IndexRecordOption::Basic,
|
||||
);
|
||||
|
||||
let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
|
||||
let (top_docs, count) =
|
||||
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
|
||||
|
||||
assert_eq!(count, 2);
|
||||
|
||||
@@ -128,7 +129,8 @@ fn main() -> tantivy::Result<()> {
|
||||
IndexRecordOption::Basic,
|
||||
);
|
||||
|
||||
let (_top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
|
||||
let (_top_docs, count) =
|
||||
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
|
||||
|
||||
assert_eq!(count, 0);
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ fn main() -> tantivy::Result<()> {
|
||||
let query_parser = QueryParser::for_index(&index, vec![title, body]);
|
||||
let query = query_parser.parse_query("sycamore spring")?;
|
||||
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
let snippet_generator = SnippetGenerator::create(&searcher, &*query, body)?;
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ fn main() -> tantivy::Result<()> {
|
||||
// stop words are applied on the query as well.
|
||||
// The following will be equivalent to `title:frankenstein`
|
||||
let query = query_parser.parse_query("title:\"the Frankenstein\"")?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
for (score, doc_address) in top_docs {
|
||||
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;
|
||||
|
||||
@@ -164,7 +164,7 @@ fn main() -> tantivy::Result<()> {
|
||||
move |doc_id: DocId| Reverse(price[doc_id as usize])
|
||||
};
|
||||
|
||||
let most_expensive_first = TopDocs::with_limit(10).custom_score(score_by_price);
|
||||
let most_expensive_first = TopDocs::with_limit(10).order_by(score_by_price);
|
||||
|
||||
let hits = searcher.search(&query, &most_expensive_first)?;
|
||||
assert_eq!(
|
||||
|
||||
@@ -758,7 +758,17 @@ fn negate(expr: UserInputAst) -> UserInputAst {
|
||||
fn leaf(inp: &str) -> IResult<&str, UserInputAst> {
|
||||
alt((
|
||||
delimited(char('('), ast, char(')')),
|
||||
map(char('*'), |_| UserInputAst::from(UserInputLeaf::All)),
|
||||
map(
|
||||
terminated(
|
||||
char('*'),
|
||||
peek(alt((
|
||||
value((), multispace1),
|
||||
value((), char(')')),
|
||||
value((), eof),
|
||||
))),
|
||||
),
|
||||
|_| UserInputAst::from(UserInputLeaf::All),
|
||||
),
|
||||
map(preceded(tuple((tag("NOT"), multispace1)), leaf), negate),
|
||||
literal,
|
||||
))(inp)
|
||||
@@ -779,7 +789,17 @@ fn leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
|
||||
),
|
||||
),
|
||||
(
|
||||
value((), char('*')),
|
||||
value(
|
||||
(),
|
||||
terminated(
|
||||
char('*'),
|
||||
peek(alt((
|
||||
value((), multispace1),
|
||||
value((), char(')')),
|
||||
value((), eof),
|
||||
))),
|
||||
),
|
||||
),
|
||||
map(nothing, |_| {
|
||||
(Some(UserInputAst::from(UserInputLeaf::All)), Vec::new())
|
||||
}),
|
||||
@@ -1671,6 +1691,21 @@ mod test {
|
||||
test_parse_query_to_ast_helper("abc:a b", "(*\"abc\":a *b)");
|
||||
test_parse_query_to_ast_helper("abc:\"a b\"", "\"abc\":\"a b\"");
|
||||
test_parse_query_to_ast_helper("foo:[1 TO 5]", "\"foo\":[\"1\" TO \"5\"]");
|
||||
|
||||
// Phrase prefixed with *
|
||||
test_parse_query_to_ast_helper("foo:(*A)", "\"foo\":*A");
|
||||
test_parse_query_to_ast_helper("*A", "*A");
|
||||
test_parse_query_to_ast_helper("(*A)", "*A");
|
||||
test_parse_query_to_ast_helper("foo:(A OR B)", "(?\"foo\":A ?\"foo\":B)");
|
||||
test_parse_query_to_ast_helper("foo:(A* OR B*)", "(?\"foo\":A* ?\"foo\":B*)");
|
||||
test_parse_query_to_ast_helper("foo:(*A OR *B)", "(?\"foo\":*A ?\"foo\":*B)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_query_all() {
|
||||
test_parse_query_to_ast_helper("*", "*");
|
||||
test_parse_query_to_ast_helper("(*)", "*");
|
||||
test_parse_query_to_ast_helper("(* )", "*");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
3
runtests.sh
Executable file
3
runtests.sh
Executable file
@@ -0,0 +1,3 @@
|
||||
#! /bin/bash
|
||||
|
||||
cargo +stable nextest run --features quickwit,mmap,stopwords,lz4-compression,zstd-compression,failpoints --verbose --workspace
|
||||
@@ -16,15 +16,16 @@ use crate::index::SegmentReader;
|
||||
/// That way we can use it the same way as if it would come from the fastfield.
|
||||
pub(crate) fn get_missing_val_as_u64_lenient(
|
||||
column_type: ColumnType,
|
||||
column_max_value: u64,
|
||||
missing: &Key,
|
||||
field_name: &str,
|
||||
) -> crate::Result<Option<u64>> {
|
||||
let missing_val = match missing {
|
||||
Key::Str(_) if column_type == ColumnType::Str => Some(u64::MAX),
|
||||
Key::Str(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
|
||||
// Allow fallback to number on text fields
|
||||
Key::F64(_) if column_type == ColumnType::Str => Some(u64::MAX),
|
||||
Key::U64(_) if column_type == ColumnType::Str => Some(u64::MAX),
|
||||
Key::I64(_) if column_type == ColumnType::Str => Some(u64::MAX),
|
||||
Key::F64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
|
||||
Key::U64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
|
||||
Key::I64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
|
||||
Key::F64(val) if column_type.numerical_type().is_some() => {
|
||||
f64_to_fastfield_u64(*val, &column_type)
|
||||
}
|
||||
|
||||
@@ -10,10 +10,10 @@ use crate::aggregation::accessor_helpers::{
|
||||
};
|
||||
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
|
||||
use crate::aggregation::bucket::{
|
||||
build_segment_aggregation_collector, FilterAggReqData, HistogramAggReqData, HistogramBounds,
|
||||
IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector,
|
||||
SegmentHistogramCollector, SegmentRangeCollector, TermMissingAgg, TermsAggReqData,
|
||||
TermsAggregation, TermsAggregationInternal,
|
||||
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
|
||||
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
|
||||
SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
|
||||
TermsAggregationInternal,
|
||||
};
|
||||
use crate::aggregation::metric::{
|
||||
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
|
||||
@@ -373,7 +373,7 @@ pub(crate) fn build_segment_agg_collector(
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
match node.kind {
|
||||
AggKind::Terms => build_segment_aggregation_collector(req, node),
|
||||
AggKind::Terms => crate::aggregation::bucket::build_segment_term_collector(req, node),
|
||||
AggKind::MissingTerm => {
|
||||
let req_data = &mut req.per_request.missing_term_req_data[node.idx_in_req_data];
|
||||
if req_data.accessors.is_empty() {
|
||||
@@ -496,7 +496,7 @@ pub(crate) fn build_aggregations_data_from_req(
|
||||
};
|
||||
|
||||
for (name, agg) in aggs.iter() {
|
||||
let nodes = build_nodes(name, agg, reader, segment_ordinal, &mut data)?;
|
||||
let nodes = build_nodes(name, agg, reader, segment_ordinal, &mut data, true)?;
|
||||
data.per_request.agg_tree.extend(nodes);
|
||||
}
|
||||
Ok(data)
|
||||
@@ -508,6 +508,7 @@ fn build_nodes(
|
||||
reader: &SegmentReader,
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
data: &mut AggregationsSegmentCtx,
|
||||
is_top_level: bool,
|
||||
) -> crate::Result<Vec<AggRefNode>> {
|
||||
use AggregationVariants::*;
|
||||
match &req.agg {
|
||||
@@ -594,6 +595,7 @@ fn build_nodes(
|
||||
data,
|
||||
&req.sub_aggregation,
|
||||
TermsOrCardinalityRequest::Terms(terms_req.clone()),
|
||||
is_top_level,
|
||||
),
|
||||
Cardinality(card_req) => build_terms_or_cardinality_nodes(
|
||||
agg_name,
|
||||
@@ -604,6 +606,7 @@ fn build_nodes(
|
||||
data,
|
||||
&req.sub_aggregation,
|
||||
TermsOrCardinalityRequest::Cardinality(card_req.clone()),
|
||||
is_top_level,
|
||||
),
|
||||
Average(AverageAggregation { field, missing, .. })
|
||||
| Max(MaxAggregation { field, missing, .. })
|
||||
@@ -732,7 +735,7 @@ fn build_nodes(
|
||||
// Build the query and evaluator upfront
|
||||
let schema = reader.schema();
|
||||
let tokenizers = &data.context.tokenizers;
|
||||
let query = filter_req.parse_query(&schema, tokenizers)?;
|
||||
let query = filter_req.parse_query(schema, tokenizers)?;
|
||||
let evaluator = crate::aggregation::bucket::DocumentQueryEvaluator::new(
|
||||
query,
|
||||
schema.clone(),
|
||||
@@ -769,7 +772,14 @@ fn build_children(
|
||||
) -> crate::Result<Vec<AggRefNode>> {
|
||||
let mut children = Vec::new();
|
||||
for (name, agg) in aggs.iter() {
|
||||
children.extend(build_nodes(name, agg, reader, segment_ordinal, data)?);
|
||||
children.extend(build_nodes(
|
||||
name,
|
||||
agg,
|
||||
reader,
|
||||
segment_ordinal,
|
||||
data,
|
||||
false,
|
||||
)?);
|
||||
}
|
||||
Ok(children)
|
||||
}
|
||||
@@ -833,6 +843,7 @@ fn build_terms_or_cardinality_nodes(
|
||||
data: &mut AggregationsSegmentCtx,
|
||||
sub_aggs: &Aggregations,
|
||||
req: TermsOrCardinalityRequest,
|
||||
is_top_level: bool,
|
||||
) -> crate::Result<Vec<AggRefNode>> {
|
||||
let mut nodes = Vec::new();
|
||||
|
||||
@@ -889,7 +900,7 @@ fn build_terms_or_cardinality_nodes(
|
||||
let missing_value_for_accessor = if use_special_missing_agg {
|
||||
None
|
||||
} else if let Some(m) = missing.as_ref() {
|
||||
get_missing_val_as_u64_lenient(column_type, m, field_name)?
|
||||
get_missing_val_as_u64_lenient(column_type, accessor.max_value(), m, field_name)?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -922,6 +933,7 @@ fn build_terms_or_cardinality_nodes(
|
||||
sub_aggregation_blueprint: None,
|
||||
sug_aggregations: sub_aggs.clone(),
|
||||
allowed_term_ids,
|
||||
is_top_level,
|
||||
});
|
||||
(idx_in_req_data, AggKind::Terms)
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ pub struct AggregationLimitsGuard {
|
||||
/// Allocated memory with this guard.
|
||||
allocated_with_the_guard: u64,
|
||||
}
|
||||
|
||||
impl Clone for AggregationLimitsGuard {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
|
||||
@@ -16,7 +16,7 @@ use super::{AggregationError, Key};
|
||||
use crate::TantivyError;
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
/// The final aggegation result.
|
||||
/// The final aggregation result.
|
||||
pub struct AggregationResults(pub FxHashMap<String, AggregationResult>);
|
||||
|
||||
impl AggregationResults {
|
||||
|
||||
@@ -32,7 +32,7 @@ use crate::{DocId, SegmentReader, TantivyError};
|
||||
///
|
||||
/// # Implementation Requirements
|
||||
///
|
||||
/// Implementors must:
|
||||
/// Implementers must:
|
||||
/// 1. Derive `Debug`, `Clone`, `Serialize`, and `Deserialize`
|
||||
/// 2. Use `#[typetag::serde]` attribute on the impl block
|
||||
/// 3. Implement `build_query()` to construct the query from schema/tokenizers
|
||||
@@ -639,16 +639,14 @@ pub struct IntermediateFilterBucketResult {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::*;
|
||||
use crate::aggregation::agg_req::Aggregations;
|
||||
use crate::aggregation::agg_result::AggregationResults;
|
||||
use crate::aggregation::{AggContextParams, AggregationCollector};
|
||||
use crate::query::{AllQuery, QueryParser, TermQuery};
|
||||
use crate::schema::{IndexRecordOption, Schema, Term, FAST, INDEXED, STORED, TEXT};
|
||||
use crate::query::{AllQuery, TermQuery};
|
||||
use crate::schema::{IndexRecordOption, Schema, Term, FAST, INDEXED, TEXT};
|
||||
use crate::{doc, Index, IndexWriter};
|
||||
|
||||
// Test helper functions
|
||||
@@ -729,12 +727,13 @@ mod tests {
|
||||
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer: IndexWriter = index.writer(50_000_000)?;
|
||||
let mut writer: IndexWriter = index.writer_for_tests()?;
|
||||
|
||||
writer.add_document(doc!(
|
||||
category => "electronics", brand => "apple",
|
||||
price => 999u64, rating => 4.5f64, in_stock => true
|
||||
))?;
|
||||
writer.commit()?;
|
||||
writer.add_document(doc!(
|
||||
category => "electronics", brand => "samsung",
|
||||
price => 799u64, rating => 4.2f64, in_stock => true
|
||||
@@ -938,7 +937,7 @@ mod tests {
|
||||
let index = create_standard_test_index()?;
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
|
||||
assert_eq!(searcher.segment_readers().len(), 2);
|
||||
let agg = json!({
|
||||
"premium_electronics": {
|
||||
"filter": "category:electronics AND price:[800 TO *]",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,196 +0,0 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use columnar::ColumnType;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use super::OrderTarget;
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::agg_limits::MemoryConsumption;
|
||||
use crate::aggregation::bucket::get_agg_name_and_property;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::TantivyError;
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
/// Container to store term_ids/or u64 values and their buckets.
|
||||
struct TermBuckets {
|
||||
pub(crate) entries: FxHashMap<u64, u32>,
|
||||
pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
|
||||
}
|
||||
|
||||
impl TermBuckets {
|
||||
fn get_memory_consumption(&self) -> usize {
|
||||
let sub_aggs_mem = self.sub_aggs.memory_consumption();
|
||||
let buckets_mem = self.entries.memory_consumption();
|
||||
sub_aggs_mem + buckets_mem
|
||||
}
|
||||
|
||||
fn force_flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
for sub_aggregations in &mut self.sub_aggs.values_mut() {
|
||||
sub_aggregations.as_mut().flush(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// The collector puts values from the fast field into the correct buckets and does a conversion to
|
||||
/// the correct datatype.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SegmentTermCollector {
|
||||
/// The buckets containing the aggregation data.
|
||||
term_buckets: TermBuckets,
|
||||
accessor_idx: usize,
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentTermCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_data.get_term_req_data(self.accessor_idx).name.clone();
|
||||
|
||||
let entries: Vec<(u64, u32)> = self.term_buckets.entries.into_iter().collect();
|
||||
let bucket = super::into_intermediate_bucket_result(
|
||||
self.accessor_idx,
|
||||
entries,
|
||||
self.term_buckets.sub_aggs,
|
||||
agg_data,
|
||||
)?;
|
||||
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let mut req_data = agg_data.take_term_req_data(self.accessor_idx);
|
||||
|
||||
let mem_pre = self.get_memory_consumption();
|
||||
|
||||
if let Some(missing) = req_data.missing_value_for_accessor {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
|
||||
for term_id in req_data.column_block_accessor.iter_vals() {
|
||||
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
|
||||
if !allowed_bs.contains(term_id as u32) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
let entry = self.term_buckets.entries.entry(term_id).or_default();
|
||||
*entry += 1;
|
||||
}
|
||||
// has subagg
|
||||
if let Some(blueprint) = req_data.sub_aggregation_blueprint.as_ref() {
|
||||
for (doc, term_id) in req_data
|
||||
.column_block_accessor
|
||||
.iter_docid_vals(docs, &req_data.accessor)
|
||||
{
|
||||
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
|
||||
if !allowed_bs.contains(term_id as u32) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
let sub_aggregations = self
|
||||
.term_buckets
|
||||
.sub_aggs
|
||||
.entry(term_id)
|
||||
.or_insert_with(|| blueprint.clone());
|
||||
sub_aggregations.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
|
||||
let mem_delta = self.get_memory_consumption() - mem_pre;
|
||||
if mem_delta > 0 {
|
||||
agg_data
|
||||
.context
|
||||
.limits
|
||||
.add_memory_consumed(mem_delta as u64)?;
|
||||
}
|
||||
agg_data.put_back_term_req_data(self.accessor_idx, req_data);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
self.term_buckets.force_flush(agg_data)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentTermCollector {
|
||||
pub fn from_req_and_validate(
|
||||
req_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data);
|
||||
let column_type = terms_req_data.column_type;
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
if column_type == ColumnType::Bytes {
|
||||
return Err(TantivyError::InvalidArgument(format!(
|
||||
"terms aggregation is not supported for column type {column_type:?}"
|
||||
)));
|
||||
}
|
||||
let term_buckets = TermBuckets::default();
|
||||
|
||||
// Validate sub aggregation exists
|
||||
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
|
||||
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
|
||||
|
||||
node.get_sub_agg(agg_name, &req_data.per_request)
|
||||
.ok_or_else(|| {
|
||||
TantivyError::InvalidArgument(format!(
|
||||
"could not find aggregation with name {agg_name} in metric \
|
||||
sub_aggregations"
|
||||
))
|
||||
})?;
|
||||
}
|
||||
|
||||
let has_sub_aggregations = !node.children.is_empty();
|
||||
let blueprint = if has_sub_aggregations {
|
||||
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
|
||||
Some(sub_aggregation)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data);
|
||||
terms_req_data.sub_aggregation_blueprint = blueprint;
|
||||
|
||||
Ok(SegmentTermCollector {
|
||||
term_buckets,
|
||||
accessor_idx,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_memory_consumption(&self) -> usize {
|
||||
let self_mem = std::mem::size_of::<Self>();
|
||||
let term_buckets_mem = self.term_buckets.get_memory_consumption();
|
||||
self_mem + term_buckets_mem
|
||||
}
|
||||
}
|
||||
@@ -1,228 +0,0 @@
|
||||
use std::vec;
|
||||
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::bucket::{get_agg_name_and_property, OrderTarget};
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::{DocId, TantivyError};
|
||||
|
||||
const MAX_BATCH_SIZE: usize = 1_024;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LowCardTermBuckets {
|
||||
entries: Box<[u32]>,
|
||||
sub_aggs: Vec<Box<dyn SegmentAggregationCollector>>,
|
||||
doc_buffers: Box<[Vec<DocId>]>,
|
||||
}
|
||||
|
||||
impl LowCardTermBuckets {
|
||||
pub fn with_num_buckets(
|
||||
num_buckets: usize,
|
||||
sub_aggs_blueprint_opt: Option<&Box<dyn SegmentAggregationCollector>>,
|
||||
) -> Self {
|
||||
let sub_aggs = sub_aggs_blueprint_opt
|
||||
.as_ref()
|
||||
.map(|blueprint| {
|
||||
std::iter::repeat_with(|| blueprint.clone_box())
|
||||
.take(num_buckets)
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
Self {
|
||||
entries: vec![0; num_buckets].into_boxed_slice(),
|
||||
sub_aggs,
|
||||
doc_buffers: std::iter::repeat_with(|| Vec::with_capacity(MAX_BATCH_SIZE))
|
||||
.take(num_buckets)
|
||||
.collect::<Vec<_>>()
|
||||
.into_boxed_slice(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_memory_consumption(&self) -> usize {
|
||||
std::mem::size_of::<Self>()
|
||||
+ self.entries.len() * std::mem::size_of::<u32>()
|
||||
+ self.doc_buffers.len()
|
||||
* (std::mem::size_of::<Vec<DocId>>()
|
||||
+ std::mem::size_of::<DocId>() * MAX_BATCH_SIZE)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LowCardSegmentTermCollector {
|
||||
term_buckets: LowCardTermBuckets,
|
||||
accessor_idx: usize,
|
||||
}
|
||||
|
||||
impl LowCardSegmentTermCollector {
|
||||
pub fn from_req_and_validate(
|
||||
req_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data);
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let cardinality = terms_req_data
|
||||
.accessor
|
||||
.max_value()
|
||||
.max(terms_req_data.missing_value_for_accessor.unwrap_or(0))
|
||||
+ 1;
|
||||
assert!(cardinality <= super::LOW_CARDINALITY_THRESHOLD);
|
||||
|
||||
// Validate sub aggregation exists
|
||||
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
|
||||
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
|
||||
|
||||
node.get_sub_agg(agg_name, &req_data.per_request)
|
||||
.ok_or_else(|| {
|
||||
TantivyError::InvalidArgument(format!(
|
||||
"could not find aggregation with name {agg_name} in metric \
|
||||
sub_aggregations"
|
||||
))
|
||||
})?;
|
||||
}
|
||||
|
||||
let has_sub_aggregations = !node.children.is_empty();
|
||||
let blueprint = if has_sub_aggregations {
|
||||
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
|
||||
Some(sub_aggregation)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data);
|
||||
|
||||
let term_buckets =
|
||||
LowCardTermBuckets::with_num_buckets(cardinality as usize, blueprint.as_ref());
|
||||
|
||||
terms_req_data.sub_aggregation_blueprint = blueprint;
|
||||
|
||||
Ok(LowCardSegmentTermCollector {
|
||||
term_buckets,
|
||||
accessor_idx,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_memory_consumption(&self) -> usize {
|
||||
let self_mem = std::mem::size_of::<Self>();
|
||||
let term_buckets_mem = self.term_buckets.get_memory_consumption();
|
||||
self_mem + term_buckets_mem
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for LowCardSegmentTermCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_data.get_term_req_data(self.accessor_idx).name.clone();
|
||||
let sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>> = self
|
||||
.term_buckets
|
||||
.sub_aggs
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|(bucket_id, _sub_agg)| self.term_buckets.entries[*bucket_id] > 0)
|
||||
.map(|(bucket_id, sub_agg)| (bucket_id as u64, sub_agg))
|
||||
.collect();
|
||||
let entries: Vec<(u64, u32)> = self
|
||||
.term_buckets
|
||||
.entries
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, count)| **count > 0)
|
||||
.map(|(bucket_id, count)| (bucket_id as u64, *count))
|
||||
.collect();
|
||||
|
||||
let bucket =
|
||||
super::into_intermediate_bucket_result(self.accessor_idx, entries, sub_aggs, agg_data)?;
|
||||
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if docs.len() > MAX_BATCH_SIZE {
|
||||
for batch in docs.chunks(MAX_BATCH_SIZE) {
|
||||
self.collect_block(batch, agg_data)?;
|
||||
}
|
||||
}
|
||||
|
||||
let mut req_data = agg_data.take_term_req_data(self.accessor_idx);
|
||||
|
||||
let mem_pre = self.get_memory_consumption();
|
||||
|
||||
if let Some(missing) = req_data.missing_value_for_accessor {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
|
||||
// has subagg
|
||||
if req_data.sub_aggregation_blueprint.is_some() {
|
||||
for (doc, term_id) in req_data
|
||||
.column_block_accessor
|
||||
.iter_docid_vals(docs, &req_data.accessor)
|
||||
{
|
||||
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
|
||||
if !allowed_bs.contains(term_id as u32) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
self.term_buckets.doc_buffers[term_id as usize].push(doc);
|
||||
}
|
||||
for (bucket_id, docs) in self.term_buckets.doc_buffers.iter_mut().enumerate() {
|
||||
self.term_buckets.entries[bucket_id] += docs.len() as u32;
|
||||
self.term_buckets.sub_aggs[bucket_id].collect_block(&docs[..], agg_data)?;
|
||||
docs.clear();
|
||||
}
|
||||
} else {
|
||||
for term_id in req_data.column_block_accessor.iter_vals() {
|
||||
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
|
||||
if !allowed_bs.contains(term_id as u32) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
self.term_buckets.entries[term_id as usize] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let mem_delta = self.get_memory_consumption() - mem_pre;
|
||||
if mem_delta > 0 {
|
||||
agg_data
|
||||
.context
|
||||
.limits
|
||||
.add_memory_consumed(mem_delta as u64)?;
|
||||
}
|
||||
agg_data.put_back_term_req_data(self.accessor_idx, req_data);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
for sub_aggregations in &mut self.term_buckets.sub_aggs.iter_mut() {
|
||||
sub_aggregations.as_mut().flush(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,12 @@ use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::DocId;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub(crate) const DOC_BLOCK_SIZE: usize = 256;
|
||||
|
||||
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
|
||||
|
||||
/// BufAggregationCollector buffers documents before calling collect_block().
|
||||
@@ -15,7 +20,7 @@ pub(crate) struct BufAggregationCollector {
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BufAggregationCollector {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentAggregationResultsCollector")
|
||||
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
|
||||
.field("num_staged_docs", &self.num_staged_docs)
|
||||
@@ -66,7 +71,6 @@ impl SegmentAggregationCollector for BufAggregationCollector {
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collector.collect_block(docs, agg_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ impl ExtendedStatsAggregation {
|
||||
|
||||
/// Extended stats contains a collection of statistics
|
||||
/// they extends stats adding variance, standard deviation
|
||||
/// and bound informations
|
||||
/// and bound information
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ExtendedStats {
|
||||
/// The number of documents.
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::HashMap;
|
||||
use std::net::Ipv6Addr;
|
||||
|
||||
use columnar::{Column, ColumnType, ColumnarReader, DynamicColumn};
|
||||
use columnar::{Column, ColumnType, ColumnarReader, DynamicColumn, ValueRange};
|
||||
use common::json_path_writer::JSON_PATH_SEGMENT_SEP_STR;
|
||||
use common::DateTime;
|
||||
use regex::Regex;
|
||||
@@ -16,6 +17,7 @@ use crate::aggregation::intermediate_agg_result::{
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::AggregationError;
|
||||
use crate::collector::sort_key::{Comparator, ReverseComparator};
|
||||
use crate::collector::TopNComputer;
|
||||
use crate::schema::OwnedValue;
|
||||
use crate::{DocAddress, DocId, SegmentOrdinal};
|
||||
@@ -382,7 +384,7 @@ impl From<FastFieldValue> for OwnedValue {
|
||||
|
||||
/// Holds a fast field value in its u64 representation, and the order in which it should be sorted.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
struct DocValueAndOrder {
|
||||
pub(crate) struct DocValueAndOrder {
|
||||
/// A fast field value in its u64 representation.
|
||||
value: Option<u64>,
|
||||
/// Sort order for the value
|
||||
@@ -454,11 +456,42 @@ impl PartialEq for DocSortValuesAndFields {
|
||||
|
||||
impl Eq for DocSortValuesAndFields {}
|
||||
|
||||
impl Comparator<DocSortValuesAndFields> for ReverseComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &DocSortValuesAndFields, rhs: &DocSortValuesAndFields) -> Ordering {
|
||||
rhs.cmp(lhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(
|
||||
&self,
|
||||
threshold: DocSortValuesAndFields,
|
||||
) -> ValueRange<DocSortValuesAndFields> {
|
||||
ValueRange::LessThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub(crate) struct TopHitsSegmentSortKey(pub Vec<DocValueAndOrder>);
|
||||
|
||||
impl Comparator<TopHitsSegmentSortKey> for ReverseComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &TopHitsSegmentSortKey, rhs: &TopHitsSegmentSortKey) -> Ordering {
|
||||
rhs.cmp(lhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(
|
||||
&self,
|
||||
threshold: TopHitsSegmentSortKey,
|
||||
) -> ValueRange<TopHitsSegmentSortKey> {
|
||||
ValueRange::LessThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
/// The TopHitsCollector used for collecting over segments and merging results.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct TopHitsTopNComputer {
|
||||
req: TopHitsAggregationReq,
|
||||
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, false>,
|
||||
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, ReverseComparator>,
|
||||
}
|
||||
|
||||
impl std::cmp::PartialEq for TopHitsTopNComputer {
|
||||
@@ -482,7 +515,7 @@ impl TopHitsTopNComputer {
|
||||
|
||||
pub(crate) fn merge_fruits(&mut self, other_fruit: Self) -> crate::Result<()> {
|
||||
for doc in other_fruit.top_n.into_vec() {
|
||||
self.collect(doc.feature, doc.doc);
|
||||
self.collect(doc.sort_key, doc.doc);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -494,9 +527,9 @@ impl TopHitsTopNComputer {
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.map(|doc| TopHitsVecEntry {
|
||||
sort: doc.feature.sorts.iter().map(|f| f.value).collect(),
|
||||
sort: doc.sort_key.sorts.iter().map(|f| f.value).collect(),
|
||||
doc_value_fields: doc
|
||||
.feature
|
||||
.sort_key
|
||||
.doc_value_fields
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.into()))
|
||||
@@ -517,7 +550,7 @@ impl TopHitsTopNComputer {
|
||||
pub(crate) struct TopHitsSegmentCollector {
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
accessor_idx: usize,
|
||||
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, false>,
|
||||
top_n: TopNComputer<TopHitsSegmentSortKey, DocAddress, ReverseComparator>,
|
||||
}
|
||||
|
||||
impl TopHitsSegmentCollector {
|
||||
@@ -538,13 +571,15 @@ impl TopHitsSegmentCollector {
|
||||
req: &TopHitsAggregationReq,
|
||||
) -> TopHitsTopNComputer {
|
||||
let mut top_hits_computer = TopHitsTopNComputer::new(req);
|
||||
// Map TopHitsSegmentSortKey back to Vec<DocValueAndOrder> if needed or use directly
|
||||
// The TopNComputer here stores TopHitsSegmentSortKey.
|
||||
let top_results = self.top_n.into_vec();
|
||||
|
||||
for res in top_results {
|
||||
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
|
||||
top_hits_computer.collect(
|
||||
DocSortValuesAndFields {
|
||||
sorts: res.feature,
|
||||
sorts: res.sort_key.0,
|
||||
doc_value_fields,
|
||||
},
|
||||
res.doc,
|
||||
@@ -578,7 +613,7 @@ impl TopHitsSegmentCollector {
|
||||
.collect();
|
||||
|
||||
self.top_n.push(
|
||||
sorts,
|
||||
TopHitsSegmentSortKey(sorts),
|
||||
DocAddress {
|
||||
segment_ord: self.segment_ordinal,
|
||||
doc_id,
|
||||
@@ -645,6 +680,7 @@ mod tests {
|
||||
use crate::aggregation::bucket::tests::get_test_index_from_docs;
|
||||
use crate::aggregation::tests::get_test_index_from_values;
|
||||
use crate::aggregation::AggregationCollector;
|
||||
use crate::collector::sort_key::ReverseComparator;
|
||||
use crate::collector::ComparableDoc;
|
||||
use crate::query::AllQuery;
|
||||
use crate::schema::OwnedValue;
|
||||
@@ -660,7 +696,7 @@ mod tests {
|
||||
|
||||
fn collector_with_capacity(capacity: usize) -> super::TopHitsTopNComputer {
|
||||
super::TopHitsTopNComputer {
|
||||
top_n: super::TopNComputer::new(capacity),
|
||||
top_n: super::TopNComputer::new_with_comparator(capacity, ReverseComparator),
|
||||
req: Default::default(),
|
||||
}
|
||||
}
|
||||
@@ -774,12 +810,12 @@ mod tests {
|
||||
#[test]
|
||||
fn test_top_hits_collector_single_feature() -> crate::Result<()> {
|
||||
let docs = vec![
|
||||
ComparableDoc::<_, _, false> {
|
||||
ComparableDoc::<_, _> {
|
||||
doc: crate::DocAddress {
|
||||
segment_ord: 0,
|
||||
doc_id: 0,
|
||||
},
|
||||
feature: DocSortValuesAndFields {
|
||||
sort_key: DocSortValuesAndFields {
|
||||
sorts: vec![DocValueAndOrder {
|
||||
value: Some(1),
|
||||
order: Order::Asc,
|
||||
@@ -792,7 +828,7 @@ mod tests {
|
||||
segment_ord: 0,
|
||||
doc_id: 2,
|
||||
},
|
||||
feature: DocSortValuesAndFields {
|
||||
sort_key: DocSortValuesAndFields {
|
||||
sorts: vec![DocValueAndOrder {
|
||||
value: Some(3),
|
||||
order: Order::Asc,
|
||||
@@ -805,7 +841,7 @@ mod tests {
|
||||
segment_ord: 0,
|
||||
doc_id: 1,
|
||||
},
|
||||
feature: DocSortValuesAndFields {
|
||||
sort_key: DocSortValuesAndFields {
|
||||
sorts: vec![DocValueAndOrder {
|
||||
value: Some(5),
|
||||
order: Order::Asc,
|
||||
@@ -817,7 +853,7 @@ mod tests {
|
||||
|
||||
let mut collector = collector_with_capacity(3);
|
||||
for doc in docs.clone() {
|
||||
collector.collect(doc.feature, doc.doc);
|
||||
collector.collect(doc.sort_key, doc.doc);
|
||||
}
|
||||
|
||||
let res = collector.into_final_result();
|
||||
@@ -827,15 +863,15 @@ mod tests {
|
||||
super::TopHitsMetricResult {
|
||||
hits: vec![
|
||||
super::TopHitsVecEntry {
|
||||
sort: vec![docs[0].feature.sorts[0].value],
|
||||
sort: vec![docs[0].sort_key.sorts[0].value],
|
||||
doc_value_fields: Default::default(),
|
||||
},
|
||||
super::TopHitsVecEntry {
|
||||
sort: vec![docs[1].feature.sorts[0].value],
|
||||
sort: vec![docs[1].sort_key.sorts[0].value],
|
||||
doc_value_fields: Default::default(),
|
||||
},
|
||||
super::TopHitsVecEntry {
|
||||
sort: vec![docs[2].feature.sorts[0].value],
|
||||
sort: vec![docs[2].sort_key.sorts[0].value],
|
||||
doc_value_fields: Default::default(),
|
||||
},
|
||||
]
|
||||
|
||||
@@ -17,14 +17,11 @@ pub trait SegmentAggregationCollector: CollectorClone + Debug {
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
) -> crate::Result<()>;
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
|
||||
use crate::collector::{Collector, SegmentCollector};
|
||||
use crate::{DocAddress, DocId, Score, SegmentReader};
|
||||
|
||||
pub(crate) struct CustomScoreTopCollector<TCustomScorer, TScore = Score> {
|
||||
custom_scorer: TCustomScorer,
|
||||
collector: TopCollector<TScore>,
|
||||
}
|
||||
|
||||
impl<TCustomScorer, TScore> CustomScoreTopCollector<TCustomScorer, TScore>
|
||||
where TScore: Clone + PartialOrd
|
||||
{
|
||||
pub(crate) fn new(
|
||||
custom_scorer: TCustomScorer,
|
||||
collector: TopCollector<TScore>,
|
||||
) -> CustomScoreTopCollector<TCustomScorer, TScore> {
|
||||
CustomScoreTopCollector {
|
||||
custom_scorer,
|
||||
collector,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A custom segment scorer makes it possible to define any kind of score
|
||||
/// for a given document belonging to a specific segment.
|
||||
///
|
||||
/// It is the segment local version of the [`CustomScorer`].
|
||||
pub trait CustomSegmentScorer<TScore>: 'static {
|
||||
/// Computes the score of a specific `doc`.
|
||||
fn score(&mut self, doc: DocId) -> TScore;
|
||||
}
|
||||
|
||||
/// `CustomScorer` makes it possible to define any kind of score.
|
||||
///
|
||||
/// The `CustomerScorer` itself does not make much of the computation itself.
|
||||
/// Instead, it helps constructing `Self::Child` instances that will compute
|
||||
/// the score at a segment scale.
|
||||
pub trait CustomScorer<TScore>: Sync {
|
||||
/// Type of the associated [`CustomSegmentScorer`].
|
||||
type Child: CustomSegmentScorer<TScore>;
|
||||
/// Builds a child scorer for a specific segment. The child scorer is associated with
|
||||
/// a specific segment.
|
||||
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child>;
|
||||
}
|
||||
|
||||
impl<TCustomScorer, TScore> Collector for CustomScoreTopCollector<TCustomScorer, TScore>
|
||||
where
|
||||
TCustomScorer: CustomScorer<TScore> + Send + Sync,
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync,
|
||||
{
|
||||
type Fruit = Vec<(TScore, DocAddress)>;
|
||||
|
||||
type Child = CustomScoreTopSegmentCollector<TCustomScorer::Child, TScore>;
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
segment_reader: &SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let segment_collector = self.collector.for_segment(segment_local_id, segment_reader);
|
||||
let segment_scorer = self.custom_scorer.segment_scorer(segment_reader)?;
|
||||
Ok(CustomScoreTopSegmentCollector {
|
||||
segment_collector,
|
||||
segment_scorer,
|
||||
})
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> crate::Result<Self::Fruit> {
|
||||
self.collector.merge_fruits(segment_fruits)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CustomScoreTopSegmentCollector<T, TScore>
|
||||
where
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync + Sized,
|
||||
T: CustomSegmentScorer<TScore>,
|
||||
{
|
||||
segment_collector: TopSegmentCollector<TScore>,
|
||||
segment_scorer: T,
|
||||
}
|
||||
|
||||
impl<T, TScore> SegmentCollector for CustomScoreTopSegmentCollector<T, TScore>
|
||||
where
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync,
|
||||
T: 'static + CustomSegmentScorer<TScore>,
|
||||
{
|
||||
type Fruit = Vec<(TScore, DocAddress)>;
|
||||
|
||||
fn collect(&mut self, doc: DocId, _score: Score) {
|
||||
let score = self.segment_scorer.score(doc);
|
||||
self.segment_collector.collect(doc, score);
|
||||
}
|
||||
|
||||
fn harvest(self) -> Vec<(TScore, DocAddress)> {
|
||||
self.segment_collector.harvest()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, TScore, T> CustomScorer<TScore> for F
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(&SegmentReader) -> T,
|
||||
T: CustomSegmentScorer<TScore>,
|
||||
{
|
||||
type Child = T;
|
||||
|
||||
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child> {
|
||||
Ok((self)(segment_reader))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, TScore> CustomSegmentScorer<TScore> for F
|
||||
where F: 'static + FnMut(DocId) -> TScore
|
||||
{
|
||||
fn score(&mut self, doc: DocId) -> TScore {
|
||||
(self)(doc)
|
||||
}
|
||||
}
|
||||
@@ -821,7 +821,6 @@ mod tests {
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
use test::Bencher;
|
||||
|
||||
@@ -12,6 +12,7 @@ use std::marker::PhantomData;
|
||||
use columnar::{BytesColumn, Column, DynamicColumn, HasAssociatedColumnType};
|
||||
|
||||
use crate::collector::{Collector, SegmentCollector};
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Score, SegmentReader};
|
||||
|
||||
/// The `FilterCollector` filters docs using a fast field value and a predicate.
|
||||
@@ -49,13 +50,13 @@ use crate::{DocId, Score, SegmentReader};
|
||||
///
|
||||
/// let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
/// let query = query_parser.parse_query("diary")?;
|
||||
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2));
|
||||
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2).order_by_score());
|
||||
/// let top_docs = searcher.search(&query, &no_filter_collector)?;
|
||||
///
|
||||
/// assert_eq!(top_docs.len(), 1);
|
||||
/// assert_eq!(top_docs[0].1, DocAddress::new(0, 1));
|
||||
///
|
||||
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2));
|
||||
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2).order_by_score());
|
||||
/// let filtered_top_docs = searcher.search(&query, &filter_all_collector)?;
|
||||
///
|
||||
/// assert_eq!(filtered_top_docs.len(), 0);
|
||||
@@ -104,6 +105,11 @@ where
|
||||
|
||||
type Child = FilterSegmentCollector<TCollector::Child, TPredicate, TPredicateValue>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.collector.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -120,6 +126,7 @@ where
|
||||
segment_collector,
|
||||
predicate: self.predicate.clone(),
|
||||
t_predicate_value: PhantomData,
|
||||
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -140,6 +147,7 @@ pub struct FilterSegmentCollector<TSegmentCollector, TPredicate, TPredicateValue
|
||||
segment_collector: TSegmentCollector,
|
||||
predicate: TPredicate,
|
||||
t_predicate_value: PhantomData<TPredicateValue>,
|
||||
filtered_docs: Vec<DocId>,
|
||||
}
|
||||
|
||||
impl<TSegmentCollector, TPredicate, TPredicateValue>
|
||||
@@ -176,6 +184,20 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.filtered_docs.clear();
|
||||
for &doc in docs {
|
||||
// TODO: `accept_document` could be further optimized to do batch lookups of column
|
||||
// values for single-valued columns.
|
||||
if self.accept_document(doc) {
|
||||
self.filtered_docs.push(doc);
|
||||
}
|
||||
}
|
||||
if !self.filtered_docs.is_empty() {
|
||||
self.segment_collector.collect_block(&self.filtered_docs);
|
||||
}
|
||||
}
|
||||
|
||||
fn harvest(self) -> TSegmentCollector::Fruit {
|
||||
self.segment_collector.harvest()
|
||||
}
|
||||
@@ -218,7 +240,7 @@ where
|
||||
///
|
||||
/// let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
/// let query = query_parser.parse_query("diary")?;
|
||||
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2));
|
||||
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2).order_by_score());
|
||||
/// let top_docs = searcher.search(&query, &filter_collector)?;
|
||||
///
|
||||
/// assert_eq!(top_docs.len(), 1);
|
||||
@@ -258,6 +280,10 @@ where
|
||||
|
||||
type Child = BytesFilterSegmentCollector<TCollector::Child, TPredicate>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.collector.check_schema(schema)
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -274,6 +300,7 @@ where
|
||||
segment_collector,
|
||||
predicate: self.predicate.clone(),
|
||||
buffer: Vec::new(),
|
||||
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -296,6 +323,7 @@ where TPredicate: 'static
|
||||
segment_collector: TSegmentCollector,
|
||||
predicate: TPredicate,
|
||||
buffer: Vec<u8>,
|
||||
filtered_docs: Vec<DocId>,
|
||||
}
|
||||
|
||||
impl<TSegmentCollector, TPredicate> BytesFilterSegmentCollector<TSegmentCollector, TPredicate>
|
||||
@@ -334,6 +362,20 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.filtered_docs.clear();
|
||||
for &doc in docs {
|
||||
// TODO: `accept_document` could be further optimized to do batch lookups of column
|
||||
// values for single-valued columns.
|
||||
if self.accept_document(doc) {
|
||||
self.filtered_docs.push(doc);
|
||||
}
|
||||
}
|
||||
if !self.filtered_docs.is_empty() {
|
||||
self.segment_collector.collect_block(&self.filtered_docs);
|
||||
}
|
||||
}
|
||||
|
||||
fn harvest(self) -> TSegmentCollector::Fruit {
|
||||
self.segment_collector.harvest()
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@
|
||||
//! # let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
//! # let query = query_parser.parse_query("diary")?;
|
||||
//! let (doc_count, top_docs): (usize, Vec<(Score, DocAddress)>) =
|
||||
//! searcher.search(&query, &(Count, TopDocs::with_limit(2)))?;
|
||||
//! searcher.search(&query, &(Count, TopDocs::with_limit(2).order_by_score()))?;
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
@@ -83,28 +83,28 @@
|
||||
|
||||
use downcast_rs::impl_downcast;
|
||||
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Score, SegmentOrdinal, SegmentReader};
|
||||
|
||||
mod count_collector;
|
||||
pub use self::count_collector::Count;
|
||||
|
||||
/// Sort keys
|
||||
pub mod sort_key;
|
||||
|
||||
mod histogram_collector;
|
||||
pub use histogram_collector::HistogramCollector;
|
||||
|
||||
mod multi_collector;
|
||||
pub use columnar::ComparableDoc;
|
||||
|
||||
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
|
||||
|
||||
mod top_collector;
|
||||
|
||||
mod top_score_collector;
|
||||
pub use self::top_collector::ComparableDoc;
|
||||
pub use self::top_score_collector::{TopDocs, TopNComputer};
|
||||
|
||||
mod custom_score_top_collector;
|
||||
pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer};
|
||||
|
||||
mod tweak_score_top_collector;
|
||||
pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker};
|
||||
mod sort_key_top_collector;
|
||||
pub use self::sort_key::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
mod facet_collector;
|
||||
pub use self::facet_collector::{FacetCollector, FacetCounts};
|
||||
use crate::query::Weight;
|
||||
@@ -145,6 +145,11 @@ pub trait Collector: Sync + Send {
|
||||
/// Type of the `SegmentCollector` associated with this collector.
|
||||
type Child: SegmentCollector;
|
||||
|
||||
/// Returns an error if the schema is not compatible with the collector.
|
||||
fn check_schema(&self, _schema: &Schema) -> crate::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// `set_segment` is called before beginning to enumerate
|
||||
/// on this segment.
|
||||
fn for_segment(
|
||||
@@ -170,41 +175,50 @@ pub trait Collector: Sync + Send {
|
||||
segment_ord: u32,
|
||||
reader: &SegmentReader,
|
||||
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
|
||||
let with_scoring = self.requires_scoring();
|
||||
let mut segment_collector = self.for_segment(segment_ord, reader)?;
|
||||
|
||||
match (reader.alive_bitset(), self.requires_scoring()) {
|
||||
(Some(alive_bitset), true) => {
|
||||
weight.for_each(reader, &mut |doc, score| {
|
||||
if alive_bitset.is_alive(doc) {
|
||||
segment_collector.collect(doc, score);
|
||||
}
|
||||
})?;
|
||||
}
|
||||
(Some(alive_bitset), false) => {
|
||||
weight.for_each_no_score(reader, &mut |docs| {
|
||||
for doc in docs.iter().cloned() {
|
||||
if alive_bitset.is_alive(doc) {
|
||||
segment_collector.collect(doc, 0.0);
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
(None, true) => {
|
||||
weight.for_each(reader, &mut |doc, score| {
|
||||
segment_collector.collect(doc, score);
|
||||
})?;
|
||||
}
|
||||
(None, false) => {
|
||||
weight.for_each_no_score(reader, &mut |docs| {
|
||||
segment_collector.collect_block(docs);
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
default_collect_segment_impl(&mut segment_collector, weight, reader, with_scoring)?;
|
||||
Ok(segment_collector.harvest())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn default_collect_segment_impl<TSegmentCollector: SegmentCollector>(
|
||||
segment_collector: &mut TSegmentCollector,
|
||||
weight: &dyn Weight,
|
||||
reader: &SegmentReader,
|
||||
with_scoring: bool,
|
||||
) -> crate::Result<()> {
|
||||
match (reader.alive_bitset(), with_scoring) {
|
||||
(Some(alive_bitset), true) => {
|
||||
weight.for_each(reader, &mut |doc, score| {
|
||||
if alive_bitset.is_alive(doc) {
|
||||
segment_collector.collect(doc, score);
|
||||
}
|
||||
})?;
|
||||
}
|
||||
(Some(alive_bitset), false) => {
|
||||
weight.for_each_no_score(reader, &mut |docs| {
|
||||
for doc in docs.iter().cloned() {
|
||||
if alive_bitset.is_alive(doc) {
|
||||
segment_collector.collect(doc, 0.0);
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
(None, true) => {
|
||||
weight.for_each(reader, &mut |doc, score| {
|
||||
segment_collector.collect(doc, score);
|
||||
})?;
|
||||
}
|
||||
(None, false) => {
|
||||
weight.for_each_no_score(reader, &mut |docs| {
|
||||
segment_collector.collect_block(docs);
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCollector> {
|
||||
type Fruit = Option<TSegmentCollector::Fruit>;
|
||||
|
||||
@@ -214,6 +228,12 @@ impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCo
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
if let Some(segment_collector) = self {
|
||||
segment_collector.collect_block(docs);
|
||||
}
|
||||
}
|
||||
|
||||
fn harvest(self) -> Self::Fruit {
|
||||
self.map(|segment_collector| segment_collector.harvest())
|
||||
}
|
||||
@@ -224,6 +244,13 @@ impl<TCollector: Collector> Collector for Option<TCollector> {
|
||||
|
||||
type Child = Option<<TCollector as Collector>::Child>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
if let Some(underlying_collector) = self {
|
||||
underlying_collector.check_schema(schema)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: SegmentOrdinal,
|
||||
@@ -299,6 +326,12 @@ where
|
||||
type Fruit = (Left::Fruit, Right::Fruit);
|
||||
type Child = (Left::Child, Right::Child);
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -342,6 +375,11 @@ where
|
||||
self.1.collect(doc, score);
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.0.collect_block(docs);
|
||||
self.1.collect_block(docs);
|
||||
}
|
||||
|
||||
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
|
||||
(self.0.harvest(), self.1.harvest())
|
||||
}
|
||||
@@ -358,6 +396,13 @@ where
|
||||
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit);
|
||||
type Child = (One::Child, Two::Child, Three::Child);
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
self.2.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -407,6 +452,12 @@ where
|
||||
self.2.collect(doc, score);
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.0.collect_block(docs);
|
||||
self.1.collect_block(docs);
|
||||
self.2.collect_block(docs);
|
||||
}
|
||||
|
||||
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
|
||||
(self.0.harvest(), self.1.harvest(), self.2.harvest())
|
||||
}
|
||||
@@ -424,6 +475,14 @@ where
|
||||
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit, Four::Fruit);
|
||||
type Child = (One::Child, Two::Child, Three::Child, Four::Child);
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
self.2.check_schema(schema)?;
|
||||
self.3.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -482,6 +541,13 @@ where
|
||||
self.3.collect(doc, score);
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.0.collect_block(docs);
|
||||
self.1.collect_block(docs);
|
||||
self.2.collect_block(docs);
|
||||
self.3.collect_block(docs);
|
||||
}
|
||||
|
||||
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
|
||||
(
|
||||
self.0.harvest(),
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::ops::Deref;
|
||||
|
||||
use super::{Collector, SegmentCollector};
|
||||
use crate::collector::Fruit;
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Score, SegmentOrdinal, SegmentReader, TantivyError};
|
||||
|
||||
/// MultiFruit keeps Fruits from every nested Collector
|
||||
@@ -16,6 +17,10 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
|
||||
type Fruit = Box<dyn Fruit>;
|
||||
type Child = Box<dyn BoxableSegmentCollector>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -147,7 +152,7 @@ impl<TFruit: Fruit> FruitHandle<TFruit> {
|
||||
/// let searcher = reader.searcher();
|
||||
///
|
||||
/// let mut collectors = MultiCollector::new();
|
||||
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2));
|
||||
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
|
||||
/// let count_handle = collectors.add_collector(Count);
|
||||
/// let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
/// let query = query_parser.parse_query("diary").unwrap();
|
||||
@@ -194,6 +199,13 @@ impl Collector for MultiCollector<'_> {
|
||||
type Fruit = MultiFruit;
|
||||
type Child = MultiCollectorChild;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
for collector in &self.collector_wrappers {
|
||||
collector.check_schema(schema)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: SegmentOrdinal,
|
||||
@@ -250,6 +262,12 @@ impl SegmentCollector for MultiCollectorChild {
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
for child in &mut self.children {
|
||||
child.collect_block(docs);
|
||||
}
|
||||
}
|
||||
|
||||
fn harvest(self) -> MultiFruit {
|
||||
MultiFruit {
|
||||
sub_fruits: self
|
||||
@@ -263,7 +281,6 @@ impl SegmentCollector for MultiCollectorChild {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::collector::{Count, TopDocs};
|
||||
use crate::query::TermQuery;
|
||||
@@ -293,7 +310,7 @@ mod tests {
|
||||
let query = TermQuery::new(term, IndexRecordOption::Basic);
|
||||
|
||||
let mut collectors = MultiCollector::new();
|
||||
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2));
|
||||
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
|
||||
let count_handler = collectors.add_collector(Count);
|
||||
let mut multifruits = searcher.search(&query, &collectors).unwrap();
|
||||
|
||||
|
||||
679
src/collector/sort_key/mod.rs
Normal file
679
src/collector/sort_key/mod.rs
Normal file
@@ -0,0 +1,679 @@
|
||||
mod order;
|
||||
mod sort_by_erased_type;
|
||||
mod sort_by_score;
|
||||
mod sort_by_static_fast_value;
|
||||
mod sort_by_string;
|
||||
mod sort_key_computer;
|
||||
|
||||
pub use order::*;
|
||||
pub use sort_by_erased_type::SortByErasedType;
|
||||
pub use sort_by_score::SortBySimilarityScore;
|
||||
pub use sort_by_static_fast_value::SortByStaticFastValue;
|
||||
pub use sort_by_string::SortByString;
|
||||
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::collector::sort_key::{
|
||||
Comparator, NaturalComparator, ReverseComparator, SortByErasedType, SortBySimilarityScore,
|
||||
SortByStaticFastValue, SortByString,
|
||||
};
|
||||
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
|
||||
use crate::indexer::NoMergePolicy;
|
||||
use crate::query::{AllQuery, QueryParser};
|
||||
use crate::schema::{OwnedValue, Schema, FAST, TEXT};
|
||||
use crate::{DocAddress, Document, Index, Order, Score, Searcher};
|
||||
|
||||
fn make_index() -> crate::Result<Index> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let id = schema_builder.add_u64_field("id", FAST);
|
||||
let city = schema_builder.add_text_field("city", TEXT | FAST);
|
||||
let catchphrase = schema_builder.add_text_field("catchphrase", TEXT);
|
||||
let altitude = schema_builder.add_f64_field("altitude", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
fn create_segment(index: &Index, docs: Vec<impl Document>) -> crate::Result<()> {
|
||||
let mut index_writer = index.writer_for_tests()?;
|
||||
index_writer.set_merge_policy(Box::new(NoMergePolicy));
|
||||
for doc in docs {
|
||||
index_writer.add_document(doc)?;
|
||||
}
|
||||
index_writer.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
create_segment(
|
||||
&index,
|
||||
vec![
|
||||
doc!(
|
||||
id => 0_u64,
|
||||
city => "austin",
|
||||
catchphrase => "Hills, Barbeque, Glow",
|
||||
altitude => 149.0,
|
||||
),
|
||||
doc!(
|
||||
id => 1_u64,
|
||||
city => "greenville",
|
||||
catchphrase => "Grow, Glow, Glow",
|
||||
altitude => 27.0,
|
||||
),
|
||||
],
|
||||
)?;
|
||||
create_segment(
|
||||
&index,
|
||||
vec![doc!(
|
||||
id => 2_u64,
|
||||
city => "tokyo",
|
||||
catchphrase => "Glow, Glow, Glow",
|
||||
altitude => 40.0,
|
||||
)],
|
||||
)?;
|
||||
create_segment(
|
||||
&index,
|
||||
vec![doc!(
|
||||
id => 3_u64,
|
||||
catchphrase => "No, No, No",
|
||||
altitude => 0.0,
|
||||
)],
|
||||
)?;
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
// NOTE: You cannot determine the SegmentIds that will be generated for Segments
|
||||
// ahead of time, so DocAddresses must be mapped back to a unique id for each Searcher.
|
||||
fn id_mapping(searcher: &Searcher) -> HashMap<DocAddress, u64> {
|
||||
searcher
|
||||
.search(&AllQuery, &DocSetCollector)
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|doc_address| {
|
||||
let column = searcher.segment_readers()[doc_address.segment_ord as usize]
|
||||
.fast_fields()
|
||||
.u64("id")
|
||||
.unwrap();
|
||||
(doc_address, column.first(doc_address.doc_id).unwrap())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_string() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
#[track_caller]
|
||||
fn assert_query(
|
||||
index: &Index,
|
||||
order: Order,
|
||||
doc_range: Range<usize>,
|
||||
expected: Vec<(Option<String>, u64)>,
|
||||
) -> crate::Result<()> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
// Try as primitive.
|
||||
let top_collector = TopDocs::for_doc_range(doc_range)
|
||||
.order_by((SortByString::for_field("city"), order));
|
||||
let actual = searcher
|
||||
.search(&AllQuery, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(sort_key_opt, doc)| (sort_key_opt, ids[&doc]))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(actual, expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
0..4,
|
||||
vec![
|
||||
(Some("austin".to_owned()), 0),
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("tokyo".to_owned()), 2),
|
||||
(None, 3),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
0..3,
|
||||
vec![
|
||||
(Some("austin".to_owned()), 0),
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("tokyo".to_owned()), 2),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
0..2,
|
||||
vec![
|
||||
(Some("austin".to_owned()), 0),
|
||||
(Some("greenville".to_owned()), 1),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
0..1,
|
||||
vec![(Some("austin".to_string()), 0)],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
1..3,
|
||||
vec![
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("tokyo".to_owned()), 2),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Desc,
|
||||
0..4,
|
||||
vec![
|
||||
(Some("tokyo".to_owned()), 2),
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("austin".to_owned()), 0),
|
||||
(None, 3),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Desc,
|
||||
1..3,
|
||||
vec![
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("austin".to_owned()), 0),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Desc,
|
||||
0..1,
|
||||
vec![(Some("tokyo".to_owned()), 2)],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_f64() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
fn assert_query(
|
||||
index: &Index,
|
||||
order: Order,
|
||||
expected: Vec<(Option<f64>, u64)>,
|
||||
) -> crate::Result<()> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
// Try as primitive.
|
||||
let top_collector = TopDocs::with_limit(3)
|
||||
.order_by((SortByStaticFastValue::<f64>::for_field("altitude"), order));
|
||||
let actual = searcher
|
||||
.search(&AllQuery, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(altitude_opt, doc)| (altitude_opt, ids[&doc]))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(actual, expected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
vec![(Some(0.0), 3), (Some(27.0), 1), (Some(40.0), 2)],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Desc,
|
||||
vec![(Some(149.0), 0), (Some(40.0), 2), (Some(27.0), 1)],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_score() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
fn query(index: &Index, order: Order) -> crate::Result<Vec<(Score, u64)>> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
let top_collector = TopDocs::with_limit(4).order_by((SortBySimilarityScore, order));
|
||||
let field = index.schema().get_field("catchphrase").unwrap();
|
||||
let query_parser = QueryParser::for_index(index, vec![field]);
|
||||
let text_query = query_parser.parse_query("glow")?;
|
||||
|
||||
Ok(searcher
|
||||
.search(&text_query, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(score, doc)| (score, ids[&doc]))
|
||||
.collect())
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Desc)?,
|
||||
&[(0.5604893, 2), (0.4904281, 1), (0.35667497, 0),]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc)?,
|
||||
&[(0.35667497, 0), (0.4904281, 1), (0.5604893, 2),]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_score_then_string() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
type SortKey = (Score, Option<String>);
|
||||
|
||||
fn query(
|
||||
index: &Index,
|
||||
score_order: Order,
|
||||
city_order: Order,
|
||||
) -> crate::Result<Vec<(SortKey, u64)>> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
let top_collector = TopDocs::with_limit(4).order_by((
|
||||
(SortBySimilarityScore, score_order),
|
||||
(SortByString::for_field("city"), city_order),
|
||||
));
|
||||
let results: Vec<((Score, Option<String>), DocAddress)> =
|
||||
searcher.search(&AllQuery, &top_collector)?;
|
||||
Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc, Order::Asc)?,
|
||||
&[
|
||||
((1.0, Some("austin".to_owned())), 0),
|
||||
((1.0, Some("greenville".to_owned())), 1),
|
||||
((1.0, Some("tokyo".to_owned())), 2),
|
||||
((1.0, None), 3),
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc, Order::Desc)?,
|
||||
&[
|
||||
((1.0, Some("tokyo".to_owned())), 2),
|
||||
((1.0, Some("greenville".to_owned())), 1),
|
||||
((1.0, Some("austin".to_owned())), 0),
|
||||
((1.0, None), 3),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_score_then_owned_value() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
type SortKey = (Score, OwnedValue);
|
||||
|
||||
fn query(
|
||||
index: &Index,
|
||||
score_order: Order,
|
||||
city_order: Order,
|
||||
) -> crate::Result<Vec<(SortKey, u64)>> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
let top_collector = TopDocs::with_limit(4).order_by::<(Score, OwnedValue)>((
|
||||
(SortBySimilarityScore, score_order),
|
||||
(SortByErasedType::for_field("city"), city_order),
|
||||
));
|
||||
let results: Vec<((Score, OwnedValue), DocAddress)> =
|
||||
searcher.search(&AllQuery, &top_collector)?;
|
||||
Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc, Order::Asc)?,
|
||||
&[
|
||||
((1.0, OwnedValue::Str("austin".to_owned())), 0),
|
||||
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
|
||||
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
|
||||
((1.0, OwnedValue::Null), 3),
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc, Order::Desc)?,
|
||||
&[
|
||||
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
|
||||
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
|
||||
((1.0, OwnedValue::Str("austin".to_owned())), 0),
|
||||
((1.0, OwnedValue::Null), 3),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_compound_fast_fields() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
type CompoundSortKey = (Option<String>, Option<f64>);
|
||||
|
||||
fn assert_query(
|
||||
index: &Index,
|
||||
city_order: Order,
|
||||
altitude_order: Order,
|
||||
expected: Vec<(CompoundSortKey, u64)>,
|
||||
) -> crate::Result<()> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
let top_collector = TopDocs::with_limit(4).order_by((
|
||||
(SortByString::for_field("city"), city_order),
|
||||
(
|
||||
SortByStaticFastValue::<f64>::for_field("altitude"),
|
||||
altitude_order,
|
||||
),
|
||||
));
|
||||
let actual = searcher
|
||||
.search(&AllQuery, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(key, doc)| (key, ids[&doc]))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(actual, expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
Order::Desc,
|
||||
vec![
|
||||
((Some("austin".to_owned()), Some(149.0)), 0),
|
||||
((Some("greenville".to_owned()), Some(27.0)), 1),
|
||||
((Some("tokyo".to_owned()), Some(40.0)), 2),
|
||||
((None, Some(0.0)), 3),
|
||||
],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_order_by_string_prop(
|
||||
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
|
||||
limit in 1..64_usize,
|
||||
offset in 0..64_usize,
|
||||
segments_terms in
|
||||
proptest::collection::vec(
|
||||
proptest::collection::vec(0..32_u8, 1..32_usize),
|
||||
0..8_usize,
|
||||
)
|
||||
) {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let city = schema_builder.add_text_field("city", TEXT | FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests()?;
|
||||
|
||||
// A Vec<Vec<u8>>, where the outer Vec represents segments, and the inner Vec
|
||||
// represents terms.
|
||||
for segment_terms in segments_terms.into_iter() {
|
||||
for term in segment_terms.into_iter() {
|
||||
let term = format!("{term:0>3}");
|
||||
index_writer.add_document(doc!(
|
||||
city => term,
|
||||
))?;
|
||||
}
|
||||
index_writer.commit()?;
|
||||
}
|
||||
|
||||
let searcher = index.reader()?.searcher();
|
||||
let top_n_results = searcher.search(&AllQuery, &TopDocs::with_limit(limit)
|
||||
.and_offset(offset)
|
||||
.order_by_string_fast_field("city", order))?;
|
||||
let all_results = searcher.search(&AllQuery, &DocSetCollector)?.into_iter().map(|doc_address| {
|
||||
// Get the term for this address.
|
||||
let column = searcher.segment_readers()[doc_address.segment_ord as usize].fast_fields().str("city").unwrap().unwrap();
|
||||
let value = column.term_ords(doc_address.doc_id).next().map(|term_ord| {
|
||||
let mut city = Vec::new();
|
||||
column.dictionary().ord_to_term(term_ord, &mut city).unwrap();
|
||||
String::try_from(city).unwrap()
|
||||
});
|
||||
(value, doc_address)
|
||||
});
|
||||
|
||||
// Using the TopDocs collector should always be equivalent to sorting, skipping the
|
||||
// offset, and then taking the limit.
|
||||
let sorted_docs: Vec<_> = {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
|
||||
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
|
||||
if order.is_desc() {
|
||||
comparable_docs.sort_by(|l, r| NaturalComparator.compare_doc(l, r));
|
||||
} else {
|
||||
comparable_docs.sort_by(|l, r| ReverseComparator.compare_doc(l, r));
|
||||
}
|
||||
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
|
||||
};
|
||||
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
|
||||
prop_assert_eq!(
|
||||
expected_docs,
|
||||
top_n_results
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_order_by_compound_prop(
|
||||
city_order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
|
||||
altitude_order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
|
||||
limit in 1..20_usize,
|
||||
offset in 0..20_usize,
|
||||
segments_data in proptest::collection::vec(
|
||||
proptest::collection::vec(
|
||||
(proptest::option::of("[a-c]"), proptest::option::of(0..50u64)),
|
||||
1..10_usize // segment size
|
||||
),
|
||||
1..4_usize // num segments
|
||||
)
|
||||
) {
|
||||
use crate::collector::sort_key::ComparatorEnum;
|
||||
use crate::TantivyDocument;
|
||||
|
||||
let mut schema_builder = Schema::builder();
|
||||
let city = schema_builder.add_text_field("city", TEXT | FAST);
|
||||
let altitude = schema_builder.add_u64_field("altitude", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
|
||||
for segment_data in segments_data.into_iter() {
|
||||
for (city_val, altitude_val) in segment_data.into_iter() {
|
||||
let mut doc = TantivyDocument::default();
|
||||
if let Some(c) = city_val {
|
||||
doc.add_text(city, c);
|
||||
}
|
||||
if let Some(a) = altitude_val {
|
||||
doc.add_u64(altitude, a);
|
||||
}
|
||||
index_writer.add_document(doc).unwrap();
|
||||
}
|
||||
index_writer.commit().unwrap();
|
||||
}
|
||||
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
|
||||
let top_collector = TopDocs::with_limit(limit)
|
||||
.and_offset(offset)
|
||||
.order_by((
|
||||
(SortByString::for_field("city"), city_order),
|
||||
(
|
||||
SortByStaticFastValue::<u64>::for_field("altitude"),
|
||||
altitude_order,
|
||||
),
|
||||
));
|
||||
|
||||
let actual_results = searcher.search(&AllQuery, &top_collector).unwrap();
|
||||
let actual_doc_ids: Vec<DocAddress> =
|
||||
actual_results.into_iter().map(|(_, doc)| doc).collect();
|
||||
|
||||
// Verification logic
|
||||
let all_docs_collector = DocSetCollector;
|
||||
let all_docs = searcher.search(&AllQuery, &all_docs_collector).unwrap();
|
||||
|
||||
let docs_with_keys: Vec<((Option<String>, Option<u64>), DocAddress)> = all_docs
|
||||
.into_iter()
|
||||
.map(|doc_addr| {
|
||||
let reader = searcher.segment_reader(doc_addr.segment_ord);
|
||||
|
||||
let city_val = if let Some(col) = reader.fast_fields().str("city").unwrap() {
|
||||
let ord = col.ords().first(doc_addr.doc_id);
|
||||
if let Some(ord) = ord {
|
||||
let mut out = Vec::new();
|
||||
col.dictionary().ord_to_term(ord, &mut out).unwrap();
|
||||
String::from_utf8(out).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let alt_val = if let Some((col, _)) = reader.fast_fields().u64_lenient("altitude").unwrap() {
|
||||
col.first(doc_addr.doc_id)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
((city_val, alt_val), doc_addr)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let city_comparator = ComparatorEnum::from(city_order);
|
||||
let alt_comparator = ComparatorEnum::from(altitude_order);
|
||||
let comparator = (city_comparator, alt_comparator);
|
||||
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _>> = docs_with_keys
|
||||
.into_iter()
|
||||
.map(|(sort_key, doc)| ComparableDoc { sort_key, doc })
|
||||
.collect();
|
||||
|
||||
comparable_docs.sort_by(|l, r| comparator.compare_doc(l, r));
|
||||
|
||||
let expected_results = comparable_docs
|
||||
.into_iter()
|
||||
.skip(offset)
|
||||
.take(limit)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected_doc_ids: Vec<DocAddress> =
|
||||
expected_results.into_iter().map(|cd| cd.doc).collect();
|
||||
|
||||
prop_assert_eq!(actual_doc_ids, expected_doc_ids);
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_order_by_u64_prop(
|
||||
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
|
||||
limit in 1..20_usize,
|
||||
offset in 0..20_usize,
|
||||
segments_data in proptest::collection::vec(
|
||||
proptest::collection::vec(
|
||||
proptest::option::of(0..100u64),
|
||||
1..1000_usize // segment size
|
||||
),
|
||||
1..4_usize // num segments
|
||||
)
|
||||
) {
|
||||
use crate::collector::sort_key::ComparatorEnum;
|
||||
use crate::TantivyDocument;
|
||||
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field = schema_builder.add_u64_field("field", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
|
||||
for segment_data in segments_data.into_iter() {
|
||||
for val in segment_data.into_iter() {
|
||||
let mut doc = TantivyDocument::default();
|
||||
if let Some(v) = val {
|
||||
doc.add_u64(field, v);
|
||||
}
|
||||
index_writer.add_document(doc).unwrap();
|
||||
}
|
||||
index_writer.commit().unwrap();
|
||||
}
|
||||
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
|
||||
let top_collector = TopDocs::with_limit(limit)
|
||||
.and_offset(offset)
|
||||
.order_by((SortByStaticFastValue::<u64>::for_field("field"), order));
|
||||
|
||||
let actual_results = searcher.search(&AllQuery, &top_collector).unwrap();
|
||||
let actual_doc_ids: Vec<DocAddress> =
|
||||
actual_results.into_iter().map(|(_, doc)| doc).collect();
|
||||
|
||||
// Verification logic
|
||||
let all_docs_collector = DocSetCollector;
|
||||
let all_docs = searcher.search(&AllQuery, &all_docs_collector).unwrap();
|
||||
|
||||
let docs_with_keys: Vec<(Option<u64>, DocAddress)> = all_docs
|
||||
.into_iter()
|
||||
.map(|doc_addr| {
|
||||
let reader = searcher.segment_reader(doc_addr.segment_ord);
|
||||
let val = if let Some((col, _)) = reader.fast_fields().u64_lenient("field").unwrap() {
|
||||
col.first(doc_addr.doc_id)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(val, doc_addr)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let comparator = ComparatorEnum::from(order);
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _>> = docs_with_keys
|
||||
.into_iter()
|
||||
.map(|(sort_key, doc)| ComparableDoc { sort_key, doc })
|
||||
.collect();
|
||||
|
||||
comparable_docs.sort_by(|l, r| comparator.compare_doc(l, r));
|
||||
|
||||
let expected_results = comparable_docs
|
||||
.into_iter()
|
||||
.skip(offset)
|
||||
.take(limit)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected_doc_ids: Vec<DocAddress> =
|
||||
expected_results.into_iter().map(|cd| cd.doc).collect();
|
||||
|
||||
prop_assert_eq!(actual_doc_ids, expected_doc_ids);
|
||||
}
|
||||
}
|
||||
}
|
||||
785
src/collector/sort_key/order.rs
Normal file
785
src/collector/sort_key/order.rs
Normal file
@@ -0,0 +1,785 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use columnar::{MonotonicallyMappableToU64, ValueRange};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::schema::{OwnedValue, Schema};
|
||||
use crate::{DocId, Order, Score};
|
||||
|
||||
fn compare_owned_value<const NULLS_FIRST: bool>(lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
match (lhs, rhs) {
|
||||
(OwnedValue::Null, OwnedValue::Null) => Ordering::Equal,
|
||||
(OwnedValue::Null, _) => {
|
||||
if NULLS_FIRST {
|
||||
Ordering::Less
|
||||
} else {
|
||||
Ordering::Greater
|
||||
}
|
||||
}
|
||||
(_, OwnedValue::Null) => {
|
||||
if NULLS_FIRST {
|
||||
Ordering::Greater
|
||||
} else {
|
||||
Ordering::Less
|
||||
}
|
||||
}
|
||||
(OwnedValue::Str(a), OwnedValue::Str(b)) => a.cmp(b),
|
||||
(OwnedValue::PreTokStr(a), OwnedValue::PreTokStr(b)) => a.cmp(b),
|
||||
(OwnedValue::U64(a), OwnedValue::U64(b)) => a.cmp(b),
|
||||
(OwnedValue::I64(a), OwnedValue::I64(b)) => a.cmp(b),
|
||||
(OwnedValue::F64(a), OwnedValue::F64(b)) => a.to_u64().cmp(&b.to_u64()),
|
||||
(OwnedValue::Bool(a), OwnedValue::Bool(b)) => a.cmp(b),
|
||||
(OwnedValue::Date(a), OwnedValue::Date(b)) => a.cmp(b),
|
||||
(OwnedValue::Facet(a), OwnedValue::Facet(b)) => a.cmp(b),
|
||||
(OwnedValue::Bytes(a), OwnedValue::Bytes(b)) => a.cmp(b),
|
||||
(OwnedValue::IpAddr(a), OwnedValue::IpAddr(b)) => a.cmp(b),
|
||||
(OwnedValue::U64(a), OwnedValue::I64(b)) => {
|
||||
if *b < 0 {
|
||||
Ordering::Greater
|
||||
} else {
|
||||
a.cmp(&(*b as u64))
|
||||
}
|
||||
}
|
||||
(OwnedValue::I64(a), OwnedValue::U64(b)) => {
|
||||
if *a < 0 {
|
||||
Ordering::Less
|
||||
} else {
|
||||
(*a as u64).cmp(b)
|
||||
}
|
||||
}
|
||||
(OwnedValue::U64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
|
||||
(OwnedValue::F64(a), OwnedValue::U64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
|
||||
(OwnedValue::I64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
|
||||
(OwnedValue::F64(a), OwnedValue::I64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
|
||||
(a, b) => {
|
||||
let ord = a.discriminant_value().cmp(&b.discriminant_value());
|
||||
// If the discriminant is equal, it's because a new type was added, but hasn't been
|
||||
// included in this `match` statement.
|
||||
assert!(
|
||||
ord != Ordering::Equal,
|
||||
"Unimplemented comparison for type of {a:?}, {b:?}"
|
||||
);
|
||||
ord
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Comparator trait defining the order in which documents should be ordered.
|
||||
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
|
||||
/// Return the order between two values.
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
|
||||
/// Return the order between two ComparableDoc values, using the semantics which are
|
||||
/// implemented by TopNComputer.
|
||||
#[inline(always)]
|
||||
fn compare_doc<D: Ord>(
|
||||
&self,
|
||||
lhs: &ComparableDoc<T, D>,
|
||||
rhs: &ComparableDoc<T, D>,
|
||||
) -> Ordering {
|
||||
// TopNComputer sorts in descending order of the SortKey by default: we apply that ordering
|
||||
// here to ease comparison in testing.
|
||||
self.compare(&rhs.sort_key, &lhs.sort_key).then_with(|| {
|
||||
// In case of a tie on the sort key, we always sort by ascending `DocAddress` in order
|
||||
// to ensure a stable sorting of the documents, regardless of the sort key's order.
|
||||
// See the TopNComputer docs for more information.
|
||||
lhs.doc.cmp(&rhs.doc)
|
||||
})
|
||||
}
|
||||
|
||||
/// Return a `ValueRange` that matches all values that are greater than the provided threshold.
|
||||
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T>;
|
||||
}
|
||||
|
||||
/// Compare values naturally (e.g. 1 < 2).
|
||||
///
|
||||
/// When used with `TopDocs`, which reverses the order, this results in a
|
||||
/// "Descending" sort (Greatest values first).
|
||||
///
|
||||
/// `None` (or Null for `OwnedValue`) values are considered to be smaller than any other value,
|
||||
/// and will therefore appear last in a descending sort (e.g. `[Some(20), Some(10), None]`).
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct NaturalComparator;
|
||||
|
||||
impl<T: PartialOrd> Comparator<T> for NaturalComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
lhs.partial_cmp(rhs).unwrap()
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
/// A (partial) implementation of comparison for OwnedValue.
|
||||
///
|
||||
/// Intended for use within columns of homogenous types, and so will panic for OwnedValues with
|
||||
/// mismatched types. The one exception is Null, for which we do define all comparisons.
|
||||
impl Comparator<OwnedValue> for NaturalComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
compare_owned_value::</* NULLS_FIRST= */ true>(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare values in reverse (e.g. 2 < 1).
|
||||
///
|
||||
/// When used with `TopDocs`, which reverses the order, this results in an
|
||||
/// "Ascending" sort (Smallest values first).
|
||||
///
|
||||
/// `None` is considered smaller than `Some` in the underlying comparator, but because the
|
||||
/// comparison is reversed, `None` is effectively treated as the lowest value in the resulting
|
||||
/// Ascending sort (e.g. `[None, Some(10), Some(20)]`).
|
||||
///
|
||||
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
|
||||
/// to the NaturalComparator. In presence of a tie on the sort key, documents will always be
|
||||
/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the sort key's order.
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ReverseComparator;
|
||||
|
||||
macro_rules! impl_reverse_comparator_primitive {
|
||||
($($t:ty),*) => {
|
||||
$(
|
||||
impl Comparator<$t> for ReverseComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &$t, rhs: &$t) -> Ordering {
|
||||
NaturalComparator.compare(rhs, lhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: $t) -> ValueRange<$t> {
|
||||
ValueRange::LessThan(threshold, true)
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
impl_reverse_comparator_primitive!(
|
||||
bool,
|
||||
u8,
|
||||
u16,
|
||||
u32,
|
||||
u64,
|
||||
u128,
|
||||
usize,
|
||||
i8,
|
||||
i16,
|
||||
i32,
|
||||
i64,
|
||||
i128,
|
||||
isize,
|
||||
f32,
|
||||
f64,
|
||||
String,
|
||||
crate::DateTime,
|
||||
Vec<u8>,
|
||||
crate::schema::Facet
|
||||
);
|
||||
|
||||
impl<T: PartialOrd + Send + Sync + std::fmt::Debug + Clone + 'static> Comparator<Option<T>>
|
||||
for ReverseComparator
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &Option<T>, rhs: &Option<T>) -> Ordering {
|
||||
NaturalComparator.compare(rhs, lhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
|
||||
let is_some = threshold.is_some();
|
||||
ValueRange::LessThan(threshold, is_some)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<OwnedValue> for ReverseComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
NaturalComparator.compare(rhs, lhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
|
||||
let is_not_null = !matches!(threshold, OwnedValue::Null);
|
||||
ValueRange::LessThan(threshold, is_not_null)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare values in reverse, but treating `None` as lower than `Some`.
|
||||
///
|
||||
/// When used with `TopDocs`, which reverses the order, this results in an
|
||||
/// "Ascending" sort (Smallest values first), but with `None` values appearing last
|
||||
/// (e.g. `[Some(10), Some(20), None]`).
|
||||
///
|
||||
/// This is usually what is wanted when sorting by a field in an ascending order.
|
||||
/// For instance, in an e-commerce website, if sorting by price ascending,
|
||||
/// the cheapest items would appear first, and items without a price would appear last.
|
||||
#[derive(Debug, Copy, Clone, Default)]
|
||||
pub struct ReverseNoneIsLowerComparator;
|
||||
|
||||
impl<T> Comparator<Option<T>> for ReverseNoneIsLowerComparator
|
||||
where ReverseComparator: Comparator<T>
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs_opt: &Option<T>, rhs_opt: &Option<T>) -> Ordering {
|
||||
match (lhs_opt, rhs_opt) {
|
||||
(None, None) => Ordering::Equal,
|
||||
(None, Some(_)) => Ordering::Less,
|
||||
(Some(_), None) => Ordering::Greater,
|
||||
(Some(lhs), Some(rhs)) => ReverseComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
|
||||
if threshold.is_some() {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
} else {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u32> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange<u32> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u64> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange<u64> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f64> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange<f64> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f32> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange<f32> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<i64> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange<i64> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<String> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: String) -> ValueRange<String> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<OwnedValue> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
compare_owned_value::</* NULLS_FIRST= */ false>(rhs, lhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare values naturally, but treating `None` as higher than `Some`.
|
||||
///
|
||||
/// When used with `TopDocs`, which reverses the order, this results in a
|
||||
/// "Descending" sort (Greatest values first), but with `None` values appearing first
|
||||
/// (e.g. `[None, Some(20), Some(10)]`).
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct NaturalNoneIsHigherComparator;
|
||||
|
||||
impl<T> Comparator<Option<T>> for NaturalNoneIsHigherComparator
|
||||
where NaturalComparator: Comparator<T>
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs_opt: &Option<T>, rhs_opt: &Option<T>) -> Ordering {
|
||||
match (lhs_opt, rhs_opt) {
|
||||
(None, None) => Ordering::Equal,
|
||||
(None, Some(_)) => Ordering::Greater,
|
||||
(Some(_), None) => Ordering::Less,
|
||||
(Some(lhs), Some(rhs)) => NaturalComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
|
||||
if threshold.is_some() {
|
||||
let is_some = threshold.is_some();
|
||||
ValueRange::GreaterThan(threshold, is_some)
|
||||
} else {
|
||||
ValueRange::LessThan(threshold, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u32> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange<u32> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u64> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange<u64> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f64> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange<f64> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f32> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange<f32> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<i64> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange<i64> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<String> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
|
||||
NaturalComparator.compare(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: String) -> ValueRange<String> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<OwnedValue> for NaturalNoneIsHigherComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
|
||||
compare_owned_value::</* NULLS_FIRST= */ false>(lhs, rhs)
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
|
||||
ValueRange::GreaterThan(threshold, true)
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum representing the different sort orders.
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
|
||||
pub enum ComparatorEnum {
|
||||
/// Natural order (See [NaturalComparator])
|
||||
#[default]
|
||||
Natural,
|
||||
/// Reverse order (See [ReverseComparator])
|
||||
Reverse,
|
||||
/// Reverse order by treating None as the lowest value. (See [ReverseNoneLowerComparator])
|
||||
ReverseNoneLower,
|
||||
/// Natural order but treating None as the highest value. (See [NaturalNoneIsHigherComparator])
|
||||
NaturalNoneHigher,
|
||||
}
|
||||
|
||||
impl From<Order> for ComparatorEnum {
|
||||
fn from(order: Order) -> Self {
|
||||
match order {
|
||||
Order::Asc => ComparatorEnum::ReverseNoneLower,
|
||||
Order::Desc => ComparatorEnum::Natural,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Comparator<T> for ComparatorEnum
|
||||
where
|
||||
ReverseNoneIsLowerComparator: Comparator<T>,
|
||||
NaturalComparator: Comparator<T>,
|
||||
ReverseComparator: Comparator<T>,
|
||||
NaturalNoneIsHigherComparator: Comparator<T>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
match self {
|
||||
ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs),
|
||||
ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs),
|
||||
ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs),
|
||||
ComparatorEnum::NaturalNoneHigher => NaturalNoneIsHigherComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T> {
|
||||
match self {
|
||||
ComparatorEnum::Natural => NaturalComparator.threshold_to_valuerange(threshold),
|
||||
ComparatorEnum::Reverse => ReverseComparator.threshold_to_valuerange(threshold),
|
||||
ComparatorEnum::ReverseNoneLower => {
|
||||
ReverseNoneIsLowerComparator.threshold_to_valuerange(threshold)
|
||||
}
|
||||
ComparatorEnum::NaturalNoneHigher => {
|
||||
NaturalNoneIsHigherComparator.threshold_to_valuerange(threshold)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Head, Tail, LeftComparator, RightComparator> Comparator<(Head, Tail)>
|
||||
for (LeftComparator, RightComparator)
|
||||
where
|
||||
LeftComparator: Comparator<Head>,
|
||||
RightComparator: Comparator<Tail>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &(Head, Tail), rhs: &(Head, Tail)) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(&self, threshold: (Head, Tail)) -> ValueRange<(Head, Tail)> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, (Type2, Type3))>
|
||||
for (Comparator1, Comparator2, Comparator3)
|
||||
where
|
||||
Comparator1: Comparator<Type1>,
|
||||
Comparator2: Comparator<Type2>,
|
||||
Comparator3: Comparator<Type3>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &(Type1, (Type2, Type3)), rhs: &(Type1, (Type2, Type3))) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
|
||||
.then_with(|| self.2.compare(&lhs.1 .1, &rhs.1 .1))
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(
|
||||
&self,
|
||||
threshold: (Type1, (Type2, Type3)),
|
||||
) -> ValueRange<(Type1, (Type2, Type3))> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, Type2, Type3)>
|
||||
for (Comparator1, Comparator2, Comparator3)
|
||||
where
|
||||
Comparator1: Comparator<Type1>,
|
||||
Comparator2: Comparator<Type2>,
|
||||
Comparator3: Comparator<Type3>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &(Type1, Type2, Type3), rhs: &(Type1, Type2, Type3)) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
|
||||
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(
|
||||
&self,
|
||||
threshold: (Type1, Type2, Type3),
|
||||
) -> ValueRange<(Type1, Type2, Type3)> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
|
||||
Comparator<(Type1, (Type2, (Type3, Type4)))>
|
||||
for (Comparator1, Comparator2, Comparator3, Comparator4)
|
||||
where
|
||||
Comparator1: Comparator<Type1>,
|
||||
Comparator2: Comparator<Type2>,
|
||||
Comparator3: Comparator<Type3>,
|
||||
Comparator4: Comparator<Type4>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(
|
||||
&self,
|
||||
lhs: &(Type1, (Type2, (Type3, Type4))),
|
||||
rhs: &(Type1, (Type2, (Type3, Type4))),
|
||||
) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
|
||||
.then_with(|| self.2.compare(&lhs.1 .1 .0, &rhs.1 .1 .0))
|
||||
.then_with(|| self.3.compare(&lhs.1 .1 .1, &rhs.1 .1 .1))
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(
|
||||
&self,
|
||||
threshold: (Type1, (Type2, (Type3, Type4))),
|
||||
) -> ValueRange<(Type1, (Type2, (Type3, Type4)))> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
|
||||
Comparator<(Type1, Type2, Type3, Type4)>
|
||||
for (Comparator1, Comparator2, Comparator3, Comparator4)
|
||||
where
|
||||
Comparator1: Comparator<Type1>,
|
||||
Comparator2: Comparator<Type2>,
|
||||
Comparator3: Comparator<Type3>,
|
||||
Comparator4: Comparator<Type4>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(
|
||||
&self,
|
||||
lhs: &(Type1, Type2, Type3, Type4),
|
||||
rhs: &(Type1, Type2, Type3, Type4),
|
||||
) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
|
||||
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
|
||||
.then_with(|| self.3.compare(&lhs.3, &rhs.3))
|
||||
}
|
||||
|
||||
fn threshold_to_valuerange(
|
||||
&self,
|
||||
threshold: (Type1, Type2, Type3, Type4),
|
||||
) -> ValueRange<(Type1, Type2, Type3, Type4)> {
|
||||
ValueRange::GreaterThan(threshold, false)
|
||||
}
|
||||
}
|
||||
|
||||
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, ComparatorEnum)
|
||||
where
|
||||
TSortKeyComputer: SortKeyComputer,
|
||||
ComparatorEnum: Comparator<TSortKeyComputer::SortKey>,
|
||||
ComparatorEnum: Comparator<
|
||||
<<TSortKeyComputer as SortKeyComputer>::Child as SegmentSortKeyComputer>::SegmentSortKey,
|
||||
>,
|
||||
{
|
||||
type SortKey = TSortKeyComputer::SortKey;
|
||||
|
||||
type Child = SegmentSortKeyComputerWithComparator<TSortKeyComputer::Child, Self::Comparator>;
|
||||
|
||||
type Comparator = ComparatorEnum;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.0.requires_scoring()
|
||||
}
|
||||
|
||||
fn comparator(&self) -> Self::Comparator {
|
||||
self.1
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let child = self.0.segment_sort_key_computer(segment_reader)?;
|
||||
Ok(SegmentSortKeyComputerWithComparator {
|
||||
segment_sort_key_computer: child,
|
||||
comparator: self.comparator(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, Order)
|
||||
where
|
||||
TSortKeyComputer: SortKeyComputer,
|
||||
ComparatorEnum: Comparator<TSortKeyComputer::SortKey>,
|
||||
ComparatorEnum: Comparator<
|
||||
<<TSortKeyComputer as SortKeyComputer>::Child as SegmentSortKeyComputer>::SegmentSortKey,
|
||||
>,
|
||||
{
|
||||
type SortKey = TSortKeyComputer::SortKey;
|
||||
|
||||
type Child = SegmentSortKeyComputerWithComparator<TSortKeyComputer::Child, Self::Comparator>;
|
||||
|
||||
type Comparator = ComparatorEnum;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.0.requires_scoring()
|
||||
}
|
||||
|
||||
fn comparator(&self) -> Self::Comparator {
|
||||
self.1.into()
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let child = self.0.segment_sort_key_computer(segment_reader)?;
|
||||
Ok(SegmentSortKeyComputerWithComparator {
|
||||
segment_sort_key_computer: child,
|
||||
comparator: self.comparator(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A segment sort key computer with a custom ordering.
|
||||
pub struct SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator> {
|
||||
segment_sort_key_computer: TSegmentSortKeyComputer,
|
||||
comparator: TComparator,
|
||||
}
|
||||
|
||||
impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComputer
|
||||
for SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator>
|
||||
where
|
||||
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
|
||||
TSegmentSortKey: Clone + 'static + Sync + Send,
|
||||
TComparator: Comparator<TSegmentSortKey> + Clone + 'static + Sync + Send,
|
||||
{
|
||||
type SortKey = TSegmentSortKeyComputer::SortKey;
|
||||
type SegmentSortKey = TSegmentSortKey;
|
||||
type SegmentComparator = TComparator;
|
||||
type Buffer = TSegmentSortKeyComputer::Buffer;
|
||||
|
||||
fn segment_comparator(&self) -> Self::SegmentComparator {
|
||||
self.comparator.clone()
|
||||
}
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
self.segment_sort_key_computer.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
buffer: &mut Self::Buffer,
|
||||
filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
self.segment_sort_key_computer
|
||||
.segment_sort_keys(input_docs, output, buffer, filter)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn compare_segment_sort_key(
|
||||
&self,
|
||||
left: &Self::SegmentSortKey,
|
||||
right: &Self::SegmentSortKey,
|
||||
) -> Ordering {
|
||||
self.comparator.compare(left, right)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
self.segment_sort_key_computer
|
||||
.convert_segment_sort_key(sort_key)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::schema::OwnedValue;
|
||||
|
||||
#[test]
|
||||
fn test_mixed_ownedvalue_compare() {
|
||||
let u = OwnedValue::U64(10);
|
||||
let i = OwnedValue::I64(10);
|
||||
let f = OwnedValue::F64(10.0);
|
||||
|
||||
let nc = NaturalComparator::default();
|
||||
assert_eq!(nc.compare(&u, &i), Ordering::Equal);
|
||||
assert_eq!(nc.compare(&u, &f), Ordering::Equal);
|
||||
assert_eq!(nc.compare(&i, &f), Ordering::Equal);
|
||||
|
||||
let u2 = OwnedValue::U64(11);
|
||||
assert_eq!(nc.compare(&u2, &f), Ordering::Greater);
|
||||
|
||||
let s = OwnedValue::Str("a".to_string());
|
||||
// Str < U64
|
||||
assert_eq!(nc.compare(&s, &u), Ordering::Less);
|
||||
// Str < I64
|
||||
assert_eq!(nc.compare(&s, &i), Ordering::Less);
|
||||
// Str < F64
|
||||
assert_eq!(nc.compare(&s, &f), Ordering::Less);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_natural_none_is_higher() {
|
||||
let comp = NaturalNoneIsHigherComparator;
|
||||
let null = OwnedValue::Null;
|
||||
let v1 = OwnedValue::U64(1);
|
||||
let v2 = OwnedValue::U64(2);
|
||||
|
||||
// NaturalNoneIsGreaterComparator logic:
|
||||
// 1. Delegates to NaturalComparator for non-nulls.
|
||||
// NaturalComparator compare(2, 1) -> 2.cmp(1) -> Greater.
|
||||
assert_eq!(comp.compare(&v2, &v1), Ordering::Greater);
|
||||
|
||||
// 2. Treats None (Null) as Greater than any value.
|
||||
// compare(Null, 2) should be Greater.
|
||||
assert_eq!(comp.compare(&null, &v2), Ordering::Greater);
|
||||
|
||||
// compare(1, Null) should be Less.
|
||||
assert_eq!(comp.compare(&v1, &null), Ordering::Less);
|
||||
|
||||
// compare(Null, Null) should be Equal.
|
||||
assert_eq!(comp.compare(&null, &null), Ordering::Equal);
|
||||
}
|
||||
}
|
||||
410
src/collector/sort_key/sort_by_erased_type.rs
Normal file
410
src/collector/sort_key/sort_by_erased_type.rs
Normal file
@@ -0,0 +1,410 @@
|
||||
use columnar::{ColumnType, MonotonicallyMappableToU64, ValueRange};
|
||||
|
||||
use crate::collector::sort_key::sort_by_score::SortBySimilarityScoreSegmentComputer;
|
||||
use crate::collector::sort_key::{
|
||||
NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString,
|
||||
};
|
||||
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::fastfield::FastFieldNotAvailableError;
|
||||
use crate::schema::OwnedValue;
|
||||
use crate::{DateTime, DocId, Score};
|
||||
|
||||
/// Sort by the boxed / OwnedValue representation of either a fast field, or of the score.
|
||||
///
|
||||
/// Using the OwnedValue representation allows for type erasure, and can be useful when sort orders
|
||||
/// are not known until runtime. But it comes with a performance cost: wherever possible, prefer to
|
||||
/// use a SortKeyComputer implementation with a known-type at compile time.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SortByErasedType {
|
||||
/// Sort by a fast field
|
||||
Field(String),
|
||||
/// Sort by score
|
||||
Score,
|
||||
}
|
||||
|
||||
impl SortByErasedType {
|
||||
/// Creates a new sort key computer which will sort by the given fast field column, with type
|
||||
/// erasure.
|
||||
pub fn for_field(column_name: impl ToString) -> Self {
|
||||
Self::Field(column_name.to_string())
|
||||
}
|
||||
|
||||
/// Creates a new sort key computer which will sort by score, with type erasure.
|
||||
pub fn for_score() -> Self {
|
||||
Self::Score
|
||||
}
|
||||
}
|
||||
|
||||
trait ErasedSegmentSortKeyComputer: Send + Sync {
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64>;
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
|
||||
filter: ValueRange<Option<u64>>,
|
||||
);
|
||||
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue;
|
||||
}
|
||||
|
||||
struct ErasedSegmentSortKeyComputerWrapper<C, F>
|
||||
where
|
||||
C: SegmentSortKeyComputer<SegmentSortKey = Option<u64>> + Send + Sync,
|
||||
F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static,
|
||||
{
|
||||
inner: C,
|
||||
converter: F,
|
||||
buffer: C::Buffer,
|
||||
}
|
||||
|
||||
impl<C, F> ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper<C, F>
|
||||
where
|
||||
C: SegmentSortKeyComputer<SegmentSortKey = Option<u64>> + Send + Sync,
|
||||
F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static,
|
||||
{
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
|
||||
self.inner.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
|
||||
filter: ValueRange<Option<u64>>,
|
||||
) {
|
||||
self.inner
|
||||
.segment_sort_keys(input_docs, output, &mut self.buffer, filter)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
|
||||
let val = self.inner.convert_segment_sort_key(sort_key);
|
||||
(self.converter)(val)
|
||||
}
|
||||
}
|
||||
|
||||
struct ScoreSegmentSortKeyComputer {
|
||||
segment_computer: SortBySimilarityScoreSegmentComputer,
|
||||
}
|
||||
|
||||
impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
|
||||
let score_value: f64 = self.segment_computer.segment_sort_key(doc, score).into();
|
||||
Some(score_value.to_u64())
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
_input_docs: &[DocId],
|
||||
_output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
|
||||
_filter: ValueRange<Option<u64>>,
|
||||
) {
|
||||
unimplemented!("Batch computation not supported for score sorting")
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
|
||||
let score_value: u64 = sort_key.expect("This implementation always produces a score.");
|
||||
OwnedValue::F64(f64::from_u64(score_value))
|
||||
}
|
||||
}
|
||||
|
||||
impl SortKeyComputer for SortByErasedType {
|
||||
type SortKey = OwnedValue;
|
||||
type Child = ErasedColumnSegmentSortKeyComputer;
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
matches!(self, Self::Score)
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let inner: Box<dyn ErasedSegmentSortKeyComputer> = match self {
|
||||
Self::Field(column_name) => {
|
||||
let fast_fields = segment_reader.fast_fields();
|
||||
// TODO: We currently double-open the column to avoid relying on the implementation
|
||||
// details of `SortByString` or `SortByStaticFastValue`. Once
|
||||
// https://github.com/quickwit-oss/tantivy/issues/2776 is resolved, we should
|
||||
// consider directly constructing the appropriate `SegmentSortKeyComputer` type for
|
||||
// the column that we open here.
|
||||
let (_column, column_type) =
|
||||
fast_fields.u64_lenient(column_name)?.ok_or_else(|| {
|
||||
FastFieldNotAvailableError {
|
||||
field_name: column_name.to_owned(),
|
||||
}
|
||||
})?;
|
||||
|
||||
match column_type {
|
||||
ColumnType::Str => {
|
||||
let computer = SortByString::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<String>| {
|
||||
val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
buffer: Default::default(),
|
||||
})
|
||||
}
|
||||
ColumnType::U64 => {
|
||||
let computer = SortByStaticFastValue::<u64>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<u64>| {
|
||||
val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
buffer: Default::default(),
|
||||
})
|
||||
}
|
||||
ColumnType::I64 => {
|
||||
let computer = SortByStaticFastValue::<i64>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<i64>| {
|
||||
val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
buffer: Default::default(),
|
||||
})
|
||||
}
|
||||
ColumnType::F64 => {
|
||||
let computer = SortByStaticFastValue::<f64>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<f64>| {
|
||||
val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
buffer: Default::default(),
|
||||
})
|
||||
}
|
||||
ColumnType::Bool => {
|
||||
let computer = SortByStaticFastValue::<bool>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<bool>| {
|
||||
val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
buffer: Default::default(),
|
||||
})
|
||||
}
|
||||
ColumnType::DateTime => {
|
||||
let computer = SortByStaticFastValue::<DateTime>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<DateTime>| {
|
||||
val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
buffer: Default::default(),
|
||||
})
|
||||
}
|
||||
column_type => {
|
||||
return Err(crate::TantivyError::SchemaError(format!(
|
||||
"Field `{}` is of type {column_type:?}, which is not supported for \
|
||||
sorting by owned value yet.",
|
||||
column_name
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::Score => Box::new(ScoreSegmentSortKeyComputer {
|
||||
segment_computer: SortBySimilarityScore
|
||||
.segment_sort_key_computer(segment_reader)?,
|
||||
}),
|
||||
};
|
||||
Ok(ErasedColumnSegmentSortKeyComputer { inner })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ErasedColumnSegmentSortKeyComputer {
|
||||
inner: Box<dyn ErasedSegmentSortKeyComputer>,
|
||||
}
|
||||
|
||||
impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer {
|
||||
type SortKey = OwnedValue;
|
||||
type SegmentSortKey = Option<u64>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
type Buffer = ();
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
|
||||
self.inner.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
_buffer: &mut Self::Buffer,
|
||||
filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
self.inner.segment_sort_keys(input_docs, output, filter)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue {
|
||||
self.inner.convert_segment_sort_key(segment_sort_key)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::collector::sort_key::{ComparatorEnum, SortByErasedType};
|
||||
use crate::collector::TopDocs;
|
||||
use crate::query::AllQuery;
|
||||
use crate::schema::{OwnedValue, Schema, FAST, TEXT};
|
||||
use crate::Index;
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_u64() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let id_field = schema_builder.add_u64_field("id", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer.add_document(doc!(id_field => 10u64)).unwrap();
|
||||
writer.add_document(doc!(id_field => 2u64)).unwrap();
|
||||
writer.add_document(doc!()).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let collector = TopDocs::with_limit(10)
|
||||
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Natural));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![OwnedValue::U64(10), OwnedValue::U64(2), OwnedValue::Null]
|
||||
);
|
||||
|
||||
let collector = TopDocs::with_limit(10).order_by((
|
||||
SortByErasedType::for_field("id"),
|
||||
ComparatorEnum::ReverseNoneLower,
|
||||
));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![OwnedValue::U64(2), OwnedValue::U64(10), OwnedValue::Null]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_string() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let city_field = schema_builder.add_text_field("city", FAST | TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer.add_document(doc!(city_field => "tokyo")).unwrap();
|
||||
writer.add_document(doc!(city_field => "austin")).unwrap();
|
||||
writer.add_document(doc!()).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let collector = TopDocs::with_limit(10).order_by((
|
||||
SortByErasedType::for_field("city"),
|
||||
ComparatorEnum::ReverseNoneLower,
|
||||
));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![
|
||||
OwnedValue::Str("austin".to_string()),
|
||||
OwnedValue::Str("tokyo".to_string()),
|
||||
OwnedValue::Null
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_reverse() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let id_field = schema_builder.add_u64_field("id", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer.add_document(doc!(id_field => 10u64)).unwrap();
|
||||
writer.add_document(doc!(id_field => 2u64)).unwrap();
|
||||
writer.add_document(doc!()).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let collector = TopDocs::with_limit(10)
|
||||
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Reverse));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![OwnedValue::Null, OwnedValue::U64(2), OwnedValue::U64(10)]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_score() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let body_field = schema_builder.add_text_field("body", TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer.add_document(doc!(body_field => "a a")).unwrap();
|
||||
writer.add_document(doc!(body_field => "a")).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let query_parser = crate::query::QueryParser::for_index(&index, vec![body_field]);
|
||||
let query = query_parser.parse_query("a").unwrap();
|
||||
|
||||
// Sort by score descending (Natural)
|
||||
let collector = TopDocs::with_limit(10)
|
||||
.order_by((SortByErasedType::for_score(), ComparatorEnum::Natural));
|
||||
let top_docs = searcher.search(&query, &collector).unwrap();
|
||||
|
||||
let values: Vec<f64> = top_docs
|
||||
.into_iter()
|
||||
.map(|(key, _)| match key {
|
||||
OwnedValue::F64(val) => val,
|
||||
_ => panic!("Wrong type {:?}", key),
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert_eq!(values.len(), 2);
|
||||
assert!(values[0] > values[1]);
|
||||
|
||||
// Sort by score ascending (ReverseNoneLower)
|
||||
let collector = TopDocs::with_limit(10).order_by((
|
||||
SortByErasedType::for_score(),
|
||||
ComparatorEnum::ReverseNoneLower,
|
||||
));
|
||||
let top_docs = searcher.search(&query, &collector).unwrap();
|
||||
|
||||
let values: Vec<f64> = top_docs
|
||||
.into_iter()
|
||||
.map(|(key, _)| match key {
|
||||
OwnedValue::F64(val) => val,
|
||||
_ => panic!("Wrong type {:?}", key),
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert_eq!(values.len(), 2);
|
||||
assert!(values[0] < values[1]);
|
||||
}
|
||||
}
|
||||
92
src/collector/sort_key/sort_by_score.rs
Normal file
92
src/collector/sort_key/sort_by_score.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
use columnar::ValueRange;
|
||||
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
|
||||
use crate::{DocAddress, DocId, Score};
|
||||
|
||||
/// Sort by similarity score.
|
||||
#[derive(Clone, Debug, Copy)]
|
||||
pub struct SortBySimilarityScore;
|
||||
|
||||
impl SortKeyComputer for SortBySimilarityScore {
|
||||
type SortKey = Score;
|
||||
|
||||
type Child = SortBySimilarityScoreSegmentComputer;
|
||||
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
_segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
Ok(SortBySimilarityScoreSegmentComputer)
|
||||
}
|
||||
|
||||
// Sorting by score is special in that it allows for the Block-Wand optimization.
|
||||
fn collect_segment_top_k(
|
||||
&self,
|
||||
k: usize,
|
||||
weight: &dyn crate::query::Weight,
|
||||
reader: &crate::SegmentReader,
|
||||
segment_ord: u32,
|
||||
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
|
||||
let mut top_n: TopNComputer<Score, DocId, Self::Comparator> =
|
||||
TopNComputer::new_with_comparator(k, self.comparator());
|
||||
|
||||
if let Some(alive_bitset) = reader.alive_bitset() {
|
||||
let mut threshold = Score::MIN;
|
||||
top_n.threshold = Some(threshold);
|
||||
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
|
||||
if alive_bitset.is_deleted(doc) {
|
||||
return threshold;
|
||||
}
|
||||
top_n.push(score, doc);
|
||||
threshold = top_n.threshold.unwrap_or(Score::MIN);
|
||||
threshold
|
||||
})?;
|
||||
} else {
|
||||
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
|
||||
top_n.push(score, doc);
|
||||
top_n.threshold.unwrap_or(Score::MIN)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(top_n
|
||||
.into_vec()
|
||||
.into_iter()
|
||||
.map(|cid| (cid.sort_key, DocAddress::new(segment_ord, cid.doc)))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SortBySimilarityScoreSegmentComputer;
|
||||
|
||||
impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer {
|
||||
type SortKey = Score;
|
||||
type SegmentSortKey = Score;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
type Buffer = ();
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {
|
||||
score
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
_input_docs: &[DocId],
|
||||
_output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
_buffer: &mut Self::Buffer,
|
||||
_filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
unimplemented!("Batch computation not supported for score sorting")
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, score: Score) -> Score {
|
||||
score
|
||||
}
|
||||
}
|
||||
194
src/collector/sort_key/sort_by_static_fast_value.rs
Normal file
194
src/collector/sort_key/sort_by_static_fast_value.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use columnar::{Column, ValueRange};
|
||||
|
||||
use crate::collector::sort_key::sort_key_computer::convert_optional_u64_range_to_u64_range;
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
|
||||
use crate::{DocId, Score, SegmentReader};
|
||||
|
||||
/// Sorts by a fast value (u64, i64, f64, bool).
|
||||
///
|
||||
/// The field must appear explicitly in the schema, with the right type, and declared as
|
||||
/// a fast field..
|
||||
///
|
||||
/// If the field is multivalued, only the first value is considered.
|
||||
///
|
||||
/// Documents that do not have this value are still considered.
|
||||
/// Their sort key will simply be `None`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SortByStaticFastValue<T: FastValue> {
|
||||
field: String,
|
||||
typ: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: FastValue> SortByStaticFastValue<T> {
|
||||
/// Creates a new `SortByStaticFastValue` instance for the given field.
|
||||
pub fn for_field(column_name: impl ToString) -> SortByStaticFastValue<T> {
|
||||
Self {
|
||||
field: column_name.to_string(),
|
||||
typ: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
|
||||
type Child = SortByFastValueSegmentSortKeyComputer<T>;
|
||||
type SortKey = Option<T>;
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> {
|
||||
// At the segment sort key computer level, we rely on the u64 representation.
|
||||
// The mapping is monotonic, so it is sufficient to compute our top-K docs.
|
||||
let field = schema.get_field(&self.field)?;
|
||||
let field_entry = schema.get_field_entry(field);
|
||||
if !field_entry.is_fast() {
|
||||
return Err(crate::TantivyError::SchemaError(format!(
|
||||
"Field `{}` is not a fast field.",
|
||||
self.field,
|
||||
)));
|
||||
}
|
||||
let schema_type = field_entry.field_type().value_type();
|
||||
if schema_type != T::to_type() {
|
||||
return Err(crate::TantivyError::SchemaError(format!(
|
||||
"Field `{}` is of type {schema_type:?}, not of the type {:?}.",
|
||||
&self.field,
|
||||
T::to_type()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let sort_column_opt = segment_reader.fast_fields().u64_lenient(&self.field)?;
|
||||
let (sort_column, _sort_column_type) =
|
||||
sort_column_opt.ok_or_else(|| FastFieldNotAvailableError {
|
||||
field_name: self.field.clone(),
|
||||
})?;
|
||||
Ok(SortByFastValueSegmentSortKeyComputer {
|
||||
sort_column,
|
||||
typ: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SortByFastValueSegmentSortKeyComputer<T> {
|
||||
sort_column: Column<u64>,
|
||||
typ: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
|
||||
type SortKey = Option<T>;
|
||||
type SegmentSortKey = Option<u64>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
type Buffer = ();
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {
|
||||
self.sort_column.first(doc)
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
_buffer: &mut Self::Buffer,
|
||||
filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
|
||||
self.sort_column
|
||||
.first_vals_in_value_range(input_docs, output, u64_filter);
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
sort_key.map(T::from_u64)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::schema::{Schema, FAST};
|
||||
use crate::Index;
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_fast_value_batch() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field_col = schema_builder.add_u64_field("field", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => 10u64))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => 20u64))
|
||||
.unwrap();
|
||||
index_writer.add_document(crate::doc!()).unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let segment_reader = searcher.segment_reader(0);
|
||||
|
||||
let sorter = SortByStaticFastValue::<u64>::for_field("field");
|
||||
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
|
||||
|
||||
let mut docs = vec![0, 1, 2];
|
||||
let mut output = Vec::new();
|
||||
let mut buffer = ();
|
||||
computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All);
|
||||
|
||||
assert_eq!(
|
||||
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
|
||||
&[Some(10), Some(20), None]
|
||||
);
|
||||
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[0, 1, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_fast_value_batch_with_filter() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field_col = schema_builder.add_u64_field("field", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => 10u64))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => 20u64))
|
||||
.unwrap();
|
||||
index_writer.add_document(crate::doc!()).unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let segment_reader = searcher.segment_reader(0);
|
||||
|
||||
let sorter = SortByStaticFastValue::<u64>::for_field("field");
|
||||
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
|
||||
|
||||
let mut docs = vec![0, 1, 2];
|
||||
let mut output = Vec::new();
|
||||
let mut buffer = ();
|
||||
computer.segment_sort_keys(
|
||||
&mut docs,
|
||||
&mut output,
|
||||
&mut buffer,
|
||||
ValueRange::GreaterThan(Some(15u64), false /* inclusive */),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
|
||||
&[Some(20)]
|
||||
);
|
||||
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[1]);
|
||||
}
|
||||
}
|
||||
185
src/collector/sort_key/sort_by_string.rs
Normal file
185
src/collector/sort_key/sort_by_string.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
use columnar::{StrColumn, ValueRange};
|
||||
|
||||
use crate::collector::sort_key::sort_key_computer::{
|
||||
convert_optional_u64_range_to_u64_range, range_contains_none,
|
||||
};
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::termdict::TermOrdinal;
|
||||
use crate::{DocId, Score};
|
||||
|
||||
/// Sort by the first value of a string column.
|
||||
///
|
||||
/// The string can be dynamic (coming from a json field)
|
||||
/// or static (being specificaly defined in the configuration).
|
||||
///
|
||||
/// If the field is multivalued, only the first value is considered.
|
||||
///
|
||||
/// Documents that do not have this value are still considered.
|
||||
/// Their sort key will simply be `None`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SortByString {
|
||||
column_name: String,
|
||||
}
|
||||
|
||||
impl SortByString {
|
||||
/// Creates a new sort by string sort key computer.
|
||||
pub fn for_field(column_name: impl ToString) -> Self {
|
||||
SortByString {
|
||||
column_name: column_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SortKeyComputer for SortByString {
|
||||
type SortKey = Option<String>;
|
||||
type Child = ByStringColumnSegmentSortKeyComputer;
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?;
|
||||
Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ByStringColumnSegmentSortKeyComputer {
|
||||
str_column_opt: Option<StrColumn>,
|
||||
}
|
||||
|
||||
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
|
||||
type SortKey = Option<String>;
|
||||
type SegmentSortKey = Option<TermOrdinal>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
type Buffer = ();
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
|
||||
let str_column = self.str_column_opt.as_ref()?;
|
||||
str_column.ords().first(doc)
|
||||
}
|
||||
|
||||
fn segment_sort_keys(
|
||||
&mut self,
|
||||
input_docs: &[DocId],
|
||||
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
|
||||
_buffer: &mut Self::Buffer,
|
||||
filter: ValueRange<Self::SegmentSortKey>,
|
||||
) {
|
||||
if let Some(str_column) = &self.str_column_opt {
|
||||
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
|
||||
str_column
|
||||
.ords()
|
||||
.first_vals_in_value_range(input_docs, output, u64_filter);
|
||||
} else if range_contains_none(&filter) {
|
||||
for &doc in input_docs {
|
||||
output.push(ComparableDoc {
|
||||
doc,
|
||||
sort_key: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
|
||||
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
|
||||
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
|
||||
let term_ord = term_ord_opt?;
|
||||
let str_column = self.str_column_opt.as_ref()?;
|
||||
let mut bytes = Vec::new();
|
||||
str_column
|
||||
.dictionary()
|
||||
.ord_to_term(term_ord, &mut bytes)
|
||||
.ok()?;
|
||||
String::try_from(bytes).ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::schema::{Schema, FAST, TEXT};
|
||||
use crate::Index;
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_string_batch() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => "a"))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => "c"))
|
||||
.unwrap();
|
||||
index_writer.add_document(crate::doc!()).unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let segment_reader = searcher.segment_reader(0);
|
||||
|
||||
let sorter = SortByString::for_field("field");
|
||||
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
|
||||
|
||||
let mut docs = vec![0, 1, 2];
|
||||
let mut output = Vec::new();
|
||||
let mut buffer = ();
|
||||
computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All);
|
||||
|
||||
assert_eq!(
|
||||
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
|
||||
&[Some(0), Some(1), None]
|
||||
);
|
||||
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[0, 1, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_string_batch_with_filter() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => "a"))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(crate::doc!(field_col => "c"))
|
||||
.unwrap();
|
||||
index_writer.add_document(crate::doc!()).unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let segment_reader = searcher.segment_reader(0);
|
||||
|
||||
let sorter = SortByString::for_field("field");
|
||||
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
|
||||
|
||||
let mut docs = vec![0, 1, 2];
|
||||
let mut output = Vec::new();
|
||||
// Filter: > "b". "a" is 0, "c" is 1.
|
||||
// We want > "a" (ord 0). So we filter > ord 0.
|
||||
// 0 is "a", 1 is "c".
|
||||
let mut buffer = ();
|
||||
computer.segment_sort_keys(
|
||||
&mut docs,
|
||||
&mut output,
|
||||
&mut buffer,
|
||||
ValueRange::GreaterThan(Some(0), false /* inclusive */),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
|
||||
&[Some(1)]
|
||||
);
|
||||
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[1]);
|
||||
}
|
||||
}
|
||||
1069
src/collector/sort_key/sort_key_computer.rs
Normal file
1069
src/collector/sort_key/sort_key_computer.rs
Normal file
File diff suppressed because it is too large
Load Diff
203
src/collector/sort_key_top_collector.rs
Normal file
203
src/collector/sort_key_top_collector.rs
Normal file
@@ -0,0 +1,203 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::collector::sort_key::{Comparator, SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::collector::{Collector, SegmentCollector, TopNComputer};
|
||||
use crate::query::Weight;
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
|
||||
|
||||
pub(crate) struct TopBySortKeyCollector<TSortKeyComputer> {
|
||||
sort_key_computer: TSortKeyComputer,
|
||||
doc_range: Range<usize>,
|
||||
}
|
||||
|
||||
impl<TSortKeyComputer> TopBySortKeyCollector<TSortKeyComputer> {
|
||||
pub fn new(sort_key_computer: TSortKeyComputer, doc_range: Range<usize>) -> Self {
|
||||
TopBySortKeyCollector {
|
||||
sort_key_computer,
|
||||
doc_range,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TSortKeyComputer> Collector for TopBySortKeyCollector<TSortKeyComputer>
|
||||
where TSortKeyComputer: SortKeyComputer + Send + Sync + 'static
|
||||
{
|
||||
type Fruit = Vec<(TSortKeyComputer::SortKey, DocAddress)>;
|
||||
|
||||
type Child =
|
||||
TopBySortKeySegmentCollector<TSortKeyComputer::Child, TSortKeyComputer::Comparator>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.sort_key_computer.check_schema(schema)
|
||||
}
|
||||
|
||||
fn for_segment(&self, segment_ord: u32, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
let segment_sort_key_computer = self
|
||||
.sort_key_computer
|
||||
.segment_sort_key_computer(segment_reader)?;
|
||||
let topn_computer = TopNComputer::new_with_comparator(
|
||||
self.doc_range.end,
|
||||
self.sort_key_computer.comparator(),
|
||||
);
|
||||
Ok(TopBySortKeySegmentCollector {
|
||||
topn_computer,
|
||||
segment_ord,
|
||||
segment_sort_key_computer,
|
||||
})
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.sort_key_computer.requires_scoring()
|
||||
}
|
||||
|
||||
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
|
||||
Ok(merge_top_k(
|
||||
segment_fruits.into_iter().flatten(),
|
||||
self.doc_range.clone(),
|
||||
self.sort_key_computer.comparator(),
|
||||
))
|
||||
}
|
||||
|
||||
fn collect_segment(
|
||||
&self,
|
||||
weight: &dyn Weight,
|
||||
segment_ord: u32,
|
||||
reader: &SegmentReader,
|
||||
) -> crate::Result<Vec<(TSortKeyComputer::SortKey, DocAddress)>> {
|
||||
let k = self.doc_range.end;
|
||||
let docs = self
|
||||
.sort_key_computer
|
||||
.collect_segment_top_k(k, weight, reader, segment_ord)?;
|
||||
Ok(docs)
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_top_k<D: Ord, TSortKey: Clone + std::fmt::Debug, C: Comparator<TSortKey>>(
|
||||
sort_key_docs: impl Iterator<Item = (TSortKey, D)>,
|
||||
doc_range: Range<usize>,
|
||||
comparator: C,
|
||||
) -> Vec<(TSortKey, D)> {
|
||||
if doc_range.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
let mut top_collector: TopNComputer<TSortKey, D, C> =
|
||||
TopNComputer::new_with_comparator(doc_range.end, comparator);
|
||||
for (sort_key, doc) in sort_key_docs {
|
||||
top_collector.push(sort_key, doc);
|
||||
}
|
||||
top_collector
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.skip(doc_range.start)
|
||||
.map(|cdoc| (cdoc.sort_key, cdoc.doc))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub struct TopBySortKeySegmentCollector<TSegmentSortKeyComputer, C>
|
||||
where
|
||||
TSegmentSortKeyComputer: SegmentSortKeyComputer,
|
||||
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey>,
|
||||
{
|
||||
pub(crate) topn_computer: TopNComputer<
|
||||
TSegmentSortKeyComputer::SegmentSortKey,
|
||||
DocId,
|
||||
C,
|
||||
TSegmentSortKeyComputer::Buffer,
|
||||
>,
|
||||
pub(crate) segment_ord: u32,
|
||||
pub(crate) segment_sort_key_computer: TSegmentSortKeyComputer,
|
||||
}
|
||||
|
||||
impl<TSegmentSortKeyComputer, C> SegmentCollector
|
||||
for TopBySortKeySegmentCollector<TSegmentSortKeyComputer, C>
|
||||
where
|
||||
TSegmentSortKeyComputer: 'static + SegmentSortKeyComputer,
|
||||
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey> + 'static,
|
||||
{
|
||||
type Fruit = Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)>;
|
||||
|
||||
fn collect(&mut self, doc: DocId, score: Score) {
|
||||
self.segment_sort_key_computer.compute_sort_key_and_collect(
|
||||
doc,
|
||||
score,
|
||||
&mut self.topn_computer,
|
||||
);
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.segment_sort_key_computer
|
||||
.compute_sort_keys_and_collect(docs, &mut self.topn_computer);
|
||||
}
|
||||
|
||||
fn harvest(self) -> Self::Fruit {
|
||||
let segment_ord = self.segment_ord;
|
||||
let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self
|
||||
.topn_computer
|
||||
.into_vec()
|
||||
.into_iter()
|
||||
.map(|comparable_doc| {
|
||||
let sort_key = self
|
||||
.segment_sort_key_computer
|
||||
.convert_segment_sort_key(comparable_doc.sort_key);
|
||||
(
|
||||
sort_key,
|
||||
DocAddress {
|
||||
segment_ord,
|
||||
doc_id: comparable_doc.doc,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
segment_hits
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::Range;
|
||||
|
||||
use rand;
|
||||
use rand::seq::SliceRandom as _;
|
||||
|
||||
use super::merge_top_k;
|
||||
use crate::collector::sort_key::ComparatorEnum;
|
||||
use crate::Order;
|
||||
|
||||
fn test_merge_top_k_aux(
|
||||
order: Order,
|
||||
doc_range: Range<usize>,
|
||||
expected: &[(crate::Score, usize)],
|
||||
) {
|
||||
let mut vals: Vec<(crate::Score, usize)> = (0..10).map(|val| (val as f32, val)).collect();
|
||||
vals.shuffle(&mut rand::thread_rng());
|
||||
let vals_merged = merge_top_k(vals.into_iter(), doc_range, ComparatorEnum::from(order));
|
||||
assert_eq!(&vals_merged, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_top_k() {
|
||||
test_merge_top_k_aux(Order::Asc, 0..0, &[]);
|
||||
test_merge_top_k_aux(Order::Asc, 3..3, &[]);
|
||||
test_merge_top_k_aux(Order::Asc, 0..3, &[(0.0f32, 0), (1.0f32, 1), (2.0f32, 2)]);
|
||||
test_merge_top_k_aux(
|
||||
Order::Asc,
|
||||
0..11,
|
||||
&[
|
||||
(0.0f32, 0),
|
||||
(1.0f32, 1),
|
||||
(2.0f32, 2),
|
||||
(3.0f32, 3),
|
||||
(4.0f32, 4),
|
||||
(5.0f32, 5),
|
||||
(6.0f32, 6),
|
||||
(7.0f32, 7),
|
||||
(8.0f32, 8),
|
||||
(9.0f32, 9),
|
||||
],
|
||||
);
|
||||
test_merge_top_k_aux(Order::Asc, 1..3, &[(1.0f32, 1), (2.0f32, 2)]);
|
||||
test_merge_top_k_aux(Order::Desc, 0..2, &[(9.0f32, 9), (8.0f32, 8)]);
|
||||
test_merge_top_k_aux(Order::Desc, 2..4, &[(7.0f32, 7), (6.0f32, 6)]);
|
||||
}
|
||||
}
|
||||
@@ -40,7 +40,7 @@ pub fn test_filter_collector() -> crate::Result<()> {
|
||||
let filter_some_collector = FilterCollector::new(
|
||||
"price".to_string(),
|
||||
&|value: u64| value > 20_120u64,
|
||||
TopDocs::with_limit(2),
|
||||
TopDocs::with_limit(2).order_by_score(),
|
||||
);
|
||||
let top_docs = searcher.search(&query, &filter_some_collector)?;
|
||||
|
||||
@@ -50,7 +50,7 @@ pub fn test_filter_collector() -> crate::Result<()> {
|
||||
let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(
|
||||
"price".to_string(),
|
||||
&|value| value < 5u64,
|
||||
TopDocs::with_limit(2),
|
||||
TopDocs::with_limit(2).order_by_score(),
|
||||
);
|
||||
let filtered_top_docs = searcher.search(&query, &filter_all_collector).unwrap();
|
||||
|
||||
@@ -62,8 +62,11 @@ pub fn test_filter_collector() -> crate::Result<()> {
|
||||
> 0
|
||||
}
|
||||
|
||||
let filter_dates_collector =
|
||||
FilterCollector::new("date".to_string(), &date_filter, TopDocs::with_limit(5));
|
||||
let filter_dates_collector = FilterCollector::new(
|
||||
"date".to_string(),
|
||||
&date_filter,
|
||||
TopDocs::with_limit(5).order_by_score(),
|
||||
);
|
||||
let filtered_date_docs = searcher.search(&query, &filter_dates_collector)?;
|
||||
|
||||
assert_eq!(filtered_date_docs.len(), 2);
|
||||
|
||||
@@ -1,374 +0,0 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::top_score_collector::TopNComputer;
|
||||
use crate::index::SegmentReader;
|
||||
use crate::{DocAddress, DocId, SegmentOrdinal};
|
||||
|
||||
/// Contains a feature (field, score, etc.) of a document along with the document address.
|
||||
///
|
||||
/// It guarantees stable sorting: in case of a tie on the feature, the document
|
||||
/// address is used.
|
||||
///
|
||||
/// The REVERSE_ORDER generic parameter controls whether the by-feature order
|
||||
/// should be reversed, which is useful for achieving for example largest-first
|
||||
/// semantics without having to wrap the feature in a `Reverse`.
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
|
||||
/// The feature of the document. In practice, this is
|
||||
/// is any type that implements `PartialOrd`.
|
||||
pub feature: T,
|
||||
/// The document address. In practice, this is any
|
||||
/// type that implements `PartialOrd`, and is guaranteed
|
||||
/// to be unique for each document.
|
||||
pub doc: D,
|
||||
}
|
||||
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
|
||||
for ComparableDoc<T, D, R>
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
|
||||
.field("feature", &self.feature)
|
||||
.field("doc", &self.doc)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialOrd for ComparableDoc<T, D, R> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R> {
|
||||
#[inline]
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
let by_feature = self
|
||||
.feature
|
||||
.partial_cmp(&other.feature)
|
||||
.map(|ord| if R { ord.reverse() } else { ord })
|
||||
.unwrap_or(Ordering::Equal);
|
||||
|
||||
let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal);
|
||||
|
||||
// In case of a tie on the feature, we sort by ascending
|
||||
// `DocAddress` in order to ensure a stable sorting of the
|
||||
// documents.
|
||||
by_feature.then_with(lazy_by_doc_address)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T, D, R> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.cmp(other) == Ordering::Equal
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}
|
||||
|
||||
pub(crate) struct TopCollector<T> {
|
||||
pub limit: usize,
|
||||
pub offset: usize,
|
||||
_marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> TopCollector<T>
|
||||
where T: PartialOrd + Clone
|
||||
{
|
||||
/// Creates a top collector, with a number of documents equal to "limit".
|
||||
///
|
||||
/// # Panics
|
||||
/// The method panics if limit is 0
|
||||
pub fn with_limit(limit: usize) -> TopCollector<T> {
|
||||
assert!(limit >= 1, "Limit must be strictly greater than 0.");
|
||||
Self {
|
||||
limit,
|
||||
offset: 0,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Skip the first "offset" documents when collecting.
|
||||
///
|
||||
/// This is equivalent to `OFFSET` in MySQL or PostgreSQL and `start` in
|
||||
/// Lucene's TopDocsCollector.
|
||||
pub fn and_offset(mut self, offset: usize) -> TopCollector<T> {
|
||||
self.offset = offset;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn merge_fruits(
|
||||
&self,
|
||||
children: Vec<Vec<(T, DocAddress)>>,
|
||||
) -> crate::Result<Vec<(T, DocAddress)>> {
|
||||
if self.limit == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut top_collector: TopNComputer<_, _> = TopNComputer::new(self.limit + self.offset);
|
||||
for child_fruit in children {
|
||||
for (feature, doc) in child_fruit {
|
||||
top_collector.push(feature, doc);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(top_collector
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.skip(self.offset)
|
||||
.map(|cdoc| (cdoc.feature, cdoc.doc))
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub(crate) fn for_segment<F: PartialOrd + Clone>(
|
||||
&self,
|
||||
segment_id: SegmentOrdinal,
|
||||
_: &SegmentReader,
|
||||
) -> TopSegmentCollector<F> {
|
||||
TopSegmentCollector::new(segment_id, self.limit + self.offset)
|
||||
}
|
||||
|
||||
/// Create a new TopCollector with the same limit and offset.
|
||||
///
|
||||
/// Ideally we would use Into but the blanket implementation seems to cause the Scorer traits
|
||||
/// to fail.
|
||||
#[doc(hidden)]
|
||||
pub(crate) fn into_tscore<TScore: PartialOrd + Clone>(self) -> TopCollector<TScore> {
|
||||
TopCollector {
|
||||
limit: self.limit,
|
||||
offset: self.offset,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The Top Collector keeps track of the K documents
|
||||
/// sorted by type `T`.
|
||||
///
|
||||
/// The implementation is based on a repeatedly truncating on the median after K * 2 documents
|
||||
/// The theoretical complexity for collecting the top `K` out of `n` documents
|
||||
/// is `O(n + K)`.
|
||||
pub(crate) struct TopSegmentCollector<T> {
|
||||
/// We reverse the order of the feature in order to
|
||||
/// have top-semantics instead of bottom semantics.
|
||||
topn_computer: TopNComputer<T, DocId>,
|
||||
segment_ord: u32,
|
||||
}
|
||||
|
||||
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
|
||||
fn new(segment_ord: SegmentOrdinal, limit: usize) -> TopSegmentCollector<T> {
|
||||
TopSegmentCollector {
|
||||
topn_computer: TopNComputer::new(limit),
|
||||
segment_ord,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
|
||||
pub fn harvest(self) -> Vec<(T, DocAddress)> {
|
||||
let segment_ord = self.segment_ord;
|
||||
self.topn_computer
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.map(|comparable_doc| {
|
||||
(
|
||||
comparable_doc.feature,
|
||||
DocAddress {
|
||||
segment_ord,
|
||||
doc_id: comparable_doc.doc,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Collects a document scored by the given feature
|
||||
///
|
||||
/// It collects documents until it has reached the max capacity. Once it reaches capacity, it
|
||||
/// will compare the lowest scoring item with the given one and keep whichever is greater.
|
||||
#[inline]
|
||||
pub fn collect(&mut self, doc: DocId, feature: T) {
|
||||
self.topn_computer.push(feature, doc);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{TopCollector, TopSegmentCollector};
|
||||
use crate::DocAddress;
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_not_at_capacity() {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 4);
|
||||
top_collector.collect(1, 0.8);
|
||||
top_collector.collect(3, 0.2);
|
||||
top_collector.collect(5, 0.3);
|
||||
assert_eq!(
|
||||
top_collector.harvest(),
|
||||
vec![
|
||||
(0.8, DocAddress::new(0, 1)),
|
||||
(0.3, DocAddress::new(0, 5)),
|
||||
(0.2, DocAddress::new(0, 3))
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_at_capacity() {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 4);
|
||||
top_collector.collect(1, 0.8);
|
||||
top_collector.collect(3, 0.2);
|
||||
top_collector.collect(5, 0.3);
|
||||
top_collector.collect(7, 0.9);
|
||||
top_collector.collect(9, -0.2);
|
||||
assert_eq!(
|
||||
top_collector.harvest(),
|
||||
vec![
|
||||
(0.9, DocAddress::new(0, 7)),
|
||||
(0.8, DocAddress::new(0, 1)),
|
||||
(0.3, DocAddress::new(0, 5)),
|
||||
(0.2, DocAddress::new(0, 3))
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_segment_collector_stable_ordering_for_equal_feature() {
|
||||
// given that the documents are collected in ascending doc id order,
|
||||
// when harvesting we have to guarantee stable sorting in case of a tie
|
||||
// on the score
|
||||
let doc_ids_collection = [4, 5, 6];
|
||||
let score = 3.3f32;
|
||||
|
||||
let mut top_collector_limit_2 = TopSegmentCollector::new(0, 2);
|
||||
for id in &doc_ids_collection {
|
||||
top_collector_limit_2.collect(*id, score);
|
||||
}
|
||||
|
||||
let mut top_collector_limit_3 = TopSegmentCollector::new(0, 3);
|
||||
for id in &doc_ids_collection {
|
||||
top_collector_limit_3.collect(*id, score);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
top_collector_limit_2.harvest(),
|
||||
top_collector_limit_3.harvest()[..2].to_vec(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_with_limit_and_offset() {
|
||||
let collector = TopCollector::with_limit(2).and_offset(1);
|
||||
|
||||
let results = collector
|
||||
.merge_fruits(vec![vec![
|
||||
(0.9, DocAddress::new(0, 1)),
|
||||
(0.8, DocAddress::new(0, 2)),
|
||||
(0.7, DocAddress::new(0, 3)),
|
||||
(0.6, DocAddress::new(0, 4)),
|
||||
(0.5, DocAddress::new(0, 5)),
|
||||
]])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![(0.8, DocAddress::new(0, 2)), (0.7, DocAddress::new(0, 3)),]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_with_limit_larger_than_set_and_offset() {
|
||||
let collector = TopCollector::with_limit(2).and_offset(1);
|
||||
|
||||
let results = collector
|
||||
.merge_fruits(vec![vec![
|
||||
(0.9, DocAddress::new(0, 1)),
|
||||
(0.8, DocAddress::new(0, 2)),
|
||||
]])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results, vec![(0.8, DocAddress::new(0, 2)),]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_with_limit_and_offset_larger_than_set() {
|
||||
let collector = TopCollector::with_limit(2).and_offset(20);
|
||||
|
||||
let results = collector
|
||||
.merge_fruits(vec![vec![
|
||||
(0.9, DocAddress::new(0, 1)),
|
||||
(0.8, DocAddress::new(0, 2)),
|
||||
]])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results, vec![]);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
use test::Bencher;
|
||||
|
||||
use super::TopSegmentCollector;
|
||||
|
||||
#[bench]
|
||||
fn bench_top_segment_collector_collect_not_at_capacity(b: &mut Bencher) {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 400);
|
||||
|
||||
b.iter(|| {
|
||||
for i in 0..100 {
|
||||
top_collector.collect(i, 0.8);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 100);
|
||||
|
||||
for i in 0..100 {
|
||||
top_collector.collect(i, 0.8);
|
||||
}
|
||||
|
||||
b.iter(|| {
|
||||
for i in 0..100 {
|
||||
top_collector.collect(i, 0.8);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_top_segment_collector_collect_and_harvest_many_ties(b: &mut Bencher) {
|
||||
b.iter(|| {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 100);
|
||||
|
||||
for i in 0..100 {
|
||||
top_collector.collect(i, 0.8);
|
||||
}
|
||||
|
||||
// it would be nice to be able to do the setup N times but still
|
||||
// measure only harvest(). We can't since harvest() consumes
|
||||
// the top_collector.
|
||||
top_collector.harvest()
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_top_segment_collector_collect_and_harvest_no_tie(b: &mut Bencher) {
|
||||
b.iter(|| {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 100);
|
||||
let mut score = 1.0;
|
||||
|
||||
for i in 0..100 {
|
||||
score += 1.0;
|
||||
top_collector.collect(i, score);
|
||||
}
|
||||
|
||||
// it would be nice to be able to do the setup N times but still
|
||||
// measure only harvest(). We can't since harvest() consumes
|
||||
// the top_collector.
|
||||
top_collector.harvest()
|
||||
});
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,124 +0,0 @@
|
||||
use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
|
||||
use crate::collector::{Collector, SegmentCollector};
|
||||
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
|
||||
|
||||
pub(crate) struct TweakedScoreTopCollector<TScoreTweaker, TScore = Score> {
|
||||
score_tweaker: TScoreTweaker,
|
||||
collector: TopCollector<TScore>,
|
||||
}
|
||||
|
||||
impl<TScoreTweaker, TScore> TweakedScoreTopCollector<TScoreTweaker, TScore>
|
||||
where TScore: Clone + PartialOrd
|
||||
{
|
||||
pub fn new(
|
||||
score_tweaker: TScoreTweaker,
|
||||
collector: TopCollector<TScore>,
|
||||
) -> TweakedScoreTopCollector<TScoreTweaker, TScore> {
|
||||
TweakedScoreTopCollector {
|
||||
score_tweaker,
|
||||
collector,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A `ScoreSegmentTweaker` makes it possible to modify the default score
|
||||
/// for a given document belonging to a specific segment.
|
||||
///
|
||||
/// It is the segment local version of the [`ScoreTweaker`].
|
||||
pub trait ScoreSegmentTweaker<TScore>: 'static {
|
||||
/// Tweak the given `score` for the document `doc`.
|
||||
fn score(&mut self, doc: DocId, score: Score) -> TScore;
|
||||
}
|
||||
|
||||
/// `ScoreTweaker` makes it possible to tweak the score
|
||||
/// emitted by the scorer into another one.
|
||||
///
|
||||
/// The `ScoreTweaker` itself does not make much of the computation itself.
|
||||
/// Instead, it helps constructing `Self::Child` instances that will compute
|
||||
/// the score at a segment scale.
|
||||
pub trait ScoreTweaker<TScore>: Sync {
|
||||
/// Type of the associated [`ScoreSegmentTweaker`].
|
||||
type Child: ScoreSegmentTweaker<TScore>;
|
||||
|
||||
/// Builds a child tweaker for a specific segment. The child scorer is associated with
|
||||
/// a specific segment.
|
||||
fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
|
||||
}
|
||||
|
||||
impl<TScoreTweaker, TScore> Collector for TweakedScoreTopCollector<TScoreTweaker, TScore>
|
||||
where
|
||||
TScoreTweaker: ScoreTweaker<TScore> + Send + Sync,
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync,
|
||||
{
|
||||
type Fruit = Vec<(TScore, DocAddress)>;
|
||||
|
||||
type Child = TopTweakedScoreSegmentCollector<TScoreTweaker::Child, TScore>;
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
segment_reader: &SegmentReader,
|
||||
) -> Result<Self::Child> {
|
||||
let segment_scorer = self.score_tweaker.segment_tweaker(segment_reader)?;
|
||||
let segment_collector = self.collector.for_segment(segment_local_id, segment_reader);
|
||||
Ok(TopTweakedScoreSegmentCollector {
|
||||
segment_collector,
|
||||
segment_scorer,
|
||||
})
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
|
||||
self.collector.merge_fruits(segment_fruits)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TopTweakedScoreSegmentCollector<TSegmentScoreTweaker, TScore>
|
||||
where
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync + Sized,
|
||||
TSegmentScoreTweaker: ScoreSegmentTweaker<TScore>,
|
||||
{
|
||||
segment_collector: TopSegmentCollector<TScore>,
|
||||
segment_scorer: TSegmentScoreTweaker,
|
||||
}
|
||||
|
||||
impl<TSegmentScoreTweaker, TScore> SegmentCollector
|
||||
for TopTweakedScoreSegmentCollector<TSegmentScoreTweaker, TScore>
|
||||
where
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync,
|
||||
TSegmentScoreTweaker: 'static + ScoreSegmentTweaker<TScore>,
|
||||
{
|
||||
type Fruit = Vec<(TScore, DocAddress)>;
|
||||
|
||||
fn collect(&mut self, doc: DocId, score: Score) {
|
||||
let score = self.segment_scorer.score(doc, score);
|
||||
self.segment_collector.collect(doc, score);
|
||||
}
|
||||
|
||||
fn harvest(self) -> Vec<(TScore, DocAddress)> {
|
||||
self.segment_collector.harvest()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, TScore, TSegmentScoreTweaker> ScoreTweaker<TScore> for F
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(&SegmentReader) -> TSegmentScoreTweaker,
|
||||
TSegmentScoreTweaker: ScoreSegmentTweaker<TScore>,
|
||||
{
|
||||
type Child = TSegmentScoreTweaker;
|
||||
|
||||
fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
Ok((self)(segment_reader))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, TScore> ScoreSegmentTweaker<TScore> for F
|
||||
where F: 'static + FnMut(DocId, Score) -> TScore
|
||||
{
|
||||
fn score(&mut self, doc: DocId, score: Score) -> TScore {
|
||||
(self)(doc, score)
|
||||
}
|
||||
}
|
||||
@@ -36,6 +36,7 @@ fn path_for_version(version: &str) -> String {
|
||||
/// feature flag quickwit uses a different dictionary type
|
||||
#[test]
|
||||
#[cfg(not(feature = "quickwit"))]
|
||||
#[ignore = "test incompatible with fixed-width footer changes"]
|
||||
fn test_format_6() {
|
||||
let path = path_for_version("6");
|
||||
|
||||
@@ -47,6 +48,7 @@ fn test_format_6() {
|
||||
/// feature flag quickwit uses a different dictionary type
|
||||
#[test]
|
||||
#[cfg(not(feature = "quickwit"))]
|
||||
#[ignore = "test incompatible with fixed-width footer changes"]
|
||||
fn test_format_7() {
|
||||
let path = path_for_version("7");
|
||||
|
||||
@@ -69,7 +71,7 @@ fn assert_date_time_precision(index: &Index, doc_store_precision: DateTimePrecis
|
||||
.parse_query("dateformat")
|
||||
.expect("Failed to parse query");
|
||||
let top_docs = searcher
|
||||
.search(&query, &TopDocs::with_limit(1))
|
||||
.search(&query, &TopDocs::with_limit(1).order_by_score())
|
||||
.expect("Search failed");
|
||||
|
||||
assert_eq!(top_docs.len(), 1, "Expected 1 search result");
|
||||
|
||||
@@ -3,6 +3,7 @@ use common::json_path_writer::{JSON_END_OF_PATH, JSON_PATH_SEGMENT_SEP};
|
||||
use common::{replace_in_place, JsonPathWriter};
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::indexer::indexing_term::IndexingTerm;
|
||||
use crate::postings::{IndexingContext, IndexingPosition, PostingsWriter};
|
||||
use crate::schema::document::{ReferenceValue, ReferenceValueLeaf, Value};
|
||||
use crate::schema::{Type, DATE_TIME_PRECISION_INDEXED};
|
||||
@@ -77,7 +78,7 @@ fn index_json_object<'a, V: Value<'a>>(
|
||||
doc: DocId,
|
||||
json_visitor: V::ObjectIter,
|
||||
text_analyzer: &mut TextAnalyzer,
|
||||
term_buffer: &mut Term,
|
||||
term_buffer: &mut IndexingTerm,
|
||||
json_path_writer: &mut JsonPathWriter,
|
||||
postings_writer: &mut dyn PostingsWriter,
|
||||
ctx: &mut IndexingContext,
|
||||
@@ -107,17 +108,17 @@ pub(crate) fn index_json_value<'a, V: Value<'a>>(
|
||||
doc: DocId,
|
||||
json_value: V,
|
||||
text_analyzer: &mut TextAnalyzer,
|
||||
term_buffer: &mut Term,
|
||||
term_buffer: &mut IndexingTerm,
|
||||
json_path_writer: &mut JsonPathWriter,
|
||||
postings_writer: &mut dyn PostingsWriter,
|
||||
ctx: &mut IndexingContext,
|
||||
positions_per_path: &mut IndexingPositionsPerPath,
|
||||
) {
|
||||
let set_path_id = |term_buffer: &mut Term, unordered_id: u32| {
|
||||
let set_path_id = |term_buffer: &mut IndexingTerm, unordered_id: u32| {
|
||||
term_buffer.truncate_value_bytes(0);
|
||||
term_buffer.append_bytes(&unordered_id.to_be_bytes());
|
||||
};
|
||||
let set_type = |term_buffer: &mut Term, typ: Type| {
|
||||
let set_type = |term_buffer: &mut IndexingTerm, typ: Type| {
|
||||
term_buffer.append_bytes(&[typ.to_code()]);
|
||||
};
|
||||
|
||||
@@ -405,7 +406,7 @@ mod tests {
|
||||
let mut term = Term::from_field_json_path(field, "color", false);
|
||||
term.append_type_and_str("red");
|
||||
|
||||
assert_eq!(term.serialized_term(), b"\x00\x00\x00\x01jcolor\x00sred")
|
||||
assert_eq!(term.serialized_value_bytes(), b"color\x00sred".to_vec())
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -415,8 +416,8 @@ mod tests {
|
||||
term.append_type_and_fast_value(-4i64);
|
||||
|
||||
assert_eq!(
|
||||
term.serialized_term(),
|
||||
b"\x00\x00\x00\x01jcolor\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc"
|
||||
term.serialized_value_bytes(),
|
||||
b"color\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc".to_vec()
|
||||
)
|
||||
}
|
||||
|
||||
@@ -427,8 +428,8 @@ mod tests {
|
||||
term.append_type_and_fast_value(4u64);
|
||||
|
||||
assert_eq!(
|
||||
term.serialized_term(),
|
||||
b"\x00\x00\x00\x01jcolor\x00u\x00\x00\x00\x00\x00\x00\x00\x04"
|
||||
term.serialized_value_bytes(),
|
||||
b"color\x00u\x00\x00\x00\x00\x00\x00\x00\x04".to_vec()
|
||||
)
|
||||
}
|
||||
|
||||
@@ -438,8 +439,8 @@ mod tests {
|
||||
let mut term = Term::from_field_json_path(field, "color", false);
|
||||
term.append_type_and_fast_value(4.0f64);
|
||||
assert_eq!(
|
||||
term.serialized_term(),
|
||||
b"\x00\x00\x00\x01jcolor\x00f\xc0\x10\x00\x00\x00\x00\x00\x00"
|
||||
term.serialized_value_bytes(),
|
||||
b"color\x00f\xc0\x10\x00\x00\x00\x00\x00\x00".to_vec()
|
||||
)
|
||||
}
|
||||
|
||||
@@ -449,8 +450,8 @@ mod tests {
|
||||
let mut term = Term::from_field_json_path(field, "color", false);
|
||||
term.append_type_and_fast_value(true);
|
||||
assert_eq!(
|
||||
term.serialized_term(),
|
||||
b"\x00\x00\x00\x01jcolor\x00o\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||
term.serialized_value_bytes(),
|
||||
b"color\x00o\x00\x00\x00\x00\x00\x00\x00\x01".to_vec()
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::{fmt, io};
|
||||
|
||||
use crate::collector::Collector;
|
||||
@@ -86,7 +86,7 @@ impl Searcher {
|
||||
/// The searcher uses the segment ordinal to route the
|
||||
/// request to the right `Segment`.
|
||||
pub fn doc<D: DocumentDeserialize>(&self, doc_address: DocAddress) -> crate::Result<D> {
|
||||
let store_reader = &self.inner.store_readers[doc_address.segment_ord as usize];
|
||||
let store_reader = &self.inner.store_readers()[doc_address.segment_ord as usize];
|
||||
store_reader.get(doc_address.doc_id)
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ impl Searcher {
|
||||
pub fn doc_store_cache_stats(&self) -> CacheStats {
|
||||
let cache_stats: CacheStats = self
|
||||
.inner
|
||||
.store_readers
|
||||
.store_readers()
|
||||
.iter()
|
||||
.map(|reader| reader.cache_stats())
|
||||
.sum();
|
||||
@@ -110,7 +110,7 @@ impl Searcher {
|
||||
doc_address: DocAddress,
|
||||
) -> crate::Result<D> {
|
||||
let executor = self.inner.index.search_executor();
|
||||
let store_reader = &self.inner.store_readers[doc_address.segment_ord as usize];
|
||||
let store_reader = &self.inner.store_readers()[doc_address.segment_ord as usize];
|
||||
store_reader.get_async(doc_address.doc_id, executor).await
|
||||
}
|
||||
|
||||
@@ -225,6 +225,7 @@ impl Searcher {
|
||||
enabled_scoring: EnableScoring,
|
||||
) -> crate::Result<C::Fruit> {
|
||||
let weight = query.weight(enabled_scoring)?;
|
||||
collector.check_schema(self.schema())?;
|
||||
let segment_readers = self.segment_readers();
|
||||
let fruits = executor.map(
|
||||
|(segment_ord, segment_reader)| {
|
||||
@@ -258,8 +259,9 @@ impl From<Arc<SearcherInner>> for Searcher {
|
||||
pub(crate) struct SearcherInner {
|
||||
schema: Schema,
|
||||
index: Index,
|
||||
doc_store_cache_num_blocks: usize,
|
||||
segment_readers: Vec<SegmentReader>,
|
||||
store_readers: Vec<StoreReader>,
|
||||
store_readers: OnceLock<Vec<StoreReader>>,
|
||||
generation: TrackedObject<SearcherGeneration>,
|
||||
}
|
||||
|
||||
@@ -280,19 +282,30 @@ impl SearcherInner {
|
||||
generation.segments(),
|
||||
"Set of segments referenced by this Searcher and its SearcherGeneration must match"
|
||||
);
|
||||
let store_readers: Vec<StoreReader> = segment_readers
|
||||
.iter()
|
||||
.map(|segment_reader| segment_reader.get_store_reader(doc_store_cache_num_blocks))
|
||||
.collect::<io::Result<Vec<_>>>()?;
|
||||
|
||||
Ok(SearcherInner {
|
||||
schema,
|
||||
index,
|
||||
doc_store_cache_num_blocks,
|
||||
segment_readers,
|
||||
store_readers,
|
||||
store_readers: OnceLock::default(),
|
||||
generation,
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn store_readers(&self) -> &[StoreReader] {
|
||||
self.store_readers.get_or_init(|| {
|
||||
self.segment_readers
|
||||
.iter()
|
||||
.map(|segment_reader| {
|
||||
segment_reader
|
||||
.get_store_reader(self.doc_store_cache_num_blocks)
|
||||
.expect("should be able to get store reader")
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Searcher {
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::ops::Range;
|
||||
use common::{BinarySerializable, CountingWriter, HasLen, VInt};
|
||||
|
||||
use crate::directory::{FileSlice, TerminatingWrite, WritePtr};
|
||||
use crate::schema::Field;
|
||||
use crate::schema::{Field, Schema};
|
||||
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
|
||||
|
||||
#[derive(Eq, PartialEq, Hash, Copy, Ord, PartialOrd, Clone, Debug)]
|
||||
@@ -167,10 +167,11 @@ impl CompositeFile {
|
||||
.map(|byte_range| self.data.slice(byte_range.clone()))
|
||||
}
|
||||
|
||||
pub fn space_usage(&self) -> PerFieldSpaceUsage {
|
||||
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
|
||||
let mut fields = Vec::new();
|
||||
for (&field_addr, byte_range) in &self.offsets_index {
|
||||
let mut field_usage = FieldUsage::empty(field_addr.field);
|
||||
let field_name = schema.get_field_name(field_addr.field).to_string();
|
||||
let mut field_usage = FieldUsage::empty(field_name);
|
||||
field_usage.add_field_idx(field_addr.idx, byte_range.len().into());
|
||||
fields.push(field_usage);
|
||||
}
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
use std::any::Any;
|
||||
use std::collections::HashSet;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, io, thread};
|
||||
|
||||
use log::Level;
|
||||
|
||||
use crate::directory::directory_lock::Lock;
|
||||
use crate::directory::error::{DeleteError, LockError, OpenReadError, OpenWriteError};
|
||||
use crate::directory::{FileHandle, FileSlice, WatchCallback, WatchHandle, WritePtr};
|
||||
use crate::directory::{
|
||||
FileHandle, FileSlice, TerminatingWrite, WatchCallback, WatchHandle, WritePtr,
|
||||
};
|
||||
use crate::index::SegmentMetaInventory;
|
||||
use crate::IndexMeta;
|
||||
|
||||
/// Retry the logic of acquiring locks is pretty simple.
|
||||
/// We just retry `n` times after a given `duratio`, both
|
||||
@@ -56,7 +64,7 @@ impl<T: Send + Sync + 'static> From<Box<T>> for DirectoryLock {
|
||||
impl Drop for DirectoryLockGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Err(e) = self.directory.delete(&self.path) {
|
||||
error!("Failed to remove the lock file. {e:?}");
|
||||
error!("Failed to remove the lock file. {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -97,6 +105,8 @@ fn retry_policy(is_blocking: bool) -> RetryPolicy {
|
||||
}
|
||||
}
|
||||
|
||||
pub type DirectoryPanicHandler = Arc<dyn Fn(Box<dyn Any + Send>) + Send + Sync + 'static>;
|
||||
|
||||
/// Write-once read many (WORM) abstraction for where
|
||||
/// tantivy's data should be stored.
|
||||
///
|
||||
@@ -108,7 +118,7 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
|
||||
/// Opens a file and returns a boxed `FileHandle`.
|
||||
///
|
||||
/// Users of `Directory` should typically call `Directory::open_read(...)`,
|
||||
/// while `Directory` implementor should implement `get_file_handle()`.
|
||||
/// while `Directory` implementer should implement `get_file_handle()`.
|
||||
fn get_file_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>, OpenReadError>;
|
||||
|
||||
/// Once a virtual file is open, its data may not
|
||||
@@ -135,6 +145,10 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
|
||||
/// Returns true if and only if the file exists
|
||||
fn exists(&self, path: &Path) -> Result<bool, OpenReadError>;
|
||||
|
||||
/// Returns a boxed `TerminatingWrite` object, to be passed into `open_write`
|
||||
/// which wraps it in a `BufWriter`
|
||||
fn open_write_inner(&self, path: &Path) -> Result<Box<dyn TerminatingWrite>, OpenWriteError>;
|
||||
|
||||
/// Opens a writer for the *virtual file* associated with
|
||||
/// a [`Path`].
|
||||
///
|
||||
@@ -161,7 +175,12 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
|
||||
/// panic! if `flush` was not called.
|
||||
///
|
||||
/// The file may not previously exist.
|
||||
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError>;
|
||||
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
|
||||
Ok(io::BufWriter::with_capacity(
|
||||
self.bufwriter_capacity(),
|
||||
self.open_write_inner(path)?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Reads the full content file that has been written using
|
||||
/// [`Directory::atomic_write()`].
|
||||
@@ -223,6 +242,75 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
|
||||
/// `OnCommitWithDelay` `ReloadPolicy`. Not implementing watch in a `Directory` only prevents
|
||||
/// the `OnCommitWithDelay` `ReloadPolicy` to work properly.
|
||||
fn watch(&self, watch_callback: WatchCallback) -> crate::Result<WatchHandle>;
|
||||
|
||||
/// Allows the directory to list managed files, overriding the ManagedDirectory's default
|
||||
/// list_managed_files
|
||||
fn list_managed_files(&self) -> crate::Result<HashSet<PathBuf>> {
|
||||
Err(crate::TantivyError::InternalError(
|
||||
"list_managed_files not implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Allows the directory to register a file as managed, overriding the ManagedDirectory's
|
||||
/// default register_file_as_managed
|
||||
fn register_files_as_managed(
|
||||
&self,
|
||||
_files: Vec<PathBuf>,
|
||||
_overwrite: bool,
|
||||
) -> crate::Result<()> {
|
||||
Err(crate::TantivyError::InternalError(
|
||||
"register_files_as_managed not implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Allows the directory to save IndexMeta, overriding the SegmentUpdater's default save_meta
|
||||
fn save_metas(
|
||||
&self,
|
||||
_metas: &IndexMeta,
|
||||
_previous_metas: &IndexMeta,
|
||||
_payload: &mut (dyn Any + '_),
|
||||
) -> crate::Result<()> {
|
||||
Err(crate::TantivyError::InternalError(
|
||||
"save_meta not implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Allows the directory to load IndexMeta, overriding the SegmentUpdater's default load_meta
|
||||
fn load_metas(&self, _inventory: &SegmentMetaInventory) -> crate::Result<IndexMeta> {
|
||||
Err(crate::TantivyError::InternalError(
|
||||
"load_metas not implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Returns true if this directory supports garbage collection. The default assumption is
|
||||
/// `true`
|
||||
fn supports_garbage_collection(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Return a panic handler to be assigned to the various thread pools that may be created
|
||||
///
|
||||
/// The default is [`None`], which indicates that an unhandled panic from a thread pool will
|
||||
/// abort the process
|
||||
fn panic_handler(&self) -> Option<DirectoryPanicHandler> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns true if this directory is in a position of requiring that tantivy cancel
|
||||
/// whatever operation(s) it might be doing Typically this is just for the background
|
||||
/// merge processes but could be used for anything
|
||||
fn wants_cancel(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Send a logging message to the Directory to handle in its own way
|
||||
fn log(&self, message: &str) {
|
||||
log!(Level::Info, "{message}");
|
||||
}
|
||||
|
||||
fn bufwriter_capacity(&self) -> usize {
|
||||
8192
|
||||
}
|
||||
}
|
||||
|
||||
/// DirectoryClone
|
||||
|
||||
@@ -58,3 +58,9 @@ pub static META_LOCK: Lazy<Lock> = Lazy::new(|| Lock {
|
||||
filepath: PathBuf::from(".tantivy-meta.lock"),
|
||||
is_blocking: true,
|
||||
});
|
||||
|
||||
#[allow(missing_docs)]
|
||||
pub static MANAGED_LOCK: Lazy<Lock> = Lazy::new(|| Lock {
|
||||
filepath: PathBuf::from(".tantivy-managed.lock"),
|
||||
is_blocking: true,
|
||||
});
|
||||
|
||||
@@ -9,6 +9,7 @@ use crc32fast::Hasher;
|
||||
|
||||
use crate::directory::{WatchCallback, WatchCallbackList, WatchHandle};
|
||||
|
||||
#[allow(dead_code)]
|
||||
const POLLING_INTERVAL: Duration = Duration::from_millis(if cfg!(test) { 1 } else { 500 });
|
||||
|
||||
// Watches a file and executes registered callbacks when the file is modified.
|
||||
@@ -18,6 +19,7 @@ pub struct FileWatcher {
|
||||
state: Arc<AtomicUsize>, // 0: new, 1: runnable, 2: terminated
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl FileWatcher {
|
||||
pub fn new(path: &Path) -> FileWatcher {
|
||||
FileWatcher {
|
||||
|
||||
@@ -7,15 +7,14 @@
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
|
||||
use common::{BinarySerializable, CountingWriter, DeserializeFrom, FixedSize, HasLen};
|
||||
use common::{BinarySerializable, HasLen};
|
||||
use crc32fast::Hasher;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::directory::error::Incompatibility;
|
||||
use crate::directory::{AntiCallToken, FileSlice, TerminatingWrite};
|
||||
use crate::{Version, INDEX_FORMAT_OLDEST_SUPPORTED_VERSION, INDEX_FORMAT_VERSION};
|
||||
|
||||
const FOOTER_MAX_LEN: u32 = 50_000;
|
||||
pub const FOOTER_LEN: usize = 24;
|
||||
|
||||
/// The magic byte of the footer to identify corruption
|
||||
/// or an old version of the footer.
|
||||
@@ -24,7 +23,7 @@ const FOOTER_MAGIC_NUMBER: u32 = 1337;
|
||||
type CrcHashU32 = u32;
|
||||
|
||||
/// A Footer is appended to every file
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Footer {
|
||||
/// The version of the index format
|
||||
pub version: Version,
|
||||
@@ -41,34 +40,45 @@ impl Footer {
|
||||
pub(crate) fn crc(&self) -> CrcHashU32 {
|
||||
self.crc
|
||||
}
|
||||
pub(crate) fn append_footer<W: io::Write>(&self, mut write: &mut W) -> io::Result<()> {
|
||||
let mut counting_write = CountingWriter::wrap(&mut write);
|
||||
counting_write.write_all(serde_json::to_string(&self)?.as_ref())?;
|
||||
let footer_payload_len = counting_write.written_bytes();
|
||||
BinarySerializable::serialize(&(footer_payload_len as u32), write)?;
|
||||
pub fn append_footer<W: io::Write>(&self, write: &mut W) -> io::Result<()> {
|
||||
// 24 bytes
|
||||
BinarySerializable::serialize(&self.version.major, write)?;
|
||||
BinarySerializable::serialize(&self.version.minor, write)?;
|
||||
BinarySerializable::serialize(&self.version.patch, write)?;
|
||||
BinarySerializable::serialize(&self.version.index_format_version, write)?;
|
||||
BinarySerializable::serialize(&self.crc, write)?;
|
||||
BinarySerializable::serialize(&FOOTER_MAGIC_NUMBER, write)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extracts the tantivy Footer from the file and returns the footer and the rest of the file
|
||||
pub fn extract_footer(file: FileSlice) -> io::Result<(Footer, FileSlice)> {
|
||||
if file.len() < 4 {
|
||||
if file.len() < FOOTER_LEN {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
format!(
|
||||
"File corrupted. The file is smaller than 4 bytes (len={}).",
|
||||
"File corrupted. The file is too small to contain the {FOOTER_LEN} byte \
|
||||
footer (len={}).",
|
||||
file.len()
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let footer_metadata_len = <(u32, u32)>::SIZE_IN_BYTES;
|
||||
let (footer_len, footer_magic_byte): (u32, u32) = file
|
||||
.slice_from_end(footer_metadata_len)
|
||||
.read_bytes()?
|
||||
.as_ref()
|
||||
.deserialize()?;
|
||||
let (body_slice, footer_slice) = file.split_from_end(FOOTER_LEN);
|
||||
let footer_bytes = footer_slice.read_bytes()?;
|
||||
let mut footer_bytes = footer_bytes.as_slice();
|
||||
|
||||
let footer = Footer {
|
||||
version: Version {
|
||||
major: u32::deserialize(&mut footer_bytes)?,
|
||||
minor: u32::deserialize(&mut footer_bytes)?,
|
||||
patch: u32::deserialize(&mut footer_bytes)?,
|
||||
index_format_version: u32::deserialize(&mut footer_bytes)?,
|
||||
},
|
||||
crc: u32::deserialize(&mut footer_bytes)?,
|
||||
};
|
||||
|
||||
let footer_magic_byte = u32::deserialize(&mut footer_bytes)?;
|
||||
if footer_magic_byte != FOOTER_MAGIC_NUMBER {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
@@ -78,38 +88,12 @@ impl Footer {
|
||||
));
|
||||
}
|
||||
|
||||
if footer_len > FOOTER_MAX_LEN {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"Footer seems invalid as it suggests a footer len of {footer_len}. File is \
|
||||
corrupted, or the index was created with a different & old version of \
|
||||
tantivy."
|
||||
),
|
||||
));
|
||||
}
|
||||
let total_footer_size = footer_len as usize + footer_metadata_len;
|
||||
if file.len() < total_footer_size {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
format!(
|
||||
"File corrupted. The file is smaller than it's footer bytes \
|
||||
(len={total_footer_size})."
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let footer: Footer =
|
||||
serde_json::from_slice(&file.read_bytes_slice(
|
||||
file.len() - total_footer_size..file.len() - footer_metadata_len,
|
||||
)?)?;
|
||||
|
||||
let body = file.slice_to(file.len() - total_footer_size);
|
||||
Ok((footer, body))
|
||||
Ok((footer, body_slice))
|
||||
}
|
||||
|
||||
/// Confirms that the index will be read correctly by this version of tantivy
|
||||
/// Has to be called after `extract_footer` to make sure it's not accessing uninitialised memory
|
||||
#[allow(dead_code)]
|
||||
pub fn is_compatible(&self) -> Result<(), Incompatibility> {
|
||||
const SUPPORTED_INDEX_FORMAT_VERSION_RANGE: std::ops::RangeInclusive<u32> =
|
||||
INDEX_FORMAT_OLDEST_SUPPORTED_VERSION..=INDEX_FORMAT_VERSION;
|
||||
@@ -188,6 +172,10 @@ mod tests {
|
||||
fn test_deserialize_footer_missing_magic_byte() {
|
||||
let mut buf: Vec<u8> = vec![];
|
||||
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
|
||||
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
|
||||
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
|
||||
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
|
||||
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
|
||||
let wrong_magic_byte: u32 = 5555;
|
||||
BinarySerializable::serialize(&wrong_magic_byte, &mut buf).unwrap();
|
||||
|
||||
@@ -205,7 +193,6 @@ mod tests {
|
||||
#[test]
|
||||
fn test_deserialize_footer_wrong_filesize() {
|
||||
let mut buf: Vec<u8> = vec![];
|
||||
BinarySerializable::serialize(&100_u32, &mut buf).unwrap();
|
||||
BinarySerializable::serialize(&FOOTER_MAGIC_NUMBER, &mut buf).unwrap();
|
||||
|
||||
let owned_bytes = OwnedBytes::new(buf);
|
||||
@@ -215,27 +202,7 @@ mod tests {
|
||||
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"File corrupted. The file is smaller than it\'s footer bytes (len=108)."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_too_large_footer() {
|
||||
let mut buf: Vec<u8> = vec![];
|
||||
|
||||
let footer_length = super::FOOTER_MAX_LEN + 1;
|
||||
BinarySerializable::serialize(&footer_length, &mut buf).unwrap();
|
||||
BinarySerializable::serialize(&FOOTER_MAGIC_NUMBER, &mut buf).unwrap();
|
||||
|
||||
let owned_bytes = OwnedBytes::new(buf);
|
||||
|
||||
let fileslice = FileSlice::new(Arc::new(owned_bytes));
|
||||
let err = Footer::extract_footer(fileslice).unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"Footer seems invalid as it suggests a footer len of 50001. File is corrupted, or the \
|
||||
index was created with a different & old version of tantivy."
|
||||
"File corrupted. The file is too small to contain the 24 byte footer (len=4)."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user