From 0e16ed9ef794635668ab67e203174c53cbfdc62d Mon Sep 17 00:00:00 2001 From: PSeitz Date: Wed, 7 Feb 2024 12:52:06 +0100 Subject: [PATCH] Fix serde for TopNComputer (#2313) * Fix serde for TopNComputer The top hits aggregation changed the TopNComputer to be serializable, but capacity needs to be carried over, as it contains logic which is checked against when pushing elements (capacity == 0 is not allowed). * use serde from deser * remove pub, clippy --- src/aggregation/metric/top_hits.rs | 2 +- src/collector/top_score_collector.rs | 60 ++++++++++++++++++++++++---- src/index/index.rs | 5 ++- 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs index fe3e7ba7f..3aaa87907 100644 --- a/src/aggregation/metric/top_hits.rs +++ b/src/aggregation/metric/top_hits.rs @@ -151,7 +151,7 @@ impl RetrievalFields { return Ok(vec![field.to_owned()]); } - let pattern = globbed_string_to_regex(&field)?; + let pattern = globbed_string_to_regex(field)?; let fields = reader .iter_columns()? .map(|(name, _)| { diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index b6312a3bd..834428bdb 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use std::sync::Arc; use columnar::ColumnValues; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use super::Collector; @@ -720,17 +721,43 @@ impl SegmentCollector for TopScoreSegmentCollector { /// /// For TopN == 0, it will be relative expensive. #[derive(Clone, Serialize, Deserialize)] -pub struct TopNComputer { +#[serde(from = "TopNComputerDeser")] +pub struct TopNComputer { /// The buffer reverses sort order to get top-semantics instead of bottom-semantics - buffer: Vec>, + buffer: Vec>, top_n: usize, pub(crate) threshold: Option, } +// Intermediate struct for TopNComputer for deserialization, to fix vec capacity +#[derive(Deserialize)] +struct TopNComputerDeser { + buffer: Vec>, + top_n: usize, + threshold: Option, +} -impl TopNComputer +impl From> for TopNComputer { + fn from(mut value: TopNComputerDeser) -> Self { + let expected_cap = value.top_n.max(1) * 2; + let current_cap = value.buffer.capacity(); + if current_cap < expected_cap { + value.buffer.reserve_exact(expected_cap - current_cap); + } else { + value.buffer.shrink_to(expected_cap); + } + + TopNComputer { + buffer: value.buffer, + top_n: value.top_n, + threshold: value.threshold, + } + } +} + +impl TopNComputer where Score: PartialOrd + Clone, - DocId: Ord + Clone, + D: Serialize + DeserializeOwned + Ord + Clone, { /// Create a new `TopNComputer`. /// Internally it will allocate a buffer of size `2 * top_n`. @@ -746,7 +773,7 @@ where /// Push a new document to the top n. /// If the document is below the current threshold, it will be ignored. #[inline] - pub fn push(&mut self, feature: Score, doc: DocId) { + pub fn push(&mut self, feature: Score, doc: D) { if let Some(last_median) = self.threshold.clone() { if feature < last_median { return; @@ -783,7 +810,7 @@ where } /// Returns the top n elements in sorted order. - pub fn into_sorted_vec(mut self) -> Vec> { + pub fn into_sorted_vec(mut self) -> Vec> { if self.buffer.len() > self.top_n { self.truncate_top_n(); } @@ -794,7 +821,7 @@ where /// Returns the top n elements in stored order. /// Useful if you do not need the elements in sorted order, /// for example when merging the results of multiple segments. - pub fn into_vec(mut self) -> Vec> { + pub fn into_vec(mut self) -> Vec> { if self.buffer.len() > self.top_n { self.truncate_top_n(); } @@ -833,6 +860,25 @@ mod tests { crate::assert_nearly_equals!(result.0, expected.0); } } + #[test] + fn test_topn_computer_serde() { + let computer: TopNComputer = TopNComputer::new(1); + + let computer_ser = serde_json::to_string(&computer).unwrap(); + let mut computer: TopNComputer = serde_json::from_str(&computer_ser).unwrap(); + + computer.push(1u32, 5u32); + computer.push(1u32, 0u32); + computer.push(1u32, 7u32); + + assert_eq!( + computer.into_sorted_vec(), + &[ComparableDoc { + feature: 1u32, + doc: 0u32, + },] + ); + } #[test] fn test_empty_topn_computer() { diff --git a/src/index/index.rs b/src/index/index.rs index e7095f47e..ce68d4b13 100644 --- a/src/index/index.rs +++ b/src/index/index.rs @@ -323,7 +323,10 @@ impl Index { } /// Custom thread pool by a outer thread pool. - pub fn set_shared_multithread_executor(&mut self, shared_thread_pool: Arc) -> crate::Result<()> { + pub fn set_shared_multithread_executor( + &mut self, + shared_thread_pool: Arc, + ) -> crate::Result<()> { self.executor = shared_thread_pool.clone(); Ok(()) }