Implement collect_block for lazy scorers using SegmentSortKeyComputer::segment_sort_keys.

This commit is contained in:
Stu Hood
2025-11-01 15:46:46 -07:00
parent 77505c3d03
commit 9615eb73b8
21 changed files with 1552 additions and 200 deletions

View File

@@ -1,6 +1,6 @@
use binggan::{InputGroup, black_box};
use common::*;
use tantivy_columnar::Column;
use tantivy_columnar::{Column, ValueRange};
pub mod common;
@@ -46,16 +46,16 @@ fn bench_group(mut runner: InputGroup<Column>) {
runner.register("access_first_vals", |column| {
let mut sum = 0;
const BLOCK_SIZE: usize = 32;
let mut docs = vec![0; BLOCK_SIZE];
let mut buffer = vec![None; BLOCK_SIZE];
let mut docs = Vec::with_capacity(BLOCK_SIZE);
let mut buffer = Vec::with_capacity(BLOCK_SIZE);
for i in (0..NUM_DOCS).step_by(BLOCK_SIZE) {
// fill docs
#[allow(clippy::needless_range_loop)]
docs.clear();
for idx in 0..BLOCK_SIZE {
docs[idx] = idx as u32 + i;
docs.push(idx as u32 + i);
}
column.first_vals(&docs, &mut buffer);
buffer.clear();
column.first_vals_in_value_range(&mut docs, &mut buffer, ValueRange::All);
for val in buffer.iter() {
let Some(val) = val else { continue };
sum += *val;

View File

@@ -89,31 +89,6 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
self.values_for_doc(row_id).next()
}
/// Load the first value for each docid in the provided slice.
#[inline]
pub fn first_vals(&self, docids: &[DocId], output: &mut [Option<T>]) {
match &self.index {
ColumnIndex::Empty { .. } => {}
ColumnIndex::Full => self.values.get_vals_opt(docids, output),
ColumnIndex::Optional(optional_index) => {
for (i, docid) in docids.iter().enumerate() {
output[i] = optional_index
.rank_if_exists(*docid)
.map(|rowid| self.values.get_val(rowid));
}
}
ColumnIndex::Multivalued(multivalued_index) => {
for (i, docid) in docids.iter().enumerate() {
let range = multivalued_index.range(*docid);
let is_empty = range.start == range.end;
if !is_empty {
output[i] = Some(self.values.get_val(range.start));
}
}
}
}
}
/// Translates a block of docids to row_ids.
///
/// returns the row_ids and the matching docids on the same index
@@ -143,7 +118,7 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
#[inline]
pub fn get_docids_for_value_range(
&self,
value_range: RangeInclusive<T>,
value_range: ValueRange<T>,
selected_docid_range: Range<u32>,
doc_ids: &mut Vec<u32>,
) {
@@ -168,6 +143,182 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
}
}
// Separate impl block for methods requiring `Default` for `T`.
impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static + Default> Column<T> {
/// Load the first value for each docid in the provided slice.
#[inline]
pub fn first_vals_in_value_range(
&self,
docids: &mut Vec<DocId>,
values: &mut Vec<Option<T>>,
value_range: ValueRange<T>,
) {
const BLOCK_LEN: usize = 64; // Corresponds to COLLECT_BLOCK_BUFFER_LEN in tantivy's docset
match (&self.index, value_range) {
(ColumnIndex::Empty { .. }, value_range) => {
let nulls_match = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(_) => false,
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
ValueRange::LessThan(_, nulls_match) => *nulls_match,
};
if nulls_match {
for _ in 0..docids.len() {
values.push(None);
}
} else {
docids.clear();
}
}
(ColumnIndex::Full, value_range) => {
self.values
.get_vals_in_value_range(docids, values, value_range);
}
(ColumnIndex::Optional(optional_index), value_range) => {
let len = docids.len();
// Ensure the input docids length does not exceed BLOCK_LEN for stack allocation
// safety. If it does, we might need to handle this with multiple
// chunks or fallback to heap. For now, an assert is used to confirm
// expected usage within batch processing limits.
assert!(
len <= BLOCK_LEN,
"Input docids length ({}) exceeds BLOCK_LEN ({})",
len,
BLOCK_LEN
);
let mut input_docs_buffer = [0u32; BLOCK_LEN];
input_docs_buffer[..len].copy_from_slice(docids);
let mut dense_row_ids_buffer = [0u32; BLOCK_LEN];
let mut dense_values_buffer = [T::default(); BLOCK_LEN];
let mut presence_mask: u64 = 0; // Bitmask to track which input_docs have a value
let mut num_present = 0;
// Phase 1: Identify existing RowIds and build dense_row_ids_buffer
for (i, &doc_id) in input_docs_buffer[..len].iter().enumerate() {
if let Some(row_id) = optional_index.rank_if_exists(doc_id) {
dense_row_ids_buffer[num_present] = row_id;
presence_mask |= 1u64 << i; // Set bit for present docid
num_present += 1;
}
}
// Phase 2: Batch fetch values for present docs
if num_present > 0 {
self.values.get_vals(
&dense_row_ids_buffer[..num_present],
&mut dense_values_buffer[..num_present],
);
}
// Determine if nulls match the value range
let nulls_match = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(_) => false,
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
ValueRange::LessThan(_, nulls_match) => *nulls_match,
};
// Phase 3: Filter and merge results, reconstructing docids and values
docids.clear();
values.clear();
let mut dense_values_cursor = 0;
for i in 0..len {
let original_doc_id = input_docs_buffer[i];
if (presence_mask & (1u64 << i)) != 0 {
// This doc_id was present in the optional index and has a value
let val = dense_values_buffer[dense_values_cursor];
dense_values_cursor += 1;
// Check if the value matches the value range
let value_matches = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(&val),
ValueRange::GreaterThan(t, _) => val > *t,
ValueRange::LessThan(t, _) => val < *t,
};
if value_matches {
docids.push(original_doc_id);
values.push(Some(val));
}
} else if nulls_match {
// This doc_id was not present in the optional index (null) and nulls match
docids.push(original_doc_id);
values.push(None);
}
}
}
(ColumnIndex::Multivalued(multivalued_index), value_range) => {
let nulls_match = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(_) => false,
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
ValueRange::LessThan(_, nulls_match) => *nulls_match,
};
let mut write_head = 0;
for i in 0..docids.len() {
let docid = docids[i];
let row_range = multivalued_index.range(docid);
let is_empty = row_range.start == row_range.end;
if !is_empty {
let val = self.values.get_val(row_range.start);
let matches = match &value_range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(&val),
ValueRange::GreaterThan(t, _) => val > *t,
ValueRange::LessThan(t, _) => val < *t,
};
if matches {
docids[write_head] = docid;
values.push(Some(val));
write_head += 1;
}
} else if nulls_match {
docids[write_head] = docid;
values.push(None);
write_head += 1;
}
}
docids.truncate(write_head);
}
}
}
}
/// A range of values.
///
/// This type is intended to be used in batch APIs, where the cost of unpacking the enum
/// is outweighed by the time spent processing a batch.
///
/// Implementers should pattern match on the variants to use optimized loops for each case.
#[derive(Clone, Debug)]
pub enum ValueRange<T> {
/// A range that includes both start and end.
Inclusive(RangeInclusive<T>),
/// A range that matches all values.
All,
/// A range that matches all values greater than the threshold.
/// The boolean flag indicates if null values should be included.
GreaterThan(T, bool),
/// A range that matches all values less than the threshold.
/// The boolean flag indicates if null values should be included.
LessThan(T, bool),
}
impl<T: PartialOrd> ValueRange<T> {
pub fn intersects(&self, min: T, max: T) -> bool {
match self {
ValueRange::Inclusive(range) => *range.start() <= max && *range.end() >= min,
ValueRange::All => true,
ValueRange::GreaterThan(val, _) => max > *val,
ValueRange::LessThan(val, _) => min < *val,
}
}
}
impl BinarySerializable for Cardinality {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> std::io::Result<()> {
self.to_code().serialize(writer)

View File

@@ -333,7 +333,7 @@ mod tests {
use std::ops::Range;
use super::MultiValueIndex;
use crate::{ColumnarReader, DynamicColumn};
use crate::{ColumnarReader, DynamicColumn, ValueRange};
fn index_to_pos_helper(
index: &MultiValueIndex,
@@ -413,7 +413,7 @@ mod tests {
assert_eq!(row_id_range, 0..4);
let check = |range, expected| {
let full_range = 0..=u64::MAX;
let full_range = ValueRange::All;
let mut docids = Vec::new();
column.get_docids_for_value_range(full_range, range, &mut docids);
assert_eq!(docids, expected);

View File

@@ -7,13 +7,15 @@
//! - Monotonically map values to u64/u128
use std::fmt::Debug;
use std::ops::{Range, RangeInclusive};
use std::ops::Range;
use std::sync::Arc;
use downcast_rs::DowncastSync;
pub use monotonic_mapping::{MonotonicallyMappableToU64, StrictlyMonotonicFn};
pub use monotonic_mapping_u128::MonotonicallyMappableToU128;
use crate::column::ValueRange;
mod merge;
pub(crate) mod monotonic_mapping;
pub(crate) mod monotonic_mapping_u128;
@@ -109,6 +111,178 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
}
}
/// Load the values for the provided docids.
///
/// The values are filtered by the provided value range.
fn get_vals_in_value_range(
&self,
indexes: &mut Vec<u32>,
output: &mut Vec<Option<T>>,
value_range: ValueRange<T>,
) {
let mut write_head = 0;
let mut read_head = 0;
let len = indexes.len();
match value_range {
ValueRange::All => {
while read_head + 3 < len {
let idx0 = indexes[read_head];
let idx1 = indexes[read_head + 1];
let idx2 = indexes[read_head + 2];
let idx3 = indexes[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
indexes[write_head] = idx0;
output.push(Some(val0));
write_head += 1;
indexes[write_head] = idx1;
output.push(Some(val1));
write_head += 1;
indexes[write_head] = idx2;
output.push(Some(val2));
write_head += 1;
indexes[write_head] = idx3;
output.push(Some(val3));
write_head += 1;
read_head += 4;
}
}
ValueRange::Inclusive(ref range) => {
while read_head + 3 < len {
let idx0 = indexes[read_head];
let idx1 = indexes[read_head + 1];
let idx2 = indexes[read_head + 2];
let idx3 = indexes[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if range.contains(&val0) {
indexes[write_head] = idx0;
output.push(Some(val0));
write_head += 1;
}
if range.contains(&val1) {
indexes[write_head] = idx1;
output.push(Some(val1));
write_head += 1;
}
if range.contains(&val2) {
indexes[write_head] = idx2;
output.push(Some(val2));
write_head += 1;
}
if range.contains(&val3) {
indexes[write_head] = idx3;
output.push(Some(val3));
write_head += 1;
}
read_head += 4;
}
}
ValueRange::GreaterThan(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = indexes[read_head];
let idx1 = indexes[read_head + 1];
let idx2 = indexes[read_head + 2];
let idx3 = indexes[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 > *threshold {
indexes[write_head] = idx0;
output.push(Some(val0));
write_head += 1;
}
if val1 > *threshold {
indexes[write_head] = idx1;
output.push(Some(val1));
write_head += 1;
}
if val2 > *threshold {
indexes[write_head] = idx2;
output.push(Some(val2));
write_head += 1;
}
if val3 > *threshold {
indexes[write_head] = idx3;
output.push(Some(val3));
write_head += 1;
}
read_head += 4;
}
}
ValueRange::LessThan(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = indexes[read_head];
let idx1 = indexes[read_head + 1];
let idx2 = indexes[read_head + 2];
let idx3 = indexes[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
if val0 < *threshold {
indexes[write_head] = idx0;
output.push(Some(val0));
write_head += 1;
}
if val1 < *threshold {
indexes[write_head] = idx1;
output.push(Some(val1));
write_head += 1;
}
if val2 < *threshold {
indexes[write_head] = idx2;
output.push(Some(val2));
write_head += 1;
}
if val3 < *threshold {
indexes[write_head] = idx3;
output.push(Some(val3));
write_head += 1;
}
read_head += 4;
}
}
}
// Process remaining elements (0 to 3)
while read_head < len {
let idx = indexes[read_head];
let val = self.get_val(idx);
let matches = match value_range {
// 'value_range' is still moved here. This is the outer `value_range`
ValueRange::All => true,
ValueRange::Inclusive(ref r) => r.contains(&val),
ValueRange::GreaterThan(ref t, _) => val > *t,
ValueRange::LessThan(ref t, _) => val < *t,
};
if matches {
indexes[write_head] = idx;
output.push(Some(val));
write_head += 1;
}
read_head += 1;
}
indexes.truncate(write_head);
}
/// Fills an output buffer with the fast field values
/// associated with the `DocId` going from
/// `start` to `start + output.len()`.
@@ -129,15 +303,38 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
/// Note that position == docid for single value fast fields
fn get_row_ids_for_value_range(
&self,
value_range: RangeInclusive<T>,
value_range: ValueRange<T>,
row_id_range: Range<RowId>,
row_id_hits: &mut Vec<RowId>,
) {
let row_id_range = row_id_range.start..row_id_range.end.min(self.num_vals());
for idx in row_id_range {
let val = self.get_val(idx);
if value_range.contains(&val) {
row_id_hits.push(idx);
match value_range {
ValueRange::Inclusive(range) => {
for idx in row_id_range {
let val = self.get_val(idx);
if range.contains(&val) {
row_id_hits.push(idx);
}
}
}
ValueRange::GreaterThan(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val > threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::LessThan(threshold, _) => {
for idx in row_id_range {
let val = self.get_val(idx);
if val < threshold {
row_id_hits.push(idx);
}
}
}
ValueRange::All => {
row_id_hits.extend(row_id_range);
}
}
}
@@ -193,6 +390,16 @@ impl<T: PartialOrd + Default> ColumnValues<T> for EmptyColumnValues {
fn num_vals(&self) -> u32 {
0
}
fn get_vals_in_value_range(
&self,
indexes: &mut Vec<u32>,
output: &mut Vec<Option<T>>,
value_range: ValueRange<T>,
) {
let _ = (indexes, output, value_range);
panic!("Internal Error: Called get_vals_in_value_range of empty column.")
}
}
impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnValues<T>> {
@@ -206,6 +413,17 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
self.as_ref().get_vals_opt(indexes, output)
}
#[inline(always)]
fn get_vals_in_value_range(
&self,
indexes: &mut Vec<u32>,
output: &mut Vec<Option<T>>,
value_range: ValueRange<T>,
) {
self.as_ref()
.get_vals_in_value_range(indexes, output, value_range)
}
#[inline(always)]
fn min_value(&self) -> T {
self.as_ref().min_value()
@@ -234,7 +452,7 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
#[inline(always)]
fn get_row_ids_for_value_range(
&self,
range: RangeInclusive<T>,
range: ValueRange<T>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {

View File

@@ -1,8 +1,9 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use std::ops::{Range, RangeInclusive};
use std::ops::Range;
use crate::ColumnValues;
use crate::column::ValueRange;
use crate::column_values::monotonic_mapping::StrictlyMonotonicFn;
struct MonotonicMappingColumn<C, T, Input> {
@@ -80,16 +81,35 @@ where
fn get_row_ids_for_value_range(
&self,
range: RangeInclusive<Output>,
range: ValueRange<Output>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {
self.from_column.get_row_ids_for_value_range(
self.monotonic_mapping.inverse(range.start().clone())
..=self.monotonic_mapping.inverse(range.end().clone()),
doc_id_range,
positions,
)
match range {
ValueRange::Inclusive(range) => self.from_column.get_row_ids_for_value_range(
ValueRange::Inclusive(
self.monotonic_mapping.inverse(range.start().clone())
..=self.monotonic_mapping.inverse(range.end().clone()),
),
doc_id_range,
positions,
),
ValueRange::All => self.from_column.get_row_ids_for_value_range(
ValueRange::All,
doc_id_range,
positions,
),
ValueRange::GreaterThan(threshold, _) => self.from_column.get_row_ids_for_value_range(
ValueRange::GreaterThan(self.monotonic_mapping.inverse(threshold), false),
doc_id_range,
positions,
),
ValueRange::LessThan(threshold, _) => self.from_column.get_row_ids_for_value_range(
ValueRange::LessThan(self.monotonic_mapping.inverse(threshold), false),
doc_id_range,
positions,
),
}
}
// We voluntarily do not implement get_range as it yields a regression,

View File

@@ -25,6 +25,7 @@ use common::{BinarySerializable, CountingWriter, OwnedBytes, VInt, VIntU128};
use tantivy_bitpacker::{BitPacker, BitUnpacker};
use crate::RowId;
use crate::column::ValueRange;
use crate::column_values::ColumnValues;
/// The cost per blank is quite hard actually, since blanks are delta encoded, the actual cost of
@@ -338,14 +339,36 @@ impl ColumnValues<u64> for CompactSpaceU64Accessor {
#[inline]
fn get_row_ids_for_value_range(
&self,
value_range: RangeInclusive<u64>,
value_range: ValueRange<u64>,
position_range: Range<u32>,
positions: &mut Vec<u32>,
) {
let value_range = self.0.compact_to_u128(*value_range.start() as u32)
..=self.0.compact_to_u128(*value_range.end() as u32);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
match value_range {
ValueRange::Inclusive(value_range) => {
let value_range = ValueRange::Inclusive(
self.0.compact_to_u128(*value_range.start() as u32)
..=self.0.compact_to_u128(*value_range.end() as u32),
);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::All => {
let position_range = position_range.start..position_range.end.min(self.num_vals());
positions.extend(position_range);
}
ValueRange::GreaterThan(threshold, _) => {
let value_range =
ValueRange::GreaterThan(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
ValueRange::LessThan(threshold, _) => {
let value_range =
ValueRange::LessThan(self.0.compact_to_u128(threshold as u32), false);
self.0
.get_row_ids_for_value_range(value_range, position_range, positions)
}
}
}
}
@@ -375,10 +398,33 @@ impl ColumnValues<u128> for CompactSpaceDecompressor {
#[inline]
fn get_row_ids_for_value_range(
&self,
value_range: RangeInclusive<u128>,
value_range: ValueRange<u128>,
position_range: Range<u32>,
positions: &mut Vec<u32>,
) {
let value_range = match value_range {
ValueRange::Inclusive(value_range) => value_range,
ValueRange::All => {
let position_range = position_range.start..position_range.end.min(self.num_vals());
positions.extend(position_range);
return;
}
ValueRange::GreaterThan(threshold, _) => {
let max = self.max_value();
if threshold >= max {
return;
}
(threshold + 1)..=max
}
ValueRange::LessThan(threshold, _) => {
let min = self.min_value();
if threshold <= min {
return;
}
min..=(threshold - 1)
}
};
if value_range.start() > value_range.end() {
return;
}
@@ -560,7 +606,7 @@ mod tests {
.collect::<Vec<_>>();
let mut positions = Vec::new();
decompressor.get_row_ids_for_value_range(
range,
ValueRange::Inclusive(range),
0..decompressor.num_vals(),
&mut positions,
);
@@ -604,7 +650,11 @@ mod tests {
let val = *val;
let pos = pos as u32;
let mut positions = Vec::new();
decomp.get_row_ids_for_value_range(val..=val, pos..pos + 1, &mut positions);
decomp.get_row_ids_for_value_range(
ValueRange::Inclusive(val..=val),
pos..pos + 1,
&mut positions,
);
assert_eq!(positions, vec![pos]);
}
@@ -746,7 +796,11 @@ mod tests {
doc_id_range: Range<u32>,
) -> Vec<u32> {
let mut positions = Vec::new();
column.get_row_ids_for_value_range(value_range, doc_id_range, &mut positions);
column.get_row_ids_for_value_range(
ValueRange::Inclusive(value_range),
doc_id_range,
&mut positions,
);
positions
}

View File

@@ -6,6 +6,7 @@ use common::{BinarySerializable, OwnedBytes};
use fastdivide::DividerU64;
use tantivy_bitpacker::{BitPacker, BitUnpacker, compute_num_bits};
use crate::column::ValueRange;
use crate::column_values::u64_based::{ColumnCodec, ColumnCodecEstimator, ColumnStats};
use crate::{ColumnValues, RowId};
@@ -66,24 +67,173 @@ impl ColumnValues for BitpackedReader {
self.stats.num_rows
}
fn get_vals_in_value_range(
&self,
indexes: &mut Vec<u32>,
output: &mut Vec<Option<u64>>,
value_range: ValueRange<u64>,
) {
let mut write_head = 0;
match value_range {
ValueRange::All => {
for i in 0..indexes.len() {
let idx = indexes[i];
indexes[write_head] = idx;
output.push(Some(self.get_val(idx)));
write_head += 1;
}
}
ValueRange::Inclusive(range) => {
if let Some(transformed_range) =
transform_range_before_linear_transformation(&self.stats, range)
{
for i in 0..indexes.len() {
let doc = indexes[i];
let raw_val = self.get_val(doc);
if transformed_range.contains(&raw_val) {
indexes[write_head] = doc;
output
.push(Some(self.stats.min_value + self.stats.gcd.get() * raw_val));
write_head += 1;
}
}
}
}
ValueRange::GreaterThan(threshold, _) => {
if threshold < self.stats.min_value {
for i in 0..indexes.len() {
let idx = indexes[i];
indexes[write_head] = idx;
output.push(Some(self.get_val(idx)));
write_head += 1;
}
} else if threshold >= self.stats.max_value {
// All filtered out
} else {
let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get();
for i in 0..indexes.len() {
let doc = indexes[i];
let raw_val = self.get_val(doc);
if raw_val > raw_threshold {
indexes[write_head] = doc;
output
.push(Some(self.stats.min_value + self.stats.gcd.get() * raw_val));
write_head += 1;
}
}
}
}
ValueRange::LessThan(threshold, _) => {
if threshold > self.stats.max_value {
for i in 0..indexes.len() {
let idx = indexes[i];
indexes[write_head] = idx;
output.push(Some(self.get_val(idx)));
write_head += 1;
}
} else if threshold <= self.stats.min_value {
// All filtered out
} else {
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
let raw_threshold = if diff % gcd == 0 {
diff / gcd
} else {
diff / gcd + 1
};
for i in 0..indexes.len() {
let doc = indexes[i];
let raw_val = self.get_val(doc);
if raw_val < raw_threshold {
indexes[write_head] = doc;
output
.push(Some(self.stats.min_value + self.stats.gcd.get() * raw_val));
write_head += 1;
}
}
}
}
}
indexes.truncate(write_head);
}
fn get_row_ids_for_value_range(
&self,
range: RangeInclusive<u64>,
range: ValueRange<u64>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {
let Some(transformed_range) =
transform_range_before_linear_transformation(&self.stats, range)
else {
positions.clear();
return;
};
self.bit_unpacker.get_ids_for_value_range(
transformed_range,
doc_id_range,
&self.data,
positions,
);
match range {
ValueRange::All => {
positions.extend(doc_id_range);
return;
}
ValueRange::Inclusive(range) => {
let Some(transformed_range) =
transform_range_before_linear_transformation(&self.stats, range)
else {
positions.clear();
return;
};
self.bit_unpacker.get_ids_for_value_range(
transformed_range,
doc_id_range,
&self.data,
positions,
);
}
ValueRange::GreaterThan(threshold, _) => {
if threshold < self.stats.min_value {
positions.extend(doc_id_range);
return;
}
if threshold >= self.stats.max_value {
return;
}
let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get();
let max_raw = (self.stats.max_value - self.stats.min_value) / self.stats.gcd.get();
let transformed_range = (raw_threshold + 1)..=max_raw;
self.bit_unpacker.get_ids_for_value_range(
transformed_range,
doc_id_range,
&self.data,
positions,
);
}
ValueRange::LessThan(threshold, _) => {
if threshold > self.stats.max_value {
positions.extend(doc_id_range);
return;
}
if threshold <= self.stats.min_value {
return;
}
let diff = threshold - self.stats.min_value;
let gcd = self.stats.gcd.get();
// We want raw < raw_threshold_limit
// raw <= raw_threshold_limit - 1
let raw_threshold_limit = if diff % gcd == 0 {
diff / gcd
} else {
diff / gcd + 1
};
if raw_threshold_limit == 0 {
return;
}
let transformed_range = 0..=(raw_threshold_limit - 1);
self.bit_unpacker.get_ids_for_value_range(
transformed_range,
doc_id_range,
&self.data,
positions,
);
}
}
}
}

View File

@@ -131,7 +131,7 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
.collect();
let mut positions = Vec::new();
reader.get_row_ids_for_value_range(
vals[test_rand_idx]..=vals[test_rand_idx],
crate::column::ValueRange::Inclusive(vals[test_rand_idx]..=vals[test_rand_idx]),
0..vals.len() as u32,
&mut positions,
);

View File

@@ -36,7 +36,7 @@ pub(crate) mod utils;
mod value;
pub use block_accessor::ColumnBlockAccessor;
pub use column::{BytesColumn, Column, StrColumn};
pub use column::{BytesColumn, Column, StrColumn, ValueRange};
pub use column_index::ColumnIndex;
pub use column_values::{
ColumnValues, EmptyColumnValues, MonotonicallyMappableToU64, MonotonicallyMappableToU128,

View File

@@ -389,6 +389,52 @@ pub(crate) mod tests {
Ok(())
}
#[test]
fn test_order_by_compound_fast_fields() -> crate::Result<()> {
let index = make_index()?;
type CompoundSortKey = (Option<String>, Option<f64>);
fn assert_query(
index: &Index,
city_order: Order,
altitude_order: Order,
expected: Vec<(CompoundSortKey, u64)>,
) -> crate::Result<()> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by((
(SortByString::for_field("city"), city_order),
(
SortByStaticFastValue::<f64>::for_field("altitude"),
altitude_order,
),
));
let actual = searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(key, doc)| (key, ids[&doc]))
.collect::<Vec<_>>();
assert_eq!(actual, expected);
Ok(())
}
assert_query(
&index,
Order::Asc,
Order::Desc,
vec![
((Some("austin".to_owned()), Some(149.0)), 0),
((Some("greenville".to_owned()), Some(27.0)), 1),
((Some("tokyo".to_owned()), Some(40.0)), 2),
((None, Some(0.0)), 3),
],
)?;
Ok(())
}
use proptest::prelude::*;
proptest! {
@@ -451,4 +497,67 @@ pub(crate) mod tests {
);
}
}
#[test]
fn test_order_by_compound_filtering_with_none() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let city = schema_builder.add_text_field("city", TEXT | FAST);
let altitude = schema_builder.add_u64_field("altitude", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests()?;
// Add enough docs to trigger thresholding.
// We want to sort by City Asc, Altitude Asc.
// Note: In NaturalComparator, None < Some.
// So Ascending order should be: None, then "a", then "b", then "c".
// Docs:
// 0: "c", 10
// 1: "b", 10
// 2: "a", 20
// 3: "a", 10
// 4: None, 5
// Expected Ascending Order (None is Last in Tantivy's Order::Asc):
// 1. Doc 3 ("a", 10)
// 2. Doc 2 ("a", 20)
// 3. Doc 1 ("b", 10)
// 4. Doc 0 ("c", 10)
// 5. Doc 4 (None, 5)
index_writer.add_document(doc!(city => "c", altitude => 10u64))?;
index_writer.add_document(doc!(city => "b", altitude => 10u64))?;
index_writer.add_document(doc!(city => "a", altitude => 20u64))?;
index_writer.add_document(doc!(city => "a", altitude => 10u64))?;
index_writer.add_document(doc!(altitude => 5u64))?; // City is None
index_writer.commit()?;
let searcher = index.reader()?.searcher();
// Use limit(2) to force a threshold update after the first few docs.
// The collector should eventually establish a threshold around ("a", 20) (Top 2: "a" 10,
// "a" 20). Then when seeing "b" and "c", it should filter them out based on the
// head key "city". This confirms that when filtering happens, the DocIds are
// preserved correctly.
let top_collector = TopDocs::with_limit(2).order_by((
(SortByString::for_field("city"), Order::Asc),
(
SortByStaticFastValue::<u64>::for_field("altitude"),
Order::Asc,
),
));
let results: Vec<DocAddress> = searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(_, doc)| doc)
.collect();
// Doc 3 is ("a", 10). Doc 2 is ("a", 20).
assert_eq!(results, vec![DocAddress::new(0, 3), DocAddress::new(0, 2)]);
Ok(())
}
}

View File

@@ -1,6 +1,6 @@
use std::cmp::Ordering;
use columnar::MonotonicallyMappableToU64;
use columnar::{MonotonicallyMappableToU64, ValueRange};
use serde::{Deserialize, Serialize};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
@@ -69,6 +69,10 @@ fn compare_owned_value<const NULLS_FIRST: bool>(lhs: &OwnedValue, rhs: &OwnedVal
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
/// Return the order between two values.
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
/// Return a `ValueRange` that matches all values that are greater than the provided threshold.
#[allow(dead_code)]
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T>;
}
/// Compare values naturally (e.g. 1 < 2).
@@ -86,6 +90,10 @@ impl<T: PartialOrd> Comparator<T> for NaturalComparator {
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal)
}
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T> {
ValueRange::GreaterThan(threshold, false)
}
}
/// A (partial) implementation of comparison for OwnedValue.
@@ -97,6 +105,10 @@ impl Comparator<OwnedValue> for NaturalComparator {
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ true>(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
ValueRange::GreaterThan(threshold, false)
}
}
/// Compare values in reverse (e.g. 2 < 1).
@@ -121,6 +133,10 @@ where NaturalComparator: Comparator<T>
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
NaturalComparator.compare(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T> {
ValueRange::LessThan(threshold, true)
}
}
/// Compare values in reverse, but treating `None` as lower than `Some`.
@@ -147,6 +163,10 @@ where ReverseComparator: Comparator<T>
(Some(lhs), Some(rhs)) => ReverseComparator.compare(lhs, rhs),
}
}
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<u32> for ReverseNoneIsLowerComparator {
@@ -154,6 +174,10 @@ impl Comparator<u32> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange<u32> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<u64> for ReverseNoneIsLowerComparator {
@@ -161,6 +185,10 @@ impl Comparator<u64> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange<u64> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<f64> for ReverseNoneIsLowerComparator {
@@ -168,6 +196,10 @@ impl Comparator<f64> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange<f64> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<f32> for ReverseNoneIsLowerComparator {
@@ -175,6 +207,10 @@ impl Comparator<f32> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange<f32> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<i64> for ReverseNoneIsLowerComparator {
@@ -182,6 +218,10 @@ impl Comparator<i64> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange<i64> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<String> for ReverseNoneIsLowerComparator {
@@ -189,6 +229,10 @@ impl Comparator<String> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: String) -> ValueRange<String> {
ValueRange::LessThan(threshold, false)
}
}
impl Comparator<OwnedValue> for ReverseNoneIsLowerComparator {
@@ -196,6 +240,10 @@ impl Comparator<OwnedValue> for ReverseNoneIsLowerComparator {
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ false>(rhs, lhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
ValueRange::LessThan(threshold, false)
}
}
/// Compare values naturally, but treating `None` as higher than `Some`.
@@ -218,6 +266,10 @@ where NaturalComparator: Comparator<T>
(Some(lhs), Some(rhs)) => NaturalComparator.compare(lhs, rhs),
}
}
fn threshold_to_valuerange(&self, threshold: Option<T>) -> ValueRange<Option<T>> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<u32> for NaturalNoneIsHigherComparator {
@@ -225,6 +277,10 @@ impl Comparator<u32> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u32) -> ValueRange<u32> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<u64> for NaturalNoneIsHigherComparator {
@@ -232,6 +288,10 @@ impl Comparator<u64> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: u64) -> ValueRange<u64> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<f64> for NaturalNoneIsHigherComparator {
@@ -239,6 +299,10 @@ impl Comparator<f64> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f64) -> ValueRange<f64> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<f32> for NaturalNoneIsHigherComparator {
@@ -246,6 +310,10 @@ impl Comparator<f32> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: f32) -> ValueRange<f32> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<i64> for NaturalNoneIsHigherComparator {
@@ -253,6 +321,10 @@ impl Comparator<i64> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: i64) -> ValueRange<i64> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<String> for NaturalNoneIsHigherComparator {
@@ -260,6 +332,10 @@ impl Comparator<String> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: String) -> ValueRange<String> {
ValueRange::GreaterThan(threshold, true)
}
}
impl Comparator<OwnedValue> for NaturalNoneIsHigherComparator {
@@ -267,6 +343,10 @@ impl Comparator<OwnedValue> for NaturalNoneIsHigherComparator {
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ false>(lhs, rhs)
}
fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange<OwnedValue> {
ValueRange::GreaterThan(threshold, true)
}
}
/// An enum representing the different sort orders.
@@ -308,6 +388,19 @@ where
ComparatorEnum::NaturalNoneHigher => NaturalNoneIsHigherComparator.compare(lhs, rhs),
}
}
fn threshold_to_valuerange(&self, threshold: T) -> ValueRange<T> {
match self {
ComparatorEnum::Natural => NaturalComparator.threshold_to_valuerange(threshold),
ComparatorEnum::Reverse => ReverseComparator.threshold_to_valuerange(threshold),
ComparatorEnum::ReverseNoneLower => {
ReverseNoneIsLowerComparator.threshold_to_valuerange(threshold)
}
ComparatorEnum::NaturalNoneHigher => {
NaturalNoneIsHigherComparator.threshold_to_valuerange(threshold)
}
}
}
}
impl<Head, Tail, LeftComparator, RightComparator> Comparator<(Head, Tail)>
@@ -322,6 +415,10 @@ where
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
}
fn threshold_to_valuerange(&self, threshold: (Head, Tail)) -> ValueRange<(Head, Tail)> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, (Type2, Type3))>
@@ -338,6 +435,13 @@ where
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
.then_with(|| self.2.compare(&lhs.1 .1, &rhs.1 .1))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, (Type2, Type3)),
) -> ValueRange<(Type1, (Type2, Type3))> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, Type2, Type3)>
@@ -354,6 +458,13 @@ where
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, Type2, Type3),
) -> ValueRange<(Type1, Type2, Type3)> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
@@ -377,6 +488,13 @@ where
.then_with(|| self.2.compare(&lhs.1 .1 .0, &rhs.1 .1 .0))
.then_with(|| self.3.compare(&lhs.1 .1 .1, &rhs.1 .1 .1))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, (Type2, (Type3, Type4))),
) -> ValueRange<(Type1, (Type2, (Type3, Type4)))> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
@@ -400,6 +518,13 @@ where
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
.then_with(|| self.3.compare(&lhs.3, &rhs.3))
}
fn threshold_to_valuerange(
&self,
threshold: (Type1, Type2, Type3, Type4),
) -> ValueRange<(Type1, Type2, Type3, Type4)> {
ValueRange::GreaterThan(threshold, false)
}
}
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, ComparatorEnum)
@@ -489,16 +614,29 @@ impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComput
where
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
TSegmentSortKey: Clone + 'static + Sync + Send,
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
TComparator: Comparator<TSegmentSortKey> + Clone + 'static + Sync + Send,
{
type SortKey = TSegmentSortKeyComputer::SortKey;
type SegmentSortKey = TSegmentSortKey;
type SegmentComparator = TComparator;
fn segment_comparator(&self) -> Self::SegmentComparator {
self.comparator.clone()
}
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.segment_sort_key_computer.segment_sort_key(doc, score)
}
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.segment_sort_key_computer
.segment_sort_keys(docs, filter)
}
#[inline(always)]
fn compare_segment_sort_key(
&self,

View File

@@ -1,5 +1,6 @@
use columnar::{ColumnType, MonotonicallyMappableToU64};
use columnar::{ColumnType, MonotonicallyMappableToU64, ValueRange};
use crate::collector::sort_key::sort_by_score::SortBySimilarityScoreSegmentComputer;
use crate::collector::sort_key::{
NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString,
};
@@ -36,6 +37,11 @@ impl SortByErasedType {
trait ErasedSegmentSortKeyComputer: Send + Sync {
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64>;
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Option<u64>>,
) -> &mut Vec<(DocId, Option<u64>)>;
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue;
}
@@ -53,6 +59,14 @@ where
self.inner.segment_sort_key(doc, score)
}
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Option<u64>>,
) -> &mut Vec<(DocId, Option<u64>)> {
self.inner.segment_sort_keys(docs, filter)
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
let val = self.inner.convert_segment_sort_key(sort_key);
(self.converter)(val)
@@ -60,7 +74,7 @@ where
}
struct ScoreSegmentSortKeyComputer {
segment_computer: SortBySimilarityScore,
segment_computer: SortBySimilarityScoreSegmentComputer,
}
impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
@@ -69,6 +83,14 @@ impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
Some(score_value.to_u64())
}
fn segment_sort_keys(
&mut self,
_docs: &[DocId],
_filter: ValueRange<Option<u64>>,
) -> &mut Vec<(DocId, Option<u64>)> {
unimplemented!("Batch computation not supported for score sorting")
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
let score_value: u64 = sort_key.expect("This implementation always produces a score.");
OwnedValue::F64(f64::from_u64(score_value))
@@ -174,7 +196,8 @@ impl SortKeyComputer for SortByErasedType {
}
}
Self::Score => Box::new(ScoreSegmentSortKeyComputer {
segment_computer: SortBySimilarityScore,
segment_computer: SortBySimilarityScore
.segment_sort_key_computer(segment_reader)?,
}),
};
Ok(ErasedColumnSegmentSortKeyComputer { inner })
@@ -195,6 +218,14 @@ impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer {
self.inner.segment_sort_key(doc, score)
}
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.inner.segment_sort_keys(docs, filter)
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue {
self.inner.convert_segment_sort_key(segment_sort_key)
}

