1#![feature(never_type)]
16
17pub mod adapter;
18pub mod cursor;
19pub mod error;
20pub mod ext;
21pub mod filter;
22pub mod recordbatch;
23pub mod util;
24
25use std::fmt;
26use std::future::Future;
27use std::pin::Pin;
28use std::sync::Arc;
29
30use adapter::RecordBatchMetrics;
31use arc_swap::ArcSwapOption;
32use common_base::readable_size::ReadableSize;
33use common_error::ext::BoxedError;
34use common_memory_manager::{
35 MemoryGuard, MemoryManager, MemoryMetrics, OnExhaustedPolicy, PermitGranularity,
36};
37use common_telemetry::tracing::Span;
38pub use datafusion::physical_plan::SendableRecordBatchStream as DfSendableRecordBatchStream;
39use datatypes::arrow::array::{ArrayRef, AsArray, StringBuilder};
40use datatypes::arrow::compute::SortOptions;
41pub use datatypes::arrow::record_batch::RecordBatch as DfRecordBatch;
42use datatypes::arrow::util::pretty;
43use datatypes::prelude::{ConcreteDataType, VectorRef};
44use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
45use datatypes::types::{JsonFormat, jsonb_to_string};
46use error::Result;
47use futures::task::{Context, Poll};
48use futures::{Stream, TryStreamExt};
49pub use recordbatch::RecordBatch;
50use snafu::{IntoError, ResultExt, ensure};
51
52use crate::error::NewDfRecordBatchSnafu;
53
54pub trait RecordBatchStream: Stream<Item = Result<RecordBatch>> {
55 fn name(&self) -> &str {
56 "RecordBatchStream"
57 }
58
59 fn schema(&self) -> SchemaRef;
60
61 fn output_ordering(&self) -> Option<&[OrderOption]>;
62
63 fn metrics(&self) -> Option<RecordBatchMetrics>;
64}
65
66pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send>>;
67
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct OrderOption {
70 pub name: String,
71 pub options: SortOptions,
72}
73
74pub struct SendableRecordBatchMapper {
81 inner: SendableRecordBatchStream,
82 mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
85 schema: SchemaRef,
87 apply_mapper: bool,
89}
90
91pub fn map_json_type_to_string(
97 batch: RecordBatch,
98 original_schema: &SchemaRef,
99 mapped_schema: &SchemaRef,
100) -> Result<RecordBatch> {
101 let mut vectors = Vec::with_capacity(original_schema.column_schemas().len());
102 for (vector, schema) in batch.columns().iter().zip(original_schema.column_schemas()) {
103 if let ConcreteDataType::Json(j) = &schema.data_type {
104 if matches!(&j.format, JsonFormat::Jsonb) {
105 let mut string_vector_builder = StringBuilder::new();
106 let binary_vector = vector.as_binary::<i32>();
107 for value in binary_vector.iter() {
108 let Some(value) = value else {
109 string_vector_builder.append_null();
110 continue;
111 };
112 let string_value =
113 jsonb_to_string(value).with_context(|_| error::CastVectorSnafu {
114 from_type: schema.data_type.clone(),
115 to_type: ConcreteDataType::string_datatype(),
116 })?;
117 string_vector_builder.append_value(string_value);
118 }
119
120 let string_vector = string_vector_builder.finish();
121 vectors.push(Arc::new(string_vector) as ArrayRef);
122 } else {
123 vectors.push(vector.clone());
124 }
125 } else {
126 vectors.push(vector.clone());
127 }
128 }
129
130 let record_batch = datatypes::arrow::record_batch::RecordBatch::try_new(
131 mapped_schema.arrow_schema().clone(),
132 vectors,
133 )
134 .context(NewDfRecordBatchSnafu)?;
135 Ok(RecordBatch::from_df_record_batch(
136 mapped_schema.clone(),
137 record_batch,
138 ))
139}
140
141pub fn map_json_type_to_string_schema(schema: SchemaRef) -> (SchemaRef, bool) {
149 let mut new_columns = Vec::with_capacity(schema.column_schemas().len());
150 let mut apply_mapper = false;
151 for column in schema.column_schemas() {
152 if matches!(column.data_type, ConcreteDataType::Json(_)) {
153 new_columns.push(ColumnSchema::new(
154 column.name.clone(),
155 ConcreteDataType::string_datatype(),
156 column.is_nullable(),
157 ));
158 apply_mapper = true;
159 } else {
160 new_columns.push(column.clone());
161 }
162 }
163 (Arc::new(Schema::new(new_columns)), apply_mapper)
164}
165
166impl SendableRecordBatchMapper {
167 pub fn new(
169 inner: SendableRecordBatchStream,
170 mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
171 schema_mapper: fn(SchemaRef) -> (SchemaRef, bool),
172 ) -> Self {
173 let (mapped_schema, apply_mapper) = schema_mapper(inner.schema());
174 Self {
175 inner,
176 mapper,
177 schema: mapped_schema,
178 apply_mapper,
179 }
180 }
181}
182
183impl RecordBatchStream for SendableRecordBatchMapper {
184 fn name(&self) -> &str {
185 "SendableRecordBatchMapper"
186 }
187
188 fn schema(&self) -> SchemaRef {
189 self.schema.clone()
190 }
191
192 fn output_ordering(&self) -> Option<&[OrderOption]> {
193 self.inner.output_ordering()
194 }
195
196 fn metrics(&self) -> Option<RecordBatchMetrics> {
197 self.inner.metrics()
198 }
199}
200
201impl Stream for SendableRecordBatchMapper {
202 type Item = Result<RecordBatch>;
203
204 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
205 if self.apply_mapper {
206 Pin::new(&mut self.inner).poll_next(cx).map(|opt| {
207 opt.map(|result| {
208 result
209 .and_then(|batch| (self.mapper)(batch, &self.inner.schema(), &self.schema))
210 })
211 })
212 } else {
213 Pin::new(&mut self.inner).poll_next(cx)
214 }
215 }
216}
217
218pub struct EmptyRecordBatchStream {
221 schema: SchemaRef,
223}
224
225impl EmptyRecordBatchStream {
226 pub fn new(schema: SchemaRef) -> Self {
228 Self { schema }
229 }
230}
231
232impl RecordBatchStream for EmptyRecordBatchStream {
233 fn schema(&self) -> SchemaRef {
234 self.schema.clone()
235 }
236
237 fn output_ordering(&self) -> Option<&[OrderOption]> {
238 None
239 }
240
241 fn metrics(&self) -> Option<RecordBatchMetrics> {
242 None
243 }
244}
245
246impl Stream for EmptyRecordBatchStream {
247 type Item = Result<RecordBatch>;
248
249 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
250 Poll::Ready(None)
251 }
252}
253
254#[derive(Debug, PartialEq)]
255pub struct RecordBatches {
256 schema: SchemaRef,
257 batches: Vec<RecordBatch>,
258}
259
260impl RecordBatches {
261 pub fn try_from_columns<I: IntoIterator<Item = VectorRef>>(
262 schema: SchemaRef,
263 columns: I,
264 ) -> Result<Self> {
265 let batches = vec![RecordBatch::new(schema.clone(), columns)?];
266 Ok(Self { schema, batches })
267 }
268
269 pub async fn try_collect(stream: SendableRecordBatchStream) -> Result<Self> {
270 let schema = stream.schema();
271 let batches = stream.try_collect::<Vec<_>>().await?;
272 Ok(Self { schema, batches })
273 }
274
275 #[inline]
276 pub fn empty() -> Self {
277 Self {
278 schema: Arc::new(Schema::new(vec![])),
279 batches: vec![],
280 }
281 }
282
283 pub fn iter(&self) -> impl Iterator<Item = &RecordBatch> {
284 self.batches.iter()
285 }
286
287 pub fn pretty_print(&self) -> Result<String> {
288 let df_batches = &self
289 .iter()
290 .map(|x| x.df_record_batch().clone())
291 .collect::<Vec<_>>();
292 let result = pretty::pretty_format_batches(df_batches).context(error::FormatSnafu)?;
293
294 Ok(result.to_string())
295 }
296
297 pub fn try_new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Result<Self> {
298 for batch in &batches {
299 ensure!(
300 batch.schema == schema,
301 error::CreateRecordBatchesSnafu {
302 reason: format!(
303 "expect RecordBatch schema equals {:?}, actual: {:?}",
304 schema, batch.schema
305 )
306 }
307 )
308 }
309 Ok(Self { schema, batches })
310 }
311
312 pub fn schema(&self) -> SchemaRef {
313 self.schema.clone()
314 }
315
316 pub fn take(self) -> Vec<RecordBatch> {
317 self.batches
318 }
319
320 pub fn as_stream(&self) -> SendableRecordBatchStream {
321 Box::pin(SimpleRecordBatchStream {
322 inner: RecordBatches {
323 schema: self.schema(),
324 batches: self.batches.clone(),
325 },
326 index: 0,
327 })
328 }
329}
330
331impl IntoIterator for RecordBatches {
332 type Item = RecordBatch;
333 type IntoIter = std::vec::IntoIter<Self::Item>;
334
335 fn into_iter(self) -> Self::IntoIter {
336 self.batches.into_iter()
337 }
338}
339
340pub struct SimpleRecordBatchStream {
341 inner: RecordBatches,
342 index: usize,
343}
344
345impl RecordBatchStream for SimpleRecordBatchStream {
346 fn schema(&self) -> SchemaRef {
347 self.inner.schema()
348 }
349
350 fn output_ordering(&self) -> Option<&[OrderOption]> {
351 None
352 }
353
354 fn metrics(&self) -> Option<RecordBatchMetrics> {
355 None
356 }
357}
358
359impl Stream for SimpleRecordBatchStream {
360 type Item = Result<RecordBatch>;
361
362 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
363 Poll::Ready(if self.index < self.inner.batches.len() {
364 let batch = self.inner.batches[self.index].clone();
365 self.index += 1;
366 Some(Ok(batch))
367 } else {
368 None
369 })
370 }
371}
372
373pub struct RecordBatchStreamWrapper<S> {
375 pub schema: SchemaRef,
376 pub stream: S,
377 pub output_ordering: Option<Vec<OrderOption>>,
378 pub metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
379 pub span: Span,
380}
381
382impl<S> RecordBatchStreamWrapper<S> {
383 pub fn new(schema: SchemaRef, stream: S) -> RecordBatchStreamWrapper<S> {
385 RecordBatchStreamWrapper {
386 schema,
387 stream,
388 output_ordering: None,
389 metrics: Default::default(),
390 span: Span::current(),
391 }
392 }
393}
394
395impl<S: Stream<Item = Result<RecordBatch>> + Unpin> RecordBatchStream
396 for RecordBatchStreamWrapper<S>
397{
398 fn name(&self) -> &str {
399 "RecordBatchStreamWrapper"
400 }
401
402 fn schema(&self) -> SchemaRef {
403 self.schema.clone()
404 }
405
406 fn output_ordering(&self) -> Option<&[OrderOption]> {
407 self.output_ordering.as_deref()
408 }
409
410 fn metrics(&self) -> Option<RecordBatchMetrics> {
411 self.metrics.load().as_ref().map(|s| s.as_ref().clone())
412 }
413}
414
415impl<S: Stream<Item = Result<RecordBatch>> + Unpin> Stream for RecordBatchStreamWrapper<S> {
416 type Item = Result<RecordBatch>;
417
418 fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
419 let _entered = self.span.clone().entered();
420 Pin::new(&mut self.stream).poll_next(ctx)
421 }
422}
423
424#[derive(Clone)]
428pub struct QueryMemoryTracker {
429 manager: MemoryManager<CallbackMemoryMetrics>,
430 metrics: CallbackMemoryMetrics,
431 on_exhausted_policy: OnExhaustedPolicy,
432}
433
434impl fmt::Debug for QueryMemoryTracker {
435 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436 f.debug_struct("QueryMemoryTracker")
437 .field("current", &self.current())
438 .field("limit", &self.limit())
439 .field("on_exhausted_policy", &self.on_exhausted_policy)
440 .field("on_update", &self.metrics.has_on_update())
441 .field("on_exhausted", &self.metrics.has_on_exhausted())
442 .field("on_rejected", &self.metrics.has_on_rejected())
443 .finish()
444 }
445}
446
447impl QueryMemoryTracker {
448 pub fn builder(
450 limit: usize,
451 on_exhausted_policy: OnExhaustedPolicy,
452 ) -> QueryMemoryTrackerBuilder {
453 QueryMemoryTrackerBuilder {
454 limit,
455 on_exhausted_policy,
456 on_update: None,
457 on_exhausted: None,
458 on_reject: None,
459 }
460 }
461
462 fn new_stream_tracker(&self) -> StreamMemoryTracker {
463 StreamMemoryTracker {
464 tracker: self.clone(),
465 guard: self.manager.try_acquire(0).unwrap(),
466 tracked_bytes: 0,
467 }
468 }
469 pub fn current(&self) -> usize {
471 self.manager.used_bytes() as usize
472 }
473
474 fn limit(&self) -> usize {
475 self.manager.limit_bytes() as usize
476 }
477
478 fn reject_error(
479 &self,
480 current: usize,
481 additional: usize,
482 stream_tracked: usize,
483 ) -> error::Error {
484 let limit = self.limit();
485 let msg = format!(
486 "{} requested, {} used globally ({}%), {} used by this stream, hard limit: {}",
487 ReadableSize(additional as u64),
488 ReadableSize(current as u64),
489 (current * 100).checked_div(limit).unwrap_or(0),
490 ReadableSize(stream_tracked as u64),
491 ReadableSize(limit as u64)
492 );
493 error::ExceedMemoryLimitSnafu { msg }.build()
494 }
495
496 fn inc_rejected(&self) {
497 self.metrics.inc_rejected();
498 }
499}
500
501pub struct QueryMemoryTrackerBuilder {
503 limit: usize,
504 on_exhausted_policy: OnExhaustedPolicy,
505 on_update: Option<UpdateCallback>,
506 on_exhausted: Option<UnitCallback>,
507 on_reject: Option<RejectCallback>,
508}
509
510impl QueryMemoryTrackerBuilder {
511 pub fn on_update<F>(mut self, on_update: F) -> Self
518 where
519 F: Fn(usize) + Send + Sync + 'static,
520 {
521 self.on_update = Some(Arc::new(on_update));
522 self
523 }
524
525 pub fn on_exhausted<F>(mut self, on_exhausted: F) -> Self
532 where
533 F: Fn() + Send + Sync + 'static,
534 {
535 self.on_exhausted = Some(Arc::new(on_exhausted));
536 self
537 }
538
539 pub fn on_reject<F>(mut self, on_reject: F) -> Self
541 where
542 F: Fn() + Send + Sync + 'static,
543 {
544 self.on_reject = Some(Arc::new(on_reject));
545 self
546 }
547
548 pub fn build(self) -> QueryMemoryTracker {
550 let metrics = CallbackMemoryMetrics::new(self.on_update, self.on_exhausted, self.on_reject);
551 let manager = MemoryManager::with_granularity(
552 self.limit as u64,
553 PermitGranularity::Kilobyte,
554 metrics.clone(),
555 );
556
557 QueryMemoryTracker {
558 manager,
559 metrics,
560 on_exhausted_policy: self.on_exhausted_policy,
561 }
562 }
563}
564
565struct StreamMemoryTracker {
566 tracker: QueryMemoryTracker,
567 guard: MemoryGuard<CallbackMemoryMetrics>,
568 tracked_bytes: usize,
569}
570
571type MemoryAcquireResult = std::result::Result<(), common_memory_manager::Error>;
572
573impl StreamMemoryTracker {
574 fn inc_rejected(&self) {
575 self.tracker.inc_rejected();
576 }
577
578 fn try_track(&mut self, additional: usize) -> Result<()> {
579 if self.guard.try_acquire_additional(additional as u64) {
580 self.tracked_bytes = self.tracked_bytes.saturating_add(additional);
581 Ok(())
582 } else {
583 Err(self.reject_error(additional))
584 }
585 }
586
587 async fn track_with_policy(mut self, additional: usize) -> (Self, MemoryAcquireResult) {
588 let result = self
589 .guard
590 .acquire_additional_with_policy(additional as u64, self.tracker.on_exhausted_policy)
591 .await;
592 if result.is_ok() {
593 self.tracked_bytes = self.tracked_bytes.saturating_add(additional);
594 }
595 (self, result)
596 }
597
598 fn reject_error(&self, additional: usize) -> error::Error {
599 let current = self.tracker.current();
600 self.tracker
601 .reject_error(current, additional, self.tracked_bytes)
602 }
603
604 fn wait_error(&self, additional: usize, source: common_memory_manager::Error) -> error::Error {
605 match source {
606 common_memory_manager::Error::MemoryLimitExceeded { .. } => {
607 self.reject_error(additional)
608 }
609 common_memory_manager::Error::MemoryAcquireTimeout { waited, .. } => {
610 let current = self.tracker.current();
611 let limit = self.tracker.limit();
612 let msg = format!(
613 "timed out waiting {:?} for {}, {} used globally ({}%), {} used by this stream, hard limit: {}",
614 waited,
615 ReadableSize(additional as u64),
616 ReadableSize(current as u64),
617 (current * 100).checked_div(limit).unwrap_or(0),
618 ReadableSize(self.tracked_bytes as u64),
619 ReadableSize(limit as u64)
620 );
621 error::ExceedMemoryLimitSnafu { msg }.build()
622 }
623 error => error::ExternalSnafu.into_error(BoxedError::new(error)),
624 }
625 }
626}
627
628type PendingTrackFuture = Pin<
629 Box<dyn Future<Output = (StreamMemoryTracker, RecordBatch, usize, MemoryAcquireResult)> + Send>,
630>;
631
632#[derive(Clone)]
633struct CallbackMemoryMetrics {
634 inner: Arc<CallbackMemoryMetricsInner>,
635}
636
637type UpdateCallback = Arc<dyn Fn(usize) + Send + Sync>;
638type UnitCallback = Arc<dyn Fn() + Send + Sync>;
639type RejectCallback = UnitCallback;
640
641struct CallbackMemoryMetricsInner {
642 on_update: Option<UpdateCallback>,
643 on_exhausted: Option<UnitCallback>,
644 on_reject: Option<RejectCallback>,
645}
646
647impl CallbackMemoryMetrics {
648 fn new(
649 on_update: Option<UpdateCallback>,
650 on_exhausted: Option<UnitCallback>,
651 on_reject: Option<RejectCallback>,
652 ) -> Self {
653 Self {
654 inner: Arc::new(CallbackMemoryMetricsInner {
655 on_update,
656 on_exhausted,
657 on_reject,
658 }),
659 }
660 }
661
662 fn has_on_update(&self) -> bool {
663 self.inner.on_update.is_some()
664 }
665
666 fn has_on_exhausted(&self) -> bool {
667 self.inner.on_exhausted.is_some()
668 }
669
670 fn has_on_rejected(&self) -> bool {
671 self.inner.on_reject.is_some()
672 }
673
674 fn inc_rejected(&self) {
675 if let Some(callback) = &self.inner.on_reject {
676 callback();
677 }
678 }
679}
680
681impl MemoryMetrics for CallbackMemoryMetrics {
682 fn set_limit(&self, _: i64) {}
683
684 fn set_in_use(&self, bytes: i64) {
685 if let Some(callback) = &self.inner.on_update {
686 callback(bytes.max(0) as usize);
687 }
688 }
689
690 fn inc_exhausted(&self, _: &str) {
691 if let Some(callback) = &self.inner.on_exhausted {
692 callback();
693 }
694 }
695}
696
697pub struct MemoryTrackedStream {
699 inner: SendableRecordBatchStream,
700 tracker: Option<StreamMemoryTracker>,
701 waiting: Option<PendingTrackFuture>,
705}
706
707impl MemoryTrackedStream {
708 pub fn new(inner: SendableRecordBatchStream, tracker: QueryMemoryTracker) -> Self {
709 Self {
710 inner,
711 tracker: Some(tracker.new_stream_tracker()),
712 waiting: None,
713 }
714 }
715
716 fn ready_tracker_mut(&mut self) -> &mut StreamMemoryTracker {
717 debug_assert!(
718 self.waiting.is_none(),
719 "a ready tracker must not coexist with a waiting future"
720 );
721 self.tracker.as_mut().unwrap()
722 }
723
724 fn enter_waiting(&mut self, batch: RecordBatch, additional: usize) {
725 debug_assert!(
726 self.waiting.is_none(),
727 "enter_waiting should only be called from the ready state"
728 );
729 debug_assert!(
730 self.tracker.is_some(),
731 "enter_waiting requires a tracker in the ready state"
732 );
733 let tracker = self.tracker.take().unwrap();
734 self.waiting = Some(Self::start_waiting(tracker, batch, additional));
735 }
736
737 fn start_waiting(
738 tracker: StreamMemoryTracker,
739 batch: RecordBatch,
740 additional: usize,
741 ) -> PendingTrackFuture {
742 Box::pin(async move {
743 let (tracker, result) = tracker.track_with_policy(additional).await;
744 (tracker, batch, additional, result)
745 })
746 }
747
748 fn poll_waiting(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<RecordBatch>>> {
749 let future = self.waiting.as_mut().unwrap();
750 match future.as_mut().poll(cx) {
751 Poll::Ready((tracker, batch, additional, result)) => {
752 let output = match result {
753 Ok(()) => Ok(batch),
754 Err(error) => {
755 tracker.inc_rejected();
756 Err(tracker.wait_error(additional, error))
757 }
758 };
759 self.waiting = None;
760 self.tracker = Some(tracker);
761 Poll::Ready(Some(output))
762 }
763 Poll::Pending => Poll::Pending,
764 }
765 }
766
767 fn poll_batch(
768 &mut self,
769 batch: RecordBatch,
770 cx: &mut Context<'_>,
771 ) -> Poll<Option<Result<RecordBatch>>> {
772 let additional = batch.buffer_memory_size();
773 let tracker = self.ready_tracker_mut();
774
775 if let Err(error) = tracker.try_track(additional) {
776 match tracker.tracker.on_exhausted_policy {
777 OnExhaustedPolicy::Fail => {
778 tracker.inc_rejected();
779 return Poll::Ready(Some(Err(error)));
780 }
781 OnExhaustedPolicy::Wait { .. } => {
786 self.enter_waiting(batch, additional);
787 return self.poll_waiting(cx);
788 }
789 }
790 }
791
792 Poll::Ready(Some(Ok(batch)))
793 }
794}
795
796impl Stream for MemoryTrackedStream {
797 type Item = Result<RecordBatch>;
798
799 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
800 if self.waiting.is_some() {
801 return self.poll_waiting(cx);
802 }
803
804 match Pin::new(&mut self.inner).poll_next(cx) {
805 Poll::Ready(Some(Ok(batch))) => self.poll_batch(batch, cx),
806 Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
807 Poll::Ready(None) => Poll::Ready(None),
808 Poll::Pending => Poll::Pending,
809 }
810 }
811
812 fn size_hint(&self) -> (usize, Option<usize>) {
813 self.inner.size_hint()
814 }
815}
816
817impl RecordBatchStream for MemoryTrackedStream {
818 fn schema(&self) -> SchemaRef {
819 self.inner.schema()
820 }
821
822 fn output_ordering(&self) -> Option<&[OrderOption]> {
823 self.inner.output_ordering()
824 }
825
826 fn metrics(&self) -> Option<RecordBatchMetrics> {
827 self.inner.metrics()
828 }
829}
830
831#[cfg(test)]
832mod tests {
833 use std::sync::Arc;
834 use std::sync::atomic::{AtomicUsize, Ordering};
835 use std::time::Duration;
836
837 use common_memory_manager::{OnExhaustedPolicy, PermitGranularity};
838 use datatypes::prelude::{ConcreteDataType, VectorRef};
839 use datatypes::schema::{ColumnSchema, Schema};
840 use datatypes::vectors::{BooleanVector, Int32Vector, StringVector};
841 use futures::StreamExt;
842 use tokio::time::{sleep, timeout};
843
844 use super::*;
845
846 fn large_string_batch(bytes: usize) -> RecordBatch {
847 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
848 "payload",
849 ConcreteDataType::string_datatype(),
850 false,
851 )]));
852 let payload = "x".repeat(bytes);
853 let vector: VectorRef = Arc::new(StringVector::from(vec![payload]));
854 RecordBatch::new(schema, vec![vector]).unwrap()
855 }
856
857 fn aligned_tracked_bytes(bytes: usize) -> usize {
858 PermitGranularity::Kilobyte
859 .permits_to_bytes(PermitGranularity::Kilobyte.bytes_to_permits(bytes as u64))
860 as usize
861 }
862
863 #[test]
864 fn test_recordbatches_try_from_columns() {
865 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
866 "a",
867 ConcreteDataType::int32_datatype(),
868 false,
869 )]));
870 let result = RecordBatches::try_from_columns(
871 schema.clone(),
872 vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
873 );
874 assert!(result.is_err());
875
876 let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
877 let expected = vec![RecordBatch::new(schema.clone(), vec![v.clone()]).unwrap()];
878 let r = RecordBatches::try_from_columns(schema, vec![v]).unwrap();
879 assert_eq!(r.take(), expected);
880 }
881
882 #[test]
883 fn test_recordbatches_try_new() {
884 let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
885 let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
886 let column_c = ColumnSchema::new("c", ConcreteDataType::boolean_datatype(), false);
887
888 let va: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
889 let vb: VectorRef = Arc::new(StringVector::from(vec!["hello", "world"]));
890 let vc: VectorRef = Arc::new(BooleanVector::from(vec![true, false]));
891
892 let schema1 = Arc::new(Schema::new(vec![column_a.clone(), column_b]));
893 let batch1 = RecordBatch::new(schema1.clone(), vec![va.clone(), vb]).unwrap();
894
895 let schema2 = Arc::new(Schema::new(vec![column_a, column_c]));
896 let batch2 = RecordBatch::new(schema2.clone(), vec![va, vc]).unwrap();
897
898 let result = RecordBatches::try_new(schema1.clone(), vec![batch1.clone(), batch2]);
899 assert!(result.is_err());
900 assert_eq!(
901 result.unwrap_err().to_string(),
902 format!(
903 "Failed to create RecordBatches, reason: expect RecordBatch schema equals {schema1:?}, actual: {schema2:?}",
904 )
905 );
906
907 let batches = RecordBatches::try_new(schema1.clone(), vec![batch1.clone()]).unwrap();
908 let expected = "\
909+---+-------+
910| a | b |
911+---+-------+
912| 1 | hello |
913| 2 | world |
914+---+-------+";
915 assert_eq!(batches.pretty_print().unwrap(), expected);
916
917 assert_eq!(schema1, batches.schema());
918 assert_eq!(vec![batch1], batches.take());
919 }
920
921 #[tokio::test]
922 async fn test_simple_recordbatch_stream() {
923 let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
924 let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
925 let schema = Arc::new(Schema::new(vec![column_a, column_b]));
926
927 let va1: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
928 let vb1: VectorRef = Arc::new(StringVector::from(vec!["a", "b"]));
929 let batch1 = RecordBatch::new(schema.clone(), vec![va1, vb1]).unwrap();
930
931 let va2: VectorRef = Arc::new(Int32Vector::from_slice([3, 4, 5]));
932 let vb2: VectorRef = Arc::new(StringVector::from(vec!["c", "d", "e"]));
933 let batch2 = RecordBatch::new(schema.clone(), vec![va2, vb2]).unwrap();
934
935 let recordbatches =
936 RecordBatches::try_new(schema.clone(), vec![batch1.clone(), batch2.clone()]).unwrap();
937 let stream = recordbatches.as_stream();
938 let collected = util::collect(stream).await.unwrap();
939 assert_eq!(collected.len(), 2);
940 assert_eq!(collected[0], batch1);
941 assert_eq!(collected[1], batch2);
942 }
943
944 const MB: usize = 1024 * 1024;
945
946 #[test]
947 fn test_query_memory_tracker_basic() {
948 let tracker =
949 Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
950
951 let mut stream1 = tracker.new_stream_tracker();
952 assert!(stream1.try_track(5 * MB).is_ok());
953 assert_eq!(tracker.current(), 5 * MB);
954
955 let mut stream2 = tracker.new_stream_tracker();
956 assert!(stream2.try_track(4 * MB).is_ok());
957 assert_eq!(tracker.current(), 9 * MB);
958
959 drop(stream1);
960 drop(stream2);
961 assert_eq!(tracker.current(), 0);
962 }
963
964 #[test]
965 fn test_query_memory_tracker_shared_global_limit() {
966 let tracker =
967 Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
968 let mut stream1 = tracker.new_stream_tracker();
969 let mut stream2 = tracker.new_stream_tracker();
970
971 assert!(stream1.try_track(3 * MB).is_ok());
972 assert_eq!(tracker.current(), 3 * MB);
973 assert!(stream2.try_track(6 * MB).is_ok());
974 assert_eq!(tracker.current(), 9 * MB);
975
976 let err = stream2.try_track(2 * MB).unwrap_err();
977 let err_msg = err.to_string();
978 assert!(err_msg.contains("6.0MiB used by this stream"));
979 assert!(err_msg.contains("9.0MiB used globally (90%)"));
980 assert!(err_msg.contains("hard limit: 10.0MiB"));
981 assert_eq!(tracker.current(), 9 * MB);
982
983 drop(stream1);
984 assert_eq!(tracker.current(), 6 * MB);
985 drop(stream2);
986 assert_eq!(tracker.current(), 0);
987 }
988
989 #[test]
990 fn test_query_memory_tracker_hard_limit() {
991 let tracker =
992 Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
993 let mut stream = tracker.new_stream_tracker();
994
995 assert!(stream.try_track(9 * MB).is_ok());
996 assert_eq!(tracker.current(), 9 * MB);
997
998 assert!(stream.try_track(2 * MB).is_err());
999 assert_eq!(tracker.current(), 9 * MB);
1000
1001 assert!(stream.try_track(MB).is_ok());
1002 assert_eq!(tracker.current(), 10 * MB);
1003
1004 assert!(stream.try_track(MB).is_err());
1005 assert_eq!(tracker.current(), 10 * MB);
1006
1007 drop(stream);
1008 assert_eq!(tracker.current(), 0);
1009 }
1010
1011 #[test]
1012 fn test_query_memory_tracker_unlimited() {
1013 let tracker = Arc::new(QueryMemoryTracker::builder(0, OnExhaustedPolicy::Fail).build());
1014 let mut stream = tracker.new_stream_tracker();
1015
1016 assert!(stream.try_track(10 * MB).is_ok());
1017 assert_eq!(tracker.current(), 10 * MB);
1018 drop(stream);
1019 assert_eq!(tracker.current(), 0);
1020 }
1021
1022 #[test]
1023 fn test_query_memory_tracker_rounds_to_kilobytes() {
1024 let tracker =
1025 Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
1026 let mut stream = tracker.new_stream_tracker();
1027
1028 assert!(stream.try_track(1_537).is_ok());
1029 assert_eq!(tracker.current(), 2 * 1024);
1030
1031 drop(stream);
1032 assert_eq!(tracker.current(), 0);
1033 }
1034
1035 #[tokio::test]
1036 async fn test_memory_tracked_stream_waits_for_capacity() {
1037 let exhausted = Arc::new(AtomicUsize::new(0));
1038 let rejected = Arc::new(AtomicUsize::new(0));
1039 let exhausted_counter = exhausted.clone();
1040 let rejected_counter = rejected.clone();
1041 let tracker = QueryMemoryTracker::builder(
1042 MB,
1043 OnExhaustedPolicy::Wait {
1044 timeout: Duration::from_millis(200),
1045 },
1046 )
1047 .on_exhausted(move || {
1048 exhausted_counter.fetch_add(1, Ordering::Relaxed);
1049 })
1050 .on_reject(move || {
1051 rejected_counter.fetch_add(1, Ordering::Relaxed);
1052 })
1053 .build();
1054 let batch = large_string_batch(700 * 1024);
1055 let expected_bytes = aligned_tracked_bytes(batch.buffer_memory_size());
1056
1057 let mut stream1 = MemoryTrackedStream::new(
1058 RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1059 .unwrap()
1060 .as_stream(),
1061 tracker.clone(),
1062 );
1063 let first = stream1.next().await.unwrap().unwrap();
1064 assert_eq!(first.num_rows(), 1);
1065 assert_eq!(tracker.current(), expected_bytes);
1066
1067 let stream2 = MemoryTrackedStream::new(
1068 RecordBatches::try_new(batch.schema.clone(), vec![batch])
1069 .unwrap()
1070 .as_stream(),
1071 tracker.clone(),
1072 );
1073 let waiter = tokio::spawn(async move {
1074 let mut stream2 = stream2;
1075 stream2.next().await.unwrap()
1076 });
1077
1078 sleep(Duration::from_millis(50)).await;
1079 assert!(!waiter.is_finished());
1080
1081 drop(stream1);
1082 let second = waiter.await.unwrap().unwrap();
1083 assert_eq!(second.num_rows(), 1);
1084 assert_eq!(exhausted.load(Ordering::Relaxed), 1);
1085 assert_eq!(rejected.load(Ordering::Relaxed), 0);
1086 }
1087
1088 #[tokio::test]
1089 async fn test_memory_tracked_stream_wait_times_out() {
1090 let exhausted = Arc::new(AtomicUsize::new(0));
1091 let rejected = Arc::new(AtomicUsize::new(0));
1092 let exhausted_counter = exhausted.clone();
1093 let rejected_counter = rejected.clone();
1094 let tracker = QueryMemoryTracker::builder(
1095 MB,
1096 OnExhaustedPolicy::Wait {
1097 timeout: Duration::from_millis(50),
1098 },
1099 )
1100 .on_exhausted(move || {
1101 exhausted_counter.fetch_add(1, Ordering::Relaxed);
1102 })
1103 .on_reject(move || {
1104 rejected_counter.fetch_add(1, Ordering::Relaxed);
1105 })
1106 .build();
1107 let batch = large_string_batch(700 * 1024);
1108
1109 let mut stream1 = MemoryTrackedStream::new(
1110 RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1111 .unwrap()
1112 .as_stream(),
1113 tracker.clone(),
1114 );
1115 let first = stream1.next().await.unwrap().unwrap();
1116 assert_eq!(first.num_rows(), 1);
1117
1118 let mut stream2 = MemoryTrackedStream::new(
1119 RecordBatches::try_new(batch.schema.clone(), vec![batch])
1120 .unwrap()
1121 .as_stream(),
1122 tracker,
1123 );
1124 let result = timeout(Duration::from_secs(1), stream2.next())
1125 .await
1126 .unwrap();
1127 let error = result.unwrap().unwrap_err();
1128 assert!(error.to_string().contains("timed out waiting"));
1129 assert_eq!(exhausted.load(Ordering::Relaxed), 1);
1130 assert_eq!(rejected.load(Ordering::Relaxed), 1);
1131 }
1132
1133 #[tokio::test]
1134 async fn test_memory_tracked_stream_fail_policy_rejects_immediately() {
1135 let exhausted = Arc::new(AtomicUsize::new(0));
1136 let rejected = Arc::new(AtomicUsize::new(0));
1137 let exhausted_counter = exhausted.clone();
1138 let rejected_counter = rejected.clone();
1139 let tracker = QueryMemoryTracker::builder(MB, OnExhaustedPolicy::Fail)
1140 .on_exhausted(move || {
1141 exhausted_counter.fetch_add(1, Ordering::Relaxed);
1142 })
1143 .on_reject(move || {
1144 rejected_counter.fetch_add(1, Ordering::Relaxed);
1145 })
1146 .build();
1147 let batch = large_string_batch(700 * 1024);
1148
1149 let mut stream1 = MemoryTrackedStream::new(
1150 RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1151 .unwrap()
1152 .as_stream(),
1153 tracker.clone(),
1154 );
1155 let first = stream1.next().await.unwrap().unwrap();
1156 assert_eq!(first.num_rows(), 1);
1157
1158 let mut stream2 = MemoryTrackedStream::new(
1159 RecordBatches::try_new(batch.schema.clone(), vec![batch])
1160 .unwrap()
1161 .as_stream(),
1162 tracker,
1163 );
1164 let result = stream2.next().await.unwrap();
1165 assert!(result.is_err());
1166 assert_eq!(exhausted.load(Ordering::Relaxed), 1);
1167 assert_eq!(rejected.load(Ordering::Relaxed), 1);
1168 }
1169}