Use a Buffer generic scratch buffer parameter on TopNComputer and push directly from ColumnValues into a TopNComputer buffer in some cases.

This commit is contained in:
Stu Hood
2025-12-26 17:52:40 -07:00
parent 041c6f01a3
commit af53ffe5df
16 changed files with 478 additions and 352 deletions

View File

@@ -16,7 +16,7 @@ stacker = { version= "0.6", path = "../stacker", package="tantivy-stacker"}
sstable = { version= "0.6", path = "../sstable", package = "tantivy-sstable" }
common = { version= "0.10", path = "../common", package = "tantivy-common" }
tantivy-bitpacker = { version= "0.9", path = "../bitpacker/" }
serde = "1.0.152"
serde = { version = "1.0.152", features = ["derive"] }
downcast-rs = "2.0.1"
[dev-dependencies]

View File

@@ -146,13 +146,18 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
// Separate impl block for methods requiring `Default` for `T`.
impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static + Default> Column<T> {
/// Load the first value for each docid in the provided slice.
///
/// The `docids` vector is mutated: documents that do not match the `value_range` are removed.
/// The `values` vector is populated with the values of the remaining documents.
#[inline]
pub fn first_vals_in_value_range(
&self,
docids: &mut Vec<DocId>,
values: &mut Vec<Option<T>>,
input_docs: &[DocId],
output: &mut Vec<crate::ComparableDoc<Option<T>, DocId>>,
value_range: ValueRange<T>,
) {
// TODO: Move `COLLECT_BLOCK_BUFFER_LEN` to allow for use here, or use a different constant
// in this context.
const BLOCK_LEN: usize = 64; // Corresponds to COLLECT_BLOCK_BUFFER_LEN in tantivy's docset
match (&self.index, value_range) {
(ColumnIndex::Empty { .. }, value_range) => {
@@ -163,19 +168,20 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static + Default> Column<T> {
ValueRange::LessThan(_, nulls_match) => *nulls_match,
};
if nulls_match {
for _ in 0..docids.len() {
values.push(None);
for &doc in input_docs {
output.push(crate::ComparableDoc {
doc,
sort_key: None,
});
}
} else {
docids.clear();
}
}
(ColumnIndex::Full, value_range) => {
self.values
.get_vals_in_value_range(docids, values, value_range);
.get_vals_in_value_range(input_docs, output, value_range);
}
(ColumnIndex::Optional(optional_index), value_range) => {
let len = docids.len();
let len = input_docs.len();
// Ensure the input docids length does not exceed BLOCK_LEN for stack allocation
// safety. If it does, we might need to handle this with multiple
// chunks or fallback to heap. For now, an assert is used to confirm
@@ -188,7 +194,7 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static + Default> Column<T> {
);
let mut input_docs_buffer = [0u32; BLOCK_LEN];
input_docs_buffer[..len].copy_from_slice(docids);
input_docs_buffer[..len].copy_from_slice(input_docs);
let mut dense_row_ids_buffer = [0u32; BLOCK_LEN];
let mut dense_values_buffer = [T::default(); BLOCK_LEN];
@@ -221,9 +227,6 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static + Default> Column<T> {
};
// Phase 3: Filter and merge results, reconstructing docids and values
docids.clear();
values.clear();
let mut dense_values_cursor = 0;
for i in 0..len {
let original_doc_id = input_docs_buffer[i];
@@ -241,13 +244,17 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static + Default> Column<T> {
};
if value_matches {
docids.push(original_doc_id);
values.push(Some(val));
output.push(crate::ComparableDoc {
doc: original_doc_id,
sort_key: Some(val),
});
}
} else if nulls_match {
// This doc_id was not present in the optional index (null) and nulls match
docids.push(original_doc_id);
values.push(None);
output.push(crate::ComparableDoc {
doc: original_doc_id,
sort_key: None,
});
}
}
}
@@ -258,9 +265,8 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static + Default> Column<T> {
ValueRange::GreaterThan(_, nulls_match) => *nulls_match,
ValueRange::LessThan(_, nulls_match) => *nulls_match,
};
let mut write_head = 0;
for i in 0..docids.len() {
let docid = docids[i];
for i in 0..input_docs.len() {
let docid = input_docs[i];
let row_range = multivalued_index.range(docid);
let is_empty = row_range.start == row_range.end;
if !is_empty {
@@ -272,17 +278,18 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static + Default> Column<T> {
ValueRange::LessThan(t, _) => val < *t,
};
if matches {
docids[write_head] = docid;
values.push(Some(val));
write_head += 1;
output.push(crate::ComparableDoc {
doc: docid,
sort_key: Some(val),
});
}
} else if nulls_match {
docids[write_head] = docid;
values.push(None);
write_head += 1;
output.push(crate::ComparableDoc {
doc: docid,
sort_key: None,
});
}
}
docids.truncate(write_head);
}
}
}

View File

@@ -116,49 +116,52 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
/// The values are filtered by the provided value range.
fn get_vals_in_value_range(
&self,
indexes: &mut Vec<u32>,
output: &mut Vec<Option<T>>,
input_indexes: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
value_range: ValueRange<T>,
) {
let mut write_head = 0;
let len = input_indexes.len();
let mut read_head = 0;
let len = indexes.len();
match value_range {
ValueRange::All => {
while read_head + 3 < len {
let idx0 = indexes[read_head];
let idx1 = indexes[read_head + 1];
let idx2 = indexes[read_head + 2];
let idx3 = indexes[read_head + 3];
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
let val2 = self.get_val(idx2);
let val3 = self.get_val(idx3);
indexes[write_head] = idx0;
output.push(Some(val0));
write_head += 1;
indexes[write_head] = idx1;
output.push(Some(val1));
write_head += 1;
indexes[write_head] = idx2;
output.push(Some(val2));
write_head += 1;
indexes[write_head] = idx3;
output.push(Some(val3));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx0,
sort_key: Some(val0),
});
output.push(crate::ComparableDoc {
doc: idx1,
sort_key: Some(val1),
});
output.push(crate::ComparableDoc {
doc: idx2,
sort_key: Some(val2),
});
output.push(crate::ComparableDoc {
doc: idx3,
sort_key: Some(val3),
});
read_head += 4;
}
}
ValueRange::Inclusive(ref range) => {
while read_head + 3 < len {
let idx0 = indexes[read_head];
let idx1 = indexes[read_head + 1];
let idx2 = indexes[read_head + 2];
let idx3 = indexes[read_head + 3];
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
@@ -166,24 +169,28 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
let val3 = self.get_val(idx3);
if range.contains(&val0) {
indexes[write_head] = idx0;
output.push(Some(val0));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx0,
sort_key: Some(val0),
});
}
if range.contains(&val1) {
indexes[write_head] = idx1;
output.push(Some(val1));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx1,
sort_key: Some(val1),
});
}
if range.contains(&val2) {
indexes[write_head] = idx2;
output.push(Some(val2));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx2,
sort_key: Some(val2),
});
}
if range.contains(&val3) {
indexes[write_head] = idx3;
output.push(Some(val3));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx3,
sort_key: Some(val3),
});
}
read_head += 4;
@@ -191,10 +198,10 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
}
ValueRange::GreaterThan(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = indexes[read_head];
let idx1 = indexes[read_head + 1];
let idx2 = indexes[read_head + 2];
let idx3 = indexes[read_head + 3];
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
@@ -202,24 +209,28 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
let val3 = self.get_val(idx3);
if val0 > *threshold {
indexes[write_head] = idx0;
output.push(Some(val0));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx0,
sort_key: Some(val0),
});
}
if val1 > *threshold {
indexes[write_head] = idx1;
output.push(Some(val1));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx1,
sort_key: Some(val1),
});
}
if val2 > *threshold {
indexes[write_head] = idx2;
output.push(Some(val2));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx2,
sort_key: Some(val2),
});
}
if val3 > *threshold {
indexes[write_head] = idx3;
output.push(Some(val3));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx3,
sort_key: Some(val3),
});
}
read_head += 4;
@@ -227,10 +238,10 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
}
ValueRange::LessThan(ref threshold, _) => {
while read_head + 3 < len {
let idx0 = indexes[read_head];
let idx1 = indexes[read_head + 1];
let idx2 = indexes[read_head + 2];
let idx3 = indexes[read_head + 3];
let idx0 = input_indexes[read_head];
let idx1 = input_indexes[read_head + 1];
let idx2 = input_indexes[read_head + 2];
let idx3 = input_indexes[read_head + 3];
let val0 = self.get_val(idx0);
let val1 = self.get_val(idx1);
@@ -238,24 +249,28 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
let val3 = self.get_val(idx3);
if val0 < *threshold {
indexes[write_head] = idx0;
output.push(Some(val0));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx0,
sort_key: Some(val0),
});
}
if val1 < *threshold {
indexes[write_head] = idx1;
output.push(Some(val1));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx1,
sort_key: Some(val1),
});
}
if val2 < *threshold {
indexes[write_head] = idx2;
output.push(Some(val2));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx2,
sort_key: Some(val2),
});
}
if val3 < *threshold {
indexes[write_head] = idx3;
output.push(Some(val3));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx3,
sort_key: Some(val3),
});
}
read_head += 4;
@@ -264,7 +279,7 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
}
// Process remaining elements (0 to 3)
while read_head < len {
let idx = indexes[read_head];
let idx = input_indexes[read_head];
let val = self.get_val(idx);
let matches = match value_range {
// 'value_range' is still moved here. This is the outer `value_range`
@@ -274,13 +289,13 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
ValueRange::LessThan(ref t, _) => val < *t,
};
if matches {
indexes[write_head] = idx;
output.push(Some(val));
write_head += 1;
output.push(crate::ComparableDoc {
doc: idx,
sort_key: Some(val),
});
}
read_head += 1;
}
indexes.truncate(write_head);
}
/// Fills an output buffer with the fast field values
@@ -393,11 +408,11 @@ impl<T: PartialOrd + Default> ColumnValues<T> for EmptyColumnValues {
fn get_vals_in_value_range(
&self,
indexes: &mut Vec<u32>,
output: &mut Vec<Option<T>>,
input_indexes: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
value_range: ValueRange<T>,
) {
let _ = (indexes, output, value_range);
let _ = (input_indexes, output, value_range);
panic!("Internal Error: Called get_vals_in_value_range of empty column.")
}
}
@@ -416,12 +431,12 @@ impl<T: Copy + PartialOrd + Debug + 'static> ColumnValues<T> for Arc<dyn ColumnV
#[inline(always)]
fn get_vals_in_value_range(
&self,
indexes: &mut Vec<u32>,
output: &mut Vec<Option<T>>,
input_indexes: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<T>, crate::DocId>>,
value_range: ValueRange<T>,
) {
self.as_ref()
.get_vals_in_value_range(indexes, output, value_range)
.get_vals_in_value_range(input_indexes, output, value_range)
}
#[inline(always)]

