Compare commits

...

16 Commits

Author SHA1 Message Date
Paul Masurel
6dd0be28cf Codex CR comments 2026-05-27 23:57:05 +02:00
Paul Masurel
465a761b2f Adding implem for filter-vec for neon as suggested by Adam. 2026-05-27 23:22:27 +02:00
Paul Masurel
46b3fb9ed3 Relying on upstream version of datasketch and stop using HLL 4. (#2936)
We were relying on a fork for:

a bugfix in LIST serialization
a better API exposing a new Coupon type, required for caching coupons.
We also stop using HLL8 in hope to fix
https://datadoghq.atlassian.net/browse/CLOUDPREM-625

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2026-05-19 13:29:35 +02:00
trinity-1686a
fbe620b9b4 Merge pull request #2933 from quickwit-oss/1686a/sstable-opt
optimise sstable index access pattern
2026-05-19 11:43:17 +02:00
trinity-1686a
95d8a3989a cr 2026-05-19 11:38:48 +02:00
trinity-1686a
ea61a68db4 skip sstable index binary search when ordinal is in same block 2026-05-16 11:35:38 +02:00
trinity-1686a
c367df37c1 refactor sstable index 2026-05-16 11:30:02 +02:00
Mohammad Dashti
d99a5d4e91 Rename validate_aggregation_fields to validate_aggregation_fields_exist
Applies @PSeitz's review suggestion to make the function name more
descriptive of what it checks. Also adds a doc note clarifying why
validation is opt-in rather than enforced by default.
2026-05-16 15:45:20 +08:00
Mohammad Dashti
2de6f075ce Fixed the example 2026-05-16 15:45:20 +08:00
Mohammad Dashti
18080067c7 Applied PR comment:
I would move it outside of the aggregation. You can fetch the fields from the aggregation request and do a validation in a helper function
2026-05-16 15:45:20 +08:00
Mohammad Dashti
95db7d2e5c Revert "Revert all impl."
This reverts commit d5e0991549a05bf80f19f853f7689ad69f96e7e5.
2026-05-16 15:45:20 +08:00
Mohammad Dashti
fc017c4c74 Applied PR comments. 2026-05-16 15:45:20 +08:00
Mohammad Dashti
141c91d028 Added a flag: strict_validation 2026-05-16 15:45:20 +08:00
Mohammad Dashti
36a83e7c1a Fixed agg validation 2026-05-16 15:45:20 +08:00
jinhelin
be11f8a6a1 Fix opening positions file error 2026-05-14 15:55:59 +08:00
dependabot[bot]
4305e4029e Update binggan requirement from 0.16.1 to 0.17.0
Updates the requirements on [binggan](https://github.com/pseitz/binggan) to permit the latest version.
- [Changelog](https://github.com/PSeitz/binggan/blob/main/CHANGELOG.md)
- [Commits](https://github.com/pseitz/binggan/commits)

---
updated-dependencies:
- dependency-name: binggan
  dependency-version: 0.17.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-05-12 15:10:20 +08:00
18 changed files with 991 additions and 337 deletions

View File

@@ -65,7 +65,7 @@ tantivy-bitpacker = { version = "0.10", path = "./bitpacker" }
common = { version = "0.11", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version = "0.7", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
sketches-ddsketch = { version = "0.4", features = ["use_serde"] }
datasketches = { git = "https://github.com/fulmicoton-dd/datasketches-rust", rev = "7635fb8" }
datasketches = { version = "0.3.0", features = ["hll"] }
futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
@@ -75,7 +75,7 @@ typetag = "0.2.21"
winapi = "0.3.9"
[dev-dependencies]
binggan = "0.16.1"
binggan = "0.17.0"
rand = "0.9"
maplit = "1.0.2"
matches = "0.1.9"

View File

@@ -14,7 +14,6 @@ mod tests {
let mut bitpacker = BitPacker::new();
let mut buffer = Vec::new();
for _ in 0..num_els {
// the values do not matter.
bitpacker.write(0u64, bit_width, &mut buffer).unwrap();
bitpacker.flush(&mut buffer).unwrap();
}
@@ -62,4 +61,124 @@ mod tests {
blocked_bitpacker
});
}
// --- filter_vec benchmarks ---
//
// We use a large N so that the vec clone is a smaller fraction of the total time,
// and so L2/L3 cache effects are representative of real workloads.
// Values are spread uniformly in [0, MAX_VAL].
const N: usize = 100_000;
const MAX_VAL: u32 = 1_000;
fn make_values(n: usize, max_val: u32) -> Vec<u32> {
(0..n as u32)
.map(|i| (i as u64 * max_val as u64 / n as u64) as u32)
.collect()
}
#[bench]
fn bench_filter_vec_dense(b: &mut Bencher) {
// ~50% of values match [250, 750]
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::filter_vec_in_place(250..=750, 0, &mut v);
v
});
}
#[bench]
fn bench_filter_vec_sparse(b: &mut Bencher) {
// ~5% of values match [0, 50]
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::filter_vec_in_place(0..=50, 0, &mut v);
v
});
}
#[bench]
fn bench_filter_vec_full(b: &mut Bencher) {
// 100% of values match
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::filter_vec_in_place(0..=MAX_VAL, 0, &mut v);
v
});
}
#[bench]
fn bench_filter_vec_scalar_dense(b: &mut Bencher) {
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::scalar_filter_vec_in_place(250..=750, 0, &mut v);
v
});
}
#[bench]
fn bench_filter_vec_scalar_sparse(b: &mut Bencher) {
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::scalar_filter_vec_in_place(0..=50, 0, &mut v);
v
});
}
// --- NEON / SVE / SVE2 (aarch64 only) ---
#[bench]
#[cfg(target_arch = "aarch64")]
fn bench_filter_vec_neon_dense(b: &mut Bencher) {
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::neon_filter_vec_in_place(250..=750, 0, &mut v);
v
});
}
#[bench]
#[cfg(target_arch = "aarch64")]
fn bench_filter_vec_neon_sparse(b: &mut Bencher) {
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::neon_filter_vec_in_place(0..=50, 0, &mut v);
v
});
}
#[bench]
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
fn bench_filter_vec_sve_dense(b: &mut Bencher) {
if !std::arch::is_aarch64_feature_detected!("sve") {
return;
}
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::sve_filter_vec_in_place(250..=750, 0, &mut v);
v
});
}
#[bench]
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
fn bench_filter_vec_sve_sparse(b: &mut Bencher) {
if !std::arch::is_aarch64_feature_detected!("sve") {
return;
}
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::sve_filter_vec_in_place(0..=50, 0, &mut v);
v
});
}
}

View File