View File

@@ -1,3 +1,5 @@
use columnar::ValueRange;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::{DocAddress, DocId, Score};
@@ -9,7 +11,7 @@ pub struct SortBySimilarityScore;
impl SortKeyComputer for SortBySimilarityScore {
type SortKey = Score;
type Child = SortBySimilarityScore;
type Child = SortBySimilarityScoreSegmentComputer;
type Comparator = NaturalComparator;
@@ -21,7 +23,7 @@ impl SortKeyComputer for SortBySimilarityScore {
&self,
_segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
Ok(SortBySimilarityScore)
Ok(SortBySimilarityScoreSegmentComputer)
}
// Sorting by score is special in that it allows for the Block-Wand optimization.
@@ -61,7 +63,9 @@ impl SortKeyComputer for SortBySimilarityScore {
}
}
impl SegmentSortKeyComputer for SortBySimilarityScore {
pub struct SortBySimilarityScoreSegmentComputer;
impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer {
type SortKey = Score;
type SegmentSortKey = Score;
type SegmentComparator = NaturalComparator;
@@ -71,6 +75,14 @@ impl SegmentSortKeyComputer for SortBySimilarityScore {
score
}
fn segment_sort_keys(
&mut self,
_docs: &[DocId],
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
unimplemented!("Batch computation not supported for score sorting")
}
fn convert_segment_sort_key(&self, score: Score) -> Score {
score
}

View File

@@ -1,7 +1,8 @@
use std::marker::PhantomData;
use columnar::Column;
use columnar::{Column, ValueRange};
use crate::collector::sort_key::sort_key_computer::convert_optional_u64_range_to_u64_range;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
@@ -71,6 +72,9 @@ impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
Ok(SortByFastValueSegmentSortKeyComputer {
sort_column,
typ: PhantomData,
buffer: Vec::new(),
fetch_buffer: Vec::new(),
doc_buffer: Vec::new(),
})
}
}
@@ -78,6 +82,9 @@ impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
pub struct SortByFastValueSegmentSortKeyComputer<T> {
sort_column: Column<u64>,
typ: PhantomData<T>,
buffer: Vec<(DocId, Option<u64>)>,
fetch_buffer: Vec<Option<u64>>,
doc_buffer: Vec<DocId>,
}
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
@@ -90,7 +97,102 @@ impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyCompu
self.sort_column.first(doc)
}
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.doc_buffer.clear();
self.doc_buffer.extend_from_slice(docs);
self.fetch_buffer.clear();
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
self.sort_column.first_vals_in_value_range(
&mut self.doc_buffer,
&mut self.fetch_buffer,
u64_filter,
);
self.buffer.clear();
for (&doc, &val) in self.doc_buffer.iter().zip(self.fetch_buffer.iter()) {
self.buffer.push((doc, val));
}
&mut self.buffer
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key.map(T::from_u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{Schema, FAST};
use crate::Index;
#[test]
fn test_sort_by_fast_value_batch() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_u64_field("field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => 10u64))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => 20u64))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByStaticFastValue::<u64>::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let docs = vec![0, 1, 2];
let output = computer.segment_sort_keys(&docs, ValueRange::All);
assert_eq!(output, &[(0, Some(10)), (1, Some(20)), (2, None)]);
}
#[test]
fn test_sort_by_fast_value_batch_with_filter() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_u64_field("field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => 10u64))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => 20u64))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByStaticFastValue::<u64>::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let docs = vec![0, 1, 2];
let output = computer.segment_sort_keys(
&docs,
ValueRange::GreaterThan(Some(15u64), false /* inclusive */),
);
// Should contain only the document with value 20.
// Doc 0 (10) < 15
// Doc 2 (None) < 15
assert_eq!(output, &[(1, Some(20))]);
}
}

