Compare commits

..

14 Commits

Author SHA1 Message Date
Pascal Seitz
2edda73656 fix name, add comment 2026-01-05 15:21:12 +01:00
Pascal Seitz
b269fd8bb8 move partitions to heap 2026-01-05 15:07:00 +01:00
Pascal Seitz
8d3a7abbe9 split subaggcache into two trait impls 2026-01-05 14:39:29 +01:00
Pascal Seitz
3ddca31292 simplify fetch block in column_block_accessor 2025-12-29 10:56:50 +01:00
Pascal Seitz
87fe3a311f share column block accessor 2025-12-12 16:54:44 +08:00
Pascal Seitz
71dc08424c add comment 2025-12-12 16:19:06 +08:00
Pascal Seitz
15913446b8 cleanup
remove clone
move data in term req, single doc opt for stats
2025-12-10 14:34:43 +08:00
Pascal Seitz
78bd3826dc remove stacktrace bloat, use &mut helper
increase cache to 2048
2025-12-08 10:35:23 +08:00
Pascal Seitz
1b56487307 specialize columntype in stats 2025-12-08 10:20:09 +08:00
Pascal Seitz
030554d544 use radix map, fix prepare_max_bucket
use paged term map in term agg
use special no sub agg term map impl
2025-12-08 10:20:09 +08:00
Pascal Seitz
c852bac532 reduce dynamic dispatch, faster term agg 2025-12-08 10:20:09 +08:00
Pascal Seitz
2ce4da8b66 one collector per agg request instead per bucket
In this refactoring a collector knows in which bucket of the parent
their data is in. This allows to convert the previous approach of one
collector per bucket to one collector per request.

low card bucket optimization
2025-12-08 10:20:09 +08:00
Pascal Seitz
0dd6a958f8 add more tests for new collection type 2025-12-08 10:20:09 +08:00
Pascal Seitz
254314a4a3 improve bench 2025-12-07 10:05:30 +08:00
190 changed files with 3889 additions and 13446 deletions

29
.github/workflows/coverage.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: Coverage
on:
push:
branches: [main]
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
coverage:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
continue-on-error: true
with:
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
files: lcov.info
fail_ci_if_error: true

View File

@@ -76,9 +76,7 @@ jobs:
profile: minimal
override: true
- uses: taiki-e/install-action@v2
with:
tool: 'nextest'
- uses: taiki-e/install-action@nextest
- uses: Swatinem/rust-cache@v2
- name: Run tests

5
.gitignore vendored
View File

@@ -6,6 +6,7 @@ target
target/debug
.vscode
target/release
Cargo.lock
benchmark
.DS_Store
*.bk
@@ -14,7 +15,3 @@ trace.dat
cargo-timing*
control
variable
# for `sample record -p`
profile.json
profile.json.gz

2361
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -21,11 +21,11 @@ byteorder = "1.4.3"
crc32fast = "1.3.2"
once_cell = "1.10.0"
regex = { version = "1.5.5", default-features = false, features = [
"std",
"unicode",
"std",
"unicode",
] }
aho-corasick = "1.0"
tantivy-fst = { git = "https://github.com/paradedb/fst.git" }
tantivy-fst = "0.5"
memmap2 = { version = "0.9.0", optional = true }
lz4_flex = { version = "0.11", default-features = false, optional = true }
zstd = { version = "0.13", optional = true, default-features = false }
@@ -38,10 +38,9 @@ levenshtein_automata = "0.2.1"
uuid = { version = "1.0.0", features = ["v4", "serde"] }
crossbeam-channel = "0.5.4"
rust-stemmers = "1.2.0"
tantivy-stemmers = { version = "0.4.0", default-features = false, features = ["polish_yarovoy"] }
downcast-rs = "2.0.1"
bitpacking = { version = "0.9.2", default-features = false, features = [
"bitpacker4x",
"bitpacker4x",
] }
census = "0.4.2"
rustc-hash = "2.0.0"
@@ -49,10 +48,6 @@ thiserror = "2.0.1"
htmlescape = "0.3.1"
fail = { version = "0.5.0", optional = true }
time = { version = "0.3.35", features = ["serde-well-known"] }
# TODO: We have integer wrappers with PartialOrd, and a misfeature of
# `deranged` causes inference to fail in a bunch of cases. See
# https://github.com/jhpratt/deranged/issues/18#issuecomment-2746844093
deranged = "=0.4.0"
smallvec = "1.8.0"
rayon = "1.5.2"
lru = "0.12.0"
@@ -74,7 +69,6 @@ hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
parking_lot = "0.12.4"
typetag = "0.2.21"
[target.'cfg(windows)'.dependencies]
@@ -94,7 +88,7 @@ more-asserts = "0.3.1"
rand_distr = "0.4.3"
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
postcard = { version = "1.0.4", features = [
"use-std",
"use-std",
], default-features = false }
[target.'cfg(not(windows))'.dev-dependencies]
@@ -141,14 +135,14 @@ compare_hash_only = ["stacker/compare_hash_only"]
[workspace]
members = [
"query-grammar",
"bitpacker",
"common",
"ownedbytes",
"stacker",
"sstable",
"tokenizer-api",
"columnar",
"query-grammar",
"bitpacker",
"common",
"ownedbytes",
"stacker",
"sstable",
"tokenizer-api",
"columnar",
]
# Following the "fail" crate best practises, we isolate

View File

@@ -54,33 +54,33 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, stats_f64);
register!(group, extendedstats_f64);
register!(group, percentiles_f64);
register!(group, terms_few);
register!(group, terms_7);
register!(group, terms_all_unique);
register!(group, terms_many);
register!(group, terms_150_000);
register!(group, terms_many_top_1000);
register!(group, terms_many_order_by_term);
register!(group, terms_many_with_top_hits);
register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_few_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status);
register!(group, terms_few_with_histogram);
register!(group, terms_status_with_histogram);
register!(group, terms_zipf_1000);
register!(group, terms_zipf_1000_with_histogram);
register!(group, terms_zipf_1000_with_avg_sub_agg);
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, cardinality_agg);
register!(group, terms_few_with_cardinality_agg);
register!(group, terms_status_with_cardinality_agg);
register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg);
register!(group, range_agg_with_term_agg_few);
register!(group, range_agg_with_term_agg_status);
register!(group, range_agg_with_term_agg_many);
register!(group, histogram);
register!(group, histogram_hard_bounds);
register!(group, histogram_with_avg_sub_agg);
register!(group, histogram_with_term_agg_few);
register!(group, histogram_with_term_agg_status);
register!(group, avg_and_range_with_avg_sub_agg);
// Filter aggregation benchmarks
@@ -159,10 +159,10 @@ fn cardinality_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_few_with_cardinality_agg(index: &Index) {
fn terms_status_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"terms": { "field": "text_few_terms_status" },
"aggs": {
"cardinality": {
"cardinality": {
@@ -175,13 +175,7 @@ fn terms_few_with_cardinality_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_few(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms" } },
});
execute_agg(index, agg_req);
}
fn terms_status(index: &Index) {
fn terms_7(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms_status" } },
});
@@ -194,7 +188,7 @@ fn terms_all_unique(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_many(index: &Index) {
fn terms_150_000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms" } },
});
@@ -253,17 +247,6 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_few_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -276,17 +259,18 @@ fn terms_status_with_histogram(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_few_with_avg_sub_agg(index: &Index) {
fn terms_zipf_1000_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
},
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -299,6 +283,25 @@ fn terms_status_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_1000_terms_zipf" } },
});
execute_agg(index, agg_req);
}
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -354,7 +357,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn range_agg_with_term_agg_few(index: &Index) {
fn range_agg_with_term_agg_status(index: &Index) {
let agg_req = json!({
"rangef64": {
"range": {
@@ -369,7 +372,7 @@ fn range_agg_with_term_agg_few(index: &Index) {
]
},
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms" } },
"my_texts": { "terms": { "field": "text_few_terms_status" } },
}
},
});
@@ -425,12 +428,12 @@ fn histogram_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn histogram_with_term_agg_few(index: &Index) {
fn histogram_with_term_agg_status(index: &Index) {
let agg_req = json!({
"rangef64": {
"histogram": { "field": "score_f64", "interval": 10 },
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms" } }
"my_texts": { "terms": { "field": "text_few_terms_status" } }
}
}
});
@@ -475,6 +478,13 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
// Flag to use existing index
let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok();
if reuse_index && std::path::Path::new("agg_bench").exists() {
return Index::open_in_dir("agg_bench");
}
// crreate dir
std::fs::create_dir_all("agg_bench")?;
let mut schema_builder = Schema::builder();
let text_fieldtype = tantivy::schema::TextOptions::default()
.set_indexing_options(
@@ -486,24 +496,44 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
let text_field_few_terms_status =
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
let text_field_1000_terms_zipf =
schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST);
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let index = Index::create_from_tempdir(schema_builder.build())?;
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
// Approximate production log proportions: INFO dominant, WARN and DEBUG occasional, ERROR rare.
let log_level_distribution = WeightedIndex::new([80u32, 3, 12, 5]).unwrap();
// use tmp dir
let index = if reuse_index {
Index::create_in_dir("agg_bench", schema_builder.build())?
} else {
Index::create_from_tempdir(schema_builder.build())?
};
// Approximate log proportions
let status_field_data = [
("INFO", 8000),
("ERROR", 300),
("WARN", 1200),
("DEBUG", 500),
("OK", 500),
("CRITICAL", 20),
("EMERGENCY", 1),
];
let log_level_distribution =
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
let many_terms_data = (0..150_000)
.map(|num| format!("author{num}"))
.collect::<Vec<_>>();
// Prepare 1000 unique terms sampled using a Zipf distribution.
// Exponent ~1.1 approximates top-20 terms covering around ~20%.
let terms_1000: Vec<String> = (1..=1000).map(|i| format!("term_{i}")).collect();
let zipf_1000 = rand_distr::Zipf::new(1000, 1.1f64).unwrap();
{
let mut rng = StdRng::from_seed([1u8; 32]);
let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?;
@@ -513,8 +543,12 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!())?;
}
if cardinality == Cardinality::Multivalued {
let log_level_sample_a = few_terms_data[log_level_distribution.sample(&mut rng)];
let log_level_sample_b = few_terms_data[log_level_distribution.sample(&mut rng)];
let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0;
let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0;
let idx_a = zipf_1000.sample(&mut rng) as usize - 1;
let idx_b = zipf_1000.sample(&mut rng) as usize - 1;
let term_1000_a = &terms_1000[idx_a];
let term_1000_b = &terms_1000[idx_b];
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
@@ -524,10 +558,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
text_field_all_unique_terms => "coolo",
text_field_many_terms => "cool",
text_field_many_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms_status => log_level_sample_a,
text_field_few_terms_status => log_level_sample_b,
text_field_1000_terms_zipf => term_1000_a.as_str(),
text_field_1000_terms_zipf => term_1000_b.as_str(),
score_field => 1u64,
score_field => 1u64,
score_field_f64 => lg_norm.sample(&mut rng),
@@ -554,8 +588,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()),
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms_status => few_terms_data[log_level_distribution.sample(&mut rng)],
text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,
@@ -607,7 +641,7 @@ fn filter_agg_all_query_with_sub_aggs(index: &Index) {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms" }
"terms": { "field": "text_few_terms_status" }
}
}
}
@@ -623,7 +657,7 @@ fn filter_agg_term_query_with_sub_aggs(index: &Index) {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms" }
"terms": { "field": "text_few_terms_status" }
}
}
}

View File

@@ -11,6 +11,9 @@ keywords = []
documentation = "https://docs.rs/tantivy-bitpacker/latest/tantivy_bitpacker"
homepage = "https://github.com/quickwit-oss/tantivy"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bitpacking = { version = "0.9.2", default-features = false, features = ["bitpacker1x"] }

View File