@@ -1,8 +1,17 @@
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple"), nightly))]
use std::arch::is_aarch64_feature_detected;
use std::ops::RangeInclusive;
#[cfg(target_arch = "x86_64")]
mod avx2;
#[cfg(target_arch = "aarch64")]
mod neon;
// SVE intrinsics are nightly-only and not exposed on aarch64-apple-darwin.
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple"), nightly))]
mod sve;
mod scalar;
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
@@ -10,6 +19,10 @@ mod scalar;
enum FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
AVX2 = 0u8,
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple"), nightly))]
Sve = 3u8,
#[cfg(target_arch = "aarch64")]
Neon = 2u8,
Scalar = 1u8,
}
@@ -19,29 +32,64 @@ impl FilterImplPerInstructionSet {
match *self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX2 => is_x86_feature_detected!("avx2"),
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple"), nightly))]
FilterImplPerInstructionSet::Sve => is_aarch64_feature_detected!("sve"),
// TIL Neon is required on aarch 64.
#[cfg(target_arch = "aarch64")]
FilterImplPerInstructionSet::Neon => true,
FilterImplPerInstructionSet::Scalar => true,
}
}
}
// List of available implementation in preferred order.
// List of available implementations in preferred order.
#[cfg(target_arch = "x86_64")]
const IMPLS: [FilterImplPerInstructionSet; 2] = [
FilterImplPerInstructionSet::AVX2,
FilterImplPerInstructionSet::Scalar,
];
#[cfg(not(target_arch = "x86_64"))]
// Non-Apple aarch64 with nightly: try SVE, NEON, Scalar.
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple"), nightly))]
const IMPLS: [FilterImplPerInstructionSet; 3] = [
FilterImplPerInstructionSet::Sve,
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
];
// Non-Apple aarch64 without nightly: SVE unavailable; use NEON or Scalar.
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple"), not(nightly)))]
const IMPLS: [FilterImplPerInstructionSet; 2] = [
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
];
// Apple aarch64 (M-series): SVE not available; use NEON or Scalar.
#[cfg(all(target_arch = "aarch64", target_vendor = "apple"))]
const IMPLS: [FilterImplPerInstructionSet; 2] = [
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
];
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
const IMPLS: [FilterImplPerInstructionSet; 1] = [FilterImplPerInstructionSet::Scalar];
impl FilterImplPerInstructionSet {
#[inline]
#[allow(unused_variables)] // on non-x86_64, code is unused.
#[allow(unused_variables)]
fn from(code: u8) -> FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
if code == FilterImplPerInstructionSet::AVX2 as u8 {
return FilterImplPerInstructionSet::AVX2;
}
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple"), nightly))]
if code == FilterImplPerInstructionSet::Sve as u8 {
return FilterImplPerInstructionSet::Sve;
}
#[cfg(target_arch = "aarch64")]
if code == FilterImplPerInstructionSet::Neon as u8 {
return FilterImplPerInstructionSet::Neon;
}
FilterImplPerInstructionSet::Scalar
}
@@ -50,6 +98,10 @@ impl FilterImplPerInstructionSet {
match self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX2 => avx2::filter_vec_in_place(range, offset, output),
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple"), nightly))]
FilterImplPerInstructionSet::Sve => sve::filter_vec_in_place(range, offset, output),
#[cfg(target_arch = "aarch64")]
FilterImplPerInstructionSet::Neon => neon::filter_vec_in_place(range, offset, output),
FilterImplPerInstructionSet::Scalar => {
scalar::filter_vec_in_place(range, offset, output)
}
@@ -63,7 +115,6 @@ fn get_best_available_instruction_set() -> FilterImplPerInstructionSet {
static INSTRUCTION_SET_BYTE: AtomicU8 = AtomicU8::new(u8::MAX);
let instruction_set_byte: u8 = INSTRUCTION_SET_BYTE.load(Ordering::Relaxed);
if instruction_set_byte == u8::MAX {
// Let's initialize the instruction set and cache it.
let instruction_set = IMPLS
.into_iter()
.find(FilterImplPerInstructionSet::is_available)
@@ -78,14 +129,29 @@ pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut
get_best_available_instruction_set().filter_vec_in_place(range, offset, output)
}
#[doc(hidden)]
pub fn scalar_filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
scalar::filter_vec_in_place(range, offset, output);
}
#[doc(hidden)]
#[cfg(target_arch = "aarch64")]
pub fn neon_filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
neon::filter_vec_in_place(range, offset, output);
}
#[doc(hidden)]
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple"), nightly))]
pub fn sve_filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
sve::filter_vec_in_place(range, offset, output);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_best_available_instruction_set() {
// This does not test much unfortunately.
// We just make sure the function returns without crashing and returns the same result.
let instruction_set = get_best_available_instruction_set();
assert_eq!(get_best_available_instruction_set(), instruction_set);
}
@@ -102,6 +168,43 @@ mod tests {
}
}
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple"), nightly))]
#[test]
fn test_instruction_set_to_code_from_code() {
for instruction_set in [
FilterImplPerInstructionSet::Sve,
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
] {
let code = instruction_set as u8;
assert_eq!(instruction_set, FilterImplPerInstructionSet::from(code));
}
}
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple"), not(nightly)))]
#[test]
fn test_instruction_set_to_code_from_code() {
for instruction_set in [
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
] {
let code = instruction_set as u8;
assert_eq!(instruction_set, FilterImplPerInstructionSet::from(code));
}
}
#[cfg(all(target_arch = "aarch64", target_vendor = "apple"))]
#[test]
fn test_instruction_set_to_code_from_code() {
for instruction_set in [
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
] {
let code = instruction_set as u8;
assert_eq!(instruction_set, FilterImplPerInstructionSet::from(code));
}
}
fn test_filter_impl_empty_aux(filter_impl: FilterImplPerInstructionSet) {
let mut output = vec![];
filter_impl.filter_vec_in_place(0..=u32::MAX, 0, &mut output);
@@ -126,11 +229,20 @@ mod tests {
assert_eq!(&output, &[1, 3, 4, 5, 6, 7, 8]);
}
fn test_filter_impl_empty_range_aux(filter_impl: FilterImplPerInstructionSet) {
// start > end: RangeInclusive::contains always returns false; output must be empty.
// The SVE path's wrapping_sub would otherwise produce a huge range_width.
let mut output = vec![3, 2, 1, 5, 11, 2, 5, 10, 2];
filter_impl.filter_vec_in_place(10..=5, 0, &mut output);
assert_eq!(&output, &[]);
}
fn test_filter_impl_test_suite(filter_impl: FilterImplPerInstructionSet) {
test_filter_impl_empty_aux(filter_impl);
test_filter_impl_simple_aux(filter_impl);
test_filter_impl_simple_aux_shifted(filter_impl);
test_filter_impl_simple_outside_i32_range(filter_impl);
test_filter_impl_empty_range_aux(filter_impl);
}
#[test]
@@ -141,6 +253,20 @@ mod tests {
}
}
#[test]
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple"), nightly))]
fn test_filter_implementation_sve() {
if FilterImplPerInstructionSet::Sve.is_available() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Sve);
}
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_filter_implementation_neon() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Neon);
}
#[test]
fn test_filter_implementation_scalar() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Scalar);
@@ -162,4 +288,19 @@ mod tests {
}
}
}
#[cfg(target_arch = "aarch64")]
proptest::proptest! {
#[test]
fn test_filter_compare_scalar_and_neon_impl_proptest(
start in proptest::prelude::any::<u32>(),
end in proptest::prelude::any::<u32>(),
offset in 0u32..2u32,
mut vals in proptest::collection::vec(0..u32::MAX, 0..30)) {
let mut vals_clone = vals.clone();
FilterImplPerInstructionSet::Neon.filter_vec_in_place(start..=end, offset, &mut vals);
FilterImplPerInstructionSet::Scalar.filter_vec_in_place(start..=end, offset, &mut vals_clone);
assert_eq!(&vals, &vals_clone);
}
}
}

