feat: enable sort_by for Str/Bytes fast fields (#101)

Adds support for sorting by Str and Bytes fast fields on both
single-segment writes and cross-segment merges. Dictionary-encoded
fields use per-segment ordinals, so segment dictionaries are merged via
the columnar TermMerger to compute per-segment local-ord -> global-ord
mappings; the remapped u64 ordinals are then compared during kmerge.
This commit is contained in:
Mithun Chicklore Yogendra
2026-02-11 08:51:36 +05:30
committed by Luca Cominardi
parent 47ff79e2fc
commit b59fac74bb
11 changed files with 924 additions and 138 deletions

View File

@@ -33,6 +33,25 @@ pub fn merge_bytes_or_str_column(
Ok(())
}
/// Computes a per-segment mapping from old term ordinal to merged term ordinal.
///
/// Performs a streaming k-way merge of per-segment term dictionaries (SSTable-backed) to build
/// a unified ordering. For each segment, the output is a `Vec<TermOrdinal>` where index `i`
/// holds the merged global ordinal corresponding to segment-local ordinal `i`.
///
/// This is used by index sorting to compare terms from different segments without materializing
/// term bytes in memory — only ordinals are compared.
#[doc(hidden)]
pub fn compute_merged_term_ord_mapping(
bytes_columns: &[BytesColumn],
) -> io::Result<Vec<Vec<TermOrdinal>>> {
let bytes_columns_opt: Vec<Option<BytesColumn>> =
bytes_columns.iter().cloned().map(Some).collect();
let term_ord_mapping =
merge_dict_and_compute_term_ord_mapping(&bytes_columns_opt, |_| true, |_| Ok(()))?;
Ok(term_ord_mapping.into_per_segment_new_term_ordinals())
}
struct RemappedTermOrdinalsValues<'a> {
bytes_columns: &'a [Option<BytesColumn>],
term_ord_mapping: &'a TermOrdinalMapping,
@@ -118,14 +137,14 @@ fn is_term_present(bitsets: &[Option<BitSet>], term_merger: &TermMerger) -> bool
false
}
fn serialize_merged_dict(
fn merge_dict_and_compute_term_ord_mapping(
bytes_columns: &[Option<BytesColumn>],
merge_row_order: &MergeRowOrder,
output: &mut impl Write,
mut should_keep_term: impl FnMut(&TermMerger) -> bool,
mut emit_term: impl FnMut(&[u8]) -> io::Result<()>,
) -> io::Result<TermOrdinalMapping> {
let mut term_ord_mapping = TermOrdinalMapping::default();
let mut field_term_streams = Vec::new();
let mut field_term_streams = Vec::with_capacity(bytes_columns.len());
for (segment_ord, column_opt) in bytes_columns.iter().enumerate() {
if let Some(column) = column_opt {
term_ord_mapping.add_segment(column.dictionary.num_terms());
@@ -141,21 +160,33 @@ fn serialize_merged_dict(
}
let mut merged_terms = TermMerger::new(field_term_streams);
let mut sstable_builder = sstable::VoidSSTable::writer(output);
match merge_row_order {
MergeRowOrder::Stack(_) => {
let mut current_term_ord = 0;
while merged_terms.advance() {
let term_bytes: &[u8] = merged_terms.key();
sstable_builder.insert(term_bytes, &())?;
for (segment_ord, from_term_ord) in merged_terms.matching_segments() {
term_ord_mapping.register_from_to(segment_ord, from_term_ord, current_term_ord);
}
current_term_ord += 1;
}
sstable_builder.finish()?;
let mut current_term_ord = 0;
while merged_terms.advance() {
if !should_keep_term(&merged_terms) {
continue;
}
emit_term(merged_terms.key())?;
for (segment_ord, from_term_ord) in merged_terms.matching_segments() {
term_ord_mapping.register_from_to(segment_ord, from_term_ord, current_term_ord);
}
current_term_ord += 1;
}
Ok(term_ord_mapping)
}
fn serialize_merged_dict(
bytes_columns: &[Option<BytesColumn>],
merge_row_order: &MergeRowOrder,
output: &mut impl Write,
) -> io::Result<TermOrdinalMapping> {
let mut sstable_builder = sstable::VoidSSTable::writer(output);
let term_ord_mapping = match merge_row_order {
MergeRowOrder::Stack(_) => merge_dict_and_compute_term_ord_mapping(
bytes_columns,
|_| true,
|term_bytes| sstable_builder.insert(term_bytes, &()),
)?,
MergeRowOrder::Shuffled(shuffle_merge_order) => {
assert_eq!(shuffle_merge_order.alive_bitsets.len(), bytes_columns.len());
let mut term_bitsets: Vec<Option<BitSet>> = Vec::with_capacity(bytes_columns.len());
@@ -174,21 +205,14 @@ fn serialize_merged_dict(
}
}
}
let mut current_term_ord = 0;
while merged_terms.advance() {
let term_bytes: &[u8] = merged_terms.key();
if !is_term_present(&term_bitsets[..], &merged_terms) {
continue;
}
sstable_builder.insert(term_bytes, &())?;
for (segment_ord, from_term_ord) in merged_terms.matching_segments() {
term_ord_mapping.register_from_to(segment_ord, from_term_ord, current_term_ord);
}
current_term_ord += 1;
}
sstable_builder.finish()?;
merge_dict_and_compute_term_ord_mapping(
bytes_columns,
|merged_terms| is_term_present(&term_bitsets[..], merged_terms),
|term_bytes| sstable_builder.insert(term_bytes, &()),
)?
}
}
};
sstable_builder.finish()?;
Ok(term_ord_mapping)
}
@@ -211,4 +235,8 @@ impl TermOrdinalMapping {
fn get_segment(&self, segment_ord: u32) -> &[TermOrdinal] {
&self.per_segment_new_term_ordinals[segment_ord as usize]
}
fn into_per_segment_new_term_ordinals(self) -> Vec<Vec<TermOrdinal>> {
self.per_segment_new_term_ordinals
}
}

View File

@@ -7,6 +7,7 @@ use std::io;
use std::net::Ipv6Addr;
use std::sync::Arc;
pub use merge_dict_column::compute_merged_term_ord_mapping;
pub use merge_mapping::{MergeRowOrder, ShuffleMergeOrder, StackMergeOrder};
use super::writer::ColumnarSerializer;

View File

@@ -8,6 +8,9 @@ pub use column_type::{ColumnType, HasAssociatedColumnType};
pub use format_version::{CURRENT_VERSION, Version};
#[cfg(test)]
pub(crate) use merge::ColumnTypeCategory;
pub use merge::{MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, merge_columnar};
pub use merge::{
MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, compute_merged_term_ord_mapping,
merge_columnar,
};
pub use reader::ColumnarReader;
pub use writer::ColumnarWriter;

View File

@@ -93,44 +93,39 @@ impl ColumnarWriter {
.get::<NumericalColumnWriter>(sort_field.as_bytes())
})
else {
return Vec::new();
let str_or_bytes_column_opt = self
.str_field_hash_map
.get::<StrOrBytesColumnWriter>(sort_field.as_bytes())
.or_else(|| {
self.bytes_field_hash_map
.get::<StrOrBytesColumnWriter>(sort_field.as_bytes())
});
let Some(str_or_bytes_column) = str_or_bytes_column_opt else {
return Vec::new();
};
let dictionary_builder = &self.dictionaries[str_or_bytes_column.dictionary_id as usize];
let term_id_mapping = dictionary_builder.build_term_id_mapping(&self.arena);
let mut symbols_buffer = Vec::new();
return collect_sort_order_from_ops(
str_or_bytes_column.operation_iterator(&self.arena, None, &mut symbols_buffer),
num_docs,
reversed,
|uid| Some(term_id_mapping.to_ord(uid).0),
None,
|a, b| a.cmp(b),
);
};
let mut symbols_buffer = Vec::new();
let mut values = Vec::new();
let mut start_doc_check_fill = 0;
let mut current_doc_opt: Option<RowId> = None;
// Assumption: NewDoc will never call the same doc twice and is strictly increasing between
// calls
for op in numerical_col_writer.operation_iterator(&self.arena, None, &mut symbols_buffer) {
match op {
ColumnOperation::NewDoc(doc) => {
current_doc_opt = Some(doc);
}
ColumnOperation::Value(numerical_value) => {
if let Some(current_doc) = current_doc_opt {
// Fill up with 0.0 since last doc
values.extend((start_doc_check_fill..current_doc).map(|doc| (0.0, doc)));
start_doc_check_fill = current_doc + 1;
// handle multi values
current_doc_opt = None;
let score: f32 = f64::coerce(numerical_value) as f32;
values.push((score, current_doc));
}
}
}
}
for doc in values.len() as u32..num_docs {
values.push((0.0f32, doc));
}
values.sort_by(|(left_score, _), (right_score, _)| {
if reversed {
right_score.total_cmp(left_score)
} else {
left_score.total_cmp(right_score)
}
});
values.into_iter().map(|(_score, doc)| doc).collect()
collect_sort_order_from_ops(
numerical_col_writer.operation_iterator(&self.arena, None, &mut symbols_buffer),
num_docs,
reversed,
|nv| f64::coerce(nv) as f32,
0.0f32,
|a, b| a.total_cmp(b),
)
}
/// Records a column type. This is useful to bypass the coercion process,
@@ -470,6 +465,56 @@ impl ColumnarWriter {
}
}
/// Shared sorting pattern for both numeric and Str/Bytes sort fields.
///
/// Iterates column operations, fills gaps for missing docs with `default_key`, converts each value
/// to a sort key via `value_to_key`, then sorts by the key using `cmp_keys`. Returns the doc ids
/// in sorted order.
fn collect_sort_order_from_ops<V, K: Clone>(
ops: impl Iterator<Item = ColumnOperation<V>>,
num_docs: RowId,
reversed: bool,
value_to_key: impl Fn(V) -> K,
default_key: K,
cmp_keys: impl Fn(&K, &K) -> std::cmp::Ordering,
) -> Vec<u32> {
let mut doc_sort_keys: Vec<(K, RowId)> = Vec::with_capacity(num_docs as usize);
let mut start_doc_check_fill: RowId = 0;
let mut current_doc_opt: Option<RowId> = None;
for op in ops {
match op {
ColumnOperation::NewDoc(doc) => {
current_doc_opt = Some(doc);
}
ColumnOperation::Value(val) => {
if let Some(current_doc) = current_doc_opt {
// Fill gaps since the last doc with the default key.
doc_sort_keys.extend(
(start_doc_check_fill..current_doc).map(|doc| (default_key.clone(), doc)),
);
start_doc_check_fill = current_doc + 1;
// For multivalued fields, only the first value is used.
current_doc_opt = None;
doc_sort_keys.push((value_to_key(val), current_doc));
}
}
}
}
// Fill remaining docs at the tail.
doc_sort_keys.extend((start_doc_check_fill..num_docs).map(|doc| (default_key.clone(), doc)));
doc_sort_keys.sort_by(|(left_key, _), (right_key, _)| {
let cmp = cmp_keys(left_key, right_key);
if reversed { cmp.reverse() } else { cmp }
});
doc_sort_keys
.into_iter()
.map(|(_sort_key, doc)| doc)
.collect()
}
// Serialize [Dictionary, Column, dictionary num bytes U32::LE]
// Column: [Column Index, Column Values, column index num bytes U32::LE]
#[expect(clippy::too_many_arguments)]

