diff --git a/src/collector/sort_key/mod.rs b/src/collector/sort_key/mod.rs index 48bd86122..04725a698 100644 --- a/src/collector/sort_key/mod.rs +++ b/src/collector/sort_key/mod.rs @@ -483,4 +483,67 @@ mod tests { ); } } + + #[test] + fn test_order_by_compound_filtering_with_none() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let city = schema_builder.add_text_field("city", TEXT | FAST); + let altitude = schema_builder.add_u64_field("altitude", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests()?; + + // Add enough docs to trigger thresholding. + // We want to sort by City Asc, Altitude Asc. + // Note: In NaturalComparator, None < Some. + // So Ascending order should be: None, then "a", then "b", then "c". + + // Docs: + // 0: "c", 10 + // 1: "b", 10 + // 2: "a", 20 + // 3: "a", 10 + // 4: None, 5 + + // Expected Ascending Order (None is Last in Tantivy's Order::Asc): + // 1. Doc 3 ("a", 10) + // 2. Doc 2 ("a", 20) + // 3. Doc 1 ("b", 10) + // 4. Doc 0 ("c", 10) + // 5. Doc 4 (None, 5) + + index_writer.add_document(doc!(city => "c", altitude => 10u64))?; + index_writer.add_document(doc!(city => "b", altitude => 10u64))?; + index_writer.add_document(doc!(city => "a", altitude => 20u64))?; + index_writer.add_document(doc!(city => "a", altitude => 10u64))?; + index_writer.add_document(doc!(altitude => 5u64))?; // City is None + + index_writer.commit()?; + + let searcher = index.reader()?.searcher(); + + // Use limit(2) to force a threshold update after the first few docs. + // The collector should eventually establish a threshold around ("a", 20) (Top 2: "a" 10, + // "a" 20). Then when seeing "b" and "c", it should filter them out based on the + // head key "city". This confirms that when filtering happens, the DocIds are + // preserved correctly. + let top_collector = TopDocs::with_limit(2).order_by(( + (SortByString::for_field("city"), Order::Asc), + ( + SortByStaticFastValue::::for_field("altitude"), + Order::Asc, + ), + )); + + let results: Vec = searcher + .search(&AllQuery, &top_collector)? + .into_iter() + .map(|(_, doc)| doc) + .collect(); + + // Doc 3 is ("a", 10). Doc 2 is ("a", 20). + assert_eq!(results, vec![DocAddress::new(0, 3), DocAddress::new(0, 2)]); + + Ok(()) + } } diff --git a/src/collector/sort_key/order.rs b/src/collector/sort_key/order.rs index 123aa7c14..ac5a45d44 100644 --- a/src/collector/sort_key/order.rs +++ b/src/collector/sort_key/order.rs @@ -645,8 +645,13 @@ where self.segment_sort_key_computer.segment_sort_key(doc, score) } - fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec { - self.segment_sort_key_computer.segment_sort_keys(docs) + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + self.segment_sort_key_computer + .segment_sort_keys(docs, filter) } #[inline(always)] diff --git a/src/collector/sort_key/sort_by_erased_type.rs b/src/collector/sort_key/sort_by_erased_type.rs index 370be0021..71afbb293 100644 --- a/src/collector/sort_key/sort_by_erased_type.rs +++ b/src/collector/sort_key/sort_by_erased_type.rs @@ -1,4 +1,4 @@ -use columnar::{ColumnType, MonotonicallyMappableToU64}; +use columnar::{ColumnType, MonotonicallyMappableToU64, ValueRange}; use crate::collector::sort_key::sort_by_score::SortBySimilarityScoreSegmentComputer; use crate::collector::sort_key::{ @@ -37,7 +37,11 @@ impl SortByErasedType { trait ErasedSegmentSortKeyComputer: Send + Sync { fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option; - fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec>; + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange>, + ) -> &mut Vec<(DocId, Option)>; fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue; } @@ -55,8 +59,12 @@ where self.inner.segment_sort_key(doc, score) } - fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec> { - self.inner.segment_sort_keys(docs) + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange>, + ) -> &mut Vec<(DocId, Option)> { + self.inner.segment_sort_keys(docs, filter) } fn convert_segment_sort_key(&self, sort_key: Option) -> OwnedValue { @@ -75,7 +83,11 @@ impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer { Some(score_value.to_u64()) } - fn segment_sort_keys(&mut self, _docs: &[DocId]) -> &mut Vec> { + fn segment_sort_keys( + &mut self, + _docs: &[DocId], + _filter: ValueRange>, + ) -> &mut Vec<(DocId, Option)> { unimplemented!("Batch computation not supported for score sorting") } @@ -206,8 +218,12 @@ impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer { self.inner.segment_sort_key(doc, score) } - fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec { - self.inner.segment_sort_keys(docs) + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + self.inner.segment_sort_keys(docs, filter) } fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue { diff --git a/src/collector/sort_key/sort_by_score.rs b/src/collector/sort_key/sort_by_score.rs index b9072669e..4da9643df 100644 --- a/src/collector/sort_key/sort_by_score.rs +++ b/src/collector/sort_key/sort_by_score.rs @@ -1,3 +1,5 @@ +use columnar::ValueRange; + use crate::collector::sort_key::NaturalComparator; use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer}; use crate::{DocAddress, DocId, Score}; @@ -73,7 +75,11 @@ impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer { score } - fn segment_sort_keys(&mut self, _docs: &[DocId]) -> &mut Vec { + fn segment_sort_keys( + &mut self, + _docs: &[DocId], + _filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { unimplemented!("Batch computation not supported for score sorting") } diff --git a/src/collector/sort_key/sort_by_static_fast_value.rs b/src/collector/sort_key/sort_by_static_fast_value.rs index 2f43b3a8f..8608155f6 100644 --- a/src/collector/sort_key/sort_by_static_fast_value.rs +++ b/src/collector/sort_key/sort_by_static_fast_value.rs @@ -2,6 +2,7 @@ use std::marker::PhantomData; use columnar::{Column, ValueRange}; +use crate::collector::sort_key::sort_key_computer::convert_optional_u64_range_to_u64_range; use crate::collector::sort_key::NaturalComparator; use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; use crate::fastfield::{FastFieldNotAvailableError, FastValue}; @@ -80,7 +81,7 @@ impl SortKeyComputer for SortByStaticFastValue { pub struct SortByFastValueSegmentSortKeyComputer { sort_column: Column, typ: PhantomData, - buffer: Vec>, + buffer: Vec<(DocId, Option)>, fetch_buffer: Vec>>, } @@ -94,14 +95,22 @@ impl SegmentSortKeyComputer for SortByFastValueSegmentSortKeyCompu self.sort_column.first(doc) } - fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec { + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { self.fetch_buffer.resize(docs.len(), None); + let u64_filter = convert_optional_u64_range_to_u64_range(filter); self.sort_column - .first_vals_in_value_range(docs, &mut self.fetch_buffer, ValueRange::All); + .first_vals_in_value_range(docs, &mut self.fetch_buffer, u64_filter); self.buffer.clear(); - self.buffer - .extend(self.fetch_buffer.iter().map(|val| val.flatten())); + for (&doc, val) in docs.iter().zip(self.fetch_buffer.iter()) { + if let Some(val) = val { + self.buffer.push((doc, *val)); + } + } &mut self.buffer } @@ -130,6 +139,7 @@ mod tests { index_writer .add_document(crate::doc!(field_col => 20u64)) .unwrap(); + index_writer.add_document(crate::doc!()).unwrap(); index_writer.commit().unwrap(); let reader = index.reader().unwrap(); @@ -139,9 +149,45 @@ mod tests { let sorter = SortByStaticFastValue::::for_field("field"); let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap(); - let docs = vec![0, 1]; - let output = computer.segment_sort_keys(&docs); + let docs = vec![0, 1, 2]; + let output = computer.segment_sort_keys(&docs, ValueRange::All); - assert_eq!(output, &[Some(10), Some(20)]); + assert_eq!(output, &[(0, Some(10)), (1, Some(20)), (2, None)]); + } + + #[test] + fn test_sort_by_fast_value_batch_with_filter() { + let mut schema_builder = Schema::builder(); + let field_col = schema_builder.add_u64_field("field", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests().unwrap(); + + index_writer + .add_document(crate::doc!(field_col => 10u64)) + .unwrap(); + index_writer + .add_document(crate::doc!(field_col => 20u64)) + .unwrap(); + index_writer.add_document(crate::doc!()).unwrap(); + index_writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let sorter = SortByStaticFastValue::::for_field("field"); + let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap(); + + let docs = vec![0, 1, 2]; + let output = computer.segment_sort_keys( + &docs, + ValueRange::GreaterThan(Some(15u64), false /* inclusive */), + ); + + // Should contain only the document with value 20. + // Doc 0 (10) < 15 + // Doc 2 (None) < 15 + assert_eq!(output, &[(1, Some(20))]); } } diff --git a/src/collector/sort_key/sort_by_string.rs b/src/collector/sort_key/sort_by_string.rs index c1e28bcb1..232d10149 100644 --- a/src/collector/sort_key/sort_by_string.rs +++ b/src/collector/sort_key/sort_by_string.rs @@ -1,5 +1,8 @@ use columnar::{StrColumn, ValueRange}; +use crate::collector::sort_key::sort_key_computer::{ + convert_optional_u64_range_to_u64_range, range_contains_none, +}; use crate::collector::sort_key::NaturalComparator; use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; use crate::termdict::TermOrdinal; @@ -48,7 +51,7 @@ impl SortKeyComputer for SortByString { pub struct ByStringColumnSegmentSortKeyComputer { str_column_opt: Option, - buffer: Vec>, + buffer: Vec<(DocId, Option)>, fetch_buffer: Vec>>, } @@ -63,18 +66,29 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { str_column.ords().first(doc) } - fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec { + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { self.fetch_buffer.resize(docs.len(), None); if let Some(str_column) = &self.str_column_opt { - str_column.ords().first_vals_in_value_range( - docs, - &mut self.fetch_buffer, - ValueRange::All, - ); + let u64_filter = convert_optional_u64_range_to_u64_range(filter); + str_column + .ords() + .first_vals_in_value_range(docs, &mut self.fetch_buffer, u64_filter); + } else if range_contains_none(&filter) { + self.fetch_buffer.fill(Some(None)); + } else { + self.fetch_buffer.fill(None); } + self.buffer.clear(); - self.buffer - .extend(self.fetch_buffer.iter().map(|val| val.flatten())); + for (&doc, val) in docs.iter().zip(self.fetch_buffer.iter()) { + if let Some(val) = val { + self.buffer.push((doc, *val)); + } + } &mut self.buffer } @@ -91,3 +105,80 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { String::try_from(bytes).ok() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::schema::{Schema, FAST, TEXT}; + use crate::Index; + + #[test] + fn test_sort_by_string_batch() { + let mut schema_builder = Schema::builder(); + let field_col = schema_builder.add_text_field("field", FAST | TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests().unwrap(); + + index_writer + .add_document(crate::doc!(field_col => "a")) + .unwrap(); + index_writer + .add_document(crate::doc!(field_col => "c")) + .unwrap(); + index_writer.add_document(crate::doc!()).unwrap(); + index_writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let sorter = SortByString::for_field("field"); + let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap(); + + let docs = vec![0, 1, 2]; + let output = computer.segment_sort_keys(&docs, ValueRange::All); + + // We expect ordinals. + // "a" -> 0 + // "c" -> 1 + assert_eq!(output, &[(0, Some(0)), (1, Some(1)), (2, None)]); + } + + #[test] + fn test_sort_by_string_batch_with_filter() { + let mut schema_builder = Schema::builder(); + let field_col = schema_builder.add_text_field("field", FAST | TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests().unwrap(); + + index_writer + .add_document(crate::doc!(field_col => "a")) + .unwrap(); + index_writer + .add_document(crate::doc!(field_col => "c")) + .unwrap(); + index_writer.add_document(crate::doc!()).unwrap(); + index_writer.commit().unwrap(); + + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let segment_reader = searcher.segment_reader(0); + + let sorter = SortByString::for_field("field"); + let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap(); + + let docs = vec![0, 1, 2]; + // Filter: > "b". "a" is 0, "c" is 1. + // We want > "a" (ord 0). So we filter > ord 0. + // 0 is "a", 1 is "c". + let output = computer.segment_sort_keys( + &docs, + ValueRange::GreaterThan(Some(0), false /* inclusive */), + ); + + // Should contain only the document with value "c" (ord 1). + assert_eq!(output, &[(1, Some(1))]); + } +} diff --git a/src/collector/sort_key/sort_key_computer.rs b/src/collector/sort_key/sort_key_computer.rs index fabaee50d..d21de75aa 100644 --- a/src/collector/sort_key/sort_key_computer.rs +++ b/src/collector/sort_key/sort_key_computer.rs @@ -1,5 +1,7 @@ use std::cmp::Ordering; +use columnar::ValueRange; + use crate::collector::sort_key::{Comparator, NaturalComparator}; use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector; use crate::collector::top_score_collector::push_assuming_capacity; @@ -38,7 +40,11 @@ pub trait SegmentSortKeyComputer: 'static { /// /// The computed sort keys are stored in an internal buffer and returned as a slice. /// Subsequent calls to this method may reuse and overwrite the internal buffer. - fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec; + fn segment_sort_keys( + &mut self, + docs: &[DocId], + filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)>; /// Computes the sort key and pushes the document in a TopN Computer. /// @@ -63,14 +69,18 @@ pub trait SegmentSortKeyComputer: 'static { // should always be able to `reserve` space for the entire block. top_n_computer.reserve(docs.len()); - if let Some(threshold) = &top_n_computer.threshold { - // TODO: Would need to split the borrow of the TopNComputer to avoid cloning the - // threshold here. - let threshold = threshold.clone(); - let comparator = self.segment_comparator(); - let sort_keys = self.segment_sort_keys(docs); + let comparator = self.segment_comparator(); + let value_range = if let Some(threshold) = &top_n_computer.threshold { + comparator.threshold_to_valuerange(threshold.clone()) + } else { + ValueRange::All + }; - for (&doc, sort_key) in docs.iter().zip(sort_keys.drain(..)) { + let sort_keys = self.segment_sort_keys(docs, value_range); + + if let Some(threshold) = &top_n_computer.threshold { + let threshold = threshold.clone(); + for (doc, sort_key) in sort_keys.drain(..) { let cmp = comparator.compare(&sort_key, &threshold); if cmp == Ordering::Greater { // We validated at the top of the method that we have capacity. @@ -79,8 +89,7 @@ pub trait SegmentSortKeyComputer: 'static { } } else { // Eagerly push, without a threshold to compare to. - let sort_keys = self.segment_sort_keys(docs); - for (&doc, sort_key) in docs.iter().zip(sort_keys.drain(..)) { + for (doc, sort_key) in sort_keys.drain(..) { // We validated at the top of the method that we have capacity. top_n_computer.append_doc_unchecked(doc, sort_key); } @@ -302,7 +311,11 @@ where .then_with(|| self.tail.compare_segment_sort_key(&left.1, &right.1)) } - fn segment_sort_keys(&mut self, _docs: &[DocId]) -> &mut Vec { + fn segment_sort_keys( + &mut self, + _docs: &[DocId], + _filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { unimplemented!("The head and the tail are accessed independently."); } @@ -339,11 +352,12 @@ where let (head_threshold, tail_threshold) = threshold.clone(); let head_cmp = self.head.segment_comparator(); let tail_cmp = self.tail.segment_comparator(); + let head_filter = head_cmp.threshold_to_valuerange(head_threshold.clone()); - let head_keys = self.head.segment_sort_keys(docs); + let head_keys = self.head.segment_sort_keys(docs, head_filter); self.doc_buffer.clear(); self.head_key_buffer.clear(); - for (head_key, &doc) in head_keys.drain(..).zip(docs) { + for (doc, head_key) in head_keys.drain(..) { let cmp = head_cmp.compare(&head_key, &head_threshold); if cmp != Ordering::Less { self.doc_buffer.push(doc); @@ -352,11 +366,13 @@ where } if !self.doc_buffer.is_empty() { - let tail_keys = self.tail.segment_sort_keys(&self.doc_buffer); + let tail_keys = self + .tail + .segment_sort_keys(&self.doc_buffer, ValueRange::All); for ((head_key, tail_key), &doc) in self .head_key_buffer .drain(..) - .zip(tail_keys.drain(..)) + .zip(tail_keys.drain(..).map(|(_, k)| k)) .zip(self.doc_buffer.iter()) { let head_ord = head_cmp.compare(&head_key, &head_threshold); @@ -372,15 +388,11 @@ where } } else { // Eagerly push, without a threshold to compare to. - let head_keys = self.head.segment_sort_keys(docs); - let tail_keys = self.tail.segment_sort_keys(docs); - for ((doc, head_key), tail_key) in docs - .iter() - .zip(head_keys.drain(..)) - .zip(tail_keys.drain(..)) - { + let head_keys = self.head.segment_sort_keys(docs, ValueRange::All); + let tail_keys = self.tail.segment_sort_keys(docs, ValueRange::All); + for ((doc, head_key), (_, tail_key)) in head_keys.drain(..).zip(tail_keys.drain(..)) { // We validated at the top of the method that we have capacity. - top_n_computer.append_doc_unchecked(*doc, (head_key, tail_key)); + top_n_computer.append_doc_unchecked(doc, (head_key, tail_key)); } } } @@ -427,8 +439,13 @@ where self.sort_key_computer.segment_sort_key(doc, score) } - fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec { - self.sort_key_computer.segment_sort_keys(docs) + fn segment_sort_keys( + &mut self, + docs: &[DocId], + _filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { + self.sort_key_computer + .segment_sort_keys(docs, ValueRange::All) } #[inline(always)] @@ -605,7 +622,7 @@ where pub struct FuncSegmentSortKeyComputer { func: F, - buffer: Vec, + buffer: Vec<(DocId, TSortKey)>, } impl SortKeyComputer for F @@ -639,11 +656,15 @@ where (self.func)(doc) } - fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec { + fn segment_sort_keys( + &mut self, + docs: &[DocId], + _filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { self.buffer.clear(); self.buffer.reserve(docs.len()); for &doc in docs { - self.buffer.push((self.func)(doc)); + self.buffer.push((doc, (self.func)(doc))); } &mut self.buffer } @@ -654,6 +675,34 @@ where } } +pub(crate) fn range_contains_none(range: &ValueRange>) -> bool { + match range { + ValueRange::All => true, + ValueRange::Inclusive(r) => r.contains(&None), + ValueRange::GreaterThan(threshold, match_nulls) => *match_nulls || (None > *threshold), + ValueRange::LessThan(threshold, match_nulls) => *match_nulls || (None < *threshold), + } +} + +pub(crate) fn convert_optional_u64_range_to_u64_range( + range: ValueRange>, +) -> ValueRange { + if range_contains_none(&range) { + return ValueRange::All; + } + match range { + ValueRange::Inclusive(r) => { + let start = r.start().unwrap_or(0); + let end = r.end().unwrap_or(u64::MAX); + ValueRange::Inclusive(start..=end) + } + ValueRange::GreaterThan(Some(val), _match_nulls) => ValueRange::GreaterThan(val, false), + ValueRange::GreaterThan(None, _match_nulls) => ValueRange::Inclusive(u64::MIN..=u64::MAX), + ValueRange::LessThan(None, _match_nulls) => ValueRange::Inclusive(1..=0), + _ => ValueRange::All, + } +} + #[cfg(test)] mod tests { use std::cmp::Ordering; diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index b022ec780..94c081bec 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -2,6 +2,7 @@ use std::cmp::Ordering; use std::fmt; use std::ops::Range; +use columnar::ValueRange; use serde::{Deserialize, Serialize}; use super::Collector; @@ -486,7 +487,11 @@ where (self.sort_key_fn)(doc, score) } - fn segment_sort_keys(&mut self, _docs: &[DocId]) -> &mut Vec { + fn segment_sort_keys( + &mut self, + _docs: &[DocId], + _filter: ValueRange, + ) -> &mut Vec<(DocId, Self::SegmentSortKey)> { unimplemented!("Batch computation is not supported for tweak score.") }