View File

@@ -0,0 +1,113 @@
use std::arch::aarch64::*;
use std::ops::RangeInclusive;
const NUM_LANES: usize = 4;
// Compacts matching lanes to the front using a byte-level shuffle.
// `mask` is a 4-bit value: bit k=1 means lane k should appear in the output.
#[inline]
#[target_feature(enable = "neon")]
unsafe fn compact(data: uint32x4_t, mask: u8) -> uint32x4_t {
unsafe {
// SAFETY: mask is always in [0, 15] by construction (max sum of [1,2,4,8]).
// BYTE_SHUFFLE_TABLE has 16 entries, so this is always in bounds.
let shuffle = BYTE_SHUFFLE_TABLE.get_unchecked(mask as usize);
let shuffle_vec = vld1q_u8(shuffle.as_ptr());
vreinterpretq_u32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(data), shuffle_vec))
}
}
#[inline(never)]
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
let num_words = output.len() / NUM_LANES;
let mut output_len = unsafe {
filter_vec_neon_aux(
output.as_ptr(),
range.clone(),
output.as_mut_ptr(),
offset,
num_words,
)
};
let remainder_start = num_words * NUM_LANES;
for i in remainder_start..output.len() {
let val = output[i];
output[output_len] = offset + i as u32;
output_len += if range.contains(&val) { 1 } else { 0 };
}
output.truncate(output_len);
}
#[target_feature(enable = "neon")]
unsafe fn filter_vec_neon_aux(
input: *const u32,
range: RangeInclusive<u32>,
output: *mut u32,
offset: u32,
num_words: usize,
) -> usize {
unsafe {
let mut input = input;
let mut output_tail = output;
let range_start_simd = vdupq_n_u32(*range.start());
let range_end_simd = vdupq_n_u32(*range.end());
let mut ids = vld1q_u32([offset, offset + 1, offset + 2, offset + 3].as_ptr());
let shift = vdupq_n_u32(NUM_LANES as u32);
let bit_weights = vld1q_u32([1u32, 2, 4, 8].as_ptr());
for _ in 0..num_words {
let word = vld1q_u32(input);
// Unsigned compares: CMHS (compare higher or same) tests `word >= start`
// and `end >= word`. ANDing both gives the inside-range mask directly,
// which is cheaper than computing `outside` and then negating.
let ge_start = vcgeq_u32(word, range_start_simd);
let le_end = vcleq_u32(word, range_end_simd);
// inside[k] = 0xFFFFFFFF if val[k] is in range, 0 otherwise.
let inside = vandq_u32(ge_start, le_end);
// Build the 4-bit mask: AND bit_weights with the inside lane mask, so each
// inside lane contributes its bit_weight (1, 2, 4, or 8). Summing yields the
// 4-bit mask in one addv.
let inside_bits = vandq_u32(bit_weights, inside);
let mask = vaddvq_u32(inside_bits) as u8;
// mask is mathematically bounded: max value is 1+2+4+8=15 (all lanes match)
debug_assert!(mask <= 15, "mask must fit in 4 bits: {}", mask);
// Count of matching lanes = popcount(mask). Derives the count directly from
// the mask instead of running a parallel SIMD reduction over `outside`.
let added_len = mask.count_ones() as usize;
// Safe because mask is guaranteed to be in [0, 15]
let filtered_ids = compact(ids, mask);
vst1q_u32(output_tail, filtered_ids);
output_tail = output_tail.add(added_len);
ids = vaddq_u32(ids, shift);
input = input.add(NUM_LANES);
}
output_tail.offset_from(output) as usize
}
}
// Byte shuffle patterns to compact matching lanes to the front of the vector.
// Index is a 4-bit mask: bit k=1 means lane k (bytes 4k..4k+3) is in-range.
// The j-th set bit determines which input lane goes to output position j.
const BYTE_SHUFFLE_TABLE: [[u8; 16]; 16] = [
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0000: none
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0001: lane 0
[4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0010: lane 1
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0011: lanes 0,1
[8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0100: lane 2
[0, 1, 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0101: lanes 0,2
[4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0110: lanes 1,2
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3], // 0b0111: lanes 0,1,2
[12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b1000: lane 3
[0, 1, 2, 3, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3], // 0b1001: lanes 0,3
[4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3], // 0b1010: lanes 1,3
[0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3], // 0b1011: lanes 0,1,3
[8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3], // 0b1100: lanes 2,3
[0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3], // 0b1101: lanes 0,2,3
[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3], // 0b1110: lanes 1,2,3
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], // 0b1111: all lanes
];

View File

@@ -0,0 +1,112 @@
use std::arch::aarch64::*;
use std::ops::RangeInclusive;
// SVE vector length (in u32 lanes) is not a compile-time constant; query at runtime.
// Safe to call only when SVE is confirmed available via is_aarch64_feature_detected!("sve").
#[target_feature(enable = "sve")]
fn num_lanes() -> usize {
svcntw() as usize
}
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
if range.start() > range.end() {
output.clear();
return;
}
let vl = unsafe { num_lanes() };
let num_words = output.len() / vl;
let range_start = *range.start();
// Unsigned subtraction trick: val ∈ [lo, hi] ↔ (val - lo) ≤ᵤ (hi - lo).
// Values below lo wrap around to large u32, so the single unsigned ≤ excludes them.
let range_width = range.end().wrapping_sub(range_start);
let mut output_len = unsafe {
filter_vec_sve_aux(
output.as_ptr(),
range_start,
range_width,
output.as_mut_ptr(),
offset,
num_words,
vl,
)
};
let remainder_start = num_words * vl;
for i in remainder_start..output.len() {
let val = output[i];
output[output_len] = offset + i as u32;
output_len += if range.contains(&val) { 1 } else { 0 };
}
output.truncate(output_len);
}
#[target_feature(enable = "sve")]
unsafe fn filter_vec_sve_aux(
input: *const u32,
range_start: u32,
range_width: u32,
output: *mut u32,
offset: u32,
num_words: usize,
vl: usize,
) -> usize {
unsafe {
let all_true = svptrue_b32();
let range_start_simd = svdup_n_u32(range_start);
let range_width_simd = svdup_n_u32(range_width);
// ids_a covers [offset .. offset+vl), ids_b covers the next vl ids.
// Keeping them separate breaks the loop-carried dependency through ids so
// both compact/cntp chains are fully independent within each unrolled body.
let mut ids_a = svindex_u32(offset, 1);
let step = svdup_n_u32(vl as u32);
let step2 = svdup_n_u32(2 * vl as u32);
let mut ids_b = svadd_u32_x(all_true, ids_a, step);
let mut input = input;
let mut output_tail = output;
// Unrolled ×2: both cntp calls have independent inputs and execute in parallel.
// The two output_tail updates are sequential but together cost 4+1+1=6 cy per
// pair vs 5+5=10 cy for two scalar iterations, breaking the cntp latency chain.
let num_pairs = num_words / 2;
for _ in 0..num_pairs {
let word_a = svld1_u32(all_true, input);
let word_b = svld1_u32(all_true, input.add(vl));
let shifted_a = svsub_u32_x(all_true, word_a, range_start_simd);
let shifted_b = svsub_u32_x(all_true, word_b, range_start_simd);
let in_range_a = svcmple_u32(all_true, shifted_a, range_width_simd);
let in_range_b = svcmple_u32(all_true, shifted_b, range_width_simd);
let compacted_a = svcompact_u32(in_range_a, ids_a);
let compacted_b = svcompact_u32(in_range_b, ids_b);
// cntp_a and cntp_b have independent inputs: OOO engine issues them in parallel.
let added_len_a = svcntp_b32(all_true, in_range_a) as usize;
let added_len_b = svcntp_b32(all_true, in_range_b) as usize;
// Write the full vector — only the first added_len slots are valid.
// Subsequent iterations overwrite the trailing zeros before truncate.
svst1_u32(all_true, output_tail, compacted_a);
output_tail = output_tail.add(added_len_a);
svst1_u32(all_true, output_tail, compacted_b);
output_tail = output_tail.add(added_len_b);
ids_a = svadd_u32_x(all_true, ids_a, step2);
ids_b = svadd_u32_x(all_true, ids_b, step2);
input = input.add(2 * vl);
}
// Handle an odd trailing word.
if num_words % 2 == 1 {
let word = svld1_u32(all_true, input);
let shifted = svsub_u32_x(all_true, word, range_start_simd);
let in_range = svcmple_u32(all_true, shifted, range_width_simd);
let added_len = svcntp_b32(all_true, in_range) as usize;
let compacted_ids = svcompact_u32(in_range, ids_a);
svst1_u32(all_true, output_tail, compacted_ids);
output_tail = output_tail.add(added_len);
}
output_tail.offset_from(output) as usize
}
}

View File

@@ -1,6 +1,13 @@
// SVE/SVE2 intrinsics require nightly; only unlock when build.rs detects a nightly compiler.
#![cfg_attr(
all(target_arch = "aarch64", not(target_vendor = "apple"), nightly),
feature(stdarch_aarch64_sve)
)]
mod bitpacker;
mod blocked_bitpacker;
mod filter_vec;
#[doc(hidden)]
pub mod filter_vec;
use std::cmp::Ordering;

