diff --git a/Cargo.toml b/Cargo.toml index 5b26e2c2d..1f415f689 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -78,6 +78,9 @@ paste = "1.0.11" more-asserts = "0.3.1" rand_distr = "0.4.3" time = { version = "0.3.10", features = ["serde-well-known", "macros"] } +postcard = { version = "1.0.4", features = [ + "use-std", +], default-features = false } [target.'cfg(not(windows))'.dev-dependencies] criterion = { version = "0.5", default-features = false } diff --git a/common/src/datetime.rs b/common/src/datetime.rs index 3aeadad3e..945856e07 100644 --- a/common/src/datetime.rs +++ b/common/src/datetime.rs @@ -40,7 +40,7 @@ pub type DatePrecision = DateTimePrecision; /// All constructors and conversions are provided as explicit /// functions and not by implementing any `From`/`Into` traits /// to prevent unintended usage. -#[derive(Clone, Default, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Clone, Default, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub struct DateTime { // Timestamp in nanoseconds. pub(crate) timestamp_nanos: i64, diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 877560aeb..4f8a7c6f0 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -292,7 +292,7 @@ impl AggregationWithAccessor { add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; } TopHits(ref mut top_hits) => { - top_hits.validate_and_resolve(reader.fast_fields().columnar())?; + top_hits.validate_and_resolve_field_names(reader.fast_fields().columnar())?; let accessors: Vec<(Column, ColumnType)> = top_hits .field_names() .iter() diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index 75bb1655f..126c2240e 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -4,6 +4,7 @@ use crate::aggregation::agg_req::{Aggregation, Aggregations}; use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::buf_collector::DOC_BLOCK_SIZE; use crate::aggregation::collector::AggregationCollector; +use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; use crate::aggregation::segment_agg_result::AggregationLimits; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::DistributedAggregationCollector; @@ -66,6 +67,22 @@ fn test_aggregation_flushing( } } }, + "top_hits_test":{ + "terms": { + "field": "string_id" + }, + "aggs": { + "bucketsL2": { + "top_hits": { + "size": 2, + "sort": [ + { "score": "asc" } + ], + "docvalue_fields": ["score"] + } + } + } + }, "histogram_test":{ "histogram": { "field": "score", @@ -108,6 +125,16 @@ fn test_aggregation_flushing( let searcher = reader.searcher(); let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap(); + + // Test postcard roundtrip serialization + let intermediate_agg_result_bytes = postcard::to_allocvec(&intermediate_agg_result).expect( + "Postcard Serialization failed, flatten etc. is not supported in the intermediate \ + result", + ); + let intermediate_agg_result: IntermediateAggregationResults = + postcard::from_bytes(&intermediate_agg_result_bytes) + .expect("Post deserialization failed"); + intermediate_agg_result .into_final_result(agg_req, &Default::default()) .unwrap() diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 4f91c4705..d59c58c62 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -20,7 +20,7 @@ use super::bucket::{ }; use super::metric::{ IntermediateAverage, IntermediateCount, IntermediateMax, IntermediateMin, IntermediateStats, - IntermediateSum, PercentilesCollector, TopHitsCollector, + IntermediateSum, PercentilesCollector, TopHitsTopNComputer, }; use super::segment_agg_result::AggregationLimits; use super::{format_date, AggregationError, Key, SerializedKey}; @@ -221,9 +221,9 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult Percentiles(_) => IntermediateAggregationResult::Metric( IntermediateMetricResult::Percentiles(PercentilesCollector::default()), ), - TopHits(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::TopHits( - TopHitsCollector::default(), - )), + TopHits(ref req) => IntermediateAggregationResult::Metric( + IntermediateMetricResult::TopHits(TopHitsTopNComputer::new(req.clone())), + ), } } @@ -285,7 +285,7 @@ pub enum IntermediateMetricResult { /// Intermediate sum result. Sum(IntermediateSum), /// Intermediate top_hits result - TopHits(TopHitsCollector), + TopHits(TopHitsTopNComputer), } impl IntermediateMetricResult { @@ -314,7 +314,7 @@ impl IntermediateMetricResult { .into_final_result(req.agg.as_percentile().expect("unexpected metric type")), ), IntermediateMetricResult::TopHits(top_hits) => { - MetricResult::TopHits(top_hits.finalize()) + MetricResult::TopHits(top_hits.into_final_result()) } } } diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index fd489922a..6da583a59 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -25,6 +25,8 @@ mod stats; mod sum; mod top_hits; +use std::collections::HashMap; + pub use average::*; pub use count::*; pub use max::*; @@ -36,6 +38,8 @@ pub use stats::*; pub use sum::*; pub use top_hits::*; +use crate::schema::OwnedValue; + /// Single-metric aggregations use this common result structure. /// /// Main reason to wrap it in value is to match elasticsearch output structure. @@ -92,8 +96,9 @@ pub struct TopHitsVecEntry { /// Search results, for queries that include field retrieval requests /// (`docvalue_fields`). - #[serde(flatten)] - pub search_results: FieldRetrivalResult, + #[serde(rename = "docvalue_fields")] + #[serde(skip_serializing_if = "HashMap::is_empty")] + pub doc_value_fields: HashMap, } /// The top_hits metric aggregation results a list of top hits by sort criteria. diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs index 3aaa87907..28f441864 100644 --- a/src/aggregation/metric/top_hits.rs +++ b/src/aggregation/metric/top_hits.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; -use std::fmt::Formatter; +use std::net::Ipv6Addr; use columnar::{ColumnarReader, DynamicColumn}; +use common::DateTime; use regex::Regex; use serde::ser::SerializeMap; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -92,53 +93,61 @@ pub struct TopHitsAggregation { size: usize, from: Option, - #[serde(flatten)] - retrieval: RetrievalFields, -} - -const fn default_doc_value_fields() -> Vec { - Vec::new() -} - -/// Search query spec for each matched document -/// TODO: move this to a common module -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] -pub struct RetrievalFields { - /// The fast fields to return for each hit. - /// This is the only variant supported for now. - /// TODO: support the {field, format} variant for custom formatting. #[serde(rename = "docvalue_fields")] - #[serde(default = "default_doc_value_fields")] - pub doc_value_fields: Vec, + #[serde(default)] + doc_value_fields: Vec, } -/// Search query result for each matched document -/// TODO: move this to a common module -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] -pub struct FieldRetrivalResult { - /// The fast fields returned for each hit. - #[serde(rename = "docvalue_fields")] - #[serde(skip_serializing_if = "HashMap::is_empty")] - pub doc_value_fields: HashMap, +#[derive(Debug, Clone, PartialEq, Default)] +struct KeyOrder { + field: String, + order: Order, } -impl RetrievalFields { - fn get_field_names(&self) -> Vec<&str> { - self.doc_value_fields.iter().map(|s| s.as_str()).collect() +impl Serialize for KeyOrder { + fn serialize(&self, serializer: S) -> Result { + let KeyOrder { field, order } = self; + let mut map = serializer.serialize_map(Some(1))?; + map.serialize_entry(field, order)?; + map.end() } +} - fn resolve_field_names(&mut self, reader: &ColumnarReader) -> crate::Result<()> { - // Tranform a glob (`pattern*`, for example) into a regex::Regex (`^pattern.*$`) - let globbed_string_to_regex = |glob: &str| { - // Replace `*` glob with `.*` regex - let sanitized = format!("^{}$", regex::escape(glob).replace(r"\*", ".*")); - Regex::new(&sanitized.replace('*', ".*")).map_err(|e| { - crate::TantivyError::SchemaError(format!( - "Invalid regex '{}' in docvalue_fields: {}", - glob, e - )) - }) - }; +impl<'de> Deserialize<'de> for KeyOrder { + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de> { + let mut key_order = >::deserialize(deserializer)?.into_iter(); + let (field, order) = key_order.next().ok_or(serde::de::Error::custom( + "Expected exactly one key-value pair in sort parameter of top_hits, found none", + ))?; + if key_order.next().is_some() { + return Err(serde::de::Error::custom(format!( + "Expected exactly one key-value pair in sort parameter of top_hits, found {:?}", + key_order + ))); + } + Ok(Self { field, order }) + } +} + +// Tranform a glob (`pattern*`, for example) into a regex::Regex (`^pattern.*$`) +fn globbed_string_to_regex(glob: &str) -> Result { + // Replace `*` glob with `.*` regex + let sanitized = format!("^{}$", regex::escape(glob).replace(r"\*", ".*")); + Regex::new(&sanitized.replace('*', ".*")).map_err(|e| { + crate::TantivyError::SchemaError(format!( + "Invalid regex '{}' in docvalue_fields: {}", + glob, e + )) + }) +} + +impl TopHitsAggregation { + /// Validate and resolve field retrieval parameters + pub fn validate_and_resolve_field_names( + &mut self, + reader: &ColumnarReader, + ) -> crate::Result<()> { self.doc_value_fields = self .doc_value_fields .iter() @@ -175,12 +184,25 @@ impl RetrievalFields { Ok(()) } + /// Return fields accessed by the aggregator, in order. + pub fn field_names(&self) -> Vec<&str> { + self.sort + .iter() + .map(|KeyOrder { field, .. }| field.as_str()) + .collect() + } + + /// Return fields accessed by the aggregator's value retrieval. + pub fn value_field_names(&self) -> Vec<&str> { + self.doc_value_fields.iter().map(|s| s.as_str()).collect() + } + fn get_document_field_data( &self, accessors: &HashMap>, doc_id: DocId, - ) -> FieldRetrivalResult { - let dvf = self + ) -> HashMap { + let doc_value_fields = self .doc_value_fields .iter() .map(|field| { @@ -188,20 +210,20 @@ impl RetrievalFields { .get(field) .unwrap_or_else(|| panic!("field '{}' not found in accessors", field)); - let values: Vec = accessors + let values: Vec = accessors .iter() .flat_map(|accessor| match accessor { DynamicColumn::U64(accessor) => accessor .values_for_doc(doc_id) - .map(OwnedValue::U64) + .map(FastFieldValue::U64) .collect::>(), DynamicColumn::I64(accessor) => accessor .values_for_doc(doc_id) - .map(OwnedValue::I64) + .map(FastFieldValue::I64) .collect::>(), DynamicColumn::F64(accessor) => accessor .values_for_doc(doc_id) - .map(OwnedValue::F64) + .map(FastFieldValue::F64) .collect::>(), DynamicColumn::Bytes(accessor) => accessor .term_ords(doc_id) @@ -213,7 +235,7 @@ impl RetrievalFields { .expect("could not read term dictionary"), "term corresponding to term_ord does not exist" ); - OwnedValue::Bytes(buffer) + FastFieldValue::Bytes(buffer) }) .collect::>(), DynamicColumn::Str(accessor) => accessor @@ -226,94 +248,82 @@ impl RetrievalFields { .expect("could not read term dictionary"), "term corresponding to term_ord does not exist" ); - OwnedValue::Str(String::from_utf8(buffer).unwrap()) + FastFieldValue::Str(String::from_utf8(buffer).unwrap()) }) .collect::>(), DynamicColumn::Bool(accessor) => accessor .values_for_doc(doc_id) - .map(OwnedValue::Bool) + .map(FastFieldValue::Bool) .collect::>(), DynamicColumn::IpAddr(accessor) => accessor .values_for_doc(doc_id) - .map(OwnedValue::IpAddr) + .map(FastFieldValue::IpAddr) .collect::>(), DynamicColumn::DateTime(accessor) => accessor .values_for_doc(doc_id) - .map(OwnedValue::Date) + .map(FastFieldValue::Date) .collect::>(), }) .collect(); - (field.to_owned(), OwnedValue::Array(values)) + (field.to_owned(), FastFieldValue::Array(values)) }) .collect(); - FieldRetrivalResult { - doc_value_fields: dvf, + doc_value_fields + } +} + +/// A retrieved value from a fast field. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum FastFieldValue { + /// The str type is used for any text information. + Str(String), + /// Unsigned 64-bits Integer `u64` + U64(u64), + /// Signed 64-bits Integer `i64` + I64(i64), + /// 64-bits Float `f64` + F64(f64), + /// Bool value + Bool(bool), + /// Date/time with nanoseconds precision + Date(DateTime), + /// Arbitrarily sized byte array + Bytes(Vec), + /// IpV6 Address. Internally there is no IpV4, it needs to be converted to `Ipv6Addr`. + IpAddr(Ipv6Addr), + /// A list of values. + Array(Vec), +} + +impl From for OwnedValue { + fn from(value: FastFieldValue) -> Self { + match value { + FastFieldValue::Str(s) => OwnedValue::Str(s), + FastFieldValue::U64(u) => OwnedValue::U64(u), + FastFieldValue::I64(i) => OwnedValue::I64(i), + FastFieldValue::F64(f) => OwnedValue::F64(f), + FastFieldValue::Bool(b) => OwnedValue::Bool(b), + FastFieldValue::Date(d) => OwnedValue::Date(d), + FastFieldValue::Bytes(b) => OwnedValue::Bytes(b), + FastFieldValue::IpAddr(ip) => OwnedValue::IpAddr(ip), + FastFieldValue::Array(a) => { + OwnedValue::Array(a.into_iter().map(OwnedValue::from).collect()) + } } } } -#[derive(Debug, Clone, PartialEq, Default)] -struct KeyOrder { - field: String, - order: Order, -} - -impl Serialize for KeyOrder { - fn serialize(&self, serializer: S) -> Result { - let KeyOrder { field, order } = self; - let mut map = serializer.serialize_map(Some(1))?; - map.serialize_entry(field, order)?; - map.end() - } -} - -impl<'de> Deserialize<'de> for KeyOrder { - fn deserialize(deserializer: D) -> Result - where D: Deserializer<'de> { - let mut k_o = >::deserialize(deserializer)?.into_iter(); - let (k, v) = k_o.next().ok_or(serde::de::Error::custom( - "Expected exactly one key-value pair in KeyOrder, found none", - ))?; - if k_o.next().is_some() { - return Err(serde::de::Error::custom( - "Expected exactly one key-value pair in KeyOrder, found more", - )); - } - Ok(Self { field: k, order: v }) - } -} - -impl TopHitsAggregation { - /// Validate and resolve field retrieval parameters - pub fn validate_and_resolve(&mut self, reader: &ColumnarReader) -> crate::Result<()> { - self.retrieval.resolve_field_names(reader) - } - - /// Return fields accessed by the aggregator, in order. - pub fn field_names(&self) -> Vec<&str> { - self.sort - .iter() - .map(|KeyOrder { field, .. }| field.as_str()) - .collect() - } - - /// Return fields accessed by the aggregator's value retrieval. - pub fn value_field_names(&self) -> Vec<&str> { - self.retrieval.get_field_names() - } -} - -/// Holds a single comparable doc feature, and the order in which it should be sorted. +/// Holds a fast field value in its u64 representation, and the order in which it should be sorted. #[derive(Clone, Serialize, Deserialize, Debug)] -struct ComparableDocFeature { - /// Stores any u64-mappable feature. +struct DocValueAndOrder { + /// A fast field value in its u64 representation. value: Option, - /// Sort order for the doc feature + /// Sort order for the value order: Order, } -impl Ord for ComparableDocFeature { +impl Ord for DocValueAndOrder { fn cmp(&self, other: &Self) -> std::cmp::Ordering { let invert = |cmp: std::cmp::Ordering| match self.order { Order::Asc => cmp, @@ -329,26 +339,32 @@ impl Ord for ComparableDocFeature { } } -impl PartialOrd for ComparableDocFeature { +impl PartialOrd for DocValueAndOrder { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl PartialEq for ComparableDocFeature { +impl PartialEq for DocValueAndOrder { fn eq(&self, other: &Self) -> bool { self.value.cmp(&other.value) == std::cmp::Ordering::Equal } } -impl Eq for ComparableDocFeature {} +impl Eq for DocValueAndOrder {} #[derive(Clone, Serialize, Deserialize, Debug)] -struct ComparableDocFeatures(Vec, FieldRetrivalResult); +struct DocSortValuesAndFields { + sorts: Vec, -impl Ord for ComparableDocFeatures { + #[serde(rename = "docvalue_fields")] + #[serde(skip_serializing_if = "HashMap::is_empty")] + doc_value_fields: HashMap, +} + +impl Ord for DocSortValuesAndFields { fn cmp(&self, other: &Self) -> std::cmp::Ordering { - for (self_feature, other_feature) in self.0.iter().zip(other.0.iter()) { + for (self_feature, other_feature) in self.sorts.iter().zip(other.sorts.iter()) { let cmp = self_feature.cmp(other_feature); if cmp != std::cmp::Ordering::Equal { return cmp; @@ -358,53 +374,43 @@ impl Ord for ComparableDocFeatures { } } -impl PartialOrd for ComparableDocFeatures { +impl PartialOrd for DocSortValuesAndFields { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl PartialEq for ComparableDocFeatures { +impl PartialEq for DocSortValuesAndFields { fn eq(&self, other: &Self) -> bool { self.cmp(other) == std::cmp::Ordering::Equal } } -impl Eq for ComparableDocFeatures {} +impl Eq for DocSortValuesAndFields {} /// The TopHitsCollector used for collecting over segments and merging results. -#[derive(Clone, Serialize, Deserialize)] -pub struct TopHitsCollector { +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct TopHitsTopNComputer { req: TopHitsAggregation, - top_n: TopNComputer, + top_n: TopNComputer, } -impl Default for TopHitsCollector { - fn default() -> Self { - Self { - req: TopHitsAggregation::default(), - top_n: TopNComputer::new(1), - } - } -} - -impl std::fmt::Debug for TopHitsCollector { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("TopHitsCollector") - .field("req", &self.req) - .field("top_n_threshold", &self.top_n.threshold) - .finish() - } -} - -impl std::cmp::PartialEq for TopHitsCollector { +impl std::cmp::PartialEq for TopHitsTopNComputer { fn eq(&self, _other: &Self) -> bool { false } } -impl TopHitsCollector { - fn collect(&mut self, features: ComparableDocFeatures, doc: DocAddress) { +impl TopHitsTopNComputer { + /// Create a new TopHitsCollector + pub fn new(req: TopHitsAggregation) -> Self { + Self { + top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), + req, + } + } + + fn collect(&mut self, features: DocSortValuesAndFields, doc: DocAddress) { self.top_n.push(features, doc); } @@ -416,14 +422,19 @@ impl TopHitsCollector { } /// Finalize by converting self into the final result form - pub fn finalize(self) -> TopHitsMetricResult { + pub fn into_final_result(self) -> TopHitsMetricResult { let mut hits: Vec = self .top_n .into_sorted_vec() .into_iter() .map(|doc| TopHitsVecEntry { - sort: doc.feature.0.iter().map(|f| f.value).collect(), - search_results: doc.feature.1, + sort: doc.feature.sorts.iter().map(|f| f.value).collect(), + doc_value_fields: doc + .feature + .doc_value_fields + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(), }) .collect(); @@ -436,48 +447,63 @@ impl TopHitsCollector { } } -#[derive(Clone)] -pub(crate) struct SegmentTopHitsCollector { +#[derive(Clone, Debug)] +pub(crate) struct TopHitsSegmentCollector { segment_ordinal: SegmentOrdinal, accessor_idx: usize, - inner_collector: TopHitsCollector, + req: TopHitsAggregation, + top_n: TopNComputer, DocAddress, false>, } -impl SegmentTopHitsCollector { +impl TopHitsSegmentCollector { pub fn from_req( req: &TopHitsAggregation, accessor_idx: usize, segment_ordinal: SegmentOrdinal, ) -> Self { Self { - inner_collector: TopHitsCollector { - req: req.clone(), - top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), - }, + req: req.clone(), + top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), segment_ordinal, accessor_idx, } } -} + fn into_top_hits_collector( + self, + value_accessors: &HashMap>, + ) -> TopHitsTopNComputer { + let mut top_hits_computer = TopHitsTopNComputer::new(self.req.clone()); + let top_results = self.top_n.into_vec(); -impl std::fmt::Debug for SegmentTopHitsCollector { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SegmentTopHitsCollector") - .field("segment_id", &self.segment_ordinal) - .field("accessor_idx", &self.accessor_idx) - .field("inner_collector", &self.inner_collector) - .finish() + for res in top_results { + let doc_value_fields = self + .req + .get_document_field_data(value_accessors, res.doc.doc_id); + top_hits_computer.collect( + DocSortValuesAndFields { + sorts: res.feature, + doc_value_fields, + }, + res.doc, + ); + } + + top_hits_computer } } -impl SegmentAggregationCollector for SegmentTopHitsCollector { +impl SegmentAggregationCollector for TopHitsSegmentCollector { fn add_intermediate_aggregation_result( self: Box, agg_with_accessor: &crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor, results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults, ) -> crate::Result<()> { let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); - let intermediate_result = IntermediateMetricResult::TopHits(self.inner_collector); + + let value_accessors = &agg_with_accessor.aggs.values[self.accessor_idx].value_accessors; + + let intermediate_result = + IntermediateMetricResult::TopHits(self.into_top_hits_collector(value_accessors)); results.push( name, IntermediateAggregationResult::Metric(intermediate_result), @@ -490,9 +516,7 @@ impl SegmentAggregationCollector for SegmentTopHitsCollector { agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor, ) -> crate::Result<()> { let accessors = &agg_with_accessor.aggs.values[self.accessor_idx].accessors; - let value_accessors = &agg_with_accessor.aggs.values[self.accessor_idx].value_accessors; - let features: Vec = self - .inner_collector + let sorts: Vec = self .req .sort .iter() @@ -505,18 +529,12 @@ impl SegmentAggregationCollector for SegmentTopHitsCollector { .0 .values_for_doc(doc_id) .next(); - ComparableDocFeature { value, order } + DocValueAndOrder { value, order } }) .collect(); - let retrieval_result = self - .inner_collector - .req - .retrieval - .get_document_field_data(value_accessors, doc_id); - - self.inner_collector.collect( - ComparableDocFeatures(features, retrieval_result), + self.top_n.push( + sorts, DocAddress { segment_ord: self.segment_ordinal, doc_id, @@ -530,11 +548,7 @@ impl SegmentAggregationCollector for SegmentTopHitsCollector { docs: &[crate::DocId], agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor, ) -> crate::Result<()> { - // TODO: Consider getting fields with the column block accessor and refactor this. - // --- - // Would the additional complexity of getting fields with the column_block_accessor - // make sense here? Probably yes, but I want to get a first-pass review first - // before proceeding. + // TODO: Consider getting fields with the column block accessor. for doc in docs { self.collect(*doc, agg_with_accessor)?; } @@ -549,7 +563,7 @@ mod tests { use serde_json::Value; use time::macros::datetime; - use super::{ComparableDocFeature, ComparableDocFeatures, Order}; + use super::{DocSortValuesAndFields, DocValueAndOrder, Order}; use crate::aggregation::agg_req::Aggregations; use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::bucket::tests::get_test_index_from_docs; @@ -557,44 +571,44 @@ mod tests { use crate::aggregation::AggregationCollector; use crate::collector::ComparableDoc; use crate::query::AllQuery; - use crate::schema::OwnedValue as SchemaValue; + use crate::schema::OwnedValue; - fn invert_order(cmp_feature: ComparableDocFeature) -> ComparableDocFeature { - let ComparableDocFeature { value, order } = cmp_feature; + fn invert_order(cmp_feature: DocValueAndOrder) -> DocValueAndOrder { + let DocValueAndOrder { value, order } = cmp_feature; let order = match order { Order::Asc => Order::Desc, Order::Desc => Order::Asc, }; - ComparableDocFeature { value, order } + DocValueAndOrder { value, order } } - fn collector_with_capacity(capacity: usize) -> super::TopHitsCollector { - super::TopHitsCollector { + fn collector_with_capacity(capacity: usize) -> super::TopHitsTopNComputer { + super::TopHitsTopNComputer { top_n: super::TopNComputer::new(capacity), - ..Default::default() + req: Default::default(), } } - fn invert_order_features(cmp_features: ComparableDocFeatures) -> ComparableDocFeatures { - let ComparableDocFeatures(cmp_features, search_results) = cmp_features; - let cmp_features = cmp_features + fn invert_order_features(mut cmp_features: DocSortValuesAndFields) -> DocSortValuesAndFields { + cmp_features.sorts = cmp_features + .sorts .into_iter() .map(invert_order) .collect::>(); - ComparableDocFeatures(cmp_features, search_results) + cmp_features } #[test] fn test_comparable_doc_feature() -> crate::Result<()> { - let small = ComparableDocFeature { + let small = DocValueAndOrder { value: Some(1), order: Order::Asc, }; - let big = ComparableDocFeature { + let big = DocValueAndOrder { value: Some(2), order: Order::Asc, }; - let none = ComparableDocFeature { + let none = DocValueAndOrder { value: None, order: Order::Asc, }; @@ -616,21 +630,21 @@ mod tests { #[test] fn test_comparable_doc_features() -> crate::Result<()> { - let features_1 = ComparableDocFeatures( - vec![ComparableDocFeature { + let features_1 = DocSortValuesAndFields { + sorts: vec![DocValueAndOrder { value: Some(1), order: Order::Asc, }], - Default::default(), - ); + doc_value_fields: Default::default(), + }; - let features_2 = ComparableDocFeatures( - vec![ComparableDocFeature { + let features_2 = DocSortValuesAndFields { + sorts: vec![DocValueAndOrder { value: Some(2), order: Order::Asc, }], - Default::default(), - ); + doc_value_fields: Default::default(), + }; assert!(features_1 < features_2); @@ -689,39 +703,39 @@ mod tests { segment_ord: 0, doc_id: 0, }, - feature: ComparableDocFeatures( - vec![ComparableDocFeature { + feature: DocSortValuesAndFields { + sorts: vec![DocValueAndOrder { value: Some(1), order: Order::Asc, }], - Default::default(), - ), + doc_value_fields: Default::default(), + }, }, ComparableDoc { doc: crate::DocAddress { segment_ord: 0, doc_id: 2, }, - feature: ComparableDocFeatures( - vec![ComparableDocFeature { + feature: DocSortValuesAndFields { + sorts: vec![DocValueAndOrder { value: Some(3), order: Order::Asc, }], - Default::default(), - ), + doc_value_fields: Default::default(), + }, }, ComparableDoc { doc: crate::DocAddress { segment_ord: 0, doc_id: 1, }, - feature: ComparableDocFeatures( - vec![ComparableDocFeature { + feature: DocSortValuesAndFields { + sorts: vec![DocValueAndOrder { value: Some(5), order: Order::Asc, }], - Default::default(), - ), + doc_value_fields: Default::default(), + }, }, ]; @@ -730,23 +744,23 @@ mod tests { collector.collect(doc.feature, doc.doc); } - let res = collector.finalize(); + let res = collector.into_final_result(); assert_eq!( res, super::TopHitsMetricResult { hits: vec![ super::TopHitsVecEntry { - sort: vec![docs[0].feature.0[0].value], - search_results: Default::default(), + sort: vec![docs[0].feature.sorts[0].value], + doc_value_fields: Default::default(), }, super::TopHitsVecEntry { - sort: vec![docs[1].feature.0[0].value], - search_results: Default::default(), + sort: vec![docs[1].feature.sorts[0].value], + doc_value_fields: Default::default(), }, super::TopHitsVecEntry { - sort: vec![docs[2].feature.0[0].value], - search_results: Default::default(), + sort: vec![docs[2].feature.sorts[0].value], + doc_value_fields: Default::default(), }, ] } @@ -803,7 +817,7 @@ mod tests { { "sort": [common::i64_to_u64(date_2017.unix_timestamp_nanos() as i64)], "docvalue_fields": { - "date": [ SchemaValue::Date(DateTime::from_utc(date_2017)) ], + "date": [ OwnedValue::Date(DateTime::from_utc(date_2017)) ], "text": [ "ccc" ], "text2": [ "ddd" ], "mixed.dyn_arr": [ 3, "4" ], @@ -812,7 +826,7 @@ mod tests { { "sort": [common::i64_to_u64(date_2016.unix_timestamp_nanos() as i64)], "docvalue_fields": { - "date": [ SchemaValue::Date(DateTime::from_utc(date_2016)) ], + "date": [ OwnedValue::Date(DateTime::from_utc(date_2016)) ], "text": [ "aaa" ], "text2": [ "bbb" ], "mixed.dyn_arr": [ 6, "7" ], diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 570dc3f03..76f0eb284 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -16,7 +16,7 @@ use super::metric::{ SumAggregation, }; use crate::aggregation::bucket::TermMissingAgg; -use crate::aggregation::metric::SegmentTopHitsCollector; +use crate::aggregation::metric::TopHitsSegmentCollector; pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug { fn add_intermediate_aggregation_result( @@ -161,7 +161,7 @@ pub(crate) fn build_single_agg_segment_collector( accessor_idx, )?, )), - TopHits(top_hits_req) => Ok(Box::new(SegmentTopHitsCollector::from_req( + TopHits(top_hits_req) => Ok(Box::new(TopHitsSegmentCollector::from_req( top_hits_req, accessor_idx, req.segment_ordinal, diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 917b2c3f7..415625bc1 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -732,6 +732,19 @@ pub struct TopNComputer { top_n: usize, pub(crate) threshold: Option, } + +impl std::fmt::Debug + for TopNComputer +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TopNComputer") + .field("buffer_len", &self.buffer.len()) + .field("top_n", &self.top_n) + .field("current_threshold", &self.threshold) + .finish() + } +} + // Intermediate struct for TopNComputer for deserialization, to keep vec capacity #[derive(Deserialize)] struct TopNComputerDeser {