diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs index fe3e7ba7f..3aaa87907 100644 --- a/src/aggregation/metric/top_hits.rs +++ b/src/aggregation/metric/top_hits.rs @@ -151,7 +151,7 @@ impl RetrievalFields { return Ok(vec![field.to_owned()]); } - let pattern = globbed_string_to_regex(&field)?; + let pattern = globbed_string_to_regex(field)?; let fields = reader .iter_columns()? .map(|(name, _)| { diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index b6312a3bd..834428bdb 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use std::sync::Arc; use columnar::ColumnValues; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use super::Collector; @@ -720,17 +721,43 @@ impl SegmentCollector for TopScoreSegmentCollector { /// /// For TopN == 0, it will be relative expensive. #[derive(Clone, Serialize, Deserialize)] -pub struct TopNComputer { +#[serde(from = "TopNComputerDeser")] +pub struct TopNComputer { /// The buffer reverses sort order to get top-semantics instead of bottom-semantics - buffer: Vec>, + buffer: Vec>, top_n: usize, pub(crate) threshold: Option, } +// Intermediate struct for TopNComputer for deserialization, to fix vec capacity +#[derive(Deserialize)] +struct TopNComputerDeser { + buffer: Vec>, + top_n: usize, + threshold: Option, +} -impl TopNComputer +impl From> for TopNComputer { + fn from(mut value: TopNComputerDeser) -> Self { + let expected_cap = value.top_n.max(1) * 2; + let current_cap = value.buffer.capacity(); + if current_cap < expected_cap { + value.buffer.reserve_exact(expected_cap - current_cap); + } else { + value.buffer.shrink_to(expected_cap); + } + + TopNComputer { + buffer: value.buffer, + top_n: value.top_n, + threshold: value.threshold, + } + } +} + +impl TopNComputer where Score: PartialOrd + Clone, - DocId: Ord + Clone, + D: Serialize + DeserializeOwned + Ord + Clone, { /// Create a new `TopNComputer`. /// Internally it will allocate a buffer of size `2 * top_n`. @@ -746,7 +773,7 @@ where /// Push a new document to the top n. /// If the document is below the current threshold, it will be ignored. #[inline] - pub fn push(&mut self, feature: Score, doc: DocId) { + pub fn push(&mut self, feature: Score, doc: D) { if let Some(last_median) = self.threshold.clone() { if feature < last_median { return; @@ -783,7 +810,7 @@ where } /// Returns the top n elements in sorted order. - pub fn into_sorted_vec(mut self) -> Vec> { + pub fn into_sorted_vec(mut self) -> Vec> { if self.buffer.len() > self.top_n { self.truncate_top_n(); } @@ -794,7 +821,7 @@ where /// Returns the top n elements in stored order. /// Useful if you do not need the elements in sorted order, /// for example when merging the results of multiple segments. - pub fn into_vec(mut self) -> Vec> { + pub fn into_vec(mut self) -> Vec> { if self.buffer.len() > self.top_n { self.truncate_top_n(); } @@ -833,6 +860,25 @@ mod tests { crate::assert_nearly_equals!(result.0, expected.0); } } + #[test] + fn test_topn_computer_serde() { + let computer: TopNComputer = TopNComputer::new(1); + + let computer_ser = serde_json::to_string(&computer).unwrap(); + let mut computer: TopNComputer = serde_json::from_str(&computer_ser).unwrap(); + + computer.push(1u32, 5u32); + computer.push(1u32, 0u32); + computer.push(1u32, 7u32); + + assert_eq!( + computer.into_sorted_vec(), + &[ComparableDoc { + feature: 1u32, + doc: 0u32, + },] + ); + } #[test] fn test_empty_topn_computer() { diff --git a/src/index/index.rs b/src/index/index.rs index e7095f47e..ce68d4b13 100644 --- a/src/index/index.rs +++ b/src/index/index.rs @@ -323,7 +323,10 @@ impl Index { } /// Custom thread pool by a outer thread pool. - pub fn set_shared_multithread_executor(&mut self, shared_thread_pool: Arc) -> crate::Result<()> { + pub fn set_shared_multithread_executor( + &mut self, + shared_thread_pool: Arc, + ) -> crate::Result<()> { self.executor = shared_thread_pool.clone(); Ok(()) }