@@ -48,7 +48,7 @@ impl BitPacker {
pub fn flush<TWrite: io::Write + ?Sized>(&mut self, output: &mut TWrite) -> io::Result<()> {
if self.mini_buffer_written > 0 {
let num_bytes = (self.mini_buffer_written + 7) / 8;
let num_bytes = self.mini_buffer_written.div_ceil(8);
let bytes = self.mini_buffer.to_le_bytes();
output.write_all(&bytes[..num_bytes])?;
self.mini_buffer_written = 0;
@@ -65,16 +65,10 @@ impl BitPacker {
#[derive(Clone, Debug, Default, Copy)]
pub struct BitUnpacker {
num_bits: u32,
num_bits: usize,
mask: u64,
}
pub type BlockNumber = usize;
// 16k
const BLOCK_SIZE_MIN_POW: u8 = 14;
const BLOCK_SIZE_MIN: usize = 2 << BLOCK_SIZE_MIN_POW;
impl BitUnpacker {
/// Creates a bit unpacker, that assumes the same bitwidth for all values.
///
@@ -88,9 +82,8 @@ impl BitUnpacker {
} else {
(1u64 << num_bits) - 1u64
};
BitUnpacker {
num_bits: u32::from(num_bits),
num_bits: usize::from(num_bits),
mask,
}
}
@@ -99,69 +92,16 @@ impl BitUnpacker {
self.num_bits as u8
}
/// Calculates a block number for the given `idx`.
#[inline]
pub fn block_num(&self, idx: u32) -> BlockNumber {
// Find the address in bits of the index.
let addr_in_bits = (idx * self.num_bits) as usize;
// Then round down to the nearest byte.
let addr_in_bytes = addr_in_bits >> 3;
// And compute the containing BlockNumber.
addr_in_bytes >> (BLOCK_SIZE_MIN_POW + 1)
}
/// Given a block number and dataset length, calculates a data Range for the block.
pub fn block(&self, block: BlockNumber, data_len: usize) -> Range<usize> {
let block_addr = block << (BLOCK_SIZE_MIN_POW + 1);
// We extend the end of the block by a constant factor, so that it overlaps the next
// block. That ensures that we never need to read on a block boundary.
block_addr..(std::cmp::min(block_addr + BLOCK_SIZE_MIN + 8, data_len))
}
/// Calculates the number of blocks for the given data_len.
///
/// Usually only called at startup to pre-allocate structures.
pub fn block_count(&self, data_len: usize) -> usize {
let block_count = data_len / (BLOCK_SIZE_MIN as usize);
if data_len % (BLOCK_SIZE_MIN as usize) == 0 {
block_count
} else {
block_count + 1
}
}
/// Returns a range within the data which covers the given id_range.
///
/// NOTE: This method is used for batch reads which bypass blocks to avoid dealing with block
/// boundaries.
#[inline]
pub fn block_oblivious_range(&self, id_range: Range<u32>, data_len: usize) -> Range<usize> {
let start_in_bits = id_range.start * self.num_bits;
let start = (start_in_bits >> 3) as usize;
let end_in_bits = id_range.end * self.num_bits;
let end = (end_in_bits >> 3) as usize;
// TODO: We fetch more than we need and then truncate.
start..(std::cmp::min(end + 8, data_len))
}
#[inline]
pub fn get(&self, idx: u32, data: &[u8]) -> u64 {
self.get_from_subset(idx, 0, data)
}
/// Get the value at the given idx, which must exist within the given subset of the data.
#[inline]
pub fn get_from_subset(&self, idx: u32, data_offset: usize, data: &[u8]) -> u64 {
let addr_in_bits = idx * self.num_bits;
let addr = (addr_in_bits >> 3) as usize - data_offset;
let addr_in_bits = idx as usize * self.num_bits;
let addr = addr_in_bits >> 3;
if addr + 8 > data.len() {
if self.num_bits == 0 {
return 0;
}
let bit_shift = addr_in_bits & 7;
return self.get_slow_path(addr, bit_shift, data);
return self.get_slow_path(addr, bit_shift as u32, data);
}
let bit_shift = addr_in_bits & 7;
let bytes: [u8; 8] = (&data[addr..addr + 8]).try_into().unwrap();
@@ -173,7 +113,6 @@ impl BitUnpacker {
#[inline(never)]
fn get_slow_path(&self, addr: usize, bit_shift: u32, data: &[u8]) -> u64 {
let mut bytes: [u8; 8] = [0u8; 8];
let available_bytes = data.len() - addr;
// This function is meant to only be called if we did not have 8 bytes to load.
debug_assert!(available_bytes < 8);
@@ -189,25 +128,26 @@ impl BitUnpacker {
// #Panics
//
// This methods panics if `num_bits` is > 32.
fn get_batch_u32s(&self, start_idx: u32, data_offset: usize, data: &[u8], output: &mut [u32]) {
fn get_batch_u32s(&self, start_idx: u32, data: &[u8], output: &mut [u32]) {
assert!(
self.bit_width() <= 32,
"Bitwidth must be <= 32 to use this method."
);
let end_idx = start_idx + output.len() as u32;
let end_idx: u32 = start_idx + output.len() as u32;
let end_bit_read = end_idx * self.num_bits;
let end_byte_read = (end_bit_read + 7) / 8;
// We use `usize` here to avoid overflow issues.
let end_bit_read = (end_idx as usize) * self.num_bits;
let end_byte_read = end_bit_read.div_ceil(8);
assert!(
end_byte_read as usize <= data_offset + data.len(),
end_byte_read <= data.len(),
"Requested index is out of bounds."
);
// Simple slow implementation of get_batch_u32s, to deal with our ramps.
let get_batch_ramp = |start_idx: u32, output: &mut [u32]| {
for (out, idx) in output.iter_mut().zip(start_idx..) {
*out = self.get_from_subset(idx, data_offset, data) as u32;
*out = self.get(idx, data) as u32;
}
};
@@ -220,24 +160,24 @@ impl BitUnpacker {
// We want the start of the fast track to start align with bytes.
// A sufficient condition is to start with an idx that is a multiple of 8,
// so highway start is the closest multiple of 8 that is >= start_idx.
let entrance_ramp_len = 8 - (start_idx % 8) % 8;
let entrance_ramp_len: u32 = 8 - (start_idx % 8) % 8;
let highway_start: u32 = start_idx + entrance_ramp_len;
if highway_start + BitPacker1x::BLOCK_LEN as u32 > end_idx {
if highway_start + (BitPacker1x::BLOCK_LEN as u32) > end_idx {
// We don't have enough values to have even a single block of highway.
// Let's just supply the values the simple way.
get_batch_ramp(start_idx, output);
return;
}
let num_blocks: u32 = (end_idx - highway_start) / BitPacker1x::BLOCK_LEN as u32;
let num_blocks: usize = (end_idx - highway_start) as usize / BitPacker1x::BLOCK_LEN;
// Entrance ramp
get_batch_ramp(start_idx, &mut output[..entrance_ramp_len as usize]);
// Highway
let mut offset = ((highway_start * self.num_bits) as usize / 8) - data_offset;
let mut offset = (highway_start as usize * self.num_bits) / 8;
let mut output_cursor = (highway_start - start_idx) as usize;
for _ in 0..num_blocks {
offset += BitPacker1x.decompress(
@@ -249,7 +189,7 @@ impl BitUnpacker {
}
// Exit ramp
let highway_end = highway_start + num_blocks * BitPacker1x::BLOCK_LEN as u32;
let highway_end: u32 = highway_start + (num_blocks * BitPacker1x::BLOCK_LEN) as u32;
get_batch_ramp(highway_end, &mut output[output_cursor..]);
}
@@ -259,27 +199,16 @@ impl BitUnpacker {
id_range: Range<u32>,
data: &[u8],
positions: &mut Vec<u32>,
) {
self.get_ids_for_value_range_from_subset(range, id_range, 0, data, positions)
}
pub fn get_ids_for_value_range_from_subset(
&self,
range: RangeInclusive<u64>,
id_range: Range<u32>,
data_offset: usize,
data: &[u8],
positions: &mut Vec<u32>,
) {
if self.bit_width() > 32 {
self.get_ids_for_value_range_slow(range, id_range, data_offset, data, positions)
self.get_ids_for_value_range_slow(range, id_range, data, positions)
} else {
if *range.start() > u32::MAX as u64 {
positions.clear();
return;
}
let range_u32 = (*range.start() as u32)..=(*range.end()).min(u32::MAX as u64) as u32;
self.get_ids_for_value_range_fast(range_u32, id_range, data_offset, data, positions)
self.get_ids_for_value_range_fast(range_u32, id_range, data, positions)
}
}
@@ -287,7 +216,6 @@ impl BitUnpacker {
&self,
range: RangeInclusive<u64>,
id_range: Range<u32>,
data_offset: usize,
data: &[u8],
positions: &mut Vec<u32>,
) {
@@ -295,7 +223,7 @@ impl BitUnpacker {
for i in id_range {
// If we cared we could make this branchless, but the slow implementation should rarely
// kick in.
let val = self.get_from_subset(i, data_offset, data);
let val = self.get(i, data);
if range.contains(&val) {
positions.push(i);
}
@@ -306,12 +234,11 @@ impl BitUnpacker {
&self,
value_range: RangeInclusive<u32>,
id_range: Range<u32>,
data_offset: usize,
data: &[u8],
positions: &mut Vec<u32>,
) {
positions.resize(id_range.len(), 0u32);
self.get_batch_u32s(id_range.start, data_offset, data, positions);
self.get_batch_u32s(id_range.start, data, positions);
crate::filter_vec::filter_vec_in_place(value_range, id_range.start, positions)
}
}
@@ -402,14 +329,14 @@ mod test {
fn test_get_batch_panics_over_32_bits() {
let bitunpacker = BitUnpacker::new(33);
let mut output: [u32; 1] = [0u32];
bitunpacker.get_batch_u32s(0, 0, &[0, 0, 0, 0, 0, 0, 0, 0], &mut output[..]);
bitunpacker.get_batch_u32s(0, &[0, 0, 0, 0, 0, 0, 0, 0], &mut output[..]);
}
#[test]
fn test_get_batch_limit() {
let bitunpacker = BitUnpacker::new(1);
let mut output: [u32; 3] = [0u32, 0u32, 0u32];
bitunpacker.get_batch_u32s(8 * 4 - 3, 0, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
bitunpacker.get_batch_u32s(8 * 4 - 3, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
}
#[test]
@@ -418,7 +345,7 @@ mod test {
let bitunpacker = BitUnpacker::new(1);
let mut output: [u32; 3] = [0u32, 0u32, 0u32];
// We are missing exactly one bit.
bitunpacker.get_batch_u32s(8 * 4 - 2, 0, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
bitunpacker.get_batch_u32s(8 * 4 - 2, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
}
proptest::proptest! {
@@ -441,7 +368,7 @@ mod test {
for len in [0, 1, 2, 32, 33, 34, 64] {
for start_idx in 0u32..32u32 {
output.resize(len, 0);
bitunpacker.get_batch_u32s(start_idx, 0, &buffer, &mut output);
bitunpacker.get_batch_u32s(start_idx, &buffer, &mut output);
for (i, output_byte) in output.iter().enumerate() {
let expected = (start_idx + i as u32) & mask;
assert_eq!(*output_byte, expected);

View File

@@ -16,7 +16,7 @@ stacker = { version= "0.6", path = "../stacker", package="tantivy-stacker"}
sstable = { version= "0.6", path = "../sstable", package = "tantivy-sstable" }
common = { version= "0.10", path = "../common", package = "tantivy-common" }
tantivy-bitpacker = { version= "0.9", path = "../bitpacker/" }
serde = { version = "1.0.152", features = ["derive"] }
serde = "1.0.152"
downcast-rs = "2.0.1"
[dev-dependencies]

View File

@@ -1,6 +1,6 @@
use binggan::{InputGroup, black_box};
use common::*;
use tantivy_columnar::{Column, ValueRange};
use tantivy_columnar::Column;
pub mod common;
@@ -46,16 +46,16 @@ fn bench_group(mut runner: InputGroup<Column>) {
runner.register("access_first_vals", |column| {
let mut sum = 0;
const BLOCK_SIZE: usize = 32;
let mut docs = Vec::with_capacity(BLOCK_SIZE);
let mut buffer = Vec::with_capacity(BLOCK_SIZE);
let mut docs = vec![0; BLOCK_SIZE];
let mut buffer = vec![None; BLOCK_SIZE];
for i in (0..NUM_DOCS).step_by(BLOCK_SIZE) {
docs.clear();
// fill docs
#[allow(clippy::needless_range_loop)]
for idx in 0..BLOCK_SIZE {
docs.push(idx as u32 + i);
docs[idx] = idx as u32 + i;
}
buffer.clear();
column.first_vals_in_value_range(&mut docs, &mut buffer, ValueRange::All);
column.first_vals(&docs, &mut buffer);
for val in buffer.iter() {
let Some(val) = val else { continue };
sum += *val;

View File

@@ -40,14 +40,7 @@ fn main() {
let columnar_readers = columnar_readers.iter().collect::<Vec<_>>();
let merge_row_order = StackMergeOrder::stack(&columnar_readers[..]);
merge_columnar(
&columnar_readers,
&[],
merge_row_order.into(),
&mut out,
|| false,
)
.unwrap();
merge_columnar(&columnar_readers, &[], merge_row_order.into(), &mut out).unwrap();
Some(out.len() as u64)
},
);

View File

@@ -29,12 +29,20 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
}
}
#[inline]
pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column<T>, missing: T) {
pub fn fetch_block_with_missing(
&mut self,
docs: &[u32],
accessor: &Column<T>,
missing: Option<T>,
) {
self.fetch_block(docs, accessor);
// no missing values
if accessor.index.get_cardinality().is_full() {
return;
}
let Some(missing) = missing else {
return;
};
// We can compare docid_cache length with docs to find missing docs
// For multi value columns we can't rely on the length and always need to scan

View File

@@ -1,7 +1,6 @@
mod dictionary_encoded;
mod serialize;
use std::cell::RefCell;
use std::fmt::{self, Debug};
use std::io::Write;
use std::ops::{Range, RangeInclusive};
@@ -20,11 +19,6 @@ use crate::column_values::monotonic_mapping::StrictlyMonotonicMappingToInternal;
use crate::column_values::{ColumnValues, monotonic_map_column};
use crate::{Cardinality, DocId, EmptyColumnValues, MonotonicallyMappableToU64, RowId};
thread_local! {
static ROWS: RefCell<Vec<RowId>> = const { RefCell::new(Vec::new()) };
static DOCS: RefCell<Vec<DocId>> = const { RefCell::new(Vec::new()) };
}
#[derive(Clone)]
pub struct Column<T = u64> {
pub index: ColumnIndex,
@@ -95,6 +89,31 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
self.values_for_doc(row_id).next()
}
/// Load the first value for each docid in the provided slice.
#[inline]
pub fn first_vals(&self, docids: &[DocId], output: &mut [Option<T>]) {
match &self.index {
ColumnIndex::Empty { .. } => {}
ColumnIndex::Full => self.values.get_vals_opt(docids, output),
ColumnIndex::Optional(optional_index) => {
for (i, docid) in docids.iter().enumerate() {
output[i] = optional_index
.rank_if_exists(*docid)
.map(|rowid| self.values.get_val(rowid));
}
}
ColumnIndex::Multivalued(multivalued_index) => {
for (i, docid) in docids.iter().enumerate() {
let range = multivalued_index.range(*docid);
let is_empty = range.start == range.end;
if !is_empty {
output[i] = Some(self.values.get_val(range.start));
}
}
}
}
}
/// Translates a block of docids to row_ids.
///
/// returns the row_ids and the matching docids on the same index
@@ -124,7 +143,7 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
#[inline]
pub fn get_docids_for_value_range(
&self,
value_range: ValueRange<T>,
value_range: RangeInclusive<T>,
selected_docid_range: Range<u32>,
doc_ids: &mut Vec<u32>,
) {
@@ -149,181 +168,6 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
}
}
// Separate impl block for methods requiring `Default` for `T`.
impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static + Default> Column<T> {
/// Load the first value for each docid in the provided slice.
///
/// The `docids` vector is mutated: documents that do not match the `value_range` are removed.
/// The `values` vector is populated with the values of the remaining documents.
#[inline]
pub fn first_vals_in_value_range(
&self,
input_docs: &[DocId],
output: &mut Vec<crate::ComparableDoc<Option<T>, DocId>>,
value_range: ValueRange<T>,
) {
match (&self.index, value_range) {
(ColumnIndex::Empty { .. }, value_range) => {
let nulls_match = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(_) => false,
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
ValueRange::LessThan(_, nulls_match) => *nulls_match,
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
};
if nulls_match {
for &doc in input_docs {
output.push(crate::ComparableDoc {
doc,
sort_key: None,
});
}
}
}
(ColumnIndex::Full, value_range) => {
self.values
.get_vals_in_value_range(input_docs, input_docs, output, value_range);
}
(ColumnIndex::Optional(optional_index), value_range) => {
let nulls_match = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(_) => false,
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
ValueRange::LessThan(_, nulls_match) => *nulls_match,
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
};
let fallback_needed = ROWS.with(|rows_cell| {
DOCS.with(|docs_cell| {
let mut rows = rows_cell.borrow_mut();
let mut docs = docs_cell.borrow_mut();
rows.clear();
docs.clear();
let mut has_nulls = false;
for &doc_id in input_docs {
if let Some(row_id) = optional_index.rank_if_exists(doc_id) {
rows.push(row_id);
docs.push(doc_id);
} else {
has_nulls = true;
if nulls_match {
break;
}
}
}
if !has_nulls || !nulls_match {
self.values.get_vals_in_value_range(
&rows,
&docs,
output,
value_range.clone(),
);
return false;
}
true
})
});
if fallback_needed {
for &doc_id in input_docs {
if let Some(row_id) = optional_index.rank_if_exists(doc_id) {
let val = self.values.get_val(row_id);
let value_matches = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(&val),
ValueRange::GreaterThan(t, _) => val > *t,
ValueRange::GreaterThanOrEqual(t, _) => val >= *t,
ValueRange::LessThan(t, _) => val < *t,
ValueRange::LessThanOrEqual(t, _) => val <= *t,
};
if value_matches {
output.push(crate::ComparableDoc {
doc: doc_id,
sort_key: Some(val),
});
}
} else if nulls_match {
output.push(crate::ComparableDoc {
doc: doc_id,
sort_key: None,
});
}
}
}
}
(ColumnIndex::Multivalued(multivalued_index), value_range) => {
let nulls_match = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(_) => false,
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
ValueRange::GreaterThanOrEqual(_, nulls_match) => *nulls_match,
ValueRange::LessThan(_, nulls_match) => *nulls_match,
ValueRange::LessThanOrEqual(_, nulls_match) => *nulls_match,
};
for i in 0..input_docs.len() {
let docid = input_docs[i];
let row_range = multivalued_index.range(docid);
let is_empty = row_range.start == row_range.end;
if !is_empty {
let val = self.values.get_val(row_range.start);
let matches = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(&val),
ValueRange::GreaterThan(t, _) => val > *t,
ValueRange::GreaterThanOrEqual(t, _) => val >= *t,
ValueRange::LessThan(t, _) => val < *t,
ValueRange::LessThanOrEqual(t, _) => val <= *t,
};
if matches {
output.push(crate::ComparableDoc {
doc: docid,
sort_key: Some(val),
});
}
} else if nulls_match {
output.push(crate::ComparableDoc {
doc: docid,
sort_key: None,
});
}
}
}
}
}
}
/// A range of values.
///
/// This type is intended to be used in batch APIs, where the cost of unpacking the enum
/// is outweighed by the time spent processing a batch.
///
/// Implementers should pattern match on the variants to use optimized loops for each case.
#[derive(Clone, Debug)]
pub enum ValueRange<T> {
/// A range that includes both start and end.
Inclusive(RangeInclusive<T>),
/// A range that matches all values.
All,
/// A range that matches all values greater than the threshold.
/// The boolean flag indicates if null values should be included.
GreaterThan(T, bool),
/// A range that matches all values greater than or equal to the threshold.
/// The boolean flag indicates if null values should be included.
GreaterThanOrEqual(T, bool),
/// A range that matches all values less than the threshold.
/// The boolean flag indicates if null values should be included.
LessThan(T, bool),
/// A range that matches all values less than or equal to the threshold.
/// The boolean flag indicates if null values should be included.
LessThanOrEqual(T, bool),
}
impl BinarySerializable for Cardinality {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> std::io::Result<()> {
self.to_code().serialize(writer)

View File

@@ -2,7 +2,7 @@ use std::io;
use std::io::Write;
use std::sync::Arc;
use common::file_slice::FileSlice;
use common::OwnedBytes;
use sstable::Dictionary;
use crate::column::{BytesColumn, Column};
@@ -41,13 +41,12 @@ pub fn serialize_column_mappable_to_u64<T: MonotonicallyMappableToU64>(
}
pub fn open_column_u64<T: MonotonicallyMappableToU64>(
file_slice: FileSlice,
bytes: OwnedBytes,
format_version: Version,
) -> io::Result<Column<T>> {
let (body, column_index_num_bytes_payload) = file_slice.split_from_end(4);
let (body, column_index_num_bytes_payload) = bytes.rsplit(4);
let column_index_num_bytes = u32::from_le_bytes(
column_index_num_bytes_payload
.read_bytes()?
.as_slice()
.try_into()
.unwrap(),
@@ -62,13 +61,12 @@ pub fn open_column_u64<T: MonotonicallyMappableToU64>(
}
pub fn open_column_u128<T: MonotonicallyMappableToU128>(
file_slice: FileSlice,
bytes: OwnedBytes,
format_version: Version,
) -> io::Result<Column<T>> {
let (body, column_index_num_bytes_payload) = file_slice.split_from_end(4);
let (body, column_index_num_bytes_payload) = bytes.rsplit(4);
let column_index_num_bytes = u32::from_le_bytes(
column_index_num_bytes_payload
.read_bytes()?
.as_slice()
.try_into()
.unwrap(),
@@ -86,13 +84,12 @@ pub fn open_column_u128<T: MonotonicallyMappableToU128>(
///
/// See [`open_u128_as_compact_u64`] for more details.
pub fn open_column_u128_as_compact_u64(
file_slice: FileSlice,
bytes: OwnedBytes,
format_version: Version,
) -> io::Result<Column<u64>> {
let (body, column_index_num_bytes_payload) = file_slice.split_from_end(4);
let (body, column_index_num_bytes_payload) = bytes.rsplit(4);
let column_index_num_bytes = u32::from_le_bytes(
column_index_num_bytes_payload
.read_bytes()?
.as_slice()
.try_into()
.unwrap(),
@@ -106,21 +103,11 @@ pub fn open_column_u128_as_compact_u64(
})
}
pub fn open_column_bytes(
file_slice: FileSlice,
format_version: Version,
) -> io::Result<BytesColumn> {
let (body, dictionary_len_bytes) = file_slice.split_from_end(4);
let dictionary_len = u32::from_le_bytes(
dictionary_len_bytes
.read_bytes()?
.as_slice()
.try_into()
.unwrap(),
);
pub fn open_column_bytes(data: OwnedBytes, format_version: Version) -> io::Result<BytesColumn> {
let (body, dictionary_len_bytes) = data.rsplit(4);
let dictionary_len = u32::from_le_bytes(dictionary_len_bytes.as_slice().try_into().unwrap());
let (dictionary_bytes, column_bytes) = body.split(dictionary_len as usize);
let dictionary = Arc::new(Dictionary::open(dictionary_bytes)?);
let dictionary = Arc::new(Dictionary::from_bytes(dictionary_bytes)?);
let term_ord_column = crate::column::open_column_u64::<u64>(column_bytes, format_version)?;
Ok(BytesColumn {
dictionary,
@@ -128,7 +115,7 @@ pub fn open_column_bytes(
})
}
pub fn open_column_str(file_slice: FileSlice, format_version: Version) -> io::Result<StrColumn> {
let bytes_column = open_column_bytes(file_slice, format_version)?;
pub fn open_column_str(data: OwnedBytes, format_version: Version) -> io::Result<StrColumn> {
let bytes_column = open_column_bytes(data, format_version)?;
Ok(StrColumn::wrap(bytes_column))
}

View File

@@ -95,7 +95,7 @@ pub fn merge_column_index<'a>(
#[cfg(test)]
mod tests {
use common::file_slice::FileSlice;
use common::OwnedBytes;
use crate::column_index::merge::detect_cardinality;
use crate::column_index::multivalued_index::{
@@ -178,7 +178,7 @@ mod tests {
let mut output = Vec::new();
serialize_multivalued_index(&start_index_iterable, &mut output).unwrap();
let multivalue =
open_multivalued_index(FileSlice::from(output), crate::Version::V2).unwrap();
open_multivalued_index(OwnedBytes::new(output), crate::Version::V2).unwrap();
let start_indexes: Vec<RowId> = multivalue.get_start_index_column().iter().collect();
assert_eq!(&start_indexes, &[0, 3, 5]);
}
@@ -216,7 +216,7 @@ mod tests {
let mut output = Vec::new();
serialize_multivalued_index(&start_index_iterable, &mut output).unwrap();
let multivalue =
open_multivalued_index(FileSlice::from(output), crate::Version::V2).unwrap();
open_multivalued_index(OwnedBytes::new(output), crate::Version::V2).unwrap();
let start_indexes: Vec<RowId> = multivalue.get_start_index_column().iter().collect();
assert_eq!(&start_indexes, &[0, 3, 5, 6]);
}

View File

@@ -3,8 +3,7 @@ use std::io::Write;
use std::ops::Range;
use std::sync::Arc;
use common::CountingWriter;
use common::file_slice::FileSlice;
use common::{CountingWriter, OwnedBytes};
use super::optional_index::{open_optional_index, serialize_optional_index};
use super::{OptionalIndex, SerializableOptionalIndex, Set};
@@ -45,26 +44,21 @@ pub fn serialize_multivalued_index(
}
pub fn open_multivalued_index(
file_slice: FileSlice,
bytes: OwnedBytes,
format_version: Version,
) -> io::Result<MultiValueIndex> {
match format_version {
Version::V1 => {
let start_index_column: Arc<dyn ColumnValues<RowId>> =
load_u64_based_column_values(file_slice)?;
load_u64_based_column_values(bytes)?;
Ok(MultiValueIndex::MultiValueIndexV1(MultiValueIndexV1 {
start_index_column,
}))
}
Version::V2 => {
let (body_bytes, optional_index_len) = file_slice.split_from_end(4);
let optional_index_len = u32::from_le_bytes(
optional_index_len
.read_bytes()?
.as_slice()
.try_into()
.unwrap(),
);
let (body_bytes, optional_index_len) = bytes.rsplit(4);
let optional_index_len =
u32::from_le_bytes(optional_index_len.as_slice().try_into().unwrap());
let (optional_index_bytes, start_index_bytes) =
body_bytes.split(optional_index_len as usize);
let optional_index = open_optional_index(optional_index_bytes)?;
@@ -191,8 +185,8 @@ impl MultiValueIndex {
};
let mut buffer = Vec::new();
serialize_multivalued_index(&serializable_multivalued_index, &mut buffer).unwrap();
let file_slice = FileSlice::from(buffer);
open_multivalued_index(file_slice, Version::V2).unwrap()
let bytes = OwnedBytes::new(buffer);
open_multivalued_index(bytes, Version::V2).unwrap()
}
pub fn get_start_index_column(&self) -> &Arc<dyn crate::ColumnValues<RowId>> {
@@ -339,7 +333,7 @@ mod tests {
use std::ops::Range;
use super::MultiValueIndex;
use crate::{ColumnarReader, DynamicColumn, ValueRange};
use crate::{ColumnarReader, DynamicColumn};
fn index_to_pos_helper(
index: &MultiValueIndex,
@@ -419,7 +413,7 @@ mod tests {
assert_eq!(row_id_range, 0..4);
let check = |range, expected| {
let full_range = ValueRange::All;
let full_range = 0..=u64::MAX;
let mut docids = Vec::new();
column.get_docids_for_value_range(full_range, range, &mut docids);
assert_eq!(docids, expected);

View File

@@ -4,7 +4,6 @@ use std::sync::Arc;
mod set;
mod set_block;
use common::file_slice::FileSlice;
use common::{BinarySerializable, OwnedBytes, VInt};
pub use set::{SelectCursor, Set, SetCodec};
use set_block::{
@@ -269,8 +268,8 @@ impl OptionalIndex {
);
let mut buffer = Vec::new();
serialize_optional_index(&row_ids, num_rows, &mut buffer).unwrap();
let file_slice = FileSlice::from(buffer);
open_optional_index(file_slice).unwrap()
let bytes = OwnedBytes::new(buffer);
open_optional_index(bytes).unwrap()
}
pub fn num_docs(&self) -> RowId {
@@ -487,17 +486,10 @@ fn deserialize_optional_index_block_metadatas(
(block_metas.into_boxed_slice(), non_null_rows_before_block)
}
pub fn open_optional_index(file_slice: FileSlice) -> io::Result<OptionalIndex> {
let (bytes, num_non_empty_blocks_bytes) = file_slice.split_from_end(2);
let num_non_empty_block_bytes = u16::from_le_bytes(
num_non_empty_blocks_bytes
.read_bytes()?
.as_slice()
.try_into()
.unwrap(),
);
let mut bytes = bytes.read_bytes()?;
pub fn open_optional_index(bytes: OwnedBytes) -> io::Result<OptionalIndex> {
let (mut bytes, num_non_empty_blocks_bytes) = bytes.rsplit(2);
let num_non_empty_block_bytes =
u16::from_le_bytes(num_non_empty_blocks_bytes.as_slice().try_into().unwrap());
let num_docs = VInt::deserialize_u64(&mut bytes)? as u32;
let block_metas_num_bytes =
num_non_empty_block_bytes as usize * SERIALIZED_BLOCK_META_NUM_BYTES;

View File

@@ -59,7 +59,7 @@ fn test_with_random_sets_simple() {
let vals = 10..ELEMENTS_PER_BLOCK * 2;
let mut out: Vec<u8> = Vec::new();
serialize_optional_index(&vals, 100, &mut out).unwrap();
let null_index = open_optional_index(FileSlice::from(out)).unwrap();
let null_index = open_optional_index(OwnedBytes::new(out)).unwrap();
let ranks: Vec<u32> = (65_472u32..65_473u32).collect();
let els: Vec<u32> = ranks.iter().copied().map(|rank| rank + 10).collect();
let mut select_cursor = null_index.select_cursor();
@@ -102,7 +102,7 @@ impl<'a> Iterable<RowId> for &'a [bool] {
fn test_null_index(data: &[bool]) {
let mut out: Vec<u8> = Vec::new();
serialize_optional_index(&data, data.len() as RowId, &mut out).unwrap();
let null_index = open_optional_index(FileSlice::from(out)).unwrap();
let null_index = open_optional_index(OwnedBytes::new(out)).unwrap();
let orig_idx_with_value: Vec<u32> = data
.iter()
.enumerate()
@@ -223,170 +223,3 @@ fn test_optional_index_for_tests() {
assert!(!optional_index.contains(3));
assert_eq!(optional_index.num_docs(), 4);
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use test::Bencher;
use super::*;
const TOTAL_NUM_VALUES: u32 = 1_000_000;
fn gen_bools(fill_ratio: f64) -> OptionalIndex {
let mut out = Vec::new();
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
let vals: Vec<RowId> = (0..TOTAL_NUM_VALUES)
.map(|_| rng.gen_bool(fill_ratio))
.enumerate()
.filter(|(_pos, val)| *val)
.map(|(pos, _)| pos as RowId)
.collect();
serialize_optional_index(&&vals[..], TOTAL_NUM_VALUES, &mut out).unwrap();
open_optional_index(FileSlice::from(out)).unwrap()
}
fn random_range_iterator(
start: u32,
end: u32,
avg_step_size: u32,
avg_deviation: u32,
) -> impl Iterator<Item = u32> {
let mut rng: StdRng = StdRng::from_seed([1u8; 32]);
let mut current = start;
std::iter::from_fn(move || {
current += rng.gen_range(avg_step_size - avg_deviation..=avg_step_size + avg_deviation);
if current >= end { None } else { Some(current) }
})
}
fn n_percent_step_iterator(percent: f32, num_values: u32) -> impl Iterator<Item = u32> {
let ratio = percent / 100.0;
let step_size = (1f32 / ratio) as u32;
let deviation = step_size - 1;
random_range_iterator(0, num_values, step_size, deviation)
}
fn walk_over_data(codec: &OptionalIndex, avg_step_size: u32) -> Option<u32> {
walk_over_data_from_positions(
codec,
random_range_iterator(0, TOTAL_NUM_VALUES, avg_step_size, 0),
)
}
fn walk_over_data_from_positions(
codec: &OptionalIndex,
positions: impl Iterator<Item = u32>,
) -> Option<u32> {
let mut dense_idx: Option<u32> = None;
for idx in positions {
dense_idx = dense_idx.or(codec.rank_if_exists(idx));
}
dense_idx
}
#[bench]
fn bench_translate_orig_to_codec_1percent_filled_10percent_hit(bench: &mut Bencher) {
let codec = gen_bools(0.01f64);
bench.iter(|| walk_over_data(&codec, 100));
}
#[bench]
fn bench_translate_orig_to_codec_5percent_filled_10percent_hit(bench: &mut Bencher) {
let codec = gen_bools(0.05f64);
bench.iter(|| walk_over_data(&codec, 100));
}
#[bench]
fn bench_translate_orig_to_codec_5percent_filled_1percent_hit(bench: &mut Bencher) {
let codec = gen_bools(0.05f64);
bench.iter(|| walk_over_data(&codec, 1000));
}
#[bench]
fn bench_translate_orig_to_codec_full_scan_1percent_filled(bench: &mut Bencher) {
let codec = gen_bools(0.01f64);
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
}
#[bench]
fn bench_translate_orig_to_codec_full_scan_10percent_filled(bench: &mut Bencher) {
let codec = gen_bools(0.1f64);
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
}
#[bench]
fn bench_translate_orig_to_codec_full_scan_90percent_filled(bench: &mut Bencher) {
let codec = gen_bools(0.9f64);
bench.iter(|| walk_over_data_from_positions(&codec, 0..TOTAL_NUM_VALUES));
}
#[bench]
fn bench_translate_orig_to_codec_10percent_filled_1percent_hit(bench: &mut Bencher) {
let codec = gen_bools(0.1f64);
bench.iter(|| walk_over_data(&codec, 100));
}
#[bench]
fn bench_translate_orig_to_codec_50percent_filled_1percent_hit(bench: &mut Bencher) {
let codec = gen_bools(0.5f64);
bench.iter(|| walk_over_data(&codec, 100));
}
#[bench]
fn bench_translate_orig_to_codec_90percent_filled_1percent_hit(bench: &mut Bencher) {
let codec = gen_bools(0.9f64);
bench.iter(|| walk_over_data(&codec, 100));
}
#[bench]
fn bench_translate_codec_to_orig_1percent_filled_0comma005percent_hit(bench: &mut Bencher) {
bench_translate_codec_to_orig_util(0.01f64, 0.005f32, bench);
}
#[bench]
fn bench_translate_codec_to_orig_10percent_filled_0comma005percent_hit(bench: &mut Bencher) {
bench_translate_codec_to_orig_util(0.1f64, 0.005f32, bench);
}
#[bench]
fn bench_translate_codec_to_orig_1percent_filled_10percent_hit(bench: &mut Bencher) {
bench_translate_codec_to_orig_util(0.01f64, 10f32, bench);
}
#[bench]
fn bench_translate_codec_to_orig_1percent_filled_full_scan(bench: &mut Bencher) {
bench_translate_codec_to_orig_util(0.01f64, 100f32, bench);
}
fn bench_translate_codec_to_orig_util(
percent_filled: f64,
percent_hit: f32,
bench: &mut Bencher,
) {
let codec = gen_bools(percent_filled);
let num_non_nulls = codec.num_non_nulls();
let idxs: Vec<u32> = if percent_hit == 100.0f32 {
(0..num_non_nulls).collect()
} else {
n_percent_step_iterator(percent_hit, num_non_nulls).collect()
};
let mut output = vec![0u32; idxs.len()];
bench.iter(|| {
output.copy_from_slice(&idxs[..]);
codec.select_batch(&mut output);
});
}
#[bench]
fn bench_translate_codec_to_orig_90percent_filled_0comma005percent_hit(bench: &mut Bencher) {
bench_translate_codec_to_orig_util(0.9f64, 0.005, bench);
}
#[bench]
fn bench_translate_codec_to_orig_90percent_filled_full_scan(bench: &mut Bencher) {
bench_translate_codec_to_orig_util(0.9f64, 100.0f32, bench);
}
}

View File

@@ -1,8 +1,7 @@
use std::io;
use std::io::Write;
use common::file_slice::FileSlice;
use common::{CountingWriter, HasLen};
use common::{CountingWriter, OwnedBytes};
use super::OptionalIndex;
use super::multivalued_index::SerializableMultivalueIndex;
@@ -66,28 +65,27 @@ pub fn serialize_column_index(
/// Open a serialized column index.
pub fn open_column_index(
file_slice: FileSlice,
mut bytes: OwnedBytes,
format_version: Version,
) -> io::Result<ColumnIndex> {
if file_slice.len() == 0 {
if bytes.is_empty() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Failed to deserialize column index. Empty buffer.",
));
}
let (header, body) = file_slice.split(1);
let cardinality_code = header.read_bytes()?.as_slice()[0];
let cardinality_code = bytes[0];
let cardinality = Cardinality::try_from_code(cardinality_code)?;
bytes.advance(1);
match cardinality {
Cardinality::Full => Ok(ColumnIndex::Full),
Cardinality::Optional => {
let optional_index = super::optional_index::open_optional_index(body)?;
let optional_index = super::optional_index::open_optional_index(bytes)?;
Ok(ColumnIndex::Optional(optional_index))
}
Cardinality::Multivalued => {
let multivalue_index =
super::multivalued_index::open_multivalued_index(body, format_version)?;
super::multivalued_index::open_multivalued_index(bytes, format_version)?;
Ok(ColumnIndex::Multivalued(multivalue_index))
}
}

View File

@@ -7,15 +7,13 @@
//! - Monotonically map values to u64/u128
use std::fmt::Debug;
use std::ops::Range;
use std::ops::{Range, RangeInclusive};
use std::sync::Arc;
use downcast_rs::DowncastSync;
pub use monotonic_mapping::{MonotonicallyMappableToU64, StrictlyMonotonicFn};
pub use monotonic_mapping_u128::MonotonicallyMappableToU128;
use crate::column::ValueRange;
mod merge;
pub(crate) mod monotonic_mapping;
pub(crate) mod monotonic_mapping_u128;
@@ -29,7 +27,8 @@ mod monotonic_column;
pub(crate) use merge::MergedColumnValues;
pub use stats::ColumnStats;
pub use u64_based::{
ALL_U64_CODEC_TYPES, CodecType, load_u64_based_column_values, serialize_u64_based_column_values,
ALL_U64_CODEC_TYPES, CodecType, load_u64_based_column_values,
serialize_and_load_u64_based_column_values, serialize_u64_based_column_values,
};
pub use u128_based::{
CompactSpaceU64Accessor, open_u128_as_compact_u64, open_u128_mapped,
@@ -110,307 +109,6 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
}
}
/// Load the values for the provided docids.
///
/// The values are filtered by the provided value range.
fn get_vals_in_value_range(
&self,
input_indexes: &[u32],
input_doc_ids: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
value_range: ValueRange<T>,
) {
let len = input_indexes.len();
let mut read_head = 0;
match value_range {
ValueRange::All => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
read_head += 4;
}
}
ValueRange::Inclusive(ref range) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if range.contains(&val0) {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if range.contains(&val1) {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if range.contains(&val2) {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if range.contains(&val3) {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
ValueRange::GreaterThan(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 > *threshold {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if val1 > *threshold {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if val2 > *threshold {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if val3 > *threshold {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
ValueRange::GreaterThanOrEqual(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 >= *threshold {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if val1 >= *threshold {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if val2 >= *threshold {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if val3 >= *threshold {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
ValueRange::LessThan(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 < *threshold {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if val1 < *threshold {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if val2 < *threshold {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if val3 < *threshold {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
ValueRange::LessThanOrEqual(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let doc0 = input_doc_ids[read_head];
let doc1 = input_doc_ids[read_head + 1];
let doc2 = input_doc_ids[read_head + 2];
let doc3 = input_doc_ids[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 <= *threshold {
output.push(crate::ComparableDoc {
doc: doc0,
sort_key: Some(val0),
});
}
if val1 <= *threshold {
output.push(crate::ComparableDoc {
doc: doc1,
sort_key: Some(val1),
});
}
if val2 <= *threshold {
output.push(crate::ComparableDoc {
doc: doc2,
sort_key: Some(val2),
});
}
if val3 <= *threshold {
output.push(crate::ComparableDoc {
doc: doc3,
sort_key: Some(val3),
});
}
read_head += 4;
}
}
}
// Process remaining elements (0 to 3)
while read_head < len {
let idx = input_indexes[read_head];
let doc = input_doc_ids[read_head];
let val = self.get_val(idx);
let matches = match value_range {
// 'value_range' is still moved here. This is the outer `value_range`
ValueRange::All => true,
ValueRange::Inclusive(ref r) => r.contains(&val),
ValueRange::GreaterThan(ref t, _) => val > *t,
ValueRange::GreaterThanOrEqual(ref t, _) => val >= *t,
ValueRange::LessThan(ref t, _) => val < *t,
ValueRange::LessThanOrEqual(ref t, _) => val <= *t,
};
if matches {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(val),
});
}
read_head += 1;
}
}
/// Fills an output buffer with the fast field values
/// associated with the `DocId` going from
/// `start` to `start + output.len()`.
@@ -431,54 +129,15 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
/// Note that position == docid for single value fast fields
fn get_row_ids_for_value_range(
&self,
value_range: ValueRange<T>,
value_range: RangeInclusive<T>,
row_id_range: Range<RowId>,
row_id_hits: &mut Vec<RowId>,
) {
let row_id_range = row_id_range.start..row_id_range.end.min(self.num_vals());
match value_range {
ValueRange::Inclusive(range) => {
for idx in row_id_range {
let val = self.get_val(idx);
if range.contains(&val) {
row_id_hits.push(idx);
}
}
}
ValueRange::GreaterThan(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val > threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val >= threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::LessThan(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val < threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::LessThanOrEqual(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val <= threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::All => {
row_id_hits.extend(row_id_range);
for idx in row_id_range {
let val = self.get_val(idx);
if value_range.contains(&val) {
row_id_hits.push(idx);
}
}
}
@@ -534,17 +193,6 @@ impl<T: PartialOrd + Default> ColumnValues<T> for EmptyColumnValues {
fn num_vals(&self) -> u32 {
0
}
fn get_vals_in_value_range(
&self,
input_indexes: &[u32],
input_doc_ids: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
value_range: ValueRange<T>,
) {
let _ = (input_indexes, input_doc_ids, output, value_range);
panic!("Internal Error: Called get_vals_in_value_range of empty column.")
}
}
impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnValues<T>> {
@@ -558,18 +206,6 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
self.as_ref().get_vals_opt(indexes, output)
}
#[inline(always)]
fn get_vals_in_value_range(
&self,
input_indexes: &[u32],
input_doc_ids: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
value_range: ValueRange<T>,
) {
self.as_ref()
.get_vals_in_value_range(input_indexes, input_doc_ids, output, value_range)
}
#[inline(always)]
fn min_value(&self) -> T {
self.as_ref().min_value()
@@ -598,7 +234,7 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
#[inline(always)]
fn get_row_ids_for_value_range(
&self,
range: ValueRange<T>,
range: RangeInclusive<T>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {

View File

@@ -1,9 +1,8 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use std::ops::Range;
use std::ops::{Range, RangeInclusive};
use crate::ColumnValues;
use crate::column::ValueRange;
use crate::column_values::monotonic_mapping::StrictlyMonotonicFn;
struct MonotonicMappingColumn<C, T, Input> {
@@ -81,52 +80,16 @@ where
fn get_row_ids_for_value_range(
&self,
range: ValueRange<Output>,
range: RangeInclusive<Output>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {
match range {
ValueRange::Inclusive(range) => self.from_column.get_row_ids_for_value_range(
ValueRange::Inclusive(
self.monotonic_mapping.inverse(range.start().clone())
..=self.monotonic_mapping.inverse(range.end().clone()),
),
doc_id_range,
positions,
),
ValueRange::All => self.from_column.get_row_ids_for_value_range(
ValueRange::All,
doc_id_range,
positions,
),
ValueRange::GreaterThan(threshold, _) => self.from_column.get_row_ids_for_value_range(
ValueRange::GreaterThan(self.monotonic_mapping.inverse(threshold), false),
doc_id_range,
positions,
),
ValueRange::GreaterThanOrEqual(threshold, _) => {
self.from_column.get_row_ids_for_value_range(
ValueRange::GreaterThanOrEqual(
self.monotonic_mapping.inverse(threshold),
false,
),
doc_id_range,
positions,
)
}
ValueRange::LessThan(threshold, _) => self.from_column.get_row_ids_for_value_range(
ValueRange::LessThan(self.monotonic_mapping.inverse(threshold), false),
doc_id_range,
positions,
),
ValueRange::LessThanOrEqual(threshold, _) => {
self.from_column.get_row_ids_for_value_range(
ValueRange::LessThanOrEqual(self.monotonic_mapping.inverse(threshold), false),
doc_id_range,
positions,
)
}
}
self.from_column.get_row_ids_for_value_range(
self.monotonic_mapping.inverse(range.start().clone())
..=self.monotonic_mapping.inverse(range.end().clone()),
doc_id_range,
positions,
)
}
// We voluntarily do not implement get_range as it yields a regression,

View File

@@ -2,8 +2,7 @@ use std::io;
use std::io::Write;
use std::num::NonZeroU64;
use common::file_slice::FileSlice;
use common::{BinarySerializable, HasLen, VInt};
use common::{BinarySerializable, VInt};
use crate::RowId;
@@ -28,55 +27,6 @@ impl ColumnStats {
}
}
impl ColumnStats {
/// Deserialize from the tail of the given FileSlice, and return the stats and remaining prefix
/// FileSlice.
pub fn deserialize_from_tail(file_slice: FileSlice) -> io::Result<(Self, FileSlice)> {
// [`deserialize_with_size`] deserializes 4 variable-width encoded u64s, which
// could end up being, in the worst case, 9 bytes each. this is where the 36 comes from
let (stats, _) = file_slice.clone().split(36.min(file_slice.len())); // hope that's enough bytes
let mut stats = stats.read_bytes()?;
let (stats, stats_nbytes) = ColumnStats::deserialize_with_size(&mut stats)?;
let (_, remainder) = file_slice.split(stats_nbytes);
Ok((stats, remainder))
}
/// Same as [`BinarySeerializable::deserialize`] but also returns the number of bytes
/// consumed from the reader `R`
fn deserialize_with_size<R: io::Read>(reader: &mut R) -> io::Result<(Self, usize)> {
let mut nbytes = 0;
let (min_value, len) = VInt::deserialize_with_size(reader)?;
let min_value = min_value.0;
nbytes += len;
let (gcd, len) = VInt::deserialize_with_size(reader)?;
let gcd = gcd.0;
let gcd = NonZeroU64::new(gcd)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "GCD of 0 is forbidden"))?;
nbytes += len;
let (amplitude, len) = VInt::deserialize_with_size(reader)?;
let amplitude = amplitude.0 * gcd.get();
let max_value = min_value + amplitude;
nbytes += len;
let (num_rows, len) = VInt::deserialize_with_size(reader)?;
let num_rows = num_rows.0 as RowId;
nbytes += len;
Ok((
ColumnStats {
min_value,
max_value,
num_rows,
gcd,
},
nbytes,
))
}
}
impl BinarySerializable for ColumnStats {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
VInt(self.min_value).serialize(writer)?;

View File

@@ -25,7 +25,6 @@ use common::{BinarySerializable, CountingWriter, OwnedBytes, VInt, VIntU128};
use tantivy_bitpacker::{BitPacker, BitUnpacker};
use crate::RowId;
use crate::column::ValueRange;
use crate::column_values::ColumnValues;
/// The cost per blank is quite hard actually, since blanks are delta encoded, the actual cost of
@@ -339,48 +338,14 @@ impl ColumnValues<u64> for CompactSpaceU64Accessor {
#[inline]
fn get_row_ids_for_value_range(
&self,
value_range: ValueRange<u64>,
value_range: RangeInclusive<u64>,
position_range: Range<u32>,
positions: &mut Vec<u32>,
) {
match value_range {
ValueRange::Inclusive(value_range) => {
let value_range = ValueRange::Inclusive(
self.0.compact_to_u128(*value_range.start() as u32)
..=self.0.compact_to_u128(*value_range.end() as u32),
);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::All => {
let position_range = position_range.start..position_range.end.min(self.num_vals());
positions.extend(position_range);
}
ValueRange::GreaterThan(threshold, _) => {
let value_range =
ValueRange::GreaterThan(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
let value_range =
ValueRange::GreaterThanOrEqual(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::LessThan(threshold, _) => {
let value_range =
ValueRange::LessThan(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::LessThanOrEqual(threshold, _) => {
let value_range =
ValueRange::LessThanOrEqual(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
}
let value_range = self.0.compact_to_u128(*value_range.start() as u32)
..=self.0.compact_to_u128(*value_range.end() as u32);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
}
@@ -410,47 +375,10 @@ impl ColumnValues<u128> for CompactSpaceDecompressor {
#[inline]
fn get_row_ids_for_value_range(
&self,
value_range: ValueRange<u128>,
value_range: RangeInclusive<u128>,
position_range: Range<u32>,
positions: &mut Vec<u32>,
) {
let value_range = match value_range {
ValueRange::Inclusive(value_range) => value_range,
ValueRange::All => {
let position_range = position_range.start..position_range.end.min(self.num_vals());
positions.extend(position_range);
return;
}
ValueRange::GreaterThan(threshold, _) => {
let max = self.max_value();
if threshold >= max {
return;
}
(threshold + 1)..=max
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
let max = self.max_value();
if threshold > max {
return;
}
threshold..=max
}
ValueRange::LessThan(threshold, _) => {
let min = self.min_value();
if threshold <= min {
return;
}
min..=(threshold - 1)
}
ValueRange::LessThanOrEqual(threshold, _) => {
let min = self.min_value();
if threshold < min {
return;
}
min..=threshold
}
};
if value_range.start() > value_range.end() {
return;
}
@@ -632,7 +560,7 @@ mod tests {
.collect::<Vec<_>>();
let mut positions = Vec::new();
decompressor.get_row_ids_for_value_range(
ValueRange::Inclusive(range),
range,
0..decompressor.num_vals(),
&mut positions,
);
@@ -676,11 +604,7 @@ mod tests {
let val = *val;
let pos = pos as u32;
let mut positions = Vec::new();
decomp.get_row_ids_for_value_range(
ValueRange::Inclusive(val..=val),
pos..pos + 1,
&mut positions,
);
decomp.get_row_ids_for_value_range(val..=val, pos..pos + 1, &mut positions);
assert_eq!(positions, vec![pos]);
}
@@ -822,11 +746,7 @@ mod tests {
doc_id_range: Range<u32>,
) -> Vec<u32> {
let mut positions = Vec::new();
column.get_row_ids_for_value_range(
ValueRange::Inclusive(value_range),
doc_id_range,
&mut positions,
);
column.get_row_ids_for_value_range(value_range, doc_id_range, &mut positions);
positions
}
@@ -849,7 +769,7 @@ mod tests {
];
let mut out = Vec::new();
serialize_column_values_u128(&&vals[..], &mut out).unwrap();
let decomp = open_u128_mapped(FileSlice::from(out)).unwrap();
let decomp = open_u128_mapped(OwnedBytes::new(out)).unwrap();
let complete_range = 0..vals.len() as u32;
assert_eq!(
@@ -903,7 +823,6 @@ mod tests {
let _data = test_aux_vals(vals);
}
use common::file_slice::FileSlice;
use proptest::prelude::*;
fn num_strategy() -> impl Strategy<Value = u128> {

View File

@@ -5,8 +5,7 @@ use std::sync::Arc;
mod compact_space;
use common::file_slice::FileSlice;
use common::{BinarySerializable, VInt};
use common::{BinarySerializable, OwnedBytes, VInt};
pub use compact_space::{
CompactSpaceCompressor, CompactSpaceDecompressor, CompactSpaceU64Accessor,
};
@@ -102,9 +101,8 @@ impl U128FastFieldCodecType {
/// Returns the correct codec reader wrapped in the `Arc` for the data.
pub fn open_u128_mapped<T: MonotonicallyMappableToU128 + Debug>(
file_slice: FileSlice,
mut bytes: OwnedBytes,
) -> io::Result<Arc<dyn ColumnValues<T>>> {
let mut bytes = file_slice.read_bytes()?;
let header = U128Header::deserialize(&mut bytes)?;
assert_eq!(header.codec_type, U128FastFieldCodecType::CompactSpace);
let reader = CompactSpaceDecompressor::open(bytes)?;
@@ -122,8 +120,7 @@ pub fn open_u128_mapped<T: MonotonicallyMappableToU128 + Debug>(
/// # Notice
/// In case there are new codecs added, check for usages of `CompactSpaceDecompressorU64` and
/// also handle the new codecs.
pub fn open_u128_as_compact_u64(file_slice: FileSlice) -> io::Result<Arc<dyn ColumnValues<u64>>> {
let mut bytes = file_slice.read_bytes()?;
pub fn open_u128_as_compact_u64(mut bytes: OwnedBytes) -> io::Result<Arc<dyn ColumnValues<u64>>> {
let header = U128Header::deserialize(&mut bytes)?;
assert_eq!(header.codec_type, U128FastFieldCodecType::CompactSpace);
let reader = CompactSpaceU64Accessor::open(bytes)?;

View File

@@ -1,14 +1,11 @@
use std::io::{self, Write};
use std::num::NonZeroU64;
use std::ops::{Range, RangeInclusive};
use std::sync::{Arc, OnceLock};
use common::file_slice::FileSlice;
use common::{BinarySerializable, HasLen, OwnedBytes};
use common::{BinarySerializable, OwnedBytes};
use fastdivide::DividerU64;
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
use crate::column::ValueRange;
use crate::column_values::u64_based::{ColumnCodec, ColumnCodecEstimator, ColumnStats};
use crate::{ColumnValues, RowId};
@@ -16,40 +13,9 @@ use crate::{ColumnValues, RowId};
/// fast field is required.
#[derive(Clone)]
pub struct BitpackedReader {
data: FileSlice,
data: OwnedBytes,
bit_unpacker: BitUnpacker,
stats: ColumnStats,
blocks: Arc<[OnceLock<Block>]>,
}
impl BitpackedReader {
#[inline(always)]
fn unpack_val(&self, doc: u32) -> u64 {
let block_num = self.bit_unpacker.block_num(doc);
if block_num == 0 && self.blocks.len() == 0 {
return 0;
}
let block = self.blocks[block_num].get_or_init(|| {
let block_range = self.bit_unpacker.block(block_num, self.data.len());
let offset = block_range.start;
let data = self
.data
.slice(block_range)
.read_bytes()
.expect("Failed to read column values.");
Block { offset, data }
});
self.bit_unpacker
.get_from_subset(doc, block.offset, &block.data)
}
}
struct Block {
offset: usize,
data: OwnedBytes,
}
#[inline(always)]
@@ -91,9 +57,8 @@ fn transform_range_before_linear_transformation(
impl ColumnValues for BitpackedReader {
#[inline(always)]
fn get_val(&self, doc: u32) -> u64 {
self.stats.min_value + self.stats.gcd.get() * self.unpack_val(doc)
self.stats.min_value + self.stats.gcd.get() * self.bit_unpacker.get(doc, &self.data)
}
#[inline]
fn min_value(&self) -> u64 {
self.stats.min_value
@@ -107,329 +72,24 @@ impl ColumnValues for BitpackedReader {
self.stats.num_rows
}
fn get_vals_in_value_range(
&self,
input_indexes: &[u32],
input_doc_ids: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<u64>, crate::DocId>>,
value_range: ValueRange<u64>,
) {
match value_range {
ValueRange::All => {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
}
ValueRange::Inclusive(range) => {
if let Some(transformed_range) =
transform_range_before_linear_transformation(&self.stats, range)
{
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.unpack_val(idx);
if transformed_range.contains(&raw_val) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::GreaterThan(threshold, _) => {
if threshold < self.stats.min_value {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold >= self.stats.max_value {
// All filtered out
} else {
let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get();
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.unpack_val(idx);
if raw_val > raw_threshold {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
if threshold <= self.stats.min_value {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold > self.stats.max_value {
// All filtered out
} else {
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
let raw_threshold = (diff + gcd - 1) / gcd;
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.unpack_val(idx);
if raw_val >= raw_threshold {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::LessThan(threshold, _) => {
if threshold > self.stats.max_value {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold <= self.stats.min_value {
// All filtered out
} else {
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
let raw_threshold = if diff % gcd == 0 {
diff / gcd
} else {
diff / gcd + 1
};
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.unpack_val(idx);
if raw_val < raw_threshold {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::LessThanOrEqual(threshold, _) => {
if threshold >= self.stats.max_value {
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold < self.stats.min_value {
// All filtered out
} else {
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
let raw_threshold = diff / gcd;
for (&idx, &doc) in input_indexes.iter().zip(input_doc_ids.iter()) {
let raw_val = self.unpack_val(idx);
if raw_val <= raw_threshold {
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
}
}
fn get_row_ids_for_value_range(
&self,
range: ValueRange<u64>,
range: RangeInclusive<u64>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {
match range {
ValueRange::All => {
positions.extend(doc_id_range);
return;
}
ValueRange::Inclusive(range) => {
let Some(transformed_range) =
transform_range_before_linear_transformation(&self.stats, range)
else {
positions.clear();
return;
};
// TODO: This does not use the `self.blocks` cache, because callers are usually
// already doing sequential, and fairly dense reads. Fix it to
// iterate over blocks if that assumption turns out to be incorrect!
let data_range = self
.bit_unpacker
.block_oblivious_range(doc_id_range.clone(), self.data.len());
let data_offset = data_range.start;
let data_subset = self
.data
.slice(data_range)
.read_bytes()
.expect("Failed to read column values.");
self.bit_unpacker.get_ids_for_value_range_from_subset(
transformed_range,
doc_id_range,
data_offset,
&data_subset,
positions,
);
}
ValueRange::GreaterThan(threshold, _) => {
if threshold < self.stats.min_value {
positions.extend(doc_id_range);
return;
}
if threshold >= self.stats.max_value {
return;
}
let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get();
// We want raw > raw_threshold.
// bit_unpacker.get_ids_for_value_range_from_subset takes a RangeInclusive.
// We can construct a RangeInclusive: (raw_threshold + 1) ..= u64::MAX
// But max raw value is known? (max_value - min_value) / gcd.
let max_raw = (self.stats.max_value - self.stats.min_value) / self.stats.gcd.get();
let transformed_range = (raw_threshold + 1)..=max_raw;
let data_range = self
.bit_unpacker
.block_oblivious_range(doc_id_range.clone(), self.data.len());
let data_offset = data_range.start;
let data_subset = self
.data
.slice(data_range)
.read_bytes()
.expect("Failed to read column values.");
self.bit_unpacker.get_ids_for_value_range_from_subset(
transformed_range,
doc_id_range,
data_offset,
&data_subset,
positions,
);
}
ValueRange::GreaterThanOrEqual(threshold, _) => {
if threshold <= self.stats.min_value {
positions.extend(doc_id_range);
return;
}
if threshold > self.stats.max_value {
return;
}
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
let raw_threshold = (diff + gcd - 1) / gcd;
// We want raw >= raw_threshold.
let max_raw = (self.stats.max_value - self.stats.min_value) / self.stats.gcd.get();
let transformed_range = raw_threshold..=max_raw;
let data_range = self
.bit_unpacker
.block_oblivious_range(doc_id_range.clone(), self.data.len());
let data_offset = data_range.start;
let data_subset = self
.data
.slice(data_range)
.read_bytes()
.expect("Failed to read column values.");
self.bit_unpacker.get_ids_for_value_range_from_subset(
transformed_range,
doc_id_range,
data_offset,
&data_subset,
positions,
);
}
ValueRange::LessThan(threshold, _) => {
if threshold > self.stats.max_value {
positions.extend(doc_id_range);
return;
}
if threshold <= self.stats.min_value {
return;
}
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
// We want raw < raw_threshold_limit
// raw <= raw_threshold_limit - 1
let raw_threshold_limit = if diff % gcd == 0 {
diff / gcd
} else {
diff / gcd + 1
};
if raw_threshold_limit == 0 {
return;
}
let transformed_range = 0..=(raw_threshold_limit - 1);
let data_range = self
.bit_unpacker
.block_oblivious_range(doc_id_range.clone(), self.data.len());
let data_offset = data_range.start;
let data_subset = self
.data
.slice(data_range)
.read_bytes()
.expect("Failed to read column values.");
self.bit_unpacker.get_ids_for_value_range_from_subset(
transformed_range,
doc_id_range,
data_offset,
&data_subset,
positions,
);
}
ValueRange::LessThanOrEqual(threshold, _) => {
if threshold >= self.stats.max_value {
positions.extend(doc_id_range);
return;
}
if threshold < self.stats.min_value {
return;
}
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
// We want raw <= raw_threshold.
let raw_threshold = diff / gcd;
let transformed_range = 0..=raw_threshold;
let data_range = self
.bit_unpacker
.block_oblivious_range(doc_id_range.clone(), self.data.len());
let data_offset = data_range.start;
let data_subset = self
.data
.slice(data_range)
.read_bytes()
.expect("Failed to read column values.");
self.bit_unpacker.get_ids_for_value_range_from_subset(
transformed_range,
doc_id_range,
data_offset,
&data_subset,
positions,
);
}
}
let Some(transformed_range) =
transform_range_before_linear_transformation(&self.stats, range)
else {
positions.clear();
return;
};
self.bit_unpacker.get_ids_for_value_range(
transformed_range,
doc_id_range,
&self.data,
positions,
);
}
}
@@ -473,20 +133,14 @@ impl ColumnCodec for BitpackedCodec {
type Estimator = BitpackedCodecEstimator;
/// Opens a fast field given a file.
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues> {
let (stats, data) = ColumnStats::deserialize_from_tail(file_slice)?;
fn load(mut data: OwnedBytes) -> io::Result<Self::ColumnValues> {
let stats = ColumnStats::deserialize(&mut data)?;
let num_bits = num_bits(&stats);
let bit_unpacker = BitUnpacker::new(num_bits);
let block_count = bit_unpacker.block_count(data.len());
Ok(BitpackedReader {
data,
bit_unpacker,
stats,
blocks: (0..block_count)
.into_iter()
.map(|_| OnceLock::new())
.collect(),
})
}
}

View File

@@ -1,10 +1,8 @@
use std::io;
use std::io::Write;
use std::ops::{Deref, DerefMut};
use std::sync::{Arc, OnceLock};
use std::sync::Arc;
use std::{io, iter};
use common::file_slice::FileSlice;
use common::{BinarySerializable, CountingWriter, DeserializeFrom, HasLen, OwnedBytes};
use common::{BinarySerializable, CountingWriter, DeserializeFrom, OwnedBytes};
use fastdivide::DividerU64;
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
@@ -174,63 +172,32 @@ impl ColumnCodec<u64> for BlockwiseLinearCodec {
type Estimator = BlockwiseLinearEstimator;
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues> {
let (stats, body) = ColumnStats::deserialize_from_tail(file_slice)?;
let (_, footer) = body.clone().split_from_end(4);
let footer_len: u32 = footer.read_bytes()?.as_slice().deserialize()?;
let (data, footer) = body.split_from_end(footer_len as usize + 4);
let mut footer = footer.read_bytes()?;
fn load(mut bytes: OwnedBytes) -> io::Result<Self::ColumnValues> {
let stats = ColumnStats::deserialize(&mut bytes)?;
let footer_len: u32 = (&bytes[bytes.len() - 4..]).deserialize()?;
let footer_offset = bytes.len() - 4 - footer_len as usize;
let (data, mut footer) = bytes.split(footer_offset);
let num_blocks = compute_num_blocks(stats.num_rows);
let mut blocks: Vec<Block> = iter::repeat_with(|| Block::deserialize(&mut footer))
.take(num_blocks as usize)
.collect::<io::Result<_>>()?;
let mut start_offset = 0;
let mut blocks = Vec::with_capacity(num_blocks as usize);
for _ in 0..num_blocks {
let mut block = Block::deserialize(&mut footer)?;
let len = (block.bit_unpacker.bit_width() as usize) * BLOCK_SIZE as usize / 8;
for block in &mut blocks {
block.data_start_offset = start_offset;
blocks.push(BlockWithData {
block,
file_slice: data.slice(start_offset..(start_offset + len).min(data.len())),
data: Default::default(),
});
start_offset += len;
start_offset += (block.bit_unpacker.bit_width() as usize) * BLOCK_SIZE as usize / 8;
}
Ok(BlockwiseLinearReader {
blocks: blocks.into_boxed_slice().into(),
data,
stats,
})
}
}
struct BlockWithData {
block: Block,
file_slice: FileSlice,
data: OnceLock<OwnedBytes>,
}
impl Deref for BlockWithData {
type Target = Block;
fn deref(&self) -> &Self::Target {
&self.block
}
}
impl DerefMut for BlockWithData {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.block
}
}
#[derive(Clone)]
pub struct BlockwiseLinearReader {
blocks: Arc<[BlockWithData]>,
blocks: Arc<[Block]>,
data: OwnedBytes,
stats: ColumnStats,
}
@@ -241,9 +208,7 @@ impl ColumnValues for BlockwiseLinearReader {
let idx_within_block = idx % BLOCK_SIZE;
let block = &self.blocks[block_id];
let interpoled_val: u64 = block.line.eval(idx_within_block);
let block_bytes = block
.data
.get_or_init(|| block.file_slice.read_bytes().unwrap());
let block_bytes = &self.data[block.data_start_offset..];
let bitpacked_diff = block.bit_unpacker.get(idx_within_block, block_bytes);
// TODO optimize me! the line parameters could be tweaked to include the multiplication and
// remove the dependency.

View File

@@ -1,6 +1,5 @@
use std::io;
use common::file_slice::FileSlice;
use common::{BinarySerializable, OwnedBytes};
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
@@ -191,8 +190,7 @@ impl ColumnCodec for LinearCodec {
type Estimator = LinearCodecEstimator;
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues> {
let mut data = file_slice.read_bytes()?;
fn load(mut data: OwnedBytes) -> io::Result<Self::ColumnValues> {
let stats = ColumnStats::deserialize(&mut data)?;
let linear_params = LinearParams::deserialize(&mut data)?;
Ok(LinearReader {

View File

@@ -8,8 +8,7 @@ use std::io;
use std::io::Write;
use std::sync::Arc;
use common::BinarySerializable;
use common::file_slice::FileSlice;
use common::{BinarySerializable, OwnedBytes};
use crate::column_values::monotonic_mapping::{
StrictlyMonotonicMappingInverter, StrictlyMonotonicMappingToInternal,
@@ -61,7 +60,7 @@ pub trait ColumnCodec<T: PartialOrd = u64> {
type Estimator: ColumnCodecEstimator + Default;
/// Loads a column that has been serialized using this codec.
fn load(file_slice: FileSlice) -> io::Result<Self::ColumnValues>;
fn load(bytes: OwnedBytes) -> io::Result<Self::ColumnValues>;
/// Returns an estimator.
fn estimator() -> Self::Estimator {
@@ -112,22 +111,20 @@ impl CodecType {
fn load<T: MonotonicallyMappableToU64>(
&self,
file_slice: FileSlice,
bytes: OwnedBytes,
) -> io::Result<Arc<dyn ColumnValues<T>>> {
match self {
CodecType::Bitpacked => load_specific_codec::<BitpackedCodec, T>(file_slice),
CodecType::Linear => load_specific_codec::<LinearCodec, T>(file_slice),
CodecType::BlockwiseLinear => {
load_specific_codec::<BlockwiseLinearCodec, T>(file_slice)
}
CodecType::Bitpacked => load_specific_codec::<BitpackedCodec, T>(bytes),
CodecType::Linear => load_specific_codec::<LinearCodec, T>(bytes),
CodecType::BlockwiseLinear => load_specific_codec::<BlockwiseLinearCodec, T>(bytes),
}
}
}
fn load_specific_codec<C: ColumnCodec, T: MonotonicallyMappableToU64>(
file_slice: FileSlice,
bytes: OwnedBytes,
) -> io::Result<Arc<dyn ColumnValues<T>>> {
let reader = C::load(file_slice)?;
let reader = C::load(bytes)?;
let reader_typed = monotonic_map_column(
reader,
StrictlyMonotonicMappingInverter::from(StrictlyMonotonicMappingToInternal::<T>::new()),
@@ -192,28 +189,25 @@ pub fn serialize_u64_based_column_values<T: MonotonicallyMappableToU64>(
///
/// This method first identifies the codec off the first byte.
pub fn load_u64_based_column_values<T: MonotonicallyMappableToU64>(
file_slice: FileSlice,
mut bytes: OwnedBytes,
) -> io::Result<Arc<dyn ColumnValues<T>>> {
let (header, body) = file_slice.split(1);
let codec_type: CodecType = header
.read_bytes()?
.as_slice()
.get(0)
.cloned()
let codec_type: CodecType = bytes
.first()
.copied()
.and_then(CodecType::try_from_code)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Failed to read codec type"))?;
codec_type.load(body)
bytes.advance(1);
codec_type.load(bytes)
}
/// Helper function to serialize a column (autodetect from all codecs) and then open it
#[cfg(test)]
pub fn serialize_and_load_u64_based_column_values<T: MonotonicallyMappableToU64>(
vals: &dyn Iterable,
codec_types: &[CodecType],
) -> Arc<dyn ColumnValues<T>> {
let mut buffer = Vec::new();
serialize_u64_based_column_values(vals, codec_types, &mut buffer).unwrap();
load_u64_based_column_values::<T>(FileSlice::from(buffer)).unwrap()
load_u64_based_column_values::<T>(OwnedBytes::new(buffer)).unwrap()
}
#[cfg(test)]

View File

@@ -1,4 +1,3 @@
use common::HasLen;
use proptest::prelude::*;
use proptest::{prop_oneof, proptest};
use rand::Rng;
@@ -14,7 +13,7 @@ fn test_serialize_and_load_simple() {
)
.unwrap();
assert_eq!(buffer.len(), 7);
let col = load_u64_based_column_values::<u64>(FileSlice::from(buffer)).unwrap();
let col = load_u64_based_column_values::<u64>(OwnedBytes::new(buffer)).unwrap();
assert_eq!(col.num_vals(), 3);
assert_eq!(col.get_val(0), 1);
assert_eq!(col.get_val(1), 2);
@@ -31,7 +30,7 @@ fn test_empty_column_i64() {
continue;
}
num_acceptable_codecs += 1;
let col = load_u64_based_column_values::<i64>(FileSlice::from(buffer)).unwrap();
let col = load_u64_based_column_values::<i64>(OwnedBytes::new(buffer)).unwrap();
assert_eq!(col.num_vals(), 0);
assert_eq!(col.min_value(), i64::MIN);
assert_eq!(col.max_value(), i64::MIN);
@@ -49,7 +48,7 @@ fn test_empty_column_u64() {
continue;
}
num_acceptable_codecs += 1;
let col = load_u64_based_column_values::<u64>(FileSlice::from(buffer)).unwrap();
let col = load_u64_based_column_values::<u64>(OwnedBytes::new(buffer)).unwrap();
assert_eq!(col.num_vals(), 0);
assert_eq!(col.min_value(), u64::MIN);
assert_eq!(col.max_value(), u64::MIN);
@@ -67,7 +66,7 @@ fn test_empty_column_f64() {
continue;
}
num_acceptable_codecs += 1;
let col = load_u64_based_column_values::<f64>(FileSlice::from(buffer)).unwrap();
let col = load_u64_based_column_values::<f64>(OwnedBytes::new(buffer)).unwrap();
assert_eq!(col.num_vals(), 0);
// FIXME. f64::MIN would be better!
assert!(col.min_value().is_nan());
@@ -98,7 +97,7 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
let actual_compression = buffer.len() as u64;
let reader = TColumnCodec::load(FileSlice::from(buffer)).unwrap();
let reader = TColumnCodec::load(OwnedBytes::new(buffer)).unwrap();
assert_eq!(reader.num_vals(), vals.len() as u32);
let mut buffer = Vec::new();
for (doc, orig_val) in vals.iter().copied().enumerate() {
@@ -132,7 +131,7 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
.collect();
let mut positions = Vec::new();
reader.get_row_ids_for_value_range(
crate::column::ValueRange::Inclusive(vals[test_rand_idx]..=vals[test_rand_idx]),
vals[test_rand_idx]..=vals[test_rand_idx],
0..vals.len() as u32,
&mut positions,
);
@@ -327,7 +326,7 @@ fn test_fastfield_gcd_i64_with_codec(codec_type: CodecType, num_vals: usize) ->
&[codec_type],
&mut buffer,
)?;
let buffer = FileSlice::from(buffer);
let buffer = OwnedBytes::new(buffer);
let column = crate::column_values::load_u64_based_column_values::<i64>(buffer.clone())?;
assert_eq!(column.get_val(0), -4000i64);
assert_eq!(column.get_val(1), -3000i64);
@@ -344,7 +343,7 @@ fn test_fastfield_gcd_i64_with_codec(codec_type: CodecType, num_vals: usize) ->
&[codec_type],
&mut buffer_without_gcd,
)?;
let buffer_without_gcd = FileSlice::from(buffer_without_gcd);
let buffer_without_gcd = OwnedBytes::new(buffer_without_gcd);
assert!(buffer_without_gcd.len() > buffer.len());
Ok(())
@@ -370,7 +369,7 @@ fn test_fastfield_gcd_u64_with_codec(codec_type: CodecType, num_vals: usize) ->
&[codec_type],
&mut buffer,
)?;
let buffer = FileSlice::from(buffer);
let buffer = OwnedBytes::new(buffer);
let column = crate::column_values::load_u64_based_column_values::<u64>(buffer.clone())?;
assert_eq!(column.get_val(0), 1000u64);
assert_eq!(column.get_val(1), 2000u64);
@@ -387,7 +386,7 @@ fn test_fastfield_gcd_u64_with_codec(codec_type: CodecType, num_vals: usize) ->
&[codec_type],
&mut buffer_without_gcd,
)?;
let buffer_without_gcd = FileSlice::from(buffer_without_gcd);
let buffer_without_gcd = OwnedBytes::new(buffer_without_gcd);
assert!(buffer_without_gcd.len() > buffer.len());
Ok(())
}
@@ -406,7 +405,7 @@ fn test_fastfield_gcd_u64() -> io::Result<()> {
#[test]
pub fn test_fastfield2() {
let test_fastfield = serialize_and_load_u64_based_column_values::<u64>(
let test_fastfield = crate::column_values::serialize_and_load_u64_based_column_values::<u64>(
&&[100u64, 200u64, 300u64][..],
&ALL_U64_CODEC_TYPES,
);

View File

@@ -4,7 +4,6 @@ mod term_merger;
use std::collections::{BTreeMap, HashSet};
use std::io;
use std::io::ErrorKind;
use std::net::Ipv6Addr;
use std::sync::Arc;
@@ -79,7 +78,6 @@ pub fn merge_columnar(
required_columns: &[(String, ColumnType)],
merge_row_order: MergeRowOrder,
output: &mut impl io::Write,
cancel: impl Fn() -> bool,
) -> io::Result<()> {
let mut serializer = ColumnarSerializer::new(output);
let num_docs_per_columnar = columnar_readers
@@ -89,9 +87,6 @@ pub fn merge_columnar(
let columns_to_merge = group_columns_for_merge(columnar_readers, required_columns)?;
for res in columns_to_merge {
if cancel() {
return Err(io::Error::new(ErrorKind::Interrupted, "Merge cancelled"));
}
let ((column_name, _column_type_category), grouped_columns) = res;
let grouped_columns = grouped_columns.open(&merge_row_order)?;
if grouped_columns.is_empty() {

View File

@@ -205,7 +205,6 @@ fn test_merge_columnar_numbers() {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
|| false,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
@@ -234,7 +233,6 @@ fn test_merge_columnar_texts() {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
|| false,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
@@ -284,7 +282,6 @@ fn test_merge_columnar_byte() {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
|| false,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
@@ -341,7 +338,6 @@ fn test_merge_columnar_byte_with_missing() {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
|| false,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
@@ -394,7 +390,6 @@ fn test_merge_columnar_different_types() {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
|| false,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
@@ -460,7 +455,6 @@ fn test_merge_columnar_different_empty_cardinality() {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
|| false,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
@@ -571,7 +565,6 @@ proptest! {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut out,
|| false,
).unwrap();
let merged_reader = ColumnarReader::open(out).unwrap();
@@ -589,7 +582,6 @@ proptest! {
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut out,
|| false,
).unwrap();
}

View File

@@ -1,22 +0,0 @@
use serde::{Deserialize, Serialize};
/// Contains a feature (field, score, etc.) of a document along with the document address.
///
/// Used only by TopNComputer, which implements the actual comparison via a `Comparator`.
#[derive(Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
pub struct ComparableDoc<T, D> {
/// The feature of the document. In practice, this is
/// is a type which can be compared with a `Comparator<T>`.
pub sort_key: T,
/// The document address. In practice, this is either a `DocId` or `DocAddress`.
pub doc: D,
}
impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ComparableDoc")
.field("feature", &self.sort_key)
.field("doc", &self.doc)
.finish()
}
}

View File

@@ -71,14 +71,7 @@ fn test_format(path: &str) {
let columnar_readers = vec![&reader, &reader2];
let merge_row_order = StackMergeOrder::stack(&columnar_readers[..]);
let mut out = Vec::new();
merge_columnar(
&columnar_readers,
&[],
merge_row_order.into(),
&mut out,
|| false,
)
.unwrap();
merge_columnar(&columnar_readers, &[], merge_row_order.into(), &mut out).unwrap();
let reader = ColumnarReader::open(out).unwrap();
check_columns(&reader);
}

View File

@@ -3,8 +3,7 @@ use std::sync::Arc;
use std::{fmt, io};
use common::file_slice::FileSlice;
use common::{ByteCount, DateTime};
use serde::{Deserialize, Serialize};
use common::{ByteCount, DateTime, HasLen, OwnedBytes};
use crate::column::{BytesColumn, Column, StrColumn};
use crate::column_values::{StrictlyMonotonicFn, monotonic_map_column};
@@ -239,7 +238,8 @@ pub struct DynamicColumnHandle {
impl DynamicColumnHandle {
// TODO rename load
pub fn open(&self) -> io::Result<DynamicColumn> {
self.open_internal(self.file_slice.clone())
let column_bytes: OwnedBytes = self.file_slice.read_bytes()?;
self.open_internal(column_bytes)
}
#[doc(hidden)]
@@ -258,15 +258,16 @@ impl DynamicColumnHandle {
/// If not, the fastfield reader will returns the u64-value associated with the original
/// FastValue.
pub fn open_u64_lenient(&self) -> io::Result<Option<Column<u64>>> {
let column_bytes = self.file_slice.read_bytes()?;
match self.column_type {
ColumnType::Str | ColumnType::Bytes => {
let column: BytesColumn =
crate::column::open_column_bytes(self.file_slice.clone(), self.format_version)?;
crate::column::open_column_bytes(column_bytes, self.format_version)?;
Ok(Some(column.term_ord_column))
}
ColumnType::IpAddr => {
let column = crate::column::open_column_u128_as_compact_u64(
self.file_slice.clone(),
column_bytes,
self.format_version,
)?;
Ok(Some(column))
@@ -276,129 +277,50 @@ impl DynamicColumnHandle {
| ColumnType::U64
| ColumnType::F64
| ColumnType::DateTime => {
let column = crate::column::open_column_u64::<u64>(
self.file_slice.clone(),
self.format_version,
)?;
let column =
crate::column::open_column_u64::<u64>(column_bytes, self.format_version)?;
Ok(Some(column))
}
}
}
fn open_internal(&self, file_slice: FileSlice) -> io::Result<DynamicColumn> {
fn open_internal(&self, column_bytes: OwnedBytes) -> io::Result<DynamicColumn> {
let dynamic_column: DynamicColumn = match self.column_type {
ColumnType::Bytes => {
crate::column::open_column_bytes(file_slice, self.format_version)?.into()
crate::column::open_column_bytes(column_bytes, self.format_version)?.into()
}
ColumnType::Str => {
crate::column::open_column_str(file_slice, self.format_version)?.into()
crate::column::open_column_str(column_bytes, self.format_version)?.into()
}
ColumnType::I64 => {
crate::column::open_column_u64::<i64>(file_slice, self.format_version)?.into()
crate::column::open_column_u64::<i64>(column_bytes, self.format_version)?.into()
}
ColumnType::U64 => {
crate::column::open_column_u64::<u64>(file_slice, self.format_version)?.into()
crate::column::open_column_u64::<u64>(column_bytes, self.format_version)?.into()
}
ColumnType::F64 => {
crate::column::open_column_u64::<f64>(file_slice, self.format_version)?.into()
crate::column::open_column_u64::<f64>(column_bytes, self.format_version)?.into()
}
ColumnType::Bool => {
crate::column::open_column_u64::<bool>(file_slice, self.format_version)?.into()
crate::column::open_column_u64::<bool>(column_bytes, self.format_version)?.into()
}
ColumnType::IpAddr => {
crate::column::open_column_u128::<Ipv6Addr>(file_slice, self.format_version)?.into()
crate::column::open_column_u128::<Ipv6Addr>(column_bytes, self.format_version)?
.into()
}
ColumnType::DateTime => {
crate::column::open_column_u64::<DateTime>(file_slice, self.format_version)?.into()
crate::column::open_column_u64::<DateTime>(column_bytes, self.format_version)?
.into()
}
};
Ok(dynamic_column)
}
pub fn num_bytes(&self) -> ByteCount {
self.file_slice.num_bytes()
}
/// Legacy helper returning the column space usage.
pub fn column_and_dictionary_num_bytes(&self) -> io::Result<ColumnSpaceUsage> {
self.space_usage()
}
/// Return the space usage of the column, optionally broken down by dictionary and column
/// values.
///
/// For dictionary encoded columns (strings and bytes), this splits the total footprint into
/// the dictionary and the remaining column data (including index and values).
/// For all other column types, the dictionary size is `None` and the column size
/// equals the total bytes.
pub fn space_usage(&self) -> io::Result<ColumnSpaceUsage> {
let total_num_bytes = self.num_bytes();
let dynamic_column = self.open()?;
let dictionary_num_bytes = match &dynamic_column {
DynamicColumn::Bytes(bytes_column) => bytes_column.dictionary().num_bytes(),
DynamicColumn::Str(str_column) => str_column.dictionary().num_bytes(),
_ => {
return Ok(ColumnSpaceUsage::new(self.num_bytes(), None));
}
};
assert!(dictionary_num_bytes <= total_num_bytes);
let column_num_bytes =
ByteCount::from(total_num_bytes.get_bytes() - dictionary_num_bytes.get_bytes());
Ok(ColumnSpaceUsage::new(
column_num_bytes,
Some(dictionary_num_bytes),
))
self.file_slice.len().into()
}
pub fn column_type(&self) -> ColumnType {
self.column_type
}
}
/// Represents space usage of a column.
///
/// `column_num_bytes` tracks the column payload (index, values and footer).
/// For dictionary encoded columns, `dictionary_num_bytes` captures the dictionary footprint.
/// [`ColumnSpaceUsage::total_num_bytes`] returns the sum of both parts.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ColumnSpaceUsage {
column_num_bytes: ByteCount,
dictionary_num_bytes: Option<ByteCount>,
}
impl ColumnSpaceUsage {
pub(crate) fn new(
column_num_bytes: ByteCount,
dictionary_num_bytes: Option<ByteCount>,
) -> Self {
ColumnSpaceUsage {
column_num_bytes,
dictionary_num_bytes,
}
}
pub fn column_num_bytes(&self) -> ByteCount {
self.column_num_bytes
}
pub fn dictionary_num_bytes(&self) -> Option<ByteCount> {
self.dictionary_num_bytes
}
pub fn total_num_bytes(&self) -> ByteCount {
self.column_num_bytes + self.dictionary_num_bytes.unwrap_or_default()
}
/// Merge two space usage values by summing their components.
pub fn merge(&self, other: &ColumnSpaceUsage) -> ColumnSpaceUsage {
let dictionary_num_bytes = match (self.dictionary_num_bytes, other.dictionary_num_bytes) {
(Some(lhs), Some(rhs)) => Some(lhs + rhs),
(Some(val), None) | (None, Some(val)) => Some(val),
(None, None) => None,
};
ColumnSpaceUsage {
column_num_bytes: self.column_num_bytes + other.column_num_bytes,
dictionary_num_bytes,
}
}
}

View File

@@ -29,7 +29,6 @@ mod column;
pub mod column_index;
pub mod column_values;
mod columnar;
mod comparable_doc;
mod dictionary;
mod dynamic_column;
mod iterable;
@@ -37,7 +36,7 @@ pub(crate) mod utils;
mod value;
pub use block_accessor::ColumnBlockAccessor;
pub use column::{BytesColumn, Column, StrColumn, ValueRange};
pub use column::{BytesColumn, Column, StrColumn};
pub use column_index::ColumnIndex;
pub use column_values::{
ColumnValues, EmptyColumnValues, MonotonicallyMappableToU64, MonotonicallyMappableToU128,
@@ -46,11 +45,10 @@ pub use columnar::{
CURRENT_VERSION, ColumnType, ColumnarReader, ColumnarWriter, HasAssociatedColumnType,
MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, Version, merge_columnar,
};
pub use comparable_doc::ComparableDoc;
use sstable::VoidSSTable;
pub use value::{NumericalType, NumericalValue};
pub use self::dynamic_column::{ColumnSpaceUsage, DynamicColumn, DynamicColumnHandle};
pub use self::dynamic_column::{DynamicColumn, DynamicColumnHandle};
pub type RowId = u32;
pub type DocId = u32;

View File

@@ -641,7 +641,7 @@ proptest! {
let columnar_readers_arr: Vec<&ColumnarReader> = columnar_readers.iter().collect();
let mut output: Vec<u8> = Vec::new();
let stack_merge_order = StackMergeOrder::stack(&columnar_readers_arr[..]).into();
crate::merge_columnar(&columnar_readers_arr[..], &[], stack_merge_order, &mut output, || false,).unwrap();
crate::merge_columnar(&columnar_readers_arr[..], &[], stack_merge_order, &mut output).unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
let concat_rows: Vec<Vec<(&'static str, ColumnValue)>> = columnar_docs.iter().flatten().cloned().collect();
let expected_merged_columnar = build_columnar(&concat_rows[..]);
@@ -665,7 +665,6 @@ fn test_columnar_merging_empty_columnar() {
&[],
crate::MergeRowOrder::Stack(stack_merge_order),
&mut output,
|| false,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
@@ -703,7 +702,6 @@ fn test_columnar_merging_number_columns() {
&[],
crate::MergeRowOrder::Stack(stack_merge_order),
&mut output,
|| false,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
@@ -777,7 +775,6 @@ fn test_columnar_merge_and_remap(
&[],
shuffle_merge_order.into(),
&mut output,
|| false,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
@@ -820,7 +817,6 @@ fn test_columnar_merge_empty() {
&[],
shuffle_merge_order.into(),
&mut output,
|| false,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
@@ -847,7 +843,6 @@ fn test_columnar_merge_single_str_column() {
&[],
shuffle_merge_order.into(),
&mut output,
|| false,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
@@ -880,7 +875,6 @@ fn test_delete_decrease_cardinality() {
&[],
shuffle_merge_order.into(),
&mut output,
|| false,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();

View File

@@ -181,6 +181,14 @@ pub struct BitSet {
len: u64,
max_value: u32,
}
impl std::fmt::Debug for BitSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BitSet")
.field("len", &self.len)
.field("max_value", &self.max_value)
.finish()
}
}
fn num_buckets(max_val: u32) -> u32 {
max_val.div_ceil(64u32)

View File

@@ -1,106 +0,0 @@
use std::cell::RefCell;
use std::cmp::min;
use std::io;
use std::ops::Range;
use super::file_slice::FileSlice;
use super::{HasLen, OwnedBytes};
const DEFAULT_BUFFER_MAX_SIZE: usize = 512 * 1024; // 512K
/// A buffered reader for a FileSlice.
///
/// Reads the underlying `FileSlice` in large, sequential chunks to amortize
/// the cost of `read_bytes` calls, while keeping peak memory usage under control.
///
/// TODO: Rather than wrapping a `FileSlice` in buffering, it will usually be better to adjust a
/// `FileHandle` to directly handle buffering itself.
/// TODO: See: https://github.com/paradedb/paradedb/issues/3374
pub struct BufferedFileSlice {
file_slice: FileSlice,
buffer: RefCell<OwnedBytes>,
buffer_range: RefCell<Range<u64>>,
buffer_max_size: usize,
}
impl BufferedFileSlice {
/// Creates a new `BufferedFileSlice`.
///
/// The `buffer_max_size` is the amount of data that will be read from the
/// `FileSlice` on a buffer miss.
pub fn new(file_slice: FileSlice, buffer_max_size: usize) -> Self {
Self {
file_slice,
buffer: RefCell::new(OwnedBytes::empty()),
buffer_range: RefCell::new(0..0),
buffer_max_size,
}
}
/// Creates a new `BufferedFileSlice` with a default buffer max size.
pub fn new_with_default_buffer_size(file_slice: FileSlice) -> Self {
Self::new(file_slice, DEFAULT_BUFFER_MAX_SIZE)
}
/// Creates an empty `BufferedFileSlice`.
pub fn empty() -> Self {
Self::new(FileSlice::empty(), 0)
}
/// Returns an `OwnedBytes` corresponding to the given `required_range`.
///
/// If the requested range is not in the buffer, this will trigger a read
/// from the underlying `FileSlice`.
///
/// If the requested range is larger than the buffer_max_size, it will be read directly from the
/// source without buffering.
///
/// # Errors
///
/// Returns an `io::Error` if the underlying read fails or the range is
/// out of bounds.
pub fn get_bytes(&self, required_range: Range<u64>) -> io::Result<OwnedBytes> {
let buffer_range = self.buffer_range.borrow();
// Cache miss condition: the required range is not fully contained in the current buffer.
if required_range.start < buffer_range.start || required_range.end > buffer_range.end {
drop(buffer_range); // release borrow before mutating
if required_range.end > self.file_slice.len() as u64 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Requested range extends beyond the end of the file slice.",
));
}
if (required_range.end - required_range.start) as usize > self.buffer_max_size {
// This read is larger than our buffer max size.
// Read it directly and bypass the buffer to avoid churning.
return self
.file_slice
.read_bytes_slice(required_range.start as usize..required_range.end as usize);
}
let new_buffer_start = required_range.start;
let new_buffer_end = min(
new_buffer_start + self.buffer_max_size as u64,
self.file_slice.len() as u64,
);
let read_range = new_buffer_start..new_buffer_end;
let new_buffer = self
.file_slice
.read_bytes_slice(read_range.start as usize..read_range.end as usize)?;
self.buffer.replace(new_buffer);
self.buffer_range.replace(read_range);
}
// Now the data is guaranteed to be in the buffer.
let buffer = self.buffer.borrow();
let buffer_range = self.buffer_range.borrow();
let local_start = (required_range.start - buffer_range.start) as usize;
let local_end = (required_range.end - buffer_range.start) as usize;
Ok(buffer.slice(local_start..local_end))
}
}

View File

@@ -1,7 +1,7 @@
use std::fs::File;
use std::ops::{Deref, Range, RangeBounds};
use std::path::Path;
use std::sync::{Arc, OnceLock};
use std::sync::Arc;
use std::{fmt, io};
use async_trait::async_trait;
@@ -339,27 +339,6 @@ impl FileHandle for OwnedBytes {
}
}
pub struct DeferredFileSlice {
opener: Arc<dyn Fn() -> io::Result<FileSlice> + Send + Sync + 'static>,
file_slice: OnceLock<std::io::Result<FileSlice>>,
}
impl DeferredFileSlice {
pub fn new(opener: impl Fn() -> io::Result<FileSlice> + Send + Sync + 'static) -> Self {
DeferredFileSlice {
opener: Arc::new(opener),
file_slice: OnceLock::default(),
}
}
pub fn open(&self) -> io::Result<&FileSlice> {
match self.file_slice.get_or_init(|| (self.opener)()) {
Ok(file_slice) => Ok(file_slice),
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e.to_string())),
}
}
}
#[cfg(test)]
mod tests {
use std::io;

View File

@@ -6,7 +6,6 @@ pub use byteorder::LittleEndian as Endianness;
mod bitset;
pub mod bounds;
pub mod buffered_file_slice;
mod byte_count;
mod datetime;
pub mod file_slice;

View File

@@ -58,33 +58,6 @@ impl BinarySerializable for VIntU128 {
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct VInt(pub u64);
impl VInt {
pub fn deserialize_with_size<R: Read>(reader: &mut R) -> io::Result<(Self, usize)> {
let mut nbytes = 0;
let mut bytes = reader.bytes();
let mut result = 0u64;
let mut shift = 0u64;
loop {
match bytes.next() {
Some(Ok(b)) => {
nbytes += 1;
result |= u64::from(b % 128u8) << shift;
if b >= STOP_BIT {
return Ok((VInt(result), nbytes));
}
shift += 7;
}
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Reach end of buffer while reading VInt",
));
}
}
}
}
}
const STOP_BIT: u8 = 128;
#[inline]
@@ -252,6 +225,7 @@ impl BinarySerializable for VInt {
#[cfg(test)]
mod tests {
use super::{BinarySerializable, VInt, serialize_vint_u32};
fn aux_test_vint(val: u64) {

View File

@@ -1,86 +0,0 @@
// # Multiple Snippets Example
//
// This example demonstrates how to return multiple text fragments
// from a document, useful for long documents with matches in different locations.
use tantivy::collector::TopDocs;
use tantivy::query::QueryParser;
use tantivy::schema::*;
use tantivy::snippet::SnippetGenerator;
use tantivy::{doc, Index, IndexWriter};
use tempfile::TempDir;
fn main() -> tantivy::Result<()> {
let index_path = TempDir::new()?;
// Define the schema
let mut schema_builder = Schema::builder();
let title = schema_builder.add_text_field("title", TEXT | STORED);
let body = schema_builder.add_text_field("body", TEXT | STORED);
let schema = schema_builder.build();
// Create the index
let index = Index::create_in_dir(&index_path, schema)?;
let mut index_writer: IndexWriter = index.writer(50_000_000)?;
// Index a long document with multiple occurrences of "rust"
index_writer.add_document(doc!(
title => "The Rust Programming Language",
body => "Rust is a systems programming language that runs blazingly fast, prevents \
segfaults, and guarantees thread safety. Lorem ipsum dolor sit amet, \
consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore. \
Rust empowers everyone to build reliable and efficient software. More filler \
text to create distance between matches. Ut enim ad minim veniam, quis nostrud \
exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. \
The Rust compiler is known for its helpful error messages. Duis aute irure \
dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla \
pariatur. Rust has a strong type system and ownership model."
))?;
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![title, body]);
let query = query_parser.parse_query("rust")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
// Create snippet generator
let mut snippet_generator = SnippetGenerator::create(&searcher, &*query, body)?;
println!("=== Single Snippet (Default Behavior) ===\n");
for (score, doc_address) in &top_docs {
let doc = searcher.doc::<TantivyDocument>(*doc_address)?;
let snippet = snippet_generator.snippet_from_doc(&doc);
println!("Document score: {}", score);
println!("Title: {}", doc.get_first(title).unwrap().as_str().unwrap());
println!("Single snippet: {}\n", snippet.to_html());
}
println!("\n=== Multiple Snippets (New Feature) ===\n");
// Configure to return multiple snippets
// Get up to 3 snippets
snippet_generator.set_snippets_limit(3);
// Smaller fragments
snippet_generator.set_max_num_chars(80);
// By default, multiple snippets are sorted by score. You can change this to sort by position.
// snippet_generator.set_sort_order(SnippetSortOrder::Position);
for (score, doc_address) in top_docs {
let doc = searcher.doc::<TantivyDocument>(doc_address)?;
let snippets = snippet_generator.snippets_from_doc(&doc);
println!("Document score: {}", score);
println!("Title: {}", doc.get_first(title).unwrap().as_str().unwrap());
println!("Found {} snippets:", snippets.len());
for (i, snippet) in snippets.iter().enumerate() {
println!(" Snippet {}: {}", i + 1, snippet.to_html());
}
println!();
}
Ok(())
}

View File

@@ -1,3 +0,0 @@
#! /bin/bash
cargo +stable nextest run --features quickwit,mmap,stopwords,lz4-compression,zstd-compression,failpoints --verbose --workspace

View File

@@ -1,4 +1,4 @@
use columnar::{Column, ColumnType, StrColumn};
use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn};
use common::BitSet;
use rustc_hash::FxHashSet;
use serde::Serialize;
@@ -10,16 +10,16 @@ use crate::aggregation::accessor_helpers::{
};
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
build_segment_filter_collector, build_segment_range_collector, FilterAggReqData,
HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData,
RangeAggReqData, SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
};
use crate::aggregation::metric::{
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
ExtendedStatsAggregation, MaxAggregation, MetricAggReqData, MinAggregation,
SegmentCardinalityCollector, SegmentExtendedStatsCollector, SegmentPercentilesCollector,
SegmentStatsCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
TopHitsSegmentCollector,
};
use crate::aggregation::segment_agg_result::{
@@ -35,6 +35,7 @@ pub struct AggregationsSegmentCtx {
/// Request data for each aggregation type.
pub per_request: PerRequestAggSegCtx,
pub context: AggContextParams,
pub column_block_accessor: ColumnBlockAccessor<u64>,
}
impl AggregationsSegmentCtx {
@@ -107,21 +108,14 @@ impl AggregationsSegmentCtx {
.as_deref()
.expect("range_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData {
self.per_request.filter_req_data[idx]
.as_deref()
.expect("filter_req_data slot is empty (taken)")
}
// ---------- mutable getters ----------
#[inline]
pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData {
self.per_request.term_req_data[idx]
.as_deref_mut()
.expect("term_req_data slot is empty (taken)")
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
&mut self.per_request.stats_metric_req_data[idx]
}
#[inline]
pub(crate) fn get_cardinality_req_data_mut(
&mut self,
@@ -129,10 +123,7 @@ impl AggregationsSegmentCtx {
) -> &mut CardinalityAggReqData {
&mut self.per_request.cardinality_req_data[idx]
}
#[inline]
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
&mut self.per_request.stats_metric_req_data[idx]
}
#[inline]
pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData {
self.per_request.histogram_req_data[idx]
@@ -142,21 +133,6 @@ impl AggregationsSegmentCtx {
// ---------- take / put (terms, histogram, range) ----------
/// Move out the boxed Terms request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_term_req_data(&mut self, idx: usize) -> Box<TermsAggReqData> {
self.per_request.term_req_data[idx]
.take()
.expect("term_req_data slot is empty (taken)")
}
/// Put back a Terms request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_term_req_data(&mut self, idx: usize, value: Box<TermsAggReqData>) {
debug_assert!(self.per_request.term_req_data[idx].is_none());
self.per_request.term_req_data[idx] = Some(value);
}
/// Move out the boxed Histogram request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box<HistogramAggReqData> {
@@ -320,6 +296,7 @@ impl PerRequestAggSegCtx {
/// Convert the aggregation tree into a serializable struct representation.
/// Each node contains: { name, kind, children }.
#[allow(dead_code)]
pub fn get_view_tree(&self) -> Vec<AggTreeViewNode> {
fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode {
let mut children: Vec<AggTreeViewNode> =
@@ -345,12 +322,19 @@ impl PerRequestAggSegCtx {
pub(crate) fn build_segment_agg_collectors_root(
req: &mut AggregationsSegmentCtx,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
build_segment_agg_collectors(req, &req.per_request.agg_tree.clone())
build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone())
}
pub(crate) fn build_segment_agg_collectors(
req: &mut AggregationsSegmentCtx,
nodes: &[AggRefNode],
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
build_segment_agg_collectors_generic(req, nodes)
}
fn build_segment_agg_collectors_generic(
req: &mut AggregationsSegmentCtx,
nodes: &[AggRefNode],
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let mut collectors = Vec::new();
for node in nodes.iter() {
@@ -388,6 +372,8 @@ pub(crate) fn build_segment_agg_collector(
Ok(Box::new(SegmentCardinalityCollector::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
)))
}
AggKind::StatsKind(stats_type) => {
@@ -398,20 +384,21 @@ pub(crate) fn build_segment_agg_collector(
| StatsType::Count
| StatsType::Max
| StatsType::Min
| StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req(
node.idx_in_req_data,
))),
StatsType::ExtendedStats(sigma) => {
Ok(Box::new(SegmentExtendedStatsCollector::from_req(
req_data.field_type,
sigma,
node.idx_in_req_data,
req_data.missing,
)))
}
StatsType::Percentiles => Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?,
| StatsType::Stats => build_segment_stats_collector(req_data),
StatsType::ExtendedStats(sigma) => Ok(Box::new(
SegmentExtendedStatsCollector::from_req(req_data, sigma),
)),
StatsType::Percentiles => {
let req_data = req.get_metric_req_data_mut(node.idx_in_req_data);
Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(
req_data.field_type,
req_data.missing_u64,
req_data.accessor.clone(),
node.idx_in_req_data,
),
))
}
}
}
AggKind::TopHits => {
@@ -428,12 +415,8 @@ pub(crate) fn build_segment_agg_collector(
AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
AggKind::Filter => build_segment_filter_collector(req, node),
}
}
@@ -493,6 +476,7 @@ pub(crate) fn build_aggregations_data_from_req(
let mut data = AggregationsSegmentCtx {
per_request: Default::default(),
context,
column_block_accessor: ColumnBlockAccessor::default(),
};
for (name, agg) in aggs.iter() {
@@ -521,9 +505,9 @@ fn build_nodes(
let idx_in_req_data = data.push_range_req_data(RangeAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: range_req.clone(),
is_top_level,
});
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
Ok(vec![AggRefNode {
@@ -541,9 +525,7 @@ fn build_nodes(
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
sub_aggregation_blueprint: None,
req: histo_req.clone(),
is_date_histogram: false,
bounds: HistogramBounds {
@@ -568,9 +550,7 @@ fn build_nodes(
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
sub_aggregation_blueprint: None,
req: histo_req,
is_date_histogram: true,
bounds: HistogramBounds {
@@ -650,7 +630,6 @@ fn build_nodes(
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
collecting_for,
missing: *missing,
@@ -678,7 +657,6 @@ fn build_nodes(
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
collecting_for: StatsType::Percentiles,
missing: percentiles_req.missing,
@@ -753,6 +731,7 @@ fn build_nodes(
segment_reader: reader.clone(),
evaluator,
matching_docs_buffer,
is_top_level,
});
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
Ok(vec![AggRefNode {
@@ -895,7 +874,7 @@ fn build_terms_or_cardinality_nodes(
});
}
// Add one node per accessor to mirror previous behavior and allow per-type missing handling.
// Add one node per accessor
for (accessor, column_type) in column_and_types {
let missing_value_for_accessor = if use_special_missing_agg {
None
@@ -926,11 +905,8 @@ fn build_terms_or_cardinality_nodes(
column_type,
str_dict_column: str_dict_column.clone(),
missing_value_for_accessor,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: TermsAggregationInternal::from_req(req),
// Will be filled later when building collectors
sub_aggregation_blueprint: None,
sug_aggregations: sub_aggs.clone(),
allowed_term_ids,
is_top_level,
@@ -943,7 +919,6 @@ fn build_terms_or_cardinality_nodes(
column_type,
str_dict_column: str_dict_column.clone(),
missing_value_for_accessor,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: req.clone(),
});

View File

@@ -2,15 +2,441 @@ use serde_json::Value;
use crate::aggregation::agg_req::{Aggregation, Aggregations};
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, FAST};
use crate::{Index, IndexWriter, Term};
// The following tests ensure that each bucket aggregation type correctly functions as a
// sub-aggregation of another bucket aggregation in two scenarios:
// 1) The parent has more buckets than the child sub-aggregation
// 2) The child sub-aggregation has more buckets than the parent
//
// These scenarios exercise the bucket id mapping and sub-aggregation routing logic.
#[test]
fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 4 buckets
// Child: terms on text -> 2 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
// Exact expected structure and counts
assert_eq!(
res["parent_range"]["buckets"],
json!([
{
"key": "*-3",
"doc_count": 1,
"to": 3.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "3-7",
"doc_count": 3,
"from": 3.0,
"to": 7.0,
"child_terms": {
"buckets": [
{"doc_count": 2, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
},
{
"key": "7-20",
"doc_count": 3,
"from": 7.0,
"to": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 3, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "20-*",
"doc_count": 2,
"from": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
])
);
// Case B: child has more buckets than parent
// Parent: histogram on score with large interval -> 1 bucket
// Child: terms on text -> 2 buckets (cool/nohit)
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_hist": {
"histogram": {"field": "score", "interval": 100.0},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_hist"],
json!({
"buckets": [
{
"key": 0.0,
"doc_count": 9,
"child_terms": {
"buckets": [
{"doc_count": 7, "key": "cool"},
{"doc_count": 2, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
]
})
);
Ok(())
}
#[test]
fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 5 buckets
// Child: coarse range with 3 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 2, "from": 20.0}
]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: range with 4 buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several ranges
// Child: histogram with large interval (single bucket per parent)
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 100.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text -> 2 buckets
// Child: histogram with small interval -> multiple buckets including empties
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 10.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 4},
{"key": 10.0, "doc_count": 2},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 1},
{"key": 10.0, "doc_count": 0},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several buckets
// Child: date_histogram with 30d -> single bucket per parent
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
let buckets = res["parent_range"]["buckets"].as_array().unwrap();
// Verify each parent bucket has exactly one child date bucket with matching doc_count
for bucket in buckets {
let parent_count = bucket["doc_count"].as_u64().unwrap();
let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(child_buckets.len(), 1);
assert_eq!(child_buckets[0]["doc_count"], parent_count);
}
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: date_histogram with 1d -> multiple buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
let buckets = res["parent_terms"]["buckets"].as_array().unwrap();
// cool bucket
assert_eq!(buckets[0]["key"], "cool");
let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(cool_buckets.len(), 3);
assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0
assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1
assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2
// nohit bucket
assert_eq!(buckets[1]["key"], "nohit");
let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(nohit_buckets.len(), 2);
assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1
assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2
Ok(())
}
fn get_avg_req(field_name: &str) -> Aggregation {
serde_json::from_value(json!({
"avg": {
@@ -25,6 +451,10 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
// Note: The flushng part of these tests are outdated, since the buffering change after converting
// the collection into one collector per request instead of per bucket.
//
// However they are useful as they test a complex aggregation requests.
fn test_aggregation_flushing(
merge_segments: bool,
use_distributed_collector: bool,
@@ -37,8 +467,9 @@ fn test_aggregation_flushing(
let reader = index.reader()?;
assert_eq!(DOC_BLOCK_SIZE, 64);
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block.
assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64);
// In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one
// block.
//
// Build a request so that on the first level we have one full cache, which is then flushed.
// The same cache should have some residue docs at the end, which are flushed (Range 0-70)

View File

@@ -6,10 +6,14 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
};
use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
use crate::docset::DocSet;
use crate::query::{AllQuery, EnableScoring, Query, QueryParser};
use crate::schema::Schema;
@@ -404,15 +408,18 @@ pub struct FilterAggReqData {
pub evaluator: DocumentQueryEvaluator,
/// Reusable buffer for matching documents to minimize allocations during collection
pub matching_docs_buffer: Vec<DocId>,
/// True if this filter aggregation is at the top level of the aggregation tree (not nested).
pub is_top_level: bool,
}
impl FilterAggReqData {
pub(crate) fn get_memory_consumption(&self) -> usize {
// Estimate: name + segment reader reference + bitset + buffer capacity
self.name.len()
+ std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
+ std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
+ std::mem::size_of::<bool>()
}
}
@@ -489,17 +496,24 @@ impl Debug for DocumentQueryEvaluator {
}
}
/// Segment collector for filter aggregation
pub struct SegmentFilterCollector {
/// Document count in this bucket
#[derive(Debug, Clone, PartialEq, Copy)]
struct DocCount {
doc_count: u64,
bucket_id: BucketId,
}
/// Segment collector for filter aggregation
pub struct SegmentFilterCollector<C: SubAggCache> {
/// Document counts per parent bucket
parent_buckets: Vec<DocCount>,
/// Sub-aggregation collectors
sub_aggregations: Option<Box<dyn SegmentAggregationCollector>>,
sub_aggregations: Option<CachedSubAggs<C>>,
bucket_id_provider: BucketIdProvider,
/// Accessor index for this filter aggregation (to access FilterAggReqData)
accessor_idx: usize,
}
impl SegmentFilterCollector {
impl<C: SubAggCache> SegmentFilterCollector<C> {
/// Create a new filter segment collector following the new agg_data pattern
pub(crate) fn from_req_and_validate(
req: &mut AggregationsSegmentCtx,
@@ -511,47 +525,75 @@ impl SegmentFilterCollector {
} else {
None
};
let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new);
Ok(SegmentFilterCollector {
doc_count: 0,
parent_buckets: Vec::new(),
sub_aggregations: sub_agg_collector,
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
})
}
}
impl Debug for SegmentFilterCollector {
pub(crate) fn build_segment_filter_collector(
req: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let is_top_level = req.per_request.filter_req_data[node.idx_in_req_data]
.as_ref()
.expect("filter_req_data slot is empty")
.is_top_level;
if is_top_level {
Ok(Box::new(
SegmentFilterCollector::<LowCardSubAggCache>::from_req_and_validate(req, node)?,
))
} else {
Ok(Box::new(
SegmentFilterCollector::<HighCardSubAggCache>::from_req_and_validate(req, node)?,
))
}
}
impl<C: SubAggCache> Debug for SegmentFilterCollector<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentFilterCollector")
.field("doc_count", &self.doc_count)
.field("buckets", &self.parent_buckets)
.field("has_sub_aggs", &self.sub_aggregations.is_some())
.field("accessor_idx", &self.accessor_idx)
.finish()
}
}
impl CollectorClone for SegmentFilterCollector {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
// For now, panic - this needs proper implementation with weight recreation
panic!("SegmentFilterCollector cloning not yet implemented - requires weight recreation")
}
}
impl SegmentAggregationCollector for SegmentFilterCollector {
impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let mut sub_results = IntermediateAggregationResults::default();
let bucket_opt = self.parent_buckets.get(parent_bucket_id as usize);
if let Some(sub_aggs) = self.sub_aggregations {
sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?;
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_results,
// Here we create a new bucket ID for sub-aggregations if the bucket doesn't
// exist, so that sub-aggregations can still produce results (e.g., zero doc
// count)
bucket_opt
.map(|bucket| bucket.bucket_id)
.unwrap_or(self.bucket_id_provider.next_bucket_id()),
)?;
}
// Create the filter bucket result
let filter_bucket_result = IntermediateBucketResult::Filter {
doc_count: self.doc_count,
doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0),
sub_aggregations: sub_results,
};
@@ -570,32 +612,17 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
Ok(())
}
fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
// Access the evaluator from FilterAggReqData
let req_data = agg_data.get_filter_req_data(self.accessor_idx);
// O(1) BitSet lookup to check if document matches filter
if req_data.evaluator.matches_document(doc) {
self.doc_count += 1;
// If we have sub-aggregations, collect on them for this filtered document
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.collect(doc, agg_data)?;
}
}
Ok(())
}
#[inline]
fn collect_block(
fn collect(
&mut self,
docs: &[DocId],
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if docs.is_empty() {
return Ok(());
}
let mut bucket = self.parent_buckets[parent_bucket_id as usize];
// Take the request data to avoid borrow checker issues with sub-aggregations
let mut req = agg_data.take_filter_req_data(self.accessor_idx);
@@ -604,18 +631,24 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
req.evaluator
.filter_batch(docs, &mut req.matching_docs_buffer);
self.doc_count += req.matching_docs_buffer.len() as u64;
bucket.doc_count += req.matching_docs_buffer.len() as u64;
// Batch process sub-aggregations if we have matches
if !req.matching_docs_buffer.is_empty() {
if let Some(sub_aggs) = &mut self.sub_aggregations {
// Use collect_block for better sub-aggregation performance
sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?;
for &doc_id in &req.matching_docs_buffer {
sub_aggs.push(bucket.bucket_id, doc_id);
}
}
}
// Put the request data back
agg_data.put_back_filter_req_data(self.accessor_idx, req);
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.check_flush_local(agg_data)?;
}
// put back bucket
self.parent_buckets[parent_bucket_id as usize] = bucket;
Ok(())
}
@@ -626,6 +659,21 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let bucket_id = self.bucket_id_provider.next_bucket_id();
self.parent_buckets.push(DocCount {
doc_count: 0,
bucket_id,
});
}
Ok(())
}
}
/// Intermediate result for filter aggregation
@@ -1519,9 +1567,9 @@ mod tests {
let searcher = reader.searcher();
let agg = json!({
"test": {
"filter": deserialized,
"aggs": { "count": { "value_count": { "field": "brand" } } }
"test": {
"filter": deserialized,
"aggs": { "count": { "value_count": { "field": "brand" } } }
}
});

View File

@@ -1,6 +1,6 @@
use std::cmp::Ordering;
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use tantivy_bitpacker::minmax;
@@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::BucketEntry;
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::*;
use crate::TantivyError;
@@ -26,13 +26,8 @@ pub struct HistogramAggReqData {
pub accessor: Column<u64>,
/// The field type of the fast field.
pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation.
pub name: String,
/// The sub aggregation blueprint, used to create sub aggregations for each bucket.
/// Will be filled during initialization of the collector.
pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
/// The histogram aggregation request.
pub req: HistogramAggregation,
/// True if this is a date_histogram aggregation.
@@ -257,18 +252,24 @@ impl HistogramBounds {
pub(crate) struct SegmentHistogramBucketEntry {
pub key: f64,
pub doc_count: u64,
pub bucket_id: BucketId,
}
impl SegmentHistogramBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
sub_aggregation: &mut Option<HighCardCachedSubAggs>,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateHistogramBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?;
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
self.bucket_id,
)?;
}
Ok(IntermediateHistogramBucketEntry {
key: self.key,
@@ -278,27 +279,38 @@ impl SegmentHistogramBucketEntry {
}
}
#[derive(Clone, Debug, Default)]
struct HistogramBuckets {
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct SegmentHistogramCollector {
/// The buckets containing the aggregation data.
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>,
/// One Histogram bucket per parent bucket id.
parent_buckets: Vec<HistogramBuckets>,
sub_agg: Option<HighCardCachedSubAggs>,
accessor_idx: usize,
bucket_id_provider: BucketIdProvider,
}
impl SegmentAggregationCollector for SegmentHistogramCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data
.get_histogram_req_data(self.accessor_idx)
.name
.clone();
let bucket = self.into_intermediate_bucket_result(agg_data)?;
// TODO: avoid prepare_max_bucket here and handle empty buckets.
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
@@ -307,44 +319,40 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mut req = agg_data.take_histogram_req_data(self.accessor_idx);
let req = agg_data.take_histogram_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let bounds = req.bounds;
let interval = req.req.interval;
let offset = req.offset;
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
req.column_block_accessor.fetch_block(docs, &req.accessor);
for (doc, val) in req
agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let val = f64_from_fastfield_u64(val, &req.field_type);
let val = f64_from_fastfield_u64(val, req.field_type);
let bucket_pos = get_bucket_pos(val);
if bounds.contains(val) {
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| {
let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
SegmentHistogramBucketEntry { key, doc_count: 0 }
SegmentHistogramBucketEntry {
key,
doc_count: 0,
bucket_id: self.bucket_id_provider.next_bucket_id(),
}
});
bucket.doc_count += 1;
if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() {
self.sub_aggregations
.entry(bucket_pos)
.or_insert_with(|| sub_aggregation_blueprint.clone())
.collect(doc, agg_data)?;
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.push(bucket.bucket_id, doc);
}
}
}
@@ -358,14 +366,30 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
.add_memory_consumed(mem_delta as u64)?;
}
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for sub_aggregation in self.sub_aggregations.values_mut() {
if let Some(sub_aggregation) = &mut self.sub_agg {
sub_aggregation.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
self.parent_buckets.push(HistogramBuckets {
buckets: FxHashMap::default(),
});
}
Ok(())
}
}
@@ -373,22 +397,19 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
impl SegmentHistogramCollector {
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let sub_aggs_mem = self.sub_aggregations.memory_consumption();
let buckets_mem = self.buckets.memory_consumption();
self_mem + sub_aggs_mem + buckets_mem
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
self_mem + buckets_mem
}
/// Converts the collector result into a intermediate bucket result.
pub fn into_intermediate_bucket_result(
self,
fn add_intermediate_bucket_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
histogram: HistogramBuckets,
) -> crate::Result<IntermediateBucketResult> {
let mut buckets = Vec::with_capacity(self.buckets.len());
let mut buckets = Vec::with_capacity(histogram.buckets.len());
for (bucket_pos, bucket) in self.buckets {
let bucket_res = bucket.into_intermediate_bucket_entry(
self.sub_aggregations.get(&bucket_pos).cloned(),
agg_data,
);
for bucket in histogram.buckets.into_values() {
let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
buckets.push(bucket_res?);
}
@@ -408,7 +429,7 @@ impl SegmentHistogramCollector {
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let blueprint = if !node.children.is_empty() {
let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
@@ -423,13 +444,13 @@ impl SegmentHistogramCollector {
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
req_data.sub_aggregation_blueprint = blueprint;
let sub_agg = sub_agg.map(CachedSubAggs::new);
Ok(Self {
buckets: Default::default(),
sub_aggregations: Default::default(),
parent_buckets: Default::default(),
sub_agg,
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
})
}
}

View File

@@ -1,18 +1,22 @@
use std::fmt::Debug;
use std::ops::Range;
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::AggregationLimitsGuard;
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::*;
use crate::TantivyError;
@@ -23,12 +27,12 @@ pub struct RangeAggReqData {
pub accessor: Column<u64>,
/// The type of the fast field.
pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The range aggregation request.
pub req: RangeAggregation,
/// The name of the aggregation.
pub name: String,
/// Whether this is a top-level aggregation.
pub is_top_level: bool,
}
impl RangeAggReqData {
@@ -151,19 +155,47 @@ pub(crate) struct SegmentRangeAndBucketEntry {
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug)]
pub struct SegmentRangeCollector {
pub struct SegmentRangeCollector<C: SubAggCache> {
/// The buckets containing the aggregation data.
buckets: Vec<SegmentRangeAndBucketEntry>,
/// One for each ParentBucketId
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
column_type: ColumnType,
pub(crate) accessor_idx: usize,
sub_agg: Option<CachedSubAggs<C>>,
/// Here things get a bit weird. We need to assign unique bucket ids across all
/// parent buckets. So we keep track of the next available bucket id here.
/// This allows a kind of flattening of the bucket ids across all parent buckets.
/// E.g. in nested aggregations:
/// Term Agg -> Range aggregation -> Stats aggregation
/// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range
/// aggregation with 4 buckets. The Range aggregation will create buckets with ids:
/// - INFO: 0,1,2,3
/// - ERROR: 4,5,6,7
/// - WARN: 8,9,10,11
///
/// This allows the Stats aggregation to have unique bucket ids to refer to.
bucket_id_provider: BucketIdProvider,
limits: AggregationLimitsGuard,
}
impl<C: SubAggCache> Debug for SegmentRangeCollector<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentRangeCollector")
.field("parent_buckets_len", &self.parent_buckets.len())
.field("column_type", &self.column_type)
.field("accessor_idx", &self.accessor_idx)
.field("has_sub_agg", &self.sub_agg.is_some())
.finish()
}
}
/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry
#[derive(Clone)]
pub(crate) struct SegmentRangeBucketEntry {
pub key: Key,
pub doc_count: u64,
pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
// pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
pub bucket_id: BucketId,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not
@@ -184,48 +216,50 @@ impl Debug for SegmentRangeBucketEntry {
impl SegmentRangeBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateRangeBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = self.sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?
} else {
Default::default()
};
let sub_aggregation = IntermediateAggregationResults::default();
Ok(IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation: sub_aggregation_res,
sub_aggregation_res: sub_aggregation,
from: self.from,
to: self.to,
})
}
}
impl SegmentAggregationCollector for SegmentRangeCollector {
impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let field_type = self.column_type;
let name = agg_data
.get_range_req_data(self.accessor_idx)
.name
.to_string();
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
.buckets
let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = buckets
.into_iter()
.map(move |range_bucket| {
Ok((
range_to_string(&range_bucket.range, &field_type)?,
range_bucket
.bucket
.into_intermediate_bucket_entry(agg_data)?,
))
.map(|range_bucket| {
let bucket_id = range_bucket.bucket.bucket_id;
let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?;
if let Some(sub_aggregation) = &mut self.sub_agg {
sub_aggregation
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut agg.sub_aggregation_res,
bucket_id,
)?;
}
Ok((range_to_string(&range_bucket.range, &field_type)?, agg))
})
.collect::<crate::Result<_>>()?;
@@ -242,73 +276,114 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
// Take request data to avoid borrow conflicts during sub-aggregation
let mut req = agg_data.take_range_req_data(self.accessor_idx);
let req = agg_data.take_range_req_data(self.accessor_idx);
req.column_block_accessor.fetch_block(docs, &req.accessor);
agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
for (doc, val) in req
let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let bucket_pos = self.get_bucket_pos(val);
let bucket = &mut self.buckets[bucket_pos];
let bucket_pos = get_bucket_pos(val, buckets);
let bucket = &mut buckets[bucket_pos];
bucket.bucket.doc_count += 1;
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.collect(doc, agg_data)?;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket.bucket_id, doc);
}
}
agg_data.put_back_range_req_data(self.accessor_idx, req);
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for bucket in self.buckets.iter_mut() {
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.flush(agg_data)?;
}
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let new_buckets = self.create_new_buckets(agg_data)?;
self.parent_buckets.push(new_buckets);
}
Ok(())
}
}
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
pub(crate) fn build_segment_range_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let accessor_idx = node.idx_in_req_data;
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
let field_type = req_data.field_type;
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
// E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets,
// we can are still in low cardinality territory.
let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64;
let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
};
if is_low_card {
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggCache> {
sub_agg: sub_agg.map(LowCardCachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
} else {
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggCache> {
sub_agg: sub_agg.map(CachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
}
}
impl SegmentRangeCollector {
pub(crate) fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let accessor_idx = node.idx_in_req_data;
let (field_type, ranges) = {
let req_view = req_data.get_range_req_data(node.idx_in_req_data);
(req_view.field_type, req_view.req.ranges.clone())
};
impl<C: SubAggCache> SegmentRangeCollector<C> {
pub(crate) fn create_new_buckets(
&mut self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
let field_type = self.column_type;
let req_data = agg_data.get_range_req_data(self.accessor_idx);
// The range input on the request is f64.
// We need to convert to u64 ranges, because we read the values as u64.
// The mapping from the conversion is monotonic so ordering is preserved.
let sub_agg_prototype = if !node.children.is_empty() {
Some(build_segment_agg_collectors(req_data, &node.children)?)
} else {
None
};
let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)?
let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)?
.iter()
.map(|range| {
let bucket_id = self.bucket_id_provider.next_bucket_id();
let key = range
.key
.clone()
@@ -317,20 +392,20 @@ impl SegmentRangeCollector {
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, &field_type))
Some(f64_from_fastfield_u64(range.range.end, field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, &field_type))
Some(f64_from_fastfield_u64(range.range.start, field_type))
};
let sub_aggregation = sub_agg_prototype.clone();
// let sub_aggregation = sub_agg_prototype.clone();
Ok(SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
sub_aggregation,
bucket_id,
key,
from,
to,
@@ -339,27 +414,20 @@ impl SegmentRangeCollector {
})
.collect::<crate::Result<_>>()?;
req_data.context.limits.add_memory_consumed(
self.limits.add_memory_consumed(
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
)?;
Ok(SegmentRangeCollector {
buckets,
column_type: field_type,
accessor_idx,
})
}
#[inline]
fn get_bucket_pos(&self, val: u64) -> usize {
let pos = self
.buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(self.buckets[pos].range.contains(&val));
pos
Ok(buckets)
}
}
#[inline]
fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize {
let pos = buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(buckets[pos].range.contains(&val));
pos
}
/// Converts the user provided f64 range value to fast field value space.
///
@@ -456,7 +524,7 @@ pub(crate) fn range_to_string(
let val = i64::from_u64(val);
format_date(val)
} else {
Ok(f64_from_fastfield_u64(val, field_type).to_string())
Ok(f64_from_fastfield_u64(val, *field_type).to_string())
}
};
@@ -486,7 +554,7 @@ mod tests {
pub fn get_collector_from_ranges(
ranges: Vec<RangeAggregationRange>,
field_type: ColumnType,
) -> SegmentRangeCollector {
) -> SegmentRangeCollector<HighCardSubAggCache> {
let req = RangeAggregation {
field: "dummy".to_string(),
ranges,
@@ -506,30 +574,33 @@ mod tests {
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, &field_type))
Some(f64_from_fastfield_u64(range.range.end, field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, &field_type))
Some(f64_from_fastfield_u64(range.range.start, field_type))
};
SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
sub_aggregation: None,
key,
from,
to,
bucket_id: 0,
},
}
})
.collect();
SegmentRangeCollector {
buckets,
parent_buckets: vec![buckets],
column_type: field_type,
accessor_idx: 0,
sub_agg: None,
bucket_id_provider: Default::default(),
limits: AggregationLimitsGuard::default(),
}
}
@@ -776,7 +847,7 @@ mod tests {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -799,7 +870,7 @@ mod tests {
];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -814,7 +885,7 @@ mod tests {
let buckets = vec![(-10f64..-1f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
}
@@ -823,7 +894,7 @@ mod tests {
let buckets = vec![(0f64..10f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
}
@@ -832,7 +903,7 @@ mod tests {
fn range_binary_search_test_u64() {
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
let search = |val: u64| collector.get_bucket_pos(val);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9), 0);
@@ -878,7 +949,7 @@ mod tests {
let ranges = vec![(10.0..100.0).into()];
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
let search = |val: u64| collector.get_bucket_pos(val);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9f64.to_u64()), 0);
@@ -890,63 +961,3 @@ mod tests {
// the max value
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use itertools::Itertools;
use rand::seq::SliceRandom;
use rand::thread_rng;
use super::*;
use crate::aggregation::bucket::range::tests::get_collector_from_ranges;
const TOTAL_DOCS: u64 = 1_000_000u64;
const NUM_DOCS: u64 = 50_000u64;
fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector {
let bucket_size = num_docs / num_buckets;
let mut buckets: Vec<RangeAggregationRange> = vec![];
for i in 0..num_buckets {
let bucket_start = (i * bucket_size) as f64;
buckets.push((bucket_start..bucket_start + bucket_size as f64).into())
}
get_collector_from_ranges(buckets, ColumnType::U64)
}
fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec<u64> {
let mut rng = thread_rng();
let all_docs = (0..total_docs - 1).collect_vec();
let mut vals = all_docs
.as_slice()
.choose_multiple(&mut rng, num_docs_returned as usize)
.cloned()
.collect_vec();
vals.sort();
vals
}
fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) {
let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS);
let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS);
b.iter(|| {
let mut bucket_pos = 0;
for val in &vals {
bucket_pos = collector.get_bucket_pos(*val);
}
bucket_pos
})
}
#[bench]
fn bench_range_100_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 100)
}
#[bench]
fn bench_range_10_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 10)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -5,11 +5,13 @@ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::term_agg::TermsAggregation;
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
/// Special aggregation to handle missing values for term aggregations.
/// This missing aggregation will check multiple columns for existence.
@@ -35,41 +37,55 @@ impl MissingTermAggReqData {
}
}
/// The specialized missing term aggregation.
#[derive(Default, Debug, Clone)]
pub struct TermMissingAgg {
struct MissingCount {
missing_count: u32,
bucket_id: BucketId,
}
/// The specialized missing term aggregation.
#[derive(Default, Debug)]
pub struct TermMissingAgg {
accessor_idx: usize,
sub_agg: Option<Box<dyn SegmentAggregationCollector>>,
sub_agg: Option<HighCardCachedSubAggs>,
/// Idx = parent bucket id, Value = missing count for that bucket
missing_count_per_bucket: Vec<MissingCount>,
bucket_id_provider: BucketIdProvider,
}
impl TermMissingAgg {
pub(crate) fn new(
req_data: &mut AggregationsSegmentCtx,
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let has_sub_aggregations = !node.children.is_empty();
let accessor_idx = node.idx_in_req_data;
let sub_agg = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let sub_agg = sub_agg.map(CachedSubAggs::new);
let bucket_id_provider = BucketIdProvider::default();
Ok(Self {
accessor_idx,
sub_agg,
..Default::default()
missing_count_per_bucket: Vec::new(),
bucket_id_provider,
})
}
}
impl SegmentAggregationCollector for TermMissingAgg {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let term_agg = &req_data.req;
let missing = term_agg
@@ -80,13 +96,16 @@ impl SegmentAggregationCollector for TermMissingAgg {
let mut entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> =
Default::default();
let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize];
let mut missing_entry = IntermediateTermBucketEntry {
doc_count: self.missing_count,
doc_count: missing_count.missing_count,
sub_aggregation: Default::default(),
};
if let Some(sub_agg) = self.sub_agg {
if let Some(sub_agg) = &mut self.sub_agg {
let mut res = IntermediateAggregationResults::default();
sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?;
sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?;
missing_entry.sub_aggregation = res;
}
entries.insert(missing.into(), missing_entry);
@@ -109,30 +128,52 @@ impl SegmentAggregationCollector for TermMissingAgg {
fn collect(
&mut self,
doc: crate::DocId,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize];
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let has_value = req_data
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
self.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.collect(doc, agg_data)?;
for doc in docs {
let doc = *doc;
let has_value = req_data
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
bucket.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket_id, doc);
}
}
}
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn collect_block(
fn prepare_max_bucket(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
for doc in docs {
self.collect(*doc, agg_data)?;
while self.missing_count_per_bucket.len() <= max_bucket as usize {
let bucket_id = self.bucket_id_provider.next_bucket_id();
self.missing_count_per_bucket.push(MissingCount {
missing_count: 0,
bucket_id,
});
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
}
Ok(())
}

View File

@@ -1,87 +0,0 @@
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::DocId;
#[cfg(test)]
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
#[cfg(not(test))]
pub(crate) const DOC_BLOCK_SIZE: usize = 256;
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
/// BufAggregationCollector buffers documents before calling collect_block().
#[derive(Clone)]
pub(crate) struct BufAggregationCollector {
pub(crate) collector: Box<dyn SegmentAggregationCollector>,
staged_docs: DocBlock,
num_staged_docs: usize,
}
impl std::fmt::Debug for BufAggregationCollector {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SegmentAggregationResultsCollector")
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
.field("num_staged_docs", &self.num_staged_docs)
.finish()
}
}
impl BufAggregationCollector {
pub fn new(collector: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
collector,
num_staged_docs: 0,
staged_docs: [0; DOC_BLOCK_SIZE],
}
}
}
impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results)
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.staged_docs[self.num_staged_docs] = doc;
self.num_staged_docs += 1;
if self.num_staged_docs == self.staged_docs.len() {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collector.collect_block(docs, agg_data)?;
Ok(())
}
#[inline]
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
self.collector.flush(agg_data)?;
Ok(())
}
}

View File

@@ -0,0 +1,245 @@
use std::fmt::Debug;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
use crate::aggregation::BucketId;
use crate::DocId;
/// A cache for sub-aggregations, storing doc ids per bucket id.
/// Depending on the cardinality of the parent aggregation, we use different
/// storage strategies.
///
/// ## Low Cardinality
/// Cardinality here refers to the number of unique flattened buckets that can be created
/// by the parent aggregation.
/// Flattened buckets are the result of combining all buckets per collector
/// into a single list of buckets, where each bucket is identified by its BucketId.
///
/// ## Usage
/// Since this is caching for sub-aggregations, it is only used by bucket
/// aggregations.
///
/// TODO: consider using a more advanced data structure for high cardinality
/// aggregations.
/// What this datastructure does in general is to group docs by bucket id.
#[derive(Debug)]
pub(crate) struct CachedSubAggs<C: SubAggCache> {
cache: C,
sub_agg_collector: Box<dyn SegmentAggregationCollector>,
num_docs: usize,
}
pub type LowCardCachedSubAggs = CachedSubAggs<LowCardSubAggCache>;
pub type HighCardCachedSubAggs = CachedSubAggs<HighCardSubAggCache>;
const FLUSH_THRESHOLD: usize = 2048;
/// A trait for caching sub-aggregation doc ids per bucket id.
/// Different implementations can be used depending on the cardinality
/// of the parent aggregation.
pub trait SubAggCache: Debug {
fn new() -> Self;
fn push(&mut self, bucket_id: BucketId, doc_id: DocId);
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
force: bool,
) -> crate::Result<()>;
}
impl<Backend: SubAggCache + Debug> CachedSubAggs<Backend> {
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
cache: Backend::new(),
sub_agg_collector: sub_agg,
num_docs: 0,
}
}
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
&mut self.sub_agg_collector
}
#[inline]
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
self.cache.push(bucket_id, doc_id);
self.num_docs += 1;
}
/// Check if we need to flush based on the number of documents cached.
/// If so, flushes the cache to the provided aggregation collector.
pub fn check_flush_local(
&mut self,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if self.num_docs >= FLUSH_THRESHOLD {
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, false)?;
self.num_docs = 0;
}
Ok(())
}
/// Note: this _does_ flush the sub aggregations.
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if self.num_docs != 0 {
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, true)?;
self.num_docs = 0;
}
self.sub_agg_collector.flush(agg_data)?;
Ok(())
}
}
/// Number of partitions for high cardinality sub-aggregation cache.
const NUM_PARTITIONS: usize = 16;
#[derive(Debug)]
pub(crate) struct HighCardSubAggCache {
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
///
/// We want to keep this cheap, because high cardinality aggregations can have a lot of
/// buckets, and there may be nothing to group.
partitions: Box<[PartitionEntry; NUM_PARTITIONS]>,
}
impl HighCardSubAggCache {
#[inline]
fn clear(&mut self) {
for partition in self.partitions.iter_mut() {
partition.clear();
}
}
}
#[derive(Debug, Clone, Default)]
struct PartitionEntry {
bucket_ids: Vec<BucketId>,
docs: Vec<DocId>,
}
impl PartitionEntry {
#[inline]
fn clear(&mut self) {
self.bucket_ids.clear();
self.docs.clear();
}
}
impl SubAggCache for HighCardSubAggCache {
fn new() -> Self {
Self {
partitions: Box::new(core::array::from_fn(|_| PartitionEntry::default())),
}
}
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id % NUM_PARTITIONS as u32;
let slot = &mut self.partitions[idx as usize];
slot.bucket_ids.push(bucket_id);
slot.docs.push(doc_id);
}
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
_force: bool,
) -> crate::Result<()> {
let mut max_bucket = 0u32;
for partition in self.partitions.iter() {
if let Some(&local_max) = partition.bucket_ids.iter().max() {
max_bucket = max_bucket.max(local_max);
}
}
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
for slot in self.partitions.iter() {
if !slot.bucket_ids.is_empty() {
// Reduce dynamic dispatch overhead by collecting a full partition in one call.
sub_agg.collect_multiple(&slot.bucket_ids, &slot.docs, agg_data)?;
}
}
self.clear();
Ok(())
}
}
#[derive(Debug)]
pub(crate) struct LowCardSubAggCache {
/// Cache doc ids per bucket for sub-aggregations.
///
/// The outer Vec is indexed by BucketId.
per_bucket_docs: Vec<Vec<DocId>>,
}
impl LowCardSubAggCache {
#[inline]
fn clear(&mut self) {
for v in &mut self.per_bucket_docs {
v.clear();
}
}
}
impl SubAggCache for LowCardSubAggCache {
fn new() -> Self {
Self {
per_bucket_docs: Vec::new(),
}
}
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id as usize;
if self.per_bucket_docs.len() <= idx {
self.per_bucket_docs.resize_with(idx + 1, Vec::new);
}
self.per_bucket_docs[idx].push(doc_id);
}
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
force: bool,
) -> crate::Result<()> {
// Pre-aggregated: call collect per bucket.
let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1);
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
// The threshold above which we flush buckets individually.
// Note: We need to make sure that we don't lock ourselves into a situation where we hit
// the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush)
let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2);
const _: () = {
// MAX_NUM_TERMS_FOR_VEC threshold is used for term aggregations
// Note: There may be other flexible values, for other aggregations, but we can use the
// const value here as a upper bound. (better than nothing)
let bucket_treshold_limit = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2);
assert!(
bucket_treshold_limit > 0,
"Bucket threshold must be greater than 0"
);
};
if force {
bucket_treshold = 0;
}
for (bucket_id, docs) in self
.per_bucket_docs
.iter()
.enumerate()
.filter(|(_, docs)| docs.len() > bucket_treshold)
{
sub_agg.collect(bucket_id as BucketId, docs, agg_data)?;
}
self.clear();
Ok(())
}
}

View File

@@ -1,9 +1,9 @@
use super::agg_req::Aggregations;
use super::agg_result::AggregationResults;
use super::buf_collector::BufAggregationCollector;
use super::cached_sub_aggs::LowCardCachedSubAggs;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use super::AggContextParams;
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
use crate::aggregation::agg_data::{
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
};
@@ -136,7 +136,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsSegmentCtx,
agg_collector: BufAggregationCollector,
agg_collector: LowCardCachedSubAggs,
error: Option<TantivyError>,
}
@@ -151,8 +151,11 @@ impl AggregationSegmentCollector {
) -> crate::Result<Self> {
let mut agg_data =
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
let result =
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);
let mut result =
LowCardCachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
result
.get_sub_agg_collector()
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero
Ok(AggregationSegmentCollector {
aggs_with_accessor: agg_data,
@@ -170,26 +173,31 @@ impl SegmentCollector for AggregationSegmentCollector {
if self.error.is_some() {
return;
}
if let Err(err) = self
self.agg_collector.push(0, doc);
match self
.agg_collector
.collect(doc, &mut self.aggs_with_accessor)
.check_flush_local(&mut self.aggs_with_accessor)
{
self.error = Some(err);
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
}
}
/// The query pushes the documents to the collector via this method.
///
/// Only valid for Collectors that ignore docs
fn collect_block(&mut self, docs: &[DocId]) {
if self.error.is_some() {
return;
}
if let Err(err) = self
.agg_collector
.collect_block(docs, &mut self.aggs_with_accessor)
{
self.error = Some(err);
match self.agg_collector.get_sub_agg_collector().collect(
0,
docs,
&mut self.aggs_with_accessor,
) {
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
}
}
@@ -200,10 +208,13 @@ impl SegmentCollector for AggregationSegmentCollector {
self.agg_collector.flush(&mut self.aggs_with_accessor)?;
let mut sub_aggregation_res = IntermediateAggregationResults::default();
Box::new(self.agg_collector).add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
)?;
self.agg_collector
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
0,
)?;
Ok(sub_aggregation_res)
}

View File

@@ -792,7 +792,7 @@ pub struct IntermediateRangeBucketEntry {
/// The number of documents in the bucket.
pub doc_count: u64,
/// The sub_aggregation in this bucket.
pub sub_aggregation: IntermediateAggregationResults,
pub sub_aggregation_res: IntermediateAggregationResults,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`.
@@ -811,7 +811,7 @@ impl IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation
.sub_aggregation_res
.into_final_result_internal(req, limits)?,
to: self.to,
from: self.from,
@@ -857,7 +857,8 @@ impl MergeFruits for IntermediateTermBucketEntry {
impl MergeFruits for IntermediateRangeBucketEntry {
fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> {
self.doc_count += other.doc_count;
self.sub_aggregation.merge_fruits(other.sub_aggregation)?;
self.sub_aggregation_res
.merge_fruits(other.sub_aggregation_res)?;
Ok(())
}
}
@@ -887,7 +888,7 @@ mod tests {
IntermediateRangeBucketEntry {
key: IntermediateKey::Str(key.to_string()),
doc_count: *doc_count,
sub_aggregation: Default::default(),
sub_aggregation_res: Default::default(),
from: None,
to: None,
},
@@ -920,7 +921,7 @@ mod tests {
doc_count: *doc_count,
from: None,
to: None,
sub_aggregation: get_sub_test_tree(&[(
sub_aggregation_res: get_sub_test_tree(&[(
sub_aggregation_key.to_string(),
*sub_aggregation_count,
)]),

View File

@@ -52,10 +52,8 @@ pub struct IntermediateAverage {
impl IntermediateAverage {
/// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateAverage) {

View File

@@ -2,7 +2,7 @@ use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher};
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn};
use columnar::{Column, ColumnType, Dictionary, StrColumn};
use common::f64_to_u64;
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use rustc_hash::FxHashSet;
@@ -106,8 +106,6 @@ pub struct CardinalityAggReqData {
pub str_dict_column: Option<StrColumn>,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_value_for_accessor: Option<u64>,
/// The column block accessor to access the fast field values.
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation.
pub name: String,
/// The aggregation request.
@@ -135,45 +133,34 @@ impl CardinalityAggregationReq {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentCardinalityCollector {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
buckets: Vec<SegmentCardinalityCollectorBucket>,
accessor_idx: usize,
/// The column accessor to access the fast field values.
accessor: Column<u64>,
/// The column_type of the field.
column_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
missing_value_for_accessor: Option<u64>,
}
impl SegmentCardinalityCollector {
pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self {
#[derive(Clone, Debug, PartialEq, Default)]
pub(crate) struct SegmentCardinalityCollectorBucket {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
}
impl SegmentCardinalityCollectorBucket {
pub fn new(column_type: ColumnType) -> Self {
Self {
cardinality: CardinalityCollector::new(column_type as u8),
entries: Default::default(),
accessor_idx,
entries: FxHashSet::default(),
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut CardinalityAggReqData,
) {
if let Some(missing) = agg_data.missing_value_for_accessor {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&agg_data.accessor,
missing,
);
} else {
agg_data
.column_block_accessor
.fetch_block(docs, &agg_data.accessor);
}
}
fn into_intermediate_metric_result(
mut self,
agg_data: &AggregationsSegmentCtx,
req_data: &CardinalityAggReqData,
) -> crate::Result<IntermediateMetricResult> {
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
if req_data.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let dict = req_data
@@ -194,6 +181,7 @@ impl SegmentCardinalityCollector {
term_ids.push(term_ord as u32);
}
}
term_ids.sort_unstable();
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
self.cardinality.sketch.insert_any(&term);
@@ -227,16 +215,49 @@ impl SegmentCardinalityCollector {
}
}
impl SegmentCardinalityCollector {
pub fn from_req(
column_type: ColumnType,
accessor_idx: usize,
accessor: Column<u64>,
missing_value_for_accessor: Option<u64>,
) -> Self {
Self {
buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1],
column_type,
accessor_idx,
accessor,
missing_value_for_accessor,
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_value_for_accessor,
);
}
}
impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
let name = req_data.name.to_string();
// take the bucket in buckets and replace it with a new empty one
let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_result = self.into_intermediate_metric_result(agg_data)?;
let intermediate_result = bucket.into_intermediate_metric_result(req_data)?;
results.push(
name,
IntermediateAggregationResult::Metric(intermediate_result),
@@ -247,27 +268,20 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx);
self.fetch_block_with_field(docs, req_data);
self.fetch_block_with_field(docs, agg_data);
let bucket = &mut self.buckets[parent_bucket_id as usize];
let col_block_accessor = &req_data.column_block_accessor;
if req_data.column_type == ColumnType::Str {
let col_block_accessor = &agg_data.column_block_accessor;
if self.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() {
self.entries.insert(term_ord);
bucket.entries.insert(term_ord);
}
} else if req_data.column_type == ColumnType::IpAddr {
let compact_space_accessor = req_data
} else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = self
.accessor
.values
.clone()
@@ -282,16 +296,29 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
self.cardinality.sketch.insert_any(&val);
bucket.cardinality.sketch.insert_any(&val);
}
} else {
for val in col_block_accessor.iter_vals() {
self.cardinality.sketch.insert_any(&val);
bucket.cardinality.sketch.insert_any(&val);
}
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
if max_bucket as usize >= self.buckets.len() {
self.buckets.resize_with(max_bucket as usize + 1, || {
SegmentCardinalityCollectorBucket::new(self.column_type)
});
}
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]

View File

@@ -52,10 +52,8 @@ pub struct IntermediateCount {
impl IntermediateCount {
/// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateCount) {

View File

@@ -8,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
use crate::TantivyError;
/// A multi-value metric aggregation that computes a collection of extended statistics
/// on numeric values that are extracted
@@ -318,51 +317,28 @@ impl IntermediateExtendedStats {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentExtendedStatsCollector {
name: String,
missing: Option<u64>,
field_type: ColumnType,
pub(crate) extended_stats: IntermediateExtendedStats,
pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
accessor: columnar::Column<u64>,
buckets: Vec<IntermediateExtendedStats>,
sigma: Option<f64>,
}
impl SegmentExtendedStatsCollector {
pub fn from_req(
field_type: ColumnType,
sigma: Option<f64>,
accessor_idx: usize,
missing: Option<f64>,
) -> Self {
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
pub fn from_req(req: &MetricAggReqData, sigma: Option<f64>) -> Self {
let missing = req
.missing
.and_then(|val| f64_to_fastfield_u64(val, &req.field_type));
Self {
field_type,
extended_stats: IntermediateExtendedStats::with_sigma(sigma),
accessor_idx,
name: req.name.clone(),
field_type: req.field_type,
accessor: req.accessor.clone(),
missing,
val_cache: Default::default(),
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = self.missing.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16],
sigma,
}
}
}
@@ -370,15 +346,18 @@ impl SegmentExtendedStatsCollector {
impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
let name = self.name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
results.push(
name,
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
self.extended_stats,
extended_stats,
)),
)?;
@@ -388,39 +367,36 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = self.missing {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
has_val = true;
}
if !has_val {
self.extended_stats
.collect(f64_from_fastfield_u64(missing, &self.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
}
let mut extended_stats = self.buckets[parent_bucket_id as usize].clone();
agg_data
.column_block_accessor
.fetch_block_with_missing(docs, &self.accessor, self.missing);
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
extended_stats.collect(val1);
}
// store back
self.buckets[parent_bucket_id as usize] = extended_stats;
Ok(())
}
#[inline]
fn collect_block(
fn prepare_max_bucket(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
if self.buckets.len() <= max_bucket as usize {
self.buckets.resize_with(max_bucket as usize + 1, || {
IntermediateExtendedStats::with_sigma(self.sigma)
});
}
Ok(())
}
}

View File

@@ -52,10 +52,8 @@ pub struct IntermediateMax {
impl IntermediateMax {
/// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMax) {

View File

@@ -52,10 +52,8 @@ pub struct IntermediateMin {
impl IntermediateMin {
/// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMin) {

View File

@@ -31,7 +31,7 @@ use std::collections::HashMap;
pub use average::*;
pub use cardinality::*;
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use columnar::{Column, ColumnType};
pub use count::*;
pub use extended_stats::*;
pub use max::*;
@@ -55,8 +55,6 @@ pub struct MetricAggReqData {
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// Used when converting to intermediate result

View File

@@ -7,10 +7,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
use crate::TantivyError;
/// # Percentiles
///
@@ -131,10 +130,16 @@ impl PercentilesAggregationReq {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentPercentilesCollector {
pub(crate) percentiles: PercentilesCollector,
pub(crate) buckets: Vec<PercentilesCollector>,
pub(crate) accessor_idx: usize,
/// The type of the field.
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
}
#[derive(Clone, Serialize, Deserialize)]
@@ -229,33 +234,18 @@ impl PercentilesCollector {
}
impl SegmentPercentilesCollector {
pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result<Self> {
Ok(Self {
percentiles: PercentilesCollector::new(),
pub fn from_req_and_validate(
field_type: ColumnType,
missing_u64: Option<u64>,
accessor: Column<u64>,
accessor_idx: usize,
) -> Self {
Self {
buckets: Vec::with_capacity(64),
field_type,
missing_u64,
accessor,
accessor_idx,
})
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
}
}
}
@@ -263,12 +253,18 @@ impl SegmentPercentilesCollector {
impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
// Swap collector with an empty one to avoid cloning
let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_metric_result =
IntermediateMetricResult::Percentiles(percentiles_collector);
results.push(
name,
@@ -281,40 +277,33 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
let percentiles = &mut self.buckets[parent_bucket_id as usize];
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
if let Some(missing) = req_data.missing_u64 {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
has_val = true;
}
if !has_val {
self.percentiles
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
}
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
percentiles.collect(val1);
}
Ok(())
}
#[inline]
fn collect_block(
fn prepare_max_bucket(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
while self.buckets.len() <= max_bucket as usize {
self.buckets.push(PercentilesCollector::new());
}
Ok(())
}
}

View File

@@ -1,5 +1,6 @@
use std::fmt::Debug;
use columnar::{Column, ColumnType};
use serde::{Deserialize, Serialize};
use super::*;
@@ -7,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
use crate::TantivyError;
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
/// are extracted from the aggregated documents.
@@ -83,7 +83,7 @@ impl Stats {
/// Intermediate result of the stats aggregation that can be combined with other intermediate
/// results.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateStats {
/// The number of extracted values.
pub(crate) count: u64,
@@ -187,75 +187,75 @@ pub enum StatsType {
Percentiles,
}
fn create_collector<const TYPE_ID: u8>(
req: &MetricAggReqData,
) -> Box<dyn SegmentAggregationCollector> {
Box::new(SegmentStatsCollector::<TYPE_ID> {
name: req.name.clone(),
collecting_for: req.collecting_for,
is_number_or_date_type: req.is_number_or_date_type,
missing_u64: req.missing_u64,
accessor: req.accessor.clone(),
buckets: vec![IntermediateStats::default()],
})
}
/// Build a concrete `SegmentStatsCollector` depending on the column type.
pub(crate) fn build_segment_stats_collector(
req: &MetricAggReqData,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match req.field_type {
ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)),
ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)),
ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)),
ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)),
ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)),
ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)),
ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)),
ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)),
}
}
#[repr(C)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentStatsCollector {
pub(crate) stats: IntermediateStats,
pub(crate) accessor_idx: usize,
pub(crate) struct SegmentStatsCollector<const COLUMN_TYPE_ID: u8> {
pub(crate) missing_u64: Option<u64>,
pub(crate) accessor: Column<u64>,
pub(crate) is_number_or_date_type: bool,
pub(crate) buckets: Vec<IntermediateStats>,
pub(crate) name: String,
pub(crate) collecting_for: StatsType,
}
impl SegmentStatsCollector {
pub fn from_req(accessor_idx: usize) -> Self {
Self {
stats: IntermediateStats::default(),
accessor_idx,
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
if req_data.is_number_or_date_type {
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
} else {
for _val in req_data.column_block_accessor.iter_vals() {
// we ignore the value and simply record that we got something
self.stats.collect(0.0);
}
}
}
}
impl SegmentAggregationCollector for SegmentStatsCollector {
impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
for SegmentStatsCollector<COLUMN_TYPE_ID>
{
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let req = agg_data.get_metric_req_data(self.accessor_idx);
let name = req.name.clone();
let name = self.name.clone();
let intermediate_metric_result = match req.collecting_for {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let stats = self.buckets[parent_bucket_id as usize];
let intermediate_metric_result = match self.collecting_for {
StatsType::Average => {
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self))
IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats))
}
StatsType::Count => {
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self))
IntermediateMetricResult::Count(IntermediateCount::from_stats(stats))
}
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)),
StatsType::Stats => IntermediateMetricResult::Stats(self.stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)),
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)),
StatsType::Stats => IntermediateMetricResult::Stats(stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)),
_ => {
return Err(TantivyError::InvalidArgument(format!(
"Unsupported stats type for stats aggregation: {:?}",
req.collecting_for
self.collecting_for
)))
}
};
@@ -271,41 +271,67 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = req_data.missing_u64 {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
has_val = true;
}
if !has_val {
self.stats
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
// TODO: remove once we fetch all values for all bucket ids in one go
if docs.len() == 1 && self.missing_u64.is_none() {
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
self.accessor.values_for_doc(docs[0]),
self.is_number_or_date_type,
)?;
return Ok(());
}
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
agg_data.column_block_accessor.iter_vals(),
self.is_number_or_date_type,
)?;
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let required_buckets = (max_bucket as usize) + 1;
if self.buckets.len() < required_buckets {
self.buckets
.resize_with(required_buckets, IntermediateStats::default);
}
Ok(())
}
}
#[inline]
fn collect_stats<const COLUMN_TYPE_ID: u8>(
stats: &mut IntermediateStats,
vals: impl Iterator<Item = u64>,
is_number_or_date_type: bool,
) -> crate::Result<()> {
if is_number_or_date_type {
for val in vals {
let val1 = convert_to_f64::<COLUMN_TYPE_ID>(val);
stats.collect(val1);
}
} else {
for _val in vals {
// we ignore the value and simply record that we got something
stats.collect(0.0);
}
}
Ok(())
}
#[cfg(test)]

View File

@@ -52,10 +52,8 @@ pub struct IntermediateSum {
impl IntermediateSum {
/// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateSum) {

View File

@@ -1,8 +1,7 @@
use std::cmp::Ordering;
use std::collections::HashMap;
use std::net::Ipv6Addr;
use columnar::{Column, ColumnType, ColumnarReader, DynamicColumn, ValueRange};
use columnar::{Column, ColumnType, ColumnarReader, DynamicColumn};
use common::json_path_writer::JSON_PATH_SEGMENT_SEP_STR;
use common::DateTime;
use regex::Regex;
@@ -16,12 +15,11 @@ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::AggregationError;
use crate::collector::sort_key::{Comparator, ReverseComparator};
use crate::aggregation::{AggregationError, BucketId};
use crate::collector::sort_key::ReverseComparator;
use crate::collector::TopNComputer;
use crate::schema::OwnedValue;
use crate::{DocAddress, DocId, SegmentOrdinal};
// duplicate import removed; already imported above
/// Contains all information required by the TopHitsSegmentCollector to perform the
/// top_hits aggregation on a segment.
@@ -384,7 +382,7 @@ impl From<FastFieldValue> for OwnedValue {
/// Holds a fast field value in its u64 representation, and the order in which it should be sorted.
#[derive(Clone, Serialize, Deserialize, Debug)]
pub(crate) struct DocValueAndOrder {
struct DocValueAndOrder {
/// A fast field value in its u64 representation.
value: Option<u64>,
/// Sort order for the value
@@ -456,37 +454,6 @@ impl PartialEq for DocSortValuesAndFields {
impl Eq for DocSortValuesAndFields {}
impl Comparator<DocSortValuesAndFields> for ReverseComparator {
#[inline(always)]
fn compare(&self, lhs: &DocSortValuesAndFields, rhs: &DocSortValuesAndFields) -> Ordering {
rhs.cmp(lhs)
}
fn threshold_to_valuerange(
&self,
threshold: DocSortValuesAndFields,
) -> ValueRange<DocSortValuesAndFields> {
ValueRange::LessThan(threshold, true)
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct TopHitsSegmentSortKey(pub Vec<DocValueAndOrder>);
impl Comparator<TopHitsSegmentSortKey> for ReverseComparator {
#[inline(always)]
fn compare(&self, lhs: &TopHitsSegmentSortKey, rhs: &TopHitsSegmentSortKey) -> Ordering {
rhs.cmp(lhs)
}
fn threshold_to_valuerange(
&self,
threshold: TopHitsSegmentSortKey,
) -> ValueRange<TopHitsSegmentSortKey> {
ValueRange::LessThan(threshold, true)
}
}
/// The TopHitsCollector used for collecting over segments and merging results.
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct TopHitsTopNComputer {
@@ -504,7 +471,10 @@ impl TopHitsTopNComputer {
/// Create a new TopHitsCollector
pub fn new(req: &TopHitsAggregationReq) -> Self {
Self {
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
top_n: TopNComputer::new_with_comparator(
req.size + req.from.unwrap_or(0),
ReverseComparator,
),
req: req.clone(),
}
}
@@ -550,7 +520,8 @@ impl TopHitsTopNComputer {
pub(crate) struct TopHitsSegmentCollector {
segment_ordinal: SegmentOrdinal,
accessor_idx: usize,
top_n: TopNComputer<TopHitsSegmentSortKey, DocAddress, ReverseComparator>,
buckets: Vec<TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>>,
num_hits: usize,
}
impl TopHitsSegmentCollector {
@@ -559,27 +530,35 @@ impl TopHitsSegmentCollector {
accessor_idx: usize,
segment_ordinal: SegmentOrdinal,
) -> Self {
let num_hits = req.size + req.from.unwrap_or(0);
Self {
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
num_hits,
segment_ordinal,
accessor_idx,
buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1],
}
}
fn into_top_hits_collector(
self,
fn get_top_hits_computer(
&mut self,
parent_bucket_id: BucketId,
value_accessors: &HashMap<String, Vec<DynamicColumn>>,
req: &TopHitsAggregationReq,
) -> TopHitsTopNComputer {
if parent_bucket_id as usize >= self.buckets.len() {
return TopHitsTopNComputer::new(req);
}
let top_n = std::mem::replace(
&mut self.buckets[parent_bucket_id as usize],
TopNComputer::new(0),
);
let mut top_hits_computer = TopHitsTopNComputer::new(req);
// Map TopHitsSegmentSortKey back to Vec<DocValueAndOrder> if needed or use directly
// The TopNComputer here stores TopHitsSegmentSortKey.
let top_results = self.top_n.into_vec();
let top_results = top_n.into_vec();
for res in top_results {
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
top_hits_computer.collect(
DocSortValuesAndFields {
sorts: res.sort_key.0,
sorts: res.sort_key,
doc_value_fields,
},
res.doc,
@@ -588,54 +567,24 @@ impl TopHitsSegmentCollector {
top_hits_computer
}
/// TODO add a specialized variant for a single sort field
fn collect_with(
&mut self,
doc_id: crate::DocId,
req: &TopHitsAggregationReq,
accessors: &[(Column<u64>, ColumnType)],
) -> crate::Result<()> {
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
self.top_n.push(
TopHitsSegmentSortKey(sorts),
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
Ok(())
}
}
impl SegmentAggregationCollector for TopHitsSegmentCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
let value_accessors = &req_data.value_accessors;
let intermediate_result = IntermediateMetricResult::TopHits(
self.into_top_hits_collector(value_accessors, &req_data.req),
);
let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer(
parent_bucket_id,
value_accessors,
&req_data.req,
));
results.push(
req_data.name.to_string(),
IntermediateAggregationResult::Metric(intermediate_result),
@@ -645,26 +594,56 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
/// TODO: Consider a caching layer to reduce the call overhead
fn collect(
&mut self,
doc_id: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
self.collect_with(doc_id, &req_data.req, &req_data.accessors)?;
Ok(())
}
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let top_n = &mut self.buckets[parent_bucket_id as usize];
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
// TODO: Consider getting fields with the column block accessor.
for doc in docs {
self.collect_with(*doc, &req_data.req, &req_data.accessors)?;
let req = &req_data.req;
let accessors = &req_data.accessors;
for &doc_id in docs {
// TODO: this is terrible, a new vec is allocated for every doc
// We can fetch blocks instead
// We don't need to store the order for every value
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
self.buckets.resize(
(max_bucket as usize) + 1,
TopNComputer::new_with_comparator(self.num_hits, ReverseComparator),
);
Ok(())
}
}
#[cfg(test)]
@@ -780,7 +759,7 @@ mod tests {
],
"from": 0,
}
}
}
}))
.unwrap();
@@ -909,7 +888,7 @@ mod tests {
"mixed.*",
],
}
}
}
}))?;
let collector = AggregationCollector::from_aggs(d, Default::default());

View File

@@ -133,7 +133,7 @@ mod agg_limits;
pub mod agg_req;
pub mod agg_result;
pub mod bucket;
mod buf_collector;
pub(crate) mod cached_sub_aggs;
mod collector;
mod date;
mod error;
@@ -162,6 +162,19 @@ use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::TokenizerManager;
/// A bucket id is a dense identifier for a bucket within an aggregation.
/// It is used to index into a Vec that hold per-bucket data.
///
/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId.
/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket.
///
/// This allows to have a single AggregationCollector instance per aggregation,
/// that can handle multiple buckets efficiently.
///
/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])].
/// For that we'll need a buffer. One Vec per bucket aggregation is needed.
pub type BucketId = u32;
/// Context parameters for aggregation execution
///
/// This struct holds shared resources needed during aggregation execution:
@@ -335,19 +348,37 @@ impl Display for Key {
}
}
pub(crate) fn convert_to_f64<const COLUMN_TYPE_ID: u8>(val: u64) -> f64 {
if COLUMN_TYPE_ID == ColumnType::U64 as u8 {
val as f64
} else if COLUMN_TYPE_ID == ColumnType::I64 as u8
|| COLUMN_TYPE_ID == ColumnType::DateTime as u8
{
i64::from_u64(val) as f64
} else if COLUMN_TYPE_ID == ColumnType::F64 as u8 {
f64::from_u64(val)
} else if COLUMN_TYPE_ID == ColumnType::Bool as u8 {
val as f64
} else {
panic!(
"ColumnType ID {} cannot be converted to f64 metric",
COLUMN_TYPE_ID
)
}
}
/// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics.
///
/// # Panics
/// Only `u64`, `f64`, `date`, and `i64` are supported.
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 {
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 {
match field_type {
ColumnType::U64 => val as f64,
ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64,
ColumnType::F64 => f64::from_u64(val),
ColumnType::Bool => val as f64,
_ => {
panic!("unexpected type {field_type:?}. This should not happen")
}
ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val),
ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val),
ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val),
ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val),
ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val),
_ => panic!("unexpected type {field_type:?}. This should not happen"),
}
}

View File

@@ -8,25 +8,67 @@ use std::fmt::Debug;
pub(crate) use super::agg_limits::AggregationLimitsGuard;
use super::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::BucketId;
/// Monotonically increasing provider of BucketIds.
#[derive(Debug, Clone, Default)]
pub struct BucketIdProvider(u32);
impl BucketIdProvider {
/// Get the next BucketId.
pub fn next_bucket_id(&mut self) -> BucketId {
let bucket_id = self.0;
self.0 += 1;
bucket_id
}
}
/// A SegmentAggregationCollector is used to collect aggregation results.
pub trait SegmentAggregationCollector: CollectorClone + Debug {
pub trait SegmentAggregationCollector: Debug {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()>;
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect(
&mut self,
doc: crate::DocId,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()>;
fn collect_block(
/// Collect docs for multiple buckets in one call.
/// Minimizes dynamic dispatch overhead when collecting many buckets.
///
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect_multiple(
&mut self,
bucket_ids: &[BucketId],
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
debug_assert_eq!(bucket_ids.len(), docs.len());
let mut start = 0;
while start < bucket_ids.len() {
let bucket_id = bucket_ids[start];
let mut end = start + 1;
while end < bucket_ids.len() && bucket_ids[end] == bucket_id {
end += 1;
}
self.collect(bucket_id, &docs[start..end], agg_data)?;
start = end;
}
Ok(())
}
/// Prepare the collector for collecting up to BucketId `max_bucket`.
/// This is useful so we can split allocation ahead of time of collecting.
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()>;
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
@@ -36,26 +78,7 @@ pub trait SegmentAggregationCollector: CollectorClone + Debug {
}
}
/// A helper trait to enable cloning of Box<dyn SegmentAggregationCollector>
pub trait CollectorClone {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
}
impl<T> CollectorClone for T
where T: 'static + SegmentAggregationCollector + Clone
{
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn SegmentAggregationCollector> {
fn clone(&self) -> Box<dyn SegmentAggregationCollector> {
self.clone_box()
}
}
#[derive(Clone, Default)]
#[derive(Default)]
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
/// and can provide specialized versions instead, that remove some of its overhead.
@@ -73,12 +96,13 @@ impl Debug for GenericSegmentAggregationResultsCollector {
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
for agg in self.aggs {
agg.add_intermediate_aggregation_result(agg_data, results)?;
for agg in &mut self.aggs {
agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?;
}
Ok(())
@@ -86,23 +110,13 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)?;
Ok(())
}
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.collect_block(docs, agg_data)?;
collector.collect(parent_bucket_id, docs, agg_data)?;
}
Ok(())
}
@@ -112,4 +126,15 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.prepare_max_bucket(max_bucket, agg_data)?;
}
Ok(())
}
}

View File

@@ -821,6 +821,7 @@ mod tests {
#[cfg(all(test, feature = "unstable"))]
mod bench {
use rand::seq::SliceRandom;
use rand::thread_rng;
use test::Bencher;

View File

@@ -96,10 +96,11 @@ mod histogram_collector;
pub use histogram_collector::HistogramCollector;
mod multi_collector;
pub use columnar::ComparableDoc;
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
mod top_collector;
pub use self::top_collector::ComparableDoc;
mod top_score_collector;
pub use self::top_score_collector::{TopDocs, TopNComputer};

View File

@@ -281,6 +281,7 @@ impl SegmentCollector for MultiCollectorChild {
#[cfg(test)]
mod tests {
use super::*;
use crate::collector::{Count, TopDocs};
use crate::query::TermQuery;

View File

@@ -1,12 +1,10 @@
mod order;
mod sort_by_erased_type;
mod sort_by_score;
mod sort_by_static_fast_value;
mod sort_by_string;
mod sort_key_computer;
pub use order::*;
pub use sort_by_erased_type::SortByErasedType;
pub use sort_by_score::SortBySimilarityScore;
pub use sort_by_static_fast_value::SortByStaticFastValue;
pub use sort_by_string::SortByString;
@@ -17,14 +15,11 @@ mod tests {
use std::collections::HashMap;
use std::ops::Range;
use crate::collector::sort_key::{
Comparator, NaturalComparator, ReverseComparator, SortByErasedType, SortBySimilarityScore,
SortByStaticFastValue, SortByString,
};
use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString};
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
use crate::indexer::NoMergePolicy;
use crate::query::{AllQuery, QueryParser};
use crate::schema::{OwnedValue, Schema, FAST, TEXT};
use crate::schema::{Schema, FAST, TEXT};
use crate::{DocAddress, Document, Index, Order, Score, Searcher};
fn make_index() -> crate::Result<Index> {
@@ -299,121 +294,32 @@ mod tests {
(SortBySimilarityScore, score_order),
(SortByString::for_field("city"), city_order),
));
let results: Vec<((Score, Option<String>), DocAddress)> =
searcher.search(&AllQuery, &top_collector)?;
Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
}
assert_eq!(
&query(&index, Order::Asc, Order::Asc)?,
&[
((1.0, Some("austin".to_owned())), 0),
((1.0, Some("greenville".to_owned())), 1),
((1.0, Some("tokyo".to_owned())), 2),
((1.0, None), 3),
]
);
assert_eq!(
&query(&index, Order::Asc, Order::Desc)?,
&[
((1.0, Some("tokyo".to_owned())), 2),
((1.0, Some("greenville".to_owned())), 1),
((1.0, Some("austin".to_owned())), 0),
((1.0, None), 3),
]
);
Ok(())
}
#[test]
fn test_order_by_score_then_owned_value() -> crate::Result<()> {
let index = make_index()?;
type SortKey = (Score, OwnedValue);
fn query(
index: &Index,
score_order: Order,
city_order: Order,
) -> crate::Result<Vec<(SortKey, u64)>> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by::<(Score, OwnedValue)>((
(SortBySimilarityScore, score_order),
(SortByErasedType::for_field("city"), city_order),
));
let results: Vec<((Score, OwnedValue), DocAddress)> =
searcher.search(&AllQuery, &top_collector)?;
Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
}
assert_eq!(
&query(&index, Order::Asc, Order::Asc)?,
&[
((1.0, OwnedValue::Str("austin".to_owned())), 0),
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
((1.0, OwnedValue::Null), 3),
]
);
assert_eq!(
&query(&index, Order::Asc, Order::Desc)?,
&[
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
((1.0, OwnedValue::Str("austin".to_owned())), 0),
((1.0, OwnedValue::Null), 3),
]
);
Ok(())
}
#[test]
fn test_order_by_compound_fast_fields() -> crate::Result<()> {
let index = make_index()?;
type CompoundSortKey = (Option<String>, Option<f64>);
fn assert_query(
index: &Index,
city_order: Order,
altitude_order: Order,
expected: Vec<(CompoundSortKey, u64)>,
) -> crate::Result<()> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by((
(SortByString::for_field("city"), city_order),
(
SortByStaticFastValue::<f64>::for_field("altitude"),
altitude_order,
),
));
let actual = searcher
Ok(searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(key, doc)| (key, ids[&doc]))
.collect::<Vec<_>>();
assert_eq!(actual, expected);
Ok(())
.map(|(f, doc)| (f, ids[&doc]))
.collect())
}
assert_query(
&index,
Order::Asc,
Order::Desc,
vec![
((Some("austin".to_owned()), Some(149.0)), 0),
((Some("greenville".to_owned()), Some(27.0)), 1),
((Some("tokyo".to_owned()), Some(40.0)), 2),
((None, Some(0.0)), 3),
],
)?;
assert_eq!(
&query(&index, Order::Asc, Order::Asc)?,
&[
((1.0, Some("austin".to_owned())), 0),
((1.0, Some("greenville".to_owned())), 1),
((1.0, Some("tokyo".to_owned())), 2),
((1.0, None), 3),
]
);
assert_eq!(
&query(&index, Order::Asc, Order::Desc)?,
&[
((1.0, Some("tokyo".to_owned())), 2),
((1.0, Some("greenville".to_owned())), 1),
((1.0, Some("austin".to_owned())), 0),
((1.0, None), 3),
]
);
Ok(())
}
@@ -466,14 +372,15 @@ mod tests {
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
if order.is_desc() {
comparable_docs.sort_by(|l, r| NaturalComparator.compare_doc(l, r));
} else {
comparable_docs.sort_by(|l, r| ReverseComparator.compare_doc(l, r));
}
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
@@ -483,197 +390,4 @@ mod tests {
);
}
}
proptest! {
#[test]
fn test_order_by_compound_prop(
city_order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
altitude_order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..20_usize,
offset in 0..20_usize,
segments_data in proptest::collection::vec(
proptest::collection::vec(
(proptest::option::of("[a-c]"), proptest::option::of(0..50u64)),
1..10_usize // segment size
),
1..4_usize // num segments
)
) {
use crate::collector::sort_key::ComparatorEnum;
use crate::TantivyDocument;
let mut schema_builder = Schema::builder();
let city = schema_builder.add_text_field("city", TEXT | FAST);
let altitude = schema_builder.add_u64_field("altitude", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
for segment_data in segments_data.into_iter() {
for (city_val, altitude_val) in segment_data.into_iter() {
let mut doc = TantivyDocument::default();
if let Some(c) = city_val {
doc.add_text(city, c);
}
if let Some(a) = altitude_val {
doc.add_u64(altitude, a);
}
index_writer.add_document(doc).unwrap();
}
index_writer.commit().unwrap();
}
let searcher = index.reader().unwrap().searcher();
let top_collector = TopDocs::with_limit(limit)
.and_offset(offset)
.order_by((
(SortByString::for_field("city"), city_order),
(
SortByStaticFastValue::<u64>::for_field("altitude"),
altitude_order,
),
));
let actual_results = searcher.search(&AllQuery, &top_collector).unwrap();
let actual_doc_ids: Vec<DocAddress> =
actual_results.into_iter().map(|(_, doc)| doc).collect();
// Verification logic
let all_docs_collector = DocSetCollector;
let all_docs = searcher.search(&AllQuery, &all_docs_collector).unwrap();
let docs_with_keys: Vec<((Option<String>, Option<u64>), DocAddress)> = all_docs
.into_iter()
.map(|doc_addr| {
let reader = searcher.segment_reader(doc_addr.segment_ord);
let city_val = if let Some(col) = reader.fast_fields().str("city").unwrap() {
let ord = col.ords().first(doc_addr.doc_id);
if let Some(ord) = ord {
let mut out = Vec::new();
col.dictionary().ord_to_term(ord, &mut out).unwrap();
String::from_utf8(out).ok()
} else {
None
}
} else {
None
};
let alt_val = if let Some((col, _)) = reader.fast_fields().u64_lenient("altitude").unwrap() {
col.first(doc_addr.doc_id)
} else {
None
};
((city_val, alt_val), doc_addr)
})
.collect();
let city_comparator = ComparatorEnum::from(city_order);
let alt_comparator = ComparatorEnum::from(altitude_order);
let comparator = (city_comparator, alt_comparator);
let mut comparable_docs: Vec<ComparableDoc<_, _>> = docs_with_keys
.into_iter()
.map(|(sort_key, doc)| ComparableDoc { sort_key, doc })
.collect();
comparable_docs.sort_by(|l, r| comparator.compare_doc(l, r));
let expected_results = comparable_docs
.into_iter()
.skip(offset)
.take(limit)
.collect::<Vec<_>>();
let expected_doc_ids: Vec<DocAddress> =
expected_results.into_iter().map(|cd| cd.doc).collect();
prop_assert_eq!(actual_doc_ids, expected_doc_ids);
}
}
proptest! {
#[test]
fn test_order_by_u64_prop(
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..20_usize,
offset in 0..20_usize,
segments_data in proptest::collection::vec(
proptest::collection::vec(
proptest::option::of(0..100u64),
1..1000_usize // segment size
),
1..4_usize // num segments
)
) {
use crate::collector::sort_key::ComparatorEnum;
use crate::TantivyDocument;
let mut schema_builder = Schema::builder();
let field = schema_builder.add_u64_field("field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
for segment_data in segments_data.into_iter() {
for val in segment_data.into_iter() {
let mut doc = TantivyDocument::default();
if let Some(v) = val {
doc.add_u64(field, v);
}
index_writer.add_document(doc).unwrap();
}
index_writer.commit().unwrap();
}
let searcher = index.reader().unwrap().searcher();
let top_collector = TopDocs::with_limit(limit)
.and_offset(offset)
.order_by((SortByStaticFastValue::<u64>::for_field("field"), order));
let actual_results = searcher.search(&AllQuery, &top_collector).unwrap();
let actual_doc_ids: Vec<DocAddress> =
actual_results.into_iter().map(|(_, doc)| doc).collect();
// Verification logic
let all_docs_collector = DocSetCollector;
let all_docs = searcher.search(&AllQuery, &all_docs_collector).unwrap();
let docs_with_keys: Vec<(Option<u64>, DocAddress)> = all_docs
.into_iter()
.map(|doc_addr| {
let reader = searcher.segment_reader(doc_addr.segment_ord);
let val = if let Some((col, _)) = reader.fast_fields().u64_lenient("field").unwrap() {
col.first(doc_addr.doc_id)
} else {
None
};
(val, doc_addr)
})
.collect();
let comparator = ComparatorEnum::from(order);
let mut comparable_docs: Vec<ComparableDoc<_, _>> = docs_with_keys
.into_iter()
.map(|(sort_key, doc)| ComparableDoc { sort_key, doc })
.collect();
comparable_docs.sort_by(|l, r| comparator.compare_doc(l, r));
let expected_results = comparable_docs
.into_iter()
.skip(offset)
.take(limit)
.collect::<Vec<_>>();
let expected_doc_ids: Vec<DocAddress> =
expected_results.into_iter().map(|cd| cd.doc).collect();
prop_assert_eq!(actual_doc_ids, expected_doc_ids);
}
}
}

View File

@@ -1,103 +1,19 @@
use std::cmp::Ordering;
use columnar::{MonotonicallyMappableToU64, ValueRange};
use serde::{Deserialize, Serialize};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::{OwnedValue, Schema};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::Schema;
use crate::{DocId, Order, Score};
fn compare_owned_value<const NULLS_FIRST: bool>(lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
match (lhs, rhs) {
(OwnedValue::Null, OwnedValue::Null) => Ordering::Equal,
(OwnedValue::Null, _) => {
if NULLS_FIRST {
Ordering::Less
} else {
Ordering::Greater
}
}
(_, OwnedValue::Null) => {
if NULLS_FIRST {
Ordering::Greater
} else {
Ordering::Less
}
}
(OwnedValue::Str(a), OwnedValue::Str(b)) => a.cmp(b),
(OwnedValue::PreTokStr(a), OwnedValue::PreTokStr(b)) => a.cmp(b),
(OwnedValue::U64(a), OwnedValue::U64(b)) => a.cmp(b),
(OwnedValue::I64(a), OwnedValue::I64(b)) => a.cmp(b),
(OwnedValue::F64(a), OwnedValue::F64(b)) => a.to_u64().cmp(&b.to_u64()),
(OwnedValue::Bool(a), OwnedValue::Bool(b)) => a.cmp(b),
(OwnedValue::Date(a), OwnedValue::Date(b)) => a.cmp(b),
(OwnedValue::Facet(a), OwnedValue::Facet(b)) => a.cmp(b),
(OwnedValue::Bytes(a), OwnedValue::Bytes(b)) => a.cmp(b),
(OwnedValue::IpAddr(a), OwnedValue::IpAddr(b)) => a.cmp(b),
(OwnedValue::U64(a), OwnedValue::I64(b)) => {
if *b < 0 {
Ordering::Greater
} else {
a.cmp(&(*b as u64))
}
}
(OwnedValue::I64(a), OwnedValue::U64(b)) => {
if *a < 0 {
Ordering::Less
} else {
(*a as u64).cmp(b)
}
}
(OwnedValue::U64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
(OwnedValue::F64(a), OwnedValue::U64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
(OwnedValue::I64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
(OwnedValue::F64(a), OwnedValue::I64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
(a, b) => {
let ord = a.discriminant_value().cmp(&b.discriminant_value());
// If the discriminant is equal, it's because a new type was added, but hasn't been
// included in this `match` statement.
assert!(
ord != Ordering::Equal,
"Unimplemented comparison for type of {a:?}, {b:?}"
);
ord
}
}
}
/// Comparator trait defining the order in which documents should be ordered.
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
/// Return the order between two values.
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
/// Return the order between two ComparableDoc values, using the semantics which are
/// implemented by TopNComputer.
#[inline(always)]
fn compare_doc<D: Ord>(
&self,
lhs: &ComparableDoc<T, D>,
rhs: &ComparableDoc<T, D>,
) -> Ordering {
// TopNComputer sorts in descending order of the SortKey by default: we apply that ordering
// here to ease comparison in testing.
self.compare(&rhs.sort_key, &lhs.sort_key).then_with(|| {
// In case of a tie on the sort key, we always sort by ascending `DocAddress` in order
// to ensure a stable sorting of the documents, regardless of the sort key's order.
// See the TopNComputer docs for more information.
lhs.doc.cmp(&rhs.doc)
})
}
/// Return a `ValueRange` that matches all values that are greater than the provided threshold.
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T>;
}
/// Compare values naturally (e.g. 1 < 2).
///
/// When used with `TopDocs`, which reverses the order, this results in a
/// "Descending" sort (Greatest values first).
///
/// `None` (or Null for `OwnedValue`) values are considered to be smaller than any other value,
/// and will therefore appear last in a descending sort (e.g. `[Some(20), Some(10), None]`).
/// With the natural comparator, the top k collector will return
/// the top documents in decreasing order.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct NaturalComparator;
@@ -106,116 +22,32 @@ impl<T: PartialOrd> Comparator<T> for NaturalComparator {
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
lhs.partial_cmp(rhs).unwrap()
}
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T> {
ValueRange::GreaterThan(threshold, false)
}
}
/// A (partial) implementation of comparison for OwnedValue.
/// Sorts document in reverse order.
///
/// Intended for use within columns of homogenous types, and so will panic for OwnedValues with
/// mismatched types. The one exception is Null, for which we do define all comparisons.
impl Comparator<OwnedValue> for NaturalComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ true>(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
ValueRange::GreaterThan(threshold, false)
}
}
/// Compare values in reverse (e.g. 2 < 1).
///
/// When used with `TopDocs`, which reverses the order, this results in an
/// "Ascending" sort (Smallest values first).
///
/// `None` is considered smaller than `Some` in the underlying comparator, but because the
/// comparison is reversed, `None` is effectively treated as the lowest value in the resulting
/// Ascending sort (e.g. `[None, Some(10), Some(20)]`).
/// If the sort key is None, it will considered as the lowest value, and will therefore appear
/// first.
///
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
/// to the NaturalComparator. In presence of a tie on the sort key, documents will always be
/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the sort key's order.
/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct ReverseComparator;
macro_rules! impl_reverse_comparator_primitive {
($($t:ty),*) => {
$(
impl Comparator<$t> for ReverseComparator {
#[inline(always)]
fn compare(&self, lhs: &$t, rhs: &$t) -> Ordering {
NaturalComparator.compare(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: $t) -> ValueRange<$t> {
ValueRange::LessThan(threshold, true)
}
}
)*
}
}
impl_reverse_comparator_primitive!(
bool,
u8,
u16,
u32,
u64,
u128,
usize,
i8,
i16,
i32,
i64,
i128,
isize,
f32,
f64,
String,
crate::DateTime,
Vec<u8>,
crate::schema::Facet
);
impl<T: PartialOrd + Send + Sync + std::fmt::Debug + Clone + 'static> Comparator<Option<T>>
for ReverseComparator
impl<T> Comparator<T> for ReverseComparator
where NaturalComparator: Comparator<T>
{
#[inline(always)]
fn compare(&self, lhs: &Option<T>, rhs: &Option<T>) -> Ordering {
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
NaturalComparator.compare(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
let is_some = threshold.is_some();
ValueRange::LessThan(threshold, is_some)
}
}
impl Comparator<OwnedValue> for ReverseComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
NaturalComparator.compare(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
let is_not_null = !matches!(threshold, OwnedValue::Null);
ValueRange::LessThan(threshold, is_not_null)
}
}
/// Compare values in reverse, but treating `None` as lower than `Some`.
///
/// When used with `TopDocs`, which reverses the order, this results in an
/// "Ascending" sort (Smallest values first), but with `None` values appearing last
/// (e.g. `[Some(10), Some(20), None]`).
/// Sorts document in reverse order, but considers None as having the lowest value.
///
/// This is usually what is wanted when sorting by a field in an ascending order.
/// For instance, in an e-commerce website, if sorting by price ascending,
/// the cheapest items would appear first, and items without a price would appear last.
/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the
/// cheapest items first, and the items without a price at last.
#[derive(Debug, Copy, Clone, Default)]
pub struct ReverseNoneIsLowerComparator;
@@ -231,14 +63,6 @@ where ReverseComparator: Comparator<T>
(Some(lhs), Some(rhs)) => ReverseComparator.compare(lhs, rhs),
}
}
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
if threshold.is_some() {
ValueRange::LessThan(threshold, false)
} else {
ValueRange::GreaterThan(threshold, false)
}
}
}
impl Comparator<u32> for ReverseNoneIsLowerComparator {
@@ -246,10 +70,6 @@ impl Comparator<u32> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange<u32> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<u64> for ReverseNoneIsLowerComparator {
@@ -257,10 +77,6 @@ impl Comparator<u64> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange<u64> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<f64> for ReverseNoneIsLowerComparator {
@@ -268,10 +84,6 @@ impl Comparator<f64> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange<f64> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<f32> for ReverseNoneIsLowerComparator {
@@ -279,10 +91,6 @@ impl Comparator<f32> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange<f32> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<i64> for ReverseNoneIsLowerComparator {
@@ -290,10 +98,6 @@ impl Comparator<i64> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange<i64> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<String> for ReverseNoneIsLowerComparator {
@@ -301,129 +105,6 @@ impl Comparator<String> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: String) -> ValueRange<String> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<OwnedValue> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ false>(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
ValueRange::LessThan(threshold, false)
}
}
/// Compare values naturally, but treating `None` as higher than `Some`.
///
/// When used with `TopDocs`, which reverses the order, this results in a
/// "Descending" sort (Greatest values first), but with `None` values appearing first
/// (e.g. `[None, Some(20), Some(10)]`).
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct NaturalNoneIsHigherComparator;
impl<T> Comparator<Option<T>> for NaturalNoneIsHigherComparator
where NaturalComparator: Comparator<T>
{
#[inline(always)]
fn compare(&self, lhs_opt: &Option<T>, rhs_opt: &Option<T>) -> Ordering {
match (lhs_opt, rhs_opt) {
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Greater,
(Some(_), None) => Ordering::Less,
(Some(lhs), Some(rhs)) => NaturalComparator.compare(lhs, rhs),
}
}
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
if threshold.is_some() {
let is_some = threshold.is_some();
ValueRange::GreaterThan(threshold, is_some)
} else {
ValueRange::LessThan(threshold, false)
}
}
}
impl Comparator<u32> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange<u32> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<u64> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange<u64> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<f64> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange<f64> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<f32> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange<f32> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<i64> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange<i64> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<String> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: String) -> ValueRange<String> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<OwnedValue> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ false>(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
ValueRange::GreaterThan(threshold, true)
}
}
/// An enum representing the different sort orders.
@@ -434,10 +115,8 @@ pub enum ComparatorEnum {
Natural,
/// Reverse order (See [ReverseComparator])
Reverse,
/// Reverse order by treating None as the lowest value. (See [ReverseNoneLowerComparator])
/// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator])
ReverseNoneLower,
/// Natural order but treating None as the highest value. (See [NaturalNoneIsHigherComparator])
NaturalNoneHigher,
}
impl From<Order> for ComparatorEnum {
@@ -454,7 +133,6 @@ where
ReverseNoneIsLowerComparator: Comparator<T>,
NaturalComparator: Comparator<T>,
ReverseComparator: Comparator<T>,
NaturalNoneIsHigherComparator: Comparator<T>,
{
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
@@ -462,20 +140,6 @@ where
ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs),
ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs),
ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs),
ComparatorEnum::NaturalNoneHigher => NaturalNoneIsHigherComparator.compare(lhs, rhs),
}
}
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T> {
match self {
ComparatorEnum::Natural => NaturalComparator.threshold_to_valuerange(threshold),
ComparatorEnum::Reverse => ReverseComparator.threshold_to_valuerange(threshold),
ComparatorEnum::ReverseNoneLower => {
ReverseNoneIsLowerComparator.threshold_to_valuerange(threshold)
}
ComparatorEnum::NaturalNoneHigher => {
NaturalNoneIsHigherComparator.threshold_to_valuerange(threshold)
}
}
}
}
@@ -492,10 +156,6 @@ where
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
}
fn threshold_to_valuerange(&self, threshold: (Head, Tail)) -> ValueRange<(Head, Tail)> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, (Type2, Type3))>
@@ -512,13 +172,6 @@ where
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
.then_with(|| self.2.compare(&lhs.1 .1, &rhs.1 .1))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, (Type2, Type3)),
) -> ValueRange<(Type1, (Type2, Type3))> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, Type2, Type3)>
@@ -535,13 +188,6 @@ where
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, Type2, Type3),
) -> ValueRange<(Type1, Type2, Type3)> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
@@ -565,13 +211,6 @@ where
.then_with(|| self.2.compare(&lhs.1 .1 .0, &rhs.1 .1 .0))
.then_with(|| self.3.compare(&lhs.1 .1 .1, &rhs.1 .1 .1))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, (Type2, (Type3, Type4))),
) -> ValueRange<(Type1, (Type2, (Type3, Type4)))> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
@@ -595,13 +234,6 @@ where
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
.then_with(|| self.3.compare(&lhs.3, &rhs.3))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, Type2, Type3, Type4),
) -> ValueRange<(Type1, Type2, Type3, Type4)> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, ComparatorEnum)
@@ -690,33 +322,16 @@ impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComput
for SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator>
where
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
TSegmentSortKey: Clone + 'static + Sync + Send,
TComparator: Comparator<TSegmentSortKey> + Clone + 'static + Sync + Send,
TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send,
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
{
type SortKey = TSegmentSortKeyComputer::SortKey;
type SegmentSortKey = TSegmentSortKey;
type SegmentComparator = TComparator;
type Buffer = TSegmentSortKeyComputer::Buffer;
fn segment_comparator(&self) -> Self::SegmentComparator {
self.comparator.clone()
}
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.segment_sort_key_computer.segment_sort_key(doc, score)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
self.segment_sort_key_computer
.segment_sort_keys(input_docs, output, buffer, filter)
}
#[inline(always)]
fn compare_segment_sort_key(
&self,
@@ -731,55 +346,3 @@ where
.convert_segment_sort_key(sort_key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::OwnedValue;
#[test]
fn test_mixed_ownedvalue_compare() {
let u = OwnedValue::U64(10);
let i = OwnedValue::I64(10);
let f = OwnedValue::F64(10.0);
let nc = NaturalComparator::default();
assert_eq!(nc.compare(&u, &i), Ordering::Equal);
assert_eq!(nc.compare(&u, &f), Ordering::Equal);
assert_eq!(nc.compare(&i, &f), Ordering::Equal);
let u2 = OwnedValue::U64(11);
assert_eq!(nc.compare(&u2, &f), Ordering::Greater);
let s = OwnedValue::Str("a".to_string());
// Str < U64
assert_eq!(nc.compare(&s, &u), Ordering::Less);
// Str < I64
assert_eq!(nc.compare(&s, &i), Ordering::Less);
// Str < F64
assert_eq!(nc.compare(&s, &f), Ordering::Less);
}
#[test]
fn test_natural_none_is_higher() {
let comp = NaturalNoneIsHigherComparator;
let null = OwnedValue::Null;
let v1 = OwnedValue::U64(1);
let v2 = OwnedValue::U64(2);
// NaturalNoneIsGreaterComparator logic:
// 1. Delegates to NaturalComparator for non-nulls.
// NaturalComparator compare(2, 1) -> 2.cmp(1) -> Greater.
assert_eq!(comp.compare(&v2, &v1), Ordering::Greater);
// 2. Treats None (Null) as Greater than any value.
// compare(Null, 2) should be Greater.
assert_eq!(comp.compare(&null, &v2), Ordering::Greater);
// compare(1, Null) should be Less.
assert_eq!(comp.compare(&v1, &null), Ordering::Less);
// compare(Null, Null) should be Equal.
assert_eq!(comp.compare(&null, &null), Ordering::Equal);
}
}

View File

@@ -1,410 +0,0 @@
use columnar::{ColumnType, MonotonicallyMappableToU64, ValueRange};
use crate::collector::sort_key::sort_by_score::SortBySimilarityScoreSegmentComputer;
use crate::collector::sort_key::{
NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString,
};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::FastFieldNotAvailableError;
use crate::schema::OwnedValue;
use crate::{DateTime, DocId, Score};
/// Sort by the boxed / OwnedValue representation of either a fast field, or of the score.
///
/// Using the OwnedValue representation allows for type erasure, and can be useful when sort orders
/// are not known until runtime. But it comes with a performance cost: wherever possible, prefer to
/// use a SortKeyComputer implementation with a known-type at compile time.
#[derive(Debug, Clone)]
pub enum SortByErasedType {
/// Sort by a fast field
Field(String),
/// Sort by score
Score,
}
impl SortByErasedType {
/// Creates a new sort key computer which will sort by the given fast field column, with type
/// erasure.
pub fn for_field(column_name: impl ToString) -> Self {
Self::Field(column_name.to_string())
}
/// Creates a new sort key computer which will sort by score, with type erasure.
pub fn for_score() -> Self {
Self::Score
}
}
trait ErasedSegmentSortKeyComputer: Send + Sync {
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64>;
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
filter: ValueRange<Option<u64>>,
);
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue;
}
struct ErasedSegmentSortKeyComputerWrapper<C, F>
where
C: SegmentSortKeyComputer<SegmentSortKey = Option<u64>> + Send + Sync,
F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static,
{
inner: C,
converter: F,
buffer: C::Buffer,
}
impl<C, F> ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper<C, F>
where
C: SegmentSortKeyComputer<SegmentSortKey = Option<u64>> + Send + Sync,
F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static,
{
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
self.inner.segment_sort_key(doc, score)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
filter: ValueRange<Option<u64>>,
) {
self.inner
.segment_sort_keys(input_docs, output, &mut self.buffer, filter)
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
let val = self.inner.convert_segment_sort_key(sort_key);
(self.converter)(val)
}
}
struct ScoreSegmentSortKeyComputer {
segment_computer: SortBySimilarityScoreSegmentComputer,
}
impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
let score_value: f64 = self.segment_computer.segment_sort_key(doc, score).into();
Some(score_value.to_u64())
}
fn segment_sort_keys(
&mut self,
_input_docs: &[DocId],
_output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
_filter: ValueRange<Option<u64>>,
) {
unimplemented!("Batch computation not supported for score sorting")
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
let score_value: u64 = sort_key.expect("This implementation always produces a score.");
OwnedValue::F64(f64::from_u64(score_value))
}
}
impl SortKeyComputer for SortByErasedType {
type SortKey = OwnedValue;
type Child = ErasedColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator;
fn requires_scoring(&self) -> bool {
matches!(self, Self::Score)
}
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let inner: Box<dyn ErasedSegmentSortKeyComputer> = match self {
Self::Field(column_name) => {
let fast_fields = segment_reader.fast_fields();
// TODO: We currently double-open the column to avoid relying on the implementation
// details of `SortByString` or `SortByStaticFastValue`. Once
// https://github.com/quickwit-oss/tantivy/issues/2776 is resolved, we should
// consider directly constructing the appropriate `SegmentSortKeyComputer` type for
// the column that we open here.
let (_column, column_type) =
fast_fields.u64_lenient(column_name)?.ok_or_else(|| {
FastFieldNotAvailableError {
field_name: column_name.to_owned(),
}
})?;
match column_type {
ColumnType::Str => {
let computer = SortByString::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<String>| {
val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::U64 => {
let computer = SortByStaticFastValue::<u64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<u64>| {
val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::I64 => {
let computer = SortByStaticFastValue::<i64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<i64>| {
val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::F64 => {
let computer = SortByStaticFastValue::<f64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<f64>| {
val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::Bool => {
let computer = SortByStaticFastValue::<bool>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<bool>| {
val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::DateTime => {
let computer = SortByStaticFastValue::<DateTime>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<DateTime>| {
val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
column_type => {
return Err(crate::TantivyError::SchemaError(format!(
"Field `{}` is of type {column_type:?}, which is not supported for \
sorting by owned value yet.",
column_name
)))
}
}
}
Self::Score => Box::new(ScoreSegmentSortKeyComputer {
segment_computer: SortBySimilarityScore
.segment_sort_key_computer(segment_reader)?,
}),
};
Ok(ErasedColumnSegmentSortKeyComputer { inner })
}
}
pub struct ErasedColumnSegmentSortKeyComputer {
inner: Box<dyn ErasedSegmentSortKeyComputer>,
}
impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer {
type SortKey = OwnedValue;
type SegmentSortKey = Option<u64>;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
self.inner.segment_sort_key(doc, score)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
self.inner.segment_sort_keys(input_docs, output, filter)
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue {
self.inner.convert_segment_sort_key(segment_sort_key)
}
}
#[cfg(test)]
mod tests {
use crate::collector::sort_key::{ComparatorEnum, SortByErasedType};
use crate::collector::TopDocs;
use crate::query::AllQuery;
use crate::schema::{OwnedValue, Schema, FAST, TEXT};
use crate::Index;
#[test]
fn test_sort_by_owned_u64() {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(id_field => 10u64)).unwrap();
writer.add_document(doc!(id_field => 2u64)).unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Natural));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![OwnedValue::U64(10), OwnedValue::U64(2), OwnedValue::Null]
);
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_field("id"),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![OwnedValue::U64(2), OwnedValue::U64(10), OwnedValue::Null]
);
}
#[test]
fn test_sort_by_owned_string() {
let mut schema_builder = Schema::builder();
let city_field = schema_builder.add_text_field("city", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(city_field => "tokyo")).unwrap();
writer.add_document(doc!(city_field => "austin")).unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_field("city"),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![
OwnedValue::Str("austin".to_string()),
OwnedValue::Str("tokyo".to_string()),
OwnedValue::Null
]
);
}
#[test]
fn test_sort_by_owned_reverse() {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(id_field => 10u64)).unwrap();
writer.add_document(doc!(id_field => 2u64)).unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Reverse));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![OwnedValue::Null, OwnedValue::U64(2), OwnedValue::U64(10)]
);
}
#[test]
fn test_sort_by_owned_score() {
let mut schema_builder = Schema::builder();
let body_field = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(body_field => "a a")).unwrap();
writer.add_document(doc!(body_field => "a")).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let query_parser = crate::query::QueryParser::for_index(&index, vec![body_field]);
let query = query_parser.parse_query("a").unwrap();
// Sort by score descending (Natural)
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_score(), ComparatorEnum::Natural));
let top_docs = searcher.search(&query, &collector).unwrap();
let values: Vec<f64> = top_docs
.into_iter()
.map(|(key, _)| match key {
OwnedValue::F64(val) => val,
_ => panic!("Wrong type {:?}", key),
})
.collect();
assert_eq!(values.len(), 2);
assert!(values[0] > values[1]);
// Sort by score ascending (ReverseNoneLower)
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_score(),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&query, &collector).unwrap();
let values: Vec<f64> = top_docs
.into_iter()
.map(|(key, _)| match key {
OwnedValue::F64(val) => val,
_ => panic!("Wrong type {:?}", key),
})
.collect();
assert_eq!(values.len(), 2);
assert!(values[0] < values[1]);
}
}

View File

@@ -1,7 +1,5 @@
use columnar::ValueRange;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::{DocAddress, DocId, Score};
/// Sort by similarity score.
@@ -11,7 +9,7 @@ pub struct SortBySimilarityScore;
impl SortKeyComputer for SortBySimilarityScore {
type SortKey = Score;
type Child = SortBySimilarityScoreSegmentComputer;
type Child = SortBySimilarityScore;
type Comparator = NaturalComparator;
@@ -23,7 +21,7 @@ impl SortKeyComputer for SortBySimilarityScore {
&self,
_segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
Ok(SortBySimilarityScoreSegmentComputer)
Ok(SortBySimilarityScore)
}
// Sorting by score is special in that it allows for the Block-Wand optimization.
@@ -63,29 +61,16 @@ impl SortKeyComputer for SortBySimilarityScore {
}
}
pub struct SortBySimilarityScoreSegmentComputer;
impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer {
impl SegmentSortKeyComputer for SortBySimilarityScore {
type SortKey = Score;
type SegmentSortKey = Score;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {
score
}
fn segment_sort_keys(
&mut self,
_input_docs: &[DocId],
_output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
_filter: ValueRange<Self::SegmentSortKey>,
) {
unimplemented!("Batch computation not supported for score sorting")
}
fn convert_segment_sort_key(&self, score: Score) -> Score {
score
}

View File

@@ -1,10 +1,9 @@
use std::marker::PhantomData;
use columnar::{Column, ValueRange};
use columnar::Column;
use crate::collector::sort_key::sort_key_computer::convert_optional_u64_range_to_u64_range;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
use crate::{DocId, Score, SegmentReader};
@@ -35,7 +34,9 @@ impl<T: FastValue> SortByStaticFastValue<T> {
impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
type Child = SortByFastValueSegmentSortKeyComputer<T>;
type SortKey = Option<T>;
type Comparator = NaturalComparator;
fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> {
@@ -83,112 +84,15 @@ pub struct SortByFastValueSegmentSortKeyComputer<T> {
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
type SortKey = Option<T>;
type SegmentSortKey = Option<u64>;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {
self.sort_column.first(doc)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
self.sort_column
.first_vals_in_value_range(input_docs, output, u64_filter);
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key.map(T::from_u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{Schema, FAST};
use crate::Index;
#[test]
fn test_sort_by_fast_value_batch() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_u64_field("field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => 10u64))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => 20u64))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByStaticFastValue::<u64>::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
let mut buffer = ();
computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(10), Some(20), None]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[0, 1, 2]);
}
#[test]
fn test_sort_by_fast_value_batch_with_filter() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_u64_field("field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => 10u64))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => 20u64))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByStaticFastValue::<u64>::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
let mut buffer = ();
computer.segment_sort_keys(
&mut docs,
&mut output,
&mut buffer,
ValueRange::GreaterThan(Some(15u64), false /* inclusive */),
);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(20)]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[1]);
}
}

View File

@@ -1,10 +1,7 @@
use columnar::{StrColumn, ValueRange};
use columnar::StrColumn;
use crate::collector::sort_key::sort_key_computer::{
convert_optional_u64_range_to_u64_range, range_contains_none,
};
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
use crate::{DocId, Score};
@@ -33,7 +30,9 @@ impl SortByString {
impl SortKeyComputer for SortByString {
type SortKey = Option<String>;
type Child = ByStringColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(
@@ -51,9 +50,8 @@ pub struct ByStringColumnSegmentSortKeyComputer {
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
type SortKey = Option<String>;
type SegmentSortKey = Option<TermOrdinal>;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
@@ -61,31 +59,7 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
str_column.ords().first(doc)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
if let Some(str_column) = &self.str_column_opt {
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
str_column
.ords()
.first_vals_in_value_range(input_docs, output, u64_filter);
} else if range_contains_none(&filter) {
for &doc in input_docs {
output.push(ComparableDoc {
doc,
sort_key: None,
});
}
}
}
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
let term_ord = term_ord_opt?;
let str_column = self.str_column_opt.as_ref()?;
let mut bytes = Vec::new();
@@ -96,90 +70,3 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
String::try_from(bytes).ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{Schema, FAST, TEXT};
use crate::Index;
#[test]
fn test_sort_by_string_batch() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => "a"))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => "c"))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByString::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
let mut buffer = ();
computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(0), Some(1), None]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[0, 1, 2]);
}
#[test]
fn test_sort_by_string_batch_with_filter() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => "a"))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => "c"))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByString::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
// Filter: > "b". "a" is 0, "c" is 1.
// We want > "a" (ord 0). So we filter > ord 0.
// 0 is "a", 1 is "c".
let mut buffer = ();
computer.segment_sort_keys(
&mut docs,
&mut output,
&mut buffer,
ValueRange::GreaterThan(Some(0), false /* inclusive */),
);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(1)]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[1]);
}
}

View File

@@ -1,12 +1,8 @@
use std::cmp::Ordering;
use columnar::ValueRange;
use crate::collector::sort_key::{Comparator, NaturalComparator};
use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector;
use crate::collector::{
default_collect_segment_impl, ComparableDoc, SegmentCollector as _, TopNComputer,
};
use crate::collector::{default_collect_segment_impl, SegmentCollector as _, TopNComputer};
use crate::schema::Schema;
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
@@ -16,40 +12,17 @@ use crate::{DocAddress, DocId, Result, Score, SegmentReader};
/// It is the segment local version of the [`SortKeyComputer`].
pub trait SegmentSortKeyComputer: 'static {
/// The final score being emitted.
type SortKey: 'static + Send + Sync + Clone;
type SortKey: 'static + PartialOrd + Send + Sync + Clone;
/// Sort key used by at the segment level by the `SegmentSortKeyComputer`.
///
/// It is typically small like a `u64`, and is meant to be converted
/// to the final score at the end of the collection of the segment.
type SegmentSortKey: 'static + Clone + Send + Sync + Clone;
/// Comparator type.
type SegmentComparator: Comparator<Self::SegmentSortKey> + Clone + 'static;
/// Buffer type used for scratch space.
type Buffer: Default + Send + Sync + 'static;
/// Returns the segment sort key comparator.
fn segment_comparator(&self) -> Self::SegmentComparator {
Self::SegmentComparator::default()
}
type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone;
/// Computes the sort key for the given document and score.
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
/// Computes the sort keys for a batch of documents.
///
/// The computed sort keys and document IDs are pushed into the `output` vector.
/// The `buffer` is used for scratch space.
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
);
/// Computes the sort key and pushes the document in a TopN Computer.
///
/// When using a tuple as the sorting key, the sort key is evaluated in a lazy manner.
@@ -58,32 +31,12 @@ pub trait SegmentSortKeyComputer: 'static {
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
let sort_key = self.segment_sort_key(doc, score);
top_n_computer.push(sort_key, doc);
}
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
docs: &[DocId],
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
// The capacity of a TopNComputer is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so we
// should always be able to `reserve` space for the entire block.
top_n_computer.reserve(docs.len());
let comparator = self.segment_comparator();
let value_range = if let Some(threshold) = &top_n_computer.threshold {
comparator.threshold_to_valuerange(threshold.clone())
} else {
ValueRange::All
};
let (buffer, scratch) = top_n_computer.buffer_and_scratch();
self.segment_sort_keys(docs, buffer, scratch, value_range);
}
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
///
@@ -94,7 +47,27 @@ pub trait SegmentSortKeyComputer: 'static {
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
self.segment_comparator().compare(left, right)
NaturalComparator.compare(left, right)
}
/// Implementing this method makes it possible to avoid computing
/// a sort_key entirely if we can assess that it won't pass a threshold
/// with a partial computation.
///
/// This is currently used for lexicographic sorting.
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let sort_key = self.segment_sort_key(doc_id, score);
let cmp = self.compare_segment_sort_key(&sort_key, threshold);
if cmp == Ordering::Less {
None
} else {
Some((cmp, sort_key))
}
}
/// Convert a segment level sort key into the global sort key.
@@ -108,7 +81,7 @@ pub trait SegmentSortKeyComputer: 'static {
/// the sort key at a segment scale.
pub trait SortKeyComputer: Sync {
/// The sort key type.
type SortKey: 'static + Send + Sync + Clone + std::fmt::Debug;
type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug;
/// Type of the associated [`SegmentSortKeyComputer`].
type Child: SegmentSortKeyComputer<SortKey = Self::SortKey>;
/// Comparator type.
@@ -163,9 +136,11 @@ where
HeadSortKeyComputer: SortKeyComputer,
TailSortKeyComputer: SortKeyComputer,
{
type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey);
type Child =
ChainSegmentSortKeyComputer<HeadSortKeyComputer::Child, TailSortKeyComputer::Child>;
type SortKey = (
<HeadSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
<TailSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
);
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
type Comparator = (
HeadSortKeyComputer::Comparator,
@@ -177,10 +152,10 @@ where
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok(ChainSegmentSortKeyComputer {
head: self.0.segment_sort_key_computer(segment_reader)?,
tail: self.1.segment_sort_key_computer(segment_reader)?,
})
Ok((
self.0.segment_sort_key_computer(segment_reader)?,
self.1.segment_sort_key_computer(segment_reader)?,
))
}
/// Checks whether the schema is compatible with the sort key computer.
@@ -198,91 +173,20 @@ where
}
}
pub struct ChainSegmentSortKeyComputer<Head, Tail>
impl<HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer> SegmentSortKeyComputer
for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer)
where
Head: SegmentSortKeyComputer,
Tail: SegmentSortKeyComputer,
HeadSegmentSortKeyComputer: SegmentSortKeyComputer,
TailSegmentSortKeyComputer: SegmentSortKeyComputer,
{
head: Head,
tail: Tail,
}
pub struct ChainBuffer<HeadBuffer, TailBuffer, HeadKey, TailKey> {
pub head: HeadBuffer,
pub tail: TailBuffer,
pub head_output: Vec<ComparableDoc<HeadKey, DocId>>,
pub tail_output: Vec<ComparableDoc<TailKey, DocId>>,
pub tail_input_docs: Vec<DocId>,
}
impl<HeadBuffer: Default, TailBuffer: Default, HeadKey, TailKey> Default
for ChainBuffer<HeadBuffer, TailBuffer, HeadKey, TailKey>
{
fn default() -> Self {
ChainBuffer {
head: HeadBuffer::default(),
tail: TailBuffer::default(),
head_output: Vec::new(),
tail_output: Vec::new(),
tail_input_docs: Vec::new(),
}
}
}
impl<Head, Tail> ChainSegmentSortKeyComputer<Head, Tail>
where
Head: SegmentSortKeyComputer,
Tail: SegmentSortKeyComputer,
{
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &<Self as SegmentSortKeyComputer>::SegmentSortKey,
) -> Option<(Ordering, <Self as SegmentSortKeyComputer>::SegmentSortKey)> {
let (head_threshold, tail_threshold) = threshold;
let head_sort_key = self.head.segment_sort_key(doc_id, score);
let head_cmp = self
.head
.compare_segment_sort_key(&head_sort_key, head_threshold);
if head_cmp == Ordering::Less {
None
} else if head_cmp == Ordering::Equal {
let tail_sort_key = self.tail.segment_sort_key(doc_id, score);
let tail_cmp = self
.tail
.compare_segment_sort_key(&tail_sort_key, tail_threshold);
if tail_cmp == Ordering::Less {
None
} else {
Some((tail_cmp, (head_sort_key, tail_sort_key)))
}
} else {
let tail_sort_key = self.tail.segment_sort_key(doc_id, score);
Some((head_cmp, (head_sort_key, tail_sort_key)))
}
}
}
impl<Head, Tail> SegmentSortKeyComputer for ChainSegmentSortKeyComputer<Head, Tail>
where
Head: SegmentSortKeyComputer,
Tail: SegmentSortKeyComputer,
{
type SortKey = (Head::SortKey, Tail::SortKey);
type SegmentSortKey = (Head::SegmentSortKey, Tail::SegmentSortKey);
type SegmentComparator = (Head::SegmentComparator, Tail::SegmentComparator);
type Buffer =
ChainBuffer<Head::Buffer, Tail::Buffer, Head::SegmentSortKey, Tail::SegmentSortKey>;
fn segment_comparator(&self) -> Self::SegmentComparator {
(
self.head.segment_comparator(),
self.tail.segment_comparator(),
)
}
type SortKey = (
HeadSegmentSortKeyComputer::SortKey,
TailSegmentSortKeyComputer::SortKey,
);
type SegmentSortKey = (
HeadSegmentSortKeyComputer::SegmentSortKey,
TailSegmentSortKeyComputer::SegmentSortKey,
);
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
@@ -294,90 +198,9 @@ where
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
self.head
self.0
.compare_segment_sort_key(&left.0, &right.0)
.then_with(|| self.tail.compare_segment_sort_key(&left.1, &right.1))
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
let (head_filter, threshold) = match filter {
ValueRange::GreaterThan((head_threshold, tail_threshold), _)
| ValueRange::LessThan((head_threshold, tail_threshold), _) => {
let head_cmp = self.head.segment_comparator();
let strict_head_filter = head_cmp.threshold_to_valuerange(head_threshold.clone());
let head_filter = match strict_head_filter {
ValueRange::GreaterThan(t, m) => ValueRange::GreaterThanOrEqual(t, m),
ValueRange::LessThan(t, m) => ValueRange::LessThanOrEqual(t, m),
other => other,
};
(head_filter, Some((head_threshold, tail_threshold)))
}
_ => (ValueRange::All, None),
};
buffer.head_output.clear();
self.head.segment_sort_keys(
input_docs,
&mut buffer.head_output,
&mut buffer.head,
head_filter,
);
if buffer.head_output.is_empty() {
return;
}
buffer.tail_output.clear();
buffer.tail_input_docs.clear();
for cd in &buffer.head_output {
buffer.tail_input_docs.push(cd.doc);
}
self.tail.segment_sort_keys(
&buffer.tail_input_docs,
&mut buffer.tail_output,
&mut buffer.tail,
ValueRange::All,
);
let head_cmp = self.head.segment_comparator();
let tail_cmp = self.tail.segment_comparator();
for (head_doc, tail_doc) in buffer
.head_output
.drain(..)
.zip(buffer.tail_output.drain(..))
{
debug_assert_eq!(head_doc.doc, tail_doc.doc);
let doc = head_doc.doc;
let head_key = head_doc.sort_key;
let tail_key = tail_doc.sort_key;
let accept = if let Some((head_threshold, tail_threshold)) = &threshold {
let head_ord = head_cmp.compare(&head_key, head_threshold);
let ord = if head_ord == Ordering::Equal {
tail_cmp.compare(&tail_key, tail_threshold)
} else {
head_ord
};
ord == Ordering::Greater
} else {
true
};
if accept {
output.push(ComparableDoc {
sort_key: (head_key, tail_key),
doc,
});
}
}
.then_with(|| self.1.compare_segment_sort_key(&left.1, &right.1))
}
#[inline(always)]
@@ -385,7 +208,7 @@ where
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
let sort_key: Self::SegmentSortKey;
if let Some(threshold) = &top_n_computer.threshold {
@@ -402,56 +225,68 @@ where
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
let head_sort_key = self.head.segment_sort_key(doc, score);
let tail_sort_key = self.tail.segment_sort_key(doc, score);
let head_sort_key = self.0.segment_sort_key(doc, score);
let tail_sort_key = self.1.segment_sort_key(doc, score);
(head_sort_key, tail_sort_key)
}
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let (head_threshold, tail_threshold) = threshold;
let (head_cmp, head_sort_key) =
self.0.accept_sort_key_lazy(doc_id, score, head_threshold)?;
if head_cmp == Ordering::Equal {
let (tail_cmp, tail_sort_key) =
self.1.accept_sort_key_lazy(doc_id, score, tail_threshold)?;
Some((tail_cmp, (head_sort_key, tail_sort_key)))
} else {
let tail_sort_key = self.1.segment_sort_key(doc_id, score);
Some((head_cmp, (head_sort_key, tail_sort_key)))
}
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
let (head_sort_key, tail_sort_key) = sort_key;
(
self.head.convert_segment_sort_key(head_sort_key),
self.tail.convert_segment_sort_key(tail_sort_key),
self.0.convert_segment_sort_key(head_sort_key),
self.1.convert_segment_sort_key(tail_sort_key),
)
}
}
/// This struct is used as an adapter to take a sort key computer and map its score to another
/// new sort key.
pub struct MappedSegmentSortKeyComputer<T: SegmentSortKeyComputer, NewSortKey> {
pub struct MappedSegmentSortKeyComputer<T, PreviousSortKey, NewSortKey> {
sort_key_computer: T,
map: fn(T::SortKey) -> NewSortKey,
map: fn(PreviousSortKey) -> NewSortKey,
}
impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
for MappedSegmentSortKeyComputer<T, NewScore>
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
where
T: SegmentSortKeyComputer<SortKey = PreviousScore>,
PreviousScore: 'static + Clone + Send + Sync,
NewScore: 'static + Clone + Send + Sync,
PreviousScore: 'static + Clone + Send + Sync + PartialOrd,
NewScore: 'static + Clone + Send + Sync + PartialOrd,
{
type SortKey = NewScore;
type SegmentSortKey = T::SegmentSortKey;
type SegmentComparator = T::SegmentComparator;
type Buffer = T::Buffer;
fn segment_comparator(&self) -> Self::SegmentComparator {
self.sort_key_computer.segment_comparator()
}
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.sort_key_computer.segment_sort_key(doc, score)
}
fn segment_sort_keys(
fn accept_sort_key_lazy(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
self.sort_key_computer
.segment_sort_keys(input_docs, output, buffer, filter)
.accept_sort_key_lazy(doc_id, score, threshold)
}
#[inline(always)]
@@ -459,21 +294,12 @@ where
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
self.sort_key_computer
.compute_sort_key_and_collect(doc, score, top_n_computer);
}
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
docs: &[DocId],
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
self.sort_key_computer
.compute_sort_keys_and_collect(docs, top_n_computer);
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey {
(self.map)(
self.sort_key_computer
@@ -499,6 +325,10 @@ where
);
type Child = MappedSegmentSortKeyComputer<
<(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(SortKeyComputer2::SortKey, SortKeyComputer3::SortKey),
),
Self::SortKey,
>;
@@ -522,13 +352,7 @@ where
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let map = |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3);
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: ChainSegmentSortKeyComputer {
head: sort_key_computer1,
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer2,
tail: sort_key_computer3,
},
},
sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)),
map,
})
}
@@ -563,6 +387,13 @@ where
SortKeyComputer1,
(SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)),
) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(
SortKeyComputer2::SortKey,
(SortKeyComputer3::SortKey, SortKeyComputer4::SortKey),
),
),
Self::SortKey,
>;
type SortKey = (
@@ -584,16 +415,10 @@ where
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?;
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: ChainSegmentSortKeyComputer {
head: sort_key_computer1,
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer2,
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer3,
tail: sort_key_computer4,
},
},
},
sort_key_computer: (
sort_key_computer1,
(sort_key_computer2, (sort_key_computer3, sort_key_computer4)),
),
map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| {
(sort_key1, sort_key2, sort_key3, sort_key4)
},
@@ -616,13 +441,6 @@ where
}
}
use std::marker::PhantomData;
pub struct FuncSegmentSortKeyComputer<F, TSortKey> {
func: F,
_phantom: PhantomData<TSortKey>,
}
impl<F, SegmentF, TSortKey> SortKeyComputer for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF,
@@ -630,44 +448,24 @@ where
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
{
type SortKey = TSortKey;
type Child = FuncSegmentSortKeyComputer<SegmentF, TSortKey>;
type Child = SegmentF;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok(FuncSegmentSortKeyComputer {
func: (self)(segment_reader),
_phantom: PhantomData,
})
Ok((self)(segment_reader))
}
}
impl<F, TSortKey> SegmentSortKeyComputer for FuncSegmentSortKeyComputer<F, TSortKey>
impl<F, TSortKey> SegmentSortKeyComputer for F
where
F: 'static + FnMut(DocId) -> TSortKey,
TSortKey: 'static + PartialOrd + Clone + Send + Sync,
{
type SortKey = TSortKey;
type SegmentSortKey = TSortKey;
type SegmentComparator = NaturalComparator;
type Buffer = ();
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
(self.func)(doc)
}
fn segment_sort_keys(
&mut self,
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
_filter: ValueRange<Self::SegmentSortKey>,
) {
for &doc in input_docs {
output.push(ComparableDoc {
sort_key: (self.func)(doc),
doc,
});
}
(self)(doc)
}
/// Convert a segment level score into the global level score.
@@ -676,75 +474,13 @@ where
}
}
pub(crate) fn range_contains_none(range: &ValueRange<Option<u64>>) -> bool {
match range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(&None),
ValueRange::GreaterThan(_threshold, match_nulls) => *match_nulls,
ValueRange::GreaterThanOrEqual(_threshold, match_nulls) => *match_nulls,
ValueRange::LessThan(_threshold, match_nulls) => *match_nulls,
ValueRange::LessThanOrEqual(_threshold, match_nulls) => *match_nulls,
}
}
pub(crate) fn convert_optional_u64_range_to_u64_range(
range: ValueRange<Option<u64>>,
) -> ValueRange<u64> {
match range {
ValueRange::Inclusive(r) => {
let start = r.start().unwrap_or(0);
let end = r.end().unwrap_or(u64::MAX);
ValueRange::Inclusive(start..=end)
}
ValueRange::GreaterThan(Some(val), match_nulls) => {
ValueRange::GreaterThan(val, match_nulls)
}
ValueRange::GreaterThan(None, match_nulls) => {
if match_nulls {
ValueRange::All
} else {
ValueRange::Inclusive(u64::MIN..=u64::MAX)
}
}
ValueRange::GreaterThanOrEqual(Some(val), match_nulls) => {
ValueRange::GreaterThanOrEqual(val, match_nulls)
}
ValueRange::GreaterThanOrEqual(None, match_nulls) => {
if match_nulls {
ValueRange::All
} else {
ValueRange::Inclusive(u64::MIN..=u64::MAX)
}
}
ValueRange::LessThan(None, match_nulls) => {
if match_nulls {
ValueRange::LessThan(u64::MIN, true)
} else {
ValueRange::Inclusive(1..=0)
}
}
ValueRange::LessThan(Some(val), match_nulls) => ValueRange::LessThan(val, match_nulls),
ValueRange::LessThanOrEqual(None, match_nulls) => {
if match_nulls {
ValueRange::LessThan(u64::MIN, true)
} else {
ValueRange::Inclusive(1..=0)
}
}
ValueRange::LessThanOrEqual(Some(val), match_nulls) => {
ValueRange::LessThanOrEqual(val, match_nulls)
}
ValueRange::All => ValueRange::All,
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::sync::Arc;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::Schema;
use crate::{DocId, Index, Order, SegmentReader};
@@ -892,178 +628,4 @@ mod tests {
(200u32, 2u32)
);
}
#[test]
fn test_batch_score_computer_edge_case() {
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
let score_computer_secondary = |_segment_reader: &SegmentReader| |_doc: DocId| "b";
let lazy_score_computer = (score_computer_primary, score_computer_secondary);
let index = build_test_index();
let searcher = index.reader().unwrap().searcher();
let mut segment_sort_key_computer = lazy_score_computer
.segment_sort_key_computer(searcher.segment_reader(0))
.unwrap();
let mut top_n_computer =
TopNComputer::new_with_comparator(10, lazy_score_computer.comparator());
// Threshold (200, "a"). Doc is (200, "b"). 200 == 200, "b" > "a". Should be accepted.
top_n_computer.threshold = Some((200, "a"));
let docs = vec![0];
segment_sort_key_computer.compute_sort_keys_and_collect(&docs, &mut top_n_computer);
let results = top_n_computer.into_sorted_vec();
assert_eq!(results.len(), 1);
let result = &results[0];
assert_eq!(result.doc, 0);
assert_eq!(result.sort_key, (200, "b"));
}
}
#[cfg(test)]
mod proptest_tests {
use proptest::prelude::*;
use super::*;
use crate::collector::sort_key::order::*;
// Re-implement logic to interpret ValueRange<Option<u64>> manually to verify expectations
fn range_contains_opt(range: &ValueRange<Option<u64>>, val: &Option<u64>) -> bool {
match range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(val),
ValueRange::GreaterThan(t, match_nulls) => {
if val.is_none() {
*match_nulls
} else {
val > t
}
}
ValueRange::GreaterThanOrEqual(t, match_nulls) => {
if val.is_none() {
*match_nulls
} else {
val >= t
}
}
ValueRange::LessThan(t, match_nulls) => {
if val.is_none() {
*match_nulls
} else {
val < t
}
}
ValueRange::LessThanOrEqual(t, match_nulls) => {
if val.is_none() {
*match_nulls
} else {
val <= t
}
}
}
}
fn range_contains_u64(range: &ValueRange<u64>, val: &u64) -> bool {
match range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(val),
ValueRange::GreaterThan(t, _) => val > t,
ValueRange::GreaterThanOrEqual(t, _) => val >= t,
ValueRange::LessThan(t, _) => val < t,
ValueRange::LessThanOrEqual(t, _) => val <= t,
}
}
proptest! {
#[test]
fn test_comparator_consistency_natural_none_is_lower(
threshold in any::<Option<u64>>(),
val in any::<Option<u64>>()
) {
check_comparator::<NaturalComparator>(threshold, val)?;
}
#[test]
fn test_comparator_consistency_reverse(
threshold in any::<Option<u64>>(),
val in any::<Option<u64>>()
) {
check_comparator::<ReverseComparator>(threshold, val)?;
}
#[test]
fn test_comparator_consistency_reverse_none_is_lower(
threshold in any::<Option<u64>>(),
val in any::<Option<u64>>()
) {
check_comparator::<ReverseNoneIsLowerComparator>(threshold, val)?;
}
#[test]
fn test_comparator_consistency_natural_none_is_higher(
threshold in any::<Option<u64>>(),
val in any::<Option<u64>>()
) {
check_comparator::<NaturalNoneIsHigherComparator>(threshold, val)?;
}
}
fn check_comparator<C: Comparator<Option<u64>>>(
threshold: Option<u64>,
val: Option<u64>,
) -> std::result::Result<(), proptest::test_runner::TestCaseError> {
let comparator = C::default();
let range = comparator.threshold_to_valuerange(threshold);
let ordering = comparator.compare(&val, &threshold);
let should_be_in_range = ordering == Ordering::Greater;
let in_range_opt = range_contains_opt(&range, &val);
prop_assert_eq!(
in_range_opt,
should_be_in_range,
"Comparator consistency failed for {:?}. Threshold: {:?}, Val: {:?}, Range: {:?}, \
Ordering: {:?}. range_contains_opt says {}, but compare says {}",
std::any::type_name::<C>(),
threshold,
val,
range,
ordering,
in_range_opt,
should_be_in_range
);
// Check range_contains_none
let expected_none_in_range = range_contains_opt(&range, &None);
let actual_none_in_range = range_contains_none(&range);
prop_assert_eq!(
actual_none_in_range,
expected_none_in_range,
"range_contains_none failed for {:?}. Range: {:?}. Expected (from \
range_contains_opt): {}, Actual: {}",
std::any::type_name::<C>(),
range,
expected_none_in_range,
actual_none_in_range
);
// Check convert_optional_u64_range_to_u64_range
let u64_range = convert_optional_u64_range_to_u64_range(range.clone());
if let Some(v) = val {
let in_u64_range = range_contains_u64(&u64_range, &v);
let in_opt_range = range_contains_opt(&range, &Some(v));
prop_assert_eq!(
in_u64_range,
in_opt_range,
"convert_optional_u64_range_to_u64_range failed for {:?}. Val: {:?}, OptRange: \
{:?}, U64Range: {:?}. Opt says {}, U64 says {}",
std::any::type_name::<C>(),
v,
range,
u64_range,
in_opt_range,
in_u64_range
);
}
Ok(())
}
}

View File

@@ -99,12 +99,7 @@ where
TSegmentSortKeyComputer: SegmentSortKeyComputer,
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey>,
{
pub(crate) topn_computer: TopNComputer<
TSegmentSortKeyComputer::SegmentSortKey,
DocId,
C,
TSegmentSortKeyComputer::Buffer,
>,
pub(crate) topn_computer: TopNComputer<TSegmentSortKeyComputer::SegmentSortKey, DocId, C>,
pub(crate) segment_ord: u32,
pub(crate) segment_sort_key_computer: TSegmentSortKeyComputer,
}
@@ -125,11 +120,6 @@ where
);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.segment_sort_key_computer
.compute_sort_keys_and_collect(docs, &mut self.topn_computer);
}
fn harvest(self) -> Self::Fruit {
let segment_ord = self.segment_ord;
let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self

View File

@@ -0,0 +1,64 @@
use std::cmp::Ordering;
use serde::{Deserialize, Serialize};
/// Contains a feature (field, score, etc.) of a document along with the document address.
///
/// It guarantees stable sorting: in case of a tie on the feature, the document
/// address is used.
///
/// The REVERSE_ORDER generic parameter controls whether the by-feature order
/// should be reversed, which is useful for achieving for example largest-first
/// semantics without having to wrap the feature in a `Reverse`.
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
/// The feature of the document. In practice, this is
/// is any type that implements `PartialOrd`.
pub sort_key: T,
/// The document address. In practice, this is any
/// type that implements `PartialOrd`, and is guaranteed
/// to be unique for each document.
pub doc: D,
}
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
for ComparableDoc<T, D, R>
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
.field("feature", &self.sort_key)
.field("doc", &self.doc)
.finish()
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialOrd for ComparableDoc<T, D, R> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R> {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
let by_feature = self
.sort_key
.partial_cmp(&other.sort_key)
.map(|ord| if R { ord.reverse() } else { ord })
.unwrap_or(Ordering::Equal);
let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal);
// In case of a tie on the feature, we sort by ascending
// `DocAddress` in order to ensure a stable sorting of the
// documents.
by_feature.then_with(lazy_by_doc_address)
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T, D, R> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}

View File

@@ -2,7 +2,6 @@ use std::cmp::Ordering;
use std::fmt;
use std::ops::Range;
use columnar::ValueRange;
use serde::{Deserialize, Serialize};
use super::Collector;
@@ -11,7 +10,8 @@ use crate::collector::sort_key::{
SortByStaticFastValue, SortByString,
};
use crate::collector::sort_key_top_collector::TopBySortKeyCollector;
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::top_collector::ComparableDoc;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::FastValue;
use crate::{DocAddress, DocId, Order, Score, SegmentReader};
@@ -23,9 +23,10 @@ use crate::{DocAddress, DocId, Order, Score, SegmentReader};
/// The theoretical complexity for collecting the top `K` out of `N` documents
/// is `O(N + K)`.
///
/// This collector guarantees a stable sorting in case of a tie on the
/// document score/sort key: The document address (`DocAddress`) is used as a tie breaker.
/// In case of a tie on the sort key, documents are always sorted by ascending `DocAddress`.
/// This collector does not guarantee a stable sorting in case of a tie on the
/// document score, for stable sorting `PartialOrd` needs to resolve on other fields
/// like docid in case of score equality.
/// Only then, it is suitable for pagination.
///
/// ```rust
/// use tantivy::collector::TopDocs;
@@ -324,7 +325,7 @@ impl TopDocs {
sort_key_computer: impl SortKeyComputer<SortKey = TSortKey> + Send + 'static,
) -> impl Collector<Fruit = Vec<(TSortKey, DocAddress)>>
where
TSortKey: 'static + Clone + Send + Sync + std::fmt::Debug,
TSortKey: 'static + Clone + Send + Sync + PartialOrd + std::fmt::Debug,
{
TopBySortKeyCollector::new(sort_key_computer, self.doc_range())
}
@@ -445,7 +446,7 @@ where
F: 'static + Send + Sync + Fn(&SegmentReader) -> TTweakScoreSortKeyFn,
TTweakScoreSortKeyFn: 'static + Fn(DocId, Score) -> TSortKey,
TweakScoreSegmentSortKeyComputer<TTweakScoreSortKeyFn>:
SegmentSortKeyComputer<SortKey = TSortKey, SegmentSortKey = TSortKey>,
SegmentSortKeyComputer<SortKey = TSortKey>,
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
{
type SortKey = TSortKey;
@@ -480,23 +481,11 @@ where
{
type SortKey = TSortKey;
type SegmentSortKey = TSortKey;
type SegmentComparator = NaturalComparator;
type Buffer = ();
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey {
(self.sort_key_fn)(doc, score)
}
fn segment_sort_keys(
&mut self,
_input_docs: &[DocId],
_output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
_filter: ValueRange<Self::SegmentSortKey>,
) {
unimplemented!("Batch computation is not supported for tweak score.")
}
/// Convert a segment level score into the global level score.
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key
@@ -511,23 +500,16 @@ where
///
/// For TopN == 0, it will be relative expensive.
///
/// The TopNComputer will tiebreak by using ascending `D` (DocId or DocAddress):
/// i.e., in case of a tie on the sort key, the `DocId|DocAddress` are always sorted in
/// ascending order, regardless of the `Comparator` used for the `Score` type.
///
/// NOTE: Items must be `push`ed to the TopNComputer in ascending `DocId|DocAddress` order, as the
/// threshold used to eliminate docs does not include the `DocId` or `DocAddress`: this provides
/// the ascending `DocId|DocAddress` tie-breaking behavior without additional comparisons.
/// When using the natural comparator, the top N computer returns the top N elements in
/// descending order, as expected for a top N.
#[derive(Serialize, Deserialize)]
#[serde(from = "TopNComputerDeser<Score, D, C>")]
pub struct TopNComputer<Score, D, C, Buffer = ()> {
pub struct TopNComputer<Score, D, C> {
/// The buffer reverses sort order to get top-semantics instead of bottom-semantics
buffer: Vec<ComparableDoc<Score, D>>,
top_n: usize,
pub(crate) threshold: Option<Score>,
comparator: C,
#[serde(skip)]
scratch: Buffer,
}
// Intermediate struct for TopNComputer for deserialization, to keep vec capacity
@@ -539,9 +521,7 @@ struct TopNComputerDeser<Score, D, C> {
comparator: C,
}
impl<Score, D, C, Buffer> From<TopNComputerDeser<Score, D, C>> for TopNComputer<Score, D, C, Buffer>
where Buffer: Default
{
impl<Score, D, C> From<TopNComputerDeser<Score, D, C>> for TopNComputer<Score, D, C> {
fn from(mut value: TopNComputerDeser<Score, D, C>) -> Self {
let expected_cap = value.top_n.max(1) * 2;
let current_cap = value.buffer.capacity();
@@ -556,15 +536,12 @@ where Buffer: Default
top_n: value.top_n,
threshold: value.threshold,
comparator: value.comparator,
scratch: Buffer::default(),
}
}
}
impl<Score: std::fmt::Debug, D, C, Buffer> std::fmt::Debug for TopNComputer<Score, D, C, Buffer>
where
C: Comparator<Score>,
Buffer: std::fmt::Debug,
impl<Score: std::fmt::Debug, D, C> std::fmt::Debug for TopNComputer<Score, D, C>
where C: Comparator<Score>
{
fn fmt(&self, f: &mut fmt::Formatter) -> std::fmt::Result {
f.debug_struct("TopNComputer")
@@ -572,13 +549,12 @@ where
.field("top_n", &self.top_n)
.field("current_threshold", &self.threshold)
.field("comparator", &self.comparator)
.field("scratch", &self.scratch)
.finish()
}
}
// Custom clone to keep capacity
impl<Score: Clone, D: Clone, C: Clone, Buffer: Clone> Clone for TopNComputer<Score, D, C, Buffer> {
impl<Score: Clone, D: Clone, C: Clone> Clone for TopNComputer<Score, D, C> {
fn clone(&self) -> Self {
let mut buffer_clone = Vec::with_capacity(self.buffer.capacity());
buffer_clone.extend(self.buffer.iter().cloned());
@@ -587,17 +563,15 @@ impl<Score: Clone, D: Clone, C: Clone, Buffer: Clone> Clone for TopNComputer<Sco
top_n: self.top_n,
threshold: self.threshold.clone(),
comparator: self.comparator.clone(),
scratch: self.scratch.clone(),
}
}
}
impl<TSortKey, D> TopNComputer<TSortKey, D, ReverseComparator, ()>
impl<TSortKey, D> TopNComputer<TSortKey, D, ReverseComparator>
where
D: Ord,
TSortKey: Clone,
NaturalComparator: Comparator<TSortKey>,
ReverseComparator: Comparator<TSortKey>,
{
/// Create a new `TopNComputer`.
/// Internally it will allocate a buffer of size `2 * top_n`.
@@ -606,38 +580,30 @@ where
}
}
impl<TSortKey, D, C, Buffer> TopNComputer<TSortKey, D, C, Buffer>
impl<TSortKey, D, C> TopNComputer<TSortKey, D, C>
where
D: Ord,
TSortKey: Clone,
C: Comparator<TSortKey>,
Buffer: Default,
{
/// Create a new `TopNComputer`.
/// Internally it will allocate a buffer of size `(top_n.max(1) * 2) +
/// COLLECT_BLOCK_BUFFER_LEN`.
/// Internally it will allocate a buffer of size `2 * top_n`.
pub fn new_with_comparator(top_n: usize, comparator: C) -> Self {
// We ensure that there is always enough space to include an entire block in the buffer if
// need be, so that `push_block_lazy` can avoid checking capacity inside its loop.
let vec_cap = (top_n.max(1) * 2) + crate::COLLECT_BLOCK_BUFFER_LEN;
let vec_cap = top_n.max(1) * 2;
TopNComputer {
buffer: Vec::with_capacity(vec_cap),
top_n,
threshold: None,
comparator,
scratch: Buffer::default(),
}
}
/// Push a new document to the top n.
/// If the document is below the current threshold, it will be ignored.
///
/// NOTE: `push` must be called in ascending `DocId`/`DocAddress` order.
#[inline]
pub fn push(&mut self, sort_key: TSortKey, doc: D) {
if let Some(last_median) = &self.threshold {
// See the struct docs for an explanation of why this comparison is strict.
if self.comparator.compare(&sort_key, last_median) != Ordering::Greater {
if self.comparator.compare(&sort_key, last_median) == Ordering::Less {
return;
}
}
@@ -649,34 +615,24 @@ where
// At this point, we need to have established that the doc is above the threshold.
#[inline(always)]
pub(crate) fn append_doc(&mut self, doc: D, sort_key: TSortKey) {
self.reserve(1);
// This cannot panic, because we've reserved room for one element.
let comparable_doc = ComparableDoc { doc, sort_key };
push_assuming_capacity(comparable_doc, &mut self.buffer);
}
// Ensure that there is capacity to push `additional` more elements without resizing.
#[inline(always)]
pub(crate) fn reserve(&mut self, additional: usize) {
if self.buffer.len() + additional > self.buffer.capacity() {
if self.buffer.len() == self.buffer.capacity() {
let median = self.truncate_top_n();
debug_assert!(self.buffer.len() + additional <= self.buffer.capacity());
self.threshold = Some(median);
}
}
pub(crate) fn buffer_and_scratch(
&mut self,
) -> (&mut Vec<ComparableDoc<TSortKey, D>>, &mut Buffer) {
(&mut self.buffer, &mut self.scratch)
// This cannot panic, because we truncate_median will at least remove one element, since
// the min capacity is 2.
let comparable_doc = ComparableDoc { doc, sort_key };
push_assuming_capacity(comparable_doc, &mut self.buffer);
}
#[inline(never)]
fn truncate_top_n(&mut self) -> TSortKey {
// Use select_nth_unstable to find the top nth score
let (_, median_el, _) = self
.buffer
.select_nth_unstable_by(self.top_n, |lhs, rhs| self.comparator.compare_doc(lhs, rhs));
let (_, median_el, _) = self.buffer.select_nth_unstable_by(self.top_n, |lhs, rhs| {
self.comparator
.compare(&rhs.sort_key, &lhs.sort_key)
.then_with(|| lhs.doc.cmp(&rhs.doc))
});
let median_score = median_el.sort_key.clone();
// Remove all elements below the top_n
@@ -690,8 +646,11 @@ where
if self.buffer.len() > self.top_n {
self.truncate_top_n();
}
self.buffer
.sort_unstable_by(|left, right| self.comparator.compare_doc(left, right));
self.buffer.sort_unstable_by(|left, right| {
self.comparator
.compare(&right.sort_key, &left.sort_key)
.then_with(|| left.doc.cmp(&right.doc))
});
self.buffer
}
@@ -710,7 +669,7 @@ where
//
// Panics if there is not enough capacity to add an element.
#[inline(always)]
pub fn push_assuming_capacity<T>(el: T, buf: &mut Vec<T>) {
fn push_assuming_capacity<T>(el: T, buf: &mut Vec<T>) {
let prev_len = buf.len();
assert!(prev_len < buf.capacity());
// This is mimicking the current (non-stabilized) implementation in std.
@@ -727,10 +686,9 @@ mod tests {
use proptest::prelude::*;
use super::{TopDocs, TopNComputer};
use crate::collector::sort_key::{
Comparator, ComparatorEnum, NaturalComparator, ReverseComparator,
};
use crate::collector::{Collector, ComparableDoc, DocSetCollector};
use crate::collector::sort_key::{ComparatorEnum, NaturalComparator, ReverseComparator};
use crate::collector::top_collector::ComparableDoc;
use crate::collector::{Collector, DocSetCollector};
use crate::query::{AllQuery, Query, QueryParser};
use crate::schema::{Field, Schema, FAST, STORED, TEXT};
use crate::time::format_description::well_known::Rfc3339;
@@ -797,33 +755,6 @@ mod tests {
);
}
#[test]
fn test_topn_computer_duplicates() {
let mut computer: TopNComputer<u32, u32, NaturalComparator> =
TopNComputer::new_with_comparator(2, NaturalComparator);
computer.push(1u32, 1u32);
computer.push(1u32, 2u32);
computer.push(1u32, 3u32);
computer.push(1u32, 4u32);
computer.push(1u32, 5u32);
// In the presence of duplicates, DocIds are always ascending order.
assert_eq!(
computer.into_sorted_vec(),
&[
ComparableDoc {
sort_key: 1u32,
doc: 1u32,
},
ComparableDoc {
sort_key: 1u32,
doc: 2u32,
}
]
);
}
#[test]
fn test_topn_computer_no_panic() {
for top_n in 0..10 {
@@ -841,17 +772,14 @@ mod tests {
#[test]
fn test_topn_computer_asc_prop(
limit in 0..10_usize,
mut docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
) {
// NB: TopNComputer must receive inputs in ascending DocId order.
docs.sort_by_key(|(_, doc_id)| *doc_id);
let mut computer: TopNComputer<_, _, ReverseComparator> = TopNComputer::new_with_comparator(limit, ReverseComparator);
for (feature, doc) in &docs {
computer.push(*feature, *doc);
}
let mut comparable_docs =
docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::<Vec<_>>();
comparable_docs.sort_by(|l, r| ReverseComparator.compare_doc(l, r));
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> = docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::<Vec<_>>();
comparable_docs.sort();
comparable_docs.truncate(limit);
prop_assert_eq!(
computer.into_sorted_vec(),
@@ -1435,11 +1363,11 @@ mod tests {
#[test]
fn test_top_field_collect_string_prop(
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..32_usize,
offset in 0..32_usize,
limit in 1..256_usize,
offset in 0..256_usize,
segments_terms in
proptest::collection::vec(
proptest::collection::vec(0..64_u8, 1..256_usize),
proptest::collection::vec(0..32_u8, 1..32_usize),
0..8_usize,
)
) {
@@ -1478,14 +1406,15 @@ mod tests {
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
if order.is_desc() {
comparable_docs.sort_by(|l, r| NaturalComparator.compare_doc(l, r));
} else {
comparable_docs.sort_by(|l, r| ReverseComparator.compare_doc(l, r));
}
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
@@ -1764,8 +1693,7 @@ mod tests {
#[test]
fn test_top_n_computer_not_at_capacity() {
let mut top_n_computer: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(4, NaturalComparator);
let mut top_n_computer = TopNComputer::new_with_comparator(4, NaturalComparator);
top_n_computer.append_doc(1, 0.8);
top_n_computer.append_doc(3, 0.2);
top_n_computer.append_doc(5, 0.3);
@@ -1790,8 +1718,7 @@ mod tests {
#[test]
fn test_top_n_computer_at_capacity() {
let mut top_collector: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(4, NaturalComparator);
let mut top_collector = TopNComputer::new_with_comparator(4, NaturalComparator);
top_collector.append_doc(1, 0.8);
top_collector.append_doc(3, 0.2);
top_collector.append_doc(5, 0.3);
@@ -1828,14 +1755,12 @@ mod tests {
let doc_ids_collection = [4, 5, 6];
let score = 3.3f32;
let mut top_collector_limit_2: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(2, NaturalComparator);
let mut top_collector_limit_2 = TopNComputer::new_with_comparator(2, NaturalComparator);
for id in &doc_ids_collection {
top_collector_limit_2.append_doc(*id, score);
}
let mut top_collector_limit_3: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(3, NaturalComparator);
let mut top_collector_limit_3 = TopNComputer::new_with_comparator(3, NaturalComparator);
for id in &doc_ids_collection {
top_collector_limit_3.append_doc(*id, score);
}
@@ -1856,16 +1781,15 @@ mod bench {
#[bench]
fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) {
let mut top_collector: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(100, NaturalComparator);
let mut top_collector = TopNComputer::new_with_comparator(100, NaturalComparator);
for i in 0..100 {
top_collector.append_doc(i as u32, 0.8);
top_collector.append_doc(i, 0.8);
}
b.iter(|| {
for i in 0..100 {
top_collector.append_doc(i as u32, 0.8);
top_collector.append_doc(i, 0.8);
}
});
}

View File

@@ -36,7 +36,6 @@ fn path_for_version(version: &str) -> String {
/// feature flag quickwit uses a different dictionary type
#[test]
#[cfg(not(feature = "quickwit"))]
#[ignore = "test incompatible with fixed-width footer changes"]
fn test_format_6() {
let path = path_for_version("6");
@@ -48,7 +47,6 @@ fn test_format_6() {
/// feature flag quickwit uses a different dictionary type
#[test]
#[cfg(not(feature = "quickwit"))]
#[ignore = "test incompatible with fixed-width footer changes"]
fn test_format_7() {
let path = path_for_version("7");

View File

@@ -48,7 +48,15 @@ impl Executor {
F: Sized + Sync + Fn(A) -> crate::Result<R>,
{
match self {
Executor::SingleThread => args.map(f).collect::<crate::Result<_>>(),
Executor::SingleThread => {
// Avoid `collect`, since the stacktrace is blown up by it, which makes profiling
// harder.
let mut result = Vec::with_capacity(args.size_hint().0);
for arg in args {
result.push(f(arg)?);
}
Ok(result)
}
Executor::ThreadPool(pool) => {
let args: Vec<A> = args.collect();
let num_fruits = args.len();

View File

@@ -406,7 +406,7 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_str("red");
assert_eq!(term.serialized_value_bytes(), b"color\x00sred".to_vec())
assert_eq!(term.serialized_term(), b"\x00\x00\x00\x01jcolor\x00sred")
}
#[test]
@@ -416,8 +416,8 @@ mod tests {
term.append_type_and_fast_value(-4i64);
assert_eq!(
term.serialized_value_bytes(),
b"color\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc".to_vec()
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc"
)
}
@@ -428,8 +428,8 @@ mod tests {
term.append_type_and_fast_value(4u64);
assert_eq!(
term.serialized_value_bytes(),
b"color\x00u\x00\x00\x00\x00\x00\x00\x00\x04".to_vec()
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00u\x00\x00\x00\x00\x00\x00\x00\x04"
)
}
@@ -439,8 +439,8 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_fast_value(4.0f64);
assert_eq!(
term.serialized_value_bytes(),
b"color\x00f\xc0\x10\x00\x00\x00\x00\x00\x00".to_vec()
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00f\xc0\x10\x00\x00\x00\x00\x00\x00"
)
}
@@ -450,8 +450,8 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_fast_value(true);
assert_eq!(
term.serialized_value_bytes(),
b"color\x00o\x00\x00\x00\x00\x00\x00\x00\x01".to_vec()
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00o\x00\x00\x00\x00\x00\x00\x00\x01"
)
}

View File

@@ -1,5 +1,5 @@
use std::collections::BTreeMap;
use std::sync::{Arc, OnceLock};
use std::sync::Arc;
use std::{fmt, io};
use crate::collector::Collector;
@@ -86,7 +86,7 @@ impl Searcher {
/// The searcher uses the segment ordinal to route the
/// request to the right `Segment`.
pub fn doc<D: DocumentDeserialize>(&self, doc_address: DocAddress) -> crate::Result<D> {
let store_reader = &self.inner.store_readers()[doc_address.segment_ord as usize];
let store_reader = &self.inner.store_readers[doc_address.segment_ord as usize];
store_reader.get(doc_address.doc_id)
}
@@ -96,7 +96,7 @@ impl Searcher {
pub fn doc_store_cache_stats(&self) -> CacheStats {
let cache_stats: CacheStats = self
.inner
.store_readers()
.store_readers
.iter()
.map(|reader| reader.cache_stats())
.sum();
@@ -110,7 +110,7 @@ impl Searcher {
doc_address: DocAddress,
) -> crate::Result<D> {
let executor = self.inner.index.search_executor();
let store_reader = &self.inner.store_readers()[doc_address.segment_ord as usize];
let store_reader = &self.inner.store_readers[doc_address.segment_ord as usize];
store_reader.get_async(doc_address.doc_id, executor).await
}
@@ -259,9 +259,8 @@ impl From<Arc<SearcherInner>> for Searcher {
pub(crate) struct SearcherInner {
schema: Schema,
index: Index,
doc_store_cache_num_blocks: usize,
segment_readers: Vec<SegmentReader>,
store_readers: OnceLock<Vec<StoreReader>>,
store_readers: Vec<StoreReader>,
generation: TrackedObject<SearcherGeneration>,
}
@@ -282,30 +281,19 @@ impl SearcherInner {
generation.segments(),
"Set of segments referenced by this Searcher and its SearcherGeneration must match"
);
let store_readers: Vec<StoreReader> = segment_readers
.iter()
.map(|segment_reader| segment_reader.get_store_reader(doc_store_cache_num_blocks))
.collect::<io::Result<Vec<_>>>()?;
Ok(SearcherInner {
schema,
index,
doc_store_cache_num_blocks,
segment_readers,
store_readers: OnceLock::default(),
store_readers,
generation,
})
}
#[inline]
fn store_readers(&self) -> &[StoreReader] {
self.store_readers.get_or_init(|| {
self.segment_readers
.iter()
.map(|segment_reader| {
segment_reader
.get_store_reader(self.doc_store_cache_num_blocks)
.expect("should be able to get store reader")
})
.collect()
})
}
}
impl fmt::Debug for Searcher {

View File

@@ -5,7 +5,7 @@ use std::ops::Range;
use common::{BinarySerializable, CountingWriter, HasLen, VInt};
use crate::directory::{FileSlice, TerminatingWrite, WritePtr};
use crate::schema::{Field, Schema};
use crate::schema::Field;
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
#[derive(Eq, PartialEq, Hash, Copy, Ord, PartialOrd, Clone, Debug)]
@@ -167,11 +167,10 @@ impl CompositeFile {
.map(|byte_range| self.data.slice(byte_range.clone()))
}
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
pub fn space_usage(&self) -> PerFieldSpaceUsage {
let mut fields = Vec::new();
for (&field_addr, byte_range) in &self.offsets_index {
let field_name = schema.get_field_name(field_addr.field).to_string();
let mut field_usage = FieldUsage::empty(field_name);
let mut field_usage = FieldUsage::empty(field_addr.field);
field_usage.add_field_idx(field_addr.idx, byte_range.len().into());
fields.push(field_usage);
}

View File

@@ -1,20 +1,12 @@
use std::any::Any;
use std::collections::HashSet;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, io, thread};
use log::Level;
use crate::directory::directory_lock::Lock;
use crate::directory::error::{DeleteError, LockError, OpenReadError, OpenWriteError};
use crate::directory::{
FileHandle, FileSlice, TerminatingWrite, WatchCallback, WatchHandle, WritePtr,
};
use crate::index::SegmentMetaInventory;
use crate::IndexMeta;
use crate::directory::{FileHandle, FileSlice, WatchCallback, WatchHandle, WritePtr};
/// Retry the logic of acquiring locks is pretty simple.
/// We just retry `n` times after a given `duratio`, both
@@ -64,7 +56,7 @@ impl<T: Send + Sync + 'static> From<Box<T>> for DirectoryLock {
impl Drop for DirectoryLockGuard {
fn drop(&mut self) {
if let Err(e) = self.directory.delete(&self.path) {
error!("Failed to remove the lock file. {:?}", e);
error!("Failed to remove the lock file. {e:?}");
}
}
}
@@ -105,8 +97,6 @@ fn retry_policy(is_blocking: bool) -> RetryPolicy {
}
}
pub type DirectoryPanicHandler = Arc<dyn Fn(Box<dyn Any + Send>) + Send + Sync + 'static>;
/// Write-once read many (WORM) abstraction for where
/// tantivy's data should be stored.
///
@@ -145,10 +135,6 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
/// Returns true if and only if the file exists
fn exists(&self, path: &Path) -> Result<bool, OpenReadError>;
/// Returns a boxed `TerminatingWrite` object, to be passed into `open_write`
/// which wraps it in a `BufWriter`
fn open_write_inner(&self, path: &Path) -> Result<Box<dyn TerminatingWrite>, OpenWriteError>;
/// Opens a writer for the *virtual file* associated with
/// a [`Path`].
///
@@ -175,12 +161,7 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
/// panic! if `flush` was not called.
///
/// The file may not previously exist.
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
Ok(io::BufWriter::with_capacity(
self.bufwriter_capacity(),
self.open_write_inner(path)?,
))
}
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError>;
/// Reads the full content file that has been written using
/// [`Directory::atomic_write()`].
@@ -242,75 +223,6 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
/// `OnCommitWithDelay` `ReloadPolicy`. Not implementing watch in a `Directory` only prevents
/// the `OnCommitWithDelay` `ReloadPolicy` to work properly.
fn watch(&self, watch_callback: WatchCallback) -> crate::Result<WatchHandle>;
/// Allows the directory to list managed files, overriding the ManagedDirectory's default
/// list_managed_files
fn list_managed_files(&self) -> crate::Result<HashSet<PathBuf>> {
Err(crate::TantivyError::InternalError(
"list_managed_files not implemented".to_string(),
))
}
/// Allows the directory to register a file as managed, overriding the ManagedDirectory's
/// default register_file_as_managed
fn register_files_as_managed(
&self,
_files: Vec<PathBuf>,
_overwrite: bool,
) -> crate::Result<()> {
Err(crate::TantivyError::InternalError(
"register_files_as_managed not implemented".to_string(),
))
}
/// Allows the directory to save IndexMeta, overriding the SegmentUpdater's default save_meta
fn save_metas(
&self,
_metas: &IndexMeta,
_previous_metas: &IndexMeta,
_payload: &mut (dyn Any + '_),
) -> crate::Result<()> {
Err(crate::TantivyError::InternalError(
"save_meta not implemented".to_string(),
))
}
/// Allows the directory to load IndexMeta, overriding the SegmentUpdater's default load_meta
fn load_metas(&self, _inventory: &SegmentMetaInventory) -> crate::Result<IndexMeta> {
Err(crate::TantivyError::InternalError(
"load_metas not implemented".to_string(),
))
}
/// Returns true if this directory supports garbage collection. The default assumption is
/// `true`
fn supports_garbage_collection(&self) -> bool {
true
}
/// Return a panic handler to be assigned to the various thread pools that may be created
///
/// The default is [`None`], which indicates that an unhandled panic from a thread pool will
/// abort the process
fn panic_handler(&self) -> Option<DirectoryPanicHandler> {
None
}
/// Returns true if this directory is in a position of requiring that tantivy cancel
/// whatever operation(s) it might be doing Typically this is just for the background
/// merge processes but could be used for anything
fn wants_cancel(&self) -> bool {
false
}
/// Send a logging message to the Directory to handle in its own way
fn log(&self, message: &str) {
log!(Level::Info, "{message}");
}
fn bufwriter_capacity(&self) -> usize {
8192
}
}
/// DirectoryClone

View File

@@ -58,9 +58,3 @@ pub static META_LOCK: Lazy<Lock> = Lazy::new(|| Lock {
filepath: PathBuf::from(".tantivy-meta.lock"),
is_blocking: true,
});
#[allow(missing_docs)]
pub static MANAGED_LOCK: Lazy<Lock> = Lazy::new(|| Lock {
filepath: PathBuf::from(".tantivy-managed.lock"),
is_blocking: true,
});

View File

@@ -9,7 +9,6 @@ use crc32fast::Hasher;
use crate::directory::{WatchCallback, WatchCallbackList, WatchHandle};
#[allow(dead_code)]
const POLLING_INTERVAL: Duration = Duration::from_millis(if cfg!(test) { 1 } else { 500 });
// Watches a file and executes registered callbacks when the file is modified.
@@ -19,7 +18,6 @@ pub struct FileWatcher {
state: Arc<AtomicUsize>, // 0: new, 1: runnable, 2: terminated
}
#[allow(dead_code)]
impl FileWatcher {
pub fn new(path: &Path) -> FileWatcher {
FileWatcher {

View File

@@ -7,14 +7,15 @@
use std::io;
use std::io::Write;
use common::{BinarySerializable, HasLen};
use common::{BinarySerializable, CountingWriter, DeserializeFrom, FixedSize, HasLen};
use crc32fast::Hasher;
use serde::{Deserialize, Serialize};
use crate::directory::error::Incompatibility;
use crate::directory::{AntiCallToken, FileSlice, TerminatingWrite};
use crate::{Version, INDEX_FORMAT_OLDEST_SUPPORTED_VERSION, INDEX_FORMAT_VERSION};
pub const FOOTER_LEN: usize = 24;
const FOOTER_MAX_LEN: u32 = 50_000;
/// The magic byte of the footer to identify corruption
/// or an old version of the footer.
@@ -23,7 +24,7 @@ const FOOTER_MAGIC_NUMBER: u32 = 1337;
type CrcHashU32 = u32;
/// A Footer is appended to every file
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Footer {
/// The version of the index format
pub version: Version,
@@ -40,45 +41,34 @@ impl Footer {
pub(crate) fn crc(&self) -> CrcHashU32 {
self.crc
}
pub fn append_footer<W: io::Write>(&self, write: &mut W) -> io::Result<()> {
// 24 bytes
BinarySerializable::serialize(&self.version.major, write)?;
BinarySerializable::serialize(&self.version.minor, write)?;
BinarySerializable::serialize(&self.version.patch, write)?;
BinarySerializable::serialize(&self.version.index_format_version, write)?;
BinarySerializable::serialize(&self.crc, write)?;
pub(crate) fn append_footer<W: io::Write>(&self, mut write: &mut W) -> io::Result<()> {
let mut counting_write = CountingWriter::wrap(&mut write);
counting_write.write_all(serde_json::to_string(&self)?.as_ref())?;
let footer_payload_len = counting_write.written_bytes();
BinarySerializable::serialize(&(footer_payload_len as u32), write)?;
BinarySerializable::serialize(&FOOTER_MAGIC_NUMBER, write)?;
Ok(())
}
/// Extracts the tantivy Footer from the file and returns the footer and the rest of the file
pub fn extract_footer(file: FileSlice) -> io::Result<(Footer, FileSlice)> {
if file.len() < FOOTER_LEN {
if file.len() < 4 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"File corrupted. The file is too small to contain the {FOOTER_LEN} byte \
footer (len={}).",
"File corrupted. The file is smaller than 4 bytes (len={}).",
file.len()
),
));
}
let (body_slice, footer_slice) = file.split_from_end(FOOTER_LEN);
let footer_bytes = footer_slice.read_bytes()?;
let mut footer_bytes = footer_bytes.as_slice();
let footer_metadata_len = <(u32, u32)>::SIZE_IN_BYTES;
let (footer_len, footer_magic_byte): (u32, u32) = file
.slice_from_end(footer_metadata_len)
.read_bytes()?
.as_ref()
.deserialize()?;
let footer = Footer {
version: Version {
major: u32::deserialize(&mut footer_bytes)?,
minor: u32::deserialize(&mut footer_bytes)?,
patch: u32::deserialize(&mut footer_bytes)?,
index_format_version: u32::deserialize(&mut footer_bytes)?,
},
crc: u32::deserialize(&mut footer_bytes)?,
};
let footer_magic_byte = u32::deserialize(&mut footer_bytes)?;
if footer_magic_byte != FOOTER_MAGIC_NUMBER {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -88,12 +78,38 @@ impl Footer {
));
}
Ok((footer, body_slice))
if footer_len > FOOTER_MAX_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Footer seems invalid as it suggests a footer len of {footer_len}. File is \
corrupted, or the index was created with a different & old version of \
tantivy."
),
));
}
let total_footer_size = footer_len as usize + footer_metadata_len;
if file.len() < total_footer_size {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"File corrupted. The file is smaller than it's footer bytes \
(len={total_footer_size})."
),
));
}
let footer: Footer =
serde_json::from_slice(&file.read_bytes_slice(
file.len() - total_footer_size..file.len() - footer_metadata_len,
)?)?;
let body = file.slice_to(file.len() - total_footer_size);
Ok((footer, body))
}
/// Confirms that the index will be read correctly by this version of tantivy
/// Has to be called after `extract_footer` to make sure it's not accessing uninitialised memory
#[allow(dead_code)]
pub fn is_compatible(&self) -> Result<(), Incompatibility> {
const SUPPORTED_INDEX_FORMAT_VERSION_RANGE: std::ops::RangeInclusive<u32> =
INDEX_FORMAT_OLDEST_SUPPORTED_VERSION..=INDEX_FORMAT_VERSION;
@@ -172,10 +188,6 @@ mod tests {
fn test_deserialize_footer_missing_magic_byte() {
let mut buf: Vec<u8> = vec![];
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
BinarySerializable::serialize(&0_u32, &mut buf).unwrap();
let wrong_magic_byte: u32 = 5555;
BinarySerializable::serialize(&wrong_magic_byte, &mut buf).unwrap();
@@ -193,6 +205,7 @@ mod tests {
#[test]
fn test_deserialize_footer_wrong_filesize() {
let mut buf: Vec<u8> = vec![];
BinarySerializable::serialize(&100_u32, &mut buf).unwrap();
BinarySerializable::serialize(&FOOTER_MAGIC_NUMBER, &mut buf).unwrap();
let owned_bytes = OwnedBytes::new(buf);
@@ -202,7 +215,27 @@ mod tests {
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
assert_eq!(
err.to_string(),
"File corrupted. The file is too small to contain the 24 byte footer (len=4)."
"File corrupted. The file is smaller than it\'s footer bytes (len=108)."
);
}
#[test]
fn test_deserialize_too_large_footer() {
let mut buf: Vec<u8> = vec![];
let footer_length = super::FOOTER_MAX_LEN + 1;
BinarySerializable::serialize(&footer_length, &mut buf).unwrap();
BinarySerializable::serialize(&FOOTER_MAGIC_NUMBER, &mut buf).unwrap();
let owned_bytes = OwnedBytes::new(buf);
let fileslice = FileSlice::new(Arc::new(owned_bytes));
let err = Footer::extract_footer(fileslice).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert_eq!(
err.to_string(),
"Footer seems invalid as it suggests a footer len of 50001. File is corrupted, or the \
index was created with a different & old version of tantivy."
);
}
}

View File

@@ -1,22 +1,20 @@
use std::any::Any;
use std::collections::HashSet;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::{Arc, RwLock, RwLockWriteGuard};
use std::{io, result};
use crc32fast::Hasher;
use crate::core::MANAGED_FILEPATH;
use crate::directory::error::{DeleteError, LockError, OpenReadError, OpenWriteError};
use crate::directory::footer::{Footer, FooterProxy, FOOTER_LEN};
use crate::directory::footer::{Footer, FooterProxy};
use crate::directory::{
DirectoryLock, DirectoryPanicHandler, FileHandle, FileSlice, GarbageCollectionResult, Lock,
TerminatingWrite, WatchCallback, WatchHandle, MANAGED_LOCK, META_LOCK,
DirectoryLock, FileHandle, FileSlice, GarbageCollectionResult, Lock, WatchCallback,
WatchHandle, WritePtr, META_LOCK,
};
use crate::error::DataCorruption;
use crate::index::SegmentMetaInventory;
use crate::{Directory, IndexMeta};
use crate::Directory;
/// Returns true if the file is "managed".
/// Non-managed file are not subject to garbage collection.
@@ -41,9 +39,9 @@ fn is_managed(path: &Path) -> bool {
#[derive(Debug)]
pub struct ManagedDirectory {
directory: Box<dyn Directory>,
meta_informations: Arc<RwLock<MetaInformation>>,
}
#[allow(unused)]
#[derive(Debug, Default)]
struct MetaInformation {
managed_paths: HashSet<PathBuf>,
@@ -53,9 +51,9 @@ struct MetaInformation {
/// that were created by tantivy.
fn save_managed_paths(
directory: &dyn Directory,
managed_paths: &HashSet<PathBuf>,
wlock: &RwLockWriteGuard<'_, MetaInformation>,
) -> io::Result<()> {
let mut w = serde_json::to_vec(managed_paths)?;
let mut w = serde_json::to_vec(&wlock.managed_paths)?;
writeln!(&mut w)?;
directory.atomic_write(&MANAGED_FILEPATH, &w[..])?;
Ok(())
@@ -64,37 +62,33 @@ fn save_managed_paths(
impl ManagedDirectory {
/// Wraps a directory as managed directory.
pub fn wrap(directory: Box<dyn Directory>) -> crate::Result<ManagedDirectory> {
Ok(ManagedDirectory { directory })
}
pub fn list_managed_files(&self) -> crate::Result<HashSet<PathBuf>> {
match self.directory.list_managed_files() {
Ok(managed_files) => Ok(managed_files),
Err(crate::TantivyError::InternalError(_)) => {
match self.directory.atomic_read(&MANAGED_FILEPATH) {
Ok(data) => {
let managed_files_json = String::from_utf8_lossy(&data);
let managed_files: HashSet<PathBuf> =
serde_json::from_str(&managed_files_json).map_err(|e| {
DataCorruption::new(
MANAGED_FILEPATH.to_path_buf(),
format!("Managed file cannot be deserialized: {e:?}. "),
)
})?;
Ok(managed_files)
}
Err(OpenReadError::FileDoesNotExist(_)) => Ok(HashSet::new()),
io_err @ Err(OpenReadError::IoError { .. }) => {
Err(io_err.err().unwrap().into())
}
Err(OpenReadError::IncompatibleIndex(incompatibility)) => {
// For the moment, this should never happen `meta.json`
// do not have any footer and cannot detect incompatibility.
Err(crate::TantivyError::IncompatibleIndex(incompatibility))
}
}
match directory.atomic_read(&MANAGED_FILEPATH) {
Ok(data) => {
let managed_files_json = String::from_utf8_lossy(&data);
let managed_files: HashSet<PathBuf> = serde_json::from_str(&managed_files_json)
.map_err(|e| {
DataCorruption::new(
MANAGED_FILEPATH.to_path_buf(),
format!("Managed file cannot be deserialized: {e:?}. "),
)
})?;
Ok(ManagedDirectory {
directory,
meta_informations: Arc::new(RwLock::new(MetaInformation {
managed_paths: managed_files,
})),
})
}
Err(OpenReadError::FileDoesNotExist(_)) => Ok(ManagedDirectory {
directory,
meta_informations: Arc::default(),
}),
io_err @ Err(OpenReadError::IoError { .. }) => Err(io_err.err().unwrap().into()),
Err(OpenReadError::IncompatibleIndex(incompatibility)) => {
// For the moment, this should never happen `meta.json`
// do not have any footer and cannot detect incompatibility.
Err(crate::TantivyError::IncompatibleIndex(incompatibility))
}
Err(err) => Err(err),
}
}
@@ -116,26 +110,9 @@ impl ManagedDirectory {
&mut self,
get_living_files: L,
) -> crate::Result<GarbageCollectionResult> {
if !self.supports_garbage_collection() {
// the underlying directory does not support garbage collection.
return Ok(GarbageCollectionResult {
deleted_files: vec![],
failed_to_delete_files: vec![],
});
}
info!("Garbage collect");
let mut files_to_delete = vec![];
// We're about to do an atomic write to managed.json, lock it down
let _lock = self.acquire_lock(&MANAGED_LOCK)?;
let managed_paths = match self.directory.list_managed_files() {
Ok(managed_paths) => managed_paths,
Err(crate::TantivyError::InternalError(_)) => {
// If the managed.json file does not exist, we consider
// that there is no managed file.
self.list_managed_files()?
}
Err(err) => return Err(err),
};
// It is crucial to get the living files after acquiring the
// read lock of meta information. That way, we
// avoid the following scenario.
@@ -147,6 +124,11 @@ impl ManagedDirectory {
//
// releasing the lock as .delete() will use it too.
{
let meta_informations_rlock = self
.meta_informations
.read()
.expect("Managed directory rlock poisoned in garbage collect.");
// The point of this second "file" lock is to enforce the following scenario
// 1) process B tries to load a new set of searcher.
// The list of segments is loaded
@@ -156,7 +138,7 @@ impl ManagedDirectory {
match self.acquire_lock(&META_LOCK) {
Ok(_meta_lock) => {
let living_files = get_living_files();
for managed_path in &managed_paths {
for managed_path in &meta_informations_rlock.managed_paths {
if !living_files.contains(managed_path) {
files_to_delete.push(managed_path.clone());
}
@@ -199,18 +181,16 @@ impl ManagedDirectory {
if !deleted_files.is_empty() {
// update the list of managed files by removing
// the file that were removed.
let mut managed_paths_write = managed_paths;
let mut meta_informations_wlock = self
.meta_informations
.write()
.expect("Managed directory wlock poisoned (2).");
let managed_paths_write = &mut meta_informations_wlock.managed_paths;
for delete_file in &deleted_files {
managed_paths_write.remove(delete_file);
}
self.directory.sync_directory()?;
if let Err(crate::TantivyError::InternalError(_)) = self
.directory
.register_files_as_managed(managed_paths_write.clone().into_iter().collect(), true)
{
save_managed_paths(self.directory.as_mut(), &managed_paths_write)?;
}
save_managed_paths(self.directory.as_mut(), &meta_informations_wlock)?;
}
Ok(GarbageCollectionResult {
@@ -235,39 +215,27 @@ impl ManagedDirectory {
if !is_managed(filepath) {
return Ok(());
}
// We're about to do an atomic write to managed.json, lock it down
let _lock = self
.acquire_lock(&MANAGED_LOCK)
.expect("must be able to acquire lock for managed.json");
if let Err(crate::TantivyError::InternalError(_)) = self
.directory
.register_files_as_managed(vec![filepath.to_owned()], false)
{
let mut managed_paths = self
.list_managed_files()
.expect("reading managed files should not fail");
let has_changed = managed_paths.insert(filepath.to_owned());
if !has_changed {
return Ok(());
}
save_managed_paths(self.directory.as_ref(), &managed_paths)?;
// This is not the first file we add.
// Therefore, we are sure that `.managed.json` has been already
// properly created and we do not need to sync its parent directory.
//
// (It might seem like a nicer solution to create the managed_json on the
// creation of the ManagedDirectory instance but it would actually
// prevent the use of read-only directories..)
let managed_file_definitely_already_exists = managed_paths.len() > 1;
if managed_file_definitely_already_exists {
return Ok(());
}
let mut meta_wlock = self
.meta_informations
.write()
.expect("Managed file lock poisoned");
let has_changed = meta_wlock.managed_paths.insert(filepath.to_owned());
if !has_changed {
return Ok(());
}
save_managed_paths(self.directory.as_ref(), &meta_wlock)?;
// This is not the first file we add.
// Therefore, we are sure that `.managed.json` has been already
// properly created and we do not need to sync its parent directory.
//
// (It might seem like a nicer solution to create the managed_json on the
// creation of the ManagedDirectory instance but it would actually
// prevent the use of read-only directories..)
let managed_file_definitely_already_exists = meta_wlock.managed_paths.len() > 1;
if managed_file_definitely_already_exists {
return Ok(());
}
self.directory.sync_directory()?;
Ok(())
}
@@ -287,6 +255,17 @@ impl ManagedDirectory {
let crc = hasher.finalize();
Ok(footer.crc() == crc)
}
/// List all managed files
pub fn list_managed_files(&self) -> HashSet<PathBuf> {
let managed_paths = self
.meta_informations
.read()
.expect("Managed directory rlock poisoned in list damaged.")
.managed_paths
.clone();
managed_paths
}
}
impl Directory for ManagedDirectory {
@@ -297,32 +276,22 @@ impl Directory for ManagedDirectory {
fn open_read(&self, path: &Path) -> result::Result<FileSlice, OpenReadError> {
let file_slice = self.directory.open_read(path)?;
debug_assert!(
{
use common::HasLen;
file_slice.len() >= FOOTER_LEN
},
"{} is too short",
path.display()
);
let (reader, _) = file_slice.split_from_end(FOOTER_LEN);
// NB: We do not read/validate the footer here -- we blindly skip it entirely
let (footer, reader) = Footer::extract_footer(file_slice)
.map_err(|io_error| OpenReadError::wrap_io_error(io_error, path.to_path_buf()))?;
footer.is_compatible()?;
Ok(reader)
}
fn open_write_inner(
&self,
path: &Path,
) -> result::Result<Box<dyn TerminatingWrite>, OpenWriteError> {
fn open_write(&self, path: &Path) -> result::Result<WritePtr, OpenWriteError> {
self.register_file_as_managed(path)
.map_err(|io_error| OpenWriteError::wrap_io_error(io_error, path.to_path_buf()))?;
Ok(Box::new(FooterProxy::new(
Ok(io::BufWriter::new(Box::new(FooterProxy::new(
self.directory
.open_write(path)?
.into_inner()
.map_err(|_| ())
.expect("buffer should be empty"),
)))
))))
}
fn atomic_write(&self, path: &Path, data: &[u8]) -> io::Result<()> {
@@ -354,45 +323,13 @@ impl Directory for ManagedDirectory {
self.directory.sync_directory()?;
Ok(())
}
fn save_metas(
&self,
metas: &IndexMeta,
previous_metas: &IndexMeta,
payload: &mut (dyn Any + '_),
) -> crate::Result<()> {
self.directory.save_metas(metas, previous_metas, payload)
}
fn load_metas(&self, inventory: &SegmentMetaInventory) -> crate::Result<IndexMeta> {
self.directory.load_metas(inventory)
}
fn supports_garbage_collection(&self) -> bool {
self.directory.supports_garbage_collection()
}
fn panic_handler(&self) -> Option<DirectoryPanicHandler> {
self.directory.panic_handler()
}
fn wants_cancel(&self) -> bool {
self.directory.wants_cancel()
}
fn log(&self, message: &str) {
self.directory.log(message);
}
fn bufwriter_capacity(&self) -> usize {
self.directory.bufwriter_capacity()
}
}
impl Clone for ManagedDirectory {
fn clone(&self) -> ManagedDirectory {
ManagedDirectory {
directory: self.directory.box_clone(),
meta_informations: Arc::clone(&self.meta_informations),
}
}
}
@@ -400,6 +337,7 @@ impl Clone for ManagedDirectory {
#[cfg(feature = "mmap")]
#[cfg(test)]
mod tests_mmap_specific {
use std::collections::HashSet;
use std::io::Write;
use std::path::{Path, PathBuf};

View File

@@ -1,7 +1,7 @@
use std::collections::HashMap;
use std::fmt;
use std::fs::{self, File, OpenOptions};
use std::io::{self, Read, Write};
use std::io::{self, BufWriter, Read, Write};
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock, Weak};
@@ -21,7 +21,7 @@ use crate::directory::error::{
use crate::directory::file_watcher::FileWatcher;
use crate::directory::{
AntiCallToken, Directory, DirectoryLock, FileHandle, Lock, OwnedBytes, TerminatingWrite,
WatchCallback, WatchHandle,
WatchCallback, WatchHandle, WritePtr,
};
pub type ArcBytes = Arc<dyn Deref<Target = [u8]> + Send + Sync + 'static>;
@@ -413,8 +413,8 @@ impl Directory for MmapDirectory {
.map_err(|io_err| OpenReadError::wrap_io_error(io_err, path.to_path_buf()))
}
fn open_write_inner(&self, path: &Path) -> Result<Box<dyn TerminatingWrite>, OpenWriteError> {
debug!("Open Write {:?}", path);
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
debug!("Open Write {path:?}");
let full_path = self.resolve_path(path);
let open_res = OpenOptions::new()
@@ -443,7 +443,7 @@ impl Directory for MmapDirectory {
// sync_directory() is called.
let writer = SafeFileWriter::new(file);
Ok(Box::new(writer))
Ok(BufWriter::new(Box::new(writer)))
}
fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {

View File

@@ -19,13 +19,12 @@ mod composite_file;
use std::io::BufWriter;
use std::path::PathBuf;
pub use common::buffered_file_slice::BufferedFileSlice;
pub use common::file_slice::{FileHandle, FileSlice};
pub use common::{AntiCallToken, OwnedBytes, TerminatingWrite};
pub(crate) use self::composite_file::{CompositeFile, CompositeWrite};
pub use self::directory::{Directory, DirectoryClone, DirectoryLock, DirectoryPanicHandler};
pub use self::directory_lock::{Lock, INDEX_WRITER_LOCK, MANAGED_LOCK, META_LOCK};
pub use self::directory::{Directory, DirectoryClone, DirectoryLock};
pub use self::directory_lock::{Lock, INDEX_WRITER_LOCK, META_LOCK};
pub use self::ram_directory::RamDirectory;
pub use self::watch_event_router::{WatchCallback, WatchCallbackList, WatchHandle};

View File

@@ -1,5 +1,5 @@
use std::collections::HashMap;
use std::io::{self, Cursor, Write};
use std::io::{self, BufWriter, Cursor, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use std::{fmt, result};
@@ -11,7 +11,7 @@ use crate::core::META_FILEPATH;
use crate::directory::error::{DeleteError, OpenReadError, OpenWriteError};
use crate::directory::{
AntiCallToken, Directory, FileSlice, TerminatingWrite, WatchCallback, WatchCallbackList,
WatchHandle,
WatchHandle, WritePtr,
};
/// Writer associated with the [`RamDirectory`].
@@ -197,7 +197,7 @@ impl Directory for RamDirectory {
.exists(path))
}
fn open_write_inner(&self, path: &Path) -> Result<Box<dyn TerminatingWrite>, OpenWriteError> {
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
let mut fs = self.fs.write().unwrap();
let path_buf = PathBuf::from(path);
let vec_writer = VecWriter::new(path_buf.clone(), self.clone());
@@ -206,7 +206,7 @@ impl Directory for RamDirectory {
if exists {
Err(OpenWriteError::FileAlreadyExists(path_buf))
} else {
Ok(Box::new(vec_writer))
Ok(BufWriter::new(Box::new(vec_writer)))
}
}

View File

@@ -110,11 +110,6 @@ pub enum TantivyError {
#[error("Deserialize error: {0}")]
/// An error occurred while attempting to deserialize a document.
DeserializeError(DeserializeError),
/// The user requested the current operation be cancelled
#[error("User requested cancel")]
Cancelled,
#[error("Segment Merging failed: {0:#?}")]
MergeErrors(Vec<TantivyError>),
}
impl From<io::Error> for TantivyError {

View File

@@ -79,7 +79,7 @@ mod tests {
use std::ops::{Range, RangeInclusive};
use std::path::Path;
use columnar::{StrColumn, ValueRange};
use columnar::StrColumn;
use common::{ByteCount, DateTimePrecision, HasLen, TerminatingWrite};
use once_cell::sync::Lazy;
use rand::prelude::SliceRandom;
@@ -395,7 +395,7 @@ mod tests {
.unwrap()
.first_or_default_col(0);
for a in 0..n {
assert_eq!(col.get_val(a as u32), permutation[a], "for doc {a}");
assert_eq!(col.get_val(a as u32), permutation[a]);
}
}
@@ -944,7 +944,7 @@ mod tests {
let test_range = |range: RangeInclusive<u64>| {
let expected_count = numbers.iter().filter(|num| range.contains(*num)).count();
let mut vec = vec![];
field.get_row_ids_for_value_range(ValueRange::Inclusive(range), 0..u32::MAX, &mut vec);
field.get_row_ids_for_value_range(range, 0..u32::MAX, &mut vec);
assert_eq!(vec.len(), expected_count);
};
test_range(50..=50);
@@ -1022,7 +1022,7 @@ mod tests {
let test_range = |range: RangeInclusive<u64>| {
let expected_count = numbers.iter().filter(|num| range.contains(*num)).count();
let mut vec = vec![];
field.get_row_ids_for_value_range(ValueRange::Inclusive(range), 0..u32::MAX, &mut vec);
field.get_row_ids_for_value_range(range, 0..u32::MAX, &mut vec);
assert_eq!(vec.len(), expected_count);
};
let test_range_variant = |start, stop| {

View File

@@ -8,7 +8,7 @@ use columnar::{
};
use common::ByteCount;
use crate::core::json_utils::{encode_column_name, json_path_sep_to_dot};
use crate::core::json_utils::encode_column_name;
use crate::directory::FileSlice;
use crate::schema::{Field, FieldEntry, FieldType, Schema};
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
@@ -39,15 +39,19 @@ impl FastFieldReaders {
self.resolve_column_name_given_default_field(column_name, default_field_opt)
}
pub(crate) fn space_usage(&self) -> io::Result<PerFieldSpaceUsage> {
pub(crate) fn space_usage(&self, schema: &Schema) -> io::Result<PerFieldSpaceUsage> {
let mut per_field_usages: Vec<FieldUsage> = Default::default();
for (mut field_name, column_handle) in self.columnar.iter_columns()? {
json_path_sep_to_dot(&mut field_name);
let space_usage = column_handle.space_usage()?;
let mut field_usage = FieldUsage::empty(field_name);
field_usage.set_column_usage(space_usage);
for (field, field_entry) in schema.fields() {
let column_handles = self.columnar.read_columns(field_entry.name())?;
let num_bytes: ByteCount = column_handles
.iter()
.map(|column_handle| column_handle.num_bytes())
.sum();
let mut field_usage = FieldUsage::empty(field);
field_usage.add_field_idx(0, num_bytes);
per_field_usages.push(field_usage);
}
// TODO fix space usage for JSON fields.
Ok(PerFieldSpaceUsage::new(per_field_usages))
}

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use super::{fieldnorm_to_id, id_to_fieldnorm};
use crate::directory::{CompositeFile, FileSlice, OwnedBytes};
use crate::schema::{Field, Schema};
use crate::schema::Field;
use crate::space_usage::PerFieldSpaceUsage;
use crate::DocId;
@@ -37,8 +37,8 @@ impl FieldNormReaders {
}
/// Return a break down of the space usage per field.
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
self.data.space_usage(schema)
pub fn space_usage(&self) -> PerFieldSpaceUsage {
self.data.space_usage()
}
/// Returns a handle to inner file

View File

@@ -30,30 +30,22 @@ fn load_metas(
directory: &dyn Directory,
inventory: &SegmentMetaInventory,
) -> crate::Result<IndexMeta> {
match directory.load_metas(inventory) {
Ok(metas) => Ok(metas),
Err(crate::TantivyError::InternalError(_)) => {
let meta_data = directory.atomic_read(&META_FILEPATH)?;
let meta_string = String::from_utf8(meta_data).map_err(|_utf8_err| {
error!("Meta data is not valid utf8.");
DataCorruption::new(
META_FILEPATH.to_path_buf(),
"Meta file does not contain valid utf8 file.".to_string(),
)
})?;
IndexMeta::deserialize(&meta_string, inventory)
.map_err(|e| {
DataCorruption::new(
META_FILEPATH.to_path_buf(),
format!(
"Meta file cannot be deserialized. {e:?}. Content: {meta_string:?}"
),
)
})
.map_err(From::from)
}
Err(err) => Err(err),
}
let meta_data = directory.atomic_read(&META_FILEPATH)?;
let meta_string = String::from_utf8(meta_data).map_err(|_utf8_err| {
error!("Meta data is not valid utf8.");
DataCorruption::new(
META_FILEPATH.to_path_buf(),
"Meta file does not contain valid utf8 file.".to_string(),
)
})?;
IndexMeta::deserialize(&meta_string, inventory)
.map_err(|e| {
DataCorruption::new(
META_FILEPATH.to_path_buf(),
format!("Meta file cannot be deserialized. {e:?}. Content: {meta_string:?}"),
)
})
.map_err(From::from)
}
/// Save the index meta file.
@@ -68,14 +60,16 @@ fn save_new_metas(
index_settings: IndexSettings,
directory: &dyn Directory,
) -> crate::Result<()> {
let empty_metas = IndexMeta {
index_settings,
segments: Vec::new(),
schema,
opstamp: 0u64,
payload: None,
};
save_metas(&empty_metas, &empty_metas, directory)?;
save_metas(
&IndexMeta {
index_settings,
segments: Vec::new(),
schema,
opstamp: 0u64,
payload: None,
},
directory,
)?;
directory.sync_directory()?;
Ok(())
}
@@ -588,7 +582,7 @@ impl Index {
num_threads: usize,
overall_memory_budget_in_bytes: usize,
) -> crate::Result<IndexWriter<D>> {
let memory_arena_in_bytes_per_thread = overall_memory_budget_in_bytes / num_threads.max(1);
let memory_arena_in_bytes_per_thread = overall_memory_budget_in_bytes / num_threads;
let options = IndexWriterOptions::builder()
.num_worker_threads(num_threads)
.memory_budget_per_thread(memory_arena_in_bytes_per_thread)
@@ -661,11 +655,9 @@ impl Index {
/// Creates a new segment.
pub fn new_segment(&self) -> Segment {
self.new_segment_with_id(SegmentId::generate_random())
}
pub fn new_segment_with_id(&self, segment_id: SegmentId) -> Segment {
let segment_meta = self.inventory.new_segment_meta(segment_id, 0);
let segment_meta = self
.inventory
.new_segment_meta(SegmentId::generate_random(), 0);
self.segment(segment_meta)
}
@@ -696,7 +688,7 @@ impl Index {
/// Returns the set of corrupted files
pub fn validate_checksum(&self) -> crate::Result<HashSet<PathBuf>> {
let managed_files = self.directory.list_managed_files()?;
let managed_files = self.directory.list_managed_files();
let active_segments_files: HashSet<PathBuf> = self
.searchable_segment_metas()?
.iter()

View File

@@ -1,6 +1,5 @@
use std::collections::HashSet;
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::path::PathBuf;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
@@ -14,24 +13,16 @@ use crate::store::Compressor;
use crate::{Inventory, Opstamp, TrackedObject};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DeleteMeta {
pub num_deleted_docs: u32,
pub opstamp: Opstamp,
struct DeleteMeta {
num_deleted_docs: u32,
opstamp: Opstamp,
}
#[derive(Clone, Default)]
pub struct SegmentMetaInventory {
pub(crate) struct SegmentMetaInventory {
inventory: Inventory<InnerSegmentMeta>,
}
impl Debug for SegmentMetaInventory {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("SegmentMetaInventory")
.field("inventory", &self.inventory.list())
.finish()
}
}
impl SegmentMetaInventory {
/// Lists all living `SegmentMeta` object at the time of the call.
pub fn all(&self) -> Vec<SegmentMeta> {
@@ -59,7 +50,7 @@ impl SegmentMetaInventory {
/// how many are deleted, etc.
#[derive(Clone)]
pub struct SegmentMeta {
pub tracked: TrackedObject<InnerSegmentMeta>,
tracked: TrackedObject<InnerSegmentMeta>,
}
impl fmt::Debug for SegmentMeta {
@@ -219,15 +210,15 @@ impl SegmentMeta {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InnerSegmentMeta {
pub segment_id: SegmentId,
pub max_doc: u32,
pub deletes: Option<DeleteMeta>,
struct InnerSegmentMeta {
segment_id: SegmentId,
max_doc: u32,
deletes: Option<DeleteMeta>,
/// If you want to avoid the SegmentComponent::TempStore file to be covered by
/// garbage collection and deleted, set this to true. This is used during merge.
#[serde(skip)]
#[serde(default = "default_temp_store")]
pub include_temp_doc_store: Arc<AtomicBool>,
pub(crate) include_temp_doc_store: Arc<AtomicBool>,
}
fn default_temp_store() -> Arc<AtomicBool> {
Arc::new(AtomicBool::new(false))

View File

@@ -1,6 +1,5 @@
use std::io;
use common::file_slice::DeferredFileSlice;
use common::json_path_writer::JSON_END_OF_PATH;
use common::{BinarySerializable, ByteCount};
#[cfg(feature = "quickwit")]
@@ -31,7 +30,7 @@ use crate::termdict::TermDictionary;
pub struct InvertedIndexReader {
termdict: TermDictionary,
postings_file_slice: FileSlice,
positions_file_slice: DeferredFileSlice,
positions_file_slice: FileSlice,
record_option: IndexRecordOption,
total_num_tokens: u64,
}
@@ -67,7 +66,7 @@ impl InvertedIndexReader {
pub(crate) fn new(
termdict: TermDictionary,
postings_file_slice: FileSlice,
positions_file_slice: DeferredFileSlice,
positions_file_slice: FileSlice,
record_option: IndexRecordOption,
) -> io::Result<InvertedIndexReader> {
let (total_num_tokens_slice, postings_body) = postings_file_slice.split(8);
@@ -87,7 +86,7 @@ impl InvertedIndexReader {
InvertedIndexReader {
termdict: TermDictionary::empty(),
postings_file_slice: FileSlice::empty(),
positions_file_slice: DeferredFileSlice::new(|| Ok(FileSlice::empty())),
positions_file_slice: FileSlice::empty(),
record_option,
total_num_tokens: 0u64,
}
@@ -212,7 +211,7 @@ impl InvertedIndexReader {
.slice(term_info.postings_range.clone());
BlockSegmentPostings::open(
term_info.doc_freq,
postings_data.read_bytes()?,
postings_data,
self.record_option,
requested_option,
)
@@ -234,7 +233,6 @@ impl InvertedIndexReader {
if option.has_positions() {
let positions_data = self
.positions_file_slice
.open()?
.read_bytes_slice(term_info.positions_range.clone())?;
let position_reader = PositionReader::open(positions_data)?;
Some(position_reader)
@@ -344,7 +342,6 @@ impl InvertedIndexReader {
if with_positions {
let positions = self
.positions_file_slice
.open()?
.read_bytes_slice_async(term_info.positions_range.clone());
futures_util::future::try_join(postings, positions).await?;
} else {
@@ -387,7 +384,6 @@ impl InvertedIndexReader {
if with_positions {
let positions = self
.positions_file_slice
.open()?
.read_bytes_slice_async(positions_range);
futures_util::future::try_join(postings, positions).await?;
} else {
@@ -482,7 +478,7 @@ impl InvertedIndexReader {
pub async fn warm_postings_full(&self, with_positions: bool) -> io::Result<()> {
self.postings_file_slice.read_bytes_async().await?;
if with_positions {
self.positions_file_slice.open()?.read_bytes_async().await?;
self.positions_file_slice.read_bytes_async().await?;
}
Ok(())
}

Some files were not shown because too many files have changed in this diff Show More