diff --git a/src/collector/sort_key/mod.rs b/src/collector/sort_key/mod.rs index a66115633..3c3ddbd5c 100644 --- a/src/collector/sort_key/mod.rs +++ b/src/collector/sort_key/mod.rs @@ -1,10 +1,12 @@ mod order; +mod sort_by_erased_type; mod sort_by_score; mod sort_by_static_fast_value; mod sort_by_string; mod sort_key_computer; pub use order::*; +pub use sort_by_erased_type::SortByErasedType; pub use sort_by_score::SortBySimilarityScore; pub use sort_by_static_fast_value::SortByStaticFastValue; pub use sort_by_string::SortByString; @@ -15,11 +17,13 @@ mod tests { use std::collections::HashMap; use std::ops::Range; - use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString}; + use crate::collector::sort_key::{ + SortByErasedType, SortBySimilarityScore, SortByStaticFastValue, SortByString, + }; use crate::collector::{ComparableDoc, DocSetCollector, TopDocs}; use crate::indexer::NoMergePolicy; use crate::query::{AllQuery, QueryParser}; - use crate::schema::{Schema, FAST, TEXT}; + use crate::schema::{OwnedValue, Schema, FAST, TEXT}; use crate::{DocAddress, Document, Index, Order, Score, Searcher}; fn make_index() -> crate::Result { @@ -294,11 +298,9 @@ mod tests { (SortBySimilarityScore, score_order), (SortByString::for_field("city"), city_order), )); - Ok(searcher - .search(&AllQuery, &top_collector)? - .into_iter() - .map(|(f, doc)| (f, ids[&doc])) - .collect()) + let results: Vec<((Score, Option), DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; + Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect()) } assert_eq!( @@ -323,6 +325,51 @@ mod tests { Ok(()) } + #[test] + fn test_order_by_score_then_owned_value() -> crate::Result<()> { + let index = make_index()?; + + type SortKey = (Score, OwnedValue); + + fn query( + index: &Index, + score_order: Order, + city_order: Order, + ) -> crate::Result> { + let searcher = index.reader()?.searcher(); + let ids = id_mapping(&searcher); + + let top_collector = TopDocs::with_limit(4).order_by::<(Score, OwnedValue)>(( + (SortBySimilarityScore, score_order), + (SortByErasedType::for_field("city"), city_order), + )); + let results: Vec<((Score, OwnedValue), DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; + Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect()) + } + + assert_eq!( + &query(&index, Order::Asc, Order::Asc)?, + &[ + ((1.0, OwnedValue::Str("austin".to_owned())), 0), + ((1.0, OwnedValue::Str("greenville".to_owned())), 1), + ((1.0, OwnedValue::Str("tokyo".to_owned())), 2), + ((1.0, OwnedValue::Null), 3), + ] + ); + + assert_eq!( + &query(&index, Order::Asc, Order::Desc)?, + &[ + ((1.0, OwnedValue::Str("tokyo".to_owned())), 2), + ((1.0, OwnedValue::Str("greenville".to_owned())), 1), + ((1.0, OwnedValue::Str("austin".to_owned())), 0), + ((1.0, OwnedValue::Null), 3), + ] + ); + Ok(()) + } + use proptest::prelude::*; proptest! { diff --git a/src/collector/sort_key/order.rs b/src/collector/sort_key/order.rs index 923d5cb8e..db5e4d56d 100644 --- a/src/collector/sort_key/order.rs +++ b/src/collector/sort_key/order.rs @@ -1,11 +1,70 @@ use std::cmp::Ordering; +use columnar::MonotonicallyMappableToU64; use serde::{Deserialize, Serialize}; use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; -use crate::schema::Schema; +use crate::schema::{OwnedValue, Schema}; use crate::{DocId, Order, Score}; +fn compare_owned_value(lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + match (lhs, rhs) { + (OwnedValue::Null, OwnedValue::Null) => Ordering::Equal, + (OwnedValue::Null, _) => { + if NULLS_FIRST { + Ordering::Less + } else { + Ordering::Greater + } + } + (_, OwnedValue::Null) => { + if NULLS_FIRST { + Ordering::Greater + } else { + Ordering::Less + } + } + (OwnedValue::Str(a), OwnedValue::Str(b)) => a.cmp(b), + (OwnedValue::PreTokStr(a), OwnedValue::PreTokStr(b)) => a.cmp(b), + (OwnedValue::U64(a), OwnedValue::U64(b)) => a.cmp(b), + (OwnedValue::I64(a), OwnedValue::I64(b)) => a.cmp(b), + (OwnedValue::F64(a), OwnedValue::F64(b)) => a.to_u64().cmp(&b.to_u64()), + (OwnedValue::Bool(a), OwnedValue::Bool(b)) => a.cmp(b), + (OwnedValue::Date(a), OwnedValue::Date(b)) => a.cmp(b), + (OwnedValue::Facet(a), OwnedValue::Facet(b)) => a.cmp(b), + (OwnedValue::Bytes(a), OwnedValue::Bytes(b)) => a.cmp(b), + (OwnedValue::IpAddr(a), OwnedValue::IpAddr(b)) => a.cmp(b), + (OwnedValue::U64(a), OwnedValue::I64(b)) => { + if *b < 0 { + Ordering::Greater + } else { + a.cmp(&(*b as u64)) + } + } + (OwnedValue::I64(a), OwnedValue::U64(b)) => { + if *a < 0 { + Ordering::Less + } else { + (*a as u64).cmp(b) + } + } + (OwnedValue::U64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()), + (OwnedValue::F64(a), OwnedValue::U64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()), + (OwnedValue::I64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()), + (OwnedValue::F64(a), OwnedValue::I64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()), + (a, b) => { + let ord = a.discriminant_value().cmp(&b.discriminant_value()); + // If the discriminant is equal, it's because a new type was added, but hasn't been + // included in this `match` statement. + assert!( + ord != Ordering::Equal, + "Unimplemented comparison for type of {a:?}, {b:?}" + ); + ord + } + } +} + /// Comparator trait defining the order in which documents should be ordered. pub trait Comparator: Send + Sync + std::fmt::Debug + Default { /// Return the order between two values. @@ -24,6 +83,17 @@ impl Comparator for NaturalComparator { } } +/// A (partial) implementation of comparison for OwnedValue. +/// +/// Intended for use within columns of homogenous types, and so will panic for OwnedValues with +/// mismatched types. The one exception is Null, for which we do define all comparisons. +impl Comparator for NaturalComparator { + #[inline(always)] + fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + compare_owned_value::(lhs, rhs) + } +} + /// Sorts document in reverse order. /// /// If the sort key is None, it will considered as the lowest value, and will therefore appear @@ -107,6 +177,13 @@ impl Comparator for ReverseNoneIsLowerComparator { } } +impl Comparator for ReverseNoneIsLowerComparator { + #[inline(always)] + fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + compare_owned_value::(rhs, lhs) + } +} + /// An enum representing the different sort orders. #[derive(Debug, Clone, Copy, Eq, PartialEq, Default)] pub enum ComparatorEnum { @@ -322,11 +399,12 @@ impl SegmentSortKeyComput for SegmentSortKeyComputerWithComparator where TSegmentSortKeyComputer: SegmentSortKeyComputer, - TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send, + TSegmentSortKey: Clone + 'static + Sync + Send, TComparator: Comparator + 'static + Sync + Send, { type SortKey = TSegmentSortKeyComputer::SortKey; type SegmentSortKey = TSegmentSortKey; + type SegmentComparator = TComparator; fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { self.segment_sort_key_computer.segment_sort_key(doc, score) @@ -346,3 +424,32 @@ where .convert_segment_sort_key(sort_key) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::schema::OwnedValue; + + #[test] + fn test_mixed_ownedvalue_compare() { + let u = OwnedValue::U64(10); + let i = OwnedValue::I64(10); + let f = OwnedValue::F64(10.0); + + let nc = NaturalComparator::default(); + assert_eq!(nc.compare(&u, &i), Ordering::Equal); + assert_eq!(nc.compare(&u, &f), Ordering::Equal); + assert_eq!(nc.compare(&i, &f), Ordering::Equal); + + let u2 = OwnedValue::U64(11); + assert_eq!(nc.compare(&u2, &f), Ordering::Greater); + + let s = OwnedValue::Str("a".to_string()); + // Str < U64 + assert_eq!(nc.compare(&s, &u), Ordering::Less); + // Str < I64 + assert_eq!(nc.compare(&s, &i), Ordering::Less); + // Str < F64 + assert_eq!(nc.compare(&s, &f), Ordering::Less); + } +} diff --git a/src/collector/sort_key/sort_by_erased_type.rs b/src/collector/sort_key/sort_by_erased_type.rs new file mode 100644 index 000000000..4f9365d06 --- /dev/null +++ b/src/collector/sort_key/sort_by_erased_type.rs @@ -0,0 +1,361 @@ +use columnar::{ColumnType, MonotonicallyMappableToU64}; + +use crate::collector::sort_key::{ + NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString, +}; +use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; +use crate::fastfield::FastFieldNotAvailableError; +use crate::schema::OwnedValue; +use crate::{DateTime, DocId, Score}; + +/// Sort by the boxed / OwnedValue representation of either a fast field, or of the score. +/// +/// Using the OwnedValue representation allows for type erasure, and can be useful when sort orders +/// are not known until runtime. But it comes with a performance cost: wherever possible, prefer to +/// use a SortKeyComputer implementation with a known-type at compile time. +#[derive(Debug, Clone)] +pub enum SortByErasedType { + /// Sort by a fast field + Field(String), + /// Sort by score + Score, +} + +impl SortByErasedType { + /// Creates a new sort key computer which will sort by the given fast field column, with type + /// erasure. + pub fn for_field(column_name: impl ToString) -> Self { + Self::Field(column_name.to_string()) + } + + /// Creates a new sort key computer which will sort by score, with type erasure. + pub fn for_score() -> Self { + Self::Score + } +} + +trait ErasedSegmentSortKeyComputer: Send + Sync { + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option; + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue; +} + +struct ErasedSegmentSortKeyComputerWrapper { + inner: C, + converter: F, +} + +impl ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper +where + C: SegmentSortKeyComputer> + Send + Sync, + F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static, +{ + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option { + self.inner.segment_sort_key(doc, score) + } + + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { + let val = self.inner.convert_segment_sort_key(sort_key); + (self.converter)(val) + } +} + +struct ScoreSegmentSortKeyComputer { + segment_computer: SortBySimilarityScore, +} + +impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer { + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option { + let score_value: f64 = self.segment_computer.segment_sort_key(doc, score).into(); + Some(score_value.to_u64()) + } + + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { + let score_value: u64 = sort_key.expect("This implementation always produces a score."); + OwnedValue::F64(f64::from_u64(score_value)) + } +} + +impl SortKeyComputer for SortByErasedType { + type SortKey = OwnedValue; + type Child = ErasedColumnSegmentSortKeyComputer; + type Comparator = NaturalComparator; + + fn requires_scoring(&self) -> bool { + matches!(self, Self::Score) + } + + fn segment_sort_key_computer( + &self, + segment_reader: &crate::SegmentReader, + ) -> crate::Result { + let inner: Box = match self { + Self::Field(column_name) => { + let fast_fields = segment_reader.fast_fields(); + // TODO: We currently double-open the column to avoid relying on the implementation + // details of `SortByString` or `SortByStaticFastValue`. Once + // https://github.com/quickwit-oss/tantivy/issues/2776 is resolved, we should + // consider directly constructing the appropriate `SegmentSortKeyComputer` type for + // the column that we open here. + let (_column, column_type) = + fast_fields.u64_lenient(column_name)?.ok_or_else(|| { + FastFieldNotAvailableError { + field_name: column_name.to_owned(), + } + })?; + + match column_type { + ColumnType::Str => { + let computer = SortByString::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::U64 => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::I64 => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::F64 => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::Bool => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null) + }, + }) + } + ColumnType::DateTime => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(ErasedSegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null) + }, + }) + } + column_type => { + return Err(crate::TantivyError::SchemaError(format!( + "Field `{}` is of type {column_type:?}, which is not supported for \ + sorting by owned value yet.", + column_name + ))) + } + } + } + Self::Score => Box::new(ScoreSegmentSortKeyComputer { + segment_computer: SortBySimilarityScore, + }), + }; + Ok(ErasedColumnSegmentSortKeyComputer { inner }) + } +} + +pub struct ErasedColumnSegmentSortKeyComputer { + inner: Box, +} + +impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer { + type SortKey = OwnedValue; + type SegmentSortKey = Option; + type SegmentComparator = NaturalComparator; + + #[inline(always)] + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option { + self.inner.segment_sort_key(doc, score) + } + + fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue { + self.inner.convert_segment_sort_key(segment_sort_key) + } +} + +#[cfg(test)] +mod tests { + use crate::collector::sort_key::{ComparatorEnum, SortByErasedType}; + use crate::collector::TopDocs; + use crate::query::AllQuery; + use crate::schema::{OwnedValue, Schema, FAST, TEXT}; + use crate::Index; + + #[test] + fn test_sort_by_owned_u64() { + let mut schema_builder = Schema::builder(); + let id_field = schema_builder.add_u64_field("id", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(id_field => 10u64)).unwrap(); + writer.add_document(doc!(id_field => 2u64)).unwrap(); + writer.add_document(doc!()).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let collector = TopDocs::with_limit(10) + .order_by((SortByErasedType::for_field("id"), ComparatorEnum::Natural)); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![OwnedValue::U64(10), OwnedValue::U64(2), OwnedValue::Null] + ); + + let collector = TopDocs::with_limit(10).order_by(( + SortByErasedType::for_field("id"), + ComparatorEnum::ReverseNoneLower, + )); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![OwnedValue::U64(2), OwnedValue::U64(10), OwnedValue::Null] + ); + } + + #[test] + fn test_sort_by_owned_string() { + let mut schema_builder = Schema::builder(); + let city_field = schema_builder.add_text_field("city", FAST | TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(city_field => "tokyo")).unwrap(); + writer.add_document(doc!(city_field => "austin")).unwrap(); + writer.add_document(doc!()).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let collector = TopDocs::with_limit(10).order_by(( + SortByErasedType::for_field("city"), + ComparatorEnum::ReverseNoneLower, + )); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![ + OwnedValue::Str("austin".to_string()), + OwnedValue::Str("tokyo".to_string()), + OwnedValue::Null + ] + ); + } + + #[test] + fn test_sort_by_owned_reverse() { + let mut schema_builder = Schema::builder(); + let id_field = schema_builder.add_u64_field("id", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(id_field => 10u64)).unwrap(); + writer.add_document(doc!(id_field => 2u64)).unwrap(); + writer.add_document(doc!()).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let collector = TopDocs::with_limit(10) + .order_by((SortByErasedType::for_field("id"), ComparatorEnum::Reverse)); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![OwnedValue::Null, OwnedValue::U64(2), OwnedValue::U64(10)] + ); + } + + #[test] + fn test_sort_by_owned_score() { + let mut schema_builder = Schema::builder(); + let body_field = schema_builder.add_text_field("body", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(body_field => "a a")).unwrap(); + writer.add_document(doc!(body_field => "a")).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let query_parser = crate::query::QueryParser::for_index(&index, vec![body_field]); + let query = query_parser.parse_query("a").unwrap(); + + // Sort by score descending (Natural) + let collector = TopDocs::with_limit(10) + .order_by((SortByErasedType::for_score(), ComparatorEnum::Natural)); + let top_docs = searcher.search(&query, &collector).unwrap(); + + let values: Vec = top_docs + .into_iter() + .map(|(key, _)| match key { + OwnedValue::F64(val) => val, + _ => panic!("Wrong type {:?}", key), + }) + .collect(); + + assert_eq!(values.len(), 2); + assert!(values[0] > values[1]); + + // Sort by score ascending (ReverseNoneLower) + let collector = TopDocs::with_limit(10).order_by(( + SortByErasedType::for_score(), + ComparatorEnum::ReverseNoneLower, + )); + let top_docs = searcher.search(&query, &collector).unwrap(); + + let values: Vec = top_docs + .into_iter() + .map(|(key, _)| match key { + OwnedValue::F64(val) => val, + _ => panic!("Wrong type {:?}", key), + }) + .collect(); + + assert_eq!(values.len(), 2); + assert!(values[0] < values[1]); + } +} diff --git a/src/collector/sort_key/sort_by_score.rs b/src/collector/sort_key/sort_by_score.rs index df8b0dd75..a23660e56 100644 --- a/src/collector/sort_key/sort_by_score.rs +++ b/src/collector/sort_key/sort_by_score.rs @@ -63,8 +63,8 @@ impl SortKeyComputer for SortBySimilarityScore { impl SegmentSortKeyComputer for SortBySimilarityScore { type SortKey = Score; - type SegmentSortKey = Score; + type SegmentComparator = NaturalComparator; #[inline(always)] fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score { diff --git a/src/collector/sort_key/sort_by_static_fast_value.rs b/src/collector/sort_key/sort_by_static_fast_value.rs index b38b8b034..44a4e1d8d 100644 --- a/src/collector/sort_key/sort_by_static_fast_value.rs +++ b/src/collector/sort_key/sort_by_static_fast_value.rs @@ -34,9 +34,7 @@ impl SortByStaticFastValue { impl SortKeyComputer for SortByStaticFastValue { type Child = SortByFastValueSegmentSortKeyComputer; - type SortKey = Option; - type Comparator = NaturalComparator; fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> { @@ -84,8 +82,8 @@ pub struct SortByFastValueSegmentSortKeyComputer { impl SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer { type SortKey = Option; - type SegmentSortKey = Option; + type SegmentComparator = NaturalComparator; #[inline(always)] fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey { diff --git a/src/collector/sort_key/sort_by_string.rs b/src/collector/sort_key/sort_by_string.rs index 41ef22e9b..2dd0b4592 100644 --- a/src/collector/sort_key/sort_by_string.rs +++ b/src/collector/sort_key/sort_by_string.rs @@ -30,9 +30,7 @@ impl SortByString { impl SortKeyComputer for SortByString { type SortKey = Option; - type Child = ByStringColumnSegmentSortKeyComputer; - type Comparator = NaturalComparator; fn segment_sort_key_computer( @@ -50,8 +48,8 @@ pub struct ByStringColumnSegmentSortKeyComputer { impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { type SortKey = Option; - type SegmentSortKey = Option; + type SegmentComparator = NaturalComparator; #[inline(always)] fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option { @@ -60,6 +58,8 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { } fn convert_segment_sort_key(&self, term_ord_opt: Option) -> Option { + // TODO: Individual lookups to the dictionary like this are very likely to repeatedly + // decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776 let term_ord = term_ord_opt?; let str_column = self.str_column_opt.as_ref()?; let mut bytes = Vec::new(); diff --git a/src/collector/sort_key/sort_key_computer.rs b/src/collector/sort_key/sort_key_computer.rs index d56fa7cd0..6aab919a9 100644 --- a/src/collector/sort_key/sort_key_computer.rs +++ b/src/collector/sort_key/sort_key_computer.rs @@ -12,13 +12,21 @@ use crate::{DocAddress, DocId, Result, Score, SegmentReader}; /// It is the segment local version of the [`SortKeyComputer`]. pub trait SegmentSortKeyComputer: 'static { /// The final score being emitted. - type SortKey: 'static + PartialOrd + Send + Sync + Clone; + type SortKey: 'static + Send + Sync + Clone; /// Sort key used by at the segment level by the `SegmentSortKeyComputer`. /// /// It is typically small like a `u64`, and is meant to be converted /// to the final score at the end of the collection of the segment. - type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone; + type SegmentSortKey: 'static + Clone + Send + Sync + Clone; + + /// Comparator type. + type SegmentComparator: Comparator + 'static; + + /// Returns the segment sort key comparator. + fn segment_comparator(&self) -> Self::SegmentComparator { + Self::SegmentComparator::default() + } /// Computes the sort key for the given document and score. fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey; @@ -47,7 +55,7 @@ pub trait SegmentSortKeyComputer: 'static { left: &Self::SegmentSortKey, right: &Self::SegmentSortKey, ) -> Ordering { - NaturalComparator.compare(left, right) + self.segment_comparator().compare(left, right) } /// Implementing this method makes it possible to avoid computing @@ -81,7 +89,7 @@ pub trait SegmentSortKeyComputer: 'static { /// the sort key at a segment scale. pub trait SortKeyComputer: Sync { /// The sort key type. - type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug; + type SortKey: 'static + Send + Sync + Clone + std::fmt::Debug; /// Type of the associated [`SegmentSortKeyComputer`]. type Child: SegmentSortKeyComputer; /// Comparator type. @@ -136,10 +144,7 @@ where HeadSortKeyComputer: SortKeyComputer, TailSortKeyComputer: SortKeyComputer, { - type SortKey = ( - ::SortKey, - ::SortKey, - ); + type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey); type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child); type Comparator = ( @@ -188,6 +193,11 @@ where TailSegmentSortKeyComputer::SegmentSortKey, ); + type SegmentComparator = ( + HeadSegmentSortKeyComputer::SegmentComparator, + TailSegmentSortKeyComputer::SegmentComparator, + ); + /// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on /// its ordering. /// @@ -269,11 +279,12 @@ impl SegmentSortKeyComputer for MappedSegmentSortKeyComputer where T: SegmentSortKeyComputer, - PreviousScore: 'static + Clone + Send + Sync + PartialOrd, - NewScore: 'static + Clone + Send + Sync + PartialOrd, + PreviousScore: 'static + Clone + Send + Sync, + NewScore: 'static + Clone + Send + Sync, { type SortKey = NewScore; type SegmentSortKey = T::SegmentSortKey; + type SegmentComparator = T::SegmentComparator; fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { self.sort_key_computer.segment_sort_key(doc, score) @@ -463,6 +474,7 @@ where { type SortKey = TSortKey; type SegmentSortKey = TSortKey; + type SegmentComparator = NaturalComparator; fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey { (self)(doc) diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 78c344dbe..631169bc0 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -325,7 +325,7 @@ impl TopDocs { sort_key_computer: impl SortKeyComputer + Send + 'static, ) -> impl Collector> where - TSortKey: 'static + Clone + Send + Sync + PartialOrd + std::fmt::Debug, + TSortKey: 'static + Clone + Send + Sync + std::fmt::Debug, { TopBySortKeyCollector::new(sort_key_computer, self.doc_range()) } @@ -446,7 +446,7 @@ where F: 'static + Send + Sync + Fn(&SegmentReader) -> TTweakScoreSortKeyFn, TTweakScoreSortKeyFn: 'static + Fn(DocId, Score) -> TSortKey, TweakScoreSegmentSortKeyComputer: - SegmentSortKeyComputer, + SegmentSortKeyComputer, TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug, { type SortKey = TSortKey; @@ -481,6 +481,7 @@ where { type SortKey = TSortKey; type SegmentSortKey = TSortKey; + type SegmentComparator = NaturalComparator; fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey { (self.sort_key_fn)(doc, score) diff --git a/src/schema/document/owned_value.rs b/src/schema/document/owned_value.rs index 9fbf1f8c2..49a6b1ac7 100644 --- a/src/schema/document/owned_value.rs +++ b/src/schema/document/owned_value.rs @@ -58,6 +58,31 @@ impl AsRef for OwnedValue { } } +impl OwnedValue { + /// Returns a u8 discriminant value for the `OwnedValue` variant. + /// + /// This can be used to sort `OwnedValue` instances by their type. + pub fn discriminant_value(&self) -> u8 { + match self { + OwnedValue::Null => 0, + OwnedValue::Str(_) => 1, + OwnedValue::PreTokStr(_) => 2, + // It is key to make sure U64, I64, F64 are grouped together in there, otherwise we + // might be breaking transivity. + OwnedValue::U64(_) => 3, + OwnedValue::I64(_) => 4, + OwnedValue::F64(_) => 5, + OwnedValue::Bool(_) => 6, + OwnedValue::Date(_) => 7, + OwnedValue::Facet(_) => 8, + OwnedValue::Bytes(_) => 9, + OwnedValue::Array(_) => 10, + OwnedValue::Object(_) => 11, + OwnedValue::IpAddr(_) => 12, + } + } +} + impl<'a> Value<'a> for &'a OwnedValue { type ArrayIter = std::slice::Iter<'a, OwnedValue>; type ObjectIter = ObjectMapIter<'a>;