1#![feature(never_type)]
16
17pub mod adapter;
18pub mod cursor;
19pub mod error;
20pub mod filter;
21pub mod recordbatch;
22pub mod util;
23
24use std::fmt;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28
29use adapter::RecordBatchMetrics;
30use arc_swap::ArcSwapOption;
31use common_base::readable_size::ReadableSize;
32use common_error::ext::BoxedError;
33use common_memory_manager::{
34 MemoryGuard, MemoryManager, MemoryMetrics, OnExhaustedPolicy, PermitGranularity,
35};
36use common_telemetry::tracing::Span;
37pub use datafusion::physical_plan::SendableRecordBatchStream as DfSendableRecordBatchStream;
38use datatypes::arrow::array::{ArrayRef, AsArray, StringBuilder};
39use datatypes::arrow::compute::SortOptions;
40pub use datatypes::arrow::record_batch::RecordBatch as DfRecordBatch;
41use datatypes::arrow::util::pretty;
42use datatypes::prelude::{ConcreteDataType, VectorRef};
43use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
44use datatypes::types::{JsonFormat, jsonb_to_string};
45use error::Result;
46use futures::task::{Context, Poll};
47use futures::{Stream, TryStreamExt};
48pub use recordbatch::RecordBatch;
49use snafu::{IntoError, ResultExt, ensure};
50
51use crate::error::NewDfRecordBatchSnafu;
52
53pub trait RecordBatchStream: Stream<Item = Result<RecordBatch>> {
54 fn name(&self) -> &str {
55 "RecordBatchStream"
56 }
57
58 fn schema(&self) -> SchemaRef;
59
60 fn output_ordering(&self) -> Option<&[OrderOption]>;
61
62 fn metrics(&self) -> Option<RecordBatchMetrics>;
63}
64
65pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send>>;
66
67#[derive(Debug, Clone, PartialEq, Eq)]
68pub struct OrderOption {
69 pub name: String,
70 pub options: SortOptions,
71}
72
73pub struct SendableRecordBatchMapper {
80 inner: SendableRecordBatchStream,
81 mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
84 schema: SchemaRef,
86 apply_mapper: bool,
88}
89
90pub fn map_json_type_to_string(
96 batch: RecordBatch,
97 original_schema: &SchemaRef,
98 mapped_schema: &SchemaRef,
99) -> Result<RecordBatch> {
100 let mut vectors = Vec::with_capacity(original_schema.column_schemas().len());
101 for (vector, schema) in batch.columns().iter().zip(original_schema.column_schemas()) {
102 if let ConcreteDataType::Json(j) = &schema.data_type {
103 if matches!(&j.format, JsonFormat::Jsonb) {
104 let mut string_vector_builder = StringBuilder::new();
105 let binary_vector = vector.as_binary::<i32>();
106 for value in binary_vector.iter() {
107 let Some(value) = value else {
108 string_vector_builder.append_null();
109 continue;
110 };
111 let string_value =
112 jsonb_to_string(value).with_context(|_| error::CastVectorSnafu {
113 from_type: schema.data_type.clone(),
114 to_type: ConcreteDataType::string_datatype(),
115 })?;
116 string_vector_builder.append_value(string_value);
117 }
118
119 let string_vector = string_vector_builder.finish();
120 vectors.push(Arc::new(string_vector) as ArrayRef);
121 } else {
122 vectors.push(vector.clone());
123 }
124 } else {
125 vectors.push(vector.clone());
126 }
127 }
128
129 let record_batch = datatypes::arrow::record_batch::RecordBatch::try_new(
130 mapped_schema.arrow_schema().clone(),
131 vectors,
132 )
133 .context(NewDfRecordBatchSnafu)?;
134 Ok(RecordBatch::from_df_record_batch(
135 mapped_schema.clone(),
136 record_batch,
137 ))
138}
139
140pub fn map_json_type_to_string_schema(schema: SchemaRef) -> (SchemaRef, bool) {
148 let mut new_columns = Vec::with_capacity(schema.column_schemas().len());
149 let mut apply_mapper = false;
150 for column in schema.column_schemas() {
151 if matches!(column.data_type, ConcreteDataType::Json(_)) {
152 new_columns.push(ColumnSchema::new(
153 column.name.clone(),
154 ConcreteDataType::string_datatype(),
155 column.is_nullable(),
156 ));
157 apply_mapper = true;
158 } else {
159 new_columns.push(column.clone());
160 }
161 }
162 (Arc::new(Schema::new(new_columns)), apply_mapper)
163}
164
165impl SendableRecordBatchMapper {
166 pub fn new(
168 inner: SendableRecordBatchStream,
169 mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
170 schema_mapper: fn(SchemaRef) -> (SchemaRef, bool),
171 ) -> Self {
172 let (mapped_schema, apply_mapper) = schema_mapper(inner.schema());
173 Self {
174 inner,
175 mapper,
176 schema: mapped_schema,
177 apply_mapper,
178 }
179 }
180}
181
182impl RecordBatchStream for SendableRecordBatchMapper {
183 fn name(&self) -> &str {
184 "SendableRecordBatchMapper"
185 }
186
187 fn schema(&self) -> SchemaRef {
188 self.schema.clone()
189 }
190
191 fn output_ordering(&self) -> Option<&[OrderOption]> {
192 self.inner.output_ordering()
193 }
194
195 fn metrics(&self) -> Option<RecordBatchMetrics> {
196 self.inner.metrics()
197 }
198}
199
200impl Stream for SendableRecordBatchMapper {
201 type Item = Result<RecordBatch>;
202
203 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204 if self.apply_mapper {
205 Pin::new(&mut self.inner).poll_next(cx).map(|opt| {
206 opt.map(|result| {
207 result
208 .and_then(|batch| (self.mapper)(batch, &self.inner.schema(), &self.schema))
209 })
210 })
211 } else {
212 Pin::new(&mut self.inner).poll_next(cx)
213 }
214 }
215}
216
217pub struct EmptyRecordBatchStream {
220 schema: SchemaRef,
222}
223
224impl EmptyRecordBatchStream {
225 pub fn new(schema: SchemaRef) -> Self {
227 Self { schema }
228 }
229}
230
231impl RecordBatchStream for EmptyRecordBatchStream {
232 fn schema(&self) -> SchemaRef {
233 self.schema.clone()
234 }
235
236 fn output_ordering(&self) -> Option<&[OrderOption]> {
237 None
238 }
239
240 fn metrics(&self) -> Option<RecordBatchMetrics> {
241 None
242 }
243}
244
245impl Stream for EmptyRecordBatchStream {
246 type Item = Result<RecordBatch>;
247
248 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
249 Poll::Ready(None)
250 }
251}
252
253#[derive(Debug, PartialEq)]
254pub struct RecordBatches {
255 schema: SchemaRef,
256 batches: Vec<RecordBatch>,
257}
258
259impl RecordBatches {
260 pub fn try_from_columns<I: IntoIterator<Item = VectorRef>>(
261 schema: SchemaRef,
262 columns: I,
263 ) -> Result<Self> {
264 let batches = vec![RecordBatch::new(schema.clone(), columns)?];
265 Ok(Self { schema, batches })
266 }
267
268 pub async fn try_collect(stream: SendableRecordBatchStream) -> Result<Self> {
269 let schema = stream.schema();
270 let batches = stream.try_collect::<Vec<_>>().await?;
271 Ok(Self { schema, batches })
272 }
273
274 #[inline]
275 pub fn empty() -> Self {
276 Self {
277 schema: Arc::new(Schema::new(vec![])),
278 batches: vec![],
279 }
280 }
281
282 pub fn iter(&self) -> impl Iterator<Item = &RecordBatch> {
283 self.batches.iter()
284 }
285
286 pub fn pretty_print(&self) -> Result<String> {
287 let df_batches = &self
288 .iter()
289 .map(|x| x.df_record_batch().clone())
290 .collect::<Vec<_>>();
291 let result = pretty::pretty_format_batches(df_batches).context(error::FormatSnafu)?;
292
293 Ok(result.to_string())
294 }
295
296 pub fn try_new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Result<Self> {
297 for batch in &batches {
298 ensure!(
299 batch.schema == schema,
300 error::CreateRecordBatchesSnafu {
301 reason: format!(
302 "expect RecordBatch schema equals {:?}, actual: {:?}",
303 schema, batch.schema
304 )
305 }
306 )
307 }
308 Ok(Self { schema, batches })
309 }
310
311 pub fn schema(&self) -> SchemaRef {
312 self.schema.clone()
313 }
314
315 pub fn take(self) -> Vec<RecordBatch> {
316 self.batches
317 }
318
319 pub fn as_stream(&self) -> SendableRecordBatchStream {
320 Box::pin(SimpleRecordBatchStream {
321 inner: RecordBatches {
322 schema: self.schema(),
323 batches: self.batches.clone(),
324 },
325 index: 0,
326 })
327 }
328}
329
330impl IntoIterator for RecordBatches {
331 type Item = RecordBatch;
332 type IntoIter = std::vec::IntoIter<Self::Item>;
333
334 fn into_iter(self) -> Self::IntoIter {
335 self.batches.into_iter()
336 }
337}
338
339pub struct SimpleRecordBatchStream {
340 inner: RecordBatches,
341 index: usize,
342}
343
344impl RecordBatchStream for SimpleRecordBatchStream {
345 fn schema(&self) -> SchemaRef {
346 self.inner.schema()
347 }
348
349 fn output_ordering(&self) -> Option<&[OrderOption]> {
350 None
351 }
352
353 fn metrics(&self) -> Option<RecordBatchMetrics> {
354 None
355 }
356}
357
358impl Stream for SimpleRecordBatchStream {
359 type Item = Result<RecordBatch>;
360
361 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
362 Poll::Ready(if self.index < self.inner.batches.len() {
363 let batch = self.inner.batches[self.index].clone();
364 self.index += 1;
365 Some(Ok(batch))
366 } else {
367 None
368 })
369 }
370}
371
372pub struct RecordBatchStreamWrapper<S> {
374 pub schema: SchemaRef,
375 pub stream: S,
376 pub output_ordering: Option<Vec<OrderOption>>,
377 pub metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
378 pub span: Span,
379}
380
381impl<S> RecordBatchStreamWrapper<S> {
382 pub fn new(schema: SchemaRef, stream: S) -> RecordBatchStreamWrapper<S> {
384 RecordBatchStreamWrapper {
385 schema,
386 stream,
387 output_ordering: None,
388 metrics: Default::default(),
389 span: Span::current(),
390 }
391 }
392}
393
394impl<S: Stream<Item = Result<RecordBatch>> + Unpin> RecordBatchStream
395 for RecordBatchStreamWrapper<S>
396{
397 fn name(&self) -> &str {
398 "RecordBatchStreamWrapper"
399 }
400
401 fn schema(&self) -> SchemaRef {
402 self.schema.clone()
403 }
404
405 fn output_ordering(&self) -> Option<&[OrderOption]> {
406 self.output_ordering.as_deref()
407 }
408
409 fn metrics(&self) -> Option<RecordBatchMetrics> {
410 self.metrics.load().as_ref().map(|s| s.as_ref().clone())
411 }
412}
413
414impl<S: Stream<Item = Result<RecordBatch>> + Unpin> Stream for RecordBatchStreamWrapper<S> {
415 type Item = Result<RecordBatch>;
416
417 fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
418 let _entered = self.span.clone().entered();
419 Pin::new(&mut self.stream).poll_next(ctx)
420 }
421}
422
423#[derive(Clone)]
427pub struct QueryMemoryTracker {
428 manager: MemoryManager<CallbackMemoryMetrics>,
429 metrics: CallbackMemoryMetrics,
430 on_exhausted_policy: OnExhaustedPolicy,
431}
432
433impl fmt::Debug for QueryMemoryTracker {
434 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435 f.debug_struct("QueryMemoryTracker")
436 .field("current", &self.current())
437 .field("limit", &self.limit())
438 .field("on_exhausted_policy", &self.on_exhausted_policy)
439 .field("on_update", &self.metrics.has_on_update())
440 .field("on_reject", &self.metrics.has_on_reject())
441 .finish()
442 }
443}
444
445impl QueryMemoryTracker {
446 pub fn builder(
448 limit: usize,
449 on_exhausted_policy: OnExhaustedPolicy,
450 ) -> QueryMemoryTrackerBuilder {
451 QueryMemoryTrackerBuilder {
452 limit,
453 on_exhausted_policy,
454 on_update: None,
455 on_reject: None,
456 }
457 }
458
459 fn new_stream_tracker(&self) -> StreamMemoryTracker {
460 StreamMemoryTracker {
461 tracker: self.clone(),
462 guard: self.manager.try_acquire(0).unwrap(),
463 tracked_bytes: 0,
464 }
465 }
466 pub fn current(&self) -> usize {
468 self.manager.used_bytes() as usize
469 }
470
471 fn limit(&self) -> usize {
472 self.manager.limit_bytes() as usize
473 }
474
475 fn reject_error(
476 &self,
477 current: usize,
478 additional: usize,
479 stream_tracked: usize,
480 ) -> error::Error {
481 let limit = self.limit();
482 let msg = format!(
483 "{} requested, {} used globally ({}%), {} used by this stream, hard limit: {}",
484 ReadableSize(additional as u64),
485 ReadableSize(current as u64),
486 if limit > 0 { current * 100 / limit } else { 0 },
487 ReadableSize(stream_tracked as u64),
488 ReadableSize(limit as u64)
489 );
490 error::ExceedMemoryLimitSnafu { msg }.build()
491 }
492}
493
494pub struct QueryMemoryTrackerBuilder {
496 limit: usize,
497 on_exhausted_policy: OnExhaustedPolicy,
498 on_update: Option<UpdateCallback>,
499 on_reject: Option<RejectCallback>,
500}
501
502impl QueryMemoryTrackerBuilder {
503 pub fn on_update<F>(mut self, on_update: F) -> Self
510 where
511 F: Fn(usize) + Send + Sync + 'static,
512 {
513 self.on_update = Some(Arc::new(on_update));
514 self
515 }
516
517 pub fn on_reject<F>(mut self, on_reject: F) -> Self
523 where
524 F: Fn() + Send + Sync + 'static,
525 {
526 self.on_reject = Some(Arc::new(on_reject));
527 self
528 }
529
530 pub fn build(self) -> QueryMemoryTracker {
532 let metrics = CallbackMemoryMetrics::new(self.on_update, self.on_reject);
533 let manager = MemoryManager::with_granularity(
534 self.limit as u64,
535 PermitGranularity::Kilobyte,
536 metrics.clone(),
537 );
538
539 QueryMemoryTracker {
540 manager,
541 metrics,
542 on_exhausted_policy: self.on_exhausted_policy,
543 }
544 }
545}
546
547struct StreamMemoryTracker {
548 tracker: QueryMemoryTracker,
549 guard: MemoryGuard<CallbackMemoryMetrics>,
550 tracked_bytes: usize,
551}
552
553type MemoryAcquireResult = std::result::Result<(), common_memory_manager::Error>;
554
555impl StreamMemoryTracker {
556 fn try_track(&mut self, additional: usize) -> Result<()> {
557 if self.guard.try_acquire_additional(additional as u64) {
558 self.tracked_bytes = self.tracked_bytes.saturating_add(additional);
559 Ok(())
560 } else {
561 Err(self.reject_error(additional))
562 }
563 }
564
565 async fn track_with_policy(mut self, additional: usize) -> (Self, MemoryAcquireResult) {
566 let result = self
567 .guard
568 .acquire_additional_with_policy(additional as u64, self.tracker.on_exhausted_policy)
569 .await;
570 if result.is_ok() {
571 self.tracked_bytes = self.tracked_bytes.saturating_add(additional);
572 }
573 (self, result)
574 }
575
576 fn reject_error(&self, additional: usize) -> error::Error {
577 let current = self.tracker.current();
578 self.tracker
579 .reject_error(current, additional, self.tracked_bytes)
580 }
581
582 fn wait_error(&self, additional: usize, source: common_memory_manager::Error) -> error::Error {
583 match source {
584 common_memory_manager::Error::MemoryLimitExceeded { .. } => {
585 self.reject_error(additional)
586 }
587 common_memory_manager::Error::MemoryAcquireTimeout { waited, .. } => {
588 let current = self.tracker.current();
589 let limit = self.tracker.limit();
590 let msg = format!(
591 "timed out waiting {:?} for {}, {} used globally ({}%), {} used by this stream, hard limit: {}",
592 waited,
593 ReadableSize(additional as u64),
594 ReadableSize(current as u64),
595 if limit > 0 { current * 100 / limit } else { 0 },
596 ReadableSize(self.tracked_bytes as u64),
597 ReadableSize(limit as u64)
598 );
599 error::ExceedMemoryLimitSnafu { msg }.build()
600 }
601 error => error::ExternalSnafu.into_error(BoxedError::new(error)),
602 }
603 }
604}
605
606type PendingTrackFuture = Pin<
607 Box<dyn Future<Output = (StreamMemoryTracker, RecordBatch, usize, MemoryAcquireResult)> + Send>,
608>;
609
610#[derive(Clone)]
611struct CallbackMemoryMetrics {
612 inner: Arc<CallbackMemoryMetricsInner>,
613}
614
615type UpdateCallback = Arc<dyn Fn(usize) + Send + Sync>;
616type RejectCallback = Arc<dyn Fn() + Send + Sync>;
617
618struct CallbackMemoryMetricsInner {
619 on_update: Option<UpdateCallback>,
620 on_reject: Option<RejectCallback>,
621}
622
623impl CallbackMemoryMetrics {
624 fn new(on_update: Option<UpdateCallback>, on_reject: Option<RejectCallback>) -> Self {
625 Self {
626 inner: Arc::new(CallbackMemoryMetricsInner {
627 on_update,
628 on_reject,
629 }),
630 }
631 }
632
633 fn has_on_update(&self) -> bool {
634 self.inner.on_update.is_some()
635 }
636
637 fn has_on_reject(&self) -> bool {
638 self.inner.on_reject.is_some()
639 }
640}
641
642impl MemoryMetrics for CallbackMemoryMetrics {
643 fn set_limit(&self, _: i64) {}
644
645 fn set_in_use(&self, bytes: i64) {
646 if let Some(callback) = &self.inner.on_update {
647 callback(bytes.max(0) as usize);
648 }
649 }
650
651 fn inc_rejected(&self, _: &str) {
652 if let Some(callback) = &self.inner.on_reject {
653 callback();
654 }
655 }
656}
657
658pub struct MemoryTrackedStream {
660 inner: SendableRecordBatchStream,
661 tracker: Option<StreamMemoryTracker>,
662 waiting: Option<PendingTrackFuture>,
666}
667
668impl MemoryTrackedStream {
669 pub fn new(inner: SendableRecordBatchStream, tracker: QueryMemoryTracker) -> Self {
670 Self {
671 inner,
672 tracker: Some(tracker.new_stream_tracker()),
673 waiting: None,
674 }
675 }
676
677 fn ready_tracker_mut(&mut self) -> &mut StreamMemoryTracker {
678 debug_assert!(
679 self.waiting.is_none(),
680 "a ready tracker must not coexist with a waiting future"
681 );
682 self.tracker.as_mut().unwrap()
683 }
684
685 fn enter_waiting(&mut self, batch: RecordBatch, additional: usize) {
686 debug_assert!(
687 self.waiting.is_none(),
688 "enter_waiting should only be called from the ready state"
689 );
690 debug_assert!(
691 self.tracker.is_some(),
692 "enter_waiting requires a tracker in the ready state"
693 );
694 let tracker = self.tracker.take().unwrap();
695 self.waiting = Some(Self::start_waiting(tracker, batch, additional));
696 }
697
698 fn start_waiting(
699 tracker: StreamMemoryTracker,
700 batch: RecordBatch,
701 additional: usize,
702 ) -> PendingTrackFuture {
703 Box::pin(async move {
704 let (tracker, result) = tracker.track_with_policy(additional).await;
705 (tracker, batch, additional, result)
706 })
707 }
708
709 fn poll_waiting(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<RecordBatch>>> {
710 let future = self.waiting.as_mut().unwrap();
711 match future.as_mut().poll(cx) {
712 Poll::Ready((tracker, batch, additional, result)) => {
713 let output = match result {
714 Ok(()) => Ok(batch),
715 Err(error) => Err(tracker.wait_error(additional, error)),
716 };
717 self.waiting = None;
718 self.tracker = Some(tracker);
719 Poll::Ready(Some(output))
720 }
721 Poll::Pending => Poll::Pending,
722 }
723 }
724
725 fn poll_batch(
726 &mut self,
727 batch: RecordBatch,
728 cx: &mut Context<'_>,
729 ) -> Poll<Option<Result<RecordBatch>>> {
730 let additional = batch.buffer_memory_size();
731 let tracker = self.ready_tracker_mut();
732
733 if let Err(error) = tracker.try_track(additional) {
734 match tracker.tracker.on_exhausted_policy {
735 OnExhaustedPolicy::Fail => return Poll::Ready(Some(Err(error))),
736 OnExhaustedPolicy::Wait { .. } => {
741 self.enter_waiting(batch, additional);
742 return self.poll_waiting(cx);
743 }
744 }
745 }
746
747 Poll::Ready(Some(Ok(batch)))
748 }
749}
750
751impl Stream for MemoryTrackedStream {
752 type Item = Result<RecordBatch>;
753
754 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
755 if self.waiting.is_some() {
756 return self.poll_waiting(cx);
757 }
758
759 match Pin::new(&mut self.inner).poll_next(cx) {
760 Poll::Ready(Some(Ok(batch))) => self.poll_batch(batch, cx),
761 Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
762 Poll::Ready(None) => Poll::Ready(None),
763 Poll::Pending => Poll::Pending,
764 }
765 }
766
767 fn size_hint(&self) -> (usize, Option<usize>) {
768 self.inner.size_hint()
769 }
770}
771
772impl RecordBatchStream for MemoryTrackedStream {
773 fn schema(&self) -> SchemaRef {
774 self.inner.schema()
775 }
776
777 fn output_ordering(&self) -> Option<&[OrderOption]> {
778 self.inner.output_ordering()
779 }
780
781 fn metrics(&self) -> Option<RecordBatchMetrics> {
782 self.inner.metrics()
783 }
784}
785
786#[cfg(test)]
787mod tests {
788 use std::sync::Arc;
789 use std::time::Duration;
790
791 use common_memory_manager::{OnExhaustedPolicy, PermitGranularity};
792 use datatypes::prelude::{ConcreteDataType, VectorRef};
793 use datatypes::schema::{ColumnSchema, Schema};
794 use datatypes::vectors::{BooleanVector, Int32Vector, StringVector};
795 use futures::StreamExt;
796 use tokio::time::{sleep, timeout};
797
798 use super::*;
799
800 fn large_string_batch(bytes: usize) -> RecordBatch {
801 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
802 "payload",
803 ConcreteDataType::string_datatype(),
804 false,
805 )]));
806 let payload = "x".repeat(bytes);
807 let vector: VectorRef = Arc::new(StringVector::from(vec![payload]));
808 RecordBatch::new(schema, vec![vector]).unwrap()
809 }
810
811 fn aligned_tracked_bytes(bytes: usize) -> usize {
812 PermitGranularity::Kilobyte
813 .permits_to_bytes(PermitGranularity::Kilobyte.bytes_to_permits(bytes as u64))
814 as usize
815 }
816
817 #[test]
818 fn test_recordbatches_try_from_columns() {
819 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
820 "a",
821 ConcreteDataType::int32_datatype(),
822 false,
823 )]));
824 let result = RecordBatches::try_from_columns(
825 schema.clone(),
826 vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
827 );
828 assert!(result.is_err());
829
830 let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
831 let expected = vec![RecordBatch::new(schema.clone(), vec![v.clone()]).unwrap()];
832 let r = RecordBatches::try_from_columns(schema, vec![v]).unwrap();
833 assert_eq!(r.take(), expected);
834 }
835
836 #[test]
837 fn test_recordbatches_try_new() {
838 let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
839 let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
840 let column_c = ColumnSchema::new("c", ConcreteDataType::boolean_datatype(), false);
841
842 let va: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
843 let vb: VectorRef = Arc::new(StringVector::from(vec!["hello", "world"]));
844 let vc: VectorRef = Arc::new(BooleanVector::from(vec![true, false]));
845
846 let schema1 = Arc::new(Schema::new(vec![column_a.clone(), column_b]));
847 let batch1 = RecordBatch::new(schema1.clone(), vec![va.clone(), vb]).unwrap();
848
849 let schema2 = Arc::new(Schema::new(vec![column_a, column_c]));
850 let batch2 = RecordBatch::new(schema2.clone(), vec![va, vc]).unwrap();
851
852 let result = RecordBatches::try_new(schema1.clone(), vec![batch1.clone(), batch2]);
853 assert!(result.is_err());
854 assert_eq!(
855 result.unwrap_err().to_string(),
856 format!(
857 "Failed to create RecordBatches, reason: expect RecordBatch schema equals {schema1:?}, actual: {schema2:?}",
858 )
859 );
860
861 let batches = RecordBatches::try_new(schema1.clone(), vec![batch1.clone()]).unwrap();
862 let expected = "\
863+---+-------+
864| a | b |
865+---+-------+
866| 1 | hello |
867| 2 | world |
868+---+-------+";
869 assert_eq!(batches.pretty_print().unwrap(), expected);
870
871 assert_eq!(schema1, batches.schema());
872 assert_eq!(vec![batch1], batches.take());
873 }
874
875 #[tokio::test]
876 async fn test_simple_recordbatch_stream() {
877 let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
878 let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
879 let schema = Arc::new(Schema::new(vec![column_a, column_b]));
880
881 let va1: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
882 let vb1: VectorRef = Arc::new(StringVector::from(vec!["a", "b"]));
883 let batch1 = RecordBatch::new(schema.clone(), vec![va1, vb1]).unwrap();
884
885 let va2: VectorRef = Arc::new(Int32Vector::from_slice([3, 4, 5]));
886 let vb2: VectorRef = Arc::new(StringVector::from(vec!["c", "d", "e"]));
887 let batch2 = RecordBatch::new(schema.clone(), vec![va2, vb2]).unwrap();
888
889 let recordbatches =
890 RecordBatches::try_new(schema.clone(), vec![batch1.clone(), batch2.clone()]).unwrap();
891 let stream = recordbatches.as_stream();
892 let collected = util::collect(stream).await.unwrap();
893 assert_eq!(collected.len(), 2);
894 assert_eq!(collected[0], batch1);
895 assert_eq!(collected[1], batch2);
896 }
897
898 const MB: usize = 1024 * 1024;
899
900 #[test]
901 fn test_query_memory_tracker_basic() {
902 let tracker =
903 Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
904
905 let mut stream1 = tracker.new_stream_tracker();
906 assert!(stream1.try_track(5 * MB).is_ok());
907 assert_eq!(tracker.current(), 5 * MB);
908
909 let mut stream2 = tracker.new_stream_tracker();
910 assert!(stream2.try_track(4 * MB).is_ok());
911 assert_eq!(tracker.current(), 9 * MB);
912
913 drop(stream1);
914 drop(stream2);
915 assert_eq!(tracker.current(), 0);
916 }
917
918 #[test]
919 fn test_query_memory_tracker_shared_global_limit() {
920 let tracker =
921 Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
922 let mut stream1 = tracker.new_stream_tracker();
923 let mut stream2 = tracker.new_stream_tracker();
924
925 assert!(stream1.try_track(3 * MB).is_ok());
926 assert_eq!(tracker.current(), 3 * MB);
927 assert!(stream2.try_track(6 * MB).is_ok());
928 assert_eq!(tracker.current(), 9 * MB);
929
930 let err = stream2.try_track(2 * MB).unwrap_err();
931 let err_msg = err.to_string();
932 assert!(err_msg.contains("6.0MiB used by this stream"));
933 assert!(err_msg.contains("9.0MiB used globally (90%)"));
934 assert!(err_msg.contains("hard limit: 10.0MiB"));
935 assert_eq!(tracker.current(), 9 * MB);
936
937 drop(stream1);
938 assert_eq!(tracker.current(), 6 * MB);
939 drop(stream2);
940 assert_eq!(tracker.current(), 0);
941 }
942
943 #[test]
944 fn test_query_memory_tracker_hard_limit() {
945 let tracker =
946 Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
947 let mut stream = tracker.new_stream_tracker();
948
949 assert!(stream.try_track(9 * MB).is_ok());
950 assert_eq!(tracker.current(), 9 * MB);
951
952 assert!(stream.try_track(2 * MB).is_err());
953 assert_eq!(tracker.current(), 9 * MB);
954
955 assert!(stream.try_track(MB).is_ok());
956 assert_eq!(tracker.current(), 10 * MB);
957
958 assert!(stream.try_track(MB).is_err());
959 assert_eq!(tracker.current(), 10 * MB);
960
961 drop(stream);
962 assert_eq!(tracker.current(), 0);
963 }
964
965 #[test]
966 fn test_query_memory_tracker_unlimited() {
967 let tracker = Arc::new(QueryMemoryTracker::builder(0, OnExhaustedPolicy::Fail).build());
968 let mut stream = tracker.new_stream_tracker();
969
970 assert!(stream.try_track(10 * MB).is_ok());
971 assert_eq!(tracker.current(), 10 * MB);
972 drop(stream);
973 assert_eq!(tracker.current(), 0);
974 }
975
976 #[test]
977 fn test_query_memory_tracker_rounds_to_kilobytes() {
978 let tracker =
979 Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
980 let mut stream = tracker.new_stream_tracker();
981
982 assert!(stream.try_track(1_537).is_ok());
983 assert_eq!(tracker.current(), 2 * 1024);
984
985 drop(stream);
986 assert_eq!(tracker.current(), 0);
987 }
988
989 #[tokio::test]
990 async fn test_memory_tracked_stream_waits_for_capacity() {
991 let tracker = QueryMemoryTracker::builder(
992 MB,
993 OnExhaustedPolicy::Wait {
994 timeout: Duration::from_millis(200),
995 },
996 )
997 .build();
998 let batch = large_string_batch(700 * 1024);
999 let expected_bytes = aligned_tracked_bytes(batch.buffer_memory_size());
1000
1001 let mut stream1 = MemoryTrackedStream::new(
1002 RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1003 .unwrap()
1004 .as_stream(),
1005 tracker.clone(),
1006 );
1007 let first = stream1.next().await.unwrap().unwrap();
1008 assert_eq!(first.num_rows(), 1);
1009 assert_eq!(tracker.current(), expected_bytes);
1010
1011 let stream2 = MemoryTrackedStream::new(
1012 RecordBatches::try_new(batch.schema.clone(), vec![batch])
1013 .unwrap()
1014 .as_stream(),
1015 tracker.clone(),
1016 );
1017 let waiter = tokio::spawn(async move {
1018 let mut stream2 = stream2;
1019 stream2.next().await.unwrap()
1020 });
1021
1022 sleep(Duration::from_millis(50)).await;
1023 assert!(!waiter.is_finished());
1024
1025 drop(stream1);
1026 let second = waiter.await.unwrap().unwrap();
1027 assert_eq!(second.num_rows(), 1);
1028 }
1029
1030 #[tokio::test]
1031 async fn test_memory_tracked_stream_wait_times_out() {
1032 let tracker = QueryMemoryTracker::builder(
1033 MB,
1034 OnExhaustedPolicy::Wait {
1035 timeout: Duration::from_millis(50),
1036 },
1037 )
1038 .build();
1039 let batch = large_string_batch(700 * 1024);
1040
1041 let mut stream1 = MemoryTrackedStream::new(
1042 RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1043 .unwrap()
1044 .as_stream(),
1045 tracker.clone(),
1046 );
1047 let first = stream1.next().await.unwrap().unwrap();
1048 assert_eq!(first.num_rows(), 1);
1049
1050 let mut stream2 = MemoryTrackedStream::new(
1051 RecordBatches::try_new(batch.schema.clone(), vec![batch])
1052 .unwrap()
1053 .as_stream(),
1054 tracker,
1055 );
1056 let result = timeout(Duration::from_secs(1), stream2.next())
1057 .await
1058 .unwrap();
1059 let error = result.unwrap().unwrap_err();
1060 assert!(error.to_string().contains("timed out waiting"));
1061 }
1062}