From af53ffe5dfc1aa4c7ceadc82ea3b36572cdc8404 Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Fri, 26 Dec 2025 17:52:40 -0700 Subject: [PATCH] Use a `Buffer` generic scratch buffer parameter on `TopNComputer` and push directly from `ColumnValues` into a `TopNComputer` buffer in some cases. --- columnar/Cargo.toml | 2 +- columnar/src/column/mod.rs | 59 +++-- columnar/src/column_values/mod.rs | 173 ++++++++------ .../src/column_values/u64_based/bitpacked.rs | 75 +++--- .../src/comparable_doc.rs | 0 columnar/src/lib.rs | 2 + src/collector/mod.rs | 5 +- src/collector/sort_key/mod.rs | 3 +- src/collector/sort_key/order.rs | 11 +- src/collector/sort_key/sort_by_erased_type.rs | 42 +++- src/collector/sort_key/sort_by_score.rs | 9 +- .../sort_key/sort_by_static_fast_value.rs | 62 +++-- src/collector/sort_key/sort_by_string.rs | 78 +++--- src/collector/sort_key/sort_key_computer.rs | 224 +++++++++++------- src/collector/sort_key_top_collector.rs | 7 +- src/collector/top_score_collector.rs | 78 +++--- 16 files changed, 478 insertions(+), 352 deletions(-) rename src/collector/top_collector.rs => columnar/src/comparable_doc.rs (100%) diff --git a/columnar/Cargo.toml b/columnar/Cargo.toml index 9eeafe2d0..d968d8e53 100644 --- a/columnar/Cargo.toml +++ b/columnar/Cargo.toml @@ -16,7 +16,7 @@ stacker = { version= "0.6", path = "../stacker", package="tantivy-stacker"} sstable = { version= "0.6", path = "../sstable", package = "tantivy-sstable" } common = { version= "0.10", path = "../common", package = "tantivy-common" } tantivy-bitpacker = { version= "0.9", path = "../bitpacker/" } -serde = "1.0.152" +serde = { version = "1.0.152", features = ["derive"] } downcast-rs = "2.0.1" [dev-dependencies] diff --git a/columnar/src/column/mod.rs b/columnar/src/column/mod.rs index 9954dce01..280b0b4bd 100644 --- a/columnar/src/column/mod.rs +++ b/columnar/src/column/mod.rs @@ -146,13 +146,18 @@ impl Column { // Separate impl block for methods requiring `Default` for `T`. impl Column { /// Load the first value for each docid in the provided slice. + /// + /// The `docids` vector is mutated: documents that do not match the `value_range` are removed. + /// The `values` vector is populated with the values of the remaining documents. #[inline] pub fn first_vals_in_value_range( &self, - docids: &mut Vec, - values: &mut Vec>, + input_docs: &[DocId], + output: &mut Vec, DocId>>, value_range: ValueRange, ) { + // TODO: Move `COLLECT_BLOCK_BUFFER_LEN` to allow for use here, or use a different constant + // in this context. const BLOCK_LEN: usize = 64; // Corresponds to COLLECT_BLOCK_BUFFER_LEN in tantivy's docset match (&self.index, value_range) { (ColumnIndex::Empty { .. }, value_range) => { @@ -163,19 +168,20 @@ impl Column { ValueRange::LessThan(_, nulls_match) => *nulls_match, }; if nulls_match { - for _ in 0..docids.len() { - values.push(None); + for &doc in input_docs { + output.push(crate::ComparableDoc { + doc, + sort_key: None, + }); } - } else { - docids.clear(); } } (ColumnIndex::Full, value_range) => { self.values - .get_vals_in_value_range(docids, values, value_range); + .get_vals_in_value_range(input_docs, output, value_range); } (ColumnIndex::Optional(optional_index), value_range) => { - let len = docids.len(); + let len = input_docs.len(); // Ensure the input docids length does not exceed BLOCK_LEN for stack allocation // safety. If it does, we might need to handle this with multiple // chunks or fallback to heap. For now, an assert is used to confirm @@ -188,7 +194,7 @@ impl Column { ); let mut input_docs_buffer = [0u32; BLOCK_LEN]; - input_docs_buffer[..len].copy_from_slice(docids); + input_docs_buffer[..len].copy_from_slice(input_docs); let mut dense_row_ids_buffer = [0u32; BLOCK_LEN]; let mut dense_values_buffer = [T::default(); BLOCK_LEN]; @@ -221,9 +227,6 @@ impl Column { }; // Phase 3: Filter and merge results, reconstructing docids and values - docids.clear(); - values.clear(); - let mut dense_values_cursor = 0; for i in 0..len { let original_doc_id = input_docs_buffer[i]; @@ -241,13 +244,17 @@ impl Column { }; if value_matches { - docids.push(original_doc_id); - values.push(Some(val)); + output.push(crate::ComparableDoc { + doc: original_doc_id, + sort_key: Some(val), + }); } } else if nulls_match { // This doc_id was not present in the optional index (null) and nulls match - docids.push(original_doc_id); - values.push(None); + output.push(crate::ComparableDoc { + doc: original_doc_id, + sort_key: None, + }); } } } @@ -258,9 +265,8 @@ impl Column { ValueRange::GreaterThan(_, nulls_match) => *nulls_match, ValueRange::LessThan(_, nulls_match) => *nulls_match, }; - let mut write_head = 0; - for i in 0..docids.len() { - let docid = docids[i]; + for i in 0..input_docs.len() { + let docid = input_docs[i]; let row_range = multivalued_index.range(docid); let is_empty = row_range.start == row_range.end; if !is_empty { @@ -272,17 +278,18 @@ impl Column { ValueRange::LessThan(t, _) => val < *t, }; if matches { - docids[write_head] = docid; - values.push(Some(val)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: docid, + sort_key: Some(val), + }); } } else if nulls_match { - docids[write_head] = docid; - values.push(None); - write_head += 1; + output.push(crate::ComparableDoc { + doc: docid, + sort_key: None, + }); } } - docids.truncate(write_head); } } } diff --git a/columnar/src/column_values/mod.rs b/columnar/src/column_values/mod.rs index eacb83c68..9646c2a81 100644 --- a/columnar/src/column_values/mod.rs +++ b/columnar/src/column_values/mod.rs @@ -116,49 +116,52 @@ pub trait ColumnValues: Send + Sync + DowncastSync { /// The values are filtered by the provided value range. fn get_vals_in_value_range( &self, - indexes: &mut Vec, - output: &mut Vec>, + input_indexes: &[u32], + output: &mut Vec, crate::DocId>>, value_range: ValueRange, ) { - let mut write_head = 0; + let len = input_indexes.len(); let mut read_head = 0; - let len = indexes.len(); match value_range { ValueRange::All => { while read_head + 3 < len { - let idx0 = indexes[read_head]; - let idx1 = indexes[read_head + 1]; - let idx2 = indexes[read_head + 2]; - let idx3 = indexes[read_head + 3]; + let idx0 = input_indexes[read_head]; + let idx1 = input_indexes[read_head + 1]; + let idx2 = input_indexes[read_head + 2]; + let idx3 = input_indexes[read_head + 3]; let val0 = self.get_val(idx0); let val1 = self.get_val(idx1); let val2 = self.get_val(idx2); let val3 = self.get_val(idx3); - indexes[write_head] = idx0; - output.push(Some(val0)); - write_head += 1; - indexes[write_head] = idx1; - output.push(Some(val1)); - write_head += 1; - indexes[write_head] = idx2; - output.push(Some(val2)); - write_head += 1; - indexes[write_head] = idx3; - output.push(Some(val3)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx0, + sort_key: Some(val0), + }); + output.push(crate::ComparableDoc { + doc: idx1, + sort_key: Some(val1), + }); + output.push(crate::ComparableDoc { + doc: idx2, + sort_key: Some(val2), + }); + output.push(crate::ComparableDoc { + doc: idx3, + sort_key: Some(val3), + }); read_head += 4; } } ValueRange::Inclusive(ref range) => { while read_head + 3 < len { - let idx0 = indexes[read_head]; - let idx1 = indexes[read_head + 1]; - let idx2 = indexes[read_head + 2]; - let idx3 = indexes[read_head + 3]; + let idx0 = input_indexes[read_head]; + let idx1 = input_indexes[read_head + 1]; + let idx2 = input_indexes[read_head + 2]; + let idx3 = input_indexes[read_head + 3]; let val0 = self.get_val(idx0); let val1 = self.get_val(idx1); @@ -166,24 +169,28 @@ pub trait ColumnValues: Send + Sync + DowncastSync { let val3 = self.get_val(idx3); if range.contains(&val0) { - indexes[write_head] = idx0; - output.push(Some(val0)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx0, + sort_key: Some(val0), + }); } if range.contains(&val1) { - indexes[write_head] = idx1; - output.push(Some(val1)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx1, + sort_key: Some(val1), + }); } if range.contains(&val2) { - indexes[write_head] = idx2; - output.push(Some(val2)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx2, + sort_key: Some(val2), + }); } if range.contains(&val3) { - indexes[write_head] = idx3; - output.push(Some(val3)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx3, + sort_key: Some(val3), + }); } read_head += 4; @@ -191,10 +198,10 @@ pub trait ColumnValues: Send + Sync + DowncastSync { } ValueRange::GreaterThan(ref threshold, _) => { while read_head + 3 < len { - let idx0 = indexes[read_head]; - let idx1 = indexes[read_head + 1]; - let idx2 = indexes[read_head + 2]; - let idx3 = indexes[read_head + 3]; + let idx0 = input_indexes[read_head]; + let idx1 = input_indexes[read_head + 1]; + let idx2 = input_indexes[read_head + 2]; + let idx3 = input_indexes[read_head + 3]; let val0 = self.get_val(idx0); let val1 = self.get_val(idx1); @@ -202,24 +209,28 @@ pub trait ColumnValues: Send + Sync + DowncastSync { let val3 = self.get_val(idx3); if val0 > *threshold { - indexes[write_head] = idx0; - output.push(Some(val0)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx0, + sort_key: Some(val0), + }); } if val1 > *threshold { - indexes[write_head] = idx1; - output.push(Some(val1)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx1, + sort_key: Some(val1), + }); } if val2 > *threshold { - indexes[write_head] = idx2; - output.push(Some(val2)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx2, + sort_key: Some(val2), + }); } if val3 > *threshold { - indexes[write_head] = idx3; - output.push(Some(val3)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx3, + sort_key: Some(val3), + }); } read_head += 4; @@ -227,10 +238,10 @@ pub trait ColumnValues: Send + Sync + DowncastSync { } ValueRange::LessThan(ref threshold, _) => { while read_head + 3 < len { - let idx0 = indexes[read_head]; - let idx1 = indexes[read_head + 1]; - let idx2 = indexes[read_head + 2]; - let idx3 = indexes[read_head + 3]; + let idx0 = input_indexes[read_head]; + let idx1 = input_indexes[read_head + 1]; + let idx2 = input_indexes[read_head + 2]; + let idx3 = input_indexes[read_head + 3]; let val0 = self.get_val(idx0); let val1 = self.get_val(idx1); @@ -238,24 +249,28 @@ pub trait ColumnValues: Send + Sync + DowncastSync { let val3 = self.get_val(idx3); if val0 < *threshold { - indexes[write_head] = idx0; - output.push(Some(val0)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx0, + sort_key: Some(val0), + }); } if val1 < *threshold { - indexes[write_head] = idx1; - output.push(Some(val1)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx1, + sort_key: Some(val1), + }); } if val2 < *threshold { - indexes[write_head] = idx2; - output.push(Some(val2)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx2, + sort_key: Some(val2), + }); } if val3 < *threshold { - indexes[write_head] = idx3; - output.push(Some(val3)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx3, + sort_key: Some(val3), + }); } read_head += 4; @@ -264,7 +279,7 @@ pub trait ColumnValues: Send + Sync + DowncastSync { } // Process remaining elements (0 to 3) while read_head < len { - let idx = indexes[read_head]; + let idx = input_indexes[read_head]; let val = self.get_val(idx); let matches = match value_range { // 'value_range' is still moved here. This is the outer `value_range` @@ -274,13 +289,13 @@ pub trait ColumnValues: Send + Sync + DowncastSync { ValueRange::LessThan(ref t, _) => val < *t, }; if matches { - indexes[write_head] = idx; - output.push(Some(val)); - write_head += 1; + output.push(crate::ComparableDoc { + doc: idx, + sort_key: Some(val), + }); } read_head += 1; } - indexes.truncate(write_head); } /// Fills an output buffer with the fast field values @@ -393,11 +408,11 @@ impl ColumnValues for EmptyColumnValues { fn get_vals_in_value_range( &self, - indexes: &mut Vec, - output: &mut Vec>, + input_indexes: &[u32], + output: &mut Vec, crate::DocId>>, value_range: ValueRange, ) { - let _ = (indexes, output, value_range); + let _ = (input_indexes, output, value_range); panic!("Internal Error: Called get_vals_in_value_range of empty column.") } } @@ -416,12 +431,12 @@ impl ColumnValues for Arc, - output: &mut Vec>, + input_indexes: &[u32], + output: &mut Vec, crate::DocId>>, value_range: ValueRange, ) { self.as_ref() - .get_vals_in_value_range(indexes, output, value_range) + .get_vals_in_value_range(input_indexes, output, value_range) } #[inline(always)] diff --git a/columnar/src/column_values/u64_based/bitpacked.rs b/columnar/src/column_values/u64_based/bitpacked.rs index 647728863..cc15a98c8 100644 --- a/columnar/src/column_values/u64_based/bitpacked.rs +++ b/columnar/src/column_values/u64_based/bitpacked.rs @@ -69,67 +69,68 @@ impl ColumnValues for BitpackedReader { fn get_vals_in_value_range( &self, - indexes: &mut Vec, - output: &mut Vec>, + input_indexes: &[u32], + output: &mut Vec, crate::DocId>>, value_range: ValueRange, ) { - let mut write_head = 0; match value_range { ValueRange::All => { - for i in 0..indexes.len() { - let idx = indexes[i]; - indexes[write_head] = idx; - output.push(Some(self.get_val(idx))); - write_head += 1; + for &idx in input_indexes { + output.push(crate::ComparableDoc { + doc: idx, + sort_key: Some(self.get_val(idx)), + }); } } ValueRange::Inclusive(range) => { if let Some(transformed_range) = transform_range_before_linear_transformation(&self.stats, range) { - for i in 0..indexes.len() { - let doc = indexes[i]; + for &doc in input_indexes { let raw_val = self.get_val(doc); if transformed_range.contains(&raw_val) { - indexes[write_head] = doc; - output - .push(Some(self.stats.min_value + self.stats.gcd.get() * raw_val)); - write_head += 1; + output.push(crate::ComparableDoc { + doc, + sort_key: Some( + self.stats.min_value + self.stats.gcd.get() * raw_val, + ), + }); } } } } ValueRange::GreaterThan(threshold, _) => { if threshold < self.stats.min_value { - for i in 0..indexes.len() { - let idx = indexes[i]; - indexes[write_head] = idx; - output.push(Some(self.get_val(idx))); - write_head += 1; + for &idx in input_indexes { + output.push(crate::ComparableDoc { + doc: idx, + sort_key: Some(self.get_val(idx)), + }); } } else if threshold >= self.stats.max_value { // All filtered out } else { let raw_threshold = (threshold - self.stats.min_value) / self.stats.gcd.get(); - for i in 0..indexes.len() { - let doc = indexes[i]; + for &doc in input_indexes { let raw_val = self.get_val(doc); if raw_val > raw_threshold { - indexes[write_head] = doc; - output - .push(Some(self.stats.min_value + self.stats.gcd.get() * raw_val)); - write_head += 1; + output.push(crate::ComparableDoc { + doc, + sort_key: Some( + self.stats.min_value + self.stats.gcd.get() * raw_val, + ), + }); } } } } ValueRange::LessThan(threshold, _) => { if threshold > self.stats.max_value { - for i in 0..indexes.len() { - let idx = indexes[i]; - indexes[write_head] = idx; - output.push(Some(self.get_val(idx))); - write_head += 1; + for &idx in input_indexes { + output.push(crate::ComparableDoc { + doc: idx, + sort_key: Some(self.get_val(idx)), + }); } } else if threshold <= self.stats.min_value { // All filtered out @@ -142,20 +143,20 @@ impl ColumnValues for BitpackedReader { diff / gcd + 1 }; - for i in 0..indexes.len() { - let doc = indexes[i]; + for &doc in input_indexes { let raw_val = self.get_val(doc); if raw_val < raw_threshold { - indexes[write_head] = doc; - output - .push(Some(self.stats.min_value + self.stats.gcd.get() * raw_val)); - write_head += 1; + output.push(crate::ComparableDoc { + doc, + sort_key: Some( + self.stats.min_value + self.stats.gcd.get() * raw_val, + ), + }); } } } } } - indexes.truncate(write_head); } fn get_row_ids_for_value_range( &self, diff --git a/src/collector/top_collector.rs b/columnar/src/comparable_doc.rs similarity index 100% rename from src/collector/top_collector.rs rename to columnar/src/comparable_doc.rs diff --git a/columnar/src/lib.rs b/columnar/src/lib.rs index ce499f72b..290db558e 100644 --- a/columnar/src/lib.rs +++ b/columnar/src/lib.rs @@ -29,6 +29,7 @@ mod column; pub mod column_index; pub mod column_values; mod columnar; +mod comparable_doc; mod dictionary; mod dynamic_column; mod iterable; @@ -45,6 +46,7 @@ pub use columnar::{ CURRENT_VERSION, ColumnType, ColumnarReader, ColumnarWriter, HasAssociatedColumnType, MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, Version, merge_columnar, }; +pub use comparable_doc::ComparableDoc; use sstable::VoidSSTable; pub use value::{NumericalType, NumericalValue}; diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 0f8360d8d..ec65fe71e 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -96,10 +96,9 @@ mod histogram_collector; pub use histogram_collector::HistogramCollector; mod multi_collector; -pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit}; +pub use columnar::ComparableDoc; -mod top_collector; -pub use self::top_collector::ComparableDoc; +pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit}; mod top_score_collector; pub use self::top_score_collector::{TopDocs, TopNComputer}; diff --git a/src/collector/sort_key/mod.rs b/src/collector/sort_key/mod.rs index 31c6e7049..8bcf22321 100644 --- a/src/collector/sort_key/mod.rs +++ b/src/collector/sort_key/mod.rs @@ -39,6 +39,7 @@ pub(crate) mod tests { use crate::collector::sort_key::{ SortByErasedType, SortBySimilarityScore, SortByStaticFastValue, SortByString, }; + use crate::collector::top_score_collector::compare_for_top_k; use crate::collector::{ComparableDoc, DocSetCollector, TopDocs}; use crate::indexer::NoMergePolicy; use crate::query::{AllQuery, QueryParser}; @@ -594,7 +595,7 @@ pub(crate) mod tests { .map(|(sort_key, doc)| ComparableDoc { sort_key, doc }) .collect(); - comparable_docs.sort_by(|l, r| comparator.compare_doc(l, r)); + comparable_docs.sort_by(|l, r| compare_for_top_k(&comparator, l, r)); let expected_results = comparable_docs .into_iter() diff --git a/src/collector/sort_key/order.rs b/src/collector/sort_key/order.rs index 178c9a07e..6fa2d482e 100644 --- a/src/collector/sort_key/order.rs +++ b/src/collector/sort_key/order.rs @@ -1,6 +1,6 @@ use std::cmp::Ordering; -use columnar::{MonotonicallyMappableToU64, ValueRange}; +use columnar::{ComparableDoc, MonotonicallyMappableToU64, ValueRange}; use serde::{Deserialize, Serialize}; use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; @@ -619,6 +619,7 @@ where type SortKey = TSegmentSortKeyComputer::SortKey; type SegmentSortKey = TSegmentSortKey; type SegmentComparator = TComparator; + type Buffer = TSegmentSortKeyComputer::Buffer; fn segment_comparator(&self) -> Self::SegmentComparator { self.comparator.clone() @@ -630,11 +631,13 @@ where fn segment_sort_keys( &mut self, - docs: &[DocId], + input_docs: &[DocId], + output: &mut Vec>, + buffer: &mut Self::Buffer, filter: ValueRange, - ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + ) { self.segment_sort_key_computer - .segment_sort_keys(docs, filter) + .segment_sort_keys(input_docs, output, buffer, filter) } #[inline(always)] diff --git a/src/collector/sort_key/sort_by_erased_type.rs b/src/collector/sort_key/sort_by_erased_type.rs index 8bdbe8057..42a73ceaf 100644 --- a/src/collector/sort_key/sort_by_erased_type.rs +++ b/src/collector/sort_key/sort_by_erased_type.rs @@ -4,7 +4,7 @@ use crate::collector::sort_key::sort_by_score::SortBySimilarityScoreSegmentCompu use crate::collector::sort_key::{ NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString, }; -use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; +use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer}; use crate::fastfield::FastFieldNotAvailableError; use crate::schema::OwnedValue; use crate::{DateTime, DocId, Score}; @@ -39,15 +39,21 @@ trait ErasedSegmentSortKeyComputer: Send + Sync { fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option; fn segment_sort_keys( &mut self, - docs: &[DocId], + input_docs: &[DocId], + output: &mut Vec, DocId>>, filter: ValueRange>, - ) -> &mut Vec<(DocId, Option)>; + ); fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue; } -struct ErasedSegmentSortKeyComputerWrapper { +struct ErasedSegmentSortKeyComputerWrapper +where + C: SegmentSortKeyComputer> + Send + Sync, + F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static, +{ inner: C, converter: F, + buffer: C::Buffer, } impl ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper @@ -61,10 +67,12 @@ where fn segment_sort_keys( &mut self, - docs: &[DocId], + input_docs: &[DocId], + output: &mut Vec, DocId>>, filter: ValueRange>, - ) -> &mut Vec<(DocId, Option)> { - self.inner.segment_sort_keys(docs, filter) + ) { + self.inner + .segment_sort_keys(input_docs, output, &mut self.buffer, filter) } fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { @@ -85,9 +93,10 @@ impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer { fn segment_sort_keys( &mut self, - _docs: &[DocId], + _input_docs: &[DocId], + _output: &mut Vec, DocId>>, _filter: ValueRange>, - ) -> &mut Vec<(DocId, Option)> { + ) { unimplemented!("Batch computation not supported for score sorting") } @@ -134,6 +143,7 @@ impl SortKeyComputer for SortByErasedType { converter: |val: Option| { val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null) }, + buffer: Default::default(), }) } ColumnType::U64 => { @@ -144,6 +154,7 @@ impl SortKeyComputer for SortByErasedType { converter: |val: Option| { val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null) }, + buffer: Default::default(), }) } ColumnType::I64 => { @@ -154,6 +165,7 @@ impl SortKeyComputer for SortByErasedType { converter: |val: Option| { val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null) }, + buffer: Default::default(), }) } ColumnType::F64 => { @@ -164,6 +176,7 @@ impl SortKeyComputer for SortByErasedType { converter: |val: Option| { val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null) }, + buffer: Default::default(), }) } ColumnType::Bool => { @@ -174,6 +187,7 @@ impl SortKeyComputer for SortByErasedType { converter: |val: Option| { val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null) }, + buffer: Default::default(), }) } ColumnType::DateTime => { @@ -184,6 +198,7 @@ impl SortKeyComputer for SortByErasedType { converter: |val: Option| { val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null) }, + buffer: Default::default(), }) } column_type => { @@ -212,6 +227,7 @@ impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer { type SortKey = OwnedValue; type SegmentSortKey = Option; type SegmentComparator = NaturalComparator; + type Buffer = (); #[inline(always)] fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option { @@ -220,10 +236,12 @@ impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer { fn segment_sort_keys( &mut self, - docs: &[DocId], + input_docs: &[DocId], + output: &mut Vec>, + _buffer: &mut Self::Buffer, filter: ValueRange, - ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { - self.inner.segment_sort_keys(docs, filter) + ) { + self.inner.segment_sort_keys(input_docs, output, filter) } fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue { diff --git a/src/collector/sort_key/sort_by_score.rs b/src/collector/sort_key/sort_by_score.rs index 4da9643df..756a74470 100644 --- a/src/collector/sort_key/sort_by_score.rs +++ b/src/collector/sort_key/sort_by_score.rs @@ -1,7 +1,7 @@ use columnar::ValueRange; use crate::collector::sort_key::NaturalComparator; -use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer}; +use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer, TopNComputer}; use crate::{DocAddress, DocId, Score}; /// Sort by similarity score. @@ -69,6 +69,7 @@ impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer { type SortKey = Score; type SegmentSortKey = Score; type SegmentComparator = NaturalComparator; + type Buffer = (); #[inline(always)] fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score { @@ -77,9 +78,11 @@ impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer { fn segment_sort_keys( &mut self, - _docs: &[DocId], + _input_docs: &[DocId], + _output: &mut Vec>, + _buffer: &mut Self::Buffer, _filter: ValueRange, - ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + ) { unimplemented!("Batch computation not supported for score sorting") } diff --git a/src/collector/sort_key/sort_by_static_fast_value.rs b/src/collector/sort_key/sort_by_static_fast_value.rs index 01ddc45d8..08f2ea521 100644 --- a/src/collector/sort_key/sort_by_static_fast_value.rs +++ b/src/collector/sort_key/sort_by_static_fast_value.rs @@ -4,7 +4,7 @@ use columnar::{Column, ValueRange}; use crate::collector::sort_key::sort_key_computer::convert_optional_u64_range_to_u64_range; use crate::collector::sort_key::NaturalComparator; -use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; +use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer}; use crate::fastfield::{FastFieldNotAvailableError, FastValue}; use crate::{DocId, Score, SegmentReader}; @@ -72,9 +72,6 @@ impl SortKeyComputer for SortByStaticFastValue { Ok(SortByFastValueSegmentSortKeyComputer { sort_column, typ: PhantomData, - buffer: Vec::new(), - fetch_buffer: Vec::new(), - doc_buffer: Vec::new(), }) } } @@ -82,15 +79,13 @@ impl SortKeyComputer for SortByStaticFastValue { pub struct SortByFastValueSegmentSortKeyComputer { sort_column: Column, typ: PhantomData, - buffer: Vec<(DocId, Option)>, - fetch_buffer: Vec>, - doc_buffer: Vec, } impl SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer { type SortKey = Option; type SegmentSortKey = Option; type SegmentComparator = NaturalComparator; + type Buffer = (); #[inline(always)] fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey { @@ -99,24 +94,14 @@ impl SegmentSortKeyComputer for SortByFastValueSegmentSortKeyCompu fn segment_sort_keys( &mut self, - docs: &[DocId], + input_docs: &[DocId], + output: &mut Vec>, + _buffer: &mut Self::Buffer, filter: ValueRange, - ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { - self.doc_buffer.clear(); - self.doc_buffer.extend_from_slice(docs); - self.fetch_buffer.clear(); + ) { let u64_filter = convert_optional_u64_range_to_u64_range(filter); - self.sort_column.first_vals_in_value_range( - &mut self.doc_buffer, - &mut self.fetch_buffer, - u64_filter, - ); - - self.buffer.clear(); - for (&doc, &val) in self.doc_buffer.iter().zip(self.fetch_buffer.iter()) { - self.buffer.push((doc, val)); - } - &mut self.buffer + self.sort_column + .first_vals_in_value_range(input_docs, output, u64_filter); } fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { @@ -154,10 +139,16 @@ mod tests { let sorter = SortByStaticFastValue::::for_field("field"); let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap(); - let docs = vec![0, 1, 2]; - let output = computer.segment_sort_keys(&docs, ValueRange::All); + let mut docs = vec![0, 1, 2]; + let mut output = Vec::new(); + let mut buffer = (); + computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All); - assert_eq!(output, &[(0, Some(10)), (1, Some(20)), (2, None)]); + assert_eq!( + output.iter().map(|c| c.sort_key).collect::>(), + &[Some(10), Some(20), None] + ); + assert_eq!(output.iter().map(|c| c.doc).collect::>(), &[0, 1, 2]); } #[test] @@ -184,15 +175,20 @@ mod tests { let sorter = SortByStaticFastValue::::for_field("field"); let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap(); - let docs = vec![0, 1, 2]; - let output = computer.segment_sort_keys( - &docs, + let mut docs = vec![0, 1, 2]; + let mut output = Vec::new(); + let mut buffer = (); + computer.segment_sort_keys( + &mut docs, + &mut output, + &mut buffer, ValueRange::GreaterThan(Some(15u64), false /* inclusive */), ); - // Should contain only the document with value 20. - // Doc 0 (10) < 15 - // Doc 2 (None) < 15 - assert_eq!(output, &[(1, Some(20))]); + assert_eq!( + output.iter().map(|c| c.sort_key).collect::>(), + &[Some(20)] + ); + assert_eq!(output.iter().map(|c| c.doc).collect::>(), &[1]); } } diff --git a/src/collector/sort_key/sort_by_string.rs b/src/collector/sort_key/sort_by_string.rs index 4582c1e22..8e3a383d7 100644 --- a/src/collector/sort_key/sort_by_string.rs +++ b/src/collector/sort_key/sort_by_string.rs @@ -4,7 +4,7 @@ use crate::collector::sort_key::sort_key_computer::{ convert_optional_u64_range_to_u64_range, range_contains_none, }; use crate::collector::sort_key::NaturalComparator; -use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; +use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer}; use crate::termdict::TermOrdinal; use crate::{DocId, Score}; @@ -41,26 +41,19 @@ impl SortKeyComputer for SortByString { segment_reader: &crate::SegmentReader, ) -> crate::Result { let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?; - Ok(ByStringColumnSegmentSortKeyComputer { - str_column_opt, - buffer: Vec::new(), - fetch_buffer: Vec::new(), - doc_buffer: Vec::new(), - }) + Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt }) } } pub struct ByStringColumnSegmentSortKeyComputer { str_column_opt: Option, - buffer: Vec<(DocId, Option)>, - fetch_buffer: Vec>, - doc_buffer: Vec, } impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { type SortKey = Option; type SegmentSortKey = Option; type SegmentComparator = NaturalComparator; + type Buffer = (); #[inline(always)] fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option { @@ -70,33 +63,24 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { fn segment_sort_keys( &mut self, - docs: &[DocId], + input_docs: &[DocId], + output: &mut Vec>, + _buffer: &mut Self::Buffer, filter: ValueRange, - ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { - self.doc_buffer.clear(); - self.doc_buffer.extend_from_slice(docs); - self.fetch_buffer.clear(); - + ) { if let Some(str_column) = &self.str_column_opt { let u64_filter = convert_optional_u64_range_to_u64_range(filter); - str_column.ords().first_vals_in_value_range( - &mut self.doc_buffer, - &mut self.fetch_buffer, - u64_filter, - ); + str_column + .ords() + .first_vals_in_value_range(input_docs, output, u64_filter); } else if range_contains_none(&filter) { - for _ in 0..docs.len() { - self.fetch_buffer.push(None); + for &doc in input_docs { + output.push(ComparableDoc { + doc, + sort_key: None, + }); } - } else { - self.doc_buffer.clear(); } - - self.buffer.clear(); - for (&doc, &val) in self.doc_buffer.iter().zip(self.fetch_buffer.iter()) { - self.buffer.push((doc, val)); - } - &mut self.buffer } fn convert_segment_sort_key(&self, term_ord_opt: Option) -> Option { @@ -143,13 +127,16 @@ mod tests { let sorter = SortByString::for_field("field"); let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap(); - let docs = vec![0, 1, 2]; - let output = computer.segment_sort_keys(&docs, ValueRange::All); + let mut docs = vec![0, 1, 2]; + let mut output = Vec::new(); + let mut buffer = (); + computer.segment_sort_keys(&mut docs, &mut output, &mut buffer, ValueRange::All); - // We expect ordinals. - // "a" -> 0 - // "c" -> 1 - assert_eq!(output, &[(0, Some(0)), (1, Some(1)), (2, None)]); + assert_eq!( + output.iter().map(|c| c.sort_key).collect::>(), + &[Some(0), Some(1), None] + ); + assert_eq!(output.iter().map(|c| c.doc).collect::>(), &[0, 1, 2]); } #[test] @@ -176,16 +163,23 @@ mod tests { let sorter = SortByString::for_field("field"); let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap(); - let docs = vec![0, 1, 2]; + let mut docs = vec![0, 1, 2]; + let mut output = Vec::new(); // Filter: > "b". "a" is 0, "c" is 1. // We want > "a" (ord 0). So we filter > ord 0. // 0 is "a", 1 is "c". - let output = computer.segment_sort_keys( - &docs, + let mut buffer = (); + computer.segment_sort_keys( + &mut docs, + &mut output, + &mut buffer, ValueRange::GreaterThan(Some(0), false /* inclusive */), ); - // Should contain only the document with value "c" (ord 1). - assert_eq!(output, &[(1, Some(1))]); + assert_eq!( + output.iter().map(|c| c.sort_key).collect::>(), + &[Some(1)] + ); + assert_eq!(output.iter().map(|c| c.doc).collect::>(), &[1]); } } diff --git a/src/collector/sort_key/sort_key_computer.rs b/src/collector/sort_key/sort_key_computer.rs index d21de75aa..875448f51 100644 --- a/src/collector/sort_key/sort_key_computer.rs +++ b/src/collector/sort_key/sort_key_computer.rs @@ -28,6 +28,9 @@ pub trait SegmentSortKeyComputer: 'static { /// Comparator type. type SegmentComparator: Comparator + Clone + 'static; + /// Buffer type used for scratch space. + type Buffer: Default + Send + Sync + 'static; + /// Returns the segment sort key comparator. fn segment_comparator(&self) -> Self::SegmentComparator { Self::SegmentComparator::default() @@ -38,13 +41,15 @@ pub trait SegmentSortKeyComputer: 'static { /// Computes the sort keys for a batch of documents. /// - /// The computed sort keys are stored in an internal buffer and returned as a slice. - /// Subsequent calls to this method may reuse and overwrite the internal buffer. + /// The computed sort keys and document IDs are pushed into the `output` vector. + /// The `buffer` is used for scratch space. fn segment_sort_keys( &mut self, - docs: &[DocId], + input_docs: &[DocId], + output: &mut Vec>, + buffer: &mut Self::Buffer, filter: ValueRange, - ) -> &mut Vec<(DocId, Self::SegmentSortKey)>; + ); /// Computes the sort key and pushes the document in a TopN Computer. /// @@ -54,7 +59,7 @@ pub trait SegmentSortKeyComputer: 'static { &mut self, doc: DocId, score: Score, - top_n_computer: &mut TopNComputer, + top_n_computer: &mut TopNComputer, ) { let sort_key = self.segment_sort_key(doc, score); top_n_computer.push(sort_key, doc); @@ -63,7 +68,7 @@ pub trait SegmentSortKeyComputer: 'static { fn compute_sort_keys_and_collect>( &mut self, docs: &[DocId], - top_n_computer: &mut TopNComputer, + top_n_computer: &mut TopNComputer, ) { // The capacity of a TopNComputer is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so we // should always be able to `reserve` space for the entire block. @@ -76,24 +81,8 @@ pub trait SegmentSortKeyComputer: 'static { ValueRange::All }; - let sort_keys = self.segment_sort_keys(docs, value_range); - - if let Some(threshold) = &top_n_computer.threshold { - let threshold = threshold.clone(); - for (doc, sort_key) in sort_keys.drain(..) { - let cmp = comparator.compare(&sort_key, &threshold); - if cmp == Ordering::Greater { - // We validated at the top of the method that we have capacity. - top_n_computer.append_doc_unchecked(doc, sort_key); - } - } - } else { - // Eagerly push, without a threshold to compare to. - for (doc, sort_key) in sort_keys.drain(..) { - // We validated at the top of the method that we have capacity. - top_n_computer.append_doc_unchecked(doc, sort_key); - } - } + let (buffer, scratch) = top_n_computer.buffer_and_scratch(); + self.segment_sort_keys(docs, buffer, scratch, value_range); } /// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on @@ -213,8 +202,6 @@ where Ok(ChainSegmentSortKeyComputer { head: self.0.segment_sort_key_computer(segment_reader)?, tail: self.1.segment_sort_key_computer(segment_reader)?, - head_key_buffer: Vec::new(), - doc_buffer: Vec::new(), }) } @@ -240,8 +227,28 @@ where { head: Head, tail: Tail, - head_key_buffer: Vec, - doc_buffer: Vec, +} + +pub struct ChainBuffer { + pub head: HeadBuffer, + pub tail: TailBuffer, + pub head_output: Vec>, + pub tail_output: Vec>, + pub tail_input_docs: Vec, +} + +impl Default + for ChainBuffer +{ + fn default() -> Self { + ChainBuffer { + head: HeadBuffer::default(), + tail: TailBuffer::default(), + head_output: Vec::new(), + tail_output: Vec::new(), + tail_input_docs: Vec::new(), + } + } } impl ChainSegmentSortKeyComputer @@ -289,6 +296,9 @@ where type SegmentComparator = (Head::SegmentComparator, Tail::SegmentComparator); + type Buffer = + ChainBuffer; + fn segment_comparator(&self) -> Self::SegmentComparator { ( self.head.segment_comparator(), @@ -313,9 +323,11 @@ where fn segment_sort_keys( &mut self, - _docs: &[DocId], + _input_docs: &[DocId], + _output: &mut Vec>, + _buffer: &mut Self::Buffer, _filter: ValueRange, - ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + ) { unimplemented!("The head and the tail are accessed independently."); } @@ -324,7 +336,7 @@ where &mut self, doc: DocId, score: Score, - top_n_computer: &mut TopNComputer, + top_n_computer: &mut TopNComputer, ) { let sort_key: Self::SegmentSortKey; if let Some(threshold) = &top_n_computer.threshold { @@ -342,39 +354,52 @@ where fn compute_sort_keys_and_collect>( &mut self, docs: &[DocId], - top_n_computer: &mut TopNComputer, + top_n_computer: &mut TopNComputer, ) { // The capacity of a TopNComputer is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so we // should always be able to `reserve` space for the entire block. top_n_computer.reserve(docs.len()); + let mut scratch = std::mem::take(&mut top_n_computer.scratch); + if let Some(threshold) = &top_n_computer.threshold { let (head_threshold, tail_threshold) = threshold.clone(); let head_cmp = self.head.segment_comparator(); let tail_cmp = self.tail.segment_comparator(); let head_filter = head_cmp.threshold_to_valuerange(head_threshold.clone()); - let head_keys = self.head.segment_sort_keys(docs, head_filter); - self.doc_buffer.clear(); - self.head_key_buffer.clear(); - for (doc, head_key) in head_keys.drain(..) { - let cmp = head_cmp.compare(&head_key, &head_threshold); - if cmp != Ordering::Less { - self.doc_buffer.push(doc); - self.head_key_buffer.push(head_key); - } - } + scratch.head_output.clear(); + self.head.segment_sort_keys( + docs, + &mut scratch.head_output, + &mut scratch.head, + head_filter, + ); - if !self.doc_buffer.is_empty() { - let tail_keys = self - .tail - .segment_sort_keys(&self.doc_buffer, ValueRange::All); - for ((head_key, tail_key), &doc) in self - .head_key_buffer + if !scratch.head_output.is_empty() { + scratch.tail_output.clear(); + scratch.tail_input_docs.clear(); + for cd in &scratch.head_output { + scratch.tail_input_docs.push(cd.doc); + } + + self.tail.segment_sort_keys( + &scratch.tail_input_docs, + &mut scratch.tail_output, + &mut scratch.tail, + ValueRange::All, + ); + + for (head_doc, tail_doc) in scratch + .head_output .drain(..) - .zip(tail_keys.drain(..).map(|(_, k)| k)) - .zip(self.doc_buffer.iter()) + .zip(scratch.tail_output.drain(..)) { + debug_assert_eq!(head_doc.doc, tail_doc.doc); + let doc = head_doc.doc; + let head_key = head_doc.sort_key; + let tail_key = tail_doc.sort_key; + let head_ord = head_cmp.compare(&head_key, &head_threshold); let ord = if head_ord == Ordering::Equal { tail_cmp.compare(&tail_key, &tail_threshold) @@ -382,19 +407,60 @@ where head_ord }; if ord == Ordering::Greater { - top_n_computer.append_doc_unchecked(doc, (head_key, tail_key)); + push_assuming_capacity( + ComparableDoc { + sort_key: (head_key, tail_key), + doc, + }, + top_n_computer.buffer(), + ); } } } } else { // Eagerly push, without a threshold to compare to. - let head_keys = self.head.segment_sort_keys(docs, ValueRange::All); - let tail_keys = self.tail.segment_sort_keys(docs, ValueRange::All); - for ((doc, head_key), (_, tail_key)) in head_keys.drain(..).zip(tail_keys.drain(..)) { + scratch.head_output.clear(); + self.head.segment_sort_keys( + docs, + &mut scratch.head_output, + &mut scratch.head, + ValueRange::All, + ); + + scratch.tail_output.clear(); + scratch.tail_input_docs.clear(); + for cd in &scratch.head_output { + scratch.tail_input_docs.push(cd.doc); + } + + self.tail.segment_sort_keys( + &scratch.tail_input_docs, + &mut scratch.tail_output, + &mut scratch.tail, + ValueRange::All, + ); + + for (head_doc, tail_doc) in scratch + .head_output + .drain(..) + .zip(scratch.tail_output.drain(..)) + { + debug_assert_eq!(head_doc.doc, tail_doc.doc); + let doc = head_doc.doc; + let head_key = head_doc.sort_key; + let tail_key = tail_doc.sort_key; + // We validated at the top of the method that we have capacity. - top_n_computer.append_doc_unchecked(doc, (head_key, tail_key)); + push_assuming_capacity( + ComparableDoc { + sort_key: (head_key, tail_key), + doc, + }, + top_n_computer.buffer(), + ); } } + top_n_computer.scratch = scratch; } #[inline(always)] @@ -430,6 +496,7 @@ where type SortKey = NewScore; type SegmentSortKey = T::SegmentSortKey; type SegmentComparator = T::SegmentComparator; + type Buffer = T::Buffer; fn segment_comparator(&self) -> Self::SegmentComparator { self.sort_key_computer.segment_comparator() @@ -441,11 +508,13 @@ where fn segment_sort_keys( &mut self, - docs: &[DocId], - _filter: ValueRange, - ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + input_docs: &[DocId], + output: &mut Vec>, + buffer: &mut Self::Buffer, + filter: ValueRange, + ) { self.sort_key_computer - .segment_sort_keys(docs, ValueRange::All) + .segment_sort_keys(input_docs, output, buffer, filter) } #[inline(always)] @@ -453,7 +522,7 @@ where &mut self, doc: DocId, score: Score, - top_n_computer: &mut TopNComputer, + top_n_computer: &mut TopNComputer, ) { self.sort_key_computer .compute_sort_key_and_collect(doc, score, top_n_computer); @@ -462,7 +531,7 @@ where fn compute_sort_keys_and_collect>( &mut self, docs: &[DocId], - top_n_computer: &mut TopNComputer, + top_n_computer: &mut TopNComputer, ) { self.sort_key_computer .compute_sort_keys_and_collect(docs, top_n_computer); @@ -521,11 +590,7 @@ where tail: ChainSegmentSortKeyComputer { head: sort_key_computer2, tail: sort_key_computer3, - head_key_buffer: Vec::new(), - doc_buffer: Vec::new(), }, - head_key_buffer: Vec::new(), - doc_buffer: Vec::new(), }, map, }) @@ -589,14 +654,8 @@ where tail: ChainSegmentSortKeyComputer { head: sort_key_computer3, tail: sort_key_computer4, - head_key_buffer: Vec::new(), - doc_buffer: Vec::new(), }, - head_key_buffer: Vec::new(), - doc_buffer: Vec::new(), }, - head_key_buffer: Vec::new(), - doc_buffer: Vec::new(), }, map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| { (sort_key1, sort_key2, sort_key3, sort_key4) @@ -620,9 +679,11 @@ where } } +use std::marker::PhantomData; + pub struct FuncSegmentSortKeyComputer { func: F, - buffer: Vec<(DocId, TSortKey)>, + _phantom: PhantomData, } impl SortKeyComputer for F @@ -638,7 +699,7 @@ where fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { Ok(FuncSegmentSortKeyComputer { func: (self)(segment_reader), - buffer: Vec::new(), + _phantom: PhantomData, }) } } @@ -651,6 +712,7 @@ where type SortKey = TSortKey; type SegmentSortKey = TSortKey; type SegmentComparator = NaturalComparator; + type Buffer = (); fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey { (self.func)(doc) @@ -658,15 +720,17 @@ where fn segment_sort_keys( &mut self, - docs: &[DocId], + input_docs: &[DocId], + output: &mut Vec>, + _buffer: &mut Self::Buffer, _filter: ValueRange, - ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { - self.buffer.clear(); - self.buffer.reserve(docs.len()); - for &doc in docs { - self.buffer.push((doc, (self.func)(doc))); + ) { + for &doc in input_docs { + output.push(ComparableDoc { + sort_key: (self.func)(doc), + doc, + }); } - &mut self.buffer } /// Convert a segment level score into the global level score. diff --git a/src/collector/sort_key_top_collector.rs b/src/collector/sort_key_top_collector.rs index 457854620..2287891b6 100644 --- a/src/collector/sort_key_top_collector.rs +++ b/src/collector/sort_key_top_collector.rs @@ -99,7 +99,12 @@ where TSegmentSortKeyComputer: SegmentSortKeyComputer, C: Comparator, { - pub(crate) topn_computer: TopNComputer, + pub(crate) topn_computer: TopNComputer< + TSegmentSortKeyComputer::SegmentSortKey, + DocId, + C, + TSegmentSortKeyComputer::Buffer, + >, pub(crate) segment_ord: u32, pub(crate) segment_sort_key_computer: TSegmentSortKeyComputer, } diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 5ce928c47..e03edd0ff 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -11,8 +11,7 @@ use crate::collector::sort_key::{ SortByStaticFastValue, SortByString, }; use crate::collector::sort_key_top_collector::TopBySortKeyCollector; -use crate::collector::top_collector::ComparableDoc; -use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; +use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer}; use crate::fastfield::FastValue; use crate::{DocAddress, DocId, Order, Score, SegmentReader}; @@ -482,6 +481,7 @@ where type SortKey = TSortKey; type SegmentSortKey = TSortKey; type SegmentComparator = NaturalComparator; + type Buffer = (); fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey { (self.sort_key_fn)(doc, score) @@ -489,9 +489,11 @@ where fn segment_sort_keys( &mut self, - _docs: &[DocId], + _input_docs: &[DocId], + _output: &mut Vec>, + _buffer: &mut Self::Buffer, _filter: ValueRange, - ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + ) { unimplemented!("Batch computation is not supported for tweak score.") } @@ -518,12 +520,14 @@ where /// the ascending `DocId|DocAddress` tie-breaking behavior without additional comparisons. #[derive(Serialize, Deserialize)] #[serde(from = "TopNComputerDeser")] -pub struct TopNComputer { +pub struct TopNComputer { /// The buffer reverses sort order to get top-semantics instead of bottom-semantics buffer: Vec>, top_n: usize, pub(crate) threshold: Option, comparator: C, + #[serde(skip)] + pub scratch: Buffer, } // Intermediate struct for TopNComputer for deserialization, to keep vec capacity @@ -535,7 +539,9 @@ struct TopNComputerDeser { comparator: C, } -impl From> for TopNComputer { +impl From> for TopNComputer +where Buffer: Default +{ fn from(mut value: TopNComputerDeser) -> Self { let expected_cap = value.top_n.max(1) * 2; let current_cap = value.buffer.capacity(); @@ -550,12 +556,15 @@ impl From> for TopNComputer std::fmt::Debug for TopNComputer -where C: Comparator +impl std::fmt::Debug for TopNComputer +where + C: Comparator, + Buffer: std::fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter) -> std::fmt::Result { f.debug_struct("TopNComputer") @@ -563,12 +572,13 @@ where C: Comparator .field("top_n", &self.top_n) .field("current_threshold", &self.threshold) .field("comparator", &self.comparator) + .field("scratch", &self.scratch) .finish() } } // Custom clone to keep capacity -impl Clone for TopNComputer { +impl Clone for TopNComputer { fn clone(&self) -> Self { let mut buffer_clone = Vec::with_capacity(self.buffer.capacity()); buffer_clone.extend(self.buffer.iter().cloned()); @@ -577,11 +587,12 @@ impl Clone for TopNComputer { top_n: self.top_n, threshold: self.threshold.clone(), comparator: self.comparator.clone(), + scratch: self.scratch.clone(), } } } -impl TopNComputer +impl TopNComputer where D: Ord, TSortKey: Clone, @@ -595,7 +606,7 @@ where } #[inline(always)] -fn compare_for_top_k>( +pub fn compare_for_top_k>( c: &C, lhs: &ComparableDoc, rhs: &ComparableDoc, @@ -606,11 +617,12 @@ fn compare_for_top_k>( // sort by doc id } -impl TopNComputer +impl TopNComputer where D: Ord, TSortKey: Clone, C: Comparator, + Buffer: Default, { /// Create a new `TopNComputer`. /// Internally it will allocate a buffer of size `(top_n.max(1) * 2) + @@ -624,6 +636,7 @@ where top_n, threshold: None, comparator, + scratch: Buffer::default(), } } @@ -649,15 +662,6 @@ where pub(crate) fn append_doc(&mut self, doc: D, sort_key: TSortKey) { self.reserve(1); // This cannot panic, because we've reserved room for one element. - self.append_doc_unchecked(doc, sort_key); - } - - // Append a document to the top n. `reserve` must already have been called to ensure that there - // is capacity, or this method will panic. - // - // At this point, we need to have established that the doc is above the threshold. - #[inline(always)] - pub(crate) fn append_doc_unchecked(&mut self, doc: D, sort_key: TSortKey) { let comparable_doc = ComparableDoc { doc, sort_key }; push_assuming_capacity(comparable_doc, &mut self.buffer); } @@ -672,6 +676,16 @@ where } } + pub(crate) fn buffer(&mut self) -> &mut Vec> { + &mut self.buffer + } + + pub(crate) fn buffer_and_scratch( + &mut self, + ) -> (&mut Vec>, &mut Buffer) { + (&mut self.buffer, &mut self.scratch) + } + #[inline(never)] fn truncate_top_n(&mut self) -> TSortKey { // Use select_nth_unstable to find the top nth score @@ -729,8 +743,7 @@ mod tests { use super::{TopDocs, TopNComputer}; use crate::collector::sort_key::{ComparatorEnum, NaturalComparator, ReverseComparator}; - use crate::collector::top_collector::ComparableDoc; - use crate::collector::{Collector, DocSetCollector}; + use crate::collector::{Collector, ComparableDoc, DocSetCollector}; use crate::query::{AllQuery, Query, QueryParser}; use crate::schema::{Field, Schema, FAST, STORED, TEXT}; use crate::time::format_description::well_known::Rfc3339; @@ -1760,7 +1773,8 @@ mod tests { #[test] fn test_top_n_computer_not_at_capacity() { - let mut top_n_computer = TopNComputer::new_with_comparator(4, NaturalComparator); + let mut top_n_computer: TopNComputer = + TopNComputer::new_with_comparator(4, NaturalComparator); top_n_computer.append_doc(1, 0.8); top_n_computer.append_doc(3, 0.2); top_n_computer.append_doc(5, 0.3); @@ -1785,7 +1799,8 @@ mod tests { #[test] fn test_top_n_computer_at_capacity() { - let mut top_collector = TopNComputer::new_with_comparator(4, NaturalComparator); + let mut top_collector: TopNComputer = + TopNComputer::new_with_comparator(4, NaturalComparator); top_collector.append_doc(1, 0.8); top_collector.append_doc(3, 0.2); top_collector.append_doc(5, 0.3); @@ -1822,12 +1837,14 @@ mod tests { let doc_ids_collection = [4, 5, 6]; let score = 3.3f32; - let mut top_collector_limit_2 = TopNComputer::new_with_comparator(2, NaturalComparator); + let mut top_collector_limit_2: TopNComputer = + TopNComputer::new_with_comparator(2, NaturalComparator); for id in &doc_ids_collection { top_collector_limit_2.append_doc(*id, score); } - let mut top_collector_limit_3 = TopNComputer::new_with_comparator(3, NaturalComparator); + let mut top_collector_limit_3: TopNComputer = + TopNComputer::new_with_comparator(3, NaturalComparator); for id in &doc_ids_collection { top_collector_limit_3.append_doc(*id, score); } @@ -1848,15 +1865,16 @@ mod bench { #[bench] fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) { - let mut top_collector = TopNComputer::new_with_comparator(100, NaturalComparator); + let mut top_collector: TopNComputer = + TopNComputer::new_with_comparator(100, NaturalComparator); for i in 0..100 { - top_collector.append_doc(i, 0.8); + top_collector.append_doc(i as u32, 0.8); } b.iter(|| { for i in 0..100 { - top_collector.append_doc(i, 0.8); + top_collector.append_doc(i as u32, 0.8); } }); }