From 9615eb73b859ccbd44f485d6aeea03c05aaa7977 Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Sat, 1 Nov 2025 15:46:46 -0700 Subject: [PATCH] Implement `collect_block` for lazy scorers using `SegmentSortKeyComputer::segment_sort_keys`. --- columnar/benches/bench_access.rs | 14 +- columnar/src/column/mod.rs | 203 +++++++-- .../src/column_index/multivalued_index.rs | 4 +- columnar/src/column_values/mod.rs | 232 +++++++++- .../src/column_values/monotonic_column.rs | 36 +- .../u128_based/compact_space/mod.rs | 72 +++- .../src/column_values/u64_based/bitpacked.rs | 176 +++++++- columnar/src/column_values/u64_based/tests.rs | 2 +- columnar/src/lib.rs | 2 +- src/collector/sort_key/mod.rs | 109 +++++ src/collector/sort_key/order.rs | 142 ++++++- src/collector/sort_key/sort_by_erased_type.rs | 37 +- src/collector/sort_key/sort_by_score.rs | 18 +- .../sort_key/sort_by_static_fast_value.rs | 104 ++++- src/collector/sort_key/sort_by_string.rs | 123 +++++- src/collector/sort_key/sort_key_computer.rs | 401 ++++++++++++++---- src/collector/sort_key_top_collector.rs | 5 + src/collector/top_score_collector.rs | 51 ++- src/fastfield/mod.rs | 6 +- .../range_query/fast_field_range_doc_set.rs | 9 +- .../range_query/range_query_fastfield.rs | 6 +- 21 files changed, 1552 insertions(+), 200 deletions(-) diff --git a/columnar/benches/bench_access.rs b/columnar/benches/bench_access.rs index 397a35af0..20ed759f3 100644 --- a/columnar/benches/bench_access.rs +++ b/columnar/benches/bench_access.rs @@ -1,6 +1,6 @@ use binggan::{InputGroup, black_box}; use common::*; -use tantivy_columnar::Column; +use tantivy_columnar::{Column, ValueRange}; pub mod common; @@ -46,16 +46,16 @@ fn bench_group(mut runner: InputGroup) { runner.register("access_first_vals", |column| { let mut sum = 0; const BLOCK_SIZE: usize = 32; - let mut docs = vec![0; BLOCK_SIZE]; - let mut buffer = vec![None; BLOCK_SIZE]; + let mut docs = Vec::with_capacity(BLOCK_SIZE); + let mut buffer = Vec::with_capacity(BLOCK_SIZE); for i in (0..NUM_DOCS).step_by(BLOCK_SIZE) { - // fill docs - #[allow(clippy::needless_range_loop)] + docs.clear(); for idx in 0..BLOCK_SIZE { - docs[idx] = idx as u32 + i; + docs.push(idx as u32 + i); } - column.first_vals(&docs, &mut buffer); + buffer.clear(); + column.first_vals_in_value_range(&mut docs, &mut buffer, ValueRange::All); for val in buffer.iter() { let Some(val) = val else { continue }; sum += *val; diff --git a/columnar/src/column/mod.rs b/columnar/src/column/mod.rs index cc2938bb8..9954dce01 100644 --- a/columnar/src/column/mod.rs +++ b/columnar/src/column/mod.rs @@ -89,31 +89,6 @@ impl Column { self.values_for_doc(row_id).next() } - /// Load the first value for each docid in the provided slice. - #[inline] - pub fn first_vals(&self, docids: &[DocId], output: &mut [Option]) { - match &self.index { - ColumnIndex::Empty { .. } => {} - ColumnIndex::Full => self.values.get_vals_opt(docids, output), - ColumnIndex::Optional(optional_index) => { - for (i, docid) in docids.iter().enumerate() { - output[i] = optional_index - .rank_if_exists(*docid) - .map(|rowid| self.values.get_val(rowid)); - } - } - ColumnIndex::Multivalued(multivalued_index) => { - for (i, docid) in docids.iter().enumerate() { - let range = multivalued_index.range(*docid); - let is_empty = range.start == range.end; - if !is_empty { - output[i] = Some(self.values.get_val(range.start)); - } - } - } - } - } - /// Translates a block of docids to row_ids. /// /// returns the row_ids and the matching docids on the same index @@ -143,7 +118,7 @@ impl Column { #[inline] pub fn get_docids_for_value_range( &self, - value_range: RangeInclusive, + value_range: ValueRange, selected_docid_range: Range, doc_ids: &mut Vec, ) { @@ -168,6 +143,182 @@ impl Column { } } +// Separate impl block for methods requiring `Default` for `T`. +impl Column { + /// Load the first value for each docid in the provided slice. + #[inline] + pub fn first_vals_in_value_range( + &self, + docids: &mut Vec, + values: &mut Vec>, + value_range: ValueRange, + ) { + const BLOCK_LEN: usize = 64; // Corresponds to COLLECT_BLOCK_BUFFER_LEN in tantivy's docset + match (&self.index, value_range) { + (ColumnIndex::Empty { .. }, value_range) => { + let nulls_match = match &value_range { + ValueRange::All => true, + ValueRange::Inclusive(_) => false, + ValueRange::GreaterThan(_, nulls_match) => *nulls_match, + ValueRange::LessThan(_, nulls_match) => *nulls_match, + }; + if nulls_match { + for _ in 0..docids.len() { + values.push(None); + } + } else { + docids.clear(); + } + } + (ColumnIndex::Full, value_range) => { + self.values + .get_vals_in_value_range(docids, values, value_range); + } + (ColumnIndex::Optional(optional_index), value_range) => { + let len = docids.len(); + // Ensure the input docids length does not exceed BLOCK_LEN for stack allocation + // safety. If it does, we might need to handle this with multiple + // chunks or fallback to heap. For now, an assert is used to confirm + // expected usage within batch processing limits. + assert!( + len <= BLOCK_LEN, + "Input docids length ({}) exceeds BLOCK_LEN ({})", + len, + BLOCK_LEN + ); + + let mut input_docs_buffer = [0u32; BLOCK_LEN]; + input_docs_buffer[..len].copy_from_slice(docids); + + let mut dense_row_ids_buffer = [0u32; BLOCK_LEN]; + let mut dense_values_buffer = [T::default(); BLOCK_LEN]; + let mut presence_mask: u64 = 0; // Bitmask to track which input_docs have a value + let mut num_present = 0; + + // Phase 1: Identify existing RowIds and build dense_row_ids_buffer + for (i, &doc_id) in input_docs_buffer[..len].iter().enumerate() { + if let Some(row_id) = optional_index.rank_if_exists(doc_id) { + dense_row_ids_buffer[num_present] = row_id; + presence_mask |= 1u64 << i; // Set bit for present docid + num_present += 1; + } + } + + // Phase 2: Batch fetch values for present docs + if num_present > 0 { + self.values.get_vals( + &dense_row_ids_buffer[..num_present], + &mut dense_values_buffer[..num_present], + ); + } + + // Determine if nulls match the value range + let nulls_match = match &value_range { + ValueRange::All => true, + ValueRange::Inclusive(_) => false, + ValueRange::GreaterThan(_, nulls_match) => *nulls_match, + ValueRange::LessThan(_, nulls_match) => *nulls_match, + }; + + // Phase 3: Filter and merge results, reconstructing docids and values + docids.clear(); + values.clear(); + + let mut dense_values_cursor = 0; + for i in 0..len { + let original_doc_id = input_docs_buffer[i]; + if (presence_mask & (1u64 << i)) != 0 { + // This doc_id was present in the optional index and has a value + let val = dense_values_buffer[dense_values_cursor]; + dense_values_cursor += 1; + + // Check if the value matches the value range + let value_matches = match &value_range { + ValueRange::All => true, + ValueRange::Inclusive(r) => r.contains(&val), + ValueRange::GreaterThan(t, _) => val > *t, + ValueRange::LessThan(t, _) => val < *t, + }; + + if value_matches { + docids.push(original_doc_id); + values.push(Some(val)); + } + } else if nulls_match { + // This doc_id was not present in the optional index (null) and nulls match + docids.push(original_doc_id); + values.push(None); + } + } + } + (ColumnIndex::Multivalued(multivalued_index), value_range) => { + let nulls_match = match &value_range { + ValueRange::All => true, + ValueRange::Inclusive(_) => false, + ValueRange::GreaterThan(_, nulls_match) => *nulls_match, + ValueRange::LessThan(_, nulls_match) => *nulls_match, + }; + let mut write_head = 0; + for i in 0..docids.len() { + let docid = docids[i]; + let row_range = multivalued_index.range(docid); + let is_empty = row_range.start == row_range.end; + if !is_empty { + let val = self.values.get_val(row_range.start); + let matches = match &value_range { + ValueRange::All => true, + ValueRange::Inclusive(r) => r.contains(&val), + ValueRange::GreaterThan(t, _) => val > *t, + ValueRange::LessThan(t, _) => val < *t, + }; + if matches { + docids[write_head] = docid; + values.push(Some(val)); + write_head += 1; + } + } else if nulls_match { + docids[write_head] = docid; + values.push(None); + write_head += 1; + } + } + docids.truncate(write_head); + } + } + } +} + +/// A range of values. +/// +/// This type is intended to be used in batch APIs, where the cost of unpacking the enum +/// is outweighed by the time spent processing a batch. +/// +/// Implementers should pattern match on the variants to use optimized loops for each case. +#[derive(Clone, Debug)] +pub enum ValueRange { + /// A range that includes both start and end. + Inclusive(RangeInclusive), + /// A range that matches all values. + All, + /// A range that matches all values greater than the threshold. + /// The boolean flag indicates if null values should be included. + GreaterThan(T, bool), + /// A range that matches all values less than the threshold. + /// The boolean flag indicates if null values should be included. + LessThan(T, bool), +} + +impl ValueRange { + pub fn intersects(&self, min: T, max: T) -> bool { + match self { + ValueRange::Inclusive(range) => *range.start() <= max && *range.end() >= min, + ValueRange::All => true, + ValueRange::GreaterThan(val, _) => max > *val, + ValueRange::LessThan(val, _) => min < *val, + } + } +} + impl BinarySerializable for Cardinality { fn serialize(&self, writer: &mut W) -> std::io::Result<()> { self.to_code().serialize(writer) diff --git a/columnar/src/column_index/multivalued_index.rs b/columnar/src/column_index/multivalued_index.rs index 883475a1b..336e83e31 100644 --- a/columnar/src/column_index/multivalued_index.rs +++ b/columnar/src/column_index/multivalued_index.rs @@ -333,7 +333,7 @@ mod tests { use std::ops::Range; use super::MultiValueIndex; - use crate::{ColumnarReader, DynamicColumn}; + use crate::{ColumnarReader, DynamicColumn, ValueRange}; fn index_to_pos_helper( index: &MultiValueIndex, @@ -413,7 +413,7 @@ mod tests { assert_eq!(row_id_range, 0..4); let check = |range, expected| { - let full_range = 0..=u64::MAX; + let full_range = ValueRange::All; let mut docids = Vec::new(); column.get_docids_for_value_range(full_range, range, &mut docids); assert_eq!(docids, expected); diff --git a/columnar/src/column_values/mod.rs b/columnar/src/column_values/mod.rs index f26bf6d33..eacb83c68 100644 --- a/columnar/src/column_values/mod.rs +++ b/columnar/src/column_values/mod.rs @@ -7,13 +7,15 @@ //! - Monotonically map values to u64/u128 use std::fmt::Debug; -use std::ops::{Range, RangeInclusive}; +use std::ops::Range; use std::sync::Arc; use downcast_rs::DowncastSync; pub use monotonic_mapping::{MonotonicallyMappableToU64, StrictlyMonotonicFn}; pub use monotonic_mapping_u128::MonotonicallyMappableToU128; +use crate::column::ValueRange; + mod merge; pub(crate) mod monotonic_mapping; pub(crate) mod monotonic_mapping_u128; @@ -109,6 +111,178 @@ pub trait ColumnValues: Send + Sync + DowncastSync { } } + /// Load the values for the provided docids. + /// + /// The values are filtered by the provided value range. + fn get_vals_in_value_range( + &self, + indexes: &mut Vec, + output: &mut Vec>, + value_range: ValueRange, + ) { + let mut write_head = 0; + let mut read_head = 0; + let len = indexes.len(); + + match value_range { + ValueRange::All => { + while read_head + 3 < len { + let idx0 = indexes[read_head]; + let idx1 = indexes[read_head + 1]; + let idx2 = indexes[read_head + 2]; + let idx3 = indexes[read_head + 3]; + + let val0 = self.get_val(idx0); + let val1 = self.get_val(idx1); + let val2 = self.get_val(idx2); + let val3 = self.get_val(idx3); + + indexes[write_head] = idx0; + output.push(Some(val0)); + write_head += 1; + indexes[write_head] = idx1; + output.push(Some(val1)); + write_head += 1; + indexes[write_head] = idx2; + output.push(Some(val2)); + write_head += 1; + indexes[write_head] = idx3; + output.push(Some(val3)); + write_head += 1; + + read_head += 4; + } + } + ValueRange::Inclusive(ref range) => { + while read_head + 3 < len { + let idx0 = indexes[read_head]; + let idx1 = indexes[read_head + 1]; + let idx2 = indexes[read_head + 2]; + let idx3 = indexes[read_head + 3]; + + let val0 = self.get_val(idx0); + let val1 = self.get_val(idx1); + let val2 = self.get_val(idx2); + let val3 = self.get_val(idx3); + + if range.contains(&val0) { + indexes[write_head] = idx0; + output.push(Some(val0)); + write_head += 1; + } + if range.contains(&val1) { + indexes[write_head] = idx1; + output.push(Some(val1)); + write_head += 1; + } + if range.contains(&val2) { + indexes[write_head] = idx2; + output.push(Some(val2)); + write_head += 1; + } + if range.contains(&val3) { + indexes[write_head] = idx3; + output.push(Some(val3)); + write_head += 1; + } + + read_head += 4; + } + } + ValueRange::GreaterThan(ref threshold, _) => { + while read_head + 3 < len { + let idx0 = indexes[read_head]; + let idx1 = indexes[read_head + 1]; + let idx2 = indexes[read_head + 2]; + let idx3 = indexes[read_head + 3]; + + let val0 = self.get_val(idx0); + let val1 = self.get_val(idx1); + let val2 = self.get_val(idx2); + let val3 = self.get_val(idx3); + + if val0 > *threshold { + indexes[write_head] = idx0; + output.push(Some(val0)); + write_head += 1; + } + if val1 > *threshold { + indexes[write_head] = idx1; + output.push(Some(val1)); + write_head += 1; + } + if val2 > *threshold { + indexes[write_head] = idx2; + output.push(Some(val2)); + write_head += 1; + } + if val3 > *threshold { + indexes[write_head] = idx3; + output.push(Some(val3)); + write_head += 1; + } + + read_head += 4; + } + } + ValueRange::LessThan(ref threshold, _) => { + while read_head + 3 < len { + let idx0 = indexes[read_head]; + let idx1 = indexes[read_head + 1]; + let idx2 = indexes[read_head + 2]; + let idx3 = indexes[read_head + 3]; + + let val0 = self.get_val(idx0); + let val1 = self.get_val(idx1); + let val2 = self.get_val(idx2); + let val3 = self.get_val(idx3); + + if val0 < *threshold { + indexes[write_head] = idx0; + output.push(Some(val0)); + write_head += 1; + } + if val1 < *threshold { + indexes[write_head] = idx1; + output.push(Some(val1)); + write_head += 1; + } + if val2 < *threshold { + indexes[write_head] = idx2; + output.push(Some(val2)); + write_head += 1; + } + if val3 < *threshold { + indexes[write_head] = idx3; + output.push(Some(val3)); + write_head += 1; + } + + read_head += 4; + } + } + } + // Process remaining elements (0 to 3) + while read_head < len { + let idx = indexes[read_head]; + let val = self.get_val(idx); + let matches = match value_range { + // 'value_range' is still moved here. This is the outer `value_range` + ValueRange::All => true, + ValueRange::Inclusive(ref r) => r.contains(&val), + ValueRange::GreaterThan(ref t, _) => val > *t, + ValueRange::LessThan(ref t, _) => val < *t, + }; + if matches { + indexes[write_head] = idx; + output.push(Some(val)); + write_head += 1; + } + read_head += 1; + } + indexes.truncate(write_head); + } + /// Fills an output buffer with the fast field values /// associated with the `DocId` going from /// `start` to `start + output.len()`. @@ -129,15 +303,38 @@ pub trait ColumnValues: Send + Sync + DowncastSync { /// Note that position == docid for single value fast fields fn get_row_ids_for_value_range( &self, - value_range: RangeInclusive, + value_range: ValueRange, row_id_range: Range, row_id_hits: &mut Vec, ) { let row_id_range = row_id_range.start..row_id_range.end.min(self.num_vals()); - for idx in row_id_range { - let val = self.get_val(idx); - if value_range.contains(&val) { - row_id_hits.push(idx); + match value_range { + ValueRange::Inclusive(range) => { + for idx in row_id_range { + let val = self.get_val(idx); + if range.contains(&val) { + row_id_hits.push(idx); + } + } + } + ValueRange::GreaterThan(threshold, _) => { + for idx in row_id_range { + let val = self.get_val(idx); + if val > threshold { + row_id_hits.push(idx); + } + } + } + ValueRange::LessThan(threshold, _) => { + for idx in row_id_range { + let val = self.get_val(idx); + if val < threshold { + row_id_hits.push(idx); + } + } + } + ValueRange::All => { + row_id_hits.extend(row_id_range); } } } @@ -193,6 +390,16 @@ impl ColumnValues for EmptyColumnValues { fn num_vals(&self) -> u32 { 0 } + + fn get_vals_in_value_range( + &self, + indexes: &mut Vec, + output: &mut Vec>, + value_range: ValueRange, + ) { + let _ = (indexes, output, value_range); + panic!("Internal Error: Called get_vals_in_value_range of empty column.") + } } impl ColumnValues for Arc> { @@ -206,6 +413,17 @@ impl ColumnValues for Arc, + output: &mut Vec>, + value_range: ValueRange, + ) { + self.as_ref() + .get_vals_in_value_range(indexes, output, value_range) + } + #[inline(always)] fn min_value(&self) -> T { self.as_ref().min_value() @@ -234,7 +452,7 @@ impl ColumnValues for Arc, + range: ValueRange, doc_id_range: Range, positions: &mut Vec, ) { diff --git a/columnar/src/column_values/monotonic_column.rs b/columnar/src/column_values/monotonic_column.rs index 35de3787a..e8586ad7a 100644 --- a/columnar/src/column_values/monotonic_column.rs +++ b/columnar/src/column_values/monotonic_column.rs @@ -1,8 +1,9 @@ use std::fmt::Debug; use std::marker::PhantomData; -use std::ops::{Range, RangeInclusive}; +use std::ops::Range; use crate::ColumnValues; +use crate::column::ValueRange; use crate::column_values::monotonic_mapping::StrictlyMonotonicFn; struct MonotonicMappingColumn { @@ -80,16 +81,35 @@ where fn get_row_ids_for_value_range( &self, - range: RangeInclusive, + range: ValueRange, doc_id_range: Range, positions: &mut Vec, ) { - self.from_column.get_row_ids_for_value_range( - self.monotonic_mapping.inverse(range.start().clone()) - ..=self.monotonic_mapping.inverse(range.end().clone()), - doc_id_range, - positions, - ) + match range { + ValueRange::Inclusive(range) => self.from_column.get_row_ids_for_value_range( + ValueRange::Inclusive( + self.monotonic_mapping.inverse(range.start().clone()) + ..=self.monotonic_mapping.inverse(range.end().clone()), + ), + doc_id_range, + positions, + ), + ValueRange::All => self.from_column.get_row_ids_for_value_range( + ValueRange::All, + doc_id_range, + positions, + ), + ValueRange::GreaterThan(threshold, _) => self.from_column.get_row_ids_for_value_range( + ValueRange::GreaterThan(self.monotonic_mapping.inverse(threshold), false), + doc_id_range, + positions, + ), + ValueRange::LessThan(threshold, _) => self.from_column.get_row_ids_for_value_range( + ValueRange::LessThan(self.monotonic_mapping.inverse(threshold), false), + doc_id_range, + positions, + ), + } } // We voluntarily do not implement get_range as it yields a regression, diff --git a/columnar/src/column_values/u128_based/compact_space/mod.rs b/columnar/src/column_values/u128_based/compact_space/mod.rs index 2c815bdce..c851d7f49 100644 --- a/columnar/src/column_values/u128_based/compact_space/mod.rs +++ b/columnar/src/column_values/u128_based/compact_space/mod.rs @@ -25,6 +25,7 @@ use common::{BinarySerializable, CountingWriter, OwnedBytes, VInt, VIntU128}; use tantivy_bitpacker::{BitPacker, BitUnpacker}; use crate::RowId; +use crate::column::ValueRange; use crate::column_values::ColumnValues; /// The cost per blank is quite hard actually, since blanks are delta encoded, the actual cost of @@ -338,14 +339,36 @@ impl ColumnValues for CompactSpaceU64Accessor { #[inline] fn get_row_ids_for_value_range( &self, - value_range: RangeInclusive, + value_range: ValueRange, position_range: Range, positions: &mut Vec, ) { - let value_range = self.0.compact_to_u128(*value_range.start() as u32) - ..=self.0.compact_to_u128(*value_range.end() as u32); - self.0 - .get_row_ids_for_value_range(value_range, position_range, positions) + match value_range { + ValueRange::Inclusive(value_range) => { + let value_range = ValueRange::Inclusive( + self.0.compact_to_u128(*value_range.start() as u32) + ..=self.0.compact_to_u128(*value_range.end() as u32), + ); + self.0 + .get_row_ids_for_value_range(value_range, position_range, positions) + } + ValueRange::All => { + let position_range = position_range.start..position_range.end.min(self.num_vals()); + positions.extend(position_range); + } + ValueRange::GreaterThan(threshold, _) => { + let value_range = + ValueRange::GreaterThan(self.0.compact_to_u128(threshold as u32), false); + self.0 + .get_row_ids_for_value_range(value_range, position_range, positions) + } + ValueRange::LessThan(threshold, _) => { + let value_range = + ValueRange::LessThan(self.0.compact_to_u128(threshold as u32), false); + self.0 + .get_row_ids_for_value_range(value_range, position_range, positions) + } + } } } @@ -375,10 +398,33 @@ impl ColumnValues for CompactSpaceDecompressor { #[inline] fn get_row_ids_for_value_range( &self, - value_range: RangeInclusive, + value_range: ValueRange, position_range: Range, positions: &mut Vec, ) { + let value_range = match value_range { + ValueRange::Inclusive(value_range) => value_range, + ValueRange::All => { + let position_range = position_range.start..position_range.end.min(self.num_vals()); + positions.extend(position_range); + return; + } + ValueRange::GreaterThan(threshold, _) => { + let max = self.max_value(); + if threshold >= max { + return; + } + (threshold + 1)..=max + } + ValueRange::LessThan(threshold, _) => { + let min = self.min_value(); + if threshold <= min { + return; + } + min..=(threshold - 1) + } + }; + if value_range.start() > value_range.end() { return; } @@ -560,7 +606,7 @@ mod tests { .collect::>(); let mut positions = Vec::new(); decompressor.get_row_ids_for_value_range( - range, + ValueRange::Inclusive(range), 0..decompressor.num_vals(), &mut positions, ); @@ -604,7 +650,11 @@ mod tests { let val = *val; let pos = pos as u32; let mut positions = Vec::new(); - decomp.get_row_ids_for_value_range(val..=val, pos..pos + 1, &mut positions); + decomp.get_row_ids_for_value_range( + ValueRange::Inclusive(val..=val), + pos..pos + 1, + &mut positions, + ); assert_eq!(positions, vec![pos]); } @@ -746,7 +796,11 @@ mod tests { doc_id_range: Range, ) -> Vec { let mut positions = Vec::new(); - column.get_row_ids_for_value_range(value_range, doc_id_range, &mut positions); + column.get_row_ids_for_value_range( + ValueRange::Inclusive(value_range), + doc_id_range, + &mut positions, + ); positions } diff --git a/columnar/src/column_values/u64_based/bitpacked.rs b/columnar/src/column_values/u64_based/bitpacked.rs index 71319cbec..647728863 100644 --- a/columnar/src/column_values/u64_based/bitpacked.rs +++ b/columnar/src/column_values/u64_based/bitpacked.rs @@ -6,6 +6,7 @@ use common::{BinarySerializable, OwnedBytes}; use fastdivide::DividerU64; use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits}; +use crate::column::ValueRange; use crate::column_values::u64_based::{ColumnCodec, ColumnCodecEstimator, ColumnStats}; use crate::{ColumnValues, RowId}; @@ -66,24 +67,173 @@ impl ColumnValues for BitpackedReader { self.stats.num_rows } + fn get_vals_in_value_range( + &self, + indexes: &mut Vec, + output: &mut Vec>, + value_range: ValueRange, + ) { + let mut write_head = 0; + match value_range { + ValueRange::All => { + for i in 0..indexes.len() { + let idx = indexes[i]; + indexes[write_head] = idx; + output.push(Some(self.get_val(idx))); + write_head += 1; + } + } + ValueRange::Inclusive(range) => { + if let Some(transformed_range) = + transform_range_before_linear_transformation(&self.stats, range) + { + for i in 0..indexes.len() { + let doc = indexes[i]; + let raw_val = self.get_val(doc); + if transformed_range.contains(&raw_val) { + indexes[write_head] = doc; + output + .push(Some(self.stats.min_value + self.stats.gcd.get() * raw_val)); + write_head += 1; + } + } + } + } + ValueRange::GreaterThan(threshold, _) => { + if threshold < self.stats.min_value { + for i in 0..indexes.len() { + let idx = indexes[i]; + indexes[write_head] = idx; + output.push(Some(self.get_val(idx))); + write_head += 1; + } + } else if threshold >= self.stats.max_value { + // All filtered out + } else { + let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get(); + for i in 0..indexes.len() { + let doc = indexes[i]; + let raw_val = self.get_val(doc); + if raw_val > raw_threshold { + indexes[write_head] = doc; + output + .push(Some(self.stats.min_value + self.stats.gcd.get() * raw_val)); + write_head += 1; + } + } + } + } + ValueRange::LessThan(threshold, _) => { + if threshold > self.stats.max_value { + for i in 0..indexes.len() { + let idx = indexes[i]; + indexes[write_head] = idx; + output.push(Some(self.get_val(idx))); + write_head += 1; + } + } else if threshold <= self.stats.min_value { + // All filtered out + } else { + let diff = threshold - self.stats.min_value; + let gcd = self.stats.gcd.get(); + let raw_threshold = if diff % gcd == 0 { + diff / gcd + } else { + diff / gcd + 1 + }; + + for i in 0..indexes.len() { + let doc = indexes[i]; + let raw_val = self.get_val(doc); + if raw_val < raw_threshold { + indexes[write_head] = doc; + output + .push(Some(self.stats.min_value + self.stats.gcd.get() * raw_val)); + write_head += 1; + } + } + } + } + } + indexes.truncate(write_head); + } fn get_row_ids_for_value_range( &self, - range: RangeInclusive, + range: ValueRange, doc_id_range: Range, positions: &mut Vec, ) { - let Some(transformed_range) = - transform_range_before_linear_transformation(&self.stats, range) - else { - positions.clear(); - return; - }; - self.bit_unpacker.get_ids_for_value_range( - transformed_range, - doc_id_range, - &self.data, - positions, - ); + match range { + ValueRange::All => { + positions.extend(doc_id_range); + return; + } + ValueRange::Inclusive(range) => { + let Some(transformed_range) = + transform_range_before_linear_transformation(&self.stats, range) + else { + positions.clear(); + return; + }; + + self.bit_unpacker.get_ids_for_value_range( + transformed_range, + doc_id_range, + &self.data, + positions, + ); + } + ValueRange::GreaterThan(threshold, _) => { + if threshold < self.stats.min_value { + positions.extend(doc_id_range); + return; + } + if threshold >= self.stats.max_value { + return; + } + let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get(); + let max_raw = (self.stats.max_value - self.stats.min_value) / self.stats.gcd.get(); + let transformed_range = (raw_threshold + 1)..=max_raw; + + self.bit_unpacker.get_ids_for_value_range( + transformed_range, + doc_id_range, + &self.data, + positions, + ); + } + ValueRange::LessThan(threshold, _) => { + if threshold > self.stats.max_value { + positions.extend(doc_id_range); + return; + } + if threshold <= self.stats.min_value { + return; + } + + let diff = threshold - self.stats.min_value; + let gcd = self.stats.gcd.get(); + // We want raw < raw_threshold_limit + // raw <= raw_threshold_limit - 1 + let raw_threshold_limit = if diff % gcd == 0 { + diff / gcd + } else { + diff / gcd + 1 + }; + + if raw_threshold_limit == 0 { + return; + } + let transformed_range = 0..=(raw_threshold_limit - 1); + + self.bit_unpacker.get_ids_for_value_range( + transformed_range, + doc_id_range, + &self.data, + positions, + ); + } + } } } diff --git a/columnar/src/column_values/u64_based/tests.rs b/columnar/src/column_values/u64_based/tests.rs index 6b2697263..5c20791f4 100644 --- a/columnar/src/column_values/u64_based/tests.rs +++ b/columnar/src/column_values/u64_based/tests.rs @@ -131,7 +131,7 @@ pub(crate) fn create_and_validate( .collect(); let mut positions = Vec::new(); reader.get_row_ids_for_value_range( - vals[test_rand_idx]..=vals[test_rand_idx], + crate::column::ValueRange::Inclusive(vals[test_rand_idx]..=vals[test_rand_idx]), 0..vals.len() as u32, &mut positions, ); diff --git a/columnar/src/lib.rs b/columnar/src/lib.rs index 537c52562..ce499f72b 100644 --- a/columnar/src/lib.rs +++ b/columnar/src/lib.rs @@ -36,7 +36,7 @@ pub(crate) mod utils; mod value; pub use block_accessor::ColumnBlockAccessor; -pub use column::{BytesColumn, Column, StrColumn}; +pub use column::{BytesColumn, Column, StrColumn, ValueRange}; pub use column_index::ColumnIndex; pub use column_values::{ ColumnValues, EmptyColumnValues, MonotonicallyMappableToU64, MonotonicallyMappableToU128, diff --git a/src/collector/sort_key/mod.rs b/src/collector/sort_key/mod.rs index 391873298..95f8901b6 100644 --- a/src/collector/sort_key/mod.rs +++ b/src/collector/sort_key/mod.rs @@ -389,6 +389,52 @@ pub(crate) mod tests { Ok(()) } + #[test] + fn test_order_by_compound_fast_fields() -> crate::Result<()> { + let index = make_index()?; + + type CompoundSortKey = (Option, Option); + + fn assert_query( + index: &Index, + city_order: Order, + altitude_order: Order, + expected: Vec<(CompoundSortKey, u64)>, + ) -> crate::Result<()> { + let searcher = index.reader()?.searcher(); + let ids = id_mapping(&searcher); + + let top_collector = TopDocs::with_limit(4).order_by(( + (SortByString::for_field("city"), city_order), + ( + SortByStaticFastValue::::for_field("altitude"), + altitude_order, + ), + )); + let actual = searcher + .search(&AllQuery, &top_collector)? + .into_iter() + .map(|(key, doc)| (key, ids[&doc])) + .collect::>(); + assert_eq!(actual, expected); + Ok(()) + } + + assert_query( + &index, + Order::Asc, + Order::Desc, + vec![ + ((Some("austin".to_owned()), Some(149.0)), 0), + ((Some("greenville".to_owned()), Some(27.0)), 1), + ((Some("tokyo".to_owned()), Some(40.0)), 2), + ((None, Some(0.0)), 3), + ], + )?; + + Ok(()) + } + use proptest::prelude::*; proptest! { @@ -451,4 +497,67 @@ pub(crate) mod tests { ); } } + + #[test] + fn test_order_by_compound_filtering_with_none() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let city = schema_builder.add_text_field("city", TEXT | FAST); + let altitude = schema_builder.add_u64_field("altitude", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests()?; + + // Add enough docs to trigger thresholding. + // We want to sort by City Asc, Altitude Asc. + // Note: In NaturalComparator, None < Some. + // So Ascending order should be: None, then "a", then "b", then "c". + + // Docs: + // 0: "c", 10 + // 1: "b", 10 + // 2: "a", 20 + // 3: "a", 10 + // 4: None, 5 + + // Expected Ascending Order (None is Last in Tantivy's Order::Asc): + // 1. Doc 3 ("a", 10) + // 2. Doc 2 ("a", 20) + // 3. Doc 1 ("b", 10) + // 4. Doc 0 ("c", 10) + // 5. Doc 4 (None, 5) + + index_writer.add_document(doc!(city => "c", altitude => 10u64))?; + index_writer.add_document(doc!(city => "b", altitude => 10u64))?; + index_writer.add_document(doc!(city => "a", altitude => 20u64))?; + index_writer.add_document(doc!(city => "a", altitude => 10u64))?; + index_writer.add_document(doc!(altitude => 5u64))?; // City is None + + index_writer.commit()?; + + let searcher = index.reader()?.searcher(); + + // Use limit(2) to force a threshold update after the first few docs. + // The collector should eventually establish a threshold around ("a", 20) (Top 2: "a" 10, + // "a" 20). Then when seeing "b" and "c", it should filter them out based on the + // head key "city". This confirms that when filtering happens, the DocIds are + // preserved correctly. + let top_collector = TopDocs::with_limit(2).order_by(( + (SortByString::for_field("city"), Order::Asc), + ( + SortByStaticFastValue::::for_field("altitude"), + Order::Asc, + ), + )); + + let results: Vec = searcher + .search(&AllQuery, &top_collector)? + .into_iter() + .map(|(_, doc)| doc) + .collect(); + + // Doc 3 is ("a", 10). Doc 2 is ("a", 20). + assert_eq!(results, vec![DocAddress::new(0, 3), DocAddress::new(0, 2)]); + + Ok(()) + } } diff --git a/src/collector/sort_key/order.rs b/src/collector/sort_key/order.rs index 3cac357ad..178c9a07e 100644 --- a/src/collector/sort_key/order.rs +++ b/src/collector/sort_key/order.rs @@ -1,6 +1,6 @@ use std::cmp::Ordering; -use columnar::MonotonicallyMappableToU64; +use columnar::{MonotonicallyMappableToU64, ValueRange}; use serde::{Deserialize, Serialize}; use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; @@ -69,6 +69,10 @@ fn compare_owned_value(lhs: &OwnedValue, rhs: &OwnedVal pub trait Comparator: Send + Sync + std::fmt::Debug + Default { /// Return the order between two values. fn compare(&self, lhs: &T, rhs: &T) -> Ordering; + + /// Return a `ValueRange` that matches all values that are greater than the provided threshold. + #[allow(dead_code)] + fn threshold_to_valuerange(&self, threshold: T) -> ValueRange; } /// Compare values naturally (e.g. 1 < 2). @@ -86,6 +90,10 @@ impl Comparator for NaturalComparator { fn compare(&self, lhs: &T, rhs: &T) -> Ordering { lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal) } + + fn threshold_to_valuerange(&self, threshold: T) -> ValueRange { + ValueRange::GreaterThan(threshold, false) + } } /// A (partial) implementation of comparison for OwnedValue. @@ -97,6 +105,10 @@ impl Comparator for NaturalComparator { fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { compare_owned_value::(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange { + ValueRange::GreaterThan(threshold, false) + } } /// Compare values in reverse (e.g. 2 < 1). @@ -121,6 +133,10 @@ where NaturalComparator: Comparator fn compare(&self, lhs: &T, rhs: &T) -> Ordering { NaturalComparator.compare(rhs, lhs) } + + fn threshold_to_valuerange(&self, threshold: T) -> ValueRange { + ValueRange::LessThan(threshold, true) + } } /// Compare values in reverse, but treating `None` as lower than `Some`. @@ -147,6 +163,10 @@ where ReverseComparator: Comparator (Some(lhs), Some(rhs)) => ReverseComparator.compare(lhs, rhs), } } + + fn threshold_to_valuerange(&self, threshold: Option) -> ValueRange> { + ValueRange::LessThan(threshold, false) + } } impl Comparator for ReverseNoneIsLowerComparator { @@ -154,6 +174,10 @@ impl Comparator for ReverseNoneIsLowerComparator { fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering { ReverseComparator.compare(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange { + ValueRange::LessThan(threshold, false) + } } impl Comparator for ReverseNoneIsLowerComparator { @@ -161,6 +185,10 @@ impl Comparator for ReverseNoneIsLowerComparator { fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering { ReverseComparator.compare(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange { + ValueRange::LessThan(threshold, false) + } } impl Comparator for ReverseNoneIsLowerComparator { @@ -168,6 +196,10 @@ impl Comparator for ReverseNoneIsLowerComparator { fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering { ReverseComparator.compare(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange { + ValueRange::LessThan(threshold, false) + } } impl Comparator for ReverseNoneIsLowerComparator { @@ -175,6 +207,10 @@ impl Comparator for ReverseNoneIsLowerComparator { fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering { ReverseComparator.compare(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange { + ValueRange::LessThan(threshold, false) + } } impl Comparator for ReverseNoneIsLowerComparator { @@ -182,6 +218,10 @@ impl Comparator for ReverseNoneIsLowerComparator { fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering { ReverseComparator.compare(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange { + ValueRange::LessThan(threshold, false) + } } impl Comparator for ReverseNoneIsLowerComparator { @@ -189,6 +229,10 @@ impl Comparator for ReverseNoneIsLowerComparator { fn compare(&self, lhs: &String, rhs: &String) -> Ordering { ReverseComparator.compare(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: String) -> ValueRange { + ValueRange::LessThan(threshold, false) + } } impl Comparator for ReverseNoneIsLowerComparator { @@ -196,6 +240,10 @@ impl Comparator for ReverseNoneIsLowerComparator { fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { compare_owned_value::(rhs, lhs) } + + fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange { + ValueRange::LessThan(threshold, false) + } } /// Compare values naturally, but treating `None` as higher than `Some`. @@ -218,6 +266,10 @@ where NaturalComparator: Comparator (Some(lhs), Some(rhs)) => NaturalComparator.compare(lhs, rhs), } } + + fn threshold_to_valuerange(&self, threshold: Option) -> ValueRange> { + ValueRange::GreaterThan(threshold, true) + } } impl Comparator for NaturalNoneIsHigherComparator { @@ -225,6 +277,10 @@ impl Comparator for NaturalNoneIsHigherComparator { fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering { NaturalComparator.compare(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange { + ValueRange::GreaterThan(threshold, true) + } } impl Comparator for NaturalNoneIsHigherComparator { @@ -232,6 +288,10 @@ impl Comparator for NaturalNoneIsHigherComparator { fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering { NaturalComparator.compare(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange { + ValueRange::GreaterThan(threshold, true) + } } impl Comparator for NaturalNoneIsHigherComparator { @@ -239,6 +299,10 @@ impl Comparator for NaturalNoneIsHigherComparator { fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering { NaturalComparator.compare(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange { + ValueRange::GreaterThan(threshold, true) + } } impl Comparator for NaturalNoneIsHigherComparator { @@ -246,6 +310,10 @@ impl Comparator for NaturalNoneIsHigherComparator { fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering { NaturalComparator.compare(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange { + ValueRange::GreaterThan(threshold, true) + } } impl Comparator for NaturalNoneIsHigherComparator { @@ -253,6 +321,10 @@ impl Comparator for NaturalNoneIsHigherComparator { fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering { NaturalComparator.compare(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange { + ValueRange::GreaterThan(threshold, true) + } } impl Comparator for NaturalNoneIsHigherComparator { @@ -260,6 +332,10 @@ impl Comparator for NaturalNoneIsHigherComparator { fn compare(&self, lhs: &String, rhs: &String) -> Ordering { NaturalComparator.compare(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: String) -> ValueRange { + ValueRange::GreaterThan(threshold, true) + } } impl Comparator for NaturalNoneIsHigherComparator { @@ -267,6 +343,10 @@ impl Comparator for NaturalNoneIsHigherComparator { fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { compare_owned_value::(lhs, rhs) } + + fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange { + ValueRange::GreaterThan(threshold, true) + } } /// An enum representing the different sort orders. @@ -308,6 +388,19 @@ where ComparatorEnum::NaturalNoneHigher => NaturalNoneIsHigherComparator.compare(lhs, rhs), } } + + fn threshold_to_valuerange(&self, threshold: T) -> ValueRange { + match self { + ComparatorEnum::Natural => NaturalComparator.threshold_to_valuerange(threshold), + ComparatorEnum::Reverse => ReverseComparator.threshold_to_valuerange(threshold), + ComparatorEnum::ReverseNoneLower => { + ReverseNoneIsLowerComparator.threshold_to_valuerange(threshold) + } + ComparatorEnum::NaturalNoneHigher => { + NaturalNoneIsHigherComparator.threshold_to_valuerange(threshold) + } + } + } } impl Comparator<(Head, Tail)> @@ -322,6 +415,10 @@ where .compare(&lhs.0, &rhs.0) .then_with(|| self.1.compare(&lhs.1, &rhs.1)) } + + fn threshold_to_valuerange(&self, threshold: (Head, Tail)) -> ValueRange<(Head, Tail)> { + ValueRange::GreaterThan(threshold, false) + } } impl Comparator<(Type1, (Type2, Type3))> @@ -338,6 +435,13 @@ where .then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0)) .then_with(|| self.2.compare(&lhs.1 .1, &rhs.1 .1)) } + + fn threshold_to_valuerange( + &self, + threshold: (Type1, (Type2, Type3)), + ) -> ValueRange<(Type1, (Type2, Type3))> { + ValueRange::GreaterThan(threshold, false) + } } impl Comparator<(Type1, Type2, Type3)> @@ -354,6 +458,13 @@ where .then_with(|| self.1.compare(&lhs.1, &rhs.1)) .then_with(|| self.2.compare(&lhs.2, &rhs.2)) } + + fn threshold_to_valuerange( + &self, + threshold: (Type1, Type2, Type3), + ) -> ValueRange<(Type1, Type2, Type3)> { + ValueRange::GreaterThan(threshold, false) + } } impl @@ -377,6 +488,13 @@ where .then_with(|| self.2.compare(&lhs.1 .1 .0, &rhs.1 .1 .0)) .then_with(|| self.3.compare(&lhs.1 .1 .1, &rhs.1 .1 .1)) } + + fn threshold_to_valuerange( + &self, + threshold: (Type1, (Type2, (Type3, Type4))), + ) -> ValueRange<(Type1, (Type2, (Type3, Type4)))> { + ValueRange::GreaterThan(threshold, false) + } } impl @@ -400,6 +518,13 @@ where .then_with(|| self.2.compare(&lhs.2, &rhs.2)) .then_with(|| self.3.compare(&lhs.3, &rhs.3)) } + + fn threshold_to_valuerange( + &self, + threshold: (Type1, Type2, Type3, Type4), + ) -> ValueRange<(Type1, Type2, Type3, Type4)> { + ValueRange::GreaterThan(threshold, false) + } } impl SortKeyComputer for (TSortKeyComputer, ComparatorEnum) @@ -489,16 +614,29 @@ impl SegmentSortKeyComput where TSegmentSortKeyComputer: SegmentSortKeyComputer, TSegmentSortKey: Clone + 'static + Sync + Send, - TComparator: Comparator + 'static + Sync + Send, + TComparator: Comparator + Clone + 'static + Sync + Send, { type SortKey = TSegmentSortKeyComputer::SortKey; type SegmentSortKey = TSegmentSortKey; type SegmentComparator = TComparator; + fn segment_comparator(&self) -> Self::SegmentComparator { + self.comparator.clone() + } + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { self.segment_sort_key_computer.segment_sort_key(doc, score) } + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + self.segment_sort_key_computer + .segment_sort_keys(docs, filter) + } + #[inline(always)] fn compare_segment_sort_key( &self, diff --git a/src/collector/sort_key/sort_by_erased_type.rs b/src/collector/sort_key/sort_by_erased_type.rs index d15dd130c..8bdbe8057 100644 --- a/src/collector/sort_key/sort_by_erased_type.rs +++ b/src/collector/sort_key/sort_by_erased_type.rs @@ -1,5 +1,6 @@ -use columnar::{ColumnType, MonotonicallyMappableToU64}; +use columnar::{ColumnType, MonotonicallyMappableToU64, ValueRange}; +use crate::collector::sort_key::sort_by_score::SortBySimilarityScoreSegmentComputer; use crate::collector::sort_key::{ NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString, }; @@ -36,6 +37,11 @@ impl SortByErasedType { trait ErasedSegmentSortKeyComputer: Send + Sync { fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option; + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange>, + ) -> &mut Vec<(DocId, Option)>; fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue; } @@ -53,6 +59,14 @@ where self.inner.segment_sort_key(doc, score) } + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange>, + ) -> &mut Vec<(DocId, Option)> { + self.inner.segment_sort_keys(docs, filter) + } + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { let val = self.inner.convert_segment_sort_key(sort_key); (self.converter)(val) @@ -60,7 +74,7 @@ where } struct ScoreSegmentSortKeyComputer { - segment_computer: SortBySimilarityScore, + segment_computer: SortBySimilarityScoreSegmentComputer, } impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer { @@ -69,6 +83,14 @@ impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer { Some(score_value.to_u64()) } + fn segment_sort_keys( + &mut self, + _docs: &[DocId], + _filter: ValueRange>, + ) -> &mut Vec<(DocId, Option)> { + unimplemented!("Batch computation not supported for score sorting") + } + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { let score_value: u64 = sort_key.expect("This implementation always produces a score."); OwnedValue::F64(f64::from_u64(score_value)) @@ -174,7 +196,8 @@ impl SortKeyComputer for SortByErasedType { } } Self::Score => Box::new(ScoreSegmentSortKeyComputer { - segment_computer: SortBySimilarityScore, + segment_computer: SortBySimilarityScore + .segment_sort_key_computer(segment_reader)?, }), }; Ok(ErasedColumnSegmentSortKeyComputer { inner }) @@ -195,6 +218,14 @@ impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer { self.inner.segment_sort_key(doc, score) } + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + self.inner.segment_sort_keys(docs, filter) + } + fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue { self.inner.convert_segment_sort_key(segment_sort_key) } diff --git a/src/collector/sort_key/sort_by_score.rs b/src/collector/sort_key/sort_by_score.rs index a23660e56..4da9643df 100644 --- a/src/collector/sort_key/sort_by_score.rs +++ b/src/collector/sort_key/sort_by_score.rs @@ -1,3 +1,5 @@ +use columnar::ValueRange; + use crate::collector::sort_key::NaturalComparator; use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer}; use crate::{DocAddress, DocId, Score}; @@ -9,7 +11,7 @@ pub struct SortBySimilarityScore; impl SortKeyComputer for SortBySimilarityScore { type SortKey = Score; - type Child = SortBySimilarityScore; + type Child = SortBySimilarityScoreSegmentComputer; type Comparator = NaturalComparator; @@ -21,7 +23,7 @@ impl SortKeyComputer for SortBySimilarityScore { &self, _segment_reader: &crate::SegmentReader, ) -> crate::Result { - Ok(SortBySimilarityScore) + Ok(SortBySimilarityScoreSegmentComputer) } // Sorting by score is special in that it allows for the Block-Wand optimization. @@ -61,7 +63,9 @@ impl SortKeyComputer for SortBySimilarityScore { } } -impl SegmentSortKeyComputer for SortBySimilarityScore { +pub struct SortBySimilarityScoreSegmentComputer; + +impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer { type SortKey = Score; type SegmentSortKey = Score; type SegmentComparator = NaturalComparator; @@ -71,6 +75,14 @@ impl SegmentSortKeyComputer for SortBySimilarityScore { score } + fn segment_sort_keys( + &mut self, + _docs: &[DocId], + _filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + unimplemented!("Batch computation not supported for score sorting") + } + fn convert_segment_sort_key(&self, score: Score) -> Score { score } diff --git a/src/collector/sort_key/sort_by_static_fast_value.rs b/src/collector/sort_key/sort_by_static_fast_value.rs index 44a4e1d8d..01ddc45d8 100644 --- a/src/collector/sort_key/sort_by_static_fast_value.rs +++ b/src/collector/sort_key/sort_by_static_fast_value.rs @@ -1,7 +1,8 @@ use std::marker::PhantomData; -use columnar::Column; +use columnar::{Column, ValueRange}; +use crate::collector::sort_key::sort_key_computer::convert_optional_u64_range_to_u64_range; use crate::collector::sort_key::NaturalComparator; use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; use crate::fastfield::{FastFieldNotAvailableError, FastValue}; @@ -71,6 +72,9 @@ impl SortKeyComputer for SortByStaticFastValue { Ok(SortByFastValueSegmentSortKeyComputer { sort_column, typ: PhantomData, + buffer: Vec::new(), + fetch_buffer: Vec::new(), + doc_buffer: Vec::new(), }) } } @@ -78,6 +82,9 @@ impl SortKeyComputer for SortByStaticFastValue { pub struct SortByFastValueSegmentSortKeyComputer { sort_column: Column, typ: PhantomData, + buffer: Vec<(DocId, Option)>, + fetch_buffer: Vec>, + doc_buffer: Vec, } impl SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer { @@ -90,7 +97,102 @@ impl SegmentSortKeyComputer for SortByFastValueSegmentSortKeyCompu self.sort_column.first(doc) } + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + self.doc_buffer.clear(); + self.doc_buffer.extend_from_slice(docs); + self.fetch_buffer.clear(); + let u64_filter = convert_optional_u64_range_to_u64_range(filter); + self.sort_column.first_vals_in_value_range( + &mut self.doc_buffer, + &mut self.fetch_buffer, + u64_filter, + ); + + self.buffer.clear(); + for (&doc, &val) in self.doc_buffer.iter().zip(self.fetch_buffer.iter()) { + self.buffer.push((doc, val)); + } + &mut self.buffer + } + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { sort_key.map(T::from_u64) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::schema::{Schema, FAST}; + use crate::Index; + + #[test] + fn test_sort_by_fast_value_batch() { + let mut schema_builder = Schema::builder(); + let field_col = schema_builder.add_u64_field("field", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests().unwrap(); + + index_writer + .add_document(crate::doc!(field_col => 10u64)) + .unwrap(); + index_writer + .add_document(crate::doc!(field_col => 20u64)) + .unwrap(); + index_writer.add_document(crate::doc!()).unwrap(); + index_writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let sorter = SortByStaticFastValue::::for_field("field"); + let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap(); + + let docs = vec![0, 1, 2]; + let output = computer.segment_sort_keys(&docs, ValueRange::All); + + assert_eq!(output, &[(0, Some(10)), (1, Some(20)), (2, None)]); + } + + #[test] + fn test_sort_by_fast_value_batch_with_filter() { + let mut schema_builder = Schema::builder(); + let field_col = schema_builder.add_u64_field("field", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests().unwrap(); + + index_writer + .add_document(crate::doc!(field_col => 10u64)) + .unwrap(); + index_writer + .add_document(crate::doc!(field_col => 20u64)) + .unwrap(); + index_writer.add_document(crate::doc!()).unwrap(); + index_writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let sorter = SortByStaticFastValue::::for_field("field"); + let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap(); + + let docs = vec![0, 1, 2]; + let output = computer.segment_sort_keys( + &docs, + ValueRange::GreaterThan(Some(15u64), false /* inclusive */), + ); + + // Should contain only the document with value 20. + // Doc 0 (10) < 15 + // Doc 2 (None) < 15 + assert_eq!(output, &[(1, Some(20))]); + } +} diff --git a/src/collector/sort_key/sort_by_string.rs b/src/collector/sort_key/sort_by_string.rs index 2dd0b4592..4582c1e22 100644 --- a/src/collector/sort_key/sort_by_string.rs +++ b/src/collector/sort_key/sort_by_string.rs @@ -1,5 +1,8 @@ -use columnar::StrColumn; +use columnar::{StrColumn, ValueRange}; +use crate::collector::sort_key::sort_key_computer::{ + convert_optional_u64_range_to_u64_range, range_contains_none, +}; use crate::collector::sort_key::NaturalComparator; use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; use crate::termdict::TermOrdinal; @@ -38,12 +41,20 @@ impl SortKeyComputer for SortByString { segment_reader: &crate::SegmentReader, ) -> crate::Result { let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?; - Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt }) + Ok(ByStringColumnSegmentSortKeyComputer { + str_column_opt, + buffer: Vec::new(), + fetch_buffer: Vec::new(), + doc_buffer: Vec::new(), + }) } } pub struct ByStringColumnSegmentSortKeyComputer { str_column_opt: Option, + buffer: Vec<(DocId, Option)>, + fetch_buffer: Vec>, + doc_buffer: Vec, } impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { @@ -57,6 +68,37 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { str_column.ords().first(doc) } + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + self.doc_buffer.clear(); + self.doc_buffer.extend_from_slice(docs); + self.fetch_buffer.clear(); + + if let Some(str_column) = &self.str_column_opt { + let u64_filter = convert_optional_u64_range_to_u64_range(filter); + str_column.ords().first_vals_in_value_range( + &mut self.doc_buffer, + &mut self.fetch_buffer, + u64_filter, + ); + } else if range_contains_none(&filter) { + for _ in 0..docs.len() { + self.fetch_buffer.push(None); + } + } else { + self.doc_buffer.clear(); + } + + self.buffer.clear(); + for (&doc, &val) in self.doc_buffer.iter().zip(self.fetch_buffer.iter()) { + self.buffer.push((doc, val)); + } + &mut self.buffer + } + fn convert_segment_sort_key(&self, term_ord_opt: Option) -> Option { // TODO: Individual lookups to the dictionary like this are very likely to repeatedly // decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776 @@ -70,3 +112,80 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { String::try_from(bytes).ok() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::schema::{Schema, FAST, TEXT}; + use crate::Index; + + #[test] + fn test_sort_by_string_batch() { + let mut schema_builder = Schema::builder(); + let field_col = schema_builder.add_text_field("field", FAST | TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests().unwrap(); + + index_writer + .add_document(crate::doc!(field_col => "a")) + .unwrap(); + index_writer + .add_document(crate::doc!(field_col => "c")) + .unwrap(); + index_writer.add_document(crate::doc!()).unwrap(); + index_writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let sorter = SortByString::for_field("field"); + let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap(); + + let docs = vec![0, 1, 2]; + let output = computer.segment_sort_keys(&docs, ValueRange::All); + + // We expect ordinals. + // "a" -> 0 + // "c" -> 1 + assert_eq!(output, &[(0, Some(0)), (1, Some(1)), (2, None)]); + } + + #[test] + fn test_sort_by_string_batch_with_filter() { + let mut schema_builder = Schema::builder(); + let field_col = schema_builder.add_text_field("field", FAST | TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests().unwrap(); + + index_writer + .add_document(crate::doc!(field_col => "a")) + .unwrap(); + index_writer + .add_document(crate::doc!(field_col => "c")) + .unwrap(); + index_writer.add_document(crate::doc!()).unwrap(); + index_writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let sorter = SortByString::for_field("field"); + let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap(); + + let docs = vec![0, 1, 2]; + // Filter: > "b". "a" is 0, "c" is 1. + // We want > "a" (ord 0). So we filter > ord 0. + // 0 is "a", 1 is "c". + let output = computer.segment_sort_keys( + &docs, + ValueRange::GreaterThan(Some(0), false /* inclusive */), + ); + + // Should contain only the document with value "c" (ord 1). + assert_eq!(output, &[(1, Some(1))]); + } +} diff --git a/src/collector/sort_key/sort_key_computer.rs b/src/collector/sort_key/sort_key_computer.rs index 6aab919a9..d21de75aa 100644 --- a/src/collector/sort_key/sort_key_computer.rs +++ b/src/collector/sort_key/sort_key_computer.rs @@ -1,8 +1,13 @@ use std::cmp::Ordering; +use columnar::ValueRange; + use crate::collector::sort_key::{Comparator, NaturalComparator}; use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector; -use crate::collector::{default_collect_segment_impl, SegmentCollector as _, TopNComputer}; +use crate::collector::top_score_collector::push_assuming_capacity; +use crate::collector::{ + default_collect_segment_impl, ComparableDoc, SegmentCollector as _, TopNComputer, +}; use crate::schema::Schema; use crate::{DocAddress, DocId, Result, Score, SegmentReader}; @@ -21,7 +26,7 @@ pub trait SegmentSortKeyComputer: 'static { type SegmentSortKey: 'static + Clone + Send + Sync + Clone; /// Comparator type. - type SegmentComparator: Comparator + 'static; + type SegmentComparator: Comparator + Clone + 'static; /// Returns the segment sort key comparator. fn segment_comparator(&self) -> Self::SegmentComparator { @@ -31,6 +36,16 @@ pub trait SegmentSortKeyComputer: 'static { /// Computes the sort key for the given document and score. fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey; + /// Computes the sort keys for a batch of documents. + /// + /// The computed sort keys are stored in an internal buffer and returned as a slice. + /// Subsequent calls to this method may reuse and overwrite the internal buffer. + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)>; + /// Computes the sort key and pushes the document in a TopN Computer. /// /// When using a tuple as the sorting key, the sort key is evaluated in a lazy manner. @@ -45,6 +60,42 @@ pub trait SegmentSortKeyComputer: 'static { top_n_computer.push(sort_key, doc); } + fn compute_sort_keys_and_collect>( + &mut self, + docs: &[DocId], + top_n_computer: &mut TopNComputer, + ) { + // The capacity of a TopNComputer is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so we + // should always be able to `reserve` space for the entire block. + top_n_computer.reserve(docs.len()); + + let comparator = self.segment_comparator(); + let value_range = if let Some(threshold) = &top_n_computer.threshold { + comparator.threshold_to_valuerange(threshold.clone()) + } else { + ValueRange::All + }; + + let sort_keys = self.segment_sort_keys(docs, value_range); + + if let Some(threshold) = &top_n_computer.threshold { + let threshold = threshold.clone(); + for (doc, sort_key) in sort_keys.drain(..) { + let cmp = comparator.compare(&sort_key, &threshold); + if cmp == Ordering::Greater { + // We validated at the top of the method that we have capacity. + top_n_computer.append_doc_unchecked(doc, sort_key); + } + } + } else { + // Eagerly push, without a threshold to compare to. + for (doc, sort_key) in sort_keys.drain(..) { + // We validated at the top of the method that we have capacity. + top_n_computer.append_doc_unchecked(doc, sort_key); + } + } + } + /// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on /// its ordering. /// @@ -58,23 +109,24 @@ pub trait SegmentSortKeyComputer: 'static { self.segment_comparator().compare(left, right) } - /// Implementing this method makes it possible to avoid computing - /// a sort_key entirely if we can assess that it won't pass a threshold - /// with a partial computation. + /// Similar to `accept_sort_key_lazy`, but pushes results directly into the given buffer. Does + /// not support scoring. /// - /// This is currently used for lexicographic sorting. - fn accept_sort_key_lazy( + /// The buffer must have at least enough capacity for `docs` matches, or this method will + /// panic. + fn accept_sort_key_block_lazy( &mut self, - doc_id: DocId, - score: Score, + docs: &[DocId], threshold: &Self::SegmentSortKey, - ) -> Option<(Ordering, Self::SegmentSortKey)> { - let sort_key = self.segment_sort_key(doc_id, score); - let cmp = self.compare_segment_sort_key(&sort_key, threshold); - if cmp == Ordering::Less { - None - } else { - Some((cmp, sort_key)) + output: &mut Vec>, + ) { + let comparator = self.segment_comparator(); + for &doc in docs { + let sort_key = self.segment_sort_key(doc, 0.0); + let cmp = comparator.compare(&sort_key, threshold); + if cmp != Ordering::Less { + push_assuming_capacity(ComparableDoc { sort_key, doc }, output); + } } } @@ -145,7 +197,8 @@ where TailSortKeyComputer: SortKeyComputer, { type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey); - type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child); + type Child = + ChainSegmentSortKeyComputer; type Comparator = ( HeadSortKeyComputer::Comparator, @@ -157,10 +210,12 @@ where } fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { - Ok(( - self.0.segment_sort_key_computer(segment_reader)?, - self.1.segment_sort_key_computer(segment_reader)?, - )) + Ok(ChainSegmentSortKeyComputer { + head: self.0.segment_sort_key_computer(segment_reader)?, + tail: self.1.segment_sort_key_computer(segment_reader)?, + head_key_buffer: Vec::new(), + doc_buffer: Vec::new(), + }) } /// Checks whether the schema is compatible with the sort key computer. @@ -178,25 +233,68 @@ where } } -impl SegmentSortKeyComputer - for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer) +pub struct ChainSegmentSortKeyComputer where - HeadSegmentSortKeyComputer: SegmentSortKeyComputer, - TailSegmentSortKeyComputer: SegmentSortKeyComputer, + Head: SegmentSortKeyComputer, + Tail: SegmentSortKeyComputer, { - type SortKey = ( - HeadSegmentSortKeyComputer::SortKey, - TailSegmentSortKeyComputer::SortKey, - ); - type SegmentSortKey = ( - HeadSegmentSortKeyComputer::SegmentSortKey, - TailSegmentSortKeyComputer::SegmentSortKey, - ); + head: Head, + tail: Tail, + head_key_buffer: Vec, + doc_buffer: Vec, +} - type SegmentComparator = ( - HeadSegmentSortKeyComputer::SegmentComparator, - TailSegmentSortKeyComputer::SegmentComparator, - ); +impl ChainSegmentSortKeyComputer +where + Head: SegmentSortKeyComputer, + Tail: SegmentSortKeyComputer, +{ + fn accept_sort_key_lazy( + &mut self, + doc_id: DocId, + score: Score, + threshold: &::SegmentSortKey, + ) -> Option<(Ordering, ::SegmentSortKey)> { + let (head_threshold, tail_threshold) = threshold; + let head_sort_key = self.head.segment_sort_key(doc_id, score); + let head_cmp = self + .head + .compare_segment_sort_key(&head_sort_key, head_threshold); + if head_cmp == Ordering::Less { + None + } else if head_cmp == Ordering::Equal { + let tail_sort_key = self.tail.segment_sort_key(doc_id, score); + let tail_cmp = self + .tail + .compare_segment_sort_key(&tail_sort_key, tail_threshold); + if tail_cmp == Ordering::Less { + None + } else { + Some((tail_cmp, (head_sort_key, tail_sort_key))) + } + } else { + let tail_sort_key = self.tail.segment_sort_key(doc_id, score); + Some((head_cmp, (head_sort_key, tail_sort_key))) + } + } +} + +impl SegmentSortKeyComputer for ChainSegmentSortKeyComputer +where + Head: SegmentSortKeyComputer, + Tail: SegmentSortKeyComputer, +{ + type SortKey = (Head::SortKey, Tail::SortKey); + type SegmentSortKey = (Head::SegmentSortKey, Tail::SegmentSortKey); + + type SegmentComparator = (Head::SegmentComparator, Tail::SegmentComparator); + + fn segment_comparator(&self) -> Self::SegmentComparator { + ( + self.head.segment_comparator(), + self.tail.segment_comparator(), + ) + } /// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on /// its ordering. @@ -208,9 +306,17 @@ where left: &Self::SegmentSortKey, right: &Self::SegmentSortKey, ) -> Ordering { - self.0 + self.head .compare_segment_sort_key(&left.0, &right.0) - .then_with(|| self.1.compare_segment_sort_key(&left.1, &right.1)) + .then_with(|| self.tail.compare_segment_sort_key(&left.1, &right.1)) + } + + fn segment_sort_keys( + &mut self, + _docs: &[DocId], + _filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + unimplemented!("The head and the tail are accessed independently."); } #[inline(always)] @@ -233,50 +339,89 @@ where top_n_computer.append_doc(doc, sort_key); } - #[inline(always)] - fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { - let head_sort_key = self.0.segment_sort_key(doc, score); - let tail_sort_key = self.1.segment_sort_key(doc, score); - (head_sort_key, tail_sort_key) + fn compute_sort_keys_and_collect>( + &mut self, + docs: &[DocId], + top_n_computer: &mut TopNComputer, + ) { + // The capacity of a TopNComputer is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so we + // should always be able to `reserve` space for the entire block. + top_n_computer.reserve(docs.len()); + + if let Some(threshold) = &top_n_computer.threshold { + let (head_threshold, tail_threshold) = threshold.clone(); + let head_cmp = self.head.segment_comparator(); + let tail_cmp = self.tail.segment_comparator(); + let head_filter = head_cmp.threshold_to_valuerange(head_threshold.clone()); + + let head_keys = self.head.segment_sort_keys(docs, head_filter); + self.doc_buffer.clear(); + self.head_key_buffer.clear(); + for (doc, head_key) in head_keys.drain(..) { + let cmp = head_cmp.compare(&head_key, &head_threshold); + if cmp != Ordering::Less { + self.doc_buffer.push(doc); + self.head_key_buffer.push(head_key); + } + } + + if !self.doc_buffer.is_empty() { + let tail_keys = self + .tail + .segment_sort_keys(&self.doc_buffer, ValueRange::All); + for ((head_key, tail_key), &doc) in self + .head_key_buffer + .drain(..) + .zip(tail_keys.drain(..).map(|(_, k)| k)) + .zip(self.doc_buffer.iter()) + { + let head_ord = head_cmp.compare(&head_key, &head_threshold); + let ord = if head_ord == Ordering::Equal { + tail_cmp.compare(&tail_key, &tail_threshold) + } else { + head_ord + }; + if ord == Ordering::Greater { + top_n_computer.append_doc_unchecked(doc, (head_key, tail_key)); + } + } + } + } else { + // Eagerly push, without a threshold to compare to. + let head_keys = self.head.segment_sort_keys(docs, ValueRange::All); + let tail_keys = self.tail.segment_sort_keys(docs, ValueRange::All); + for ((doc, head_key), (_, tail_key)) in head_keys.drain(..).zip(tail_keys.drain(..)) { + // We validated at the top of the method that we have capacity. + top_n_computer.append_doc_unchecked(doc, (head_key, tail_key)); + } + } } - fn accept_sort_key_lazy( - &mut self, - doc_id: DocId, - score: Score, - threshold: &Self::SegmentSortKey, - ) -> Option<(Ordering, Self::SegmentSortKey)> { - let (head_threshold, tail_threshold) = threshold; - let (head_cmp, head_sort_key) = - self.0.accept_sort_key_lazy(doc_id, score, head_threshold)?; - if head_cmp == Ordering::Equal { - let (tail_cmp, tail_sort_key) = - self.1.accept_sort_key_lazy(doc_id, score, tail_threshold)?; - Some((tail_cmp, (head_sort_key, tail_sort_key))) - } else { - let tail_sort_key = self.1.segment_sort_key(doc_id, score); - Some((head_cmp, (head_sort_key, tail_sort_key))) - } + #[inline(always)] + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { + let head_sort_key = self.head.segment_sort_key(doc, score); + let tail_sort_key = self.tail.segment_sort_key(doc, score); + (head_sort_key, tail_sort_key) } fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { let (head_sort_key, tail_sort_key) = sort_key; ( - self.0.convert_segment_sort_key(head_sort_key), - self.1.convert_segment_sort_key(tail_sort_key), + self.head.convert_segment_sort_key(head_sort_key), + self.tail.convert_segment_sort_key(tail_sort_key), ) } } /// This struct is used as an adapter to take a sort key computer and map its score to another /// new sort key. -pub struct MappedSegmentSortKeyComputer { +pub struct MappedSegmentSortKeyComputer { sort_key_computer: T, - map: fn(PreviousSortKey) -> NewSortKey, + map: fn(T::SortKey) -> NewSortKey, } impl SegmentSortKeyComputer - for MappedSegmentSortKeyComputer + for MappedSegmentSortKeyComputer where T: SegmentSortKeyComputer, PreviousScore: 'static + Clone + Send + Sync, @@ -286,18 +431,21 @@ where type SegmentSortKey = T::SegmentSortKey; type SegmentComparator = T::SegmentComparator; + fn segment_comparator(&self) -> Self::SegmentComparator { + self.sort_key_computer.segment_comparator() + } + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { self.sort_key_computer.segment_sort_key(doc, score) } - fn accept_sort_key_lazy( + fn segment_sort_keys( &mut self, - doc_id: DocId, - score: Score, - threshold: &Self::SegmentSortKey, - ) -> Option<(Ordering, Self::SegmentSortKey)> { + docs: &[DocId], + _filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { self.sort_key_computer - .accept_sort_key_lazy(doc_id, score, threshold) + .segment_sort_keys(docs, ValueRange::All) } #[inline(always)] @@ -311,6 +459,15 @@ where .compute_sort_key_and_collect(doc, score, top_n_computer); } + fn compute_sort_keys_and_collect>( + &mut self, + docs: &[DocId], + top_n_computer: &mut TopNComputer, + ) { + self.sort_key_computer + .compute_sort_keys_and_collect(docs, top_n_computer); + } + fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey { (self.map)( self.sort_key_computer @@ -336,10 +493,6 @@ where ); type Child = MappedSegmentSortKeyComputer< <(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child, - ( - SortKeyComputer1::SortKey, - (SortKeyComputer2::SortKey, SortKeyComputer3::SortKey), - ), Self::SortKey, >; @@ -363,7 +516,17 @@ where let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?; let map = |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3); Ok(MappedSegmentSortKeyComputer { - sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)), + sort_key_computer: ChainSegmentSortKeyComputer { + head: sort_key_computer1, + tail: ChainSegmentSortKeyComputer { + head: sort_key_computer2, + tail: sort_key_computer3, + head_key_buffer: Vec::new(), + doc_buffer: Vec::new(), + }, + head_key_buffer: Vec::new(), + doc_buffer: Vec::new(), + }, map, }) } @@ -398,13 +561,6 @@ where SortKeyComputer1, (SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)), ) as SortKeyComputer>::Child, - ( - SortKeyComputer1::SortKey, - ( - SortKeyComputer2::SortKey, - (SortKeyComputer3::SortKey, SortKeyComputer4::SortKey), - ), - ), Self::SortKey, >; type SortKey = ( @@ -426,10 +582,22 @@ where let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?; let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?; Ok(MappedSegmentSortKeyComputer { - sort_key_computer: ( - sort_key_computer1, - (sort_key_computer2, (sort_key_computer3, sort_key_computer4)), - ), + sort_key_computer: ChainSegmentSortKeyComputer { + head: sort_key_computer1, + tail: ChainSegmentSortKeyComputer { + head: sort_key_computer2, + tail: ChainSegmentSortKeyComputer { + head: sort_key_computer3, + tail: sort_key_computer4, + head_key_buffer: Vec::new(), + doc_buffer: Vec::new(), + }, + head_key_buffer: Vec::new(), + doc_buffer: Vec::new(), + }, + head_key_buffer: Vec::new(), + doc_buffer: Vec::new(), + }, map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| { (sort_key1, sort_key2, sort_key3, sort_key4) }, @@ -452,6 +620,11 @@ where } } +pub struct FuncSegmentSortKeyComputer { + func: F, + buffer: Vec<(DocId, TSortKey)>, +} + impl SortKeyComputer for F where F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF, @@ -459,15 +632,18 @@ where TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug, { type SortKey = TSortKey; - type Child = SegmentF; + type Child = FuncSegmentSortKeyComputer; type Comparator = NaturalComparator; fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { - Ok((self)(segment_reader)) + Ok(FuncSegmentSortKeyComputer { + func: (self)(segment_reader), + buffer: Vec::new(), + }) } } -impl SegmentSortKeyComputer for F +impl SegmentSortKeyComputer for FuncSegmentSortKeyComputer where F: 'static + FnMut(DocId) -> TSortKey, TSortKey: 'static + PartialOrd + Clone + Send + Sync, @@ -477,7 +653,20 @@ where type SegmentComparator = NaturalComparator; fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey { - (self)(doc) + (self.func)(doc) + } + + fn segment_sort_keys( + &mut self, + docs: &[DocId], + _filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + self.buffer.clear(); + self.buffer.reserve(docs.len()); + for &doc in docs { + self.buffer.push((doc, (self.func)(doc))); + } + &mut self.buffer } /// Convert a segment level score into the global level score. @@ -486,6 +675,34 @@ where } } +pub(crate) fn range_contains_none(range: &ValueRange>) -> bool { + match range { + ValueRange::All => true, + ValueRange::Inclusive(r) => r.contains(&None), + ValueRange::GreaterThan(threshold, match_nulls) => *match_nulls || (None > *threshold), + ValueRange::LessThan(threshold, match_nulls) => *match_nulls || (None < *threshold), + } +} + +pub(crate) fn convert_optional_u64_range_to_u64_range( + range: ValueRange>, +) -> ValueRange { + if range_contains_none(&range) { + return ValueRange::All; + } + match range { + ValueRange::Inclusive(r) => { + let start = r.start().unwrap_or(0); + let end = r.end().unwrap_or(u64::MAX); + ValueRange::Inclusive(start..=end) + } + ValueRange::GreaterThan(Some(val), _match_nulls) => ValueRange::GreaterThan(val, false), + ValueRange::GreaterThan(None, _match_nulls) => ValueRange::Inclusive(u64::MIN..=u64::MAX), + ValueRange::LessThan(None, _match_nulls) => ValueRange::Inclusive(1..=0), + _ => ValueRange::All, + } +} + #[cfg(test)] mod tests { use std::cmp::Ordering; diff --git a/src/collector/sort_key_top_collector.rs b/src/collector/sort_key_top_collector.rs index 3ca27fc75..457854620 100644 --- a/src/collector/sort_key_top_collector.rs +++ b/src/collector/sort_key_top_collector.rs @@ -120,6 +120,11 @@ where ); } + fn collect_block(&mut self, docs: &[DocId]) { + self.segment_sort_key_computer + .compute_sort_keys_and_collect(docs, &mut self.topn_computer); + } + fn harvest(self) -> Self::Fruit { let segment_ord = self.segment_ord; let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 0ce1c611a..5ce928c47 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -2,6 +2,7 @@ use std::cmp::Ordering; use std::fmt; use std::ops::Range; +use columnar::ValueRange; use serde::{Deserialize, Serialize}; use super::Collector; @@ -486,6 +487,14 @@ where (self.sort_key_fn)(doc, score) } + fn segment_sort_keys( + &mut self, + _docs: &[DocId], + _filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + unimplemented!("Batch computation is not supported for tweak score.") + } + /// Convert a segment level score into the global level score. fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { sort_key @@ -604,9 +613,12 @@ where C: Comparator, { /// Create a new `TopNComputer`. - /// Internally it will allocate a buffer of size `2 * top_n`. + /// Internally it will allocate a buffer of size `(top_n.max(1) * 2) + + /// COLLECT_BLOCK_BUFFER_LEN`. pub fn new_with_comparator(top_n: usize, comparator: C) -> Self { - let vec_cap = top_n.max(1) * 2; + // We ensure that there is always enough space to include an entire block in the buffer if + // need be, so that `push_block_lazy` can avoid checking capacity inside its loop. + let vec_cap = (top_n.max(1) * 2) + crate::COLLECT_BLOCK_BUFFER_LEN; TopNComputer { buffer: Vec::with_capacity(vec_cap), top_n, @@ -635,16 +647,31 @@ where // At this point, we need to have established that the doc is above the threshold. #[inline(always)] pub(crate) fn append_doc(&mut self, doc: D, sort_key: TSortKey) { - if self.buffer.len() == self.buffer.capacity() { - let median = self.truncate_top_n(); - self.threshold = Some(median); - } - // This cannot panic, because we truncate_median will at least remove one element, since - // the min capacity is 2. + self.reserve(1); + // This cannot panic, because we've reserved room for one element. + self.append_doc_unchecked(doc, sort_key); + } + + // Append a document to the top n. `reserve` must already have been called to ensure that there + // is capacity, or this method will panic. + // + // At this point, we need to have established that the doc is above the threshold. + #[inline(always)] + pub(crate) fn append_doc_unchecked(&mut self, doc: D, sort_key: TSortKey) { let comparable_doc = ComparableDoc { doc, sort_key }; push_assuming_capacity(comparable_doc, &mut self.buffer); } + // Ensure that there is capacity to push `additional` more elements without resizing. + #[inline(always)] + pub(crate) fn reserve(&mut self, additional: usize) { + if self.buffer.len() + additional > self.buffer.capacity() { + let median = self.truncate_top_n(); + debug_assert!(self.buffer.len() + additional <= self.buffer.capacity()); + self.threshold = Some(median); + } + } + #[inline(never)] fn truncate_top_n(&mut self) -> TSortKey { // Use select_nth_unstable to find the top nth score @@ -684,7 +711,7 @@ where // // Panics if there is not enough capacity to add an element. #[inline(always)] -fn push_assuming_capacity(el: T, buf: &mut Vec) { +pub fn push_assuming_capacity(el: T, buf: &mut Vec) { let prev_len = buf.len(); assert!(prev_len < buf.capacity()); // This is mimicking the current (non-stabilized) implementation in std. @@ -1408,11 +1435,11 @@ mod tests { #[test] fn test_top_field_collect_string_prop( order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)), - limit in 1..256_usize, - offset in 0..256_usize, + limit in 1..32_usize, + offset in 0..32_usize, segments_terms in proptest::collection::vec( - proptest::collection::vec(0..32_u8, 1..32_usize), + proptest::collection::vec(0..64_u8, 1..256_usize), 0..8_usize, ) ) { diff --git a/src/fastfield/mod.rs b/src/fastfield/mod.rs index 726b9b76a..ecfe61bd0 100644 --- a/src/fastfield/mod.rs +++ b/src/fastfield/mod.rs @@ -79,7 +79,7 @@ mod tests { use std::ops::{Range, RangeInclusive}; use std::path::Path; - use columnar::StrColumn; + use columnar::{StrColumn, ValueRange}; use common::{ByteCount, DateTimePrecision, HasLen, TerminatingWrite}; use once_cell::sync::Lazy; use rand::prelude::SliceRandom; @@ -944,7 +944,7 @@ mod tests { let test_range = |range: RangeInclusive| { let expected_count = numbers.iter().filter(|num| range.contains(*num)).count(); let mut vec = vec![]; - field.get_row_ids_for_value_range(range, 0..u32::MAX, &mut vec); + field.get_row_ids_for_value_range(ValueRange::Inclusive(range), 0..u32::MAX, &mut vec); assert_eq!(vec.len(), expected_count); }; test_range(50..=50); @@ -1022,7 +1022,7 @@ mod tests { let test_range = |range: RangeInclusive| { let expected_count = numbers.iter().filter(|num| range.contains(*num)).count(); let mut vec = vec![]; - field.get_row_ids_for_value_range(range, 0..u32::MAX, &mut vec); + field.get_row_ids_for_value_range(ValueRange::Inclusive(range), 0..u32::MAX, &mut vec); assert_eq!(vec.len(), expected_count); }; let test_range_variant = |start, stop| { diff --git a/src/query/range_query/fast_field_range_doc_set.rs b/src/query/range_query/fast_field_range_doc_set.rs index 24d2b1fe3..3a606f221 100644 --- a/src/query/range_query/fast_field_range_doc_set.rs +++ b/src/query/range_query/fast_field_range_doc_set.rs @@ -1,7 +1,6 @@ use core::fmt::Debug; -use std::ops::RangeInclusive; -use columnar::Column; +use columnar::{Column, ValueRange}; use crate::{DocId, DocSet, TERMINATED}; @@ -41,7 +40,7 @@ impl VecCursor { pub(crate) struct RangeDocSet { /// The range filter on the values. - value_range: RangeInclusive, + value_range: ValueRange, column: Column, /// The next docid start range to fetch (inclusive). next_fetch_start: u32, @@ -61,8 +60,8 @@ pub(crate) struct RangeDocSet { const DEFAULT_FETCH_HORIZON: u32 = 128; impl RangeDocSet { - pub(crate) fn new(value_range: RangeInclusive, column: Column) -> Self { - if *value_range.start() > column.max_value() || *value_range.end() < column.min_value() { + pub(crate) fn new(value_range: ValueRange, column: Column) -> Self { + if !value_range.intersects(column.min_value(), column.max_value()) { return Self { value_range, column, diff --git a/src/query/range_query/range_query_fastfield.rs b/src/query/range_query/range_query_fastfield.rs index e379e108e..f4baa04cc 100644 --- a/src/query/range_query/range_query_fastfield.rs +++ b/src/query/range_query/range_query_fastfield.rs @@ -7,7 +7,7 @@ use std::ops::{Bound, RangeInclusive}; use columnar::{ Cardinality, Column, ColumnType, MonotonicallyMappableToU128, MonotonicallyMappableToU64, - NumericalType, StrColumn, + NumericalType, StrColumn, ValueRange, }; use common::bounds::{BoundsRange, TransformBound}; @@ -154,7 +154,7 @@ impl Weight for FastFieldRangeWeight { ip_addr_column.min_value(), ip_addr_column.max_value(), ); - let docset = RangeDocSet::new(value_range, ip_addr_column); + let docset = RangeDocSet::new(ValueRange::Inclusive(value_range), ip_addr_column); Ok(Box::new(ConstScorer::new(docset, boost))) } else if field_type.is_str() { let Some(str_dict_column): Option = reader.fast_fields().str(&field_name)? @@ -426,7 +426,7 @@ fn search_on_u64_ff( } } - let docset = RangeDocSet::new(value_range, column); + let docset = RangeDocSet::new(ValueRange::Inclusive(value_range), column); Ok(Box::new(ConstScorer::new(docset, boost))) }