View File

@@ -69,67 +69,68 @@ impl ColumnValues for BitpackedReader {
fn get_vals_in_value_range(
&self,
indexes: &mut Vec<u32>,
output: &mut Vec<Option<u64>>,
input_indexes: &[u32],
output: &mut Vec<crate::ComparableDoc<Option<u64>, crate::DocId>>,
value_range: ValueRange<u64>,
) {
let mut write_head = 0;
match value_range {
ValueRange::All => {
for i in 0..indexes.len() {
let idx = indexes[i];
indexes[write_head] = idx;
output.push(Some(self.get_val(idx)));
write_head += 1;
for &idx in input_indexes {
output.push(crate::ComparableDoc {
doc: idx,
sort_key: Some(self.get_val(idx)),
});
}
}
ValueRange::Inclusive(range) => {
if let Some(transformed_range) =
transform_range_before_linear_transformation(&self.stats, range)
{
for i in 0..indexes.len() {
let doc = indexes[i];
for &doc in input_indexes {
let raw_val = self.get_val(doc);
if transformed_range.contains(&raw_val) {
indexes[write_head] = doc;
output
.push(Some(self.stats.min_value + self.stats.gcd.get() * raw_val));
write_head += 1;
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::GreaterThan(threshold, _) => {
if threshold < self.stats.min_value {
for i in 0..indexes.len() {
let idx = indexes[i];
indexes[write_head] = idx;
output.push(Some(self.get_val(idx)));
write_head += 1;
for &idx in input_indexes {
output.push(crate::ComparableDoc {
doc: idx,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold >= self.stats.max_value {
// All filtered out
} else {
let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get();
for i in 0..indexes.len() {
let doc = indexes[i];
for &doc in input_indexes {
let raw_val = self.get_val(doc);
if raw_val > raw_threshold {
indexes[write_head] = doc;
output
.push(Some(self.stats.min_value + self.stats.gcd.get() * raw_val));
write_head += 1;
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
ValueRange::LessThan(threshold, _) => {
if threshold > self.stats.max_value {
for i in 0..indexes.len() {
let idx = indexes[i];
indexes[write_head] = idx;
output.push(Some(self.get_val(idx)));
write_head += 1;
for &idx in input_indexes {
output.push(crate::ComparableDoc {
doc: idx,
sort_key: Some(self.get_val(idx)),
});
}
} else if threshold <= self.stats.min_value {
// All filtered out
@@ -142,20 +143,20 @@ impl ColumnValues for BitpackedReader {
diff / gcd + 1
};
for i in 0..indexes.len() {
let doc = indexes[i];
for &doc in input_indexes {
let raw_val = self.get_val(doc);
if raw_val < raw_threshold {
indexes[write_head] = doc;
output
.push(Some(self.stats.min_value + self.stats.gcd.get() * raw_val));
write_head += 1;
output.push(crate::ComparableDoc {
doc,
sort_key: Some(
self.stats.min_value + self.stats.gcd.get() * raw_val,
),
});
}
}
}
}
}
indexes.truncate(write_head);
}
fn get_row_ids_for_value_range(
&self,

View File

@@ -29,6 +29,7 @@ mod column;
pub mod column_index;
pub mod column_values;
mod columnar;
mod comparable_doc;
mod dictionary;
mod dynamic_column;
mod iterable;
@@ -45,6 +46,7 @@ pub use columnar::{
CURRENT_VERSION, ColumnType, ColumnarReader, ColumnarWriter, HasAssociatedColumnType,
MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, Version, merge_columnar,
};
pub use comparable_doc::ComparableDoc;
use sstable::VoidSSTable;
pub use value::{NumericalType, NumericalValue};

View File

@@ -96,10 +96,9 @@ mod histogram_collector;
pub use histogram_collector::HistogramCollector;
mod multi_collector;
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
pub use columnar::ComparableDoc;
mod top_collector;
pub use self::top_collector::ComparableDoc;
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
mod top_score_collector;
pub use self::top_score_collector::{TopDocs, TopNComputer};

View File

@@ -39,6 +39,7 @@ pub(crate) mod tests {
use crate::collector::sort_key::{
SortByErasedType, SortBySimilarityScore, SortByStaticFastValue, SortByString,
};
use crate::collector::top_score_collector::compare_for_top_k;
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
use crate::indexer::NoMergePolicy;
use crate::query::{AllQuery, QueryParser};
@@ -594,7 +595,7 @@ pub(crate) mod tests {
.map(|(sort_key, doc)| ComparableDoc { sort_key, doc })
.collect();
comparable_docs.sort_by(|l, r| comparator.compare_doc(l, r));
comparable_docs.sort_by(|l, r| compare_for_top_k(&comparator, l, r));
let expected_results = comparable_docs
.into_iter()

View File

@@ -1,6 +1,6 @@
use std::cmp::Ordering;
use columnar::{MonotonicallyMappableToU64, ValueRange};
use columnar::{ComparableDoc, MonotonicallyMappableToU64, ValueRange};
use serde::{Deserialize, Serialize};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
@@ -619,6 +619,7 @@ where
type SortKey = TSegmentSortKeyComputer::SortKey;
type SegmentSortKey = TSegmentSortKey;
type SegmentComparator = TComparator;
type Buffer = TSegmentSortKeyComputer::Buffer;
fn segment_comparator(&self) -> Self::SegmentComparator {
self.comparator.clone()
@@ -630,11 +631,13 @@ where
fn segment_sort_keys(
&mut self,
docs: &[DocId],
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
) {
self.segment_sort_key_computer
.segment_sort_keys(docs, filter)
.segment_sort_keys(input_docs, output, buffer, filter)
}
#[inline(always)]

View File

@@ -4,7 +4,7 @@ use crate::collector::sort_key::sort_by_score::SortBySimilarityScoreSegmentCompu
use crate::collector::sort_key::{
NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString,
};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::FastFieldNotAvailableError;
use crate::schema::OwnedValue;
use crate::{DateTime, DocId, Score};
@@ -39,15 +39,21 @@ trait ErasedSegmentSortKeyComputer: Send + Sync {
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64>;
fn segment_sort_keys(
&mut self,
docs: &[DocId],
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
filter: ValueRange<Option<u64>>,
) -> &mut Vec<(DocId, Option<u64>)>;
);
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue;
}
struct ErasedSegmentSortKeyComputerWrapper<C, F> {
struct ErasedSegmentSortKeyComputerWrapper<C, F>
where
C: SegmentSortKeyComputer<SegmentSortKey = Option<u64>> + Send + Sync,
F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static,
{
inner: C,
converter: F,
buffer: C::Buffer,
}
impl<C, F> ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper<C, F>
@@ -61,10 +67,12 @@ where
fn segment_sort_keys(
&mut self,
docs: &[DocId],
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
filter: ValueRange<Option<u64>>,
) -> &mut Vec<(DocId, Option<u64>)> {
self.inner.segment_sort_keys(docs, filter)
) {
self.inner
.segment_sort_keys(input_docs, output, &mut self.buffer, filter)
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
@@ -85,9 +93,10 @@ impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
fn segment_sort_keys(
&mut self,
_docs: &[DocId],
_input_docs: &[DocId],
_output: &mut Vec<ComparableDoc<Option<u64>, DocId>>,
_filter: ValueRange<Option<u64>>,
) -> &mut Vec<(DocId, Option<u64>)> {
) {
unimplemented!("Batch computation not supported for score sorting")
}
@@ -134,6 +143,7 @@ impl SortKeyComputer for SortByErasedType {
converter: |val: Option<String>| {
val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::U64 => {
@@ -144,6 +154,7 @@ impl SortKeyComputer for SortByErasedType {
converter: |val: Option<u64>| {
val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::I64 => {
@@ -154,6 +165,7 @@ impl SortKeyComputer for SortByErasedType {
converter: |val: Option<i64>| {
val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::F64 => {
@@ -164,6 +176,7 @@ impl SortKeyComputer for SortByErasedType {
converter: |val: Option<f64>| {
val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::Bool => {
@@ -174,6 +187,7 @@ impl SortKeyComputer for SortByErasedType {
converter: |val: Option<bool>| {
val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
ColumnType::DateTime => {
@@ -184,6 +198,7 @@ impl SortKeyComputer for SortByErasedType {
converter: |val: Option<DateTime>| {
val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null)
},
buffer: Default::default(),
})
}
column_type => {
@@ -212,6 +227,7 @@ impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer {
type SortKey = OwnedValue;
type SegmentSortKey = Option<u64>;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
@@ -220,10 +236,12 @@ impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer {
fn segment_sort_keys(
&mut self,
docs: &[DocId],
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.inner.segment_sort_keys(docs, filter)
) {
self.inner.segment_sort_keys(input_docs, output, filter)
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue {

View File

@@ -1,7 +1,7 @@
use columnar::ValueRange;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::{DocAddress, DocId, Score};
/// Sort by similarity score.
@@ -69,6 +69,7 @@ impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer {
type SortKey = Score;
type SegmentSortKey = Score;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {
@@ -77,9 +78,11 @@ impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer {
fn segment_sort_keys(
&mut self,
_docs: &[DocId],
_input_docs: &[DocId],
_output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
) {
unimplemented!("Batch computation not supported for score sorting")
}

View File

@@ -4,7 +4,7 @@ use columnar::{Column, ValueRange};
use crate::collector::sort_key::sort_key_computer::convert_optional_u64_range_to_u64_range;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
use crate::{DocId, Score, SegmentReader};
@@ -72,9 +72,6 @@ impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
Ok(SortByFastValueSegmentSortKeyComputer {
sort_column,
typ: PhantomData,
buffer: Vec::new(),
fetch_buffer: Vec::new(),
doc_buffer: Vec::new(),
})
}
}
@@ -82,15 +79,13 @@ impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
pub struct SortByFastValueSegmentSortKeyComputer<T> {
sort_column: Column<u64>,
typ: PhantomData<T>,
buffer: Vec<(DocId, Option<u64>)>,
fetch_buffer: Vec<Option<u64>>,
doc_buffer: Vec<DocId>,
}
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
type SortKey = Option<T>;
type SegmentSortKey = Option<u64>;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {
@@ -99,24 +94,14 @@ impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyCompu
fn segment_sort_keys(
&mut self,
docs: &[DocId],
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.doc_buffer.clear();
self.doc_buffer.extend_from_slice(docs);
self.fetch_buffer.clear();
) {
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
self.sort_column.first_vals_in_value_range(
&mut self.doc_buffer,
&mut self.fetch_buffer,
u64_filter,
);
self.buffer.clear();
for (&doc, &val) in self.doc_buffer.iter().zip(self.fetch_buffer.iter()) {
self.buffer.push((doc, val));
}
&mut self.buffer
self.sort_column
.first_vals_in_value_range(input_docs, output, u64_filter);
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
@@ -154,10 +139,16 @@ mod tests {
let sorter = SortByStaticFastValue::<u64>::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let docs = vec![0, 1, 2];
let output = computer.segment_sort_keys(&docs, ValueRange::All);
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
let mut buffer = ();
computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All);
assert_eq!(output, &[(0, Some(10)), (1, Some(20)), (2, None)]);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(10), Some(20), None]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[0, 1, 2]);
}
#[test]
@@ -184,15 +175,20 @@ mod tests {
let sorter = SortByStaticFastValue::<u64>::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let docs = vec![0, 1, 2];
let output = computer.segment_sort_keys(
&docs,
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
let mut buffer = ();
computer.segment_sort_keys(
&mut docs,
&mut output,
&mut buffer,
ValueRange::GreaterThan(Some(15u64), false /* inclusive */),
);
// Should contain only the document with value 20.
// Doc 0 (10) < 15
// Doc 2 (None) < 15
assert_eq!(output, &[(1, Some(20))]);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(20)]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[1]);
}
}

View File

@@ -4,7 +4,7 @@ use crate::collector::sort_key::sort_key_computer::{
convert_optional_u64_range_to_u64_range, range_contains_none,
};
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
use crate::{DocId, Score};
@@ -41,26 +41,19 @@ impl SortKeyComputer for SortByString {
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?;
Ok(ByStringColumnSegmentSortKeyComputer {
str_column_opt,
buffer: Vec::new(),
fetch_buffer: Vec::new(),
doc_buffer: Vec::new(),
})
Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt })
}
}
pub struct ByStringColumnSegmentSortKeyComputer {
str_column_opt: Option<StrColumn>,
buffer: Vec<(DocId, Option<TermOrdinal>)>,
fetch_buffer: Vec<Option<TermOrdinal>>,
doc_buffer: Vec<DocId>,
}
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
type SortKey = Option<String>;
type SegmentSortKey = Option<TermOrdinal>;
type SegmentComparator = NaturalComparator;
type Buffer = ();
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
@@ -70,33 +63,24 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
fn segment_sort_keys(
&mut self,
docs: &[DocId],
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.doc_buffer.clear();
self.doc_buffer.extend_from_slice(docs);
self.fetch_buffer.clear();
) {
if let Some(str_column) = &self.str_column_opt {
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
str_column.ords().first_vals_in_value_range(
&mut self.doc_buffer,
&mut self.fetch_buffer,
u64_filter,
);
str_column
.ords()
.first_vals_in_value_range(input_docs, output, u64_filter);
} else if range_contains_none(&filter) {
for _ in 0..docs.len() {
self.fetch_buffer.push(None);
for &doc in input_docs {
output.push(ComparableDoc {
doc,
sort_key: None,
});
}
} else {
self.doc_buffer.clear();
}
self.buffer.clear();
for (&doc, &val) in self.doc_buffer.iter().zip(self.fetch_buffer.iter()) {
self.buffer.push((doc, val));
}
&mut self.buffer
}
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
@@ -143,13 +127,16 @@ mod tests {
let sorter = SortByString::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let docs = vec![0, 1, 2];
let output = computer.segment_sort_keys(&docs, ValueRange::All);
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
let mut buffer = ();
computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All);
// We expect ordinals.
// "a" -> 0
// "c" -> 1
assert_eq!(output, &[(0, Some(0)), (1, Some(1)), (2, None)]);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(0), Some(1), None]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[0, 1, 2]);
}
#[test]
@@ -176,16 +163,23 @@ mod tests {
let sorter = SortByString::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let docs = vec![0, 1, 2];
let mut docs = vec![0, 1, 2];
let mut output = Vec::new();
// Filter: > "b". "a" is 0, "c" is 1.
// We want > "a" (ord 0). So we filter > ord 0.
// 0 is "a", 1 is "c".
let output = computer.segment_sort_keys(
&docs,
let mut buffer = ();
computer.segment_sort_keys(
&mut docs,
&mut output,
&mut buffer,
ValueRange::GreaterThan(Some(0), false /* inclusive */),
);
// Should contain only the document with value "c" (ord 1).
assert_eq!(output, &[(1, Some(1))]);
assert_eq!(
output.iter().map(|c| c.sort_key).collect::<Vec<_>>(),
&[Some(1)]
);
assert_eq!(output.iter().map(|c| c.doc).collect::<Vec<_>>(), &[1]);
}
}

View File

@@ -28,6 +28,9 @@ pub trait SegmentSortKeyComputer: 'static {
/// Comparator type.
type SegmentComparator: Comparator<Self::SegmentSortKey> + Clone + 'static;
/// Buffer type used for scratch space.
type Buffer: Default + Send + Sync + 'static;
/// Returns the segment sort key comparator.
fn segment_comparator(&self) -> Self::SegmentComparator {
Self::SegmentComparator::default()
@@ -38,13 +41,15 @@ pub trait SegmentSortKeyComputer: 'static {
/// Computes the sort keys for a batch of documents.
///
/// The computed sort keys are stored in an internal buffer and returned as a slice.
/// Subsequent calls to this method may reuse and overwrite the internal buffer.
/// The computed sort keys and document IDs are pushed into the `output` vector.
/// The `buffer` is used for scratch space.
fn segment_sort_keys(
&mut self,
docs: &[DocId],
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)>;
);
/// Computes the sort key and pushes the document in a TopN Computer.
///
@@ -54,7 +59,7 @@ pub trait SegmentSortKeyComputer: 'static {
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
let sort_key = self.segment_sort_key(doc, score);
top_n_computer.push(sort_key, doc);
@@ -63,7 +68,7 @@ pub trait SegmentSortKeyComputer: 'static {
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
docs: &[DocId],
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
// The capacity of a TopNComputer is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so we
// should always be able to `reserve` space for the entire block.
@@ -76,24 +81,8 @@ pub trait SegmentSortKeyComputer: 'static {
ValueRange::All
};
let sort_keys = self.segment_sort_keys(docs, value_range);
if let Some(threshold) = &top_n_computer.threshold {
let threshold = threshold.clone();
for (doc, sort_key) in sort_keys.drain(..) {
let cmp = comparator.compare(&sort_key, &threshold);
if cmp == Ordering::Greater {
// We validated at the top of the method that we have capacity.
top_n_computer.append_doc_unchecked(doc, sort_key);
}
}
} else {
// Eagerly push, without a threshold to compare to.
for (doc, sort_key) in sort_keys.drain(..) {
// We validated at the top of the method that we have capacity.
top_n_computer.append_doc_unchecked(doc, sort_key);
}
}
let (buffer, scratch) = top_n_computer.buffer_and_scratch();
self.segment_sort_keys(docs, buffer, scratch, value_range);
}
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
@@ -213,8 +202,6 @@ where
Ok(ChainSegmentSortKeyComputer {
head: self.0.segment_sort_key_computer(segment_reader)?,
tail: self.1.segment_sort_key_computer(segment_reader)?,
head_key_buffer: Vec::new(),
doc_buffer: Vec::new(),
})
}
@@ -240,8 +227,28 @@ where
{
head: Head,
tail: Tail,
head_key_buffer: Vec<Head::SegmentSortKey>,
doc_buffer: Vec<DocId>,
}
pub struct ChainBuffer<HeadBuffer, TailBuffer, HeadKey, TailKey> {
pub head: HeadBuffer,
pub tail: TailBuffer,
pub head_output: Vec<ComparableDoc<HeadKey, DocId>>,
pub tail_output: Vec<ComparableDoc<TailKey, DocId>>,
pub tail_input_docs: Vec<DocId>,
}
impl<HeadBuffer: Default, TailBuffer: Default, HeadKey, TailKey> Default
for ChainBuffer<HeadBuffer, TailBuffer, HeadKey, TailKey>
{
fn default() -> Self {
ChainBuffer {
head: HeadBuffer::default(),
tail: TailBuffer::default(),
head_output: Vec::new(),
tail_output: Vec::new(),
tail_input_docs: Vec::new(),
}
}
}
impl<Head, Tail> ChainSegmentSortKeyComputer<Head, Tail>
@@ -289,6 +296,9 @@ where
type SegmentComparator = (Head::SegmentComparator, Tail::SegmentComparator);
type Buffer =
ChainBuffer<Head::Buffer, Tail::Buffer, Head::SegmentSortKey, Tail::SegmentSortKey>;
fn segment_comparator(&self) -> Self::SegmentComparator {
(
self.head.segment_comparator(),
@@ -313,9 +323,11 @@ where
fn segment_sort_keys(
&mut self,
_docs: &[DocId],
_input_docs: &[DocId],
_output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
) {
unimplemented!("The head and the tail are accessed independently.");
}
@@ -324,7 +336,7 @@ where
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
let sort_key: Self::SegmentSortKey;
if let Some(threshold) = &top_n_computer.threshold {
@@ -342,39 +354,52 @@ where
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
docs: &[DocId],
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
// The capacity of a TopNComputer is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so we
// should always be able to `reserve` space for the entire block.
top_n_computer.reserve(docs.len());
let mut scratch = std::mem::take(&mut top_n_computer.scratch);
if let Some(threshold) = &top_n_computer.threshold {
let (head_threshold, tail_threshold) = threshold.clone();
let head_cmp = self.head.segment_comparator();
let tail_cmp = self.tail.segment_comparator();
let head_filter = head_cmp.threshold_to_valuerange(head_threshold.clone());
let head_keys = self.head.segment_sort_keys(docs, head_filter);
self.doc_buffer.clear();
self.head_key_buffer.clear();
for (doc, head_key) in head_keys.drain(..) {
let cmp = head_cmp.compare(&head_key, &head_threshold);
if cmp != Ordering::Less {
self.doc_buffer.push(doc);
self.head_key_buffer.push(head_key);
}
}
scratch.head_output.clear();
self.head.segment_sort_keys(
docs,
&mut scratch.head_output,
&mut scratch.head,
head_filter,
);
if !self.doc_buffer.is_empty() {
let tail_keys = self
.tail
.segment_sort_keys(&self.doc_buffer, ValueRange::All);
for ((head_key, tail_key), &doc) in self
.head_key_buffer
if !scratch.head_output.is_empty() {
scratch.tail_output.clear();
scratch.tail_input_docs.clear();
for cd in &scratch.head_output {
scratch.tail_input_docs.push(cd.doc);
}
self.tail.segment_sort_keys(
&scratch.tail_input_docs,
&mut scratch.tail_output,
&mut scratch.tail,
ValueRange::All,
);
for (head_doc, tail_doc) in scratch
.head_output
.drain(..)
.zip(tail_keys.drain(..).map(|(_, k)| k))
.zip(self.doc_buffer.iter())
.zip(scratch.tail_output.drain(..))
{
debug_assert_eq!(head_doc.doc, tail_doc.doc);
let doc = head_doc.doc;
let head_key = head_doc.sort_key;
let tail_key = tail_doc.sort_key;
let head_ord = head_cmp.compare(&head_key, &head_threshold);
let ord = if head_ord == Ordering::Equal {
tail_cmp.compare(&tail_key, &tail_threshold)
@@ -382,19 +407,60 @@ where
head_ord
};
if ord == Ordering::Greater {
top_n_computer.append_doc_unchecked(doc, (head_key, tail_key));
push_assuming_capacity(
ComparableDoc {
sort_key: (head_key, tail_key),
doc,
},
top_n_computer.buffer(),
);
}
}
}
} else {
// Eagerly push, without a threshold to compare to.
let head_keys = self.head.segment_sort_keys(docs, ValueRange::All);
let tail_keys = self.tail.segment_sort_keys(docs, ValueRange::All);
for ((doc, head_key), (_, tail_key)) in head_keys.drain(..).zip(tail_keys.drain(..)) {
scratch.head_output.clear();
self.head.segment_sort_keys(
docs,
&mut scratch.head_output,
&mut scratch.head,
ValueRange::All,
);
scratch.tail_output.clear();
scratch.tail_input_docs.clear();
for cd in &scratch.head_output {
scratch.tail_input_docs.push(cd.doc);
}
self.tail.segment_sort_keys(
&scratch.tail_input_docs,
&mut scratch.tail_output,
&mut scratch.tail,
ValueRange::All,
);
for (head_doc, tail_doc) in scratch
.head_output
.drain(..)
.zip(scratch.tail_output.drain(..))
{
debug_assert_eq!(head_doc.doc, tail_doc.doc);
let doc = head_doc.doc;
let head_key = head_doc.sort_key;
let tail_key = tail_doc.sort_key;
// We validated at the top of the method that we have capacity.
top_n_computer.append_doc_unchecked(doc, (head_key, tail_key));
push_assuming_capacity(
ComparableDoc {
sort_key: (head_key, tail_key),
doc,
},
top_n_computer.buffer(),
);
}
}
top_n_computer.scratch = scratch;
}
#[inline(always)]
@@ -430,6 +496,7 @@ where
type SortKey = NewScore;
type SegmentSortKey = T::SegmentSortKey;
type SegmentComparator = T::SegmentComparator;
type Buffer = T::Buffer;
fn segment_comparator(&self) -> Self::SegmentComparator {
self.sort_key_computer.segment_comparator()
@@ -441,11 +508,13 @@ where
fn segment_sort_keys(
&mut self,
docs: &[DocId],
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
buffer: &mut Self::Buffer,
filter: ValueRange<Self::SegmentSortKey>,
) {
self.sort_key_computer
.segment_sort_keys(docs, ValueRange::All)
.segment_sort_keys(input_docs, output, buffer, filter)
}
#[inline(always)]
@@ -453,7 +522,7 @@ where
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
self.sort_key_computer
.compute_sort_key_and_collect(doc, score, top_n_computer);
@@ -462,7 +531,7 @@ where
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
docs: &[DocId],
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C, Self::Buffer>,
) {
self.sort_key_computer
.compute_sort_keys_and_collect(docs, top_n_computer);
@@ -521,11 +590,7 @@ where
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer2,
tail: sort_key_computer3,
head_key_buffer: Vec::new(),
doc_buffer: Vec::new(),
},
head_key_buffer: Vec::new(),
doc_buffer: Vec::new(),
},
map,
})
@@ -589,14 +654,8 @@ where
tail: ChainSegmentSortKeyComputer {
head: sort_key_computer3,
tail: sort_key_computer4,
head_key_buffer: Vec::new(),
doc_buffer: Vec::new(),
},
head_key_buffer: Vec::new(),
doc_buffer: Vec::new(),
},
head_key_buffer: Vec::new(),
doc_buffer: Vec::new(),
},
map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| {
(sort_key1, sort_key2, sort_key3, sort_key4)
@@ -620,9 +679,11 @@ where
}
}
use std::marker::PhantomData;
pub struct FuncSegmentSortKeyComputer<F, TSortKey> {
func: F,
buffer: Vec<(DocId, TSortKey)>,
_phantom: PhantomData<TSortKey>,
}
impl<F, SegmentF, TSortKey> SortKeyComputer for F
@@ -638,7 +699,7 @@ where
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok(FuncSegmentSortKeyComputer {
func: (self)(segment_reader),
buffer: Vec::new(),
_phantom: PhantomData,
})
}
}
@@ -651,6 +712,7 @@ where
type SortKey = TSortKey;
type SegmentSortKey = TSortKey;
type SegmentComparator = NaturalComparator;
type Buffer = ();
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
(self.func)(doc)
@@ -658,15 +720,17 @@ where
fn segment_sort_keys(
&mut self,
docs: &[DocId],
input_docs: &[DocId],
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.buffer.clear();
self.buffer.reserve(docs.len());
for &doc in docs {
self.buffer.push((doc, (self.func)(doc)));
) {
for &doc in input_docs {
output.push(ComparableDoc {
sort_key: (self.func)(doc),
doc,
});
}
&mut self.buffer
}
/// Convert a segment level score into the global level score.

View File

@@ -99,7 +99,12 @@ where
TSegmentSortKeyComputer: SegmentSortKeyComputer,
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey>,
{
pub(crate) topn_computer: TopNComputer<TSegmentSortKeyComputer::SegmentSortKey, DocId, C>,
pub(crate) topn_computer: TopNComputer<
TSegmentSortKeyComputer::SegmentSortKey,
DocId,
C,
TSegmentSortKeyComputer::Buffer,
>,
pub(crate) segment_ord: u32,
pub(crate) segment_sort_key_computer: TSegmentSortKeyComputer,
}

View File

@@ -11,8 +11,7 @@ use crate::collector::sort_key::{
SortByStaticFastValue, SortByString,
};
use crate::collector::sort_key_top_collector::TopBySortKeyCollector;
use crate::collector::top_collector::ComparableDoc;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::FastValue;
use crate::{DocAddress, DocId, Order, Score, SegmentReader};
@@ -482,6 +481,7 @@ where
type SortKey = TSortKey;
type SegmentSortKey = TSortKey;
type SegmentComparator = NaturalComparator;
type Buffer = ();
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey {
(self.sort_key_fn)(doc, score)
@@ -489,9 +489,11 @@ where
fn segment_sort_keys(
&mut self,
_docs: &[DocId],
_input_docs: &[DocId],
_output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
_buffer: &mut Self::Buffer,
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
) {
unimplemented!("Batch computation is not supported for tweak score.")
}
@@ -518,12 +520,14 @@ where
/// the ascending `DocId|DocAddress` tie-breaking behavior without additional comparisons.
#[derive(Serialize, Deserialize)]
#[serde(from = "TopNComputerDeser<Score, D, C>")]
pub struct TopNComputer<Score, D, C> {
pub struct TopNComputer<Score, D, C, Buffer = ()> {
/// The buffer reverses sort order to get top-semantics instead of bottom-semantics
buffer: Vec<ComparableDoc<Score, D>>,
top_n: usize,
pub(crate) threshold: Option<Score>,
comparator: C,
#[serde(skip)]
pub scratch: Buffer,
}
// Intermediate struct for TopNComputer for deserialization, to keep vec capacity
@@ -535,7 +539,9 @@ struct TopNComputerDeser<Score, D, C> {
comparator: C,
}
impl<Score, D, C> From<TopNComputerDeser<Score, D, C>> for TopNComputer<Score, D, C> {
impl<Score, D, C, Buffer> From<TopNComputerDeser<Score, D, C>> for TopNComputer<Score, D, C, Buffer>
where Buffer: Default
{
fn from(mut value: TopNComputerDeser<Score, D, C>) -> Self {
let expected_cap = value.top_n.max(1) * 2;
let current_cap = value.buffer.capacity();
@@ -550,12 +556,15 @@ impl<Score, D, C> From<TopNComputerDeser<Score, D, C>> for TopNComputer<Score, D
top_n: value.top_n,
threshold: value.threshold,
comparator: value.comparator,
scratch: Buffer::default(),
}
}
}
impl<Score: std::fmt::Debug, D, C> std::fmt::Debug for TopNComputer<Score, D, C>
where C: Comparator<Score>
impl<Score: std::fmt::Debug, D, C, Buffer> std::fmt::Debug for TopNComputer<Score, D, C, Buffer>
where
C: Comparator<Score>,
Buffer: std::fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> std::fmt::Result {
f.debug_struct("TopNComputer")
@@ -563,12 +572,13 @@ where C: Comparator<Score>
.field("top_n", &self.top_n)
.field("current_threshold", &self.threshold)
.field("comparator", &self.comparator)
.field("scratch", &self.scratch)
.finish()
}
}
// Custom clone to keep capacity
impl<Score: Clone, D: Clone, C: Clone> Clone for TopNComputer<Score, D, C> {
impl<Score: Clone, D: Clone, C: Clone, Buffer: Clone> Clone for TopNComputer<Score, D, C, Buffer> {
fn clone(&self) -> Self {
let mut buffer_clone = Vec::with_capacity(self.buffer.capacity());
buffer_clone.extend(self.buffer.iter().cloned());
@@ -577,11 +587,12 @@ impl<Score: Clone, D: Clone, C: Clone> Clone for TopNComputer<Score, D, C> {
top_n: self.top_n,
threshold: self.threshold.clone(),
comparator: self.comparator.clone(),
scratch: self.scratch.clone(),
}
}
}
impl<TSortKey, D> TopNComputer<TSortKey, D, ReverseComparator>
impl<TSortKey, D> TopNComputer<TSortKey, D, ReverseComparator, ()>
where
D: Ord,
TSortKey: Clone,
@@ -595,7 +606,7 @@ where
}
#[inline(always)]
fn compare_for_top_k<TSortKey, D: Ord, C: Comparator<TSortKey>>(
pub fn compare_for_top_k<TSortKey, D: Ord, C: Comparator<TSortKey>>(
c: &C,
lhs: &ComparableDoc<TSortKey, D>,
rhs: &ComparableDoc<TSortKey, D>,
@@ -606,11 +617,12 @@ fn compare_for_top_k<TSortKey, D: Ord, C: Comparator<TSortKey>>(
// sort by doc id
}
impl<TSortKey, D, C> TopNComputer<TSortKey, D, C>
impl<TSortKey, D, C, Buffer> TopNComputer<TSortKey, D, C, Buffer>
where
D: Ord,
TSortKey: Clone,
C: Comparator<TSortKey>,
Buffer: Default,
{
/// Create a new `TopNComputer`.
/// Internally it will allocate a buffer of size `(top_n.max(1) * 2) +
@@ -624,6 +636,7 @@ where
top_n,
threshold: None,
comparator,
scratch: Buffer::default(),
}
}
@@ -649,15 +662,6 @@ where
pub(crate) fn append_doc(&mut self, doc: D, sort_key: TSortKey) {
self.reserve(1);
// This cannot panic, because we've reserved room for one element.
self.append_doc_unchecked(doc, sort_key);
}
// Append a document to the top n. `reserve` must already have been called to ensure that there
// is capacity, or this method will panic.
//
// At this point, we need to have established that the doc is above the threshold.
#[inline(always)]
pub(crate) fn append_doc_unchecked(&mut self, doc: D, sort_key: TSortKey) {
let comparable_doc = ComparableDoc { doc, sort_key };
push_assuming_capacity(comparable_doc, &mut self.buffer);
}
@@ -672,6 +676,16 @@ where
}
}
pub(crate) fn buffer(&mut self) -> &mut Vec<ComparableDoc<TSortKey, D>> {
&mut self.buffer
}
pub(crate) fn buffer_and_scratch(
&mut self,
) -> (&mut Vec<ComparableDoc<TSortKey, D>>, &mut Buffer) {
(&mut self.buffer, &mut self.scratch)
}
#[inline(never)]
fn truncate_top_n(&mut self) -> TSortKey {
// Use select_nth_unstable to find the top nth score
@@ -729,8 +743,7 @@ mod tests {
use super::{TopDocs, TopNComputer};
use crate::collector::sort_key::{ComparatorEnum, NaturalComparator, ReverseComparator};
use crate::collector::top_collector::ComparableDoc;
use crate::collector::{Collector, DocSetCollector};
use crate::collector::{Collector, ComparableDoc, DocSetCollector};
use crate::query::{AllQuery, Query, QueryParser};
use crate::schema::{Field, Schema, FAST, STORED, TEXT};
use crate::time::format_description::well_known::Rfc3339;
@@ -1760,7 +1773,8 @@ mod tests {
#[test]
fn test_top_n_computer_not_at_capacity() {
let mut top_n_computer = TopNComputer::new_with_comparator(4, NaturalComparator);
let mut top_n_computer: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(4, NaturalComparator);
top_n_computer.append_doc(1, 0.8);
top_n_computer.append_doc(3, 0.2);
top_n_computer.append_doc(5, 0.3);
@@ -1785,7 +1799,8 @@ mod tests {
#[test]
fn test_top_n_computer_at_capacity() {
let mut top_collector = TopNComputer::new_with_comparator(4, NaturalComparator);
let mut top_collector: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(4, NaturalComparator);
top_collector.append_doc(1, 0.8);
top_collector.append_doc(3, 0.2);
top_collector.append_doc(5, 0.3);
@@ -1822,12 +1837,14 @@ mod tests {
let doc_ids_collection = [4, 5, 6];
let score = 3.3f32;
let mut top_collector_limit_2 = TopNComputer::new_with_comparator(2, NaturalComparator);
let mut top_collector_limit_2: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(2, NaturalComparator);
for id in &doc_ids_collection {
top_collector_limit_2.append_doc(*id, score);
}
let mut top_collector_limit_3 = TopNComputer::new_with_comparator(3, NaturalComparator);
let mut top_collector_limit_3: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(3, NaturalComparator);
for id in &doc_ids_collection {
top_collector_limit_3.append_doc(*id, score);
}
@@ -1848,15 +1865,16 @@ mod bench {
#[bench]
fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) {
let mut top_collector = TopNComputer::new_with_comparator(100, NaturalComparator);
let mut top_collector: TopNComputer<f32, u32, _, ()> =
TopNComputer::new_with_comparator(100, NaturalComparator);
for i in 0..100 {
top_collector.append_doc(i, 0.8);
top_collector.append_doc(i as u32, 0.8);
}
b.iter(|| {
for i in 0..100 {
top_collector.append_doc(i, 0.8);
top_collector.append_doc(i as u32, 0.8);
}
});
}