From 1afc432df8e2d7c6604e5fc78d00a9a878385486 Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Tue, 23 Dec 2025 17:23:10 -0700 Subject: [PATCH] Use an internal buffer in the SegmentSortKeyComputer. --- src/collector/sort_key/order.rs | 10 +- src/collector/sort_key/sort_by_erased_type.rs | 25 ++- src/collector/sort_key/sort_by_score.rs | 12 +- .../sort_key/sort_by_static_fast_value.rs | 15 +- src/collector/sort_key/sort_by_string.rs | 14 +- src/collector/sort_key/sort_key_computer.rs | 212 ++++++++++-------- src/collector/top_score_collector.rs | 4 + 7 files changed, 172 insertions(+), 120 deletions(-) diff --git a/src/collector/sort_key/order.rs b/src/collector/sort_key/order.rs index 7f9315b3c..a063fa11a 100644 --- a/src/collector/sort_key/order.rs +++ b/src/collector/sort_key/order.rs @@ -506,16 +506,24 @@ impl SegmentSortKeyComput where TSegmentSortKeyComputer: SegmentSortKeyComputer, TSegmentSortKey: Clone + 'static + Sync + Send, - TComparator: Comparator + 'static + Sync + Send, + TComparator: Comparator + Clone + 'static + Sync + Send, { type SortKey = TSegmentSortKeyComputer::SortKey; type SegmentSortKey = TSegmentSortKey; type SegmentComparator = TComparator; + fn segment_comparator(&self) -> Self::SegmentComparator { + self.comparator.clone() + } + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { self.segment_sort_key_computer.segment_sort_key(doc, score) } + fn segment_sort_keys(&mut self, docs: &[DocId]) -> &[Self::SegmentSortKey] { + self.segment_sort_key_computer.segment_sort_keys(docs) + } + #[inline(always)] fn compare_segment_sort_key( &self, diff --git a/src/collector/sort_key/sort_by_erased_type.rs b/src/collector/sort_key/sort_by_erased_type.rs index df41069ea..89eed733a 100644 --- a/src/collector/sort_key/sort_by_erased_type.rs +++ b/src/collector/sort_key/sort_by_erased_type.rs @@ -1,5 +1,6 @@ use columnar::{ColumnType, MonotonicallyMappableToU64}; +use crate::collector::sort_key::sort_by_score::SortBySimilarityScoreSegmentComputer; use crate::collector::sort_key::{ NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString, }; @@ -36,12 +37,7 @@ impl SortByErasedType { trait ErasedSegmentSortKeyComputer: Send + Sync { fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option; - fn segment_sort_keys(&mut self, docs: &[DocId], output: &mut Vec>) { - output.reserve(docs.len()); - for &doc in docs { - output.push(self.segment_sort_key(doc, 0.0)); - } - } + fn segment_sort_keys(&mut self, docs: &[DocId]) -> &[Option]; fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue; } @@ -59,8 +55,8 @@ where self.inner.segment_sort_key(doc, score) } - fn segment_sort_keys(&mut self, docs: &[DocId], output: &mut Vec>) { - self.inner.segment_sort_keys(docs, output); + fn segment_sort_keys(&mut self, docs: &[DocId]) -> &[Option] { + self.inner.segment_sort_keys(docs) } fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { @@ -70,7 +66,7 @@ where } struct ScoreSegmentSortKeyComputer { - segment_computer: SortBySimilarityScore, + segment_computer: SortBySimilarityScoreSegmentComputer, } impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer { @@ -79,6 +75,10 @@ impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer { Some(score_value.to_u64()) } + fn segment_sort_keys(&mut self, _docs: &[DocId]) -> &[Option] { + unimplemented!("Batch computation not supported for score sorting") + } + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { let score_value: u64 = sort_key.expect("This implementation always produces a score."); OwnedValue::F64(f64::from_u64(score_value)) @@ -184,7 +184,8 @@ impl SortKeyComputer for SortByErasedType { } } Self::Score => Box::new(ScoreSegmentSortKeyComputer { - segment_computer: SortBySimilarityScore, + segment_computer: SortBySimilarityScore + .segment_sort_key_computer(segment_reader)?, }), }; Ok(ErasedColumnSegmentSortKeyComputer { inner }) @@ -205,8 +206,8 @@ impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer { self.inner.segment_sort_key(doc, score) } - fn segment_sort_keys(&mut self, docs: &[DocId], output: &mut Vec) { - self.inner.segment_sort_keys(docs, output); + fn segment_sort_keys(&mut self, docs: &[DocId]) -> &[Self::SegmentSortKey] { + self.inner.segment_sort_keys(docs) } fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue { diff --git a/src/collector/sort_key/sort_by_score.rs b/src/collector/sort_key/sort_by_score.rs index a23660e56..a162d70ae 100644 --- a/src/collector/sort_key/sort_by_score.rs +++ b/src/collector/sort_key/sort_by_score.rs @@ -9,7 +9,7 @@ pub struct SortBySimilarityScore; impl SortKeyComputer for SortBySimilarityScore { type SortKey = Score; - type Child = SortBySimilarityScore; + type Child = SortBySimilarityScoreSegmentComputer; type Comparator = NaturalComparator; @@ -21,7 +21,7 @@ impl SortKeyComputer for SortBySimilarityScore { &self, _segment_reader: &crate::SegmentReader, ) -> crate::Result { - Ok(SortBySimilarityScore) + Ok(SortBySimilarityScoreSegmentComputer) } // Sorting by score is special in that it allows for the Block-Wand optimization. @@ -61,7 +61,9 @@ impl SortKeyComputer for SortBySimilarityScore { } } -impl SegmentSortKeyComputer for SortBySimilarityScore { +pub struct SortBySimilarityScoreSegmentComputer; + +impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer { type SortKey = Score; type SegmentSortKey = Score; type SegmentComparator = NaturalComparator; @@ -71,6 +73,10 @@ impl SegmentSortKeyComputer for SortBySimilarityScore { score } + fn segment_sort_keys(&mut self, _docs: &[DocId]) -> &[Self::SegmentSortKey] { + unimplemented!("Batch computation not supported for score sorting") + } + fn convert_segment_sort_key(&self, score: Score) -> Score { score } diff --git a/src/collector/sort_key/sort_by_static_fast_value.rs b/src/collector/sort_key/sort_by_static_fast_value.rs index 3acba3983..fbcfe41ce 100644 --- a/src/collector/sort_key/sort_by_static_fast_value.rs +++ b/src/collector/sort_key/sort_by_static_fast_value.rs @@ -71,6 +71,7 @@ impl SortKeyComputer for SortByStaticFastValue { Ok(SortByFastValueSegmentSortKeyComputer { sort_column, typ: PhantomData, + buffer: Vec::new(), }) } } @@ -78,6 +79,7 @@ impl SortKeyComputer for SortByStaticFastValue { pub struct SortByFastValueSegmentSortKeyComputer { sort_column: Column, typ: PhantomData, + buffer: Vec>, } impl SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer { @@ -90,10 +92,10 @@ impl SegmentSortKeyComputer for SortByFastValueSegmentSortKeyCompu self.sort_column.first(doc) } - fn segment_sort_keys(&mut self, docs: &[DocId], output: &mut Vec) { - let start = output.len(); - output.resize(start + docs.len(), None); - self.sort_column.first_vals(docs, &mut output[start..]); + fn segment_sort_keys(&mut self, docs: &[DocId]) -> &[Self::SegmentSortKey] { + self.buffer.resize(docs.len(), None); + self.sort_column.first_vals(docs, &mut self.buffer); + &self.buffer } fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { @@ -131,9 +133,8 @@ mod tests { let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap(); let docs = vec![0, 1]; - let mut output = Vec::new(); - computer.segment_sort_keys(&docs, &mut output); + let output = computer.segment_sort_keys(&docs); - assert_eq!(output, vec![Some(10), Some(20)]); + assert_eq!(output, &[Some(10), Some(20)]); } } diff --git a/src/collector/sort_key/sort_by_string.rs b/src/collector/sort_key/sort_by_string.rs index 08a5161f6..92272eba9 100644 --- a/src/collector/sort_key/sort_by_string.rs +++ b/src/collector/sort_key/sort_by_string.rs @@ -38,12 +38,16 @@ impl SortKeyComputer for SortByString { segment_reader: &crate::SegmentReader, ) -> crate::Result { let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?; - Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt }) + Ok(ByStringColumnSegmentSortKeyComputer { + str_column_opt, + buffer: Vec::new(), + }) } } pub struct ByStringColumnSegmentSortKeyComputer { str_column_opt: Option, + buffer: Vec>, } impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { @@ -57,12 +61,12 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { str_column.ords().first(doc) } - fn segment_sort_keys(&mut self, docs: &[DocId], output: &mut Vec) { - let start = output.len(); - output.resize(start + docs.len(), None); + fn segment_sort_keys(&mut self, docs: &[DocId]) -> &[Self::SegmentSortKey] { + self.buffer.resize(docs.len(), None); if let Some(str_column) = &self.str_column_opt { - str_column.ords().first_vals(docs, &mut output[start..]); + str_column.ords().first_vals(docs, &mut self.buffer); } + &self.buffer } fn convert_segment_sort_key(&self, term_ord_opt: Option) -> Option { diff --git a/src/collector/sort_key/sort_key_computer.rs b/src/collector/sort_key/sort_key_computer.rs index d6e7e3aa6..20fc9330d 100644 --- a/src/collector/sort_key/sort_key_computer.rs +++ b/src/collector/sort_key/sort_key_computer.rs @@ -24,7 +24,7 @@ pub trait SegmentSortKeyComputer: 'static { type SegmentSortKey: 'static + Clone + Send + Sync + Clone; /// Comparator type. - type SegmentComparator: Comparator + 'static; + type SegmentComparator: Comparator + Clone + 'static; /// Returns the segment sort key comparator. fn segment_comparator(&self) -> Self::SegmentComparator { @@ -36,13 +36,9 @@ pub trait SegmentSortKeyComputer: 'static { /// Computes the sort keys for a batch of documents. /// - /// The computed sort keys are appended to the `output` buffer. - fn segment_sort_keys(&mut self, docs: &[DocId], output: &mut Vec) { - output.reserve(docs.len()); - for &doc in docs { - output.push(self.segment_sort_key(doc, 0.0)); - } - } + /// The computed sort keys are stored in an internal buffer and returned as a slice. + /// Subsequent calls to this method may reuse and overwrite the internal buffer. + fn segment_sort_keys(&mut self, docs: &[DocId]) -> &[Self::SegmentSortKey]; /// Computes the sort key and pushes the document in a TopN Computer. /// @@ -71,23 +67,22 @@ pub trait SegmentSortKeyComputer: 'static { // TODO: Would need to split the borrow of the TopNComputer to avoid cloning the // threshold here. let threshold = threshold.clone(); + let comparator = self.segment_comparator(); + let sort_keys = self.segment_sort_keys(docs); - let mut sort_keys = Vec::with_capacity(docs.len()); - self.segment_sort_keys(docs, &mut sort_keys); for (&doc, sort_key) in docs.iter().zip(sort_keys) { - let cmp = self.compare_segment_sort_key(&sort_key, &threshold); + let cmp = comparator.compare(sort_key, &threshold); if cmp == Ordering::Greater { // We validated at the top of the method that we have capacity. - top_n_computer.append_doc_unchecked(doc, sort_key); + top_n_computer.append_doc_unchecked(doc, sort_key.clone()); } } } else { // Eagerly push, without a threshold to compare to. - let mut sort_keys = Vec::with_capacity(docs.len()); - self.segment_sort_keys(docs, &mut sort_keys); + let sort_keys = self.segment_sort_keys(docs); for (&doc, sort_key) in docs.iter().zip(sort_keys) { // We validated at the top of the method that we have capacity. - top_n_computer.append_doc_unchecked(doc, sort_key); + top_n_computer.append_doc_unchecked(doc, sort_key.clone()); } } } @@ -136,9 +131,10 @@ pub trait SegmentSortKeyComputer: 'static { threshold: &Self::SegmentSortKey, output: &mut Vec>, ) { + let comparator = self.segment_comparator(); for &doc in docs { let sort_key = self.segment_sort_key(doc, 0.0); - let cmp = self.compare_segment_sort_key(&sort_key, threshold); + let cmp = comparator.compare(&sort_key, threshold); if cmp != Ordering::Less { push_assuming_capacity(ComparableDoc { sort_key, doc }, output); } @@ -212,7 +208,8 @@ where TailSortKeyComputer: SortKeyComputer, { type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey); - type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child); + type Child = + ChainSegmentSortKeyComputer; type Comparator = ( HeadSortKeyComputer::Comparator, @@ -224,10 +221,11 @@ where } fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { - Ok(( - self.0.segment_sort_key_computer(segment_reader)?, - self.1.segment_sort_key_computer(segment_reader)?, - )) + Ok(ChainSegmentSortKeyComputer { + head: self.0.segment_sort_key_computer(segment_reader)?, + tail: self.1.segment_sort_key_computer(segment_reader)?, + buffer: Vec::new(), + }) } /// Checks whether the schema is compatible with the sort key computer. @@ -245,25 +243,32 @@ where } } -impl SegmentSortKeyComputer - for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer) +pub struct ChainSegmentSortKeyComputer where - HeadSegmentSortKeyComputer: SegmentSortKeyComputer, - TailSegmentSortKeyComputer: SegmentSortKeyComputer, + Head: SegmentSortKeyComputer, + Tail: SegmentSortKeyComputer, { - type SortKey = ( - HeadSegmentSortKeyComputer::SortKey, - TailSegmentSortKeyComputer::SortKey, - ); - type SegmentSortKey = ( - HeadSegmentSortKeyComputer::SegmentSortKey, - TailSegmentSortKeyComputer::SegmentSortKey, - ); + head: Head, + tail: Tail, + buffer: Vec<(Head::SegmentSortKey, Tail::SegmentSortKey)>, +} - type SegmentComparator = ( - HeadSegmentSortKeyComputer::SegmentComparator, - TailSegmentSortKeyComputer::SegmentComparator, - ); +impl SegmentSortKeyComputer for ChainSegmentSortKeyComputer +where + Head: SegmentSortKeyComputer, + Tail: SegmentSortKeyComputer, +{ + type SortKey = (Head::SortKey, Tail::SortKey); + type SegmentSortKey = (Head::SegmentSortKey, Tail::SegmentSortKey); + + type SegmentComparator = (Head::SegmentComparator, Tail::SegmentComparator); + + fn segment_comparator(&self) -> Self::SegmentComparator { + ( + self.head.segment_comparator(), + self.tail.segment_comparator(), + ) + } /// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on /// its ordering. @@ -275,22 +280,18 @@ where left: &Self::SegmentSortKey, right: &Self::SegmentSortKey, ) -> Ordering { - self.0 + self.head .compare_segment_sort_key(&left.0, &right.0) - .then_with(|| self.1.compare_segment_sort_key(&left.1, &right.1)) + .then_with(|| self.tail.compare_segment_sort_key(&left.1, &right.1)) } - fn segment_sort_keys(&mut self, docs: &[DocId], output: &mut Vec) { - let mut head_keys = Vec::with_capacity(docs.len()); - self.0.segment_sort_keys(docs, &mut head_keys); - - let mut tail_keys = Vec::with_capacity(docs.len()); - self.1.segment_sort_keys(docs, &mut tail_keys); - - output.reserve(docs.len()); - for (head, tail) in head_keys.into_iter().zip(tail_keys) { - output.push((head, tail)); - } + fn segment_sort_keys(&mut self, docs: &[DocId]) -> &[Self::SegmentSortKey] { + let head_keys = self.head.segment_sort_keys(docs); + let tail_keys = self.tail.segment_sort_keys(docs); + self.buffer.clear(); + self.buffer + .extend(head_keys.iter().cloned().zip(tail_keys.iter().cloned())); + &self.buffer } #[inline(always)] @@ -326,31 +327,30 @@ where // TODO: Would need to split the borrow of the TopNComputer to avoid cloning the // threshold here. let threshold = threshold.clone(); + let comparator = self.segment_comparator(); + let sort_keys = self.segment_sort_keys(docs); - let mut sort_keys = Vec::with_capacity(docs.len()); - self.segment_sort_keys(docs, &mut sort_keys); for (&doc, sort_key) in docs.iter().zip(sort_keys) { - let cmp = self.compare_segment_sort_key(&sort_key, &threshold); + let cmp = comparator.compare(sort_key, &threshold); if cmp == Ordering::Greater { // We validated at the top of the method that we have capacity. - top_n_computer.append_doc_unchecked(doc, sort_key); + top_n_computer.append_doc_unchecked(doc, sort_key.clone()); } } } else { // Eagerly push, without a threshold to compare to. - let mut sort_keys = Vec::with_capacity(docs.len()); - self.segment_sort_keys(docs, &mut sort_keys); + let sort_keys = self.segment_sort_keys(docs); for (&doc, sort_key) in docs.iter().zip(sort_keys) { // We validated at the top of the method that we have capacity. - top_n_computer.append_doc_unchecked(doc, sort_key); + top_n_computer.append_doc_unchecked(doc, sort_key.clone()); } } } #[inline(always)] fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { - let head_sort_key = self.0.segment_sort_key(doc, score); - let tail_sort_key = self.1.segment_sort_key(doc, score); + let head_sort_key = self.head.segment_sort_key(doc, score); + let tail_sort_key = self.tail.segment_sort_key(doc, score); (head_sort_key, tail_sort_key) } @@ -362,13 +362,15 @@ where ) -> Option<(Ordering, Self::SegmentSortKey)> { let (head_threshold, tail_threshold) = threshold; let (head_cmp, head_sort_key) = - self.0.accept_sort_key_lazy(doc_id, score, head_threshold)?; + self.head + .accept_sort_key_lazy(doc_id, score, head_threshold)?; if head_cmp == Ordering::Equal { let (tail_cmp, tail_sort_key) = - self.1.accept_sort_key_lazy(doc_id, score, tail_threshold)?; + self.tail + .accept_sort_key_lazy(doc_id, score, tail_threshold)?; Some((tail_cmp, (head_sort_key, tail_sort_key))) } else { - let tail_sort_key = self.1.segment_sort_key(doc_id, score); + let tail_sort_key = self.tail.segment_sort_key(doc_id, score); Some((head_cmp, (head_sort_key, tail_sort_key))) } } @@ -376,21 +378,21 @@ where fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { let (head_sort_key, tail_sort_key) = sort_key; ( - self.0.convert_segment_sort_key(head_sort_key), - self.1.convert_segment_sort_key(tail_sort_key), + self.head.convert_segment_sort_key(head_sort_key), + self.tail.convert_segment_sort_key(tail_sort_key), ) } } /// This struct is used as an adapter to take a sort key computer and map its score to another /// new sort key. -pub struct MappedSegmentSortKeyComputer { +pub struct MappedSegmentSortKeyComputer { sort_key_computer: T, - map: fn(PreviousSortKey) -> NewSortKey, + map: fn(T::SortKey) -> NewSortKey, } impl SegmentSortKeyComputer - for MappedSegmentSortKeyComputer + for MappedSegmentSortKeyComputer where T: SegmentSortKeyComputer, PreviousScore: 'static + Clone + Send + Sync, @@ -400,12 +402,16 @@ where type SegmentSortKey = T::SegmentSortKey; type SegmentComparator = T::SegmentComparator; + fn segment_comparator(&self) -> Self::SegmentComparator { + self.sort_key_computer.segment_comparator() + } + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { self.sort_key_computer.segment_sort_key(doc, score) } - fn segment_sort_keys(&mut self, docs: &[DocId], output: &mut Vec) { - self.sort_key_computer.segment_sort_keys(docs, output); + fn segment_sort_keys(&mut self, docs: &[DocId]) -> &[Self::SegmentSortKey] { + self.sort_key_computer.segment_sort_keys(docs) } fn accept_sort_key_lazy( @@ -463,10 +469,6 @@ where ); type Child = MappedSegmentSortKeyComputer< <(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child, - ( - SortKeyComputer1::SortKey, - (SortKeyComputer2::SortKey, SortKeyComputer3::SortKey), - ), Self::SortKey, >; @@ -490,7 +492,15 @@ where let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?; let map = |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3); Ok(MappedSegmentSortKeyComputer { - sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)), + sort_key_computer: ChainSegmentSortKeyComputer { + head: sort_key_computer1, + tail: ChainSegmentSortKeyComputer { + head: sort_key_computer2, + tail: sort_key_computer3, + buffer: Vec::new(), + }, + buffer: Vec::new(), + }, map, }) } @@ -525,13 +535,6 @@ where SortKeyComputer1, (SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)), ) as SortKeyComputer>::Child, - ( - SortKeyComputer1::SortKey, - ( - SortKeyComputer2::SortKey, - (SortKeyComputer3::SortKey, SortKeyComputer4::SortKey), - ), - ), Self::SortKey, >; type SortKey = ( @@ -553,10 +556,19 @@ where let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?; let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?; Ok(MappedSegmentSortKeyComputer { - sort_key_computer: ( - sort_key_computer1, - (sort_key_computer2, (sort_key_computer3, sort_key_computer4)), - ), + sort_key_computer: ChainSegmentSortKeyComputer { + head: sort_key_computer1, + tail: ChainSegmentSortKeyComputer { + head: sort_key_computer2, + tail: ChainSegmentSortKeyComputer { + head: sort_key_computer3, + tail: sort_key_computer4, + buffer: Vec::new(), + }, + buffer: Vec::new(), + }, + buffer: Vec::new(), + }, map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| { (sort_key1, sort_key2, sort_key3, sort_key4) }, @@ -579,6 +591,11 @@ where } } +pub struct FuncSegmentSortKeyComputer { + func: F, + buffer: Vec, +} + impl SortKeyComputer for F where F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF, @@ -586,15 +603,18 @@ where TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug, { type SortKey = TSortKey; - type Child = SegmentF; + type Child = FuncSegmentSortKeyComputer; type Comparator = NaturalComparator; fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { - Ok((self)(segment_reader)) + Ok(FuncSegmentSortKeyComputer { + func: (self)(segment_reader), + buffer: Vec::new(), + }) } } -impl SegmentSortKeyComputer for F +impl SegmentSortKeyComputer for FuncSegmentSortKeyComputer where F: 'static + FnMut(DocId) -> TSortKey, TSortKey: 'static + PartialOrd + Clone + Send + Sync, @@ -604,7 +624,16 @@ where type SegmentComparator = NaturalComparator; fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey { - (self)(doc) + (self.func)(doc) + } + + fn segment_sort_keys(&mut self, docs: &[DocId]) -> &[Self::SegmentSortKey] { + self.buffer.clear(); + self.buffer.reserve(docs.len()); + for &doc in docs { + self.buffer.push((self.func)(doc)); + } + &self.buffer } /// Convert a segment level score into the global level score. @@ -647,10 +676,9 @@ mod tests { .unwrap(); let docs = vec![1, 2, 3]; - let mut output = Vec::new(); - segment_sort_key_computer.segment_sort_keys(&docs, &mut output); + let output = segment_sort_key_computer.segment_sort_keys(&docs); - assert_eq!(output, vec![(200, 10), (200, 10), (200, 10)]); + assert_eq!(output, &[(200, 10), (200, 10), (200, 10)]); } #[test] diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 118c6e750..271386bed 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -486,6 +486,10 @@ where (self.sort_key_fn)(doc, score) } + fn segment_sort_keys(&mut self, _docs: &[DocId]) -> &[Self::SegmentSortKey] { + unimplemented!("Batch computation is not supported for tweak score.") + } + /// Convert a segment level score into the global level score. fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { sort_key