Add a ValueRange filter to SegmentSortKeyComputer::segment_sort_keys.

This commit is contained in:
Stu Hood
2025-12-26 11:30:36 -07:00
parent 996fc936f6
commit 0c920dfc61
8 changed files with 337 additions and 56 deletions

View File

@@ -483,4 +483,67 @@ mod tests {
);
}
}
#[test]
fn test_order_by_compound_filtering_with_none() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let city = schema_builder.add_text_field("city", TEXT | FAST);
let altitude = schema_builder.add_u64_field("altitude", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests()?;
// Add enough docs to trigger thresholding.
// We want to sort by City Asc, Altitude Asc.
// Note: In NaturalComparator, None < Some.
// So Ascending order should be: None, then "a", then "b", then "c".
// Docs:
// 0: "c", 10
// 1: "b", 10
// 2: "a", 20
// 3: "a", 10
// 4: None, 5
// Expected Ascending Order (None is Last in Tantivy's Order::Asc):
// 1. Doc 3 ("a", 10)
// 2. Doc 2 ("a", 20)
// 3. Doc 1 ("b", 10)
// 4. Doc 0 ("c", 10)
// 5. Doc 4 (None, 5)
index_writer.add_document(doc!(city => "c", altitude => 10u64))?;
index_writer.add_document(doc!(city => "b", altitude => 10u64))?;
index_writer.add_document(doc!(city => "a", altitude => 20u64))?;
index_writer.add_document(doc!(city => "a", altitude => 10u64))?;
index_writer.add_document(doc!(altitude => 5u64))?; // City is None
index_writer.commit()?;
let searcher = index.reader()?.searcher();
// Use limit(2) to force a threshold update after the first few docs.
// The collector should eventually establish a threshold around ("a", 20) (Top 2: "a" 10,
// "a" 20). Then when seeing "b" and "c", it should filter them out based on the
// head key "city". This confirms that when filtering happens, the DocIds are
// preserved correctly.
let top_collector = TopDocs::with_limit(2).order_by((
(SortByString::for_field("city"), Order::Asc),
(
SortByStaticFastValue::<u64>::for_field("altitude"),
Order::Asc,
),
));
let results: Vec<DocAddress> = searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(_, doc)| doc)
.collect();
// Doc 3 is ("a", 10). Doc 2 is ("a", 20).
assert_eq!(results, vec![DocAddress::new(0, 3), DocAddress::new(0, 2)]);
Ok(())
}
}

View File

@@ -645,8 +645,13 @@ where
self.segment_sort_key_computer.segment_sort_key(doc, score)
}
fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec<Self::SegmentSortKey> {
self.segment_sort_key_computer.segment_sort_keys(docs)
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.segment_sort_key_computer
.segment_sort_keys(docs, filter)
}
#[inline(always)]

View File

