Skip to main content

common_recordbatch/
lib.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(never_type)]
16
17pub mod adapter;
18pub mod cursor;
19pub mod error;
20pub mod ext;
21pub mod filter;
22pub mod recordbatch;
23pub mod util;
24
25use std::fmt;
26use std::future::Future;
27use std::pin::Pin;
28use std::sync::Arc;
29
30use adapter::RecordBatchMetrics;
31use arc_swap::ArcSwapOption;
32use common_base::readable_size::ReadableSize;
33use common_error::ext::BoxedError;
34use common_memory_manager::{
35    MemoryGuard, MemoryManager, MemoryMetrics, OnExhaustedPolicy, PermitGranularity,
36};
37use common_telemetry::tracing::Span;
38pub use datafusion::physical_plan::SendableRecordBatchStream as DfSendableRecordBatchStream;
39use datatypes::arrow::array::{ArrayRef, AsArray, StringBuilder};
40use datatypes::arrow::compute::SortOptions;
41pub use datatypes::arrow::record_batch::RecordBatch as DfRecordBatch;
42use datatypes::arrow::util::pretty;
43use datatypes::prelude::{ConcreteDataType, VectorRef};
44use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
45use datatypes::types::{JsonFormat, jsonb_to_string};
46use error::Result;
47use futures::task::{Context, Poll};
48use futures::{Stream, TryStreamExt};
49pub use recordbatch::RecordBatch;
50use snafu::{IntoError, ResultExt, ensure};
51
52use crate::error::NewDfRecordBatchSnafu;
53
54pub trait RecordBatchStream: Stream<Item = Result<RecordBatch>> {
55    fn name(&self) -> &str {
56        "RecordBatchStream"
57    }
58
59    fn schema(&self) -> SchemaRef;
60
61    fn output_ordering(&self) -> Option<&[OrderOption]>;
62
63    fn metrics(&self) -> Option<RecordBatchMetrics>;
64}
65
66pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send>>;
67
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct OrderOption {
70    pub name: String,
71    pub options: SortOptions,
72}
73
74/// A wrapper that maps a [RecordBatchStream] to a new [RecordBatchStream] by applying a function to each [RecordBatch].
75///
76/// The mapper function is applied to each [RecordBatch] in the stream.
77/// The schema of the new [RecordBatchStream] is the same as the schema of the inner [RecordBatchStream] after applying the schema mapper function.
78/// The output ordering of the new [RecordBatchStream] is the same as the output ordering of the inner [RecordBatchStream].
79/// The metrics of the new [RecordBatchStream] is the same as the metrics of the inner [RecordBatchStream] if it is not `None`.
80pub struct SendableRecordBatchMapper {
81    inner: SendableRecordBatchStream,
82    /// The mapper function is applied to each [RecordBatch] in the stream.
83    /// The original schema and the mapped schema are passed to the mapper function.
84    mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
85    /// The schema of the new [RecordBatchStream] is the same as the schema of the inner [RecordBatchStream] after applying the schema mapper function.
86    schema: SchemaRef,
87    /// Whether the mapper function is applied to each [RecordBatch] in the stream.
88    apply_mapper: bool,
89}
90
91/// Maps the json type to string in the batch.
92///
93/// The json type is mapped to string by converting the json value to string.
94/// The batch is updated to have the same number of columns as the original batch,
95/// but with the json type mapped to string.
96pub fn map_json_type_to_string(
97    batch: RecordBatch,
98    original_schema: &SchemaRef,
99    mapped_schema: &SchemaRef,
100) -> Result<RecordBatch> {
101    let mut vectors = Vec::with_capacity(original_schema.column_schemas().len());
102    for (vector, schema) in batch.columns().iter().zip(original_schema.column_schemas()) {
103        if let ConcreteDataType::Json(j) = &schema.data_type {
104            if matches!(&j.format, JsonFormat::Jsonb) {
105                let mut string_vector_builder = StringBuilder::new();
106                let binary_vector = vector.as_binary::<i32>();
107                for value in binary_vector.iter() {
108                    let Some(value) = value else {
109                        string_vector_builder.append_null();
110                        continue;
111                    };
112                    let string_value =
113                        jsonb_to_string(value).with_context(|_| error::CastVectorSnafu {
114                            from_type: schema.data_type.clone(),
115                            to_type: ConcreteDataType::string_datatype(),
116                        })?;
117                    string_vector_builder.append_value(string_value);
118                }
119
120                let string_vector = string_vector_builder.finish();
121                vectors.push(Arc::new(string_vector) as ArrayRef);
122            } else {
123                vectors.push(vector.clone());
124            }
125        } else {
126            vectors.push(vector.clone());
127        }
128    }
129
130    let record_batch = datatypes::arrow::record_batch::RecordBatch::try_new(
131        mapped_schema.arrow_schema().clone(),
132        vectors,
133    )
134    .context(NewDfRecordBatchSnafu)?;
135    Ok(RecordBatch::from_df_record_batch(
136        mapped_schema.clone(),
137        record_batch,
138    ))
139}
140
141/// Maps the json type to string in the schema.
142///
143/// The json type is mapped to string by converting the json value to string.
144/// The schema is updated to have the same number of columns as the original schema,
145/// but with the json type mapped to string.
146///
147/// Returns the new schema and whether the schema needs to be mapped to string.
148pub fn map_json_type_to_string_schema(schema: SchemaRef) -> (SchemaRef, bool) {
149    let mut new_columns = Vec::with_capacity(schema.column_schemas().len());
150    let mut apply_mapper = false;
151    for column in schema.column_schemas() {
152        if matches!(column.data_type, ConcreteDataType::Json(_)) {
153            new_columns.push(ColumnSchema::new(
154                column.name.clone(),
155                ConcreteDataType::string_datatype(),
156                column.is_nullable(),
157            ));
158            apply_mapper = true;
159        } else {
160            new_columns.push(column.clone());
161        }
162    }
163    (Arc::new(Schema::new(new_columns)), apply_mapper)
164}
165
166impl SendableRecordBatchMapper {
167    /// Creates a new [SendableRecordBatchMapper] with the given inner [RecordBatchStream], mapper function, and schema mapper function.
168    pub fn new(
169        inner: SendableRecordBatchStream,
170        mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
171        schema_mapper: fn(SchemaRef) -> (SchemaRef, bool),
172    ) -> Self {
173        let (mapped_schema, apply_mapper) = schema_mapper(inner.schema());
174        Self {
175            inner,
176            mapper,
177            schema: mapped_schema,
178            apply_mapper,
179        }
180    }
181}
182
183impl RecordBatchStream for SendableRecordBatchMapper {
184    fn name(&self) -> &str {
185        "SendableRecordBatchMapper"
186    }
187
188    fn schema(&self) -> SchemaRef {
189        self.schema.clone()
190    }
191
192    fn output_ordering(&self) -> Option<&[OrderOption]> {
193        self.inner.output_ordering()
194    }
195
196    fn metrics(&self) -> Option<RecordBatchMetrics> {
197        self.inner.metrics()
198    }
199}
200
201impl Stream for SendableRecordBatchMapper {
202    type Item = Result<RecordBatch>;
203
204    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
205        if self.apply_mapper {
206            Pin::new(&mut self.inner).poll_next(cx).map(|opt| {
207                opt.map(|result| {
208                    result
209                        .and_then(|batch| (self.mapper)(batch, &self.inner.schema(), &self.schema))
210                })
211            })
212        } else {
213            Pin::new(&mut self.inner).poll_next(cx)
214        }
215    }
216}
217
218/// EmptyRecordBatchStream can be used to create a RecordBatchStream
219/// that will produce no results
220pub struct EmptyRecordBatchStream {
221    /// Schema wrapped by Arc
222    schema: SchemaRef,
223}
224
225impl EmptyRecordBatchStream {
226    /// Create an empty RecordBatchStream
227    pub fn new(schema: SchemaRef) -> Self {
228        Self { schema }
229    }
230}
231
232impl RecordBatchStream for EmptyRecordBatchStream {
233    fn schema(&self) -> SchemaRef {
234        self.schema.clone()
235    }
236
237    fn output_ordering(&self) -> Option<&[OrderOption]> {
238        None
239    }
240
241    fn metrics(&self) -> Option<RecordBatchMetrics> {
242        None
243    }
244}
245
246impl Stream for EmptyRecordBatchStream {
247    type Item = Result<RecordBatch>;
248
249    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
250        Poll::Ready(None)
251    }
252}
253
254#[derive(Debug, PartialEq)]
255pub struct RecordBatches {
256    schema: SchemaRef,
257    batches: Vec<RecordBatch>,
258}
259
260impl RecordBatches {
261    pub fn try_from_columns<I: IntoIterator<Item = VectorRef>>(
262        schema: SchemaRef,
263        columns: I,
264    ) -> Result<Self> {
265        let batches = vec![RecordBatch::new(schema.clone(), columns)?];
266        Ok(Self { schema, batches })
267    }
268
269    pub async fn try_collect(stream: SendableRecordBatchStream) -> Result<Self> {
270        let schema = stream.schema();
271        let batches = stream.try_collect::<Vec<_>>().await?;
272        Ok(Self { schema, batches })
273    }
274
275    #[inline]
276    pub fn empty() -> Self {
277        Self {
278            schema: Arc::new(Schema::new(vec![])),
279            batches: vec![],
280        }
281    }
282
283    pub fn iter(&self) -> impl Iterator<Item = &RecordBatch> {
284        self.batches.iter()
285    }
286
287    pub fn pretty_print(&self) -> Result<String> {
288        let df_batches = &self
289            .iter()
290            .map(|x| x.df_record_batch().clone())
291            .collect::<Vec<_>>();
292        let result = pretty::pretty_format_batches(df_batches).context(error::FormatSnafu)?;
293
294        Ok(result.to_string())
295    }
296
297    pub fn try_new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Result<Self> {
298        for batch in &batches {
299            ensure!(
300                batch.schema == schema,
301                error::CreateRecordBatchesSnafu {
302                    reason: format!(
303                        "expect RecordBatch schema equals {:?}, actual: {:?}",
304                        schema, batch.schema
305                    )
306                }
307            )
308        }
309        Ok(Self { schema, batches })
310    }
311
312    pub fn schema(&self) -> SchemaRef {
313        self.schema.clone()
314    }
315
316    pub fn take(self) -> Vec<RecordBatch> {
317        self.batches
318    }
319
320    pub fn as_stream(&self) -> SendableRecordBatchStream {
321        Box::pin(SimpleRecordBatchStream {
322            inner: RecordBatches {
323                schema: self.schema(),
324                batches: self.batches.clone(),
325            },
326            index: 0,
327        })
328    }
329}
330
331impl IntoIterator for RecordBatches {
332    type Item = RecordBatch;
333    type IntoIter = std::vec::IntoIter<Self::Item>;
334
335    fn into_iter(self) -> Self::IntoIter {
336        self.batches.into_iter()
337    }
338}
339
340pub struct SimpleRecordBatchStream {
341    inner: RecordBatches,
342    index: usize,
343}
344
345impl RecordBatchStream for SimpleRecordBatchStream {
346    fn schema(&self) -> SchemaRef {
347        self.inner.schema()
348    }
349
350    fn output_ordering(&self) -> Option<&[OrderOption]> {
351        None
352    }
353
354    fn metrics(&self) -> Option<RecordBatchMetrics> {
355        None
356    }
357}
358
359impl Stream for SimpleRecordBatchStream {
360    type Item = Result<RecordBatch>;
361
362    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
363        Poll::Ready(if self.index < self.inner.batches.len() {
364            let batch = self.inner.batches[self.index].clone();
365            self.index += 1;
366            Some(Ok(batch))
367        } else {
368            None
369        })
370    }
371}
372
373/// Adapt a [Stream] of [RecordBatch] to a [RecordBatchStream].
374pub struct RecordBatchStreamWrapper<S> {
375    pub schema: SchemaRef,
376    pub stream: S,
377    pub output_ordering: Option<Vec<OrderOption>>,
378    pub metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
379    pub span: Span,
380}
381
382impl<S> RecordBatchStreamWrapper<S> {
383    /// Creates a [RecordBatchStreamWrapper] without output ordering requirement.
384    pub fn new(schema: SchemaRef, stream: S) -> RecordBatchStreamWrapper<S> {
385        RecordBatchStreamWrapper {
386            schema,
387            stream,
388            output_ordering: None,
389            metrics: Default::default(),
390            span: Span::current(),
391        }
392    }
393}
394
395impl<S: Stream<Item = Result<RecordBatch>> + Unpin> RecordBatchStream
396    for RecordBatchStreamWrapper<S>
397{
398    fn name(&self) -> &str {
399        "RecordBatchStreamWrapper"
400    }
401
402    fn schema(&self) -> SchemaRef {
403        self.schema.clone()
404    }
405
406    fn output_ordering(&self) -> Option<&[OrderOption]> {
407        self.output_ordering.as_deref()
408    }
409
410    fn metrics(&self) -> Option<RecordBatchMetrics> {
411        self.metrics.load().as_ref().map(|s| s.as_ref().clone())
412    }
413}
414
415impl<S: Stream<Item = Result<RecordBatch>> + Unpin> Stream for RecordBatchStreamWrapper<S> {
416    type Item = Result<RecordBatch>;
417
418    fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
419        let _entered = self.span.clone().entered();
420        Pin::new(&mut self.stream).poll_next(ctx)
421    }
422}
423
424/// Memory tracker for RecordBatch streams. Clone to share the same limit across queries.
425///
426/// Each stream acquires quota independently from this tracker.
427#[derive(Clone)]
428pub struct QueryMemoryTracker {
429    manager: MemoryManager<CallbackMemoryMetrics>,
430    metrics: CallbackMemoryMetrics,
431    on_exhausted_policy: OnExhaustedPolicy,
432}
433
434impl fmt::Debug for QueryMemoryTracker {
435    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436        f.debug_struct("QueryMemoryTracker")
437            .field("current", &self.current())
438            .field("limit", &self.limit())
439            .field("on_exhausted_policy", &self.on_exhausted_policy)
440            .field("on_update", &self.metrics.has_on_update())
441            .field("on_exhausted", &self.metrics.has_on_exhausted())
442            .field("on_rejected", &self.metrics.has_on_rejected())
443            .finish()
444    }
445}
446
447impl QueryMemoryTracker {
448    /// Create a builder for a query memory tracker.
449    pub fn builder(
450        limit: usize,
451        on_exhausted_policy: OnExhaustedPolicy,
452    ) -> QueryMemoryTrackerBuilder {
453        QueryMemoryTrackerBuilder {
454            limit,
455            on_exhausted_policy,
456            on_update: None,
457            on_exhausted: None,
458            on_reject: None,
459        }
460    }
461
462    fn new_stream_tracker(&self) -> StreamMemoryTracker {
463        StreamMemoryTracker {
464            tracker: self.clone(),
465            guard: self.manager.try_acquire(0).unwrap(),
466            tracked_bytes: 0,
467        }
468    }
469    /// Get the current memory usage in bytes.
470    pub fn current(&self) -> usize {
471        self.manager.used_bytes() as usize
472    }
473
474    fn limit(&self) -> usize {
475        self.manager.limit_bytes() as usize
476    }
477
478    fn reject_error(
479        &self,
480        current: usize,
481        additional: usize,
482        stream_tracked: usize,
483    ) -> error::Error {
484        let limit = self.limit();
485        let msg = format!(
486            "{} requested, {} used globally ({}%), {} used by this stream, hard limit: {}",
487            ReadableSize(additional as u64),
488            ReadableSize(current as u64),
489            (current * 100).checked_div(limit).unwrap_or(0),
490            ReadableSize(stream_tracked as u64),
491            ReadableSize(limit as u64)
492        );
493        error::ExceedMemoryLimitSnafu { msg }.build()
494    }
495
496    fn inc_rejected(&self) {
497        self.metrics.inc_rejected();
498    }
499}
500
501/// Builder for constructing a [`QueryMemoryTracker`] with optional callbacks.
502pub struct QueryMemoryTrackerBuilder {
503    limit: usize,
504    on_exhausted_policy: OnExhaustedPolicy,
505    on_update: Option<UpdateCallback>,
506    on_exhausted: Option<UnitCallback>,
507    on_reject: Option<RejectCallback>,
508}
509
510impl QueryMemoryTrackerBuilder {
511    /// Set a callback to be called whenever the usage changes successfully.
512    /// The callback receives the new total usage in bytes.
513    ///
514    /// # Note
515    /// The callback is called after both successful `track()` and stream drop.
516    /// Usage is exact in unlimited mode and 1KB-aligned in limited mode.
517    pub fn on_update<F>(mut self, on_update: F) -> Self
518    where
519        F: Fn(usize) + Send + Sync + 'static,
520    {
521        self.on_update = Some(Arc::new(on_update));
522        self
523    }
524
525    /// Set a callback to be called when memory is unavailable for immediate acquisition.
526    ///
527    /// # Note
528    /// This is called when the non-blocking allocation fast path fails.
529    /// Requests using `OnExhaustedPolicy::Wait` may still succeed after waiting.
530    /// It is never called when `limit == 0` (unlimited mode).
531    pub fn on_exhausted<F>(mut self, on_exhausted: F) -> Self
532    where
533        F: Fn() + Send + Sync + 'static,
534    {
535        self.on_exhausted = Some(Arc::new(on_exhausted));
536        self
537    }
538
539    /// Set a callback to be called when the request ultimately fails due to memory pressure.
540    pub fn on_reject<F>(mut self, on_reject: F) -> Self
541    where
542        F: Fn() + Send + Sync + 'static,
543    {
544        self.on_reject = Some(Arc::new(on_reject));
545        self
546    }
547
548    /// Build a [`QueryMemoryTracker`] from this builder.
549    pub fn build(self) -> QueryMemoryTracker {
550        let metrics = CallbackMemoryMetrics::new(self.on_update, self.on_exhausted, self.on_reject);
551        let manager = MemoryManager::with_granularity(
552            self.limit as u64,
553            PermitGranularity::Kilobyte,
554            metrics.clone(),
555        );
556
557        QueryMemoryTracker {
558            manager,
559            metrics,
560            on_exhausted_policy: self.on_exhausted_policy,
561        }
562    }
563}
564
565struct StreamMemoryTracker {
566    tracker: QueryMemoryTracker,
567    guard: MemoryGuard<CallbackMemoryMetrics>,
568    tracked_bytes: usize,
569}
570
571type MemoryAcquireResult = std::result::Result<(), common_memory_manager::Error>;
572
573impl StreamMemoryTracker {
574    fn inc_rejected(&self) {
575        self.tracker.inc_rejected();
576    }
577
578    fn try_track(&mut self, additional: usize) -> Result<()> {
579        if self.guard.try_acquire_additional(additional as u64) {
580            self.tracked_bytes = self.tracked_bytes.saturating_add(additional);
581            Ok(())
582        } else {
583            Err(self.reject_error(additional))
584        }
585    }
586
587    async fn track_with_policy(mut self, additional: usize) -> (Self, MemoryAcquireResult) {
588        let result = self
589            .guard
590            .acquire_additional_with_policy(additional as u64, self.tracker.on_exhausted_policy)
591            .await;
592        if result.is_ok() {
593            self.tracked_bytes = self.tracked_bytes.saturating_add(additional);
594        }
595        (self, result)
596    }
597
598    fn reject_error(&self, additional: usize) -> error::Error {
599        let current = self.tracker.current();
600        self.tracker
601            .reject_error(current, additional, self.tracked_bytes)
602    }
603
604    fn wait_error(&self, additional: usize, source: common_memory_manager::Error) -> error::Error {
605        match source {
606            common_memory_manager::Error::MemoryLimitExceeded { .. } => {
607                self.reject_error(additional)
608            }
609            common_memory_manager::Error::MemoryAcquireTimeout { waited, .. } => {
610                let current = self.tracker.current();
611                let limit = self.tracker.limit();
612                let msg = format!(
613                    "timed out waiting {:?} for {}, {} used globally ({}%), {} used by this stream, hard limit: {}",
614                    waited,
615                    ReadableSize(additional as u64),
616                    ReadableSize(current as u64),
617                    (current * 100).checked_div(limit).unwrap_or(0),
618                    ReadableSize(self.tracked_bytes as u64),
619                    ReadableSize(limit as u64)
620                );
621                error::ExceedMemoryLimitSnafu { msg }.build()
622            }
623            error => error::ExternalSnafu.into_error(BoxedError::new(error)),
624        }
625    }
626}
627
628type PendingTrackFuture = Pin<
629    Box<dyn Future<Output = (StreamMemoryTracker, RecordBatch, usize, MemoryAcquireResult)> + Send>,
630>;
631
632#[derive(Clone)]
633struct CallbackMemoryMetrics {
634    inner: Arc<CallbackMemoryMetricsInner>,
635}
636
637type UpdateCallback = Arc<dyn Fn(usize) + Send + Sync>;
638type UnitCallback = Arc<dyn Fn() + Send + Sync>;
639type RejectCallback = UnitCallback;
640
641struct CallbackMemoryMetricsInner {
642    on_update: Option<UpdateCallback>,
643    on_exhausted: Option<UnitCallback>,
644    on_reject: Option<RejectCallback>,
645}
646
647impl CallbackMemoryMetrics {
648    fn new(
649        on_update: Option<UpdateCallback>,
650        on_exhausted: Option<UnitCallback>,
651        on_reject: Option<RejectCallback>,
652    ) -> Self {
653        Self {
654            inner: Arc::new(CallbackMemoryMetricsInner {
655                on_update,
656                on_exhausted,
657                on_reject,
658            }),
659        }
660    }
661
662    fn has_on_update(&self) -> bool {
663        self.inner.on_update.is_some()
664    }
665
666    fn has_on_exhausted(&self) -> bool {
667        self.inner.on_exhausted.is_some()
668    }
669
670    fn has_on_rejected(&self) -> bool {
671        self.inner.on_reject.is_some()
672    }
673
674    fn inc_rejected(&self) {
675        if let Some(callback) = &self.inner.on_reject {
676            callback();
677        }
678    }
679}
680
681impl MemoryMetrics for CallbackMemoryMetrics {
682    fn set_limit(&self, _: i64) {}
683
684    fn set_in_use(&self, bytes: i64) {
685        if let Some(callback) = &self.inner.on_update {
686            callback(bytes.max(0) as usize);
687        }
688    }
689
690    fn inc_exhausted(&self, _: &str) {
691        if let Some(callback) = &self.inner.on_exhausted {
692            callback();
693        }
694    }
695}
696
697/// A wrapper stream that tracks memory usage of RecordBatches.
698pub struct MemoryTrackedStream {
699    inner: SendableRecordBatchStream,
700    tracker: Option<StreamMemoryTracker>,
701    // Waiting stores a batch that has already been pulled from the inner stream but has not yet
702    // acquired additional quota. This keeps `poll_next()` non-blocking and allows bounded waits,
703    // at the cost of temporarily holding one untracked batch per blocked stream in memory.
704    waiting: Option<PendingTrackFuture>,
705}
706
707impl MemoryTrackedStream {
708    pub fn new(inner: SendableRecordBatchStream, tracker: QueryMemoryTracker) -> Self {
709        Self {
710            inner,
711            tracker: Some(tracker.new_stream_tracker()),
712            waiting: None,
713        }
714    }
715
716    fn ready_tracker_mut(&mut self) -> &mut StreamMemoryTracker {
717        debug_assert!(
718            self.waiting.is_none(),
719            "a ready tracker must not coexist with a waiting future"
720        );
721        self.tracker.as_mut().unwrap()
722    }
723
724    fn enter_waiting(&mut self, batch: RecordBatch, additional: usize) {
725        debug_assert!(
726            self.waiting.is_none(),
727            "enter_waiting should only be called from the ready state"
728        );
729        debug_assert!(
730            self.tracker.is_some(),
731            "enter_waiting requires a tracker in the ready state"
732        );
733        let tracker = self.tracker.take().unwrap();
734        self.waiting = Some(Self::start_waiting(tracker, batch, additional));
735    }
736
737    fn start_waiting(
738        tracker: StreamMemoryTracker,
739        batch: RecordBatch,
740        additional: usize,
741    ) -> PendingTrackFuture {
742        Box::pin(async move {
743            let (tracker, result) = tracker.track_with_policy(additional).await;
744            (tracker, batch, additional, result)
745        })
746    }
747
748    fn poll_waiting(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<RecordBatch>>> {
749        let future = self.waiting.as_mut().unwrap();
750        match future.as_mut().poll(cx) {
751            Poll::Ready((tracker, batch, additional, result)) => {
752                let output = match result {
753                    Ok(()) => Ok(batch),
754                    Err(error) => {
755                        tracker.inc_rejected();
756                        Err(tracker.wait_error(additional, error))
757                    }
758                };
759                self.waiting = None;
760                self.tracker = Some(tracker);
761                Poll::Ready(Some(output))
762            }
763            Poll::Pending => Poll::Pending,
764        }
765    }
766
767    fn poll_batch(
768        &mut self,
769        batch: RecordBatch,
770        cx: &mut Context<'_>,
771    ) -> Poll<Option<Result<RecordBatch>>> {
772        let additional = batch.buffer_memory_size();
773        let tracker = self.ready_tracker_mut();
774
775        if let Err(error) = tracker.try_track(additional) {
776            match tracker.tracker.on_exhausted_policy {
777                OnExhaustedPolicy::Fail => {
778                    tracker.inc_rejected();
779                    return Poll::Ready(Some(Err(error)));
780                }
781                // `Wait` is a deliberate tradeoff: the batch has already been materialized, so we
782                // keep it in memory while waiting for quota instead of failing immediately. Under
783                // contention, real memory usage can therefore exceed `scan_memory_limit` by up to
784                // one buffered batch per blocked stream.
785                OnExhaustedPolicy::Wait { .. } => {
786                    self.enter_waiting(batch, additional);
787                    return self.poll_waiting(cx);
788                }
789            }
790        }
791
792        Poll::Ready(Some(Ok(batch)))
793    }
794}
795
796impl Stream for MemoryTrackedStream {
797    type Item = Result<RecordBatch>;
798
799    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
800        if self.waiting.is_some() {
801            return self.poll_waiting(cx);
802        }
803
804        match Pin::new(&mut self.inner).poll_next(cx) {
805            Poll::Ready(Some(Ok(batch))) => self.poll_batch(batch, cx),
806            Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
807            Poll::Ready(None) => Poll::Ready(None),
808            Poll::Pending => Poll::Pending,
809        }
810    }
811
812    fn size_hint(&self) -> (usize, Option<usize>) {
813        self.inner.size_hint()
814    }
815}
816
817impl RecordBatchStream for MemoryTrackedStream {
818    fn schema(&self) -> SchemaRef {
819        self.inner.schema()
820    }
821
822    fn output_ordering(&self) -> Option<&[OrderOption]> {
823        self.inner.output_ordering()
824    }
825
826    fn metrics(&self) -> Option<RecordBatchMetrics> {
827        self.inner.metrics()
828    }
829}
830
831#[cfg(test)]
832mod tests {
833    use std::sync::Arc;
834    use std::sync::atomic::{AtomicUsize, Ordering};
835    use std::time::Duration;
836
837    use common_memory_manager::{OnExhaustedPolicy, PermitGranularity};
838    use datatypes::prelude::{ConcreteDataType, VectorRef};
839    use datatypes::schema::{ColumnSchema, Schema};
840    use datatypes::vectors::{BooleanVector, Int32Vector, StringVector};
841    use futures::StreamExt;
842    use tokio::time::{sleep, timeout};
843
844    use super::*;
845
846    fn large_string_batch(bytes: usize) -> RecordBatch {
847        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
848            "payload",
849            ConcreteDataType::string_datatype(),
850            false,
851        )]));
852        let payload = "x".repeat(bytes);
853        let vector: VectorRef = Arc::new(StringVector::from(vec![payload]));
854        RecordBatch::new(schema, vec![vector]).unwrap()
855    }
856
857    fn aligned_tracked_bytes(bytes: usize) -> usize {
858        PermitGranularity::Kilobyte
859            .permits_to_bytes(PermitGranularity::Kilobyte.bytes_to_permits(bytes as u64))
860            as usize
861    }
862
863    #[test]
864    fn test_recordbatches_try_from_columns() {
865        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
866            "a",
867            ConcreteDataType::int32_datatype(),
868            false,
869        )]));
870        let result = RecordBatches::try_from_columns(
871            schema.clone(),
872            vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
873        );
874        assert!(result.is_err());
875
876        let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
877        let expected = vec![RecordBatch::new(schema.clone(), vec![v.clone()]).unwrap()];
878        let r = RecordBatches::try_from_columns(schema, vec![v]).unwrap();
879        assert_eq!(r.take(), expected);
880    }
881
882    #[test]
883    fn test_recordbatches_try_new() {
884        let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
885        let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
886        let column_c = ColumnSchema::new("c", ConcreteDataType::boolean_datatype(), false);
887
888        let va: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
889        let vb: VectorRef = Arc::new(StringVector::from(vec!["hello", "world"]));
890        let vc: VectorRef = Arc::new(BooleanVector::from(vec![true, false]));
891
892        let schema1 = Arc::new(Schema::new(vec![column_a.clone(), column_b]));
893        let batch1 = RecordBatch::new(schema1.clone(), vec![va.clone(), vb]).unwrap();
894
895        let schema2 = Arc::new(Schema::new(vec![column_a, column_c]));
896        let batch2 = RecordBatch::new(schema2.clone(), vec![va, vc]).unwrap();
897
898        let result = RecordBatches::try_new(schema1.clone(), vec![batch1.clone(), batch2]);
899        assert!(result.is_err());
900        assert_eq!(
901            result.unwrap_err().to_string(),
902            format!(
903                "Failed to create RecordBatches, reason: expect RecordBatch schema equals {schema1:?}, actual: {schema2:?}",
904            )
905        );
906
907        let batches = RecordBatches::try_new(schema1.clone(), vec![batch1.clone()]).unwrap();
908        let expected = "\
909+---+-------+
910| a | b     |
911+---+-------+
912| 1 | hello |
913| 2 | world |
914+---+-------+";
915        assert_eq!(batches.pretty_print().unwrap(), expected);
916
917        assert_eq!(schema1, batches.schema());
918        assert_eq!(vec![batch1], batches.take());
919    }
920
921    #[tokio::test]
922    async fn test_simple_recordbatch_stream() {
923        let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
924        let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
925        let schema = Arc::new(Schema::new(vec![column_a, column_b]));
926
927        let va1: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
928        let vb1: VectorRef = Arc::new(StringVector::from(vec!["a", "b"]));
929        let batch1 = RecordBatch::new(schema.clone(), vec![va1, vb1]).unwrap();
930
931        let va2: VectorRef = Arc::new(Int32Vector::from_slice([3, 4, 5]));
932        let vb2: VectorRef = Arc::new(StringVector::from(vec!["c", "d", "e"]));
933        let batch2 = RecordBatch::new(schema.clone(), vec![va2, vb2]).unwrap();
934
935        let recordbatches =
936            RecordBatches::try_new(schema.clone(), vec![batch1.clone(), batch2.clone()]).unwrap();
937        let stream = recordbatches.as_stream();
938        let collected = util::collect(stream).await.unwrap();
939        assert_eq!(collected.len(), 2);
940        assert_eq!(collected[0], batch1);
941        assert_eq!(collected[1], batch2);
942    }
943
944    const MB: usize = 1024 * 1024;
945
946    #[test]
947    fn test_query_memory_tracker_basic() {
948        let tracker =
949            Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
950
951        let mut stream1 = tracker.new_stream_tracker();
952        assert!(stream1.try_track(5 * MB).is_ok());
953        assert_eq!(tracker.current(), 5 * MB);
954
955        let mut stream2 = tracker.new_stream_tracker();
956        assert!(stream2.try_track(4 * MB).is_ok());
957        assert_eq!(tracker.current(), 9 * MB);
958
959        drop(stream1);
960        drop(stream2);
961        assert_eq!(tracker.current(), 0);
962    }
963
964    #[test]
965    fn test_query_memory_tracker_shared_global_limit() {
966        let tracker =
967            Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
968        let mut stream1 = tracker.new_stream_tracker();
969        let mut stream2 = tracker.new_stream_tracker();
970
971        assert!(stream1.try_track(3 * MB).is_ok());
972        assert_eq!(tracker.current(), 3 * MB);
973        assert!(stream2.try_track(6 * MB).is_ok());
974        assert_eq!(tracker.current(), 9 * MB);
975
976        let err = stream2.try_track(2 * MB).unwrap_err();
977        let err_msg = err.to_string();
978        assert!(err_msg.contains("6.0MiB used by this stream"));
979        assert!(err_msg.contains("9.0MiB used globally (90%)"));
980        assert!(err_msg.contains("hard limit: 10.0MiB"));
981        assert_eq!(tracker.current(), 9 * MB);
982
983        drop(stream1);
984        assert_eq!(tracker.current(), 6 * MB);
985        drop(stream2);
986        assert_eq!(tracker.current(), 0);
987    }
988
989    #[test]
990    fn test_query_memory_tracker_hard_limit() {
991        let tracker =
992            Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
993        let mut stream = tracker.new_stream_tracker();
994
995        assert!(stream.try_track(9 * MB).is_ok());
996        assert_eq!(tracker.current(), 9 * MB);
997
998        assert!(stream.try_track(2 * MB).is_err());
999        assert_eq!(tracker.current(), 9 * MB);
1000
1001        assert!(stream.try_track(MB).is_ok());
1002        assert_eq!(tracker.current(), 10 * MB);
1003
1004        assert!(stream.try_track(MB).is_err());
1005        assert_eq!(tracker.current(), 10 * MB);
1006
1007        drop(stream);
1008        assert_eq!(tracker.current(), 0);
1009    }
1010
1011    #[test]
1012    fn test_query_memory_tracker_unlimited() {
1013        let tracker = Arc::new(QueryMemoryTracker::builder(0, OnExhaustedPolicy::Fail).build());
1014        let mut stream = tracker.new_stream_tracker();
1015
1016        assert!(stream.try_track(10 * MB).is_ok());
1017        assert_eq!(tracker.current(), 10 * MB);
1018        drop(stream);
1019        assert_eq!(tracker.current(), 0);
1020    }
1021
1022    #[test]
1023    fn test_query_memory_tracker_rounds_to_kilobytes() {
1024        let tracker =
1025            Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
1026        let mut stream = tracker.new_stream_tracker();
1027
1028        assert!(stream.try_track(1_537).is_ok());
1029        assert_eq!(tracker.current(), 2 * 1024);
1030
1031        drop(stream);
1032        assert_eq!(tracker.current(), 0);
1033    }
1034
1035    #[tokio::test]
1036    async fn test_memory_tracked_stream_waits_for_capacity() {
1037        let exhausted = Arc::new(AtomicUsize::new(0));
1038        let rejected = Arc::new(AtomicUsize::new(0));
1039        let exhausted_counter = exhausted.clone();
1040        let rejected_counter = rejected.clone();
1041        let tracker = QueryMemoryTracker::builder(
1042            MB,
1043            OnExhaustedPolicy::Wait {
1044                timeout: Duration::from_millis(200),
1045            },
1046        )
1047        .on_exhausted(move || {
1048            exhausted_counter.fetch_add(1, Ordering::Relaxed);
1049        })
1050        .on_reject(move || {
1051            rejected_counter.fetch_add(1, Ordering::Relaxed);
1052        })
1053        .build();
1054        let batch = large_string_batch(700 * 1024);
1055        let expected_bytes = aligned_tracked_bytes(batch.buffer_memory_size());
1056
1057        let mut stream1 = MemoryTrackedStream::new(
1058            RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1059                .unwrap()
1060                .as_stream(),
1061            tracker.clone(),
1062        );
1063        let first = stream1.next().await.unwrap().unwrap();
1064        assert_eq!(first.num_rows(), 1);
1065        assert_eq!(tracker.current(), expected_bytes);
1066
1067        let stream2 = MemoryTrackedStream::new(
1068            RecordBatches::try_new(batch.schema.clone(), vec![batch])
1069                .unwrap()
1070                .as_stream(),
1071            tracker.clone(),
1072        );
1073        let waiter = tokio::spawn(async move {
1074            let mut stream2 = stream2;
1075            stream2.next().await.unwrap()
1076        });
1077
1078        sleep(Duration::from_millis(50)).await;
1079        assert!(!waiter.is_finished());
1080
1081        drop(stream1);
1082        let second = waiter.await.unwrap().unwrap();
1083        assert_eq!(second.num_rows(), 1);
1084        assert_eq!(exhausted.load(Ordering::Relaxed), 1);
1085        assert_eq!(rejected.load(Ordering::Relaxed), 0);
1086    }
1087
1088    #[tokio::test]
1089    async fn test_memory_tracked_stream_wait_times_out() {
1090        let exhausted = Arc::new(AtomicUsize::new(0));
1091        let rejected = Arc::new(AtomicUsize::new(0));
1092        let exhausted_counter = exhausted.clone();
1093        let rejected_counter = rejected.clone();
1094        let tracker = QueryMemoryTracker::builder(
1095            MB,
1096            OnExhaustedPolicy::Wait {
1097                timeout: Duration::from_millis(50),
1098            },
1099        )
1100        .on_exhausted(move || {
1101            exhausted_counter.fetch_add(1, Ordering::Relaxed);
1102        })
1103        .on_reject(move || {
1104            rejected_counter.fetch_add(1, Ordering::Relaxed);
1105        })
1106        .build();
1107        let batch = large_string_batch(700 * 1024);
1108
1109        let mut stream1 = MemoryTrackedStream::new(
1110            RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1111                .unwrap()
1112                .as_stream(),
1113            tracker.clone(),
1114        );
1115        let first = stream1.next().await.unwrap().unwrap();
1116        assert_eq!(first.num_rows(), 1);
1117
1118        let mut stream2 = MemoryTrackedStream::new(
1119            RecordBatches::try_new(batch.schema.clone(), vec![batch])
1120                .unwrap()
1121                .as_stream(),
1122            tracker,
1123        );
1124        let result = timeout(Duration::from_secs(1), stream2.next())
1125            .await
1126            .unwrap();
1127        let error = result.unwrap().unwrap_err();
1128        assert!(error.to_string().contains("timed out waiting"));
1129        assert_eq!(exhausted.load(Ordering::Relaxed), 1);
1130        assert_eq!(rejected.load(Ordering::Relaxed), 1);
1131    }
1132
1133    #[tokio::test]
1134    async fn test_memory_tracked_stream_fail_policy_rejects_immediately() {
1135        let exhausted = Arc::new(AtomicUsize::new(0));
1136        let rejected = Arc::new(AtomicUsize::new(0));
1137        let exhausted_counter = exhausted.clone();
1138        let rejected_counter = rejected.clone();
1139        let tracker = QueryMemoryTracker::builder(MB, OnExhaustedPolicy::Fail)
1140            .on_exhausted(move || {
1141                exhausted_counter.fetch_add(1, Ordering::Relaxed);
1142            })
1143            .on_reject(move || {
1144                rejected_counter.fetch_add(1, Ordering::Relaxed);
1145            })
1146            .build();
1147        let batch = large_string_batch(700 * 1024);
1148
1149        let mut stream1 = MemoryTrackedStream::new(
1150            RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1151                .unwrap()
1152                .as_stream(),
1153            tracker.clone(),
1154        );
1155        let first = stream1.next().await.unwrap().unwrap();
1156        assert_eq!(first.num_rows(), 1);
1157
1158        let mut stream2 = MemoryTrackedStream::new(
1159            RecordBatches::try_new(batch.schema.clone(), vec![batch])
1160                .unwrap()
1161                .as_stream(),
1162            tracker,
1163        );
1164        let result = stream2.next().await.unwrap();
1165        assert!(result.is_err());
1166        assert_eq!(exhausted.load(Ordering::Relaxed), 1);
1167        assert_eq!(rejected.load(Ordering::Relaxed), 1);
1168    }
1169}