View File

@@ -1,5 +1,8 @@
use columnar::StrColumn;
use columnar::{StrColumn, ValueRange};
use crate::collector::sort_key::sort_key_computer::{
convert_optional_u64_range_to_u64_range, range_contains_none,
};
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
@@ -38,12 +41,20 @@ impl SortKeyComputer for SortByString {
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?;
Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt })
Ok(ByStringColumnSegmentSortKeyComputer {
str_column_opt,
buffer: Vec::new(),
fetch_buffer: Vec::new(),
doc_buffer: Vec::new(),
})
}
}
pub struct ByStringColumnSegmentSortKeyComputer {
str_column_opt: Option<StrColumn>,
buffer: Vec<(DocId, Option<TermOrdinal>)>,
fetch_buffer: Vec<Option<TermOrdinal>>,
doc_buffer: Vec<DocId>,
}
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
@@ -57,6 +68,37 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
str_column.ords().first(doc)
}
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.doc_buffer.clear();
self.doc_buffer.extend_from_slice(docs);
self.fetch_buffer.clear();
if let Some(str_column) = &self.str_column_opt {
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
str_column.ords().first_vals_in_value_range(
&mut self.doc_buffer,
&mut self.fetch_buffer,
u64_filter,
);
} else if range_contains_none(&filter) {
for _ in 0..docs.len() {
self.fetch_buffer.push(None);
}
} else {
self.doc_buffer.clear();
}
self.buffer.clear();
for (&doc, &val) in self.doc_buffer.iter().zip(self.fetch_buffer.iter()) {
self.buffer.push((doc, val));
}
&mut self.buffer
}
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
@@ -70,3 +112,80 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
String::try_from(bytes).ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{Schema, FAST, TEXT};
use crate::Index;
#[test]
fn test_sort_by_string_batch() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => "a"))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => "c"))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByString::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let docs = vec![0, 1, 2];
let output = computer.segment_sort_keys(&docs, ValueRange::All);
// We expect ordinals.
// "a" -> 0
// "c" -> 1
assert_eq!(output, &[(0, Some(0)), (1, Some(1)), (2, None)]);
}
#[test]
fn test_sort_by_string_batch_with_filter() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => "a"))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => "c"))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByString::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let docs = vec![0, 1, 2];
// Filter: > "b". "a" is 0, "c" is 1.
// We want > "a" (ord 0). So we filter > ord 0.
// 0 is "a", 1 is "c".
let output = computer.segment_sort_keys(
&docs,
ValueRange::GreaterThan(Some(0), false /* inclusive */),
);
// Should contain only the document with value "c" (ord 1).
assert_eq!(output, &[(1, Some(1))]);
}
}

