Compare commits

...

1 Commits

Author SHA1 Message Date
Paul Masurel
b8636e707c Seek into the danger zoner 2026-01-03 18:00:05 +01:00
9 changed files with 166 additions and 56 deletions

View File

@@ -1,6 +1,7 @@
use std::borrow::{Borrow, BorrowMut};
use crate::fastfield::AliveBitSet;
use crate::query::SeekAntiCallToken;
use crate::DocId;
/// Sentinel value returned when a [`DocSet`] has been entirely consumed.
@@ -14,6 +15,15 @@ pub const TERMINATED: DocId = i32::MAX as u32;
/// exactly this size as long as we can fill the buffer.
pub const COLLECT_BLOCK_BUFFER_LEN: usize = 64;
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum SeekDangerResult {
/// The seek operation was successful.
Success,
/// The seek operation was unsuccessful.
/// The document was not found.
NotFound(DocId),
}
/// Represents an iterable set of sorted doc ids.
pub trait DocSet: Send {
/// Goes to the next element.
@@ -70,12 +80,20 @@ pub trait DocSet: Send {
///
/// # Warning
/// This is an advanced API used by intersection. The API contract is tricky, avoid using it.
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
fn seek_into_the_danger_zone(
&mut self,
target: DocId,
_: SeekAntiCallToken,
) -> SeekDangerResult {
let current_doc = self.doc();
if current_doc < target {
self.seek(target);
}
self.doc() == target
if self.doc() == target {
SeekDangerResult::Success
} else {
SeekDangerResult::NotFound(self.doc())
}
}
/// Fills a given mutable buffer with the next doc ids from the
@@ -175,8 +193,12 @@ impl DocSet for &mut dyn DocSet {
(**self).seek(target)
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
(**self).seek_into_the_danger_zone(target)
fn seek_into_the_danger_zone(
&mut self,
target: DocId,
token: SeekAntiCallToken,
) -> SeekDangerResult {
(**self).seek_into_the_danger_zone(target, token)
}
fn doc(&self) -> u32 {
@@ -211,9 +233,13 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
unboxed.seek(target)
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
fn seek_into_the_danger_zone(
&mut self,
target: DocId,
token: SeekAntiCallToken,
) -> SeekDangerResult {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.seek_into_the_danger_zone(target)
unboxed.seek_into_the_danger_zone(target, token)
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {

View File

@@ -1,8 +1,8 @@
use std::fmt;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::docset::{SeekDangerResult, COLLECT_BLOCK_BUFFER_LEN};
use crate::fastfield::AliveBitSet;
use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight};
use crate::query::{EnableScoring, Explanation, Query, Scorer, SeekAntiCallToken, Weight};
use crate::{DocId, DocSet, Score, SegmentReader, Term};
/// `BoostQuery` is a wrapper over a query used to boost its score.
@@ -104,8 +104,14 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
fn seek(&mut self, target: DocId) -> DocId {
self.underlying.seek(target)
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
self.underlying.seek_into_the_danger_zone(target)
fn seek_into_the_danger_zone(
&mut self,
target: DocId,
anti_call_token: SeekAntiCallToken,
) -> SeekDangerResult {
self.underlying
.seek_into_the_danger_zone(target, anti_call_token)
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {

View File

@@ -1,8 +1,9 @@
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::docset::SeekDangerResult;
use crate::query::score_combiner::DoNothingCombiner;
use crate::query::{ScoreCombiner, Scorer};
use crate::query::{ScoreCombiner, Scorer, SeekAntiCallToken};
use crate::{DocId, DocSet, Score, TERMINATED};
/// `Disjunction` is responsible for merging `DocSet` from multiple
@@ -67,10 +68,16 @@ impl<T: Scorer> DocSet for ScorerWrapper<T> {
self.current_doc = doc_id;
doc_id
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
let found = self.scorer.seek_into_the_danger_zone(target);
fn seek_into_the_danger_zone(
&mut self,
target: DocId,
anti_call_token: SeekAntiCallToken,
) -> SeekDangerResult {
let result = self
.scorer
.seek_into_the_danger_zone(target, anti_call_token);
self.current_doc = self.scorer.doc();
found
result
}
fn doc(&self) -> DocId {

View File

@@ -1,9 +1,16 @@
use super::size_hint::estimate_intersection;
use crate::docset::{DocSet, TERMINATED};
use crate::docset::{DocSet, SeekDangerResult, TERMINATED};
use crate::query::term_query::TermScorer;
use crate::query::{EmptyScorer, Scorer};
use crate::{DocId, Score};
/// This is a token used to prevent calls to seek_into_the_danger_zone
/// outside of the intersection.
///
/// This is zero-cost.
#[derive(Clone, Copy)]
pub struct SeekAntiCallToken(());
/// Returns the intersection scorer.
///
/// The score associated with the documents is the sum of the
@@ -113,15 +120,19 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
return TERMINATED;
}
const ANTI_CALL_TOKEN: SeekAntiCallToken = SeekAntiCallToken(());
loop {
// In the first part we look for a document in the intersection
// of the two rarest `DocSet` in the intersection.
loop {
if right.seek_into_the_danger_zone(candidate) {
break;
}
let right_doc = right.doc();
let right_doc = match right.seek_into_the_danger_zone(candidate, ANTI_CALL_TOKEN) {
SeekDangerResult::Success => {
break;
}
SeekDangerResult::NotFound(seek_lower_bound) => seek_lower_bound,
};
// TODO: Think about which value would make sense here
// It depends on the DocSet implementation, when a seek would outweigh an advance.
if right_doc > candidate.wrapping_add(100) {
@@ -136,11 +147,10 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
debug_assert_eq!(left.doc(), right.doc());
// test the remaining scorers
if self
.others
.iter_mut()
.all(|docset| docset.seek_into_the_danger_zone(candidate))
{
if self.others.iter_mut().all(|docset| {
docset.seek_into_the_danger_zone(candidate, ANTI_CALL_TOKEN)
== SeekDangerResult::Success
}) {
debug_assert_eq!(candidate, self.left.doc());
debug_assert_eq!(candidate, self.right.doc());
debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate));
@@ -166,13 +176,29 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
///
/// Some implementations may choose to advance past the target if beneficial for performance.
/// The return value is `true` if the target is in the docset, and `false` otherwise.
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
self.left.seek_into_the_danger_zone(target)
&& self.right.seek_into_the_danger_zone(target)
&& self
.others
.iter_mut()
.all(|docset| docset.seek_into_the_danger_zone(target))
fn seek_into_the_danger_zone(
&mut self,
target: DocId,
token: SeekAntiCallToken,
) -> SeekDangerResult {
if let SeekDangerResult::NotFound(seek_doc) =
self.left.seek_into_the_danger_zone(target, token)
{
return SeekDangerResult::NotFound(seek_doc);
}
if let SeekDangerResult::NotFound(seek_doc) =
self.right.seek_into_the_danger_zone(target, token)
{
return SeekDangerResult::NotFound(seek_doc);
}
for other in self.others.iter_mut() {
if let SeekDangerResult::NotFound(seek_doc) =
other.seek_into_the_danger_zone(target, token)
{
return SeekDangerResult::NotFound(seek_doc);
}
}
SeekDangerResult::Success
}
#[inline]

View File

@@ -32,6 +32,7 @@ mod weight;
mod vec_docset;
pub(crate) mod score_combiner;
pub use intersection::SeekAntiCallToken;
pub use query_grammar::Occur;
pub use self::all_query::{AllQuery, AllScorer, AllWeight};

View File

@@ -1,9 +1,9 @@
use crate::docset::{DocSet, TERMINATED};
use crate::docset::{DocSet, SeekDangerResult, TERMINATED};
use crate::fieldnorm::FieldNormReader;
use crate::postings::Postings;
use crate::query::bm25::Bm25Weight;
use crate::query::phrase_query::{intersection_count, PhraseScorer};
use crate::query::Scorer;
use crate::query::{Scorer, SeekAntiCallToken};
use crate::{DocId, Score};
// MultiPrefix is the larger variant, and also the one we expect most often. PhraseScorer is > 1kB
@@ -194,11 +194,20 @@ impl<TPostings: Postings> DocSet for PhrasePrefixScorer<TPostings> {
self.advance()
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
if self.phrase_scorer.seek_into_the_danger_zone(target) {
self.matches_prefix()
fn seek_into_the_danger_zone(
&mut self,
target: DocId,
token: SeekAntiCallToken,
) -> SeekDangerResult {
if let SeekDangerResult::NotFound(seek_doc) =
self.phrase_scorer.seek_into_the_danger_zone(target, token)
{
return SeekDangerResult::NotFound(seek_doc);
}
if self.matches_prefix() {
SeekDangerResult::Success
} else {
false
SeekDangerResult::NotFound(target + 1)
}
}

View File

@@ -1,10 +1,10 @@
use std::cmp::Ordering;
use crate::docset::{DocSet, TERMINATED};
use crate::docset::{DocSet, SeekDangerResult, TERMINATED};
use crate::fieldnorm::FieldNormReader;
use crate::postings::Postings;
use crate::query::bm25::Bm25Weight;
use crate::query::{Intersection, Scorer};
use crate::query::{Intersection, Scorer, SeekAntiCallToken};
use crate::{DocId, Score};
struct PostingsWithOffset<TPostings> {
@@ -530,12 +530,23 @@ impl<TPostings: Postings> DocSet for PhraseScorer<TPostings> {
self.advance()
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
fn seek_into_the_danger_zone(
&mut self,
target: DocId,
token: SeekAntiCallToken,
) -> SeekDangerResult {
debug_assert!(target >= self.doc());
if self.intersection_docset.seek_into_the_danger_zone(target) && self.phrase_match() {
return true;
if let SeekDangerResult::NotFound(seek_doc) = self
.intersection_docset
.seek_into_the_danger_zone(target, token)
{
return SeekDangerResult::NotFound(seek_doc);
}
if self.phrase_match() {
SeekDangerResult::Success
} else {
SeekDangerResult::NotFound(target + 1)
}
false
}
fn doc(&self) -> DocId {

View File

@@ -1,8 +1,8 @@
use std::marker::PhantomData;
use crate::docset::DocSet;
use crate::docset::{DocSet, SeekDangerResult};
use crate::query::score_combiner::ScoreCombiner;
use crate::query::Scorer;
use crate::query::{Scorer, SeekAntiCallToken};
use crate::{DocId, Score};
/// Given a required scorer and an optional scorer
@@ -56,9 +56,13 @@ where
self.req_scorer.seek(target)
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
fn seek_into_the_danger_zone(
&mut self,
target: DocId,
token: SeekAntiCallToken,
) -> SeekDangerResult {
self.score_cache = None;
self.req_scorer.seek_into_the_danger_zone(target)
self.req_scorer.seek_into_the_danger_zone(target, token)
}
fn doc(&self) -> DocId {

View File

@@ -1,9 +1,9 @@
use common::TinySet;
use crate::docset::{DocSet, TERMINATED};
use crate::docset::{DocSet, SeekDangerResult, TERMINATED};
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::size_hint::estimate_union;
use crate::query::Scorer;
use crate::query::{Scorer, SeekAntiCallToken};
use crate::{DocId, Score};
// The buffered union looks ahead within a fixed-size sliding window
@@ -225,25 +225,45 @@ where
}
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
fn seek_into_the_danger_zone(
&mut self,
target: DocId,
token: SeekAntiCallToken,
) -> SeekDangerResult {
if self.is_in_horizon(target) {
// Our value is within the buffered horizon and the docset may already have been
// processed and removed, so we need to use seek, which uses the regular advance.
self.seek(target) == target
let seek_doc = self.seek(target);
if self.seek(target) == target {
SeekDangerResult::Success
} else {
SeekDangerResult::NotFound(seek_doc)
}
} else {
// The docsets are not in the buffered range, so we can use seek_into_the_danger_zone
// of the underlying docsets
let is_hit = self
.docsets
.iter_mut()
.any(|docset| docset.seek_into_the_danger_zone(target));
let mut is_hit = false;
let mut seek_doc_min = u32::MAX;
for docset in self.docsets.iter_mut() {
match docset.seek_into_the_danger_zone(target, token) {
SeekDangerResult::Success => {
is_hit = true;
break;
}
SeekDangerResult::NotFound(seek_doc) => {
seek_doc_min = seek_doc.min(seek_doc_min);
}
}
}
// The API requires the DocSet to be in a valid state when `seek_into_the_danger_zone`
// returns true.
if is_hit {
self.seek(target);
SeekDangerResult::Success
} else {
SeekDangerResult::NotFound(seek_doc_min)
}
is_hit
}
}