From b7121d4f8e469d40e7fb464db01e6cd7fdda083d Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Wed, 3 Dec 2025 21:07:12 -0800 Subject: [PATCH] Add a SortByOwnedValue implementation to provide a type-erased column. --- src/collector/sort_key/mod.rs | 61 +++- src/collector/sort_key/order.rs | 59 +++- src/collector/sort_key/sort_by_owned.rs | 352 ++++++++++++++++++++++++ 3 files changed, 464 insertions(+), 8 deletions(-) create mode 100644 src/collector/sort_key/sort_by_owned.rs diff --git a/src/collector/sort_key/mod.rs b/src/collector/sort_key/mod.rs index 3bfb3b1c8..038de25e8 100644 --- a/src/collector/sort_key/mod.rs +++ b/src/collector/sort_key/mod.rs @@ -1,10 +1,12 @@ mod order; +mod sort_by_owned; mod sort_by_score; mod sort_by_static_fast_value; mod sort_by_string; mod sort_key_computer; pub use order::*; +pub use sort_by_owned::SortByOwnedValue; pub use sort_by_score::SortBySimilarityScore; pub use sort_by_static_fast_value::SortByStaticFastValue; pub use sort_by_string::SortByString; @@ -34,11 +36,13 @@ pub(crate) mod tests { use std::collections::HashMap; use std::ops::Range; - use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString}; + use crate::collector::sort_key::{ + SortByOwnedValue, SortBySimilarityScore, SortByStaticFastValue, SortByString, + }; use crate::collector::{ComparableDoc, DocSetCollector, TopDocs}; use crate::indexer::NoMergePolicy; use crate::query::{AllQuery, QueryParser}; - use crate::schema::{Schema, FAST, TEXT}; + use crate::schema::{OwnedValue, Schema, FAST, TEXT}; use crate::{DocAddress, Document, Index, Order, Score, Searcher}; fn make_index() -> crate::Result { @@ -313,11 +317,9 @@ pub(crate) mod tests { (SortBySimilarityScore, score_order), (SortByString::for_field("city"), city_order), )); - Ok(searcher - .search(&AllQuery, &top_collector)? - .into_iter() - .map(|(f, doc)| (f, ids[&doc])) - .collect()) + let results: Vec<((Score, Option), DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; + Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect()) } assert_eq!( @@ -342,6 +344,51 @@ pub(crate) mod tests { Ok(()) } + #[test] + fn test_order_by_score_then_owned_value() -> crate::Result<()> { + let index = make_index()?; + + type SortKey = (Score, OwnedValue); + + fn query( + index: &Index, + score_order: Order, + city_order: Order, + ) -> crate::Result> { + let searcher = index.reader()?.searcher(); + let ids = id_mapping(&searcher); + + let top_collector = TopDocs::with_limit(4).order_by::<(Score, OwnedValue)>(( + (SortBySimilarityScore, score_order), + (SortByOwnedValue::for_field("city"), city_order), + )); + let results: Vec<((Score, OwnedValue), DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; + Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect()) + } + + assert_eq!( + &query(&index, Order::Asc, Order::Asc)?, + &[ + ((1.0, OwnedValue::Str("austin".to_owned())), 0), + ((1.0, OwnedValue::Str("greenville".to_owned())), 1), + ((1.0, OwnedValue::Str("tokyo".to_owned())), 2), + ((1.0, OwnedValue::Null), 3), + ] + ); + + assert_eq!( + &query(&index, Order::Asc, Order::Desc)?, + &[ + ((1.0, OwnedValue::Str("tokyo".to_owned())), 2), + ((1.0, OwnedValue::Str("greenville".to_owned())), 1), + ((1.0, OwnedValue::Str("austin".to_owned())), 0), + ((1.0, OwnedValue::Null), 3), + ] + ); + Ok(()) + } + use proptest::prelude::*; proptest! { diff --git a/src/collector/sort_key/order.rs b/src/collector/sort_key/order.rs index 28d5a5343..319bb736e 100644 --- a/src/collector/sort_key/order.rs +++ b/src/collector/sort_key/order.rs @@ -1,11 +1,43 @@ use std::cmp::Ordering; +use columnar::MonotonicallyMappableToU64; use serde::{Deserialize, Serialize}; use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; -use crate::schema::Schema; +use crate::schema::{OwnedValue, Schema}; use crate::{DocId, Order, Score}; +fn compare_owned_value(lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + match (lhs, rhs) { + (OwnedValue::Null, OwnedValue::Null) => Ordering::Equal, + (OwnedValue::Null, _) => { + if NULLS_FIRST { + Ordering::Less + } else { + Ordering::Greater + } + } + (_, OwnedValue::Null) => { + if NULLS_FIRST { + Ordering::Greater + } else { + Ordering::Less + } + } + (OwnedValue::Str(a), OwnedValue::Str(b)) => a.cmp(b), + (OwnedValue::PreTokStr(a), OwnedValue::PreTokStr(b)) => a.cmp(b), + (OwnedValue::U64(a), OwnedValue::U64(b)) => a.cmp(b), + (OwnedValue::I64(a), OwnedValue::I64(b)) => a.cmp(b), + (OwnedValue::F64(a), OwnedValue::F64(b)) => a.to_u64().cmp(&b.to_u64()), + (OwnedValue::Bool(a), OwnedValue::Bool(b)) => a.cmp(b), + (OwnedValue::Date(a), OwnedValue::Date(b)) => a.cmp(b), + (OwnedValue::Facet(a), OwnedValue::Facet(b)) => a.cmp(b), + (OwnedValue::Bytes(a), OwnedValue::Bytes(b)) => a.cmp(b), + (OwnedValue::IpAddr(a), OwnedValue::IpAddr(b)) => a.cmp(b), + x => panic!("Unsupported comparison: {x:?}"), + } +} + /// Comparator trait defining the order in which documents should be ordered. pub trait Comparator: Send + Sync + std::fmt::Debug + Default { /// Return the order between two values. @@ -29,6 +61,17 @@ impl Comparator for NaturalComparator { } } +/// A (partial) implementation of comparison for OwnedValue. +/// +/// Intended for use within columns of homogenous types, and so will panic for OwnedValues with +/// mismatched types. The one exception is Null, for which we do define all comparisons. +impl Comparator for NaturalComparator { + #[inline(always)] + fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + compare_owned_value::(lhs, rhs) + } +} + /// Compare values in reverse (e.g. 2 < 1). /// /// When used with `TopDocs`, which reverses the order, this results in an @@ -121,6 +164,13 @@ impl Comparator for ReverseNoneIsLowerComparator { } } +impl Comparator for ReverseNoneIsLowerComparator { + #[inline(always)] + fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + compare_owned_value::(rhs, lhs) + } +} + /// Compare values naturally, but treating `None` as higher than `Some`. /// /// When used with `TopDocs`, which reverses the order, this results in a @@ -185,6 +235,13 @@ impl Comparator for NaturalNoneIsHigherComparator { } } +impl Comparator for NaturalNoneIsHigherComparator { + #[inline(always)] + fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + compare_owned_value::(lhs, rhs) + } +} + /// An enum representing the different sort orders. #[derive(Debug, Clone, Copy, Eq, PartialEq, Default)] pub enum ComparatorEnum { diff --git a/src/collector/sort_key/sort_by_owned.rs b/src/collector/sort_key/sort_by_owned.rs new file mode 100644 index 000000000..77e275eb1 --- /dev/null +++ b/src/collector/sort_key/sort_by_owned.rs @@ -0,0 +1,352 @@ +use columnar::MonotonicallyMappableToU64; + +use crate::collector::sort_key::{ + NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString, +}; +use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; +use crate::schema::{OwnedValue, Type}; +use crate::{DateTime, DocId, Score}; + +/// Sort by the boxed / OwnedValue representation of either a fast field, or of the score. +/// +/// Using the OwnedValue representation allows for type erasure, and can be useful when sort orders +/// are not known until runtime. But it comes with a performance cost: wherever possible, prefer to +/// use a SortKeyComputer implementation with a known-type at compile time. +#[derive(Debug, Clone)] +pub enum SortByOwnedValue { + Field(String), + Score, +} + +impl SortByOwnedValue { + /// Creates a new sort key computer which will sort by the given fast field column, with type + /// erasure. + pub fn for_field(column_name: impl ToString) -> Self { + Self::Field(column_name.to_string()) + } + + /// Creates a new sort key computer which will sort by score, with type erasure. + pub fn for_score() -> Self { + Self::Score + } +} + +/// TODO: Rename to Boxed...? Or Owned? +trait AnySegmentSortKeyComputer: Send + Sync { + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option; + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue; +} + +struct AnySegmentSortKeyComputerWrapper { + inner: C, + converter: F, +} + +impl AnySegmentSortKeyComputer for AnySegmentSortKeyComputerWrapper +where + C: SegmentSortKeyComputer> + Send + Sync, + F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static, +{ + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option { + self.inner.segment_sort_key(doc, score) + } + + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { + let val = self.inner.convert_segment_sort_key(sort_key); + (self.converter)(val) + } +} + +struct ScoreSegmentSortKeyComputer { + segment_computer: SortBySimilarityScore, +} + +impl AnySegmentSortKeyComputer for ScoreSegmentSortKeyComputer { + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option { + let score_value: f64 = self.segment_computer.segment_sort_key(doc, score).into(); + Some(score_value.to_u64()) + } + + fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { + let score_value: u64 = sort_key.expect("This implementation always produces a score."); + OwnedValue::F64(f64::from_u64(score_value)) + } +} + +impl SortKeyComputer for SortByOwnedValue { + type SortKey = OwnedValue; + type Child = ByOwnedValueColumnSegmentSortKeyComputer; + type Comparator = NaturalComparator; + + fn requires_scoring(&self) -> bool { + matches!(self, Self::Score) + } + + fn segment_sort_key_computer( + &self, + segment_reader: &crate::SegmentReader, + ) -> crate::Result { + let schema = segment_reader.schema(); + let inner: Box = match self { + Self::Field(column_name) => { + let field = schema.get_field(column_name)?; + let field_entry = schema.get_field_entry(field); + let field_type = field_entry.field_type(); + + match field_type.value_type() { + Type::Str => { + let computer = SortByString::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(AnySegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null) + }, + }) + } + Type::U64 => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(AnySegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null) + }, + }) + } + Type::I64 => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(AnySegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null) + }, + }) + } + Type::F64 => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(AnySegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null) + }, + }) + } + Type::Bool => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(AnySegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null) + }, + }) + } + Type::Date => { + let computer = SortByStaticFastValue::::for_field(column_name); + let inner = computer.segment_sort_key_computer(segment_reader)?; + Box::new(AnySegmentSortKeyComputerWrapper { + inner, + converter: |val: Option| { + val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null) + }, + }) + } + _ => { + return Err(crate::TantivyError::SchemaError(format!( + "Field `{}` is of type {:?}, which is not supported for sorting by \ + owned value yet.", + column_name, + field_type.value_type() + ))) + } + } + } + Self::Score => Box::new(ScoreSegmentSortKeyComputer { + segment_computer: SortBySimilarityScore, + }), + }; + Ok(ByOwnedValueColumnSegmentSortKeyComputer { inner }) + } +} + +pub struct ByOwnedValueColumnSegmentSortKeyComputer { + inner: Box, +} + +impl SegmentSortKeyComputer for ByOwnedValueColumnSegmentSortKeyComputer { + type SortKey = OwnedValue; + type SegmentSortKey = Option; + type SegmentComparator = NaturalComparator; + + #[inline(always)] + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option { + self.inner.segment_sort_key(doc, score) + } + + fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue { + self.inner.convert_segment_sort_key(segment_sort_key) + } +} + +#[cfg(test)] +mod tests { + use crate::collector::sort_key::{ComparatorEnum, SortByOwnedValue}; + use crate::collector::TopDocs; + use crate::query::AllQuery; + use crate::schema::{OwnedValue, Schema, FAST, TEXT}; + use crate::Index; + + #[test] + fn test_sort_by_owned_u64() { + let mut schema_builder = Schema::builder(); + let id_field = schema_builder.add_u64_field("id", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(id_field => 10u64)).unwrap(); + writer.add_document(doc!(id_field => 2u64)).unwrap(); + writer.add_document(doc!()).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let collector = TopDocs::with_limit(10) + .order_by((SortByOwnedValue::for_field("id"), ComparatorEnum::Natural)); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![OwnedValue::U64(10), OwnedValue::U64(2), OwnedValue::Null] + ); + + let collector = TopDocs::with_limit(10).order_by(( + SortByOwnedValue::for_field("id"), + ComparatorEnum::ReverseNoneLower, + )); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![OwnedValue::U64(2), OwnedValue::U64(10), OwnedValue::Null] + ); + } + + #[test] + fn test_sort_by_owned_string() { + let mut schema_builder = Schema::builder(); + let city_field = schema_builder.add_text_field("city", FAST | TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(city_field => "tokyo")).unwrap(); + writer.add_document(doc!(city_field => "austin")).unwrap(); + writer.add_document(doc!()).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let collector = TopDocs::with_limit(10).order_by(( + SortByOwnedValue::for_field("city"), + ComparatorEnum::ReverseNoneLower, + )); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![ + OwnedValue::Str("austin".to_string()), + OwnedValue::Str("tokyo".to_string()), + OwnedValue::Null + ] + ); + } + + #[test] + fn test_sort_by_owned_reverse() { + let mut schema_builder = Schema::builder(); + let id_field = schema_builder.add_u64_field("id", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(id_field => 10u64)).unwrap(); + writer.add_document(doc!(id_field => 2u64)).unwrap(); + writer.add_document(doc!()).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + let collector = TopDocs::with_limit(10) + .order_by((SortByOwnedValue::for_field("id"), ComparatorEnum::Reverse)); + let top_docs = searcher.search(&AllQuery, &collector).unwrap(); + + let values: Vec = top_docs.into_iter().map(|(key, _)| key).collect(); + + assert_eq!( + values, + vec![OwnedValue::Null, OwnedValue::U64(2), OwnedValue::U64(10)] + ); + } + + #[test] + fn test_sort_by_owned_score() { + let mut schema_builder = Schema::builder(); + let body_field = schema_builder.add_text_field("body", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + writer.add_document(doc!(body_field => "a a")).unwrap(); + writer.add_document(doc!(body_field => "a")).unwrap(); + writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let query_parser = crate::query::QueryParser::for_index(&index, vec![body_field]); + let query = query_parser.parse_query("a").unwrap(); + + // Sort by score descending (Natural) + let collector = TopDocs::with_limit(10) + .order_by((SortByOwnedValue::for_score(), ComparatorEnum::Natural)); + let top_docs = searcher.search(&query, &collector).unwrap(); + + let values: Vec = top_docs + .into_iter() + .map(|(key, _)| match key { + OwnedValue::F64(val) => val, + _ => panic!("Wrong type {:?}", key), + }) + .collect(); + + assert_eq!(values.len(), 2); + assert!(values[0] > values[1]); + + // Sort by score ascending (ReverseNoneLower) + let collector = TopDocs::with_limit(10).order_by(( + SortByOwnedValue::for_score(), + ComparatorEnum::ReverseNoneLower, + )); + let top_docs = searcher.search(&query, &collector).unwrap(); + + let values: Vec = top_docs + .into_iter() + .map(|(key, _)| match key { + OwnedValue::F64(val) => val, + _ => panic!("Wrong type {:?}", key), + }) + .collect(); + + assert_eq!(values.len(), 2); + assert!(values[0] < values[1]); + } +}