mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-06-23 02:40:44 +00:00
Compare commits
7 Commits
faster_uni
...
trinity.po
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4031d97bac | ||
|
|
799f7b4646 | ||
|
|
fc88d80726 | ||
|
|
6a684e7c38 | ||
|
|
94fe52cc67 | ||
|
|
2ff39f6f7f | ||
|
|
1d06328cb3 |
@@ -241,6 +241,28 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::column_values::u64_based::tests::create_and_validate;
|
||||
|
||||
// A block boundary where a high run ends and a low run begins: y0 ≈ 2^32, y511 ≈ 0.
|
||||
// This large jump used to cause an overflow which made us render all value on 64b
|
||||
// when 32 was enough.
|
||||
fn large_descending_jump_vals() -> Vec<u64> {
|
||||
let high_start: u64 = 4_294_967_039; // ≈ 2^32 - 257
|
||||
(0u64..256)
|
||||
.map(|i| high_start + i)
|
||||
.chain(0u64..256)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_blockwise_linear_large_descending_jump_uses_at_most_32bit() {
|
||||
let vals = large_descending_jump_vals();
|
||||
let (_, actual_rate) =
|
||||
create_and_validate::<BlockwiseLinearCodec>(&vals, "large descending jump").unwrap();
|
||||
assert!(
|
||||
actual_rate <= 0.6,
|
||||
"compression rate {actual_rate:.3} is too high (bug: 64-bit residuals)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_codec_data_sets_simple() {
|
||||
create_and_validate::<BlockwiseLinearCodec>(
|
||||
|
||||
@@ -37,7 +37,7 @@ fn compute_slope(y0: u64, y1: u64, num_vals: NonZeroU32) -> u64 {
|
||||
} else {
|
||||
y0.wrapping_sub(y1)
|
||||
};
|
||||
if abs_dy >= 1 << 32 {
|
||||
if abs_dy >= 1 << 31 {
|
||||
// This is outside of realm we handle.
|
||||
// Let's just bail.
|
||||
return 0u64;
|
||||
|
||||
@@ -299,6 +299,12 @@ impl AggregationVariants {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
pub(crate) fn as_sum(&self) -> Option<&SumAggregation> {
|
||||
match &self {
|
||||
AggregationVariants::Sum(sum) => Some(sum),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -377,7 +377,22 @@ impl IntermediateMetricResult {
|
||||
MetricResult::ExtendedStats(intermediate_stats.finalize())
|
||||
}
|
||||
IntermediateMetricResult::Sum(intermediate_sum) => {
|
||||
MetricResult::Sum(intermediate_sum.finalize().into())
|
||||
// By default match Elasticsearch: empty / all-missing sum
|
||||
// buckets serialize as `"value": 0`, not `"value": null`.
|
||||
// The non-ES `none_if_no_match` flag on `SumAggregation`
|
||||
// opts into SQL-style `null` for downstream consumers.
|
||||
let none_if_no_match = req
|
||||
.agg
|
||||
.as_sum()
|
||||
.and_then(|sum| sum.none_if_no_match)
|
||||
.unwrap_or(false);
|
||||
let value = intermediate_sum.finalize();
|
||||
if none_if_no_match {
|
||||
MetricResult::Sum(value.into())
|
||||
} else {
|
||||
let value = Some(value.unwrap_or(0.0));
|
||||
MetricResult::Sum(value.into())
|
||||
}
|
||||
}
|
||||
IntermediateMetricResult::Percentiles(percentiles) => MetricResult::Percentiles(
|
||||
percentiles
|
||||
|
||||
@@ -27,6 +27,16 @@ pub struct SumAggregation {
|
||||
/// { "field": "my_numbers", "missing": "10.0" }
|
||||
#[serde(default, deserialize_with = "deserialize_option_f64")]
|
||||
pub missing: Option<f64>,
|
||||
/// Non-Elasticsearch extension. When `Some(true)`, the serialized result
|
||||
/// returns `"value": null` if no values were collected (all documents had
|
||||
/// missing/NULL values for the field), matching the behavior of `min`,
|
||||
/// `max`, and `avg`. When `None` or `Some(false)` (the default) the
|
||||
/// result returns `"value": 0`, matching Elasticsearch.
|
||||
///
|
||||
/// Intended for SQL-style consumers where `SUM` of zero rows is `NULL`
|
||||
/// and must be distinguishable from a bucket that genuinely sums to `0`.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub none_if_no_match: Option<bool>,
|
||||
}
|
||||
|
||||
impl SumAggregation {
|
||||
@@ -35,6 +45,7 @@ impl SumAggregation {
|
||||
Self {
|
||||
field: field_name,
|
||||
missing: None,
|
||||
none_if_no_match: None,
|
||||
}
|
||||
}
|
||||
/// Returns the field name the aggregation is computed on.
|
||||
@@ -59,8 +70,104 @@ impl IntermediateSum {
|
||||
pub fn merge_fruits(&mut self, other: IntermediateSum) {
|
||||
self.stats.merge_fruits(other.stats);
|
||||
}
|
||||
/// Computes the final minimum value.
|
||||
/// Computes the final sum value.
|
||||
///
|
||||
/// Returns `None` when no values were collected, matching the Rust-side
|
||||
/// behavior of `IntermediateMin`, `IntermediateMax`, and
|
||||
/// `IntermediateAvg`. The Elasticsearch-vs-SQL choice for the
|
||||
/// user-visible result is made at the boundary in
|
||||
/// [`IntermediateMetricResult::into_final_metric_result`]: by default
|
||||
/// `None` is coerced to `Some(0.0)` to match Elasticsearch
|
||||
/// (`"value": 0`), and the [`SumAggregation::none_if_no_match`] flag
|
||||
/// opts out of that coercion for SQL-style consumers.
|
||||
pub fn finalize(&self) -> Option<f64> {
|
||||
Some(self.stats.finalize().sum)
|
||||
let stats = self.stats.finalize();
|
||||
if stats.count == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(stats.sum)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sum_finalize_returns_none_when_no_values() {
|
||||
// Default IntermediateSum has count=0 — finalize should return None,
|
||||
// matching MIN/MAX/AVG behavior for all-NULL groups.
|
||||
let sum = IntermediateSum::default();
|
||||
assert_eq!(sum.finalize(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_finalize_returns_value_when_has_values() {
|
||||
let mut sum = IntermediateSum::default();
|
||||
// Merge in a result that has actual values
|
||||
let stats = IntermediateStats {
|
||||
count: 3,
|
||||
sum: 42.0,
|
||||
min: 10.0,
|
||||
max: 20.0,
|
||||
..Default::default()
|
||||
};
|
||||
let other = IntermediateSum::from_stats(stats);
|
||||
sum.merge_fruits(other);
|
||||
assert_eq!(sum.finalize(), Some(42.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_merge_two_empty_still_none() {
|
||||
let mut a = IntermediateSum::default();
|
||||
let b = IntermediateSum::default();
|
||||
a.merge_fruits(b);
|
||||
assert_eq!(a.finalize(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_aggregation_empty_index_default_matches_es() -> crate::Result<()> {
|
||||
use serde_json::json;
|
||||
|
||||
use crate::aggregation::agg_req::Aggregations;
|
||||
use crate::aggregation::tests::{exec_request, get_test_index_from_terms};
|
||||
|
||||
// Empty index — sum has no values to collect.
|
||||
let values: Vec<Vec<&str>> = vec![];
|
||||
let index = get_test_index_from_terms(false, &values)?;
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"score_sum": { "sum": { "field": "score" } }
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
// Default: match Elasticsearch — empty sum serializes as 0, not null.
|
||||
assert_eq!(res["score_sum"]["value"], 0.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sum_aggregation_empty_index_none_if_no_match_opt_in() -> crate::Result<()> {
|
||||
use serde_json::json;
|
||||
|
||||
use crate::aggregation::agg_req::Aggregations;
|
||||
use crate::aggregation::tests::{exec_request, get_test_index_from_terms};
|
||||
|
||||
let values: Vec<Vec<&str>> = vec![];
|
||||
let index = get_test_index_from_terms(false, &values)?;
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"score_sum": { "sum": { "field": "score", "none_if_no_match": true } }
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
// Opt-in non-ES extension — empty sum serializes as null.
|
||||
assert!(
|
||||
res["score_sum"]["value"].is_null(),
|
||||
"expected null, got {:?}",
|
||||
res["score_sum"]["value"]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,31 +138,6 @@ pub trait DocSet: Send {
|
||||
buffer.len()
|
||||
}
|
||||
|
||||
/// Fills a given mutable buffer with the next doc ids smaller than `horizon`.
|
||||
///
|
||||
/// Unlike [`DocSet::fill_buffer`], this method must not advance past a doc id greater than or
|
||||
/// equal to `horizon`.
|
||||
fn fill_buffer_up_to(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
if self.doc() == TERMINATED {
|
||||
return 0;
|
||||
}
|
||||
for (pos, buffer_val) in buffer.iter_mut().enumerate() {
|
||||
let doc = self.doc();
|
||||
if doc >= horizon {
|
||||
return pos;
|
||||
}
|
||||
*buffer_val = doc;
|
||||
if self.advance() == TERMINATED {
|
||||
return pos + 1;
|
||||
}
|
||||
}
|
||||
buffer.len()
|
||||
}
|
||||
|
||||
/// Returns the current document
|
||||
/// Right after creating a new `DocSet`, the docset points to the first document.
|
||||
///
|
||||
@@ -276,14 +251,6 @@ impl DocSet for &mut dyn DocSet {
|
||||
(**self).fill_buffer(buffer)
|
||||
}
|
||||
|
||||
fn fill_buffer_up_to(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
(**self).fill_buffer_up_to(horizon, buffer)
|
||||
}
|
||||
|
||||
fn fill_bitset_block(
|
||||
&mut self,
|
||||
min_doc: DocId,
|
||||
|
||||
@@ -240,42 +240,6 @@ impl BlockSegmentPostings {
|
||||
self.freq_decoder.output_array()
|
||||
}
|
||||
|
||||
pub(crate) fn copy_docs_and_term_freqs(
|
||||
&self,
|
||||
block_offset: usize,
|
||||
horizon: DocId,
|
||||
docs: &mut [DocId],
|
||||
term_freqs: &mut [u32],
|
||||
) -> usize {
|
||||
debug_assert_eq!(docs.len(), term_freqs.len());
|
||||
let block_docs = self.docs();
|
||||
let remaining_docs_in_block = block_docs.len().saturating_sub(block_offset);
|
||||
let max_len = remaining_docs_in_block.min(docs.len());
|
||||
if max_len == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let source_docs = &block_docs[block_offset..block_offset + max_len];
|
||||
let len = if source_docs[max_len - 1] < horizon {
|
||||
max_len
|
||||
} else {
|
||||
source_docs
|
||||
.iter()
|
||||
.position(|&doc| doc >= horizon)
|
||||
.unwrap_or(max_len)
|
||||
};
|
||||
|
||||
docs[..len].copy_from_slice(&source_docs[..len]);
|
||||
|
||||
let block_freqs = self.freq_output_array();
|
||||
if block_freqs.len() >= block_offset + len {
|
||||
term_freqs[..len].copy_from_slice(&block_freqs[block_offset..block_offset + len]);
|
||||
} else {
|
||||
term_freqs[..len].fill(1);
|
||||
}
|
||||
len
|
||||
}
|
||||
|
||||
/// Return the frequency at index `idx` of the block.
|
||||
#[inline]
|
||||
pub fn freq(&self, idx: usize) -> u32 {
|
||||
@@ -323,6 +287,33 @@ impl BlockSegmentPostings {
|
||||
doc
|
||||
}
|
||||
|
||||
/// Returns the number of documents with a doc id strictly smaller than `target`
|
||||
/// (i.e. the *rank* of `target` in this posting list).
|
||||
///
|
||||
/// This jumps to the block that may contain `target` through the skip list, so no
|
||||
/// skipped block is decoded; a single block is then decoded to locate `target`
|
||||
/// within it. The cost is therefore `O(number_of_skip_list_entries)` plus one block
|
||||
/// decode, rather than `O(doc_freq)`.
|
||||
///
|
||||
/// Like [`Self::seek`], the underlying cursor only ever moves forward. This method
|
||||
/// must be called with **non-decreasing** `target` values (galloping); calling it
|
||||
/// with a `target` smaller than a previous one yields an incorrect result. `target`
|
||||
/// must be a valid doc id (i.e. `target <= TERMINATED`), exactly as for `seek`.
|
||||
///
|
||||
/// Edge cases: returns `0` when `target` is smaller than every doc id, and
|
||||
/// `doc_freq()` when `target` is larger than every doc id.
|
||||
pub fn rank(&mut self, target: DocId) -> u32 {
|
||||
if self.doc_freq == 0 {
|
||||
return 0;
|
||||
}
|
||||
// `within` = number of docs in the landed block with a doc id < target.
|
||||
let within = self.seek(target);
|
||||
// `remaining_docs` counts the landed block and everything after it, so the
|
||||
// difference is the number of docs in all blocks strictly before it.
|
||||
let docs_before_block = self.doc_freq - self.skip_reader.remaining_docs();
|
||||
docs_before_block + within as u32
|
||||
}
|
||||
|
||||
pub(crate) fn position_offset(&self) -> u64 {
|
||||
self.skip_reader.position_offset()
|
||||
}
|
||||
@@ -604,4 +595,38 @@ mod tests {
|
||||
assert_eq!(block_segments.docs(), &[1, 3, 5]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_segment_postings_rank() -> crate::Result<()> {
|
||||
// ~8 blocks worth of docs so the skip list is actually exercised.
|
||||
let docs: Vec<DocId> = (0..1000u32).map(|i| i * 3).collect();
|
||||
let mut block_postings = build_block_postings(&docs[..])?;
|
||||
let doc_freq = block_postings.doc_freq();
|
||||
|
||||
// rank(target) must equal the number of docs strictly below target.
|
||||
// Targets are queried in non-decreasing order, as the API requires.
|
||||
// `target` values must be a valid doc id (<= TERMINATED) and non-decreasing.
|
||||
let targets = [
|
||||
0u32, 1, 2, 3, 4, 299, 300, 301, 1500, 2996, 2997, 3000, 10_000,
|
||||
];
|
||||
for &target in &targets {
|
||||
let expected = docs.iter().filter(|&&d| d < target).count() as u32;
|
||||
assert_eq!(
|
||||
block_postings.rank(target),
|
||||
expected,
|
||||
"rank({target}) mismatch"
|
||||
);
|
||||
}
|
||||
|
||||
// Edge cases: below the first doc -> 0, above the last doc -> doc_freq.
|
||||
let mut fresh = build_block_postings(&docs[..])?;
|
||||
assert_eq!(fresh.rank(0), 0);
|
||||
let mut fresh = build_block_postings(&docs[..])?;
|
||||
assert_eq!(fresh.rank(1_000_000), doc_freq);
|
||||
|
||||
// Empty postings: rank is always 0.
|
||||
let mut empty = BlockSegmentPostings::empty();
|
||||
assert_eq!(empty.rank(42), 0);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -532,16 +532,6 @@ pub(crate) mod tests {
|
||||
fn score(&mut self) -> Score {
|
||||
self.0.score()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
self.0.can_score_doc()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
|
||||
self.0.score_doc(doc, term_freq)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_skip_against_unoptimized<F: Fn() -> Box<dyn DocSet>>(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use common::HasLen;
|
||||
|
||||
use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN};
|
||||
use crate::docset::DocSet;
|
||||
use crate::fastfield::AliveBitSet;
|
||||
use crate::positions::PositionReader;
|
||||
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
|
||||
@@ -151,34 +151,6 @@ impl SegmentPostings {
|
||||
position_reader,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn fill_buffer_up_to_with_term_freqs(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
let mut num_elems = 0;
|
||||
while num_elems < COLLECT_BLOCK_BUFFER_LEN && self.doc() < horizon {
|
||||
let copied = self.block_cursor.copy_docs_and_term_freqs(
|
||||
self.cur,
|
||||
horizon,
|
||||
&mut docs[num_elems..],
|
||||
&mut term_freqs[num_elems..],
|
||||
);
|
||||
if copied == 0 {
|
||||
break;
|
||||
}
|
||||
num_elems += copied;
|
||||
self.cur += copied;
|
||||
|
||||
if self.cur == COMPRESSION_BLOCK_SIZE {
|
||||
self.cur = 0;
|
||||
self.block_cursor.advance();
|
||||
}
|
||||
}
|
||||
num_elems
|
||||
}
|
||||
}
|
||||
|
||||
impl DocSet for SegmentPostings {
|
||||
|
||||
@@ -187,6 +187,12 @@ impl SkipReader {
|
||||
self.last_doc_in_block
|
||||
}
|
||||
|
||||
/// Number of docs from the start of the current block to the end of the postings
|
||||
/// (i.e. the current block plus every block after it).
|
||||
pub(crate) fn remaining_docs(&self) -> u32 {
|
||||
self.remaining_docs
|
||||
}
|
||||
|
||||
pub fn position_offset(&self) -> u64 {
|
||||
self.position_offset
|
||||
}
|
||||
|
||||
@@ -109,16 +109,6 @@ impl Scorer for AllScorer {
|
||||
fn score(&mut self) -> Score {
|
||||
1.0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
|
||||
1.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
use std::cell::RefCell;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::sync::Arc;
|
||||
|
||||
use lru::LruCache;
|
||||
|
||||
use crate::fieldnorm::FieldNormReader;
|
||||
use crate::query::Explanation;
|
||||
use crate::schema::Field;
|
||||
@@ -63,9 +59,7 @@ fn cached_tf_component(fieldnorm: u32, average_fieldnorm: Score) -> Score {
|
||||
K1 * (1.0 - B + B * fieldnorm as Score / average_fieldnorm)
|
||||
}
|
||||
|
||||
const BM25_TF_CACHE_CAPACITY: usize = 64;
|
||||
|
||||
fn compute_tf_cache_uncached(average_fieldnorm: Score) -> Arc<[Score; 256]> {
|
||||
fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
|
||||
let mut cache: [Score; 256] = [0.0; 256];
|
||||
for (fieldnorm_id, cache_mut) in cache.iter_mut().enumerate() {
|
||||
let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8);
|
||||
@@ -74,36 +68,6 @@ fn compute_tf_cache_uncached(average_fieldnorm: Score) -> Arc<[Score; 256]> {
|
||||
Arc::new(cache)
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static TF_CACHES: RefCell<LruCache<u32, Arc<[Score; 256]>>> = RefCell::new(LruCache::new(
|
||||
NonZeroUsize::new(BM25_TF_CACHE_CAPACITY).unwrap(),
|
||||
));
|
||||
}
|
||||
|
||||
/// The cache is shared across all [Bm25Weight] with the same average fieldnorm on the same thread.
|
||||
/// It is stored in a thread local LRU cache.
|
||||
///
|
||||
/// On one query all terms on the same field will share the same average fieldnorm, and thus the
|
||||
/// same cache. This will lower cache pressure.
|
||||
///
|
||||
/// Even between queries (on the same thread), the cache will be reused, which allows the cache to
|
||||
/// better learn the memory address of the cache and access patterns.
|
||||
///
|
||||
/// Thread local is used in order to be defensive about potential contention on the cache.
|
||||
fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
|
||||
let cache_key = average_fieldnorm.to_bits();
|
||||
TF_CACHES.with(|cache_by_average_fieldnorm| {
|
||||
let mut cache_by_average_fieldnorm = cache_by_average_fieldnorm.borrow_mut();
|
||||
if let Some(cache) = cache_by_average_fieldnorm.get(&cache_key) {
|
||||
return cache.clone();
|
||||
}
|
||||
|
||||
let cache = compute_tf_cache_uncached(average_fieldnorm);
|
||||
cache_by_average_fieldnorm.put(cache_key, cache.clone());
|
||||
cache
|
||||
})
|
||||
}
|
||||
|
||||
/// A struct used for computing BM25 scores.
|
||||
#[derive(Clone)]
|
||||
pub struct Bm25Weight {
|
||||
@@ -265,7 +229,7 @@ impl Bm25Weight {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::{idf, Bm25Weight};
|
||||
use super::idf;
|
||||
use crate::{assert_nearly_equals, Score};
|
||||
|
||||
#[test]
|
||||
@@ -273,12 +237,4 @@ mod tests {
|
||||
let score: Score = 2.0;
|
||||
assert_nearly_equals!(idf(1, 2), score.ln());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bm25_tf_cache_is_shared_for_same_average_fieldnorm() {
|
||||
let weight1 = Bm25Weight::for_one_term(1, 10, 3.0);
|
||||
let weight2 = Bm25Weight::for_one_term(2, 10, 3.0);
|
||||
|
||||
assert!(std::sync::Arc::ptr_eq(&weight1.cache, &weight2.cache));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,14 +91,10 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
|
||||
num_docs: u32,
|
||||
) -> Box<dyn Scorer> {
|
||||
match scorer {
|
||||
SpecializedScorer::TermUnion(mut term_scorers) => {
|
||||
if term_scorers.len() == 1 {
|
||||
Box::new(term_scorers.pop().unwrap())
|
||||
} else {
|
||||
let union_scorer =
|
||||
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
|
||||
Box::new(union_scorer)
|
||||
}
|
||||
SpecializedScorer::TermUnion(term_scorers) => {
|
||||
let union_scorer =
|
||||
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
|
||||
Box::new(union_scorer)
|
||||
}
|
||||
SpecializedScorer::TermIntersection(term_scorers) => {
|
||||
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
|
||||
|
||||
@@ -112,14 +112,6 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
|
||||
self.underlying.fill_buffer(buffer)
|
||||
}
|
||||
|
||||
fn fill_buffer_up_to(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
self.underlying.fill_buffer_up_to(horizon, buffer)
|
||||
}
|
||||
|
||||
fn doc(&self) -> u32 {
|
||||
self.underlying.doc()
|
||||
}
|
||||
@@ -146,27 +138,6 @@ impl<S: Scorer> Scorer for BoostScorer<S> {
|
||||
fn score(&mut self) -> Score {
|
||||
self.underlying.score() * self.boost
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
self.underlying.can_score_doc()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
|
||||
self.underlying.score_doc(doc, term_freq) * self.boost
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn fill_buffer_up_to_with_term_freqs(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
self.underlying
|
||||
.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -141,16 +141,6 @@ impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
|
||||
fn score(&mut self) -> Score {
|
||||
self.score
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
|
||||
self.score
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -315,20 +315,6 @@ mod tests {
|
||||
fn score(&mut self) -> Score {
|
||||
self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, doc: DocId, _term_freq: u32) -> Score {
|
||||
self.foo
|
||||
.iter()
|
||||
.find(|(candidate_doc, _)| *candidate_doc == doc)
|
||||
.map(|(_, score)| *score)
|
||||
.unwrap_or(0.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -59,16 +59,6 @@ impl Scorer for EmptyScorer {
|
||||
fn score(&mut self) -> Score {
|
||||
0.0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,40 +1,5 @@
|
||||
use crate::docset::{DocSet, TERMINATED};
|
||||
use crate::query::Scorer;
|
||||
use crate::{DocId, Score};
|
||||
|
||||
struct ScoreOnlyScorer {
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
}
|
||||
|
||||
impl DocSet for ScoreOnlyScorer {
|
||||
fn advance(&mut self) -> DocId {
|
||||
self.doc = TERMINATED;
|
||||
TERMINATED
|
||||
}
|
||||
|
||||
fn doc(&self) -> DocId {
|
||||
self.doc
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl Scorer for ScoreOnlyScorer {
|
||||
fn score(&mut self) -> Score {
|
||||
self.score
|
||||
}
|
||||
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
|
||||
self.score
|
||||
}
|
||||
}
|
||||
use crate::Score;
|
||||
|
||||
/// The `ScoreCombiner` trait defines how to compute
|
||||
/// an overall score given a list of scores.
|
||||
@@ -45,17 +10,6 @@ pub trait ScoreCombiner: Default + Clone + Send + Copy + 'static {
|
||||
/// or not.
|
||||
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer);
|
||||
|
||||
/// Aggregates the score combiner with an already computed score.
|
||||
fn update_score(&mut self, doc: DocId, score: Score) {
|
||||
let mut scorer = ScoreOnlyScorer { doc, score };
|
||||
self.update(&mut scorer);
|
||||
}
|
||||
|
||||
/// Returns true if this combiner needs scorer scores to compute its state.
|
||||
fn requires_scoring() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Clears the score combiner state back to its initial state.
|
||||
fn clear(&mut self);
|
||||
|
||||
@@ -73,12 +27,6 @@ pub struct DoNothingCombiner;
|
||||
impl ScoreCombiner for DoNothingCombiner {
|
||||
fn update<TScorer: Scorer>(&mut self, _scorer: &mut TScorer) {}
|
||||
|
||||
fn update_score(&mut self, _doc: DocId, _score: Score) {}
|
||||
|
||||
fn requires_scoring() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn clear(&mut self) {}
|
||||
|
||||
#[inline]
|
||||
@@ -94,16 +42,10 @@ pub struct SumCombiner {
|
||||
}
|
||||
|
||||
impl ScoreCombiner for SumCombiner {
|
||||
#[inline]
|
||||
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer) {
|
||||
self.score += scorer.score();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn update_score(&mut self, _doc: DocId, score: Score) {
|
||||
self.score += score;
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.score = 0.0;
|
||||
}
|
||||
@@ -135,19 +77,12 @@ impl DisjunctionMaxCombiner {
|
||||
}
|
||||
|
||||
impl ScoreCombiner for DisjunctionMaxCombiner {
|
||||
#[inline]
|
||||
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer) {
|
||||
let score = scorer.score();
|
||||
self.max = Score::max(score, self.max);
|
||||
self.sum += score;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn update_score(&mut self, _doc: DocId, score: Score) {
|
||||
self.max = Score::max(score, self.max);
|
||||
self.sum += score;
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.max = 0.0;
|
||||
self.sum = 0.0;
|
||||
|
||||
@@ -2,8 +2,8 @@ use std::ops::DerefMut;
|
||||
|
||||
use downcast_rs::impl_downcast;
|
||||
|
||||
use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN};
|
||||
use crate::{DocId, Score};
|
||||
use crate::docset::DocSet;
|
||||
use crate::Score;
|
||||
|
||||
/// Scored set of documents matching a query within a specific segment.
|
||||
///
|
||||
@@ -13,36 +13,6 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static {
|
||||
///
|
||||
/// This method will perform a bit of computation and is not cached.
|
||||
fn score(&mut self) -> Score;
|
||||
|
||||
/// Returns true if [`Scorer::score_doc`] can score buffered docs without
|
||||
/// repositioning the scorer.
|
||||
///
|
||||
/// Scorers whose [`Scorer::score_doc`] needs term frequencies must also override
|
||||
/// [`Scorer::fill_buffer_up_to_with_term_freqs`].
|
||||
fn can_score_doc(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Returns the score for `doc` with its term frequency.
|
||||
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
|
||||
panic!(
|
||||
"score_doc is not supported by this scorer. You need check can_score_doc() before \
|
||||
calling this method."
|
||||
)
|
||||
}
|
||||
|
||||
/// Fills docs up to `horizon`.
|
||||
///
|
||||
/// The default implementation does not fill `term_freqs`. Scorers whose
|
||||
/// [`Scorer::score_doc`] reads term frequencies must override this method.
|
||||
fn fill_buffer_up_to_with_term_freqs(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
_term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
DocSet::fill_buffer_up_to(self, horizon, docs)
|
||||
}
|
||||
}
|
||||
|
||||
impl_downcast!(Scorer);
|
||||
@@ -52,25 +22,4 @@ impl Scorer for Box<dyn Scorer> {
|
||||
fn score(&mut self) -> Score {
|
||||
self.deref_mut().score()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
self.as_ref().can_score_doc()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
|
||||
self.deref_mut().score_doc(doc, term_freq)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn fill_buffer_up_to_with_term_freqs(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
self.deref_mut()
|
||||
.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN};
|
||||
use crate::docset::DocSet;
|
||||
use crate::fieldnorm::FieldNormReader;
|
||||
use crate::postings::{BlockSegmentPostings, FreqReadingOption, Postings, SegmentPostings};
|
||||
use crate::query::bm25::Bm25Weight;
|
||||
@@ -147,27 +147,6 @@ impl Scorer for TermScorer {
|
||||
let term_freq = self.term_freq();
|
||||
self.similarity_weight.score(fieldnorm_id, term_freq)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
|
||||
let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc);
|
||||
self.similarity_weight.score(fieldnorm_id, term_freq)
|
||||
}
|
||||
|
||||
fn fill_buffer_up_to_with_term_freqs(
|
||||
&mut self,
|
||||
horizon: DocId,
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
) -> usize {
|
||||
self.postings
|
||||
.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -10,7 +10,23 @@ use crate::{DocId, Score};
|
||||
// of upcoming document IDs (the "horizon").
|
||||
const HORIZON_NUM_TINYBITSETS: usize = HORIZON as usize / 64;
|
||||
const HORIZON: u32 = 64u32 * 64u32;
|
||||
const GROUPED_INSERT_MAX_BUCKET_SPAN: u32 = 2;
|
||||
|
||||
// `drain_filter` is not stable yet.
|
||||
// This function is similar except that it does is not unstable, and
|
||||
// it does not keep the original vector ordering.
|
||||
//
|
||||
// Elements are dropped and not yielded.
|
||||
fn unordered_drain_filter<T, P>(v: &mut Vec<T>, mut predicate: P)
|
||||
where P: FnMut(&mut T) -> bool {
|
||||
let mut i = 0;
|
||||
while i < v.len() {
|
||||
if predicate(&mut v[i]) {
|
||||
v.swap_remove(i);
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a `DocSet` that iterate through the union of two or more `DocSet`s.
|
||||
pub struct BufferedUnionScorer<TScorer, TScoreCombiner = DoNothingCombiner> {
|
||||
@@ -37,213 +53,31 @@ pub struct BufferedUnionScorer<TScorer, TScoreCombiner = DoNothingCombiner> {
|
||||
score: Score,
|
||||
/// Number of documents in the segment.
|
||||
num_docs: u32,
|
||||
/// Scratch buffer for block-based refill.
|
||||
refill_docs: [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
/// Scratch buffer for term frequencies matching `refill_docs`.
|
||||
refill_term_freqs: [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
/// Whether all children support scoring buffered docs after advancing.
|
||||
use_score_doc_refill: bool,
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn union_bucket(
|
||||
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
|
||||
bucket_pos: u32,
|
||||
tinyset: TinySet,
|
||||
) {
|
||||
debug_assert!((bucket_pos as usize) < HORIZON_NUM_TINYBITSETS);
|
||||
// `bucket` comes from a doc delta below `HORIZON`; there are exactly
|
||||
// `HORIZON / 64` buckets in the refill window.
|
||||
bitsets[bucket_pos as usize] = bitsets[bucket_pos as usize].union(tinyset);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn insert_delta(bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS], delta: DocId) {
|
||||
debug_assert!(delta < HORIZON);
|
||||
// `delta < HORIZON`, so `delta / 64` is in the bitset array. The bit
|
||||
// offset is reduced modulo 64 before being inserted in the TinySet.
|
||||
bitsets[delta as usize / 64].insert_mut(delta % 64u32);
|
||||
}
|
||||
|
||||
fn insert_and_score_full_buffer<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
scorer: &mut TScorer,
|
||||
docs: &[DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &[u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
|
||||
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
|
||||
min_doc: DocId,
|
||||
) {
|
||||
debug_assert!(docs.windows(2).all(|pair| pair[0] < pair[1]));
|
||||
debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] - min_doc < HORIZON);
|
||||
|
||||
let first_delta = docs[0] - min_doc;
|
||||
let last_delta = docs[COLLECT_BLOCK_BUFFER_LEN - 1] - min_doc;
|
||||
let first_bucket = first_delta / 64;
|
||||
let last_bucket = last_delta / 64;
|
||||
|
||||
// Common for very dense scorers: 64 distinct doc ids in one 64-doc bucket
|
||||
// means all bits in that bucket are present.
|
||||
if first_bucket == last_bucket {
|
||||
union_bucket(bitsets, first_bucket, TinySet::full());
|
||||
score_full_buffer(scorer, docs, term_freqs, score_combiner, min_doc);
|
||||
return;
|
||||
}
|
||||
|
||||
// 64 sorted distinct integers spanning exactly 64 values are consecutive.
|
||||
// If they cross a TinySet boundary, this is just the suffix of the first
|
||||
// bucket plus the prefix of the second bucket.
|
||||
if last_delta - first_delta == COLLECT_BLOCK_BUFFER_LEN as u32 - 1 {
|
||||
union_bucket(
|
||||
bitsets,
|
||||
first_bucket,
|
||||
TinySet::range_greater_or_equal(first_delta % 64u32),
|
||||
);
|
||||
union_bucket(
|
||||
bitsets,
|
||||
last_bucket,
|
||||
TinySet::range_lower((last_delta + 1) % 64u32),
|
||||
);
|
||||
score_full_buffer(scorer, docs, term_freqs, score_combiner, min_doc);
|
||||
return;
|
||||
}
|
||||
|
||||
// Grouping wins only for very dense buffers that hit the same TinySet many
|
||||
// times. Once the 64 docs are spread farther, a straight pass is cheaper.
|
||||
if last_bucket - first_bucket <= GROUPED_INSERT_MAX_BUCKET_SPAN {
|
||||
let mut bucket = first_bucket;
|
||||
let mut tinyset = TinySet::empty();
|
||||
for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) {
|
||||
let delta = doc - min_doc;
|
||||
let delta_bucket = delta / 64;
|
||||
if delta_bucket != bucket {
|
||||
union_bucket(bitsets, bucket, tinyset);
|
||||
bucket = delta_bucket;
|
||||
tinyset = TinySet::empty();
|
||||
}
|
||||
tinyset.insert_mut(delta % 64u32);
|
||||
let score = scorer.score_doc(doc, term_freq);
|
||||
update_score_combiner(score_combiner, delta, doc, score);
|
||||
}
|
||||
union_bucket(bitsets, bucket, tinyset);
|
||||
} else {
|
||||
for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) {
|
||||
let delta = doc - min_doc;
|
||||
insert_delta(bitsets, delta);
|
||||
// TODO: score_doc access the field_norm reader for each _term_, instead of once per
|
||||
// doc. We could optimize this by caching the field norm for the doc, and
|
||||
// reusing it for all terms in the doc.
|
||||
let score = scorer.score_doc(doc, term_freq);
|
||||
update_score_combiner(score_combiner, delta, doc, score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn update_score_combiner<TScoreCombiner: ScoreCombiner>(
|
||||
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
|
||||
delta: DocId,
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
) {
|
||||
debug_assert!(delta < HORIZON);
|
||||
// Full and partial refill only buffer docs below `horizon`, so their
|
||||
// deltas are always in the score-combiner window.
|
||||
score_combiner[delta as usize].update_score(doc, score);
|
||||
}
|
||||
|
||||
fn score_full_buffer<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
scorer: &mut TScorer,
|
||||
docs: &[DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &[u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
|
||||
min_doc: DocId,
|
||||
) {
|
||||
for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) {
|
||||
let score = scorer.score_doc(doc, term_freq);
|
||||
update_score_combiner(score_combiner, doc - min_doc, doc, score);
|
||||
}
|
||||
}
|
||||
|
||||
fn refill_scorer_with_score_docs<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
scorer: &mut TScorer,
|
||||
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
|
||||
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
min_doc: DocId,
|
||||
horizon: DocId,
|
||||
) {
|
||||
loop {
|
||||
let len = scorer.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs);
|
||||
if len == COLLECT_BLOCK_BUFFER_LEN {
|
||||
debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] != TERMINATED);
|
||||
debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] < horizon);
|
||||
insert_and_score_full_buffer(
|
||||
scorer,
|
||||
docs,
|
||||
term_freqs,
|
||||
bitsets,
|
||||
score_combiner,
|
||||
min_doc,
|
||||
);
|
||||
} else {
|
||||
for (&doc, &term_freq) in docs[..len].iter().zip(term_freqs[..len].iter()) {
|
||||
let delta = doc - min_doc;
|
||||
insert_delta(bitsets, delta);
|
||||
let score = scorer.score_doc(doc, term_freq);
|
||||
update_score_combiner(score_combiner, delta, doc, score);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn refill_scorer_from_current_doc<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
scorer: &mut TScorer,
|
||||
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
|
||||
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
|
||||
min_doc: DocId,
|
||||
horizon: DocId,
|
||||
) {
|
||||
loop {
|
||||
let doc = scorer.doc();
|
||||
if doc >= horizon {
|
||||
break;
|
||||
}
|
||||
let delta = doc - min_doc;
|
||||
insert_delta(bitsets, delta);
|
||||
debug_assert!(delta < HORIZON);
|
||||
score_combiner[delta as usize].update(scorer);
|
||||
scorer.advance();
|
||||
}
|
||||
}
|
||||
|
||||
fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
scorers: &mut Vec<TScorer>,
|
||||
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
|
||||
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
|
||||
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
|
||||
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
min_doc: DocId,
|
||||
use_score_doc_refill: bool,
|
||||
) {
|
||||
let horizon = min_doc + HORIZON;
|
||||
for scorer in scorers.iter_mut() {
|
||||
if use_score_doc_refill {
|
||||
refill_scorer_with_score_docs(
|
||||
scorer,
|
||||
bitsets,
|
||||
score_combiner,
|
||||
docs,
|
||||
term_freqs,
|
||||
min_doc,
|
||||
horizon,
|
||||
);
|
||||
} else {
|
||||
refill_scorer_from_current_doc(scorer, bitsets, score_combiner, min_doc, horizon);
|
||||
unordered_drain_filter(scorers, |scorer| {
|
||||
let horizon = min_doc + HORIZON;
|
||||
loop {
|
||||
let doc = scorer.doc();
|
||||
if doc >= horizon {
|
||||
return false;
|
||||
}
|
||||
// add this document
|
||||
let delta = doc - min_doc;
|
||||
bitsets[(delta / 64) as usize].insert_mut(delta % 64u32);
|
||||
score_combiner[delta as usize].update(scorer);
|
||||
if scorer.advance() == TERMINATED {
|
||||
// remove the docset, it has been entirely consumed.
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
scorers.retain(|scorer| scorer.doc() != TERMINATED);
|
||||
});
|
||||
}
|
||||
|
||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer, TScoreCombiner> {
|
||||
@@ -253,8 +87,6 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
|
||||
score_combiner_fn: impl FnOnce() -> TScoreCombiner,
|
||||
num_docs: u32,
|
||||
) -> BufferedUnionScorer<TScorer, TScoreCombiner> {
|
||||
let use_score_doc_refill =
|
||||
TScoreCombiner::requires_scoring() && docsets.iter().all(Scorer::can_score_doc);
|
||||
let non_empty_docsets: Vec<TScorer> = docsets
|
||||
.into_iter()
|
||||
.filter(|docset| docset.doc() != TERMINATED)
|
||||
@@ -268,9 +100,6 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
|
||||
doc: 0,
|
||||
score: 0.0,
|
||||
num_docs,
|
||||
refill_docs: [TERMINATED; COLLECT_BLOCK_BUFFER_LEN],
|
||||
refill_term_freqs: [1u32; COLLECT_BLOCK_BUFFER_LEN],
|
||||
use_score_doc_refill,
|
||||
};
|
||||
if union.refill() {
|
||||
union.advance();
|
||||
@@ -291,10 +120,7 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
|
||||
&mut self.docsets,
|
||||
&mut self.bitsets,
|
||||
&mut self.scores,
|
||||
&mut self.refill_docs,
|
||||
&mut self.refill_term_freqs,
|
||||
min_doc,
|
||||
self.use_score_doc_refill,
|
||||
);
|
||||
true
|
||||
} else {
|
||||
@@ -422,12 +248,12 @@ where
|
||||
|
||||
// The target is outside of the buffered horizon.
|
||||
// advance all docsets to a doc >= to the target.
|
||||
for docset in &mut self.docsets {
|
||||
unordered_drain_filter(&mut self.docsets, |docset| {
|
||||
if docset.doc() < target {
|
||||
docset.seek(target);
|
||||
}
|
||||
}
|
||||
self.docsets.retain(|docset| docset.doc() != TERMINATED);
|
||||
docset.doc() == TERMINATED
|
||||
});
|
||||
|
||||
// at this point all of the docsets
|
||||
// are positioned on a doc >= to the target.
|
||||
|
||||
@@ -10,8 +10,6 @@ pub use simple_union::SimpleUnion;
|
||||
mod tests {
|
||||
|
||||
use std::collections::BTreeSet;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::BitSet;
|
||||
|
||||
@@ -20,8 +18,8 @@ mod tests {
|
||||
use crate::postings::tests::test_skip_against_unoptimized;
|
||||
use crate::query::score_combiner::DoNothingCombiner;
|
||||
use crate::query::union::bitset_union::BitSetPostingUnion;
|
||||
use crate::query::{BitSetDocSet, ConstScorer, Scorer, VecDocSet};
|
||||
use crate::{tests, DocId, Score};
|
||||
use crate::query::{BitSetDocSet, ConstScorer, VecDocSet};
|
||||
use crate::{tests, DocId};
|
||||
|
||||
fn vec_doc_set_from_docs_list(
|
||||
docs_list: &[Vec<DocId>],
|
||||
@@ -68,61 +66,6 @@ mod tests {
|
||||
}
|
||||
BitSetDocSet::from(doc_bitset)
|
||||
}
|
||||
|
||||
struct CountingScorer {
|
||||
docset: VecDocSet,
|
||||
score_calls: Arc<AtomicUsize>,
|
||||
score_doc_calls: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl CountingScorer {
|
||||
fn new(
|
||||
doc_ids: Vec<DocId>,
|
||||
score_calls: Arc<AtomicUsize>,
|
||||
score_doc_calls: Arc<AtomicUsize>,
|
||||
) -> Self {
|
||||
CountingScorer {
|
||||
docset: VecDocSet::from(doc_ids),
|
||||
score_calls,
|
||||
score_doc_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DocSet for CountingScorer {
|
||||
fn advance(&mut self) -> DocId {
|
||||
self.docset.advance()
|
||||
}
|
||||
|
||||
fn seek(&mut self, target: DocId) -> DocId {
|
||||
self.docset.seek(target)
|
||||
}
|
||||
|
||||
fn doc(&self) -> DocId {
|
||||
self.docset.doc()
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.docset.size_hint()
|
||||
}
|
||||
}
|
||||
|
||||
impl Scorer for CountingScorer {
|
||||
fn score(&mut self) -> Score {
|
||||
self.score_calls.fetch_add(1, Ordering::SeqCst);
|
||||
1.0
|
||||
}
|
||||
|
||||
fn can_score_doc(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
|
||||
self.score_doc_calls.fetch_add(1, Ordering::SeqCst);
|
||||
1.0
|
||||
}
|
||||
}
|
||||
|
||||
fn aux_test_union(docs_list: &[Vec<DocId>]) {
|
||||
for constructor in [
|
||||
posting_list_union_from_docs_list,
|
||||
@@ -225,22 +168,6 @@ mod tests {
|
||||
]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_do_nothing_combiner_does_not_score_buffered_docs() {
|
||||
let score_calls = Arc::new(AtomicUsize::new(0));
|
||||
let score_doc_calls = Arc::new(AtomicUsize::new(0));
|
||||
let scorers = vec![
|
||||
CountingScorer::new(vec![1, 3, 5], score_calls.clone(), score_doc_calls.clone()),
|
||||
CountingScorer::new(vec![2, 3, 6], score_calls.clone(), score_doc_calls.clone()),
|
||||
];
|
||||
|
||||
let mut union = BufferedUnionScorer::build(scorers, DoNothingCombiner::default, 10);
|
||||
|
||||
assert_eq!(union.count_including_deleted(), 5);
|
||||
assert_eq!(score_calls.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(score_doc_calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
fn test_aux_union_skip(docs_list: &[Vec<DocId>], skip_targets: Vec<DocId>) {
|
||||
for constructor in [
|
||||
posting_list_union_from_docs_list,
|
||||
|
||||
Reference in New Issue
Block a user