@@ -1,4 +1,4 @@
use columnar::{ColumnType, MonotonicallyMappableToU64};
use columnar::{ColumnType, MonotonicallyMappableToU64, ValueRange};
use crate::collector::sort_key::sort_by_score::SortBySimilarityScoreSegmentComputer;
use crate::collector::sort_key::{
@@ -37,7 +37,11 @@ impl SortByErasedType {
trait ErasedSegmentSortKeyComputer: Send + Sync {
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64>;
fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec<Option<u64>>;
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Option<u64>>,
) -> &mut Vec<(DocId, Option<u64>)>;
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue;
}
@@ -55,8 +59,12 @@ where
self.inner.segment_sort_key(doc, score)
}
fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec<Option<u64>> {
self.inner.segment_sort_keys(docs)
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Option<u64>>,
) -> &mut Vec<(DocId, Option<u64>)> {
self.inner.segment_sort_keys(docs, filter)
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
@@ -75,7 +83,11 @@ impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
Some(score_value.to_u64())
}
fn segment_sort_keys(&mut self, _docs: &[DocId]) -> &mut Vec<Option<u64>> {
fn segment_sort_keys(
&mut self,
_docs: &[DocId],
_filter: ValueRange<Option<u64>>,
) -> &mut Vec<(DocId, Option<u64>)> {
unimplemented!("Batch computation not supported for score sorting")
}
@@ -206,8 +218,12 @@ impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer {
self.inner.segment_sort_key(doc, score)
}
fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec<Self::SegmentSortKey> {
self.inner.segment_sort_keys(docs)
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.inner.segment_sort_keys(docs, filter)
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue {

View File

@@ -1,3 +1,5 @@
use columnar::ValueRange;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::{DocAddress, DocId, Score};
@@ -73,7 +75,11 @@ impl SegmentSortKeyComputer for SortBySimilarityScoreSegmentComputer {
score
}
fn segment_sort_keys(&mut self, _docs: &[DocId]) -> &mut Vec<Self::SegmentSortKey> {
fn segment_sort_keys(
&mut self,
_docs: &[DocId],
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
unimplemented!("Batch computation not supported for score sorting")
}

View File

@@ -2,6 +2,7 @@ use std::marker::PhantomData;
use columnar::{Column, ValueRange};
use crate::collector::sort_key::sort_key_computer::convert_optional_u64_range_to_u64_range;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
@@ -80,7 +81,7 @@ impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
pub struct SortByFastValueSegmentSortKeyComputer<T> {
sort_column: Column<u64>,
typ: PhantomData<T>,
buffer: Vec<Option<u64>>,
buffer: Vec<(DocId, Option<u64>)>,
fetch_buffer: Vec<Option<Option<u64>>>,
}
@@ -94,14 +95,22 @@ impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyCompu
self.sort_column.first(doc)
}
fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec<Self::SegmentSortKey> {
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.fetch_buffer.resize(docs.len(), None);
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
self.sort_column
.first_vals_in_value_range(docs, &mut self.fetch_buffer, ValueRange::All);
.first_vals_in_value_range(docs, &mut self.fetch_buffer, u64_filter);
self.buffer.clear();
self.buffer
.extend(self.fetch_buffer.iter().map(|val| val.flatten()));
for (&doc, val) in docs.iter().zip(self.fetch_buffer.iter()) {
if let Some(val) = val {
self.buffer.push((doc, *val));
}
}
&mut self.buffer
}
@@ -130,6 +139,7 @@ mod tests {
index_writer
.add_document(crate::doc!(field_col => 20u64))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
@@ -139,9 +149,45 @@ mod tests {
let sorter = SortByStaticFastValue::<u64>::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let docs = vec![0, 1];
let output = computer.segment_sort_keys(&docs);
let docs = vec![0, 1, 2];
let output = computer.segment_sort_keys(&docs, ValueRange::All);
assert_eq!(output, &[Some(10), Some(20)]);
assert_eq!(output, &[(0, Some(10)), (1, Some(20)), (2, None)]);
}
#[test]
fn test_sort_by_fast_value_batch_with_filter() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_u64_field("field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => 10u64))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => 20u64))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByStaticFastValue::<u64>::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let docs = vec![0, 1, 2];
let output = computer.segment_sort_keys(
&docs,
ValueRange::GreaterThan(Some(15u64), false /* inclusive */),
);
// Should contain only the document with value 20.
// Doc 0 (10) < 15
// Doc 2 (None) < 15
assert_eq!(output, &[(1, Some(20))]);
}
}

View File

@@ -1,5 +1,8 @@
use columnar::{StrColumn, ValueRange};
use crate::collector::sort_key::sort_key_computer::{
convert_optional_u64_range_to_u64_range, range_contains_none,
};
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
@@ -48,7 +51,7 @@ impl SortKeyComputer for SortByString {
pub struct ByStringColumnSegmentSortKeyComputer {
str_column_opt: Option<StrColumn>,
buffer: Vec<Option<TermOrdinal>>,
buffer: Vec<(DocId, Option<TermOrdinal>)>,
fetch_buffer: Vec<Option<Option<TermOrdinal>>>,
}
@@ -63,18 +66,29 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
str_column.ords().first(doc)
}
fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec<Self::SegmentSortKey> {
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.fetch_buffer.resize(docs.len(), None);
if let Some(str_column) = &self.str_column_opt {
str_column.ords().first_vals_in_value_range(
docs,
&mut self.fetch_buffer,
ValueRange::All,
);
let u64_filter = convert_optional_u64_range_to_u64_range(filter);
str_column
.ords()
.first_vals_in_value_range(docs, &mut self.fetch_buffer, u64_filter);
} else if range_contains_none(&filter) {
self.fetch_buffer.fill(Some(None));
} else {
self.fetch_buffer.fill(None);
}
self.buffer.clear();
self.buffer
.extend(self.fetch_buffer.iter().map(|val| val.flatten()));
for (&doc, val) in docs.iter().zip(self.fetch_buffer.iter()) {
if let Some(val) = val {
self.buffer.push((doc, *val));
}
}
&mut self.buffer
}
@@ -91,3 +105,80 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
String::try_from(bytes).ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{Schema, FAST, TEXT};
use crate::Index;
#[test]
fn test_sort_by_string_batch() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => "a"))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => "c"))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByString::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let docs = vec![0, 1, 2];
let output = computer.segment_sort_keys(&docs, ValueRange::All);
// We expect ordinals.
// "a" -> 0
// "c" -> 1
assert_eq!(output, &[(0, Some(0)), (1, Some(1)), (2, None)]);
}
#[test]
fn test_sort_by_string_batch_with_filter() {
let mut schema_builder = Schema::builder();
let field_col = schema_builder.add_text_field("field", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::doc!(field_col => "a"))
.unwrap();
index_writer
.add_document(crate::doc!(field_col => "c"))
.unwrap();
index_writer.add_document(crate::doc!()).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let sorter = SortByString::for_field("field");
let mut computer = sorter.segment_sort_key_computer(segment_reader).unwrap();
let docs = vec![0, 1, 2];
// Filter: > "b". "a" is 0, "c" is 1.
// We want > "a" (ord 0). So we filter > ord 0.
// 0 is "a", 1 is "c".
let output = computer.segment_sort_keys(
&docs,
ValueRange::GreaterThan(Some(0), false /* inclusive */),
);
// Should contain only the document with value "c" (ord 1).
assert_eq!(output, &[(1, Some(1))]);
}
}