View File

@@ -51,6 +51,16 @@ impl DictionaryBuilder {
UnorderedId(unordered_id)
}
fn build_sorted_terms<'a>(&'a self, arena: &'a MemoryArena) -> Vec<(&'a [u8], UnorderedId)> {
let mut terms: Vec<(&[u8], UnorderedId)> = self
.dict
.iter(arena)
.map(|(k, v)| (k, arena.read(v)))
.collect();
terms.sort_unstable_by_key(|(key, _)| *key);
terms
}
/// Serialize the dictionary into an fst, and returns the
/// `UnorderedId -> TermOrdinal` map.
pub fn serialize<'a, W: io::Write + 'a>(
@@ -58,12 +68,7 @@ impl DictionaryBuilder {
arena: &MemoryArena,
wrt: &mut W,
) -> io::Result<TermIdMapping> {
let mut terms: Vec<(&[u8], UnorderedId)> = self
.dict
.iter(arena)
.map(|(k, v)| (k, arena.read(v)))
.collect();
terms.sort_unstable_by_key(|(key, _)| *key);
let terms = self.build_sorted_terms(arena);
// TODO Remove the allocation.
let mut unordered_to_ord: Vec<OrderedId> = vec![OrderedId(0u32); terms.len()];
let mut sstable_builder = sstable::VoidSSTable::writer(wrt);
@@ -76,6 +81,16 @@ impl DictionaryBuilder {
Ok(TermIdMapping { unordered_to_ord })
}
/// Build the `UnorderedId -> OrderedId` mapping in memory without serializing.
pub fn build_term_id_mapping(&self, arena: &MemoryArena) -> TermIdMapping {
let terms = self.build_sorted_terms(arena);
let mut unordered_to_ord: Vec<OrderedId> = vec![OrderedId(0u32); terms.len()];
for (ord, (_key, unordered_id)) in terms.into_iter().enumerate() {
unordered_to_ord[unordered_id.0 as usize] = OrderedId(ord as u32);
}
TermIdMapping { unordered_to_ord }
}
pub(crate) fn mem_usage(&self) -> usize {
self.dict.mem_usage()
}

View File

@@ -43,7 +43,8 @@ pub use column_values::{
};
pub use columnar::{
CURRENT_VERSION, ColumnType, ColumnarReader, ColumnarWriter, HasAssociatedColumnType,
MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, Version, merge_columnar,
MergeRowOrder, ShuffleMergeOrder, StackMergeOrder, Version, compute_merged_term_ord_mapping,
merge_columnar,
};
use sstable::VoidSSTable;
pub use value::{NumericalType, NumericalValue};

View File

@@ -272,6 +272,51 @@ fn test_dictionary_encoded_bytes() {
assert_eq!(term_buffer, b"b");
}
#[test]
fn test_sort_order_str_asc_desc() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_str(0, "s", "z");
dataframe_writer.record_str(2, "s", "a");
dataframe_writer.record_str(3, "s", "m");
let asc = dataframe_writer.sort_order("s", 4, false);
assert_eq!(asc, vec![1, 2, 3, 0]); // None, a, m, z
let desc = dataframe_writer.sort_order("s", 4, true);
assert_eq!(desc, vec![0, 3, 2, 1]); // z, m, a, None
}
#[test]
fn test_sort_order_str_empty_vs_missing() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_str(0, "s", "");
let asc = dataframe_writer.sort_order("s", 2, false);
assert_eq!(asc, vec![1, 0]); // None first, then empty string
}
#[test]
fn test_sort_order_str_multivalued_stable() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_str(0, "s", "z");
dataframe_writer.record_str(0, "s", "a"); // multivalued; first value wins
dataframe_writer.record_str(1, "s", "b");
dataframe_writer.record_str(2, "s", "b");
let asc = dataframe_writer.sort_order("s", 3, false);
assert_eq!(asc, vec![1, 2, 0]); // b, b (stable), z
}
#[test]
fn test_sort_order_bytes_asc() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_bytes(1, "b", &[0x01]);
dataframe_writer.record_bytes(3, "b", &[0x00]);
let asc = dataframe_writer.sort_order("b", 4, false);
assert_eq!(asc, vec![0, 2, 3, 1]); // None, None, 0x00, 0x01
}
fn num_strategy() -> impl Strategy<Value = NumericalValue> {
prop_oneof![
3 => Just(NumericalValue::U64(0u64)),

View File

@@ -248,7 +248,14 @@ impl IndexBuilder {
sort_by_field.field
)));
}
let supported_field_types = [Type::I64, Type::U64, Type::F64, Type::Date];
let supported_field_types = [
Type::I64,
Type::U64,
Type::F64,
Type::Date,
Type::Str,
Type::Bytes,
];
let field_type = entry.field_type().value_type();
if !supported_field_types.contains(&field_type) {
return Err(TantivyError::InvalidArgument(format!(

View File

@@ -562,10 +562,10 @@ mod tests_indexsorting {
#[test]
fn test_text_sort() -> crate::Result<()> {
let mut schema_builder = SchemaBuilder::new();
schema_builder.add_text_field("id", STRING | FAST | STORED);
let id_field = schema_builder.add_text_field("id", STRING | FAST | STORED);
schema_builder.add_text_field("name", TEXT | STORED);
let resp = IndexBuilder::new()
let index = IndexBuilder::new()
.schema(schema_builder.build())
.settings(IndexSettings {
sort_by_field: Some(IndexSortByField {
@@ -574,11 +574,26 @@ mod tests_indexsorting {
}),
..Default::default()
})
.create_in_ram();
assert!(resp
.unwrap_err()
.to_string()
.contains("Unsupported field type"));
.create_in_ram()?;
let mut index_writer = index.writer_for_tests()?;
index_writer.add_document(doc!(id_field => "z"))?;
index_writer.add_document(doc!(id_field => "a"))?;
index_writer.add_document(doc!(id_field => "m"))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let segment_reader = searcher.segment_reader(0);
let str_col = segment_reader.fast_fields().str("id")?.unwrap();
let mut values = Vec::new();
for doc in 0..segment_reader.max_doc() {
if let Some(ord) = str_col.ords().first(doc) {
let mut s = String::new();
str_col.ord_to_str(ord, &mut s)?;
values.push(s);
}
}
assert_eq!(values, vec!["a", "m", "z"]);
Ok(())
}

View File

@@ -1,8 +1,8 @@
use std::sync::Arc;
use columnar::{
ColumnType, ColumnValues, ColumnarReader, MergeRowOrder, RowAddr, ShuffleMergeOrder,
StackMergeOrder,
compute_merged_term_ord_mapping, BytesColumn, Column, ColumnType, ColumnValues, ColumnarReader,
MergeRowOrder, RowAddr, ShuffleMergeOrder, StackMergeOrder,
};
use common::ReadOnlyBitSet;
use itertools::Itertools;
@@ -17,13 +17,47 @@ use crate::index::{Segment, SegmentComponent, SegmentReader};
use crate::indexer::doc_id_mapping::{MappingType, SegmentDocIdMapping};
use crate::indexer::SegmentSerializer;
use crate::postings::{InvertedIndexSerializer, Postings, SegmentPostings};
use crate::schema::{value_type_to_column_type, Field, FieldType, Schema};
use crate::schema::{value_type_to_column_type, Field, FieldType, Schema, Type};
use crate::store::StoreWriter;
use crate::termdict::{TermMerger, TermOrdinal};
use crate::{
DocAddress, DocId, IndexSettings, IndexSortByField, InvertedIndexReader, Order, SegmentOrdinal,
};
/// Per-segment accessor for Str/Bytes sort fields during index merging.
///
/// Each segment stores its own term dictionary with segment-local ordinals. To compare terms
/// across segments we compute a merged global dictionary and map each segment's local ordinals
/// to the corresponding merged ordinal via `merged_term_ord_mapping`. This avoids materializing
/// the actual term bytes during the merge sort — ordinal comparison is sufficient because the
/// merged dictionary preserves lexicographic order.
struct StrBytesSortFieldAccessor {
ords: Column<u64>,
merged_term_ord_mapping: Vec<TermOrdinal>,
}
impl StrBytesSortFieldAccessor {
fn remapped_term_ord(&self, doc_id: DocId) -> Option<TermOrdinal> {
self.ords.first(doc_id).map(|old_ord| {
let old_ord = old_ord as usize;
debug_assert!(old_ord < self.merged_term_ord_mapping.len());
self.merged_term_ord_mapping[old_ord]
})
}
}
/// Owned per-segment sort-field accessors, kept alive for the duration of the merge.
///
/// - `Numeric`: direct column value access — all numeric/datetime types share a single u64 column
/// interface, so segments can be compared directly by value.
/// - `StrBytes`: ordinal-based access — each segment's local term ordinals are remapped to merged
/// global ordinals so that cross-segment lexicographic comparison works without loading term
/// bytes.
enum ReaderSortFieldAccessors {
Numeric(Vec<(SegmentOrdinal, Arc<dyn ColumnValues>)>),
StrBytes(Vec<(SegmentOrdinal, StrBytesSortFieldAccessor)>),
}
/// Segment's max doc must be `< MAX_DOC_LIMIT`.
///
/// We do not allow segments with more than
@@ -187,7 +221,10 @@ impl IndexMerger {
let max_doc = readers.iter().map(|reader| reader.num_docs()).sum();
if let Some(sort_by_field) = index_settings.sort_by_field.as_ref() {
readers = Self::sort_readers_by_min_sort_field(readers, sort_by_field)?;
let schema_field = schema.get_field(&sort_by_field.field)?;
let field_entry = schema.get_field_entry(schema_field);
let field_type = field_entry.field_type().value_type();
readers = Self::sort_readers_by_min_sort_field(readers, sort_by_field, field_type)?;
}
// sort segments by their natural sort setting
if max_doc >= MAX_DOC_LIMIT {
@@ -205,16 +242,29 @@ impl IndexMerger {
})
}
fn sort_by_field_type(&self, sort_by_field: &IndexSortByField) -> crate::Result<Type> {
let schema_field = self.schema.get_field(&sort_by_field.field)?;
let field_entry = self.schema.get_field_entry(schema_field);
Ok(field_entry.field_type().value_type())
}
fn sort_readers_by_min_sort_field(
readers: Vec<SegmentReader>,
sort_by_field: &IndexSortByField,
field_type: Type,
) -> crate::Result<Vec<SegmentReader>> {
if matches!(field_type, Type::Str | Type::Bytes) {
// Ordinals are per-segment and not directly comparable, so the "disjunct min/max"
// shortcut that works for numeric fields does not apply here.
return Ok(readers);
}
// presort the readers by their min_values, so that when they are disjunct, we can use
// the regular merge logic (implicitly sorted)
let mut readers_with_min_sort_values = readers
.into_iter()
.map(|reader| {
let accessor = Self::get_sort_field_accessor(&reader, sort_by_field)?;
let accessor = Self::get_numeric_accessor(&reader, sort_by_field)?;
Ok((reader, accessor.min_value()))
})
.collect::<crate::Result<Vec<_>>>()?;
@@ -282,12 +332,17 @@ impl IndexMerger {
&self,
sort_by_field: &IndexSortByField,
) -> crate::Result<bool> {
let reader_ordinal_and_field_accessors =
self.get_reader_with_sort_field_accessor(sort_by_field)?;
let field_type = self.sort_by_field_type(sort_by_field)?;
// Disjunct shortcut is invalid for Str/Bytes because ords are per-segment.
if matches!(field_type, Type::Str | Type::Bytes) {
return Ok(false);
}
let reader_ordinal_and_field_accessors = self.get_numeric_accessors(sort_by_field)?;
let everything_is_in_order = reader_ordinal_and_field_accessors
.into_iter()
.map(|(_, col)| Arc::new(col))
.map(|(_, col)| col)
.tuple_windows()
.all(|(field_accessor1, field_accessor2)| {
if sort_by_field.order.is_asc() {
@@ -299,7 +354,69 @@ impl IndexMerger {
Ok(everything_is_in_order)
}
pub(crate) fn get_sort_field_accessor(
fn get_str_bytes_column(
reader: &SegmentReader,
sort_by_field: &IndexSortByField,
field_type: Type,
) -> crate::Result<BytesColumn> {
let not_available = || -> crate::TantivyError {
FastFieldNotAvailableError {
field_name: sort_by_field.field.to_string(),
}
.into()
};
match field_type {
Type::Str => reader
.fast_fields()
.str(&sort_by_field.field)?
.map(Into::into)
.ok_or_else(not_available),
Type::Bytes => reader
.fast_fields()
.bytes(&sort_by_field.field)?
.ok_or_else(not_available),
_ => unreachable!("get_str_bytes_column called with non-Str/Bytes type"),
}
}
/// Builds per-segment [`StrBytesSortFieldAccessor`]s for Str/Bytes sort fields.
///
/// 1. Extracts each segment's `BytesColumn` (term dictionary + ordinal column).
/// 2. Computes a merged dictionary across all segments via [`compute_merged_term_ord_mapping`],
/// producing a per-segment mapping from local term ordinal → merged global ordinal.
/// 3. Wraps each segment's ordinal column and mapping into a `StrBytesSortFieldAccessor`.
fn get_str_bytes_accessors(
&self,
sort_by_field: &IndexSortByField,
field_type: Type,
) -> crate::Result<Vec<(SegmentOrdinal, StrBytesSortFieldAccessor)>> {
let bytes_columns = self
.readers
.iter()
.map(|reader| Self::get_str_bytes_column(reader, sort_by_field, field_type))
.collect::<crate::Result<Vec<_>>>()?;
let merged_term_ord_mappings = compute_merged_term_ord_mapping(&bytes_columns)?;
debug_assert_eq!(bytes_columns.len(), merged_term_ord_mappings.len());
let accessors = bytes_columns
.into_iter()
.zip(merged_term_ord_mappings)
.enumerate()
.map(
|(reader_ordinal, (bytes_column, merged_term_ord_mapping))| {
(
reader_ordinal as SegmentOrdinal,
StrBytesSortFieldAccessor {
ords: bytes_column.ords().clone(),
merged_term_ord_mapping,
},
)
},
)
.collect::<Vec<_>>();
Ok(accessors)
}
fn get_numeric_accessor(
reader: &SegmentReader,
sort_by_field: &IndexSortByField,
) -> crate::Result<Arc<dyn ColumnValues>> {
@@ -312,25 +429,67 @@ impl IndexMerger {
})?;
Ok(value_accessor.first_or_default_col(0u64))
}
/// Collecting value_accessors into a vec to bind the lifetime.
pub(crate) fn get_reader_with_sort_field_accessor(
fn get_numeric_accessors(
&self,
sort_by_field: &IndexSortByField,
) -> crate::Result<Vec<(SegmentOrdinal, Arc<dyn ColumnValues>)>> {
let reader_ordinal_and_field_accessors = self
.readers
self.readers
.iter()
.enumerate()
.map(|(reader_ordinal, _)| reader_ordinal as SegmentOrdinal)
.map(|reader_ordinal: SegmentOrdinal| {
let value_accessor = Self::get_sort_field_accessor(
&self.readers[reader_ordinal as usize],
sort_by_field,
)?;
Ok((reader_ordinal, value_accessor))
.map(|(reader_ordinal, reader)| {
let reader_ordinal = reader_ordinal as SegmentOrdinal;
let accessor = Self::get_numeric_accessor(reader, sort_by_field)?;
Ok((reader_ordinal, accessor))
})
.collect::<crate::Result<Vec<_>>>()?;
Ok(reader_ordinal_and_field_accessors)
.collect::<crate::Result<Vec<_>>>()
}
/// Builds owned per-segment sort accessors so they stay alive during merge.
///
/// Dispatches on the sort field's value type: numeric types use direct column value access,
/// while Str/Bytes types go through the ordinal-remapping path (see
/// [`StrBytesSortFieldAccessor`]).
fn get_reader_with_sort_field_accessor(
&self,
sort_by_field: &IndexSortByField,
) -> crate::Result<ReaderSortFieldAccessors> {
let field_type = self.sort_by_field_type(sort_by_field)?;
if matches!(field_type, Type::Str | Type::Bytes) {
let accessors = self.get_str_bytes_accessors(sort_by_field, field_type)?;
return Ok(ReaderSortFieldAccessors::StrBytes(accessors));
}
let accessors = self.get_numeric_accessors(sort_by_field)?;
Ok(ReaderSortFieldAccessors::Numeric(accessors))
}
fn extend_sorted_doc_ids<T, F>(
&self,
reader_ordinal_and_field_accessors: &[(SegmentOrdinal, T)],
mut is_less: F,
sorted_doc_ids: &mut Vec<DocAddress>,
) where
F: FnMut(&(DocId, &SegmentOrdinal, &T), &(DocId, &SegmentOrdinal, &T)) -> bool,
{
let doc_id_reader_pair =
reader_ordinal_and_field_accessors
.iter()
.map(|(reader_ord, ff_reader)| {
let reader = &self.readers[*reader_ord as usize];
reader
.doc_ids_alive()
.map(move |doc_id| (doc_id, reader_ord, ff_reader))
});
sorted_doc_ids.extend(
doc_id_reader_pair
.into_iter()
.kmerge_by(|a, b| is_less(a, b))
.map(|(doc_id, &segment_ord, _)| DocAddress {
doc_id,
segment_ord,
}),
);
}
/// Generates the doc_id mapping where position in the vec=new
@@ -341,21 +500,9 @@ impl IndexMerger {
&self,
sort_by_field: &IndexSortByField,
) -> crate::Result<SegmentDocIdMapping> {
let reader_ordinal_and_field_accessors =
self.get_reader_with_sort_field_accessor(sort_by_field)?;
let sort_field_accessors = self.get_reader_with_sort_field_accessor(sort_by_field)?;
// Loading the field accessor on demand causes a 15x regression
// create iterators over segment/sort_accessor/doc_id tuple
let doc_id_reader_pair =
reader_ordinal_and_field_accessors
.iter()
.map(|(reader_ord, ff_reader)| {
let reader = &self.readers[*reader_ord as usize];
reader
.doc_ids_alive()
.map(move |doc_id| (doc_id, reader_ord, ff_reader))
});
let total_num_new_docs = self
.readers
.iter()
@@ -364,24 +511,51 @@ impl IndexMerger {
let mut sorted_doc_ids: Vec<DocAddress> = Vec::with_capacity(total_num_new_docs);
// create iterator tuple of (old doc_id, reader) in order of the new doc_ids
sorted_doc_ids.extend(
doc_id_reader_pair
.into_iter()
.kmerge_by(|a, b| {
let val1 = a.2.get_val(a.0);
let val2 = b.2.get_val(b.0);
if sort_by_field.order == Order::Asc {
val1 < val2
} else {
val1 > val2
}
})
.map(|(doc_id, &segment_ord, _)| DocAddress {
doc_id,
segment_ord,
}),
);
// K-way merge of alive doc ids across segments, ordered by the sort field.
//
// Numeric: compare raw u64 column values directly.
// Str/Bytes: compare merged global ordinals obtained via `remapped_term_ord`.
// Documents without a value map to `None` — first in ascending, last in descending.
let asc = sort_by_field.order == Order::Asc;
match sort_field_accessors {
ReaderSortFieldAccessors::Numeric(reader_ordinal_and_field_accessors) => {
self.extend_sorted_doc_ids(
&reader_ordinal_and_field_accessors,
|a, b| {
let val1 = a.2.get_val(a.0);
let val2 = b.2.get_val(b.0);
if asc {
val1 < val2
} else {
val1 > val2
}
},
&mut sorted_doc_ids,
);
}
ReaderSortFieldAccessors::StrBytes(reader_ordinal_and_field_accessors) => {
self.extend_sorted_doc_ids(
&reader_ordinal_and_field_accessors,
|a, b| {
let val1 = a.2.remapped_term_ord(a.0);
let val2 = b.2.remapped_term_ord(b.0);
match (val1, val2) {
(None, None) => false,
(None, Some(_)) => asc,
(Some(_), None) => !asc,
(Some(left), Some(right)) => {
if asc {
left < right
} else {
left > right
}
}
}
},
&mut sorted_doc_ids,
);
}
}
let alive_bitsets: Vec<Option<ReadOnlyBitSet>> = self
.readers

View File

@@ -1,5 +1,9 @@
#[cfg(test)]
mod tests {
use std::cmp::Ordering;
use proptest::prelude::*;
use crate::collector::TopDocs;
use crate::fastfield::AliveBitSet;
use crate::index::Index;
@@ -7,7 +11,7 @@ mod tests {
use crate::query::QueryParser;
use crate::schema::{
self, BytesOptions, Facet, FacetOptions, IndexRecordOption, NumericOptions,
TextFieldIndexing, TextOptions, Value,
TextFieldIndexing, TextOptions, Value, FAST, STRING,
};
use crate::{
DocAddress, DocSet, IndexSettings, IndexSortByField, IndexWriter, Order, TantivyDocument,
@@ -359,6 +363,454 @@ mod tests {
}
}
// ---- Str/Bytes sort_by helpers ----
fn build_str_sorted_index(order: Order, segments: Vec<Vec<Option<&str>>>) -> Index {
let segments = segments
.into_iter()
.map(|segment| {
segment
.into_iter()
.map(|value| value.map(str::to_owned))
.collect()
})
.collect();
build_str_sorted_index_owned(order, segments)
}
fn build_str_sorted_index_owned(order: Order, segments: Vec<Vec<Option<String>>>) -> Index {
let mut schema_builder = schema::Schema::builder();
let str_field = schema_builder.add_text_field("str", STRING | FAST);
let schema = schema_builder.build();
let index_builder = Index::builder().schema(schema).settings(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "str".to_string(),
order,
}),
..Default::default()
});
let index = index_builder.create_in_ram().unwrap();
{
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
for segment in segments {
for value in segment {
let mut doc = TantivyDocument::new();
if let Some(val) = value {
doc.add_text(str_field, val);
}
index_writer.add_document(doc).unwrap();
}
index_writer.commit().unwrap();
}
}
{
let segment_ids = index.searchable_segment_ids().unwrap();
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
index_writer.merge(&segment_ids).wait().unwrap();
index_writer.wait_merging_threads().unwrap();
}
index
}
fn build_bytes_sorted_index(order: Order, segments: Vec<Vec<Option<&[u8]>>>) -> Index {
let segments = segments
.into_iter()
.map(|segment| {
segment
.into_iter()
.map(|value| value.map(<[u8]>::to_vec))
.collect()
})
.collect();
build_bytes_sorted_index_owned(order, segments)
}
fn build_bytes_sorted_index_owned(order: Order, segments: Vec<Vec<Option<Vec<u8>>>>) -> Index {
let mut schema_builder = schema::Schema::builder();
let bytes_field = schema_builder
.add_bytes_field("bytes", BytesOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index_builder = Index::builder().schema(schema).settings(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "bytes".to_string(),
order,
}),
..Default::default()
});
let index = index_builder.create_in_ram().unwrap();
{
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
for segment in segments {
for value in segment {
let mut doc = TantivyDocument::new();
if let Some(val) = value {
doc.add_bytes(bytes_field, &val);
}
index_writer.add_document(doc).unwrap();
}
index_writer.commit().unwrap();
}
}
{
let segment_ids = index.searchable_segment_ids().unwrap();
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
index_writer.merge(&segment_ids).wait().unwrap();
index_writer.wait_merging_threads().unwrap();
}
index
}
fn collect_str_values(index: &Index) -> Vec<Option<String>> {
let reader = index.reader().unwrap();
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_readers().last().unwrap();
let str_col = segment_reader.fast_fields().str("str").unwrap().unwrap();
let mut values = Vec::new();
for doc in 0..segment_reader.max_doc() {
if let Some(ord) = str_col.ords().first(doc) {
let mut s = String::new();
str_col.ord_to_str(ord, &mut s).unwrap();
values.push(Some(s));
} else {
values.push(None);
}
}
values
}
fn collect_bytes_values(index: &Index) -> Vec<Option<Vec<u8>>> {
let reader = index.reader().unwrap();
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_readers().last().unwrap();
let bytes_col = segment_reader
.fast_fields()
.bytes("bytes")
.unwrap()
.unwrap();
let mut values = Vec::new();
for doc in 0..segment_reader.max_doc() {
if let Some(ord) = bytes_col.ords().first(doc) {
let mut buf = Vec::new();
bytes_col.ord_to_bytes(ord, &mut buf).unwrap();
values.push(Some(buf));
} else {
values.push(None);
}
}
values
}
fn compare_option_values<T: Ord>(
left: &Option<T>,
right: &Option<T>,
order: Order,
) -> Ordering {
match (left, right) {
(None, None) => Ordering::Equal,
(None, Some(_)) => {
if order.is_asc() {
Ordering::Less
} else {
Ordering::Greater
}
}
(Some(_), None) => {
if order.is_asc() {
Ordering::Greater
} else {
Ordering::Less
}
}
(Some(left), Some(right)) => {
if order.is_asc() {
left.cmp(right)
} else {
right.cmp(left)
}
}
}
}
// ---- Single-segment sort ----
#[test]
fn test_single_segment_str_sorted() {
// Insert out of order into a single segment.
// Read back and verify the segment itself is sorted — no merge involved.
let mut schema_builder = schema::Schema::builder();
let str_field = schema_builder.add_text_field("str", STRING | FAST);
let schema = schema_builder.build();
let index = Index::builder()
.schema(schema)
.settings(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "str".to_string(),
order: Order::Asc,
}),
..Default::default()
})
.create_in_ram()
.unwrap();
{
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
index_writer.add_document(doc!(str_field => "z")).unwrap();
index_writer.add_document(doc!(str_field => "a")).unwrap();
index_writer.add_document(doc!(str_field => "m")).unwrap();
index_writer.commit().unwrap();
}
// No merge — read the single segment directly.
let values = collect_str_values(&index);
assert_eq!(
values,
vec![
Some("a".to_string()),
Some("m".to_string()),
Some("z".to_string())
]
);
}
// ---- Cross-segment merge: Str ----
#[test]
fn test_merge_sorted_index_str_asc() {
// Segment A: ["z", "a"] (out of order — proves per-segment sorting).
// Segment B: ["m", "b"] (also out of order).
// If per-segment sorting failed, kmerge would see unsorted streams.
// If the disjunct shortcut fired, it would stack segments without re-sorting.
// Correct merged order is ["a","b","m","z"].
let index = build_str_sorted_index(
Order::Asc,
vec![vec![Some("z"), Some("a")], vec![Some("m"), Some("b")]],
);
let values = collect_str_values(&index);
assert_eq!(
values,
vec![
Some("a".to_string()),
Some("b".to_string()),
Some("m".to_string()),
Some("z".to_string())
]
);
}
#[test]
fn test_merge_sorted_index_str_desc() {
let index = build_str_sorted_index(
Order::Desc,
vec![vec![Some("z"), None], vec![Some("m"), Some("a")]],
);
let values = collect_str_values(&index);
assert_eq!(
values,
vec![
Some("z".to_string()),
Some("m".to_string()),
Some("a".to_string()),
None
]
);
}
#[test]
fn test_merge_sorted_index_str_missing_values() {
// Second segment has no values for the sort field.
let index = build_str_sorted_index(
Order::Asc,
vec![vec![Some("b"), Some("c")], vec![None, None]],
);
let values = collect_str_values(&index);
assert_eq!(
values,
vec![None, None, Some("b".to_string()), Some("c".to_string())]
);
}
#[test]
fn test_merge_sorted_index_str_with_deletes() {
let mut schema_builder = schema::Schema::builder();
let str_field = schema_builder.add_text_field("str", STRING | FAST);
let schema = schema_builder.build();
let index_builder = Index::builder().schema(schema).settings(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "str".to_string(),
order: Order::Asc,
}),
..Default::default()
});
let index = index_builder.create_in_ram().unwrap();
{
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
// Segment 1 (with a delete)
index_writer.add_document(doc!(str_field => "z")).unwrap();
index_writer
.add_document(doc!(str_field => "deleteme"))
.unwrap();
index_writer.delete_term(Term::from_field_text(str_field, "deleteme"));
index_writer.commit().unwrap();
index_writer.add_document(doc!(str_field => "a")).unwrap();
index_writer.add_document(doc!(str_field => "m")).unwrap();
index_writer.commit().unwrap();
}
{
let segment_ids = index.searchable_segment_ids().unwrap();
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
index_writer.merge(&segment_ids).wait().unwrap();
index_writer.wait_merging_threads().unwrap();
}
let values = collect_str_values(&index);
assert_eq!(
values,
vec![
Some("a".to_string()),
Some("m".to_string()),
Some("z".to_string())
]
);
}
// ---- Cross-segment merge: Bytes ----
#[test]
fn test_merge_sorted_index_bytes_asc() {
let index = build_bytes_sorted_index(
Order::Asc,
vec![
vec![Some(&[0x02][..]), Some(&[0x01][..])],
vec![Some(&[0x00][..])],
],
);
let values = collect_bytes_values(&index);
assert_eq!(
values,
vec![Some(vec![0x00]), Some(vec![0x01]), Some(vec![0x02])]
);
}
#[test]
fn test_merge_sorted_index_bytes_desc() {
let index = build_bytes_sorted_index(
Order::Desc,
vec![
vec![Some(&[0x02][..]), None],
vec![Some(&[0x01][..]), Some(&[0x00][..])],
],
);
let values = collect_bytes_values(&index);
assert_eq!(
values,
vec![Some(vec![0x02]), Some(vec![0x01]), Some(vec![0x00]), None]
);
}
#[test]
fn test_merge_sorted_index_bytes_missing_values() {
// Second segment has no values for the sort field.
let index = build_bytes_sorted_index(
Order::Asc,
vec![vec![Some(&[0x01][..]), Some(&[0x02][..])], vec![None, None]],
);
let values = collect_bytes_values(&index);
assert_eq!(values, vec![None, None, Some(vec![0x01]), Some(vec![0x02])]);
}
#[test]
fn test_merge_sorted_index_bytes_with_deletes() {
let mut schema_builder = schema::Schema::builder();
let bytes_field = schema_builder
.add_bytes_field("bytes", BytesOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index_builder = Index::builder().schema(schema).settings(IndexSettings {
sort_by_field: Some(IndexSortByField {
field: "bytes".to_string(),
order: Order::Asc,
}),
..Default::default()
});
let index = index_builder.create_in_ram().unwrap();
{
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
// Segment 1 (with a delete)
index_writer
.add_document(doc!(bytes_field => vec![0x02]))
.unwrap();
index_writer
.add_document(doc!(bytes_field => vec![0x01]))
.unwrap();
index_writer.delete_term(Term::from_field_bytes(bytes_field, &[0x01]));
index_writer.commit().unwrap();
index_writer
.add_document(doc!(bytes_field => vec![0x00]))
.unwrap();
index_writer.commit().unwrap();
}
{
let segment_ids = index.searchable_segment_ids().unwrap();
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
index_writer.merge(&segment_ids).wait().unwrap();
index_writer.wait_merging_threads().unwrap();
}
let values = collect_bytes_values(&index);
assert_eq!(values, vec![Some(vec![0x00]), Some(vec![0x02])]);
}
proptest! {
#[test]
fn test_merge_sorted_index_str_matches_sorted_input(
order in prop_oneof![Just(Order::Asc), Just(Order::Desc)],
segments in proptest::collection::vec(
proptest::collection::vec(proptest::option::of("[a-z]{0,8}"), 1..8),
1..6,
)
) {
let index = build_str_sorted_index_owned(order, segments.clone());
let values = collect_str_values(&index);
let mut expected: Vec<Option<String>> = segments.into_iter().flatten().collect();
expected.sort_by(|left, right| compare_option_values(left, right, order));
prop_assert_eq!(values, expected);
}
#[test]
fn test_merge_sorted_index_bytes_matches_sorted_input(
order in prop_oneof![Just(Order::Asc), Just(Order::Desc)],
segments in proptest::collection::vec(
proptest::collection::vec(
proptest::option::of(proptest::collection::vec(any::<u8>(), 0..8)),
1..8,
),
1..6,
)
) {
let index = build_bytes_sorted_index_owned(order, segments.clone());
let values = collect_bytes_values(&index);
let mut expected: Vec<Option<Vec<u8>>> = segments.into_iter().flatten().collect();
expected.sort_by(|left, right| compare_option_values(left, right, order));
prop_assert_eq!(values, expected);
}
}
// #[test]
// fn test_merge_sorted_index_asc() {
// let index = create_test_index(