mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-05 21:02:58 +00:00
feat: extract MemTable to ease testing (#133)
* feat: memtable backed by DataFusion to ease testing * move test utility codes out of src folder * Implement our own MemTable because DataFusion's MemTable does not support limit; and replace the original testing numbers table. * fix: address PR comments * fix: "testutil" -> "test-util" * roll back "NumbersTable" Co-authored-by: luofucong <luofucong@greptime.com>
This commit is contained in:
17
Cargo.lock
generated
17
Cargo.lock
generated
@@ -3015,6 +3015,7 @@ dependencies = [
|
||||
"snafu",
|
||||
"sql",
|
||||
"table",
|
||||
"test-util",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
]
|
||||
@@ -3786,6 +3787,22 @@ dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "test-util"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"arrow2",
|
||||
"async-trait",
|
||||
"common-query",
|
||||
"common-recordbatch",
|
||||
"datafusion",
|
||||
"datatypes",
|
||||
"futures",
|
||||
"snafu",
|
||||
"table",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.11.0"
|
||||
|
||||
@@ -23,4 +23,5 @@ members = [
|
||||
"src/store-api",
|
||||
"src/table",
|
||||
"src/table-engine",
|
||||
"test-util",
|
||||
]
|
||||
|
||||
@@ -32,5 +32,6 @@ tokio = "1.0"
|
||||
num = "0.4"
|
||||
num-traits = "0.2"
|
||||
rand = "0.8"
|
||||
test-util = { path = "../../test-util" }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
tokio-stream = "0.1"
|
||||
|
||||
@@ -224,7 +224,7 @@ impl SchemaProvider for SchemaProviderAdapter {
|
||||
})?
|
||||
.map(|table| {
|
||||
let adapter = TableAdapter::new(table, self.runtime.clone())
|
||||
.context(error::ConvertTableSnafu)?;
|
||||
.context(error::TableSchemaMismatchSnafu)?;
|
||||
Ok(Arc::new(adapter) as _)
|
||||
})
|
||||
.transpose()
|
||||
|
||||
@@ -39,8 +39,8 @@ pub enum InnerError {
|
||||
source: datatypes::error::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("Fail to convert table, source: {}", source))]
|
||||
ConvertTable {
|
||||
#[snafu(display("Failed to convert table schema, source: {}", source))]
|
||||
TableSchemaMismatch {
|
||||
#[snafu(backtrace)]
|
||||
source: table::error::Error,
|
||||
},
|
||||
@@ -54,7 +54,7 @@ impl ErrorExt for InnerError {
|
||||
// TODO(yingwen): Further categorize datafusion error.
|
||||
Datafusion { .. } => StatusCode::EngineExecuteQuery,
|
||||
// This downcast should not fail in usual case.
|
||||
PhysicalPlanDowncast { .. } | ConvertSchema { .. } | ConvertTable { .. } => {
|
||||
PhysicalPlanDowncast { .. } | ConvertSchema { .. } | TableSchemaMismatch { .. } => {
|
||||
StatusCode::Unexpected
|
||||
}
|
||||
ParseSql { source, .. } => source.status_code(),
|
||||
|
||||
@@ -97,10 +97,10 @@ impl PhysicalPlan for PhysicalPlanAdapter {
|
||||
msg: "Fail to execute physical plan",
|
||||
})?;
|
||||
|
||||
Ok(Box::pin(RecordBatchStreamAdapter::new(
|
||||
self.schema.clone(),
|
||||
df_stream,
|
||||
)))
|
||||
Ok(Box::pin(
|
||||
RecordBatchStreamAdapter::try_new(df_stream)
|
||||
.context(error::TableSchemaMismatchSnafu)?,
|
||||
))
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
|
||||
@@ -2,8 +2,6 @@ use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
mod testing_table;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use catalog::memory::{MemoryCatalogList, MemoryCatalogProvider, MemorySchemaProvider};
|
||||
use catalog::{CatalogList, SchemaProvider, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
@@ -13,10 +11,11 @@ use common_query::error::Result as QueryResult;
|
||||
use common_query::logical_plan::Accumulator;
|
||||
use common_query::logical_plan::AggregateFunctionCreator;
|
||||
use common_query::prelude::*;
|
||||
use common_recordbatch::util;
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use datafusion::arrow_print;
|
||||
use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::types::DataTypeBuilder;
|
||||
use datatypes::types::PrimitiveType;
|
||||
use datatypes::vectors::PrimitiveVector;
|
||||
@@ -26,8 +25,7 @@ use query::error::Result;
|
||||
use query::query_engine::Output;
|
||||
use query::QueryEngineFactory;
|
||||
use table::TableRef;
|
||||
|
||||
use crate::testing_table::TestingTable;
|
||||
use test_util::MemTable;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct MySumAccumulator<T, SumT>
|
||||
@@ -217,10 +215,15 @@ where
|
||||
let table_name = format!("{}_numbers", std::any::type_name::<T>());
|
||||
let column_name = format!("{}_number", std::any::type_name::<T>());
|
||||
|
||||
let testing_table = Arc::new(TestingTable::new(
|
||||
&column_name,
|
||||
Arc::new(PrimitiveVector::<T>::from_vec(numbers.clone())),
|
||||
));
|
||||
let column_schemas = vec![ColumnSchema::new(
|
||||
column_name.clone(),
|
||||
T::build_data_type(),
|
||||
true,
|
||||
)];
|
||||
let schema = Arc::new(Schema::new(column_schemas.clone()));
|
||||
let column: VectorRef = Arc::new(PrimitiveVector::<T>::from_vec(numbers));
|
||||
let recordbatch = RecordBatch::new(schema, vec![column]).unwrap();
|
||||
let testing_table = Arc::new(MemTable::new(recordbatch));
|
||||
|
||||
let factory = new_query_engine_factory(table_name.clone(), testing_table);
|
||||
let engine = factory.query_engine();
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
mod pow;
|
||||
mod testing_table;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -14,8 +13,9 @@ use datafusion::field_util::SchemaExt;
|
||||
use datafusion::logical_plan::LogicalPlanBuilder;
|
||||
use datatypes::for_all_ordered_primitive_types;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::types::DataTypeBuilder;
|
||||
use datatypes::vectors::PrimitiveVector;
|
||||
use datatypes::vectors::{Float32Vector, Float64Vector, PrimitiveVector, UInt32Vector};
|
||||
use num::NumCast;
|
||||
use query::error::Result;
|
||||
use query::plan::LogicalPlan;
|
||||
@@ -23,10 +23,9 @@ use query::query_engine::{Output, QueryEngineFactory};
|
||||
use query::QueryEngine;
|
||||
use rand::Rng;
|
||||
use table::table::adapter::DfTableProviderAdapter;
|
||||
use table::table::numbers::NumbersTable;
|
||||
use test_util::MemTable;
|
||||
|
||||
use crate::pow::pow;
|
||||
use crate::testing_table::TestingTable;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_datafusion_query_engine() -> Result<()> {
|
||||
@@ -35,8 +34,19 @@ async fn test_datafusion_query_engine() -> Result<()> {
|
||||
let factory = QueryEngineFactory::new(catalog_list);
|
||||
let engine = factory.query_engine();
|
||||
|
||||
let column_schemas = vec![ColumnSchema::new(
|
||||
"number",
|
||||
ConcreteDataType::uint32_datatype(),
|
||||
false,
|
||||
)];
|
||||
let schema = Arc::new(Schema::new(column_schemas));
|
||||
let columns: Vec<VectorRef> = vec![Arc::new(UInt32Vector::from_slice(
|
||||
(0..100).collect::<Vec<_>>(),
|
||||
))];
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let table = Arc::new(MemTable::new(recordbatch));
|
||||
|
||||
let limit = 10;
|
||||
let table = Arc::new(NumbersTable::default());
|
||||
let table_provider = Arc::new(DfTableProviderAdapter::new(table.clone()));
|
||||
let plan = LogicalPlan::DfPlan(
|
||||
LogicalPlanBuilder::scan("numbers", table_provider, None)
|
||||
@@ -126,47 +136,73 @@ fn create_query_engine() -> Arc<dyn QueryEngine> {
|
||||
let catalog_provider = Arc::new(MemoryCatalogProvider::new());
|
||||
let catalog_list = Arc::new(MemoryCatalogList::default());
|
||||
|
||||
macro_rules! create_testing_table {
|
||||
// create table with ordered primitives, and all columns' length are even
|
||||
let mut column_schemas = vec![];
|
||||
let mut columns = vec![];
|
||||
macro_rules! create_even_number_table {
|
||||
([], $( { $T:ty } ),*) => {
|
||||
$(
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let table_name = format!("{}_number_even", std::any::type_name::<$T>());
|
||||
let column_name = table_name.clone();
|
||||
let numbers = (1..=100).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::<Vec<$T>>();
|
||||
let table = Arc::new(TestingTable::new(
|
||||
&column_name,
|
||||
Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())),
|
||||
));
|
||||
schema_provider.register_table(table_name, table).unwrap();
|
||||
let column_name = format!("{}_number_even", std::any::type_name::<$T>());
|
||||
let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true);
|
||||
column_schemas.push(column_schema);
|
||||
|
||||
let table_name = format!("{}_number_odd", std::any::type_name::<$T>());
|
||||
let column_name = table_name.clone();
|
||||
let numbers = (1..=99).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::<Vec<$T>>();
|
||||
let table = Arc::new(TestingTable::new(
|
||||
&column_name,
|
||||
Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())),
|
||||
));
|
||||
schema_provider.register_table(table_name, table).unwrap();
|
||||
let numbers = (1..=100).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::<Vec<$T>>();
|
||||
let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec()));
|
||||
columns.push(column);
|
||||
)*
|
||||
}
|
||||
}
|
||||
for_all_ordered_primitive_types! { create_testing_table }
|
||||
for_all_ordered_primitive_types! { create_even_number_table }
|
||||
|
||||
let table = Arc::new(TestingTable::new(
|
||||
"f32_number",
|
||||
Arc::new(PrimitiveVector::<f32>::from_vec(vec![1.0f32, 2.0, 3.0])),
|
||||
));
|
||||
let schema = Arc::new(Schema::new(column_schemas.clone()));
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let even_number_table = Arc::new(MemTable::new(recordbatch));
|
||||
schema_provider
|
||||
.register_table("f32_number".to_string(), table)
|
||||
.register_table("even_numbers".to_string(), even_number_table)
|
||||
.unwrap();
|
||||
|
||||
let table = Arc::new(TestingTable::new(
|
||||
"f64_number",
|
||||
Arc::new(PrimitiveVector::<f64>::from_vec(vec![1.0f64, 2.0, 3.0])),
|
||||
));
|
||||
// create table with ordered primitives, and all columns' length are odd
|
||||
let mut column_schemas = vec![];
|
||||
let mut columns = vec![];
|
||||
macro_rules! create_odd_number_table {
|
||||
([], $( { $T:ty } ),*) => {
|
||||
$(
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let column_name = format!("{}_number_odd", std::any::type_name::<$T>());
|
||||
let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true);
|
||||
column_schemas.push(column_schema);
|
||||
|
||||
let numbers = (1..=99).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::<Vec<$T>>();
|
||||
let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec()));
|
||||
columns.push(column);
|
||||
)*
|
||||
}
|
||||
}
|
||||
for_all_ordered_primitive_types! { create_odd_number_table }
|
||||
|
||||
let schema = Arc::new(Schema::new(column_schemas.clone()));
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let odd_number_table = Arc::new(MemTable::new(recordbatch));
|
||||
schema_provider
|
||||
.register_table("f64_number".to_string(), table)
|
||||
.register_table("odd_numbers".to_string(), odd_number_table)
|
||||
.unwrap();
|
||||
|
||||
// create table with floating numbers
|
||||
let column_schemas = vec![
|
||||
ColumnSchema::new("f32_number", ConcreteDataType::float32_datatype(), true),
|
||||
ColumnSchema::new("f64_number", ConcreteDataType::float64_datatype(), true),
|
||||
];
|
||||
let f32_numbers: VectorRef = Arc::new(Float32Vector::from_vec(vec![1.0f32, 2.0, 3.0]));
|
||||
let f64_numbers: VectorRef = Arc::new(Float64Vector::from_vec(vec![1.0f64, 2.0, 3.0]));
|
||||
let columns = vec![f32_numbers, f64_numbers];
|
||||
let schema = Arc::new(Schema::new(column_schemas));
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let float_number_table = Arc::new(MemTable::new(recordbatch));
|
||||
schema_provider
|
||||
.register_table("float_numbers".to_string(), float_number_table)
|
||||
.unwrap();
|
||||
|
||||
catalog_provider.register_schema(DEFAULT_SCHEMA_NAME, schema_provider);
|
||||
@@ -176,12 +212,15 @@ fn create_query_engine() -> Arc<dyn QueryEngine> {
|
||||
factory.query_engine().clone()
|
||||
}
|
||||
|
||||
async fn get_numbers_from_table<T>(table_name: &str, engine: Arc<dyn QueryEngine>) -> Vec<T>
|
||||
async fn get_numbers_from_table<'s, T>(
|
||||
column_name: &'s str,
|
||||
table_name: &'s str,
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> Vec<T>
|
||||
where
|
||||
T: Primitive + DataTypeBuilder,
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
{
|
||||
let column_name = table_name;
|
||||
let sql = format!("SELECT {} FROM {}", column_name, table_name);
|
||||
let plan = engine.sql_to_plan(&sql).unwrap();
|
||||
|
||||
@@ -204,17 +243,17 @@ async fn test_median_aggregator() -> Result<()> {
|
||||
|
||||
let engine = create_query_engine();
|
||||
|
||||
test_median_failed::<f32>("f32_number", engine.clone()).await?;
|
||||
test_median_failed::<f64>("f64_number", engine.clone()).await?;
|
||||
test_median_failed::<f32>("f32_number", "float_numbers", engine.clone()).await?;
|
||||
test_median_failed::<f64>("f64_number", "float_numbers", engine.clone()).await?;
|
||||
|
||||
macro_rules! test_median {
|
||||
([], $( { $T:ty } ),*) => {
|
||||
$(
|
||||
let table_name = format!("{}_number_even", std::any::type_name::<$T>());
|
||||
test_median_success::<$T>(&table_name, engine.clone()).await?;
|
||||
let column_name = format!("{}_number_even", std::any::type_name::<$T>());
|
||||
test_median_success::<$T>(&column_name, "even_numbers", engine.clone()).await?;
|
||||
|
||||
let table_name = format!("{}_number_odd", std::any::type_name::<$T>());
|
||||
test_median_success::<$T>(&table_name, engine.clone()).await?;
|
||||
let column_name = format!("{}_number_odd", std::any::type_name::<$T>());
|
||||
test_median_success::<$T>(&column_name, "odd_numbers", engine.clone()).await?;
|
||||
)*
|
||||
}
|
||||
}
|
||||
@@ -222,12 +261,18 @@ async fn test_median_aggregator() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn test_median_success<T>(table_name: &str, engine: Arc<dyn QueryEngine>) -> Result<()>
|
||||
async fn test_median_success<T>(
|
||||
column_name: &str,
|
||||
table_name: &str,
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: Primitive + Ord + DataTypeBuilder,
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
{
|
||||
let result = execute_median(table_name, engine.clone()).await.unwrap();
|
||||
let result = execute_median(column_name, table_name, engine.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, result.len());
|
||||
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
|
||||
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
|
||||
@@ -240,7 +285,7 @@ where
|
||||
assert_eq!(1, v.len());
|
||||
let median = v.get(0);
|
||||
|
||||
let mut numbers = get_numbers_from_table::<T>(table_name, engine.clone()).await;
|
||||
let mut numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
numbers.sort();
|
||||
let len = numbers.len();
|
||||
let expected_median: Value = if len % 2 == 1 {
|
||||
@@ -255,11 +300,15 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn test_median_failed<T>(table_name: &str, engine: Arc<dyn QueryEngine>) -> Result<()>
|
||||
async fn test_median_failed<T>(
|
||||
column_name: &str,
|
||||
table_name: &str,
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: Primitive + DataTypeBuilder,
|
||||
{
|
||||
let result = execute_median(table_name, engine).await;
|
||||
let result = execute_median(column_name, table_name, engine).await;
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err();
|
||||
assert!(error.to_string().contains(&format!(
|
||||
@@ -269,11 +318,11 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn execute_median(
|
||||
table_name: &str,
|
||||
async fn execute_median<'a>(
|
||||
column_name: &'a str,
|
||||
table_name: &'a str,
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> RecordResult<Vec<RecordBatch>> {
|
||||
let column_name = table_name;
|
||||
let sql = format!(
|
||||
"select MEDIAN({}) as median from {}",
|
||||
column_name, table_name
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
use std::any::Any;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::prelude::Expr;
|
||||
use common_recordbatch::error::Result as RecordBatchResult;
|
||||
use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStream};
|
||||
use datatypes::prelude::VectorRef;
|
||||
use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
|
||||
use futures::task::{Context, Poll};
|
||||
use futures::Stream;
|
||||
use table::error::Result;
|
||||
use table::Table;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TestingTable {
|
||||
records: RecordBatch,
|
||||
}
|
||||
|
||||
impl TestingTable {
|
||||
pub fn new(column_name: &str, values: VectorRef) -> Self {
|
||||
let column_schemas = vec![ColumnSchema::new(column_name, values.data_type(), false)];
|
||||
let schema = Arc::new(Schema::new(column_schemas));
|
||||
Self {
|
||||
records: RecordBatch::new(schema, vec![values]).unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Table for TestingTable {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn schema(&self) -> SchemaRef {
|
||||
self.records.schema.clone()
|
||||
}
|
||||
|
||||
async fn scan(
|
||||
&self,
|
||||
_projection: &Option<Vec<usize>>,
|
||||
_filters: &[Expr],
|
||||
_limit: Option<usize>,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
Ok(Box::pin(TestingRecordsStream {
|
||||
schema: self.records.schema.clone(),
|
||||
records: Some(self.records.clone()),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordBatchStream for TestingRecordsStream {
|
||||
fn schema(&self) -> SchemaRef {
|
||||
self.schema.clone()
|
||||
}
|
||||
}
|
||||
|
||||
struct TestingRecordsStream {
|
||||
schema: SchemaRef,
|
||||
records: Option<RecordBatch>,
|
||||
}
|
||||
|
||||
impl Stream for TestingRecordsStream {
|
||||
type Item = RecordBatchResult<RecordBatch>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match self.records.take() {
|
||||
Some(records) => Poll::Ready(Some(Ok(records))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -36,14 +36,27 @@ pub enum InnerError {
|
||||
source: ArrowError,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to convert Arrow schema, source: {}", source))]
|
||||
SchemaConversion {
|
||||
source: datatypes::error::Error,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Table projection error, source: {}", source))]
|
||||
TableProjection {
|
||||
source: ArrowError,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
}
|
||||
|
||||
impl ErrorExt for InnerError {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
InnerError::Datafusion { .. } | InnerError::PollStream { .. } => {
|
||||
StatusCode::EngineExecuteQuery
|
||||
}
|
||||
InnerError::Datafusion { .. }
|
||||
| InnerError::PollStream { .. }
|
||||
| InnerError::SchemaConversion { .. }
|
||||
| InnerError::TableProjection { .. } => StatusCode::EngineExecuteQuery,
|
||||
InnerError::MissingColumn { .. } => StatusCode::InvalidArguments,
|
||||
InnerError::ExecuteRepeatedly { .. } => StatusCode::Unexpected,
|
||||
}
|
||||
|
||||
@@ -25,8 +25,8 @@ use datafusion::physical_plan::{
|
||||
};
|
||||
use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
|
||||
use datatypes::arrow::error::{ArrowError, Result as ArrowResult};
|
||||
use datatypes::schema::SchemaRef as TableSchemaRef;
|
||||
use datatypes::schema::SchemaRef;
|
||||
use datatypes::schema::{Schema, SchemaRef as TableSchemaRef};
|
||||
use futures::Stream;
|
||||
use snafu::prelude::*;
|
||||
|
||||
@@ -215,10 +215,7 @@ impl Table for TableAdapter {
|
||||
.await
|
||||
.context(error::DatafusionSnafu)?;
|
||||
|
||||
Ok(Box::pin(RecordBatchStreamAdapter::new(
|
||||
self.schema.clone(),
|
||||
df_stream,
|
||||
)))
|
||||
Ok(Box::pin(RecordBatchStreamAdapter::try_new(df_stream)?))
|
||||
}
|
||||
|
||||
fn supports_filter_pushdown(&self, filter: &Expr) -> Result<FilterPushDownType> {
|
||||
@@ -278,8 +275,10 @@ pub struct RecordBatchStreamAdapter {
|
||||
}
|
||||
|
||||
impl RecordBatchStreamAdapter {
|
||||
pub fn new(schema: SchemaRef, stream: DfSendableRecordBatchStream) -> Self {
|
||||
Self { schema, stream }
|
||||
pub fn try_new(stream: DfSendableRecordBatchStream) -> Result<Self> {
|
||||
let schema =
|
||||
Arc::new(Schema::try_from(stream.schema()).context(error::SchemaConversionSnafu)?);
|
||||
Ok(Self { schema, stream })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
20
test-util/Cargo.toml
Normal file
20
test-util/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "test-util"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies.arrow]
|
||||
package = "arrow2"
|
||||
version="0.10"
|
||||
features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc", "ahash", "compute", "serde_types"]
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1"
|
||||
common-query = { path = "../src/common/query" }
|
||||
common-recordbatch = {path = "../src/common/recordbatch" }
|
||||
datafusion = { git = "https://github.com/apache/arrow-datafusion.git" , branch = "arrow2", features = ["simd"]}
|
||||
datatypes = {path = "../src/datatypes" }
|
||||
futures = "0.3"
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
table = { path = "../src/table" }
|
||||
tokio = { version = "1.20", features = ["full"] }
|
||||
3
test-util/src/lib.rs
Normal file
3
test-util/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod memtable;
|
||||
|
||||
pub use memtable::MemTable;
|
||||
172
test-util/src/memtable.rs
Normal file
172
test-util/src/memtable.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
use std::any::Any;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_query::prelude::Expr;
|
||||
use common_recordbatch::error::Result as RecordBatchResult;
|
||||
use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStream};
|
||||
use datatypes::schema::{Schema, SchemaRef};
|
||||
use futures::task::{Context, Poll};
|
||||
use futures::Stream;
|
||||
use snafu::prelude::*;
|
||||
use table::error::{Result, SchemaConversionSnafu, TableProjectionSnafu};
|
||||
use table::Table;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemTable {
|
||||
recordbatch: RecordBatch,
|
||||
}
|
||||
|
||||
impl MemTable {
|
||||
pub fn new(recordbatch: RecordBatch) -> Self {
|
||||
Self { recordbatch }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Table for MemTable {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn schema(&self) -> SchemaRef {
|
||||
self.recordbatch.schema.clone()
|
||||
}
|
||||
|
||||
async fn scan(
|
||||
&self,
|
||||
projection: &Option<Vec<usize>>,
|
||||
_filters: &[Expr],
|
||||
limit: Option<usize>,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let df_recordbatch = if let Some(indices) = projection {
|
||||
self.recordbatch
|
||||
.df_recordbatch
|
||||
.project(indices)
|
||||
.context(TableProjectionSnafu)?
|
||||
} else {
|
||||
self.recordbatch.df_recordbatch.clone()
|
||||
};
|
||||
|
||||
let rows = df_recordbatch.num_rows();
|
||||
let limit = if let Some(limit) = limit {
|
||||
limit.min(rows)
|
||||
} else {
|
||||
rows
|
||||
};
|
||||
let df_recordbatch = df_recordbatch.slice(0, limit);
|
||||
|
||||
let recordbatch = RecordBatch {
|
||||
schema: Arc::new(
|
||||
Schema::try_from(df_recordbatch.schema().clone()).context(SchemaConversionSnafu)?,
|
||||
),
|
||||
df_recordbatch,
|
||||
};
|
||||
Ok(Box::pin(MemtableStream {
|
||||
schema: recordbatch.schema.clone(),
|
||||
recordbatch: Some(recordbatch),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordBatchStream for MemtableStream {
|
||||
fn schema(&self) -> SchemaRef {
|
||||
self.schema.clone()
|
||||
}
|
||||
}
|
||||
|
||||
struct MemtableStream {
|
||||
schema: SchemaRef,
|
||||
recordbatch: Option<RecordBatch>,
|
||||
}
|
||||
|
||||
impl Stream for MemtableStream {
|
||||
type Item = RecordBatchResult<RecordBatch>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match self.recordbatch.take() {
|
||||
Some(records) => Poll::Ready(Some(Ok(records))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use common_recordbatch::util;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::schema::ColumnSchema;
|
||||
use datatypes::vectors::{Int32Vector, StringVector};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_with_projection() {
|
||||
let table = build_testing_table();
|
||||
|
||||
let scan_stream = table.scan(&Some(vec![1]), &[], None).await.unwrap();
|
||||
let recordbatch = util::collect(scan_stream).await.unwrap();
|
||||
assert_eq!(1, recordbatch.len());
|
||||
let columns = recordbatch[0].df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
|
||||
let string_column = VectorHelper::try_into_vector(&columns[0]).unwrap();
|
||||
let string_column = string_column
|
||||
.as_any()
|
||||
.downcast_ref::<StringVector>()
|
||||
.unwrap();
|
||||
let string_column = string_column.iter_data().flatten().collect::<Vec<&str>>();
|
||||
assert_eq!(vec!["hello", "greptime"], string_column);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_with_limit() {
|
||||
let table = build_testing_table();
|
||||
|
||||
let scan_stream = table.scan(&None, &[], Some(2)).await.unwrap();
|
||||
let recordbatch = util::collect(scan_stream).await.unwrap();
|
||||
assert_eq!(1, recordbatch.len());
|
||||
let columns = recordbatch[0].df_recordbatch.columns();
|
||||
assert_eq!(2, columns.len());
|
||||
|
||||
let i32_column = VectorHelper::try_into_vector(&columns[0]).unwrap();
|
||||
let i32_column = i32_column.as_any().downcast_ref::<Int32Vector>().unwrap();
|
||||
let i32_column = i32_column.iter_data().flatten().collect::<Vec<i32>>();
|
||||
assert_eq!(vec![-100], i32_column);
|
||||
|
||||
let string_column = VectorHelper::try_into_vector(&columns[1]).unwrap();
|
||||
let string_column = string_column
|
||||
.as_any()
|
||||
.downcast_ref::<StringVector>()
|
||||
.unwrap();
|
||||
let string_column = string_column.iter_data().flatten().collect::<Vec<&str>>();
|
||||
assert_eq!(vec!["hello"], string_column);
|
||||
}
|
||||
|
||||
fn build_testing_table() -> MemTable {
|
||||
let i32_column_schema =
|
||||
ColumnSchema::new("i32_numbers", ConcreteDataType::int32_datatype(), true);
|
||||
let string_column_schema =
|
||||
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true);
|
||||
let column_schemas = vec![i32_column_schema, string_column_schema];
|
||||
|
||||
let schema = Arc::new(Schema::new(column_schemas));
|
||||
let columns: Vec<VectorRef> = vec![
|
||||
Arc::new(Int32Vector::from(vec![
|
||||
Some(-100),
|
||||
None,
|
||||
Some(1),
|
||||
Some(100),
|
||||
])),
|
||||
Arc::new(StringVector::from(vec![
|
||||
Some("hello"),
|
||||
None,
|
||||
Some("greptime"),
|
||||
None,
|
||||
])),
|
||||
];
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
MemTable::new(recordbatch)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user