View File

@@ -23,7 +23,7 @@ downcast-rs = "2.0.1"
proptest = "1"
more-asserts = "0.3.1"
rand = "0.9"
binggan = "0.16.1"
binggan = "0.17.0"
[[bench]]
name = "bench_merge"

View File

@@ -19,6 +19,6 @@ time = { version = "0.3.47", features = ["serde-well-known"] }
serde = { version = "1.0.136", features = ["derive"] }
[dev-dependencies]
binggan = "0.16.1"
binggan = "0.17.0"
proptest = "1.0.0"
rand = "0.9"

View File

@@ -115,6 +115,71 @@ pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
fast_field_names
}
/// Validates that all fields referenced in the aggregation request exist in the schema
/// and are configured as fast fields.
///
/// This is a convenience function for upfront validation before executing aggregations.
/// Returns an error if any field doesn't exist or is not a fast field.
///
/// Validation is intentionally opt-in rather than baked into aggregation execution: the
/// default lenient behavior (returning empty results for missing fields) supports
/// schema evolution and federated queries where the same request runs against segments
/// or indices with different schemas.
///
/// # Example
/// ```
/// use tantivy::aggregation::agg_req::{Aggregations, validate_aggregation_fields_exist};
/// use tantivy::schema::{Schema, FAST};
/// use tantivy::Index;
///
/// # fn main() -> tantivy::Result<()> {
/// // Create a simple index
/// let mut schema_builder = Schema::builder();
/// schema_builder.add_f64_field("price", FAST);
/// let schema = schema_builder.build();
/// let index = Index::create_in_ram(schema);
///
/// // Parse aggregation request
/// let agg_req: Aggregations = serde_json::from_str(r#"{
/// "avg_price": { "avg": { "field": "price" } }
/// }"#)?;
///
/// let reader = index.reader()?;
/// let searcher = reader.searcher();
///
/// // Validate fields before executing
/// for segment_reader in searcher.segment_readers() {
/// validate_aggregation_fields_exist(&agg_req, segment_reader)?;
/// }
/// # Ok(())
/// # }
/// ```
pub fn validate_aggregation_fields_exist(
aggs: &Aggregations,
reader: &crate::SegmentReader,
) -> crate::Result<()> {
let field_names = get_fast_field_names(aggs);
let schema = reader.schema();
for field_name in field_names {
// Check if the field is either directly in the schema or could be part of a json field
// present in the schema, and verify it's a fast field.
if let Some((field, _path)) = schema.find_field(&field_name) {
let field_type = schema.get_field_entry(field).field_type();
if !field_type.is_fast() {
return Err(crate::TantivyError::SchemaError(format!(
"Field '{}' is not a fast field. Aggregations require fast fields.",
field_name
)));
}
} else {
return Err(crate::TantivyError::FieldNotFound(field_name));
}
}
Ok(())
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// All aggregation types.
pub enum AggregationVariants {

View File

@@ -1436,3 +1436,46 @@ fn test_aggregation_on_json_object_mixed_numerical_segments() {
)
);
}
#[test]
fn test_aggregation_field_validation_helper() {
// Test the standalone validation helper function for field validation
let index = get_test_index_2_segments(false).unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
// Test with invalid field
let agg_req: Aggregations = serde_json::from_str(
r#"{
"avg_test": {
"avg": { "field": "nonexistent_field" }
}
}"#,
)
.unwrap();
let result =
crate::aggregation::agg_req::validate_aggregation_fields_exist(&agg_req, segment_reader);
assert!(result.is_err());
match result {
Err(crate::TantivyError::FieldNotFound(field_name)) => {
assert_eq!(field_name, "nonexistent_field");
}
_ => panic!("Expected FieldNotFound error, got: {:?}", result),
}
// Test with valid field
let agg_req: Aggregations = serde_json::from_str(
r#"{
"avg_test": {
"avg": { "field": "score" }
}
}"#,
)
.unwrap();
let result =
crate::aggregation::agg_req::validate_aggregation_fields_exist(&agg_req, segment_reader);
assert!(result.is_ok());
}

View File