View File

@@ -1,8 +1,13 @@
use std::cmp::Ordering;
use columnar::ValueRange;
use crate::collector::sort_key::{Comparator, NaturalComparator};
use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector;
use crate::collector::{default_collect_segment_impl, SegmentCollector as _, TopNComputer};
use crate::collector::top_score_collector::push_assuming_capacity;
use crate::collector::{
default_collect_segment_impl, ComparableDoc, SegmentCollector as _, TopNComputer,
};
use crate::schema::Schema;
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
@@ -21,7 +26,7 @@ pub trait SegmentSortKeyComputer: 'static {
type SegmentSortKey: 'static + Clone + Send + Sync + Clone;
/// Comparator type.
type SegmentComparator: Comparator<Self::SegmentSortKey> + 'static;
type SegmentComparator: Comparator<Self::SegmentSortKey> + Clone + 'static;
/// Returns the segment sort key comparator.
fn segment_comparator(&self) -> Self::SegmentComparator {
@@ -31,6 +36,16 @@ pub trait SegmentSortKeyComputer: 'static {
/// Computes the sort key for the given document and score.
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
/// Computes the sort keys for a batch of documents.
///
/// The computed sort keys are stored in an internal buffer and returned as a slice.
/// Subsequent calls to this method may reuse and overwrite the internal buffer.
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)>;
/// Computes the sort key and pushes the document in a TopN Computer.
///
/// When using a tuple as the sorting key, the sort key is evaluated in a lazy manner.
@@ -45,6 +60,42 @@ pub trait SegmentSortKeyComputer: 'static {
top_n_computer.push(sort_key, doc);
}
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
docs: &[DocId],
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
// The capacity of a TopNComputer is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so we
// should always be able to `reserve` space for the entire block.
top_n_computer.reserve(docs.len());
let comparator = self.segment_comparator();
let value_range = if let Some(threshold) = &top_n_computer.threshold {
comparator.threshold_to_valuerange(threshold.clone())
} else {
ValueRange::All
};
let sort_keys = self.segment_sort_keys(docs, value_range);
if let Some(threshold) = &top_n_computer.threshold {
let threshold = threshold.clone();
for (doc, sort_key) in sort_keys.drain(..) {
let cmp = comparator.compare(&sort_key, &threshold);
if cmp == Ordering::Greater {
// We validated at the top of the method that we have capacity.
top_n_computer.append_doc_unchecked(doc, sort_key);
}
}
} else {
// Eagerly push, without a threshold to compare to.
for (doc, sort_key) in sort_keys.drain(..) {
// We validated at the top of the method that we have capacity.
top_n_computer.append_doc_unchecked(doc, sort_key);
}
}
}
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
///
@@ -58,23 +109,24 @@ pub trait SegmentSortKeyComputer: 'static {
self.segment_comparator().compare(left, right)
}
/// Implementing this method makes it possible to avoid computing
/// a sort_key entirely if we can assess that it won't pass a threshold
/// with a partial computation.
/// Similar to `accept_sort_key_lazy`, but pushes results directly into the given buffer. Does
/// not support scoring.
///
/// This is currently used for lexicographic sorting.
fn accept_sort_key_lazy(
/// The buffer must have at least enough capacity for `docs` matches, or this method will
/// panic.
fn accept_sort_key_block_lazy(
&mut self,
doc_id: DocId,
score: Score,
docs: &[DocId],
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let sort_key = self.segment_sort_key(doc_id, score);
let cmp = self.compare_segment_sort_key(&sort_key, threshold);
if cmp == Ordering::Less {
None
} else {
Some((cmp, sort_key))
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
) {
let comparator = self.segment_comparator();
for &doc in docs {
let sort_key = self.segment_sort_key(doc, 0.0);
let cmp = comparator.compare(&sort_key, threshold);
if cmp != Ordering::Less {
push_assuming_capacity(ComparableDoc { sort_key, doc }, output);
}
}
}
@@ -145,7 +197,8 @@ where
TailSortKeyComputer: SortKeyComputer,
{
type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey);
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
type Child =
ChainSegmentSortKeyComputer<HeadSortKeyComputer::Child, TailSortKeyComputer::Child>;
type Comparator = (
HeadSortKeyComputer::Comparator,
@@ -157,10 +210,12 @@ where
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((
self.0.segment_sort_key_computer(segment_reader)?,
self.1.segment_sort_key_computer(segment_reader)?,
))
Ok(ChainSegmentSortKeyComputer {
head: self.0.segment_sort_key_computer(segment_reader)?,
tail: self.1.segment_sort_key_computer(segment_reader)?,
head_key_buffer: Vec::new(),
doc_buffer: Vec::new(),
})
}
/// Checks whether the schema is compatible with the sort key computer.
@@ -178,25 +233,68 @@ where
}
}
impl<HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer> SegmentSortKeyComputer
for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer)
pub struct ChainSegmentSortKeyComputer<Head, Tail>
where
HeadSegmentSortKeyComputer: SegmentSortKeyComputer,
TailSegmentSortKeyComputer: SegmentSortKeyComputer,
Head: SegmentSortKeyComputer,
Tail: SegmentSortKeyComputer,
{
type SortKey = (
HeadSegmentSortKeyComputer::SortKey,
TailSegmentSortKeyComputer::SortKey,
);
type SegmentSortKey = (
HeadSegmentSortKeyComputer::SegmentSortKey,
TailSegmentSortKeyComputer::SegmentSortKey,
);
head: Head,
tail: Tail,
head_key_buffer: Vec<Head::SegmentSortKey>,
doc_buffer: Vec<DocId>,
}
type SegmentComparator = (
HeadSegmentSortKeyComputer::SegmentComparator,
TailSegmentSortKeyComputer::SegmentComparator,
);
impl<Head, Tail> ChainSegmentSortKeyComputer<Head, Tail>
where
Head: SegmentSortKeyComputer,
Tail: SegmentSortKeyComputer,
{
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &<Self as SegmentSortKeyComputer>::SegmentSortKey,
) -> Option<(Ordering, <Self as SegmentSortKeyComputer>::SegmentSortKey)> {
let (head_threshold, tail_threshold) = threshold;
let head_sort_key = self.head.segment_sort_key(doc_id, score);
let head_cmp = self
.head
.compare_segment_sort_key(&head_sort_key, head_threshold);
if head_cmp == Ordering::Less {
None
} else if head_cmp == Ordering::Equal {
let tail_sort_key = self.tail.segment_sort_key(doc_id, score);
let tail_cmp = self
.tail
.compare_segment_sort_key(&tail_sort_key, tail_threshold);
if tail_cmp == Ordering::Less {
None
} else {
Some((tail_cmp, (head_sort_key, tail_sort_key)))
}
} else {
let tail_sort_key = self.tail.segment_sort_key(doc_id, score);
Some((head_cmp, (head_sort_key, tail_sort_key)))
}
}
}
impl<Head, Tail> SegmentSortKeyComputer for ChainSegmentSortKeyComputer<Head, Tail>
where
Head: SegmentSortKeyComputer,
Tail: SegmentSortKeyComputer,
{
type SortKey = (Head::SortKey, Tail::SortKey);
type SegmentSortKey = (Head::SegmentSortKey, Tail::SegmentSortKey);
type SegmentComparator = (Head::SegmentComparator, Tail::SegmentComparator);
fn segment_comparator(&self) -> Self::SegmentComparator {
(
self.head.segment_comparator(),
self.tail.segment_comparator(),
)
}
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
@@ -208,9 +306,17 @@ where
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
self.0
self.head
.compare_segment_sort_key(&left.0, &right.0)
.then_with(|| self.1.compare_segment_sort_key(&left.1, &right.1))
.then_with(|| self.tail.compare_segment_sort_key(&left.1, &right.1))
}
fn segment_sort_keys(
&mut self,
_docs: &[DocId],
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
unimplemented!("The head and the tail are accessed independently.");
}
#[inline(always)]
@@ -233,50 +339,89 @@ where
top_n_computer.append_doc(doc, sort_key);
}
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
let head_sort_key = self.0.segment_sort_key(doc, score);
let tail_sort_key = self.1.segment_sort_key(doc, score);
(head_sort_key, tail_sort_key)
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
docs: &[DocId],
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
// The capacity of a TopNComputer is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so we
// should always be able to `reserve` space for the entire block.
top_n_computer.reserve(docs.len());
if let Some(threshold) = &top_n_computer.threshold {
let (head_threshold, tail_threshold) = threshold.clone();
let head_cmp = self.head.segment_comparator();
let tail_cmp = self.tail.segment_comparator();
let head_filter = head_cmp.threshold_to_valuerange(head_threshold.clone());
let head_keys = self.head.segment_sort_keys(docs, head_filter);
self.doc_buffer.clear();
self.head_key_buffer.clear();
for (doc, head_key) in head_keys.drain(..) {
let cmp = head_cmp.compare(&head_key, &head_threshold);
if cmp != Ordering::Less {
self.doc_buffer.push(doc);
self.head_key_buffer.push(head_key);
}
}
if !self.doc_buffer.is_empty() {
let tail_keys = self
.tail
.segment_sort_keys(&self.doc_buffer, ValueRange::All);
for ((head_key, tail_key), &doc) in self
.head_key_buffer
.drain(..)
.zip(tail_keys.drain(..).map(|(_, k)| k))
.zip(self.doc_buffer.iter())
{
let head_ord = head_cmp.compare(&head_key, &head_threshold);
let ord = if head_ord == Ordering::Equal {
tail_cmp.compare(&tail_key, &tail_threshold)
} else {
head_ord
};
if ord == Ordering::Greater {
top_n_computer.append_doc_unchecked(doc, (head_key, tail_key));
}
}
}
} else {
// Eagerly push, without a threshold to compare to.
let head_keys = self.head.segment_sort_keys(docs, ValueRange::All);
let tail_keys = self.tail.segment_sort_keys(docs, ValueRange::All);
for ((doc, head_key), (_, tail_key)) in head_keys.drain(..).zip(tail_keys.drain(..)) {
// We validated at the top of the method that we have capacity.
top_n_computer.append_doc_unchecked(doc, (head_key, tail_key));
}
}
}
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let (head_threshold, tail_threshold) = threshold;
let (head_cmp, head_sort_key) =
self.0.accept_sort_key_lazy(doc_id, score, head_threshold)?;
if head_cmp == Ordering::Equal {
let (tail_cmp, tail_sort_key) =
self.1.accept_sort_key_lazy(doc_id, score, tail_threshold)?;
Some((tail_cmp, (head_sort_key, tail_sort_key)))
} else {
let tail_sort_key = self.1.segment_sort_key(doc_id, score);
Some((head_cmp, (head_sort_key, tail_sort_key)))
}
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
let head_sort_key = self.head.segment_sort_key(doc, score);
let tail_sort_key = self.tail.segment_sort_key(doc, score);
(head_sort_key, tail_sort_key)
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
let (head_sort_key, tail_sort_key) = sort_key;
(
self.0.convert_segment_sort_key(head_sort_key),
self.1.convert_segment_sort_key(tail_sort_key),
self.head.convert_segment_sort_key(head_sort_key),
self.tail.convert_segment_sort_key(tail_sort_key),
)
}
}
/// This struct is used as an adapter to take a sort key computer and map its score to another
/// new sort key.
pub struct MappedSegmentSortKeyComputer<T, PreviousSortKey, NewSortKey> {
pub struct MappedSegmentSortKeyComputer<T: SegmentSortKeyComputer, NewSortKey> {
sort_key_computer: T,
map: fn(PreviousSortKey) -> NewSortKey,
map: fn(T::SortKey) -> NewSortKey,
}
impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
for MappedSegmentSortKeyComputer<T, NewScore>
where
T: SegmentSortKeyComputer<SortKey = PreviousScore>,
PreviousScore: 'static + Clone + Send + Sync,
@@ -286,18 +431,21 @@ where
type SegmentSortKey = T::SegmentSortKey;
type SegmentComparator = T::SegmentComparator;
fn segment_comparator(&self) -> Self::SegmentComparator {
self.sort_key_computer.segment_comparator()
}
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.sort_key_computer.segment_sort_key(doc, score)
}
fn accept_sort_key_lazy(
fn segment_sort_keys(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
docs: &[DocId],
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.sort_key_computer
.accept_sort_key_lazy(doc_id, score, threshold)
.segment_sort_keys(docs, ValueRange::All)
}
#[inline(always)]
@@ -311,6 +459,15 @@ where
.compute_sort_key_and_collect(doc, score, top_n_computer);
}
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
docs: &[DocId],
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
self.sort_key_computer
.compute_sort_keys_and_collect(docs, top_n_computer);
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey {
(self.map)(
self.sort_key_computer
@@ -336,10 +493,6 @@ where
);
type Child = MappedSegmentSortKeyComputer<
<(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(SortKeyComputer2::SortKey, SortKeyComputer3::SortKey),
),
Self::SortKey,
>;
@@ -363,7 +516,17 @@ where
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let map = |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3);
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)),
sort_key_computer: ChainSegmentSortKeyComputer {
head: sort_key_computer1,
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer2,
tail: sort_key_computer3,
head_key_buffer: Vec::new(),
doc_buffer: Vec::new(),
},
head_key_buffer: Vec::new(),
doc_buffer: Vec::new(),
},
map,
})
}
@@ -398,13 +561,6 @@ where
SortKeyComputer1,
(SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)),
) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(
SortKeyComputer2::SortKey,
(SortKeyComputer3::SortKey, SortKeyComputer4::SortKey),
),
),
Self::SortKey,
>;
type SortKey = (
@@ -426,10 +582,22 @@ where
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?;
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: (
sort_key_computer1,
(sort_key_computer2, (sort_key_computer3, sort_key_computer4)),
),
sort_key_computer: ChainSegmentSortKeyComputer {
head: sort_key_computer1,
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer2,
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer3,
tail: sort_key_computer4,
head_key_buffer: Vec::new(),
doc_buffer: Vec::new(),
},
head_key_buffer: Vec::new(),
doc_buffer: Vec::new(),
},
head_key_buffer: Vec::new(),
doc_buffer: Vec::new(),
},
map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| {
(sort_key1, sort_key2, sort_key3, sort_key4)
},
@@ -452,6 +620,11 @@ where
}
}
pub struct FuncSegmentSortKeyComputer<F, TSortKey> {
func: F,
buffer: Vec<(DocId, TSortKey)>,
}
impl<F, SegmentF, TSortKey> SortKeyComputer for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF,
@@ -459,15 +632,18 @@ where
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
{
type SortKey = TSortKey;
type Child = SegmentF;
type Child = FuncSegmentSortKeyComputer<SegmentF, TSortKey>;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((self)(segment_reader))
Ok(FuncSegmentSortKeyComputer {
func: (self)(segment_reader),
buffer: Vec::new(),
})
}
}
impl<F, TSortKey> SegmentSortKeyComputer for F
impl<F, TSortKey> SegmentSortKeyComputer for FuncSegmentSortKeyComputer<F, TSortKey>
where
F: 'static + FnMut(DocId) -> TSortKey,
TSortKey: 'static + PartialOrd + Clone + Send + Sync,
@@ -477,7 +653,20 @@ where
type SegmentComparator = NaturalComparator;
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
(self)(doc)
(self.func)(doc)
}
fn segment_sort_keys(
&mut self,
docs: &[DocId],
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.buffer.clear();
self.buffer.reserve(docs.len());
for &doc in docs {
self.buffer.push((doc, (self.func)(doc)));
}
&mut self.buffer
}
/// Convert a segment level score into the global level score.
@@ -486,6 +675,34 @@ where
}
}
pub(crate) fn range_contains_none(range: &ValueRange<Option<u64>>) -> bool {
match range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(&None),
ValueRange::GreaterThan(threshold, match_nulls) => *match_nulls || (None > *threshold),
ValueRange::LessThan(threshold, match_nulls) => *match_nulls || (None < *threshold),
}
}
pub(crate) fn convert_optional_u64_range_to_u64_range(
range: ValueRange<Option<u64>>,
) -> ValueRange<u64> {
if range_contains_none(&range) {
return ValueRange::All;
}
match range {
ValueRange::Inclusive(r) => {
let start = r.start().unwrap_or(0);
let end = r.end().unwrap_or(u64::MAX);
ValueRange::Inclusive(start..=end)
}
ValueRange::GreaterThan(Some(val), _match_nulls) => ValueRange::GreaterThan(val, false),
ValueRange::GreaterThan(None, _match_nulls) => ValueRange::Inclusive(u64::MIN..=u64::MAX),
ValueRange::LessThan(None, _match_nulls) => ValueRange::Inclusive(1..=0),
_ => ValueRange::All,
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;

View File

@@ -120,6 +120,11 @@ where
);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.segment_sort_key_computer
.compute_sort_keys_and_collect(docs, &mut self.topn_computer);
}
fn harvest(self) -> Self::Fruit {
let segment_ord = self.segment_ord;
let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self

View File

@@ -2,6 +2,7 @@ use std::cmp::Ordering;
use std::fmt;
use std::ops::Range;
use columnar::ValueRange;
use serde::{Deserialize, Serialize};
use super::Collector;
@@ -486,6 +487,14 @@ where
(self.sort_key_fn)(doc, score)
}
fn segment_sort_keys(
&mut self,
_docs: &[DocId],
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
unimplemented!("Batch computation is not supported for tweak score.")
}
/// Convert a segment level score into the global level score.
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key
@@ -604,9 +613,12 @@ where
C: Comparator<TSortKey>,
{
/// Create a new `TopNComputer`.
/// Internally it will allocate a buffer of size `2 * top_n`.
/// Internally it will allocate a buffer of size `(top_n.max(1) * 2) +
/// COLLECT_BLOCK_BUFFER_LEN`.
pub fn new_with_comparator(top_n: usize, comparator: C) -> Self {
let vec_cap = top_n.max(1) * 2;
// We ensure that there is always enough space to include an entire block in the buffer if
// need be, so that `push_block_lazy` can avoid checking capacity inside its loop.
let vec_cap = (top_n.max(1) * 2) + crate::COLLECT_BLOCK_BUFFER_LEN;
TopNComputer {
buffer: Vec::with_capacity(vec_cap),
top_n,
@@ -635,16 +647,31 @@ where
// At this point, we need to have established that the doc is above the threshold.
#[inline(always)]
pub(crate) fn append_doc(&mut self, doc: D, sort_key: TSortKey) {
if self.buffer.len() == self.buffer.capacity() {
let median = self.truncate_top_n();
self.threshold = Some(median);
}
// This cannot panic, because we truncate_median will at least remove one element, since
// the min capacity is 2.
self.reserve(1);
// This cannot panic, because we've reserved room for one element.
self.append_doc_unchecked(doc, sort_key);
}
// Append a document to the top n. `reserve` must already have been called to ensure that there
// is capacity, or this method will panic.
//
// At this point, we need to have established that the doc is above the threshold.
#[inline(always)]
pub(crate) fn append_doc_unchecked(&mut self, doc: D, sort_key: TSortKey) {
let comparable_doc = ComparableDoc { doc, sort_key };
push_assuming_capacity(comparable_doc, &mut self.buffer);
}
// Ensure that there is capacity to push `additional` more elements without resizing.
#[inline(always)]
pub(crate) fn reserve(&mut self, additional: usize) {
if self.buffer.len() + additional > self.buffer.capacity() {
let median = self.truncate_top_n();
debug_assert!(self.buffer.len() + additional <= self.buffer.capacity());
self.threshold = Some(median);
}
}
#[inline(never)]
fn truncate_top_n(&mut self) -> TSortKey {
// Use select_nth_unstable to find the top nth score
@@ -684,7 +711,7 @@ where
//
// Panics if there is not enough capacity to add an element.
#[inline(always)]
fn push_assuming_capacity<T>(el: T, buf: &mut Vec<T>) {
pub fn push_assuming_capacity<T>(el: T, buf: &mut Vec<T>) {
let prev_len = buf.len();
assert!(prev_len < buf.capacity());
// This is mimicking the current (non-stabilized) implementation in std.
@@ -1408,11 +1435,11 @@ mod tests {
#[test]
fn test_top_field_collect_string_prop(
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..256_usize,
offset in 0..256_usize,
limit in 1..32_usize,
offset in 0..32_usize,
segments_terms in
proptest::collection::vec(
proptest::collection::vec(0..32_u8, 1..32_usize),
proptest::collection::vec(0..64_u8, 1..256_usize),
0..8_usize,
)
) {

View File

@@ -79,7 +79,7 @@ mod tests {
use std::ops::{Range, RangeInclusive};
use std::path::Path;
use columnar::StrColumn;
use columnar::{StrColumn, ValueRange};
use common::{ByteCount, DateTimePrecision, HasLen, TerminatingWrite};
use once_cell::sync::Lazy;
use rand::prelude::SliceRandom;
@@ -944,7 +944,7 @@ mod tests {
let test_range = |range: RangeInclusive<u64>| {
let expected_count = numbers.iter().filter(|num| range.contains(*num)).count();
let mut vec = vec![];
field.get_row_ids_for_value_range(range, 0..u32::MAX, &mut vec);
field.get_row_ids_for_value_range(ValueRange::Inclusive(range), 0..u32::MAX, &mut vec);
assert_eq!(vec.len(), expected_count);
};
test_range(50..=50);
@@ -1022,7 +1022,7 @@ mod tests {
let test_range = |range: RangeInclusive<u64>| {
let expected_count = numbers.iter().filter(|num| range.contains(*num)).count();
let mut vec = vec![];
field.get_row_ids_for_value_range(range, 0..u32::MAX, &mut vec);
field.get_row_ids_for_value_range(ValueRange::Inclusive(range), 0..u32::MAX, &mut vec);
assert_eq!(vec.len(), expected_count);
};
let test_range_variant = |start, stop| {

View File

@@ -1,7 +1,6 @@
use core::fmt::Debug;
use std::ops::RangeInclusive;
use columnar::Column;
use columnar::{Column, ValueRange};
use crate::{DocId, DocSet, TERMINATED};
@@ -41,7 +40,7 @@ impl VecCursor {
pub(crate) struct RangeDocSet<T> {
/// The range filter on the values.
value_range: RangeInclusive<T>,
value_range: ValueRange<T>,
column: Column<T>,
/// The next docid start range to fetch (inclusive).
next_fetch_start: u32,
@@ -61,8 +60,8 @@ pub(crate) struct RangeDocSet<T> {
const DEFAULT_FETCH_HORIZON: u32 = 128;
impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
pub(crate) fn new(value_range: RangeInclusive<T>, column: Column<T>) -> Self {
if *value_range.start() > column.max_value() || *value_range.end() < column.min_value() {
pub(crate) fn new(value_range: ValueRange<T>, column: Column<T>) -> Self {
if !value_range.intersects(column.min_value(), column.max_value()) {
return Self {
value_range,
column,

View File

@@ -7,7 +7,7 @@ use std::ops::{Bound, RangeInclusive};
use columnar::{
Cardinality, Column, ColumnType, MonotonicallyMappableToU128, MonotonicallyMappableToU64,
NumericalType, StrColumn,
NumericalType, StrColumn, ValueRange,
};
use common::bounds::{BoundsRange, TransformBound};
@@ -154,7 +154,7 @@ impl Weight for FastFieldRangeWeight {
ip_addr_column.min_value(),
ip_addr_column.max_value(),
);
let docset = RangeDocSet::new(value_range, ip_addr_column);
let docset = RangeDocSet::new(ValueRange::Inclusive(value_range), ip_addr_column);
Ok(Box::new(ConstScorer::new(docset, boost)))
} else if field_type.is_str() {
let Some(str_dict_column): Option<StrColumn> = reader.fast_fields().str(&field_name)?
@@ -426,7 +426,7 @@ fn search_on_u64_ff(
}
}
let docset = RangeDocSet::new(value_range, column);
let docset = RangeDocSet::new(ValueRange::Inclusive(value_range), column);
Ok(Box::new(ConstScorer::new(docset, boost)))
}