Implement collect_block for lazy scorers.

This commit is contained in:
Stu Hood
2025-11-01 15:46:46 -07:00
parent 03f09a2b5b
commit 9ec5750c25
3 changed files with 99 additions and 13 deletions

View File

@@ -2,7 +2,10 @@ use std::cmp::Ordering;
use crate::collector::sort_key::{Comparator, NaturalComparator};
use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector;
use crate::collector::{default_collect_segment_impl, SegmentCollector as _, TopNComputer};
use crate::collector::top_score_collector::push_assuming_capacity;
use crate::collector::{
default_collect_segment_impl, ComparableDoc, SegmentCollector as _, TopNComputer,
};
use crate::schema::Schema;
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
@@ -45,6 +48,38 @@ pub trait SegmentSortKeyComputer: 'static {
top_n_computer.push(sort_key, doc);
}
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
docs: &[DocId],
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
// The capacity of a TopNComputer is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so we
// should always be able to `reserve` space for the entire block.
top_n_computer.reserve(docs.len());
if let Some(threshold) = &top_n_computer.threshold {
// TODO: Would need to split the borrow of the TopNComputer to avoid cloning the
// threshold here.
let threshold = threshold.clone();
// Eagerly push, with a threshold to compare to.
for &doc in docs {
let sort_key = self.segment_sort_key(doc, 0.0);
let cmp = self.compare_segment_sort_key(&sort_key, &threshold);
if cmp == Ordering::Greater {
// We validated at the top of the method that we have capacity.
top_n_computer.append_doc_unchecked(doc, sort_key);
}
}
} else {
// Eagerly push, without a threshold to compare to.
for &doc in docs {
let sort_key = self.segment_sort_key(doc, 0.0);
// We validated at the top of the method that we have capacity.
top_n_computer.append_doc_unchecked(doc, sort_key);
}
}
}
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
///
@@ -78,6 +113,26 @@ pub trait SegmentSortKeyComputer: 'static {
}
}
/// Similar to `accept_sort_key_lazy`, but pushes results directly into the given buffer. Does
/// not support scoring.
///
/// The buffer must have at least enough capacity for `docs` matches, or this method will
/// panic.
fn accept_sort_key_block_lazy(
&mut self,
docs: &[DocId],
threshold: &Self::SegmentSortKey,
output: &mut Vec<ComparableDoc<Self::SegmentSortKey, DocId>>,
) {
for &doc in docs {
let sort_key = self.segment_sort_key(doc, 0.0);
let cmp = self.compare_segment_sort_key(&sort_key, threshold);
if cmp != Ordering::Less {
push_assuming_capacity(ComparableDoc { sort_key, doc }, output);
}
}
}
/// Convert a segment level sort key into the global sort key.
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey;
}
@@ -233,6 +288,14 @@ where
top_n_computer.append_doc(doc, sort_key);
}
fn compute_sort_keys_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
_docs: &[DocId],
_top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
todo!("Override for laziness.");
}
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
let head_sort_key = self.0.segment_sort_key(doc, score);

View File

@@ -120,6 +120,11 @@ where
);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.segment_sort_key_computer
.compute_sort_keys_and_collect(docs, &mut self.topn_computer);
}
fn harvest(self) -> Self::Fruit {
let segment_ord = self.segment_ord;
let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self

View File

@@ -592,9 +592,12 @@ where
C: Comparator<TSortKey>,
{
/// Create a new `TopNComputer`.
/// Internally it will allocate a buffer of size `2 * top_n`.
/// Internally it will allocate a buffer of size `(top_n.max(1) * 2) +
/// COLLECT_BLOCK_BUFFER_LEN`.
pub fn new_with_comparator(top_n: usize, comparator: C) -> Self {
let vec_cap = top_n.max(1) * 2;
// We ensure that there is always enough space to include an entire block in the buffer if
// need be, so that `push_block_lazy` can avoid checking capacity inside its loop.
let vec_cap = (top_n.max(1) * 2) + crate::COLLECT_BLOCK_BUFFER_LEN;
TopNComputer {
buffer: Vec::with_capacity(vec_cap),
top_n,
@@ -623,16 +626,31 @@ where
// At this point, we need to have established that the doc is above the threshold.
#[inline(always)]
pub(crate) fn append_doc(&mut self, doc: D, sort_key: TSortKey) {
if self.buffer.len() == self.buffer.capacity() {
let median = self.truncate_top_n();
self.threshold = Some(median);
}
// This cannot panic, because we truncate_median will at least remove one element, since
// the min capacity is 2.
self.reserve(1);
// This cannot panic, because we've reserved room for one element.
self.append_doc_unchecked(doc, sort_key);
}
// Append a document to the top n. `reserve` must already have been called to ensure that there
// is capacity, or this method will panic.
//
// At this point, we need to have established that the doc is above the threshold.
#[inline(always)]
pub(crate) fn append_doc_unchecked(&mut self, doc: D, sort_key: TSortKey) {
let comparable_doc = ComparableDoc { doc, sort_key };
push_assuming_capacity(comparable_doc, &mut self.buffer);
}
// Ensure that there is capacity to push `additional` more elements without resizing.
#[inline(always)]
pub(crate) fn reserve(&mut self, additional: usize) {
if self.buffer.len() + additional > self.buffer.capacity() {
let median = self.truncate_top_n();
debug_assert!(self.buffer.len() + additional <= self.buffer.capacity());
self.threshold = Some(median);
}
}
#[inline(never)]
fn truncate_top_n(&mut self) -> TSortKey {
// Use select_nth_unstable to find the top nth score
@@ -672,7 +690,7 @@ where
//
// Panics if there is not enough capacity to add an element.
#[inline(always)]
fn push_assuming_capacity<T>(el: T, buf: &mut Vec<T>) {
pub fn push_assuming_capacity<T>(el: T, buf: &mut Vec<T>) {
let prev_len = buf.len();
assert!(prev_len < buf.capacity());
// This is mimicking the current (non-stabilized) implementation in std.
@@ -1398,11 +1416,11 @@ mod tests {
#[test]
fn test_top_field_collect_string_prop(
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..256_usize,
offset in 0..256_usize,
limit in 1..32_usize,
offset in 0..32_usize,
segments_terms in
proptest::collection::vec(
proptest::collection::vec(0..32_u8, 1..32_usize),
proptest::collection::vec(0..64_u8, 1..256_usize),
0..8_usize,
)
) {