mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-09 02:22:54 +00:00
178 lines
5.5 KiB
Rust
178 lines
5.5 KiB
Rust
use crate::docset::DocSet;
|
|
use crate::query::score_combiner::ScoreCombiner;
|
|
use crate::query::Scorer;
|
|
use crate::DocId;
|
|
use crate::Score;
|
|
use std::marker::PhantomData;
|
|
|
|
/// Given a required scorer and an optional scorer
|
|
/// matches all document from the required scorer
|
|
/// and complements the score using the optional scorer.
|
|
///
|
|
/// This is useful for queries like `+somethingrequired somethingoptional`.
|
|
///
|
|
/// Note that `somethingoptional` has no impact on the `DocSet`.
|
|
pub struct RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner: ScoreCombiner> {
|
|
req_scorer: TReqScorer,
|
|
opt_scorer: TOptScorer,
|
|
score_cache: Option<Score>,
|
|
_phantom: PhantomData<TScoreCombiner>,
|
|
}
|
|
|
|
impl<TReqScorer, TOptScorer, TScoreCombiner>
|
|
RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner>
|
|
where
|
|
TOptScorer: DocSet,
|
|
TScoreCombiner: ScoreCombiner,
|
|
{
|
|
/// Creates a new `RequiredOptionalScorer`.
|
|
pub fn new(
|
|
req_scorer: TReqScorer,
|
|
opt_scorer: TOptScorer,
|
|
) -> RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner> {
|
|
RequiredOptionalScorer {
|
|
req_scorer,
|
|
opt_scorer,
|
|
score_cache: None,
|
|
_phantom: PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<TReqScorer, TOptScorer, TScoreCombiner> DocSet
|
|
for RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner>
|
|
where
|
|
TReqScorer: DocSet,
|
|
TOptScorer: DocSet,
|
|
TScoreCombiner: ScoreCombiner,
|
|
{
|
|
fn advance(&mut self) -> DocId {
|
|
self.score_cache = None;
|
|
self.req_scorer.advance()
|
|
}
|
|
|
|
fn doc(&self) -> DocId {
|
|
self.req_scorer.doc()
|
|
}
|
|
|
|
fn size_hint(&self) -> u32 {
|
|
self.req_scorer.size_hint()
|
|
}
|
|
}
|
|
|
|
impl<TReqScorer, TOptScorer, TScoreCombiner> Scorer
|
|
for RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner>
|
|
where
|
|
TReqScorer: Scorer,
|
|
TOptScorer: Scorer,
|
|
TScoreCombiner: ScoreCombiner,
|
|
{
|
|
fn score(&mut self) -> Score {
|
|
if let Some(score) = self.score_cache {
|
|
return score;
|
|
}
|
|
let doc = self.doc();
|
|
let mut score_combiner = TScoreCombiner::default();
|
|
score_combiner.update(&mut self.req_scorer);
|
|
if self.opt_scorer.doc() <= doc && self.opt_scorer.seek(doc) == doc {
|
|
score_combiner.update(&mut self.opt_scorer);
|
|
}
|
|
let score = score_combiner.score();
|
|
self.score_cache = Some(score);
|
|
score
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::RequiredOptionalScorer;
|
|
use crate::docset::{DocSet, TERMINATED};
|
|
use crate::postings::tests::test_skip_against_unoptimized;
|
|
use crate::query::score_combiner::{DoNothingCombiner, SumCombiner};
|
|
use crate::query::ConstScorer;
|
|
use crate::query::Scorer;
|
|
use crate::query::VecDocSet;
|
|
use crate::tests::sample_with_seed;
|
|
|
|
#[test]
|
|
fn test_reqopt_scorer_empty() {
|
|
let req = vec![1, 3, 7];
|
|
let mut reqoptscorer: RequiredOptionalScorer<_, _, SumCombiner> =
|
|
RequiredOptionalScorer::new(
|
|
ConstScorer::from(VecDocSet::from(req.clone())),
|
|
ConstScorer::from(VecDocSet::from(vec![])),
|
|
);
|
|
let mut docs = vec![];
|
|
while reqoptscorer.doc() != TERMINATED {
|
|
docs.push(reqoptscorer.doc());
|
|
reqoptscorer.advance();
|
|
}
|
|
assert_eq!(docs, req);
|
|
}
|
|
|
|
#[test]
|
|
fn test_reqopt_scorer() {
|
|
let mut reqoptscorer: RequiredOptionalScorer<_, _, SumCombiner> =
|
|
RequiredOptionalScorer::new(
|
|
ConstScorer::new(VecDocSet::from(vec![1, 3, 7, 8, 9, 10, 13, 15]), 1.0),
|
|
ConstScorer::new(VecDocSet::from(vec![1, 2, 7, 11, 12, 15]), 1.0),
|
|
);
|
|
{
|
|
assert_eq!(reqoptscorer.doc(), 1);
|
|
assert_eq!(reqoptscorer.score(), 2.0);
|
|
}
|
|
{
|
|
assert_eq!(reqoptscorer.advance(), 3);
|
|
assert_eq!(reqoptscorer.doc(), 3);
|
|
assert_eq!(reqoptscorer.score(), 1.0);
|
|
}
|
|
{
|
|
assert_eq!(reqoptscorer.advance(), 7);
|
|
assert_eq!(reqoptscorer.doc(), 7);
|
|
assert_eq!(reqoptscorer.score(), 2.0);
|
|
}
|
|
{
|
|
assert_eq!(reqoptscorer.advance(), 8);
|
|
assert_eq!(reqoptscorer.doc(), 8);
|
|
assert_eq!(reqoptscorer.score(), 1.0);
|
|
}
|
|
{
|
|
assert_eq!(reqoptscorer.advance(), 9);
|
|
assert_eq!(reqoptscorer.doc(), 9);
|
|
assert_eq!(reqoptscorer.score(), 1.0);
|
|
}
|
|
{
|
|
assert_eq!(reqoptscorer.advance(), 10);
|
|
assert_eq!(reqoptscorer.doc(), 10);
|
|
assert_eq!(reqoptscorer.score(), 1.0);
|
|
}
|
|
{
|
|
assert_eq!(reqoptscorer.advance(), 13);
|
|
assert_eq!(reqoptscorer.doc(), 13);
|
|
assert_eq!(reqoptscorer.score(), 1.0);
|
|
}
|
|
{
|
|
assert_eq!(reqoptscorer.advance(), 15);
|
|
assert_eq!(reqoptscorer.doc(), 15);
|
|
assert_eq!(reqoptscorer.score(), 2.0);
|
|
}
|
|
assert_eq!(reqoptscorer.advance(), TERMINATED);
|
|
}
|
|
|
|
#[test]
|
|
fn test_reqopt_scorer_skip() {
|
|
let req_docs = sample_with_seed(10_000, 0.02, 1);
|
|
let opt_docs = sample_with_seed(10_000, 0.02, 2);
|
|
let skip_docs = sample_with_seed(10_000, 0.001, 3);
|
|
test_skip_against_unoptimized(
|
|
|| {
|
|
Box::new(RequiredOptionalScorer::<_, _, DoNothingCombiner>::new(
|
|
ConstScorer::from(VecDocSet::from(req_docs.clone())),
|
|
ConstScorer::from(VecDocSet::from(opt_docs.clone())),
|
|
))
|
|
},
|
|
skip_docs,
|
|
);
|
|
}
|
|
}
|