diff --git a/Cargo.lock b/Cargo.lock index 98017cb3f3..9f40f5bfdb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3015,6 +3015,7 @@ dependencies = [ "snafu", "sql", "table", + "test-util", "tokio", "tokio-stream", ] @@ -3786,6 +3787,22 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "test-util" +version = "0.1.0" +dependencies = [ + "arrow2", + "async-trait", + "common-query", + "common-recordbatch", + "datafusion", + "datatypes", + "futures", + "snafu", + "table", + "tokio", +] + [[package]] name = "textwrap" version = "0.11.0" diff --git a/Cargo.toml b/Cargo.toml index 00274c419d..132b677d7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,4 +23,5 @@ members = [ "src/store-api", "src/table", "src/table-engine", + "test-util", ] diff --git a/src/query/Cargo.toml b/src/query/Cargo.toml index ea4a24a20d..797dfccc26 100644 --- a/src/query/Cargo.toml +++ b/src/query/Cargo.toml @@ -32,5 +32,6 @@ tokio = "1.0" num = "0.4" num-traits = "0.2" rand = "0.8" +test-util = { path = "../../test-util" } tokio = { version = "1.0", features = ["full"] } tokio-stream = "0.1" diff --git a/src/query/src/datafusion/catalog_adapter.rs b/src/query/src/datafusion/catalog_adapter.rs index 580ff783c6..aff837b8d7 100644 --- a/src/query/src/datafusion/catalog_adapter.rs +++ b/src/query/src/datafusion/catalog_adapter.rs @@ -224,7 +224,7 @@ impl SchemaProvider for SchemaProviderAdapter { })? .map(|table| { let adapter = TableAdapter::new(table, self.runtime.clone()) - .context(error::ConvertTableSnafu)?; + .context(error::TableSchemaMismatchSnafu)?; Ok(Arc::new(adapter) as _) }) .transpose() diff --git a/src/query/src/datafusion/error.rs b/src/query/src/datafusion/error.rs index 1d3eafd8a0..5ac3aa3684 100644 --- a/src/query/src/datafusion/error.rs +++ b/src/query/src/datafusion/error.rs @@ -39,8 +39,8 @@ pub enum InnerError { source: datatypes::error::Error, }, - #[snafu(display("Fail to convert table, source: {}", source))] - ConvertTable { + #[snafu(display("Failed to convert table schema, source: {}", source))] + TableSchemaMismatch { #[snafu(backtrace)] source: table::error::Error, }, @@ -54,7 +54,7 @@ impl ErrorExt for InnerError { // TODO(yingwen): Further categorize datafusion error. Datafusion { .. } => StatusCode::EngineExecuteQuery, // This downcast should not fail in usual case. - PhysicalPlanDowncast { .. } | ConvertSchema { .. } | ConvertTable { .. } => { + PhysicalPlanDowncast { .. } | ConvertSchema { .. } | TableSchemaMismatch { .. } => { StatusCode::Unexpected } ParseSql { source, .. } => source.status_code(), diff --git a/src/query/src/datafusion/plan_adapter.rs b/src/query/src/datafusion/plan_adapter.rs index 7a1e24e606..69852fb7e7 100644 --- a/src/query/src/datafusion/plan_adapter.rs +++ b/src/query/src/datafusion/plan_adapter.rs @@ -97,10 +97,10 @@ impl PhysicalPlan for PhysicalPlanAdapter { msg: "Fail to execute physical plan", })?; - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), - df_stream, - ))) + Ok(Box::pin( + RecordBatchStreamAdapter::try_new(df_stream) + .context(error::TableSchemaMismatchSnafu)?, + )) } fn as_any(&self) -> &dyn Any { diff --git a/src/query/tests/my_sum_udaf_example.rs b/src/query/tests/my_sum_udaf_example.rs index 79ca494c16..226d393370 100644 --- a/src/query/tests/my_sum_udaf_example.rs +++ b/src/query/tests/my_sum_udaf_example.rs @@ -2,8 +2,6 @@ use std::fmt::Debug; use std::marker::PhantomData; use std::sync::Arc; -mod testing_table; - use arc_swap::ArcSwapOption; use catalog::memory::{MemoryCatalogList, MemoryCatalogProvider, MemorySchemaProvider}; use catalog::{CatalogList, SchemaProvider, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; @@ -13,10 +11,11 @@ use common_query::error::Result as QueryResult; use common_query::logical_plan::Accumulator; use common_query::logical_plan::AggregateFunctionCreator; use common_query::prelude::*; -use common_recordbatch::util; +use common_recordbatch::{util, RecordBatch}; use datafusion::arrow_print; use datafusion_common::record_batch::RecordBatch as DfRecordBatch; use datatypes::prelude::*; +use datatypes::schema::{ColumnSchema, Schema}; use datatypes::types::DataTypeBuilder; use datatypes::types::PrimitiveType; use datatypes::vectors::PrimitiveVector; @@ -26,8 +25,7 @@ use query::error::Result; use query::query_engine::Output; use query::QueryEngineFactory; use table::TableRef; - -use crate::testing_table::TestingTable; +use test_util::MemTable; #[derive(Debug, Default)] struct MySumAccumulator @@ -217,10 +215,15 @@ where let table_name = format!("{}_numbers", std::any::type_name::()); let column_name = format!("{}_number", std::any::type_name::()); - let testing_table = Arc::new(TestingTable::new( - &column_name, - Arc::new(PrimitiveVector::::from_vec(numbers.clone())), - )); + let column_schemas = vec![ColumnSchema::new( + column_name.clone(), + T::build_data_type(), + true, + )]; + let schema = Arc::new(Schema::new(column_schemas.clone())); + let column: VectorRef = Arc::new(PrimitiveVector::::from_vec(numbers)); + let recordbatch = RecordBatch::new(schema, vec![column]).unwrap(); + let testing_table = Arc::new(MemTable::new(recordbatch)); let factory = new_query_engine_factory(table_name.clone(), testing_table); let engine = factory.query_engine(); diff --git a/src/query/tests/query_engine_test.rs b/src/query/tests/query_engine_test.rs index 253e4b80db..de3fc0bdfe 100644 --- a/src/query/tests/query_engine_test.rs +++ b/src/query/tests/query_engine_test.rs @@ -1,5 +1,4 @@ mod pow; -mod testing_table; use std::sync::Arc; @@ -14,8 +13,9 @@ use datafusion::field_util::SchemaExt; use datafusion::logical_plan::LogicalPlanBuilder; use datatypes::for_all_ordered_primitive_types; use datatypes::prelude::*; +use datatypes::schema::{ColumnSchema, Schema}; use datatypes::types::DataTypeBuilder; -use datatypes::vectors::PrimitiveVector; +use datatypes::vectors::{Float32Vector, Float64Vector, PrimitiveVector, UInt32Vector}; use num::NumCast; use query::error::Result; use query::plan::LogicalPlan; @@ -23,10 +23,9 @@ use query::query_engine::{Output, QueryEngineFactory}; use query::QueryEngine; use rand::Rng; use table::table::adapter::DfTableProviderAdapter; -use table::table::numbers::NumbersTable; +use test_util::MemTable; use crate::pow::pow; -use crate::testing_table::TestingTable; #[tokio::test] async fn test_datafusion_query_engine() -> Result<()> { @@ -35,8 +34,19 @@ async fn test_datafusion_query_engine() -> Result<()> { let factory = QueryEngineFactory::new(catalog_list); let engine = factory.query_engine(); + let column_schemas = vec![ColumnSchema::new( + "number", + ConcreteDataType::uint32_datatype(), + false, + )]; + let schema = Arc::new(Schema::new(column_schemas)); + let columns: Vec = vec![Arc::new(UInt32Vector::from_slice( + (0..100).collect::>(), + ))]; + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + let table = Arc::new(MemTable::new(recordbatch)); + let limit = 10; - let table = Arc::new(NumbersTable::default()); let table_provider = Arc::new(DfTableProviderAdapter::new(table.clone())); let plan = LogicalPlan::DfPlan( LogicalPlanBuilder::scan("numbers", table_provider, None) @@ -126,47 +136,73 @@ fn create_query_engine() -> Arc { let catalog_provider = Arc::new(MemoryCatalogProvider::new()); let catalog_list = Arc::new(MemoryCatalogList::default()); - macro_rules! create_testing_table { + // create table with ordered primitives, and all columns' length are even + let mut column_schemas = vec![]; + let mut columns = vec![]; + macro_rules! create_even_number_table { ([], $( { $T:ty } ),*) => { $( let mut rng = rand::thread_rng(); - let table_name = format!("{}_number_even", std::any::type_name::<$T>()); - let column_name = table_name.clone(); - let numbers = (1..=100).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::>(); - let table = Arc::new(TestingTable::new( - &column_name, - Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())), - )); - schema_provider.register_table(table_name, table).unwrap(); + let column_name = format!("{}_number_even", std::any::type_name::<$T>()); + let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true); + column_schemas.push(column_schema); - let table_name = format!("{}_number_odd", std::any::type_name::<$T>()); - let column_name = table_name.clone(); - let numbers = (1..=99).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::>(); - let table = Arc::new(TestingTable::new( - &column_name, - Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())), - )); - schema_provider.register_table(table_name, table).unwrap(); + let numbers = (1..=100).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::>(); + let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())); + columns.push(column); )* } } - for_all_ordered_primitive_types! { create_testing_table } + for_all_ordered_primitive_types! { create_even_number_table } - let table = Arc::new(TestingTable::new( - "f32_number", - Arc::new(PrimitiveVector::::from_vec(vec![1.0f32, 2.0, 3.0])), - )); + let schema = Arc::new(Schema::new(column_schemas.clone())); + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + let even_number_table = Arc::new(MemTable::new(recordbatch)); schema_provider - .register_table("f32_number".to_string(), table) + .register_table("even_numbers".to_string(), even_number_table) .unwrap(); - let table = Arc::new(TestingTable::new( - "f64_number", - Arc::new(PrimitiveVector::::from_vec(vec![1.0f64, 2.0, 3.0])), - )); + // create table with ordered primitives, and all columns' length are odd + let mut column_schemas = vec![]; + let mut columns = vec![]; + macro_rules! create_odd_number_table { + ([], $( { $T:ty } ),*) => { + $( + let mut rng = rand::thread_rng(); + + let column_name = format!("{}_number_odd", std::any::type_name::<$T>()); + let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true); + column_schemas.push(column_schema); + + let numbers = (1..=99).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::>(); + let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())); + columns.push(column); + )* + } + } + for_all_ordered_primitive_types! { create_odd_number_table } + + let schema = Arc::new(Schema::new(column_schemas.clone())); + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + let odd_number_table = Arc::new(MemTable::new(recordbatch)); schema_provider - .register_table("f64_number".to_string(), table) + .register_table("odd_numbers".to_string(), odd_number_table) + .unwrap(); + + // create table with floating numbers + let column_schemas = vec![ + ColumnSchema::new("f32_number", ConcreteDataType::float32_datatype(), true), + ColumnSchema::new("f64_number", ConcreteDataType::float64_datatype(), true), + ]; + let f32_numbers: VectorRef = Arc::new(Float32Vector::from_vec(vec![1.0f32, 2.0, 3.0])); + let f64_numbers: VectorRef = Arc::new(Float64Vector::from_vec(vec![1.0f64, 2.0, 3.0])); + let columns = vec![f32_numbers, f64_numbers]; + let schema = Arc::new(Schema::new(column_schemas)); + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + let float_number_table = Arc::new(MemTable::new(recordbatch)); + schema_provider + .register_table("float_numbers".to_string(), float_number_table) .unwrap(); catalog_provider.register_schema(DEFAULT_SCHEMA_NAME, schema_provider); @@ -176,12 +212,15 @@ fn create_query_engine() -> Arc { factory.query_engine().clone() } -async fn get_numbers_from_table(table_name: &str, engine: Arc) -> Vec +async fn get_numbers_from_table<'s, T>( + column_name: &'s str, + table_name: &'s str, + engine: Arc, +) -> Vec where T: Primitive + DataTypeBuilder, for<'a> T: Scalar = T>, { - let column_name = table_name; let sql = format!("SELECT {} FROM {}", column_name, table_name); let plan = engine.sql_to_plan(&sql).unwrap(); @@ -204,17 +243,17 @@ async fn test_median_aggregator() -> Result<()> { let engine = create_query_engine(); - test_median_failed::("f32_number", engine.clone()).await?; - test_median_failed::("f64_number", engine.clone()).await?; + test_median_failed::("f32_number", "float_numbers", engine.clone()).await?; + test_median_failed::("f64_number", "float_numbers", engine.clone()).await?; macro_rules! test_median { ([], $( { $T:ty } ),*) => { $( - let table_name = format!("{}_number_even", std::any::type_name::<$T>()); - test_median_success::<$T>(&table_name, engine.clone()).await?; + let column_name = format!("{}_number_even", std::any::type_name::<$T>()); + test_median_success::<$T>(&column_name, "even_numbers", engine.clone()).await?; - let table_name = format!("{}_number_odd", std::any::type_name::<$T>()); - test_median_success::<$T>(&table_name, engine.clone()).await?; + let column_name = format!("{}_number_odd", std::any::type_name::<$T>()); + test_median_success::<$T>(&column_name, "odd_numbers", engine.clone()).await?; )* } } @@ -222,12 +261,18 @@ async fn test_median_aggregator() -> Result<()> { Ok(()) } -async fn test_median_success(table_name: &str, engine: Arc) -> Result<()> +async fn test_median_success( + column_name: &str, + table_name: &str, + engine: Arc, +) -> Result<()> where T: Primitive + Ord + DataTypeBuilder, for<'a> T: Scalar = T>, { - let result = execute_median(table_name, engine.clone()).await.unwrap(); + let result = execute_median(column_name, table_name, engine.clone()) + .await + .unwrap(); assert_eq!(1, result.len()); assert_eq!(result[0].df_recordbatch.num_columns(), 1); assert_eq!(1, result[0].schema.arrow_schema().fields().len()); @@ -240,7 +285,7 @@ where assert_eq!(1, v.len()); let median = v.get(0); - let mut numbers = get_numbers_from_table::(table_name, engine.clone()).await; + let mut numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; numbers.sort(); let len = numbers.len(); let expected_median: Value = if len % 2 == 1 { @@ -255,11 +300,15 @@ where Ok(()) } -async fn test_median_failed(table_name: &str, engine: Arc) -> Result<()> +async fn test_median_failed( + column_name: &str, + table_name: &str, + engine: Arc, +) -> Result<()> where T: Primitive + DataTypeBuilder, { - let result = execute_median(table_name, engine).await; + let result = execute_median(column_name, table_name, engine).await; assert!(result.is_err()); let error = result.unwrap_err(); assert!(error.to_string().contains(&format!( @@ -269,11 +318,11 @@ where Ok(()) } -async fn execute_median( - table_name: &str, +async fn execute_median<'a>( + column_name: &'a str, + table_name: &'a str, engine: Arc, ) -> RecordResult> { - let column_name = table_name; let sql = format!( "select MEDIAN({}) as median from {}", column_name, table_name diff --git a/src/query/tests/testing_table.rs b/src/query/tests/testing_table.rs deleted file mode 100644 index 6965af1908..0000000000 --- a/src/query/tests/testing_table.rs +++ /dev/null @@ -1,73 +0,0 @@ -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; - -use common_query::prelude::Expr; -use common_recordbatch::error::Result as RecordBatchResult; -use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStream}; -use datatypes::prelude::VectorRef; -use datatypes::schema::{ColumnSchema, Schema, SchemaRef}; -use futures::task::{Context, Poll}; -use futures::Stream; -use table::error::Result; -use table::Table; - -#[derive(Debug, Clone)] -pub struct TestingTable { - records: RecordBatch, -} - -impl TestingTable { - pub fn new(column_name: &str, values: VectorRef) -> Self { - let column_schemas = vec![ColumnSchema::new(column_name, values.data_type(), false)]; - let schema = Arc::new(Schema::new(column_schemas)); - Self { - records: RecordBatch::new(schema, vec![values]).unwrap(), - } - } -} - -#[async_trait::async_trait] -impl Table for TestingTable { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.records.schema.clone() - } - - async fn scan( - &self, - _projection: &Option>, - _filters: &[Expr], - _limit: Option, - ) -> Result { - Ok(Box::pin(TestingRecordsStream { - schema: self.records.schema.clone(), - records: Some(self.records.clone()), - })) - } -} - -impl RecordBatchStream for TestingRecordsStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -struct TestingRecordsStream { - schema: SchemaRef, - records: Option, -} - -impl Stream for TestingRecordsStream { - type Item = RecordBatchResult; - - fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { - match self.records.take() { - Some(records) => Poll::Ready(Some(Ok(records))), - None => Poll::Ready(None), - } - } -} diff --git a/src/table/src/error.rs b/src/table/src/error.rs index 0fb7bb77a8..08aab5b2a9 100644 --- a/src/table/src/error.rs +++ b/src/table/src/error.rs @@ -36,14 +36,27 @@ pub enum InnerError { source: ArrowError, backtrace: Backtrace, }, + + #[snafu(display("Failed to convert Arrow schema, source: {}", source))] + SchemaConversion { + source: datatypes::error::Error, + backtrace: Backtrace, + }, + + #[snafu(display("Table projection error, source: {}", source))] + TableProjection { + source: ArrowError, + backtrace: Backtrace, + }, } impl ErrorExt for InnerError { fn status_code(&self) -> StatusCode { match self { - InnerError::Datafusion { .. } | InnerError::PollStream { .. } => { - StatusCode::EngineExecuteQuery - } + InnerError::Datafusion { .. } + | InnerError::PollStream { .. } + | InnerError::SchemaConversion { .. } + | InnerError::TableProjection { .. } => StatusCode::EngineExecuteQuery, InnerError::MissingColumn { .. } => StatusCode::InvalidArguments, InnerError::ExecuteRepeatedly { .. } => StatusCode::Unexpected, } diff --git a/src/table/src/table/adapter.rs b/src/table/src/table/adapter.rs index 60e329a845..4d4c71b0f1 100644 --- a/src/table/src/table/adapter.rs +++ b/src/table/src/table/adapter.rs @@ -25,8 +25,8 @@ use datafusion::physical_plan::{ }; use datafusion_common::record_batch::RecordBatch as DfRecordBatch; use datatypes::arrow::error::{ArrowError, Result as ArrowResult}; -use datatypes::schema::SchemaRef as TableSchemaRef; use datatypes::schema::SchemaRef; +use datatypes::schema::{Schema, SchemaRef as TableSchemaRef}; use futures::Stream; use snafu::prelude::*; @@ -215,10 +215,7 @@ impl Table for TableAdapter { .await .context(error::DatafusionSnafu)?; - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), - df_stream, - ))) + Ok(Box::pin(RecordBatchStreamAdapter::try_new(df_stream)?)) } fn supports_filter_pushdown(&self, filter: &Expr) -> Result { @@ -278,8 +275,10 @@ pub struct RecordBatchStreamAdapter { } impl RecordBatchStreamAdapter { - pub fn new(schema: SchemaRef, stream: DfSendableRecordBatchStream) -> Self { - Self { schema, stream } + pub fn try_new(stream: DfSendableRecordBatchStream) -> Result { + let schema = + Arc::new(Schema::try_from(stream.schema()).context(error::SchemaConversionSnafu)?); + Ok(Self { schema, stream }) } } diff --git a/test-util/Cargo.toml b/test-util/Cargo.toml new file mode 100644 index 0000000000..49fe6ee31c --- /dev/null +++ b/test-util/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "test-util" +version = "0.1.0" +edition = "2021" + +[dependencies.arrow] +package = "arrow2" +version="0.10" +features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc", "ahash", "compute", "serde_types"] + +[dependencies] +async-trait = "0.1" +common-query = { path = "../src/common/query" } +common-recordbatch = {path = "../src/common/recordbatch" } +datafusion = { git = "https://github.com/apache/arrow-datafusion.git" , branch = "arrow2", features = ["simd"]} +datatypes = {path = "../src/datatypes" } +futures = "0.3" +snafu = { version = "0.7", features = ["backtraces"] } +table = { path = "../src/table" } +tokio = { version = "1.20", features = ["full"] } diff --git a/test-util/src/lib.rs b/test-util/src/lib.rs new file mode 100644 index 0000000000..135489fa90 --- /dev/null +++ b/test-util/src/lib.rs @@ -0,0 +1,3 @@ +mod memtable; + +pub use memtable::MemTable; diff --git a/test-util/src/memtable.rs b/test-util/src/memtable.rs new file mode 100644 index 0000000000..08066eb3ff --- /dev/null +++ b/test-util/src/memtable.rs @@ -0,0 +1,172 @@ +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; + +use async_trait::async_trait; +use common_query::prelude::Expr; +use common_recordbatch::error::Result as RecordBatchResult; +use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStream}; +use datatypes::schema::{Schema, SchemaRef}; +use futures::task::{Context, Poll}; +use futures::Stream; +use snafu::prelude::*; +use table::error::{Result, SchemaConversionSnafu, TableProjectionSnafu}; +use table::Table; + +#[derive(Debug, Clone)] +pub struct MemTable { + recordbatch: RecordBatch, +} + +impl MemTable { + pub fn new(recordbatch: RecordBatch) -> Self { + Self { recordbatch } + } +} + +#[async_trait] +impl Table for MemTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.recordbatch.schema.clone() + } + + async fn scan( + &self, + projection: &Option>, + _filters: &[Expr], + limit: Option, + ) -> Result { + let df_recordbatch = if let Some(indices) = projection { + self.recordbatch + .df_recordbatch + .project(indices) + .context(TableProjectionSnafu)? + } else { + self.recordbatch.df_recordbatch.clone() + }; + + let rows = df_recordbatch.num_rows(); + let limit = if let Some(limit) = limit { + limit.min(rows) + } else { + rows + }; + let df_recordbatch = df_recordbatch.slice(0, limit); + + let recordbatch = RecordBatch { + schema: Arc::new( + Schema::try_from(df_recordbatch.schema().clone()).context(SchemaConversionSnafu)?, + ), + df_recordbatch, + }; + Ok(Box::pin(MemtableStream { + schema: recordbatch.schema.clone(), + recordbatch: Some(recordbatch), + })) + } +} + +impl RecordBatchStream for MemtableStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +struct MemtableStream { + schema: SchemaRef, + recordbatch: Option, +} + +impl Stream for MemtableStream { + type Item = RecordBatchResult; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll> { + match self.recordbatch.take() { + Some(records) => Poll::Ready(Some(Ok(records))), + None => Poll::Ready(None), + } + } +} + +#[cfg(test)] +mod test { + use common_recordbatch::util; + use datatypes::prelude::*; + use datatypes::schema::ColumnSchema; + use datatypes::vectors::{Int32Vector, StringVector}; + + use super::*; + + #[tokio::test] + async fn test_scan_with_projection() { + let table = build_testing_table(); + + let scan_stream = table.scan(&Some(vec![1]), &[], None).await.unwrap(); + let recordbatch = util::collect(scan_stream).await.unwrap(); + assert_eq!(1, recordbatch.len()); + let columns = recordbatch[0].df_recordbatch.columns(); + assert_eq!(1, columns.len()); + + let string_column = VectorHelper::try_into_vector(&columns[0]).unwrap(); + let string_column = string_column + .as_any() + .downcast_ref::() + .unwrap(); + let string_column = string_column.iter_data().flatten().collect::>(); + assert_eq!(vec!["hello", "greptime"], string_column); + } + + #[tokio::test] + async fn test_scan_with_limit() { + let table = build_testing_table(); + + let scan_stream = table.scan(&None, &[], Some(2)).await.unwrap(); + let recordbatch = util::collect(scan_stream).await.unwrap(); + assert_eq!(1, recordbatch.len()); + let columns = recordbatch[0].df_recordbatch.columns(); + assert_eq!(2, columns.len()); + + let i32_column = VectorHelper::try_into_vector(&columns[0]).unwrap(); + let i32_column = i32_column.as_any().downcast_ref::().unwrap(); + let i32_column = i32_column.iter_data().flatten().collect::>(); + assert_eq!(vec![-100], i32_column); + + let string_column = VectorHelper::try_into_vector(&columns[1]).unwrap(); + let string_column = string_column + .as_any() + .downcast_ref::() + .unwrap(); + let string_column = string_column.iter_data().flatten().collect::>(); + assert_eq!(vec!["hello"], string_column); + } + + fn build_testing_table() -> MemTable { + let i32_column_schema = + ColumnSchema::new("i32_numbers", ConcreteDataType::int32_datatype(), true); + let string_column_schema = + ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true); + let column_schemas = vec![i32_column_schema, string_column_schema]; + + let schema = Arc::new(Schema::new(column_schemas)); + let columns: Vec = vec![ + Arc::new(Int32Vector::from(vec![ + Some(-100), + None, + Some(1), + Some(100), + ])), + Arc::new(StringVector::from(vec![ + Some("hello"), + None, + Some("greptime"), + None, + ])), + ]; + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + MemTable::new(recordbatch) + } +}