common_recordbatch/
lib.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(never_type)]
16
17pub mod adapter;
18pub mod cursor;
19pub mod error;
20pub mod filter;
21pub mod recordbatch;
22pub mod util;
23
24use std::fmt;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28
29use adapter::RecordBatchMetrics;
30use arc_swap::ArcSwapOption;
31use common_base::readable_size::ReadableSize;
32use common_error::ext::BoxedError;
33use common_memory_manager::{
34    MemoryGuard, MemoryManager, MemoryMetrics, OnExhaustedPolicy, PermitGranularity,
35};
36use common_telemetry::tracing::Span;
37pub use datafusion::physical_plan::SendableRecordBatchStream as DfSendableRecordBatchStream;
38use datatypes::arrow::array::{ArrayRef, AsArray, StringBuilder};
39use datatypes::arrow::compute::SortOptions;
40pub use datatypes::arrow::record_batch::RecordBatch as DfRecordBatch;
41use datatypes::arrow::util::pretty;
42use datatypes::prelude::{ConcreteDataType, VectorRef};
43use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
44use datatypes::types::{JsonFormat, jsonb_to_string};
45use error::Result;
46use futures::task::{Context, Poll};
47use futures::{Stream, TryStreamExt};
48pub use recordbatch::RecordBatch;
49use snafu::{IntoError, ResultExt, ensure};
50
51use crate::error::NewDfRecordBatchSnafu;
52
53pub trait RecordBatchStream: Stream<Item = Result<RecordBatch>> {
54    fn name(&self) -> &str {
55        "RecordBatchStream"
56    }
57
58    fn schema(&self) -> SchemaRef;
59
60    fn output_ordering(&self) -> Option<&[OrderOption]>;
61
62    fn metrics(&self) -> Option<RecordBatchMetrics>;
63}
64
65pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send>>;
66
67#[derive(Debug, Clone, PartialEq, Eq)]
68pub struct OrderOption {
69    pub name: String,
70    pub options: SortOptions,
71}
72
73/// A wrapper that maps a [RecordBatchStream] to a new [RecordBatchStream] by applying a function to each [RecordBatch].
74///
75/// The mapper function is applied to each [RecordBatch] in the stream.
76/// The schema of the new [RecordBatchStream] is the same as the schema of the inner [RecordBatchStream] after applying the schema mapper function.
77/// The output ordering of the new [RecordBatchStream] is the same as the output ordering of the inner [RecordBatchStream].
78/// The metrics of the new [RecordBatchStream] is the same as the metrics of the inner [RecordBatchStream] if it is not `None`.
79pub struct SendableRecordBatchMapper {
80    inner: SendableRecordBatchStream,
81    /// The mapper function is applied to each [RecordBatch] in the stream.
82    /// The original schema and the mapped schema are passed to the mapper function.
83    mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
84    /// The schema of the new [RecordBatchStream] is the same as the schema of the inner [RecordBatchStream] after applying the schema mapper function.
85    schema: SchemaRef,
86    /// Whether the mapper function is applied to each [RecordBatch] in the stream.
87    apply_mapper: bool,
88}
89
90/// Maps the json type to string in the batch.
91///
92/// The json type is mapped to string by converting the json value to string.
93/// The batch is updated to have the same number of columns as the original batch,
94/// but with the json type mapped to string.
95pub fn map_json_type_to_string(
96    batch: RecordBatch,
97    original_schema: &SchemaRef,
98    mapped_schema: &SchemaRef,
99) -> Result<RecordBatch> {
100    let mut vectors = Vec::with_capacity(original_schema.column_schemas().len());
101    for (vector, schema) in batch.columns().iter().zip(original_schema.column_schemas()) {
102        if let ConcreteDataType::Json(j) = &schema.data_type {
103            if matches!(&j.format, JsonFormat::Jsonb) {
104                let mut string_vector_builder = StringBuilder::new();
105                let binary_vector = vector.as_binary::<i32>();
106                for value in binary_vector.iter() {
107                    let Some(value) = value else {
108                        string_vector_builder.append_null();
109                        continue;
110                    };
111                    let string_value =
112                        jsonb_to_string(value).with_context(|_| error::CastVectorSnafu {
113                            from_type: schema.data_type.clone(),
114                            to_type: ConcreteDataType::string_datatype(),
115                        })?;
116                    string_vector_builder.append_value(string_value);
117                }
118
119                let string_vector = string_vector_builder.finish();
120                vectors.push(Arc::new(string_vector) as ArrayRef);
121            } else {
122                vectors.push(vector.clone());
123            }
124        } else {
125            vectors.push(vector.clone());
126        }
127    }
128
129    let record_batch = datatypes::arrow::record_batch::RecordBatch::try_new(
130        mapped_schema.arrow_schema().clone(),
131        vectors,
132    )
133    .context(NewDfRecordBatchSnafu)?;
134    Ok(RecordBatch::from_df_record_batch(
135        mapped_schema.clone(),
136        record_batch,
137    ))
138}
139
140/// Maps the json type to string in the schema.
141///
142/// The json type is mapped to string by converting the json value to string.
143/// The schema is updated to have the same number of columns as the original schema,
144/// but with the json type mapped to string.
145///
146/// Returns the new schema and whether the schema needs to be mapped to string.
147pub fn map_json_type_to_string_schema(schema: SchemaRef) -> (SchemaRef, bool) {
148    let mut new_columns = Vec::with_capacity(schema.column_schemas().len());
149    let mut apply_mapper = false;
150    for column in schema.column_schemas() {
151        if matches!(column.data_type, ConcreteDataType::Json(_)) {
152            new_columns.push(ColumnSchema::new(
153                column.name.clone(),
154                ConcreteDataType::string_datatype(),
155                column.is_nullable(),
156            ));
157            apply_mapper = true;
158        } else {
159            new_columns.push(column.clone());
160        }
161    }
162    (Arc::new(Schema::new(new_columns)), apply_mapper)
163}
164
165impl SendableRecordBatchMapper {
166    /// Creates a new [SendableRecordBatchMapper] with the given inner [RecordBatchStream], mapper function, and schema mapper function.
167    pub fn new(
168        inner: SendableRecordBatchStream,
169        mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
170        schema_mapper: fn(SchemaRef) -> (SchemaRef, bool),
171    ) -> Self {
172        let (mapped_schema, apply_mapper) = schema_mapper(inner.schema());
173        Self {
174            inner,
175            mapper,
176            schema: mapped_schema,
177            apply_mapper,
178        }
179    }
180}
181
182impl RecordBatchStream for SendableRecordBatchMapper {
183    fn name(&self) -> &str {
184        "SendableRecordBatchMapper"
185    }
186
187    fn schema(&self) -> SchemaRef {
188        self.schema.clone()
189    }
190
191    fn output_ordering(&self) -> Option<&[OrderOption]> {
192        self.inner.output_ordering()
193    }
194
195    fn metrics(&self) -> Option<RecordBatchMetrics> {
196        self.inner.metrics()
197    }
198}
199
200impl Stream for SendableRecordBatchMapper {
201    type Item = Result<RecordBatch>;
202
203    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204        if self.apply_mapper {
205            Pin::new(&mut self.inner).poll_next(cx).map(|opt| {
206                opt.map(|result| {
207                    result
208                        .and_then(|batch| (self.mapper)(batch, &self.inner.schema(), &self.schema))
209                })
210            })
211        } else {
212            Pin::new(&mut self.inner).poll_next(cx)
213        }
214    }
215}
216
217/// EmptyRecordBatchStream can be used to create a RecordBatchStream
218/// that will produce no results
219pub struct EmptyRecordBatchStream {
220    /// Schema wrapped by Arc
221    schema: SchemaRef,
222}
223
224impl EmptyRecordBatchStream {
225    /// Create an empty RecordBatchStream
226    pub fn new(schema: SchemaRef) -> Self {
227        Self { schema }
228    }
229}
230
231impl RecordBatchStream for EmptyRecordBatchStream {
232    fn schema(&self) -> SchemaRef {
233        self.schema.clone()
234    }
235
236    fn output_ordering(&self) -> Option<&[OrderOption]> {
237        None
238    }
239
240    fn metrics(&self) -> Option<RecordBatchMetrics> {
241        None
242    }
243}
244
245impl Stream for EmptyRecordBatchStream {
246    type Item = Result<RecordBatch>;
247
248    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
249        Poll::Ready(None)
250    }
251}
252
253#[derive(Debug, PartialEq)]
254pub struct RecordBatches {
255    schema: SchemaRef,
256    batches: Vec<RecordBatch>,
257}
258
259impl RecordBatches {
260    pub fn try_from_columns<I: IntoIterator<Item = VectorRef>>(
261        schema: SchemaRef,
262        columns: I,
263    ) -> Result<Self> {
264        let batches = vec![RecordBatch::new(schema.clone(), columns)?];
265        Ok(Self { schema, batches })
266    }
267
268    pub async fn try_collect(stream: SendableRecordBatchStream) -> Result<Self> {
269        let schema = stream.schema();
270        let batches = stream.try_collect::<Vec<_>>().await?;
271        Ok(Self { schema, batches })
272    }
273
274    #[inline]
275    pub fn empty() -> Self {
276        Self {
277            schema: Arc::new(Schema::new(vec![])),
278            batches: vec![],
279        }
280    }
281
282    pub fn iter(&self) -> impl Iterator<Item = &RecordBatch> {
283        self.batches.iter()
284    }
285
286    pub fn pretty_print(&self) -> Result<String> {
287        let df_batches = &self
288            .iter()
289            .map(|x| x.df_record_batch().clone())
290            .collect::<Vec<_>>();
291        let result = pretty::pretty_format_batches(df_batches).context(error::FormatSnafu)?;
292
293        Ok(result.to_string())
294    }
295
296    pub fn try_new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Result<Self> {
297        for batch in &batches {
298            ensure!(
299                batch.schema == schema,
300                error::CreateRecordBatchesSnafu {
301                    reason: format!(
302                        "expect RecordBatch schema equals {:?}, actual: {:?}",
303                        schema, batch.schema
304                    )
305                }
306            )
307        }
308        Ok(Self { schema, batches })
309    }
310
311    pub fn schema(&self) -> SchemaRef {
312        self.schema.clone()
313    }
314
315    pub fn take(self) -> Vec<RecordBatch> {
316        self.batches
317    }
318
319    pub fn as_stream(&self) -> SendableRecordBatchStream {
320        Box::pin(SimpleRecordBatchStream {
321            inner: RecordBatches {
322                schema: self.schema(),
323                batches: self.batches.clone(),
324            },
325            index: 0,
326        })
327    }
328}
329
330impl IntoIterator for RecordBatches {
331    type Item = RecordBatch;
332    type IntoIter = std::vec::IntoIter<Self::Item>;
333
334    fn into_iter(self) -> Self::IntoIter {
335        self.batches.into_iter()
336    }
337}
338
339pub struct SimpleRecordBatchStream {
340    inner: RecordBatches,
341    index: usize,
342}
343
344impl RecordBatchStream for SimpleRecordBatchStream {
345    fn schema(&self) -> SchemaRef {
346        self.inner.schema()
347    }
348
349    fn output_ordering(&self) -> Option<&[OrderOption]> {
350        None
351    }
352
353    fn metrics(&self) -> Option<RecordBatchMetrics> {
354        None
355    }
356}
357
358impl Stream for SimpleRecordBatchStream {
359    type Item = Result<RecordBatch>;
360
361    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
362        Poll::Ready(if self.index < self.inner.batches.len() {
363            let batch = self.inner.batches[self.index].clone();
364            self.index += 1;
365            Some(Ok(batch))
366        } else {
367            None
368        })
369    }
370}
371
372/// Adapt a [Stream] of [RecordBatch] to a [RecordBatchStream].
373pub struct RecordBatchStreamWrapper<S> {
374    pub schema: SchemaRef,
375    pub stream: S,
376    pub output_ordering: Option<Vec<OrderOption>>,
377    pub metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
378    pub span: Span,
379}
380
381impl<S> RecordBatchStreamWrapper<S> {
382    /// Creates a [RecordBatchStreamWrapper] without output ordering requirement.
383    pub fn new(schema: SchemaRef, stream: S) -> RecordBatchStreamWrapper<S> {
384        RecordBatchStreamWrapper {
385            schema,
386            stream,
387            output_ordering: None,
388            metrics: Default::default(),
389            span: Span::current(),
390        }
391    }
392}
393
394impl<S: Stream<Item = Result<RecordBatch>> + Unpin> RecordBatchStream
395    for RecordBatchStreamWrapper<S>
396{
397    fn name(&self) -> &str {
398        "RecordBatchStreamWrapper"
399    }
400
401    fn schema(&self) -> SchemaRef {
402        self.schema.clone()
403    }
404
405    fn output_ordering(&self) -> Option<&[OrderOption]> {
406        self.output_ordering.as_deref()
407    }
408
409    fn metrics(&self) -> Option<RecordBatchMetrics> {
410        self.metrics.load().as_ref().map(|s| s.as_ref().clone())
411    }
412}
413
414impl<S: Stream<Item = Result<RecordBatch>> + Unpin> Stream for RecordBatchStreamWrapper<S> {
415    type Item = Result<RecordBatch>;
416
417    fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
418        let _entered = self.span.clone().entered();
419        Pin::new(&mut self.stream).poll_next(ctx)
420    }
421}
422
423/// Memory tracker for RecordBatch streams. Clone to share the same limit across queries.
424///
425/// Each stream acquires quota independently from this tracker.
426#[derive(Clone)]
427pub struct QueryMemoryTracker {
428    manager: MemoryManager<CallbackMemoryMetrics>,
429    metrics: CallbackMemoryMetrics,
430    on_exhausted_policy: OnExhaustedPolicy,
431}
432
433impl fmt::Debug for QueryMemoryTracker {
434    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435        f.debug_struct("QueryMemoryTracker")
436            .field("current", &self.current())
437            .field("limit", &self.limit())
438            .field("on_exhausted_policy", &self.on_exhausted_policy)
439            .field("on_update", &self.metrics.has_on_update())
440            .field("on_reject", &self.metrics.has_on_reject())
441            .finish()
442    }
443}
444
445impl QueryMemoryTracker {
446    /// Create a builder for a query memory tracker.
447    pub fn builder(
448        limit: usize,
449        on_exhausted_policy: OnExhaustedPolicy,
450    ) -> QueryMemoryTrackerBuilder {
451        QueryMemoryTrackerBuilder {
452            limit,
453            on_exhausted_policy,
454            on_update: None,
455            on_reject: None,
456        }
457    }
458
459    fn new_stream_tracker(&self) -> StreamMemoryTracker {
460        StreamMemoryTracker {
461            tracker: self.clone(),
462            guard: self.manager.try_acquire(0).unwrap(),
463            tracked_bytes: 0,
464        }
465    }
466    /// Get the current memory usage in bytes.
467    pub fn current(&self) -> usize {
468        self.manager.used_bytes() as usize
469    }
470
471    fn limit(&self) -> usize {
472        self.manager.limit_bytes() as usize
473    }
474
475    fn reject_error(
476        &self,
477        current: usize,
478        additional: usize,
479        stream_tracked: usize,
480    ) -> error::Error {
481        let limit = self.limit();
482        let msg = format!(
483            "{} requested, {} used globally ({}%), {} used by this stream, hard limit: {}",
484            ReadableSize(additional as u64),
485            ReadableSize(current as u64),
486            if limit > 0 { current * 100 / limit } else { 0 },
487            ReadableSize(stream_tracked as u64),
488            ReadableSize(limit as u64)
489        );
490        error::ExceedMemoryLimitSnafu { msg }.build()
491    }
492}
493
494/// Builder for constructing a [`QueryMemoryTracker`] with optional callbacks.
495pub struct QueryMemoryTrackerBuilder {
496    limit: usize,
497    on_exhausted_policy: OnExhaustedPolicy,
498    on_update: Option<UpdateCallback>,
499    on_reject: Option<RejectCallback>,
500}
501
502impl QueryMemoryTrackerBuilder {
503    /// Set a callback to be called whenever the usage changes successfully.
504    /// The callback receives the new total usage in bytes.
505    ///
506    /// # Note
507    /// The callback is called after both successful `track()` and stream drop.
508    /// Usage is exact in unlimited mode and 1KB-aligned in limited mode.
509    pub fn on_update<F>(mut self, on_update: F) -> Self
510    where
511        F: Fn(usize) + Send + Sync + 'static,
512    {
513        self.on_update = Some(Arc::new(on_update));
514        self
515    }
516
517    /// Set a callback to be called when memory allocation is rejected.
518    ///
519    /// # Note
520    /// This is only called when `track()` fails due to exceeding the limit.
521    /// It is never called when `limit == 0` (unlimited mode).
522    pub fn on_reject<F>(mut self, on_reject: F) -> Self
523    where
524        F: Fn() + Send + Sync + 'static,
525    {
526        self.on_reject = Some(Arc::new(on_reject));
527        self
528    }
529
530    /// Build a [`QueryMemoryTracker`] from this builder.
531    pub fn build(self) -> QueryMemoryTracker {
532        let metrics = CallbackMemoryMetrics::new(self.on_update, self.on_reject);
533        let manager = MemoryManager::with_granularity(
534            self.limit as u64,
535            PermitGranularity::Kilobyte,
536            metrics.clone(),
537        );
538
539        QueryMemoryTracker {
540            manager,
541            metrics,
542            on_exhausted_policy: self.on_exhausted_policy,
543        }
544    }
545}
546
547struct StreamMemoryTracker {
548    tracker: QueryMemoryTracker,
549    guard: MemoryGuard<CallbackMemoryMetrics>,
550    tracked_bytes: usize,
551}
552
553type MemoryAcquireResult = std::result::Result<(), common_memory_manager::Error>;
554
555impl StreamMemoryTracker {
556    fn try_track(&mut self, additional: usize) -> Result<()> {
557        if self.guard.try_acquire_additional(additional as u64) {
558            self.tracked_bytes = self.tracked_bytes.saturating_add(additional);
559            Ok(())
560        } else {
561            Err(self.reject_error(additional))
562        }
563    }
564
565    async fn track_with_policy(mut self, additional: usize) -> (Self, MemoryAcquireResult) {
566        let result = self
567            .guard
568            .acquire_additional_with_policy(additional as u64, self.tracker.on_exhausted_policy)
569            .await;
570        if result.is_ok() {
571            self.tracked_bytes = self.tracked_bytes.saturating_add(additional);
572        }
573        (self, result)
574    }
575
576    fn reject_error(&self, additional: usize) -> error::Error {
577        let current = self.tracker.current();
578        self.tracker
579            .reject_error(current, additional, self.tracked_bytes)
580    }
581
582    fn wait_error(&self, additional: usize, source: common_memory_manager::Error) -> error::Error {
583        match source {
584            common_memory_manager::Error::MemoryLimitExceeded { .. } => {
585                self.reject_error(additional)
586            }
587            common_memory_manager::Error::MemoryAcquireTimeout { waited, .. } => {
588                let current = self.tracker.current();
589                let limit = self.tracker.limit();
590                let msg = format!(
591                    "timed out waiting {:?} for {}, {} used globally ({}%), {} used by this stream, hard limit: {}",
592                    waited,
593                    ReadableSize(additional as u64),
594                    ReadableSize(current as u64),
595                    if limit > 0 { current * 100 / limit } else { 0 },
596                    ReadableSize(self.tracked_bytes as u64),
597                    ReadableSize(limit as u64)
598                );
599                error::ExceedMemoryLimitSnafu { msg }.build()
600            }
601            error => error::ExternalSnafu.into_error(BoxedError::new(error)),
602        }
603    }
604}
605
606type PendingTrackFuture = Pin<
607    Box<dyn Future<Output = (StreamMemoryTracker, RecordBatch, usize, MemoryAcquireResult)> + Send>,
608>;
609
610#[derive(Clone)]
611struct CallbackMemoryMetrics {
612    inner: Arc<CallbackMemoryMetricsInner>,
613}
614
615type UpdateCallback = Arc<dyn Fn(usize) + Send + Sync>;
616type RejectCallback = Arc<dyn Fn() + Send + Sync>;
617
618struct CallbackMemoryMetricsInner {
619    on_update: Option<UpdateCallback>,
620    on_reject: Option<RejectCallback>,
621}
622
623impl CallbackMemoryMetrics {
624    fn new(on_update: Option<UpdateCallback>, on_reject: Option<RejectCallback>) -> Self {
625        Self {
626            inner: Arc::new(CallbackMemoryMetricsInner {
627                on_update,
628                on_reject,
629            }),
630        }
631    }
632
633    fn has_on_update(&self) -> bool {
634        self.inner.on_update.is_some()
635    }
636
637    fn has_on_reject(&self) -> bool {
638        self.inner.on_reject.is_some()
639    }
640}
641
642impl MemoryMetrics for CallbackMemoryMetrics {
643    fn set_limit(&self, _: i64) {}
644
645    fn set_in_use(&self, bytes: i64) {
646        if let Some(callback) = &self.inner.on_update {
647            callback(bytes.max(0) as usize);
648        }
649    }
650
651    fn inc_rejected(&self, _: &str) {
652        if let Some(callback) = &self.inner.on_reject {
653            callback();
654        }
655    }
656}
657
658/// A wrapper stream that tracks memory usage of RecordBatches.
659pub struct MemoryTrackedStream {
660    inner: SendableRecordBatchStream,
661    tracker: Option<StreamMemoryTracker>,
662    // Waiting stores a batch that has already been pulled from the inner stream but has not yet
663    // acquired additional quota. This keeps `poll_next()` non-blocking and allows bounded waits,
664    // at the cost of temporarily holding one untracked batch per blocked stream in memory.
665    waiting: Option<PendingTrackFuture>,
666}
667
668impl MemoryTrackedStream {
669    pub fn new(inner: SendableRecordBatchStream, tracker: QueryMemoryTracker) -> Self {
670        Self {
671            inner,
672            tracker: Some(tracker.new_stream_tracker()),
673            waiting: None,
674        }
675    }
676
677    fn ready_tracker_mut(&mut self) -> &mut StreamMemoryTracker {
678        debug_assert!(
679            self.waiting.is_none(),
680            "a ready tracker must not coexist with a waiting future"
681        );
682        self.tracker.as_mut().unwrap()
683    }
684
685    fn enter_waiting(&mut self, batch: RecordBatch, additional: usize) {
686        debug_assert!(
687            self.waiting.is_none(),
688            "enter_waiting should only be called from the ready state"
689        );
690        debug_assert!(
691            self.tracker.is_some(),
692            "enter_waiting requires a tracker in the ready state"
693        );
694        let tracker = self.tracker.take().unwrap();
695        self.waiting = Some(Self::start_waiting(tracker, batch, additional));
696    }
697
698    fn start_waiting(
699        tracker: StreamMemoryTracker,
700        batch: RecordBatch,
701        additional: usize,
702    ) -> PendingTrackFuture {
703        Box::pin(async move {
704            let (tracker, result) = tracker.track_with_policy(additional).await;
705            (tracker, batch, additional, result)
706        })
707    }
708
709    fn poll_waiting(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<RecordBatch>>> {
710        let future = self.waiting.as_mut().unwrap();
711        match future.as_mut().poll(cx) {
712            Poll::Ready((tracker, batch, additional, result)) => {
713                let output = match result {
714                    Ok(()) => Ok(batch),
715                    Err(error) => Err(tracker.wait_error(additional, error)),
716                };
717                self.waiting = None;
718                self.tracker = Some(tracker);
719                Poll::Ready(Some(output))
720            }
721            Poll::Pending => Poll::Pending,
722        }
723    }
724
725    fn poll_batch(
726        &mut self,
727        batch: RecordBatch,
728        cx: &mut Context<'_>,
729    ) -> Poll<Option<Result<RecordBatch>>> {
730        let additional = batch.buffer_memory_size();
731        let tracker = self.ready_tracker_mut();
732
733        if let Err(error) = tracker.try_track(additional) {
734            match tracker.tracker.on_exhausted_policy {
735                OnExhaustedPolicy::Fail => return Poll::Ready(Some(Err(error))),
736                // `Wait` is a deliberate tradeoff: the batch has already been materialized, so we
737                // keep it in memory while waiting for quota instead of failing immediately. Under
738                // contention, real memory usage can therefore exceed `scan_memory_limit` by up to
739                // one buffered batch per blocked stream.
740                OnExhaustedPolicy::Wait { .. } => {
741                    self.enter_waiting(batch, additional);
742                    return self.poll_waiting(cx);
743                }
744            }
745        }
746
747        Poll::Ready(Some(Ok(batch)))
748    }
749}
750
751impl Stream for MemoryTrackedStream {
752    type Item = Result<RecordBatch>;
753
754    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
755        if self.waiting.is_some() {
756            return self.poll_waiting(cx);
757        }
758
759        match Pin::new(&mut self.inner).poll_next(cx) {
760            Poll::Ready(Some(Ok(batch))) => self.poll_batch(batch, cx),
761            Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
762            Poll::Ready(None) => Poll::Ready(None),
763            Poll::Pending => Poll::Pending,
764        }
765    }
766
767    fn size_hint(&self) -> (usize, Option<usize>) {
768        self.inner.size_hint()
769    }
770}
771
772impl RecordBatchStream for MemoryTrackedStream {
773    fn schema(&self) -> SchemaRef {
774        self.inner.schema()
775    }
776
777    fn output_ordering(&self) -> Option<&[OrderOption]> {
778        self.inner.output_ordering()
779    }
780
781    fn metrics(&self) -> Option<RecordBatchMetrics> {
782        self.inner.metrics()
783    }
784}
785
786#[cfg(test)]
787mod tests {
788    use std::sync::Arc;
789    use std::time::Duration;
790
791    use common_memory_manager::{OnExhaustedPolicy, PermitGranularity};
792    use datatypes::prelude::{ConcreteDataType, VectorRef};
793    use datatypes::schema::{ColumnSchema, Schema};
794    use datatypes::vectors::{BooleanVector, Int32Vector, StringVector};
795    use futures::StreamExt;
796    use tokio::time::{sleep, timeout};
797
798    use super::*;
799
800    fn large_string_batch(bytes: usize) -> RecordBatch {
801        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
802            "payload",
803            ConcreteDataType::string_datatype(),
804            false,
805        )]));
806        let payload = "x".repeat(bytes);
807        let vector: VectorRef = Arc::new(StringVector::from(vec![payload]));
808        RecordBatch::new(schema, vec![vector]).unwrap()
809    }
810
811    fn aligned_tracked_bytes(bytes: usize) -> usize {
812        PermitGranularity::Kilobyte
813            .permits_to_bytes(PermitGranularity::Kilobyte.bytes_to_permits(bytes as u64))
814            as usize
815    }
816
817    #[test]
818    fn test_recordbatches_try_from_columns() {
819        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
820            "a",
821            ConcreteDataType::int32_datatype(),
822            false,
823        )]));
824        let result = RecordBatches::try_from_columns(
825            schema.clone(),
826            vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
827        );
828        assert!(result.is_err());
829
830        let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
831        let expected = vec![RecordBatch::new(schema.clone(), vec![v.clone()]).unwrap()];
832        let r = RecordBatches::try_from_columns(schema, vec![v]).unwrap();
833        assert_eq!(r.take(), expected);
834    }
835
836    #[test]
837    fn test_recordbatches_try_new() {
838        let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
839        let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
840        let column_c = ColumnSchema::new("c", ConcreteDataType::boolean_datatype(), false);
841
842        let va: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
843        let vb: VectorRef = Arc::new(StringVector::from(vec!["hello", "world"]));
844        let vc: VectorRef = Arc::new(BooleanVector::from(vec![true, false]));
845
846        let schema1 = Arc::new(Schema::new(vec![column_a.clone(), column_b]));
847        let batch1 = RecordBatch::new(schema1.clone(), vec![va.clone(), vb]).unwrap();
848
849        let schema2 = Arc::new(Schema::new(vec![column_a, column_c]));
850        let batch2 = RecordBatch::new(schema2.clone(), vec![va, vc]).unwrap();
851
852        let result = RecordBatches::try_new(schema1.clone(), vec![batch1.clone(), batch2]);
853        assert!(result.is_err());
854        assert_eq!(
855            result.unwrap_err().to_string(),
856            format!(
857                "Failed to create RecordBatches, reason: expect RecordBatch schema equals {schema1:?}, actual: {schema2:?}",
858            )
859        );
860
861        let batches = RecordBatches::try_new(schema1.clone(), vec![batch1.clone()]).unwrap();
862        let expected = "\
863+---+-------+
864| a | b     |
865+---+-------+
866| 1 | hello |
867| 2 | world |
868+---+-------+";
869        assert_eq!(batches.pretty_print().unwrap(), expected);
870
871        assert_eq!(schema1, batches.schema());
872        assert_eq!(vec![batch1], batches.take());
873    }
874
875    #[tokio::test]
876    async fn test_simple_recordbatch_stream() {
877        let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
878        let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
879        let schema = Arc::new(Schema::new(vec![column_a, column_b]));
880
881        let va1: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
882        let vb1: VectorRef = Arc::new(StringVector::from(vec!["a", "b"]));
883        let batch1 = RecordBatch::new(schema.clone(), vec![va1, vb1]).unwrap();
884
885        let va2: VectorRef = Arc::new(Int32Vector::from_slice([3, 4, 5]));
886        let vb2: VectorRef = Arc::new(StringVector::from(vec!["c", "d", "e"]));
887        let batch2 = RecordBatch::new(schema.clone(), vec![va2, vb2]).unwrap();
888
889        let recordbatches =
890            RecordBatches::try_new(schema.clone(), vec![batch1.clone(), batch2.clone()]).unwrap();
891        let stream = recordbatches.as_stream();
892        let collected = util::collect(stream).await.unwrap();
893        assert_eq!(collected.len(), 2);
894        assert_eq!(collected[0], batch1);
895        assert_eq!(collected[1], batch2);
896    }
897
898    const MB: usize = 1024 * 1024;
899
900    #[test]
901    fn test_query_memory_tracker_basic() {
902        let tracker =
903            Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
904
905        let mut stream1 = tracker.new_stream_tracker();
906        assert!(stream1.try_track(5 * MB).is_ok());
907        assert_eq!(tracker.current(), 5 * MB);
908
909        let mut stream2 = tracker.new_stream_tracker();
910        assert!(stream2.try_track(4 * MB).is_ok());
911        assert_eq!(tracker.current(), 9 * MB);
912
913        drop(stream1);
914        drop(stream2);
915        assert_eq!(tracker.current(), 0);
916    }
917
918    #[test]
919    fn test_query_memory_tracker_shared_global_limit() {
920        let tracker =
921            Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
922        let mut stream1 = tracker.new_stream_tracker();
923        let mut stream2 = tracker.new_stream_tracker();
924
925        assert!(stream1.try_track(3 * MB).is_ok());
926        assert_eq!(tracker.current(), 3 * MB);
927        assert!(stream2.try_track(6 * MB).is_ok());
928        assert_eq!(tracker.current(), 9 * MB);
929
930        let err = stream2.try_track(2 * MB).unwrap_err();
931        let err_msg = err.to_string();
932        assert!(err_msg.contains("6.0MiB used by this stream"));
933        assert!(err_msg.contains("9.0MiB used globally (90%)"));
934        assert!(err_msg.contains("hard limit: 10.0MiB"));
935        assert_eq!(tracker.current(), 9 * MB);
936
937        drop(stream1);
938        assert_eq!(tracker.current(), 6 * MB);
939        drop(stream2);
940        assert_eq!(tracker.current(), 0);
941    }
942
943    #[test]
944    fn test_query_memory_tracker_hard_limit() {
945        let tracker =
946            Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
947        let mut stream = tracker.new_stream_tracker();
948
949        assert!(stream.try_track(9 * MB).is_ok());
950        assert_eq!(tracker.current(), 9 * MB);
951
952        assert!(stream.try_track(2 * MB).is_err());
953        assert_eq!(tracker.current(), 9 * MB);
954
955        assert!(stream.try_track(MB).is_ok());
956        assert_eq!(tracker.current(), 10 * MB);
957
958        assert!(stream.try_track(MB).is_err());
959        assert_eq!(tracker.current(), 10 * MB);
960
961        drop(stream);
962        assert_eq!(tracker.current(), 0);
963    }
964
965    #[test]
966    fn test_query_memory_tracker_unlimited() {
967        let tracker = Arc::new(QueryMemoryTracker::builder(0, OnExhaustedPolicy::Fail).build());
968        let mut stream = tracker.new_stream_tracker();
969
970        assert!(stream.try_track(10 * MB).is_ok());
971        assert_eq!(tracker.current(), 10 * MB);
972        drop(stream);
973        assert_eq!(tracker.current(), 0);
974    }
975
976    #[test]
977    fn test_query_memory_tracker_rounds_to_kilobytes() {
978        let tracker =
979            Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
980        let mut stream = tracker.new_stream_tracker();
981
982        assert!(stream.try_track(1_537).is_ok());
983        assert_eq!(tracker.current(), 2 * 1024);
984
985        drop(stream);
986        assert_eq!(tracker.current(), 0);
987    }
988
989    #[tokio::test]
990    async fn test_memory_tracked_stream_waits_for_capacity() {
991        let tracker = QueryMemoryTracker::builder(
992            MB,
993            OnExhaustedPolicy::Wait {
994                timeout: Duration::from_millis(200),
995            },
996        )
997        .build();
998        let batch = large_string_batch(700 * 1024);
999        let expected_bytes = aligned_tracked_bytes(batch.buffer_memory_size());
1000
1001        let mut stream1 = MemoryTrackedStream::new(
1002            RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1003                .unwrap()
1004                .as_stream(),
1005            tracker.clone(),
1006        );
1007        let first = stream1.next().await.unwrap().unwrap();
1008        assert_eq!(first.num_rows(), 1);
1009        assert_eq!(tracker.current(), expected_bytes);
1010
1011        let stream2 = MemoryTrackedStream::new(
1012            RecordBatches::try_new(batch.schema.clone(), vec![batch])
1013                .unwrap()
1014                .as_stream(),
1015            tracker.clone(),
1016        );
1017        let waiter = tokio::spawn(async move {
1018            let mut stream2 = stream2;
1019            stream2.next().await.unwrap()
1020        });
1021
1022        sleep(Duration::from_millis(50)).await;
1023        assert!(!waiter.is_finished());
1024
1025        drop(stream1);
1026        let second = waiter.await.unwrap().unwrap();
1027        assert_eq!(second.num_rows(), 1);
1028    }
1029
1030    #[tokio::test]
1031    async fn test_memory_tracked_stream_wait_times_out() {
1032        let tracker = QueryMemoryTracker::builder(
1033            MB,
1034            OnExhaustedPolicy::Wait {
1035                timeout: Duration::from_millis(50),
1036            },
1037        )
1038        .build();
1039        let batch = large_string_batch(700 * 1024);
1040
1041        let mut stream1 = MemoryTrackedStream::new(
1042            RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1043                .unwrap()
1044                .as_stream(),
1045            tracker.clone(),
1046        );
1047        let first = stream1.next().await.unwrap().unwrap();
1048        assert_eq!(first.num_rows(), 1);
1049
1050        let mut stream2 = MemoryTrackedStream::new(
1051            RecordBatches::try_new(batch.schema.clone(), vec![batch])
1052                .unwrap()
1053                .as_stream(),
1054            tracker,
1055        );
1056        let result = timeout(Duration::from_secs(1), stream2.next())
1057            .await
1058            .unwrap();
1059        let error = result.unwrap().unwrap_err();
1060        assert!(error.to_string().contains("timed out waiting"));
1061    }
1062}