Skip to main content

mito2/read/
range_cache.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Utilities for the partition range scan result cache.
16
17use std::mem;
18use std::sync::Arc;
19
20use async_stream::try_stream;
21use common_time::range::TimestampRange;
22use datatypes::arrow::array::{Array, AsArray, DictionaryArray};
23use datatypes::arrow::datatypes::UInt32Type;
24use datatypes::arrow::record_batch::RecordBatch;
25use datatypes::prelude::ConcreteDataType;
26use futures::TryStreamExt;
27use store_api::region_engine::PartitionRange;
28use store_api::storage::{ColumnId, FileId, RegionId, TimeSeriesRowSelector};
29
30use crate::cache::CacheStrategy;
31use crate::read::BoxedRecordBatchStream;
32use crate::read::scan_region::StreamContext;
33use crate::read::scan_util::PartitionMetrics;
34use crate::region::options::MergeMode;
35use crate::sst::file::FileTimeRange;
36use crate::sst::parquet::flat_format::primary_key_column_index;
37
38/// Fingerprint of the scan request fields that affect partition range cache reuse.
39///
40/// It records a normalized view of the projected columns and filters, plus
41/// scan options that can change the returned rows. Schema-dependent metadata
42/// and the partition expression version are included so cached results are not
43/// reused across incompatible schema or partitioning changes.
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45pub(crate) struct ScanRequestFingerprint {
46    /// Projection and filters without the time index and partition exprs.
47    inner: Arc<SharedScanRequestFingerprint>,
48    /// Filters with the time index column.
49    time_filters: Option<Arc<Vec<String>>>,
50    series_row_selector: Option<TimeSeriesRowSelector>,
51    append_mode: bool,
52    filter_deleted: bool,
53    merge_mode: MergeMode,
54    /// We keep the partition expr version to ensure we won't reuse the fingerprint after we change the partition expr.
55    /// We store the version instead of the whole partition expr or partition expr filters.
56    partition_expr_version: u64,
57}
58
59#[derive(Debug)]
60pub(crate) struct ScanRequestFingerprintBuilder {
61    pub(crate) read_column_ids: Vec<ColumnId>,
62    pub(crate) read_column_types: Vec<Option<ConcreteDataType>>,
63    pub(crate) filters: Vec<String>,
64    pub(crate) time_filters: Vec<String>,
65    pub(crate) series_row_selector: Option<TimeSeriesRowSelector>,
66    pub(crate) append_mode: bool,
67    pub(crate) filter_deleted: bool,
68    pub(crate) merge_mode: MergeMode,
69    pub(crate) partition_expr_version: u64,
70}
71
72impl ScanRequestFingerprintBuilder {
73    pub(crate) fn build(self) -> ScanRequestFingerprint {
74        let Self {
75            read_column_ids,
76            read_column_types,
77            filters,
78            time_filters,
79            series_row_selector,
80            append_mode,
81            filter_deleted,
82            merge_mode,
83            partition_expr_version,
84        } = self;
85
86        ScanRequestFingerprint {
87            inner: Arc::new(SharedScanRequestFingerprint {
88                read_column_ids,
89                read_column_types,
90                filters,
91            }),
92            time_filters: (!time_filters.is_empty()).then(|| Arc::new(time_filters)),
93            series_row_selector,
94            append_mode,
95            filter_deleted,
96            merge_mode,
97            partition_expr_version,
98        }
99    }
100}
101
102/// Non-copiable struct of the fingerprint.
103#[derive(Debug, PartialEq, Eq, Hash)]
104struct SharedScanRequestFingerprint {
105    /// Column ids of the projection.
106    read_column_ids: Vec<ColumnId>,
107    /// Column types of the projection.
108    /// We keep this to ensure we won't reuse the fingerprint after a schema change.
109    read_column_types: Vec<Option<ConcreteDataType>>,
110    /// Filters without the time index column and region partition exprs.
111    filters: Vec<String>,
112}
113
114impl ScanRequestFingerprint {
115    #[cfg(test)]
116    pub(crate) fn read_column_ids(&self) -> &[ColumnId] {
117        &self.inner.read_column_ids
118    }
119
120    #[cfg(test)]
121    pub(crate) fn read_column_types(&self) -> &[Option<ConcreteDataType>] {
122        &self.inner.read_column_types
123    }
124
125    #[cfg(test)]
126    pub(crate) fn filters(&self) -> &[String] {
127        &self.inner.filters
128    }
129
130    #[cfg(test)]
131    pub(crate) fn time_filters(&self) -> &[String] {
132        self.time_filters
133            .as_deref()
134            .map(Vec::as_slice)
135            .unwrap_or(&[])
136    }
137
138    pub(crate) fn without_time_filters(&self) -> Self {
139        Self {
140            inner: Arc::clone(&self.inner),
141            time_filters: None,
142            series_row_selector: self.series_row_selector,
143            append_mode: self.append_mode,
144            filter_deleted: self.filter_deleted,
145            merge_mode: self.merge_mode,
146            partition_expr_version: self.partition_expr_version,
147        }
148    }
149
150    pub(crate) fn estimated_size(&self) -> usize {
151        mem::size_of::<SharedScanRequestFingerprint>()
152            + self.inner.read_column_ids.capacity() * mem::size_of::<ColumnId>()
153            + self.inner.read_column_types.capacity() * mem::size_of::<Option<ConcreteDataType>>()
154            + self.inner.filters.capacity() * mem::size_of::<String>()
155            + self
156                .inner
157                .filters
158                .iter()
159                .map(|filter| filter.capacity())
160                .sum::<usize>()
161            + self.time_filters.as_ref().map_or(0, |filters| {
162                mem::size_of::<Vec<String>>()
163                    + filters.capacity() * mem::size_of::<String>()
164                    + filters
165                        .iter()
166                        .map(|filter| filter.capacity())
167                        .sum::<usize>()
168            })
169    }
170}
171
172/// Cache key for range scan outputs.
173#[derive(Debug, Clone, PartialEq, Eq, Hash)]
174pub(crate) struct RangeScanCacheKey {
175    pub(crate) region_id: RegionId,
176    /// Sorted (file_id, row_group_index) pairs that uniquely identify the data this range covers.
177    pub(crate) row_groups: Vec<(FileId, i64)>,
178    pub(crate) scan: ScanRequestFingerprint,
179}
180
181impl RangeScanCacheKey {
182    pub(crate) fn estimated_size(&self) -> usize {
183        mem::size_of::<Self>()
184            + self.row_groups.capacity() * mem::size_of::<(FileId, i64)>()
185            + self.scan.estimated_size()
186    }
187}
188
189/// Cached result for one range scan.
190pub(crate) struct RangeScanCacheValue {
191    pub(crate) batches: Vec<RecordBatch>,
192    /// Precomputed size of all batches, accounting for shared dictionary values.
193    estimated_batches_size: usize,
194}
195
196impl RangeScanCacheValue {
197    pub(crate) fn new(batches: Vec<RecordBatch>, estimated_batches_size: usize) -> Self {
198        Self {
199            batches,
200            estimated_batches_size,
201        }
202    }
203
204    pub(crate) fn estimated_size(&self) -> usize {
205        mem::size_of::<Self>()
206            + self.batches.capacity() * mem::size_of::<RecordBatch>()
207            + self.estimated_batches_size
208    }
209}
210
211/// Row groups and whether all sources are file-only for a partition range.
212#[allow(dead_code)]
213pub(crate) struct PartitionRangeRowGroups {
214    /// Sorted (file_id, row_group_index) pairs.
215    pub(crate) row_groups: Vec<(FileId, i64)>,
216    pub(crate) only_file_sources: bool,
217}
218
219/// Collects (file_id, row_group_index) pairs from a partition range's row group indices.
220#[allow(dead_code)]
221pub(crate) fn collect_partition_range_row_groups(
222    stream_ctx: &StreamContext,
223    part_range: &PartitionRange,
224) -> PartitionRangeRowGroups {
225    let range_meta = &stream_ctx.ranges[part_range.identifier];
226    let mut row_groups = Vec::new();
227    let mut only_file_sources = true;
228
229    for index in &range_meta.row_group_indices {
230        if stream_ctx.is_file_range_index(*index) {
231            let file_id = stream_ctx.input.file_from_index(*index).file_id().file_id();
232            row_groups.push((file_id, index.row_group_index));
233        } else {
234            only_file_sources = false;
235        }
236    }
237
238    row_groups.sort_unstable_by(|a, b| a.0.as_bytes().cmp(b.0.as_bytes()).then(a.1.cmp(&b.1)));
239
240    PartitionRangeRowGroups {
241        row_groups,
242        only_file_sources,
243    }
244}
245
246/// Builds a cache key for the given partition range if it is eligible for caching.
247#[allow(dead_code)]
248pub(crate) fn build_range_cache_key(
249    stream_ctx: &StreamContext,
250    part_range: &PartitionRange,
251) -> Option<RangeScanCacheKey> {
252    let fingerprint = stream_ctx.scan_fingerprint.as_ref()?;
253
254    // Dyn filters can change at runtime, so we can't cache when they're present.
255    let has_dyn_filters = stream_ctx
256        .input
257        .predicate_group()
258        .predicate_without_region()
259        .is_some_and(|p| !p.dyn_filters().is_empty());
260    if has_dyn_filters {
261        return None;
262    }
263
264    let rg = collect_partition_range_row_groups(stream_ctx, part_range);
265    if !rg.only_file_sources || rg.row_groups.is_empty() {
266        return None;
267    }
268
269    let range_meta = &stream_ctx.ranges[part_range.identifier];
270    let scan = if query_time_range_covers_partition_range(
271        stream_ctx.input.time_range.as_ref(),
272        range_meta.time_range,
273    ) {
274        fingerprint.without_time_filters()
275    } else {
276        fingerprint.clone()
277    };
278
279    Some(RangeScanCacheKey {
280        region_id: stream_ctx.input.region_metadata().region_id,
281        row_groups: rg.row_groups,
282        scan,
283    })
284}
285
286#[allow(dead_code)]
287fn query_time_range_covers_partition_range(
288    query_time_range: Option<&TimestampRange>,
289    partition_time_range: FileTimeRange,
290) -> bool {
291    let Some(query_time_range) = query_time_range else {
292        return true;
293    };
294
295    let (part_start, part_end) = partition_time_range;
296    query_time_range.contains(&part_start) && query_time_range.contains(&part_end)
297}
298
299/// Returns a stream that replays cached record batches.
300#[allow(dead_code)]
301pub(crate) fn cached_flat_range_stream(value: Arc<RangeScanCacheValue>) -> BoxedRecordBatchStream {
302    Box::pin(futures::stream::iter(
303        value.batches.clone().into_iter().map(Ok),
304    ))
305}
306
307/// Returns true if two primary key dictionary arrays share the same underlying
308/// values buffers by pointer comparison.
309///
310/// The primary key column is always `DictionaryArray<UInt32Type>` with `Binary` values.
311fn pk_values_ptr_eq(a: &DictionaryArray<UInt32Type>, b: &DictionaryArray<UInt32Type>) -> bool {
312    let a = a.values().as_binary::<i32>();
313    let b = b.values().as_binary::<i32>();
314    let values_eq = a.values().ptr_eq(b.values()) && a.offsets().ptr_eq(b.offsets());
315    match (a.nulls(), b.nulls()) {
316        (Some(a), Some(b)) => values_eq && a.inner().ptr_eq(b.inner()),
317        (None, None) => values_eq,
318        _ => false,
319    }
320}
321
322/// Buffers record batches for caching, tracking memory size while deduplicating
323/// shared dictionary values across batches.
324///
325/// Uses the primary key column as a proxy to detect dictionary sharing: if the PK
326/// column's dictionary values are pointer-equal across batches, we assume all
327/// dictionary columns share their values and deduct the total dictionary values size.
328struct CacheBatchBuffer {
329    batches: Vec<RecordBatch>,
330    /// Running total of batch memory.
331    total_size: usize,
332    /// The first batch's PK dictionary array, for pointer comparison.
333    /// `None` if no dictionary PK column exists or no batch has been added yet.
334    first_pk_dict: Option<DictionaryArray<UInt32Type>>,
335    /// Sum of `get_array_memory_size()` of all dictionary value arrays from the first batch.
336    total_dict_values_size: usize,
337    /// Whether the PK dictionary is still shared across all batches seen so far.
338    shared: bool,
339}
340
341impl CacheBatchBuffer {
342    fn new() -> Self {
343        Self {
344            batches: Vec::new(),
345            total_size: 0,
346            first_pk_dict: None,
347            total_dict_values_size: 0,
348            shared: true,
349        }
350    }
351
352    fn push(&mut self, batch: RecordBatch) {
353        if self.batches.is_empty() {
354            self.init_first_batch(&batch);
355        } else {
356            self.add_subsequent_batch(&batch);
357        }
358        self.batches.push(batch);
359    }
360
361    fn init_first_batch(&mut self, batch: &RecordBatch) {
362        self.total_size += batch.get_array_memory_size();
363
364        let pk_col_idx = primary_key_column_index(batch.num_columns());
365        let mut total_dict_values_size = 0;
366        for col_idx in 0..batch.num_columns() {
367            let col = batch.column(col_idx);
368            if let Some(dict) = col.as_any().downcast_ref::<DictionaryArray<UInt32Type>>() {
369                total_dict_values_size += dict.values().get_array_memory_size();
370                if col_idx == pk_col_idx {
371                    self.first_pk_dict = Some(dict.clone());
372                }
373            }
374        }
375        self.total_dict_values_size = total_dict_values_size;
376    }
377
378    fn add_subsequent_batch(&mut self, batch: &RecordBatch) {
379        let batch_size = batch.get_array_memory_size();
380
381        if self.shared
382            && let Some(first_pk_dict) = &self.first_pk_dict
383        {
384            let pk_col_idx = primary_key_column_index(batch.num_columns());
385            let col = batch.column(pk_col_idx);
386            if let Some(dict) = col.as_any().downcast_ref::<DictionaryArray<UInt32Type>>()
387                && pk_values_ptr_eq(first_pk_dict, dict)
388            {
389                // PK dict is shared, deduct all dict values sizes.
390                self.total_size += batch_size - self.total_dict_values_size;
391                return;
392            }
393            // Dictionary diverged.
394            self.shared = false;
395        }
396
397        self.total_size += batch_size;
398    }
399
400    fn estimated_batches_size(&self) -> usize {
401        self.total_size
402    }
403
404    fn into_batches(self) -> Vec<RecordBatch> {
405        self.batches
406    }
407}
408
409/// Wraps a stream to cache its output for future range cache hits.
410#[allow(dead_code)]
411pub(crate) fn cache_flat_range_stream(
412    mut stream: BoxedRecordBatchStream,
413    cache_strategy: CacheStrategy,
414    key: RangeScanCacheKey,
415    part_metrics: PartitionMetrics,
416) -> BoxedRecordBatchStream {
417    Box::pin(try_stream! {
418        let mut buffer = CacheBatchBuffer::new();
419        while let Some(batch) = stream.try_next().await? {
420            buffer.push(batch.clone());
421            yield batch;
422        }
423
424        let estimated_size = buffer.estimated_batches_size();
425        let batches = buffer.into_batches();
426        let value = Arc::new(RangeScanCacheValue::new(batches, estimated_size));
427        part_metrics.inc_range_cache_size(key.estimated_size() + value.estimated_size());
428        cache_strategy.put_range_result(key, value);
429    })
430}
431
432/// Creates a `cache_flat_range_stream` with dummy internals for benchmarking.
433///
434/// This avoids exposing `RangeScanCacheKey`, `ScanRequestFingerprint`, and
435/// `PartitionMetrics` publicly.
436#[cfg(feature = "test")]
437pub fn bench_cache_flat_range_stream(
438    stream: BoxedRecordBatchStream,
439    cache_size_bytes: u64,
440    region_id: RegionId,
441) -> BoxedRecordBatchStream {
442    use std::time::Instant;
443
444    use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
445
446    use crate::region::options::MergeMode;
447
448    let cache_manager = Arc::new(
449        crate::cache::CacheManager::builder()
450            .range_result_cache_size(cache_size_bytes)
451            .build(),
452    );
453    let cache_strategy = CacheStrategy::EnableAll(cache_manager);
454
455    let fingerprint = ScanRequestFingerprintBuilder {
456        read_column_ids: vec![],
457        read_column_types: vec![],
458        filters: vec![],
459        time_filters: vec![],
460        series_row_selector: None,
461        append_mode: false,
462        filter_deleted: false,
463        merge_mode: MergeMode::LastRow,
464        partition_expr_version: 0,
465    }
466    .build();
467
468    let key = RangeScanCacheKey {
469        region_id,
470        row_groups: vec![],
471        scan: fingerprint,
472    };
473
474    let metrics_set = ExecutionPlanMetricsSet::new();
475    let part_metrics =
476        PartitionMetrics::new(region_id, 0, "bench", Instant::now(), false, &metrics_set);
477
478    cache_flat_range_stream(stream, cache_strategy, key, part_metrics)
479}
480
481#[cfg(test)]
482mod tests {
483    use std::sync::Arc;
484    use std::time::Instant;
485
486    use common_time::Timestamp;
487    use common_time::range::TimestampRange;
488    use common_time::timestamp::TimeUnit;
489    use datafusion_common::ScalarValue;
490    use datafusion_expr::{Expr, col, lit};
491    use smallvec::smallvec;
492    use store_api::storage::FileId;
493
494    use super::*;
495    use crate::cache::CacheManager;
496    use crate::read::projection::ProjectionMapper;
497    use crate::read::range::{RangeMeta, RowGroupIndex, SourceIndex};
498    use crate::read::scan_region::{PredicateGroup, ScanInput};
499    use crate::test_util::memtable_util::metadata_with_primary_key;
500    use crate::test_util::scheduler_util::SchedulerEnv;
501    use crate::test_util::sst_util::sst_file_handle_with_file_id;
502
503    fn test_cache_strategy() -> CacheStrategy {
504        CacheStrategy::EnableAll(Arc::new(
505            CacheManager::builder()
506                .range_result_cache_size(1024)
507                .build(),
508        ))
509    }
510
511    async fn new_stream_context(
512        filters: Vec<Expr>,
513        query_time_range: Option<TimestampRange>,
514        partition_time_range: FileTimeRange,
515    ) -> (StreamContext, PartitionRange) {
516        let env = SchedulerEnv::new().await;
517        let metadata = Arc::new(metadata_with_primary_key(vec![0, 1], false));
518        let mapper = ProjectionMapper::new(&metadata, [0, 2, 3].into_iter()).unwrap();
519        let predicate = PredicateGroup::new(metadata.as_ref(), &filters).unwrap();
520        let file_id = FileId::random();
521        let file = sst_file_handle_with_file_id(
522            file_id,
523            partition_time_range.0.value(),
524            partition_time_range.1.value(),
525        );
526        let input = ScanInput::new(env.access_layer.clone(), mapper)
527            .with_predicate(predicate)
528            .with_time_range(query_time_range)
529            .with_files(vec![file])
530            .with_cache(test_cache_strategy());
531        let range_meta = RangeMeta {
532            time_range: partition_time_range,
533            indices: smallvec![SourceIndex {
534                index: 0,
535                num_row_groups: 1,
536            }],
537            row_group_indices: smallvec![RowGroupIndex {
538                index: 0,
539                row_group_index: 0,
540            }],
541            num_rows: 10,
542        };
543        let partition_range = range_meta.new_partition_range(0);
544        let scan_fingerprint = crate::read::scan_region::build_scan_fingerprint(&input);
545        let stream_ctx = StreamContext {
546            input,
547            ranges: vec![range_meta],
548            scan_fingerprint,
549            query_start: Instant::now(),
550        };
551
552        (stream_ctx, partition_range)
553    }
554
555    /// Helper to create a timestamp millisecond literal.
556    fn ts_lit(val: i64) -> Expr {
557        lit(ScalarValue::TimestampMillisecond(Some(val), None))
558    }
559
560    #[tokio::test]
561    async fn strips_time_only_filters_when_query_covers_partition_range() {
562        let (stream_ctx, part_range) = new_stream_context(
563            vec![
564                col("ts").gt_eq(ts_lit(1000)),
565                col("ts").lt(ts_lit(2001)),
566                col("ts").is_not_null(),
567                col("k0").eq(lit("foo")),
568            ],
569            TimestampRange::with_unit(1000, 2002, TimeUnit::Millisecond),
570            (
571                Timestamp::new_millisecond(1000),
572                Timestamp::new_millisecond(2000),
573            ),
574        )
575        .await;
576
577        let key = build_range_cache_key(&stream_ctx, &part_range).unwrap();
578
579        // Range-reducible time filters should be cleared when query covers partition range.
580        assert!(key.scan.time_filters().is_empty());
581        // Non-range time predicates stay in filters.
582        let mut expected_filters = [
583            col("k0").eq(lit("foo")).to_string(),
584            col("ts").is_not_null().to_string(),
585        ];
586        expected_filters.sort_unstable();
587        assert_eq!(key.scan.filters(), expected_filters.as_slice());
588    }
589
590    #[tokio::test]
591    async fn preserves_time_filters_when_query_does_not_cover_partition_range() {
592        let (stream_ctx, part_range) = new_stream_context(
593            vec![col("ts").gt_eq(ts_lit(1000)), col("k0").eq(lit("foo"))],
594            TimestampRange::with_unit(1000, 1500, TimeUnit::Millisecond),
595            (
596                Timestamp::new_millisecond(1000),
597                Timestamp::new_millisecond(2000),
598            ),
599        )
600        .await;
601
602        let key = build_range_cache_key(&stream_ctx, &part_range).unwrap();
603
604        // Time filters should be preserved when query does not cover partition range.
605        assert_eq!(
606            key.scan.time_filters(),
607            [col("ts").gt_eq(ts_lit(1000)).to_string()].as_slice()
608        );
609        assert_eq!(
610            key.scan.filters(),
611            [col("k0").eq(lit("foo")).to_string()].as_slice()
612        );
613    }
614
615    #[tokio::test]
616    async fn strips_time_only_filters_when_query_has_no_time_range_limit() {
617        let (stream_ctx, part_range) = new_stream_context(
618            vec![
619                col("ts").gt_eq(ts_lit(1000)),
620                col("ts").is_not_null(),
621                col("k0").eq(lit("foo")),
622            ],
623            None,
624            (
625                Timestamp::new_millisecond(1000),
626                Timestamp::new_millisecond(2000),
627            ),
628        )
629        .await;
630
631        let key = build_range_cache_key(&stream_ctx, &part_range).unwrap();
632
633        // Range-reducible time filters should be cleared when query has no time range limit.
634        assert!(key.scan.time_filters().is_empty());
635        // Non-range time predicates stay in filters.
636        let mut expected_filters = [
637            col("k0").eq(lit("foo")).to_string(),
638            col("ts").is_not_null().to_string(),
639        ];
640        expected_filters.sort_unstable();
641        assert_eq!(key.scan.filters(), expected_filters.as_slice());
642    }
643
644    #[test]
645    fn normalizes_and_clears_time_filters() {
646        let normalized = ScanRequestFingerprintBuilder {
647            read_column_ids: vec![1, 2],
648            read_column_types: vec![None, None],
649            filters: vec!["k0 = 'foo'".to_string()],
650            time_filters: vec![],
651            series_row_selector: None,
652            append_mode: false,
653            filter_deleted: true,
654            merge_mode: MergeMode::LastRow,
655            partition_expr_version: 0,
656        }
657        .build();
658
659        assert!(normalized.time_filters().is_empty());
660
661        let fingerprint = ScanRequestFingerprintBuilder {
662            read_column_ids: vec![1, 2],
663            read_column_types: vec![None, None],
664            filters: vec!["k0 = 'foo'".to_string()],
665            time_filters: vec!["ts >= 1000".to_string()],
666            series_row_selector: Some(TimeSeriesRowSelector::LastRow),
667            append_mode: false,
668            filter_deleted: true,
669            merge_mode: MergeMode::LastRow,
670            partition_expr_version: 7,
671        }
672        .build();
673
674        let reset = fingerprint.without_time_filters();
675
676        assert_eq!(reset.read_column_ids(), fingerprint.read_column_ids());
677        assert_eq!(reset.read_column_types(), fingerprint.read_column_types());
678        assert_eq!(reset.filters(), fingerprint.filters());
679        assert!(reset.time_filters().is_empty());
680        assert_eq!(reset.series_row_selector, fingerprint.series_row_selector);
681        assert_eq!(reset.append_mode, fingerprint.append_mode);
682        assert_eq!(reset.filter_deleted, fingerprint.filter_deleted);
683        assert_eq!(reset.merge_mode, fingerprint.merge_mode);
684        assert_eq!(
685            reset.partition_expr_version,
686            fingerprint.partition_expr_version
687        );
688    }
689
690    /// Creates a test schema with 5 columns where the primary key dictionary column
691    /// is at index 2 (`num_columns - 3`), matching the flat format layout.
692    ///
693    /// Layout: `[field0: Int64, field1: Int64, pk: Dictionary<UInt32,Binary>, ts: Int64, seq: Int64]`
694    fn dict_test_schema() -> Arc<datatypes::arrow::datatypes::Schema> {
695        use datatypes::arrow::datatypes::{DataType as ArrowDataType, Field, Schema};
696        Arc::new(Schema::new(vec![
697            Field::new("field0", ArrowDataType::Int64, false),
698            Field::new("field1", ArrowDataType::Int64, false),
699            Field::new(
700                "pk",
701                ArrowDataType::Dictionary(
702                    Box::new(ArrowDataType::UInt32),
703                    Box::new(ArrowDataType::Binary),
704                ),
705                false,
706            ),
707            Field::new("ts", ArrowDataType::Int64, false),
708            Field::new("seq", ArrowDataType::Int64, false),
709        ]))
710    }
711
712    /// Helper to create a record batch with a dictionary column at the primary key position.
713    fn make_dict_batch(
714        schema: Arc<datatypes::arrow::datatypes::Schema>,
715        dict_values: &datatypes::arrow::array::BinaryArray,
716        keys: &[u32],
717        int_values: &[i64],
718    ) -> RecordBatch {
719        use datatypes::arrow::array::{Int64Array, UInt32Array};
720
721        let key_array = UInt32Array::from(keys.to_vec());
722        let dict_array: DictionaryArray<UInt32Type> =
723            DictionaryArray::new(key_array, Arc::new(dict_values.clone()));
724        let int_array = Int64Array::from(int_values.to_vec());
725        let zeros = Int64Array::from(vec![0i64; int_values.len()]);
726        RecordBatch::try_new(
727            schema,
728            vec![
729                Arc::new(zeros.clone()),
730                Arc::new(int_array),
731                Arc::new(dict_array),
732                Arc::new(zeros.clone()),
733                Arc::new(zeros),
734            ],
735        )
736        .unwrap()
737    }
738
739    /// Computes the total `get_array_memory_size()` of all dictionary value arrays in a batch.
740    fn compute_total_dict_values_size(batch: &RecordBatch) -> usize {
741        batch
742            .columns()
743            .iter()
744            .filter_map(|col| {
745                col.as_any()
746                    .downcast_ref::<DictionaryArray<UInt32Type>>()
747                    .map(|dict| dict.values().get_array_memory_size())
748            })
749            .sum()
750    }
751
752    #[test]
753    fn cache_batch_buffer_empty() {
754        let buffer = CacheBatchBuffer::new();
755        assert_eq!(buffer.estimated_batches_size(), 0);
756        assert!(buffer.into_batches().is_empty());
757    }
758
759    #[test]
760    fn cache_batch_buffer_single_batch() {
761        use datatypes::arrow::array::BinaryArray;
762
763        let schema = dict_test_schema();
764        let dict_values = BinaryArray::from_vec(vec![b"a", b"b", b"c"]);
765        let batch = make_dict_batch(schema, &dict_values, &[0, 1, 2], &[10, 20, 30]);
766
767        let full_size = batch.get_array_memory_size();
768
769        let mut buffer = CacheBatchBuffer::new();
770        buffer.push(batch);
771        assert_eq!(buffer.estimated_batches_size(), full_size);
772        assert_eq!(buffer.into_batches().len(), 1);
773    }
774
775    #[test]
776    fn cache_batch_buffer_shared_dictionary() {
777        use datatypes::arrow::array::BinaryArray;
778
779        let schema = dict_test_schema();
780        let dict_values = BinaryArray::from_vec(vec![b"alpha", b"beta", b"gamma"]);
781
782        // Two batches sharing the same dictionary values array.
783        let batch1 = make_dict_batch(schema.clone(), &dict_values, &[0, 1], &[10, 20]);
784        let batch2 = make_dict_batch(schema, &dict_values, &[1, 2], &[30, 40]);
785
786        let batch1_full = batch1.get_array_memory_size();
787        let batch2_full = batch2.get_array_memory_size();
788
789        // The total dictionary values size that should be deduplicated for the second batch.
790        let dict_values_size = compute_total_dict_values_size(&batch2);
791
792        let mut buffer = CacheBatchBuffer::new();
793        buffer.push(batch1);
794        buffer.push(batch2);
795
796        // Second batch's dict values should not be counted again.
797        assert_eq!(
798            buffer.estimated_batches_size(),
799            batch1_full + batch2_full - dict_values_size
800        );
801        assert_eq!(buffer.into_batches().len(), 2);
802    }
803
804    #[test]
805    fn cache_batch_buffer_non_shared_dictionary() {
806        use datatypes::arrow::array::BinaryArray;
807
808        let schema = dict_test_schema();
809        let dict_values1 = BinaryArray::from_vec(vec![b"a", b"b"]);
810        let dict_values2 = BinaryArray::from_vec(vec![b"x", b"y"]);
811
812        let batch1 = make_dict_batch(schema.clone(), &dict_values1, &[0, 1], &[10, 20]);
813        let batch2 = make_dict_batch(schema, &dict_values2, &[0, 1], &[30, 40]);
814
815        let batch1_full = batch1.get_array_memory_size();
816        let batch2_full = batch2.get_array_memory_size();
817
818        let mut buffer = CacheBatchBuffer::new();
819        buffer.push(batch1);
820        buffer.push(batch2);
821
822        // Different dictionaries: full size for both.
823        assert_eq!(buffer.estimated_batches_size(), batch1_full + batch2_full);
824    }
825
826    #[test]
827    fn cache_batch_buffer_shared_then_diverged() {
828        use datatypes::arrow::array::BinaryArray;
829
830        let schema = dict_test_schema();
831        let shared_values = BinaryArray::from_vec(vec![b"a", b"b", b"c"]);
832        let different_values = BinaryArray::from_vec(vec![b"x", b"y"]);
833
834        let batch1 = make_dict_batch(schema.clone(), &shared_values, &[0], &[1]);
835        let batch2 = make_dict_batch(schema.clone(), &shared_values, &[1], &[2]);
836        let batch3 = make_dict_batch(schema, &different_values, &[0], &[3]);
837
838        let size1 = batch1.get_array_memory_size();
839        let size2 = batch2.get_array_memory_size();
840        let size3 = batch3.get_array_memory_size();
841
842        let dict_values_size = compute_total_dict_values_size(&batch2);
843
844        let mut buffer = CacheBatchBuffer::new();
845        buffer.push(batch1);
846        buffer.push(batch2);
847        buffer.push(batch3);
848
849        // batch2 shares dict with batch1 (dedup), batch3 does not (full size).
850        assert_eq!(
851            buffer.estimated_batches_size(),
852            size1 + (size2 - dict_values_size) + size3
853        );
854    }
855}