View File

@@ -1,5 +1,7 @@
use std::cmp::Ordering;
use columnar::ValueRange;
use crate::collector::sort_key::{Comparator, NaturalComparator};
use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector;
use crate::collector::top_score_collector::push_assuming_capacity;
@@ -38,7 +40,11 @@ pub trait SegmentSortKeyComputer: 'static {
///
/// The computed sort keys are stored in an internal buffer and returned as a slice.
/// Subsequent calls to this method may reuse and overwrite the internal buffer.
fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec<Self::SegmentSortKey>;
fn segment_sort_keys(
&mut self,
docs: &[DocId],
filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)>;
/// Computes the sort key and pushes the document in a TopN Computer.
///
@@ -63,14 +69,18 @@ pub trait SegmentSortKeyComputer: 'static {
// should always be able to `reserve` space for the entire block.
top_n_computer.reserve(docs.len());
if let Some(threshold) = &top_n_computer.threshold {
// TODO: Would need to split the borrow of the TopNComputer to avoid cloning the
// threshold here.
let threshold = threshold.clone();
let comparator = self.segment_comparator();
let sort_keys = self.segment_sort_keys(docs);
let comparator = self.segment_comparator();
let value_range = if let Some(threshold) = &top_n_computer.threshold {
comparator.threshold_to_valuerange(threshold.clone())
} else {
ValueRange::All
};
for (&doc, sort_key) in docs.iter().zip(sort_keys.drain(..)) {
let sort_keys = self.segment_sort_keys(docs, value_range);
if let Some(threshold) = &top_n_computer.threshold {
let threshold = threshold.clone();
for (doc, sort_key) in sort_keys.drain(..) {
let cmp = comparator.compare(&sort_key, &threshold);
if cmp == Ordering::Greater {
// We validated at the top of the method that we have capacity.
@@ -79,8 +89,7 @@ pub trait SegmentSortKeyComputer: 'static {
}
} else {
// Eagerly push, without a threshold to compare to.
let sort_keys = self.segment_sort_keys(docs);
for (&doc, sort_key) in docs.iter().zip(sort_keys.drain(..)) {
for (doc, sort_key) in sort_keys.drain(..) {
// We validated at the top of the method that we have capacity.
top_n_computer.append_doc_unchecked(doc, sort_key);
}
@@ -302,7 +311,11 @@ where
.then_with(|| self.tail.compare_segment_sort_key(&left.1, &right.1))
}
fn segment_sort_keys(&mut self, _docs: &[DocId]) -> &mut Vec<Self::SegmentSortKey> {
fn segment_sort_keys(
&mut self,
_docs: &[DocId],
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
unimplemented!("The head and the tail are accessed independently.");
}
@@ -339,11 +352,12 @@ where
let (head_threshold, tail_threshold) = threshold.clone();
let head_cmp = self.head.segment_comparator();
let tail_cmp = self.tail.segment_comparator();
let head_filter = head_cmp.threshold_to_valuerange(head_threshold.clone());
let head_keys = self.head.segment_sort_keys(docs);
let head_keys = self.head.segment_sort_keys(docs, head_filter);
self.doc_buffer.clear();
self.head_key_buffer.clear();
for (head_key, &doc) in head_keys.drain(..).zip(docs) {
for (doc, head_key) in head_keys.drain(..) {
let cmp = head_cmp.compare(&head_key, &head_threshold);
if cmp != Ordering::Less {
self.doc_buffer.push(doc);
@@ -352,11 +366,13 @@ where
}
if !self.doc_buffer.is_empty() {
let tail_keys = self.tail.segment_sort_keys(&self.doc_buffer);
let tail_keys = self
.tail
.segment_sort_keys(&self.doc_buffer, ValueRange::All);
for ((head_key, tail_key), &doc) in self
.head_key_buffer
.drain(..)
.zip(tail_keys.drain(..))
.zip(tail_keys.drain(..).map(|(_, k)| k))
.zip(self.doc_buffer.iter())
{
let head_ord = head_cmp.compare(&head_key, &head_threshold);
@@ -372,15 +388,11 @@ where
}
} else {
// Eagerly push, without a threshold to compare to.
let head_keys = self.head.segment_sort_keys(docs);
let tail_keys = self.tail.segment_sort_keys(docs);
for ((doc, head_key), tail_key) in docs
.iter()
.zip(head_keys.drain(..))
.zip(tail_keys.drain(..))
{
let head_keys = self.head.segment_sort_keys(docs, ValueRange::All);
let tail_keys = self.tail.segment_sort_keys(docs, ValueRange::All);
for ((doc, head_key), (_, tail_key)) in head_keys.drain(..).zip(tail_keys.drain(..)) {
// We validated at the top of the method that we have capacity.
top_n_computer.append_doc_unchecked(*doc, (head_key, tail_key));
top_n_computer.append_doc_unchecked(doc, (head_key, tail_key));
}
}
}
@@ -427,8 +439,13 @@ where
self.sort_key_computer.segment_sort_key(doc, score)
}
fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec<Self::SegmentSortKey> {
self.sort_key_computer.segment_sort_keys(docs)
fn segment_sort_keys(
&mut self,
docs: &[DocId],
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.sort_key_computer
.segment_sort_keys(docs, ValueRange::All)
}
#[inline(always)]
@@ -605,7 +622,7 @@ where
pub struct FuncSegmentSortKeyComputer<F, TSortKey> {
func: F,
buffer: Vec<TSortKey>,
buffer: Vec<(DocId, TSortKey)>,
}
impl<F, SegmentF, TSortKey> SortKeyComputer for F
@@ -639,11 +656,15 @@ where
(self.func)(doc)
}
fn segment_sort_keys(&mut self, docs: &[DocId]) -> &mut Vec<Self::SegmentSortKey> {
fn segment_sort_keys(
&mut self,
docs: &[DocId],
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
self.buffer.clear();
self.buffer.reserve(docs.len());
for &doc in docs {
self.buffer.push((self.func)(doc));
self.buffer.push((doc, (self.func)(doc)));
}
&mut self.buffer
}
@@ -654,6 +675,34 @@ where
}
}
pub(crate) fn range_contains_none(range: &ValueRange<Option<u64>>) -> bool {
match range {
ValueRange::All => true,
ValueRange::Inclusive(r) => r.contains(&None),
ValueRange::GreaterThan(threshold, match_nulls) => *match_nulls || (None > *threshold),
ValueRange::LessThan(threshold, match_nulls) => *match_nulls || (None < *threshold),
}
}
pub(crate) fn convert_optional_u64_range_to_u64_range(
range: ValueRange<Option<u64>>,
) -> ValueRange<u64> {
if range_contains_none(&range) {
return ValueRange::All;
}
match range {
ValueRange::Inclusive(r) => {
let start = r.start().unwrap_or(0);
let end = r.end().unwrap_or(u64::MAX);
ValueRange::Inclusive(start..=end)
}
ValueRange::GreaterThan(Some(val), _match_nulls) => ValueRange::GreaterThan(val, false),
ValueRange::GreaterThan(None, _match_nulls) => ValueRange::Inclusive(u64::MIN..=u64::MAX),
ValueRange::LessThan(None, _match_nulls) => ValueRange::Inclusive(1..=0),
_ => ValueRange::All,
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;

View File

@@ -2,6 +2,7 @@ use std::cmp::Ordering;
use std::fmt;
use std::ops::Range;
use columnar::ValueRange;
use serde::{Deserialize, Serialize};
use super::Collector;
@@ -486,7 +487,11 @@ where
(self.sort_key_fn)(doc, score)
}
fn segment_sort_keys(&mut self, _docs: &[DocId]) -> &mut Vec<Self::SegmentSortKey> {
fn segment_sort_keys(
&mut self,
_docs: &[DocId],
_filter: ValueRange<Self::SegmentSortKey>,
) -> &mut Vec<(DocId, Self::SegmentSortKey)> {
unimplemented!("Batch computation is not supported for tweak score.")
}