@@ -166,8 +166,12 @@ impl CouponCache {
let should_use_dense =
highest_term_ord < 1_000_000u64 || highest_term_ord < num_terms as u64 * 3u64;
if should_use_dense {
let mut coupon_map: Vec<Coupon> = vec![Coupon::EMPTY; highest_term_ord as usize + 1];
for (term_ord, coupon) in term_ords.into_iter().zip(coupons.into_iter()) {
// We don't really care about the value here. We will populate all the values we will
// read anyway.
let uninitialized_coupon = Coupon::from_hash(0);
let mut coupon_map: Vec<Coupon> =
vec![uninitialized_coupon; highest_term_ord as usize + 1];
for (term_ord, coupon) in term_ords.into_iter().zip(coupons) {
coupon_map[term_ord as usize] = coupon;
}
CouponCache::Dense {
@@ -821,7 +825,7 @@ impl<'de> Deserialize<'de> for CardinalityCollector {
impl CardinalityCollector {
fn new(salt: u8) -> Self {
Self {
sketch: HllSketch::new(LG_K, HllType::Hll4),
sketch: HllSketch::new(LG_K, HllType::Hll8),
salt,
}
}
@@ -852,7 +856,7 @@ impl CardinalityCollector {
let mut union = HllUnion::new(LG_K);
union.update(&self.sketch);
union.update(&right.sketch);
self.sketch = union.to_sketch(HllType::Hll4);
self.sketch = union.to_sketch(HllType::Hll8);
Ok(())
}
}

View File

@@ -6,6 +6,7 @@ use common::{ByteCount, HasLen};
use fnv::FnvHashMap;
use itertools::Itertools;
use crate::directory::error::OpenReadError;
use crate::directory::{CompositeFile, FileSlice};
use crate::error::DataCorruption;
use crate::fastfield::{intersect_alive_bitsets, AliveBitSet, FacetReader, FastFieldReaders};
@@ -159,12 +160,10 @@ impl SegmentReader {
let postings_file = segment.open_read(SegmentComponent::Postings)?;
let postings_composite = CompositeFile::open(&postings_file)?;
let positions_composite = {
if let Ok(positions_file) = segment.open_read(SegmentComponent::Positions) {
CompositeFile::open(&positions_file)?
} else {
CompositeFile::empty()
}
let positions_composite = match segment.open_read(SegmentComponent::Positions) {
Ok(positions_file) => CompositeFile::open(&positions_file)?,
Err(OpenReadError::FileDoesNotExist(_)) => CompositeFile::empty(),
Err(open_read_error) => return Err(open_read_error.into()),
};
let schema = segment.schema();

View File

@@ -14,11 +14,8 @@ use itertools::Itertools;
use tantivy_fst::Automaton;
use tantivy_fst::automaton::AlwaysMatch;
use crate::sstable_index_v3::SSTableIndexV3Empty;
use crate::streamer::{Streamer, StreamerBuilder};
use crate::{
BlockAddr, DeltaReader, Reader, SSTable, SSTableIndex, SSTableIndexV3, TermOrdinal, VoidSSTable,
};
use crate::{BlockAddr, DeltaReader, Reader, SSTable, SSTableIndex, TermOrdinal, VoidSSTable};
/// An SSTable is a sorted map that associates sorted `&[u8]` keys
/// to any kind of typed values.
@@ -288,33 +285,7 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
let (sstable_slice, index_slice) = main_slice.split(index_offset as usize);
let sstable_index_bytes = index_slice.read_bytes()?;
let sstable_index = match version {
2 => SSTableIndex::V2(
crate::sstable_index_v2::SSTableIndex::load(sstable_index_bytes).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption")
})?,
),
3 => {
let (sstable_index_bytes, mut footerv3_len_bytes) = sstable_index_bytes.rsplit(8);
let store_offset = u64::deserialize(&mut footerv3_len_bytes)?;
if store_offset != 0 {
SSTableIndex::V3(
SSTableIndexV3::load(sstable_index_bytes, store_offset).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption")
})?,
)
} else {
// if store_offset is zero, there is no index, so we build a pseudo-index
// assuming a single block of sstable covering everything.
SSTableIndex::V3Empty(SSTableIndexV3Empty::load(index_offset as usize))
}
}
_ => {
return Err(io::Error::other(format!(
"Unsupported sstable version, expected one of [2, 3], found {version}"
)));
}
};
let sstable_index = SSTableIndex::open(version, index_offset, sstable_index_bytes)?;
Ok(Dictionary {
sstable_slice,
@@ -525,10 +496,15 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
// Open the block for the first ordinal.
let mut bytes = Vec::new();
let mut current_block_addr = self.sstable_index.get_block_with_ord(ord);
let (mut current_block_addr, block_id) = self.sstable_index.get_and_locate_with_ord(ord);
let mut current_sstable_delta_reader =
self.sstable_delta_reader_block(current_block_addr.clone())?;
let mut current_block_ordinal = current_block_addr.first_ordinal;
let mut current_block_end_bound = self
.sstable_index
.get_block(block_id + 1)
.map(|block_addr| block_addr.first_ordinal)
.unwrap_or(u64::MAX);
loop {
// move to the ord inside the current block
@@ -557,17 +533,19 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
}
};
// TODO optimization: it is silly to do a binary search to get the block every single
// time.
//
// Check if block changed for new term_ord
let new_block_addr = self.sstable_index.get_block_with_ord(next_ord);
if new_block_addr != current_block_addr {
if next_ord >= current_block_end_bound {
let (new_block_addr, block_id) =
self.sstable_index.get_and_locate_with_ord(next_ord);
current_block_addr = new_block_addr;
current_block_ordinal = current_block_addr.first_ordinal;
current_sstable_delta_reader =
self.sstable_delta_reader_block(current_block_addr.clone())?;
bytes.clear();
current_block_end_bound = self
.sstable_index
.get_block(block_id + 1)
.map(|block_addr| block_addr.first_ordinal)
.unwrap_or(u64::MAX)
}
ord = next_ord;
}

319
sstable/src/index/mod.rs Normal file
View File

@@ -0,0 +1,319 @@
pub(crate) mod v2;
pub(crate) mod v3;
use std::io::{self, Read, Write};
use std::ops::Range;
use common::{BinarySerializable, FixedSize, OwnedBytes};
use tantivy_fst::{Automaton, MapBuilder};
use crate::{TermOrdinal, common_prefix_len};
#[derive(Debug, Clone)]
pub enum SSTableIndex {
V2(v2::SSTableIndex),
V3(v3::SSTableIndexV3),
V3Empty(v3::SSTableIndexV3Empty),
}
impl SSTableIndex {
pub(crate) fn open(
version: u32,
index_offset: u64,
index_bytes: OwnedBytes,
) -> io::Result<Self> {
let index = match version {
2 => {
SSTableIndex::V2(v2::SSTableIndex::load(index_bytes).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption")
})?)
}
3 => {
let (index_bytes, mut footerv3_len_bytes) = index_bytes.rsplit(8);
let store_offset = u64::deserialize(&mut footerv3_len_bytes)?;
if store_offset != 0 {
SSTableIndex::V3(v3::SSTableIndexV3::load(index_bytes, store_offset).map_err(
|_| io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption"),
)?)
} else {
// if store_offset is zero, there is no index, so we build a pseudo-index
// assuming a single block of sstable covering everything.
SSTableIndex::V3Empty(v3::SSTableIndexV3Empty::load(index_offset as usize))
}
}
_ => {
return Err(io::Error::other(format!(
"Unsupported sstable version, expected one of [2, 3], found {version}"
)));
}
};
Ok(index)
}
/// Get the [`BlockAddr`] of the requested block.
pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block(block_id as usize),
SSTableIndex::V3(v3_index) => v3_index.get_block(block_id),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block(block_id),
}
}
/// Get the block id of the block that would contain `key`.
///
/// Returns None if `key` is lexicographically after the last key recorded.
pub(crate) fn locate_with_key(&self, key: &[u8]) -> Option<u64> {
match self {
SSTableIndex::V2(v2_index) => v2_index.locate_with_key(key).map(|i| i as u64),
SSTableIndex::V3(v3_index) => v3_index.locate_with_key(key),
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_key(key),
}
}
/// Get the [`BlockAddr`] of the block that would contain `key`.
///
/// Returns None if `key` is lexicographically after the last key recorded.
pub fn get_block_with_key(&self, key: &[u8]) -> Option<BlockAddr> {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block_with_key(key),
SSTableIndex::V3(v3_index) => v3_index.get_block_with_key(key),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_key(key),
}
}
pub(crate) fn locate_with_ord(&self, ord: TermOrdinal) -> u64 {
match self {
SSTableIndex::V2(v2_index) => v2_index.locate_with_ord(ord) as u64,
SSTableIndex::V3(v3_index) => v3_index.locate_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_ord(ord),
}
}
/// Get the [`BlockAddr`] of the block containing the `ord`-th term.
pub(crate) fn get_block_with_ord(&self, ord: TermOrdinal) -> BlockAddr {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block_with_ord(ord),
SSTableIndex::V3(v3_index) => v3_index.get_block_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_ord(ord),
}
}
pub(crate) fn get_and_locate_with_ord(&self, ord: TermOrdinal) -> (BlockAddr, u64) {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_and_locate_with_ord(ord),
SSTableIndex::V3(v3_index) => v3_index.get_and_locate_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_and_locate_with_ord(ord),
}
}
pub fn get_block_for_automaton<'a>(
&'a self,
automaton: &'a impl Automaton,
) -> impl Iterator<Item = (u64, BlockAddr)> + 'a {
match self {
SSTableIndex::V2(v2_index) => {
BlockIter::V2(v2_index.get_block_for_automaton(automaton))
}
SSTableIndex::V3(v3_index) => {
BlockIter::V3(v3_index.get_block_for_automaton(automaton))
}
SSTableIndex::V3Empty(v3_empty) => {
BlockIter::V3Empty(std::iter::once((0, v3_empty.block_addr.clone())))
}
}
}
}
enum BlockIter<V2, V3, T> {
V2(V2),
V3(V3),
V3Empty(std::iter::Once<T>),
}
impl<V2: Iterator<Item = T>, V3: Iterator<Item = T>, T> Iterator for BlockIter<V2, V3, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self {
BlockIter::V2(v2) => v2.next(),
BlockIter::V3(v3) => v3.next(),
BlockIter::V3Empty(once) => once.next(),
}
}
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct BlockAddr {
pub first_ordinal: u64,
pub byte_range: Range<usize>,
}
impl BlockAddr {
fn to_block_start(&self) -> BlockStartAddr {
BlockStartAddr {
first_ordinal: self.first_ordinal,
byte_range_start: self.byte_range.start,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct BlockStartAddr {
first_ordinal: u64,
byte_range_start: usize,
}
impl BlockStartAddr {
fn to_block_addr(&self, byte_range_end: usize) -> BlockAddr {
BlockAddr {
first_ordinal: self.first_ordinal,
byte_range: self.byte_range_start..byte_range_end,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct BlockMeta {
/// Any byte string that is lexicographically greater or equal to
/// the last key in the block,
/// and yet strictly smaller than the first key in the next block.
pub last_key_or_greater: Vec<u8>,
pub block_addr: BlockAddr,
}
impl BinarySerializable for BlockStartAddr {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
let start = self.byte_range_start as u64;
start.serialize(writer)?;
self.first_ordinal.serialize(writer)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let byte_range_start = u64::deserialize(reader)? as usize;
let first_ordinal = u64::deserialize(reader)?;
Ok(BlockStartAddr {
first_ordinal,
byte_range_start,
})
}
// Provided method
fn num_bytes(&self) -> u64 {
BlockStartAddr::SIZE_IN_BYTES as u64
}
}
impl FixedSize for BlockStartAddr {
const SIZE_IN_BYTES: usize = 2 * u64::SIZE_IN_BYTES;
}
/// Given that left < right,
/// mutates `left into a shorter byte string left'` that
/// matches `left <= left' < right`.
fn find_shorter_str_in_between(left: &mut Vec<u8>, right: &[u8]) {
assert!(&left[..] < right);
let common_len = common_prefix_len(left, right);
if left.len() == common_len {
return;
}
// It is possible to do one character shorter in some case,
// but it is not worth the extra complexity
for pos in (common_len + 1)..left.len() {
if left[pos] != u8::MAX {
left[pos] += 1;
left.truncate(pos + 1);
return;
}
}
}
#[derive(Default)]
pub struct SSTableIndexBuilder {
blocks: Vec<BlockMeta>,
}
impl SSTableIndexBuilder {
/// In order to make the index as light as possible, we
/// try to find a shorter alternative to the last key of the last block
/// that is still smaller than the next key.
pub(crate) fn shorten_last_block_key_given_next_key(&mut self, next_key: &[u8]) {
if let Some(last_block) = self.blocks.last_mut() {
find_shorter_str_in_between(&mut last_block.last_key_or_greater, next_key);
}
}
pub fn add_block(&mut self, last_key: &[u8], byte_range: Range<usize>, first_ordinal: u64) {
self.blocks.push(BlockMeta {
last_key_or_greater: last_key.to_vec(),
block_addr: BlockAddr {
byte_range,
first_ordinal,
},
})
}
pub fn serialize<W: std::io::Write>(&self, wrt: W) -> io::Result<u64> {
if self.blocks.len() <= 1 {
return Ok(0);
}
let counting_writer = common::CountingWriter::wrap(wrt);
let mut map_builder = MapBuilder::new(counting_writer).map_err(fst_error_to_io_error)?;
for (i, block) in self.blocks.iter().enumerate() {
map_builder
.insert(&block.last_key_or_greater, i as u64)
.map_err(fst_error_to_io_error)?;
}
let counting_writer = map_builder.into_inner().map_err(fst_error_to_io_error)?;
let written_bytes = counting_writer.written_bytes();
let mut wrt = counting_writer.finish();
let mut block_store_writer = v3::BlockAddrStoreWriter::new();
for block in &self.blocks {
block_store_writer.write_block_meta(block.block_addr.clone())?;
}
block_store_writer.serialize(&mut wrt)?;
Ok(written_bytes)
}
}
fn fst_error_to_io_error(error: tantivy_fst::Error) -> io::Error {
match error {
tantivy_fst::Error::Fst(fst_error) => io::Error::other(fst_error),
tantivy_fst::Error::Io(ioerror) => ioerror,
}
}
#[cfg(test)]
mod tests {
#[track_caller]
fn test_find_shorter_str_in_between_aux(left: &[u8], right: &[u8]) {
let mut left_buf = left.to_vec();
super::find_shorter_str_in_between(&mut left_buf, right);
assert!(left_buf.len() <= left.len());
assert!(left <= &left_buf);
assert!(&left_buf[..] < right);
}
#[test]
fn test_find_shorter_str_in_between() {
test_find_shorter_str_in_between_aux(b"", b"hello");
test_find_shorter_str_in_between_aux(b"abc", b"abcd");
test_find_shorter_str_in_between_aux(b"abcd", b"abd");
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[1]);
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[0, 0, 1]);
test_find_shorter_str_in_between_aux(&[0, 0, 255, 255, 255, 0u8], &[0, 1]);
}
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn test_proptest_find_shorter_str(left in any::<Vec<u8>>(), right in any::<Vec<u8>>()) {
if left < right {
test_find_shorter_str_in_between_aux(&left, &right);
}
}
}
}

View File

@@ -77,6 +77,13 @@ impl SSTableIndex {
self.get_block(self.locate_with_ord(ord)).unwrap()
}
pub(crate) fn get_and_locate_with_ord(&self, ord: TermOrdinal) -> (BlockAddr, u64) {
let location = self.locate_with_ord(ord);
// locate_with_ord always returns an index within range
let block_addr = self.get_block(location).unwrap();
(block_addr, location as u64)
}
pub(crate) fn get_block_for_automaton<'a>(
&'a self,
automaton: &'a impl Automaton,

View File

@@ -1,106 +1,14 @@
use std::io::{self, Read, Write};
use std::ops::Range;
use std::sync::Arc;
use common::{BinarySerializable, FixedSize, OwnedBytes};
use tantivy_bitpacker::{BitPacker, compute_num_bits};
use tantivy_fst::raw::Fst;
use tantivy_fst::{Automaton, IntoStreamer, Map, MapBuilder, Streamer};
use tantivy_fst::{Automaton, IntoStreamer, Map, Streamer};
use super::{BlockAddr, BlockStartAddr};
use crate::block_match_automaton::can_block_match_automaton;
use crate::{SSTableDataCorruption, TermOrdinal, common_prefix_len};
#[derive(Debug, Clone)]
pub enum SSTableIndex {
V2(crate::sstable_index_v2::SSTableIndex),
V3(SSTableIndexV3),
V3Empty(SSTableIndexV3Empty),
}
impl SSTableIndex {
/// Get the [`BlockAddr`] of the requested block.
pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block(block_id as usize),
SSTableIndex::V3(v3_index) => v3_index.get_block(block_id),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block(block_id),
}
}
/// Get the block id of the block that would contain `key`.
///
/// Returns None if `key` is lexicographically after the last key recorded.
pub(crate) fn locate_with_key(&self, key: &[u8]) -> Option<u64> {
match self {
SSTableIndex::V2(v2_index) => v2_index.locate_with_key(key).map(|i| i as u64),
SSTableIndex::V3(v3_index) => v3_index.locate_with_key(key),
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_key(key),
}
}
/// Get the [`BlockAddr`] of the block that would contain `key`.
///
/// Returns None if `key` is lexicographically after the last key recorded.
pub fn get_block_with_key(&self, key: &[u8]) -> Option<BlockAddr> {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block_with_key(key),
SSTableIndex::V3(v3_index) => v3_index.get_block_with_key(key),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_key(key),
}
}
pub(crate) fn locate_with_ord(&self, ord: TermOrdinal) -> u64 {
match self {
SSTableIndex::V2(v2_index) => v2_index.locate_with_ord(ord) as u64,
SSTableIndex::V3(v3_index) => v3_index.locate_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_ord(ord),
}
}
/// Get the [`BlockAddr`] of the block containing the `ord`-th term.
pub(crate) fn get_block_with_ord(&self, ord: TermOrdinal) -> BlockAddr {
match self {
SSTableIndex::V2(v2_index) => v2_index.get_block_with_ord(ord),
SSTableIndex::V3(v3_index) => v3_index.get_block_with_ord(ord),
SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_ord(ord),
}
}
pub fn get_block_for_automaton<'a>(
&'a self,
automaton: &'a impl Automaton,
) -> impl Iterator<Item = (u64, BlockAddr)> + 'a {
match self {
SSTableIndex::V2(v2_index) => {
BlockIter::V2(v2_index.get_block_for_automaton(automaton))
}
SSTableIndex::V3(v3_index) => {
BlockIter::V3(v3_index.get_block_for_automaton(automaton))
}
SSTableIndex::V3Empty(v3_empty) => {
BlockIter::V3Empty(std::iter::once((0, v3_empty.block_addr.clone())))
}
}
}
}
enum BlockIter<V2, V3, T> {
V2(V2),
V3(V3),
V3Empty(std::iter::Once<T>),
}
impl<V2: Iterator<Item = T>, V3: Iterator<Item = T>, T> Iterator for BlockIter<V2, V3, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self {
BlockIter::V2(v2) => v2.next(),
BlockIter::V3(v3) => v3.next(),
BlockIter::V3Empty(once) => once.next(),
}
}
}
use crate::{SSTableDataCorruption, TermOrdinal};
#[derive(Debug, Clone)]
pub struct SSTableIndexV3 {
@@ -160,6 +68,11 @@ impl SSTableIndexV3 {
self.block_addr_store.binary_search_ord(ord).1
}
pub(crate) fn get_and_locate_with_ord(&self, ord: TermOrdinal) -> (BlockAddr, u64) {
let (location, block_addr) = self.block_addr_store.binary_search_ord(ord);
(block_addr, location)
}
pub(crate) fn get_block_for_automaton<'a>(
&'a self,
automaton: &'a impl Automaton,
@@ -216,7 +129,7 @@ impl<A: Automaton> Iterator for GetBlockForAutomaton<'_, A> {
#[derive(Debug, Clone)]
pub struct SSTableIndexV3Empty {
block_addr: BlockAddr,
pub block_addr: BlockAddr,
}
impl SSTableIndexV3Empty {
@@ -230,8 +143,8 @@ impl SSTableIndexV3Empty {
}
/// Get the [`BlockAddr`] of the requested block.
pub(crate) fn get_block(&self, _block_id: u64) -> Option<BlockAddr> {
Some(self.block_addr.clone())
pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> {
(block_id == 0).then(|| self.block_addr.clone())
}
/// Get the block id of the block that would contain `key`.
@@ -256,146 +169,9 @@ impl SSTableIndexV3Empty {
pub(crate) fn get_block_with_ord(&self, _ord: TermOrdinal) -> BlockAddr {
self.block_addr.clone()
}
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct BlockAddr {
pub first_ordinal: u64,
pub byte_range: Range<usize>,
}
impl BlockAddr {
fn to_block_start(&self) -> BlockStartAddr {
BlockStartAddr {
first_ordinal: self.first_ordinal,
byte_range_start: self.byte_range.start,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct BlockStartAddr {
first_ordinal: u64,
byte_range_start: usize,
}
impl BlockStartAddr {
fn to_block_addr(&self, byte_range_end: usize) -> BlockAddr {
BlockAddr {
first_ordinal: self.first_ordinal,
byte_range: self.byte_range_start..byte_range_end,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct BlockMeta {
/// Any byte string that is lexicographically greater or equal to
/// the last key in the block,
/// and yet strictly smaller than the first key in the next block.
pub last_key_or_greater: Vec<u8>,
pub block_addr: BlockAddr,
}
impl BinarySerializable for BlockStartAddr {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
let start = self.byte_range_start as u64;
start.serialize(writer)?;
self.first_ordinal.serialize(writer)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let byte_range_start = u64::deserialize(reader)? as usize;
let first_ordinal = u64::deserialize(reader)?;
Ok(BlockStartAddr {
first_ordinal,
byte_range_start,
})
}
// Provided method
fn num_bytes(&self) -> u64 {
BlockStartAddr::SIZE_IN_BYTES as u64
}
}
impl FixedSize for BlockStartAddr {
const SIZE_IN_BYTES: usize = 2 * u64::SIZE_IN_BYTES;
}
/// Given that left < right,
/// mutates `left into a shorter byte string left'` that
/// matches `left <= left' < right`.
fn find_shorter_str_in_between(left: &mut Vec<u8>, right: &[u8]) {
assert!(&left[..] < right);
let common_len = common_prefix_len(left, right);
if left.len() == common_len {
return;
}
// It is possible to do one character shorter in some case,
// but it is not worth the extra complexity
for pos in (common_len + 1)..left.len() {
if left[pos] != u8::MAX {
left[pos] += 1;
left.truncate(pos + 1);
return;
}
}
}
#[derive(Default)]
pub struct SSTableIndexBuilder {
blocks: Vec<BlockMeta>,
}
impl SSTableIndexBuilder {
/// In order to make the index as light as possible, we
/// try to find a shorter alternative to the last key of the last block
/// that is still smaller than the next key.
pub(crate) fn shorten_last_block_key_given_next_key(&mut self, next_key: &[u8]) {
if let Some(last_block) = self.blocks.last_mut() {
find_shorter_str_in_between(&mut last_block.last_key_or_greater, next_key);
}
}
pub fn add_block(&mut self, last_key: &[u8], byte_range: Range<usize>, first_ordinal: u64) {
self.blocks.push(BlockMeta {
last_key_or_greater: last_key.to_vec(),
block_addr: BlockAddr {
byte_range,
first_ordinal,
},
})
}
pub fn serialize<W: std::io::Write>(&self, wrt: W) -> io::Result<u64> {
if self.blocks.len() <= 1 {
return Ok(0);
}
let counting_writer = common::CountingWriter::wrap(wrt);
let mut map_builder = MapBuilder::new(counting_writer).map_err(fst_error_to_io_error)?;
for (i, block) in self.blocks.iter().enumerate() {
map_builder
.insert(&block.last_key_or_greater, i as u64)
.map_err(fst_error_to_io_error)?;
}
let counting_writer = map_builder.into_inner().map_err(fst_error_to_io_error)?;
let written_bytes = counting_writer.written_bytes();
let mut wrt = counting_writer.finish();
let mut block_store_writer = BlockAddrStoreWriter::new();
for block in &self.blocks {
block_store_writer.write_block_meta(block.block_addr.clone())?;
}
block_store_writer.serialize(&mut wrt)?;
Ok(written_bytes)
}
}
fn fst_error_to_io_error(error: tantivy_fst::Error) -> io::Error {
match error {
tantivy_fst::Error::Fst(fst_error) => io::Error::other(fst_error),
tantivy_fst::Error::Io(ioerror) => ioerror,
pub(crate) fn get_and_locate_with_ord(&self, _ord: TermOrdinal) -> (BlockAddr, u64) {
(self.block_addr.clone(), 0)
}
}
@@ -647,14 +423,14 @@ fn binary_search(max: u64, cmp_fn: impl Fn(u64) -> std::cmp::Ordering) -> Result
Err(left)
}
struct BlockAddrStoreWriter {
pub(crate) struct BlockAddrStoreWriter {
buffer_block_metas: Vec<u8>,
buffer_addrs: Vec<u8>,
block_addrs: Vec<BlockAddr>,
}
impl BlockAddrStoreWriter {
fn new() -> Self {
pub(crate) fn new() -> Self {
BlockAddrStoreWriter {
buffer_block_metas: Vec::new(),
buffer_addrs: Vec::new(),
@@ -662,7 +438,7 @@ impl BlockAddrStoreWriter {
}
}
fn flush_block(&mut self) -> io::Result<()> {
pub(crate) fn flush_block(&mut self) -> io::Result<()> {
if self.block_addrs.is_empty() {
return Ok(());
}
@@ -741,7 +517,7 @@ impl BlockAddrStoreWriter {
Ok(())
}
fn write_block_meta(&mut self, block_addr: BlockAddr) -> io::Result<()> {
pub(crate) fn write_block_meta(&mut self, block_addr: BlockAddr) -> io::Result<()> {
self.block_addrs.push(block_addr);
if self.block_addrs.len() >= STORE_BLOCK_LEN {
self.flush_block()?;
@@ -749,7 +525,7 @@ impl BlockAddrStoreWriter {
Ok(())
}
fn serialize<W: std::io::Write>(&mut self, wrt: &mut W) -> io::Result<()> {
pub(crate) fn serialize<W: std::io::Write>(&mut self, wrt: &mut W) -> io::Result<()> {
self.flush_block()?;
let len = self.buffer_block_metas.len() as u64;
len.serialize(wrt)?;
@@ -824,8 +600,9 @@ mod tests {
use common::OwnedBytes;
use super::*;
use crate::SSTableDataCorruption;
use crate::block_match_automaton::tests::EqBuffer;
use crate::index::BlockMeta;
use crate::{SSTableDataCorruption, SSTableIndexBuilder};
#[test]
fn test_sstable_index() {
@@ -874,36 +651,7 @@ mod tests {
assert!(matches!(data_corruption_err, SSTableDataCorruption));
}
#[track_caller]
fn test_find_shorter_str_in_between_aux(left: &[u8], right: &[u8]) {
let mut left_buf = left.to_vec();
super::find_shorter_str_in_between(&mut left_buf, right);
assert!(left_buf.len() <= left.len());
assert!(left <= &left_buf);
assert!(&left_buf[..] < right);
}
#[test]
fn test_find_shorter_str_in_between() {
test_find_shorter_str_in_between_aux(b"", b"hello");
test_find_shorter_str_in_between_aux(b"abc", b"abcd");
test_find_shorter_str_in_between_aux(b"abcd", b"abd");
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[1]);
test_find_shorter_str_in_between_aux(&[0, 0, 0], &[0, 0, 1]);
test_find_shorter_str_in_between_aux(&[0, 0, 255, 255, 255, 0u8], &[0, 1]);
}
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn test_proptest_find_shorter_str(left in any::<Vec<u8>>(), right in any::<Vec<u8>>()) {
if left < right {
test_find_shorter_str_in_between_aux(&left, &right);
}
}
}
// use proptest::prelude::*;
#[test]
fn test_find_best_slop() {

View File

@@ -47,9 +47,8 @@ pub mod merge;
mod streamer;
pub mod value;
mod sstable_index_v3;
pub use sstable_index_v3::{BlockAddr, SSTableIndex, SSTableIndexBuilder, SSTableIndexV3};
mod sstable_index_v2;
mod index;
pub use index::{BlockAddr, SSTableIndex, SSTableIndexBuilder};
pub(crate) mod vint;
pub use dictionary::{Dictionary, TermOrdHit};
pub use streamer::{Streamer, StreamerBuilder};

View File

@@ -27,7 +27,7 @@ rand = "0.9"
zipf = "7.0.0"
rustc-hash = "2.1.0"
proptest = "1.2.0"
binggan = { version = "0.16.1" }
binggan = { version = "0.17.0" }
rand_distr = "0.5"
[features]