From b936d8b18a6670421884c10cdacaaacb8713022c Mon Sep 17 00:00:00 2001 From: Yingwen Date: Thu, 8 Dec 2022 17:51:20 +0800 Subject: [PATCH] fix: Fix common::grpc compiler errors (#722) * fix: Fix common::grpc compiler errors This commit refactors RecordBatch and holds vectors in the RecordBatch struct, so we don't need to cast the array to vector when doing serialization or iterating the batch. Now we use the vector API instead of the arrow API in grpc crate. * chore: Address CR comments --- Cargo.lock | 1 + src/common/grpc/Cargo.toml | 1 + src/common/grpc/src/select.rs | 211 ++++++++++++-------- src/common/grpc/src/writer.rs | 52 ++--- src/common/recordbatch/src/adapter.rs | 19 +- src/common/recordbatch/src/lib.rs | 4 +- src/common/recordbatch/src/recordbatch.rs | 123 +++++++----- src/datatypes/src/arrow_array.rs | 224 ---------------------- 8 files changed, 255 insertions(+), 380 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 468633826d..ead0a5fc2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1321,6 +1321,7 @@ dependencies = [ "criterion 0.4.0", "dashmap", "datafusion", + "datatypes", "rand 0.8.5", "snafu", "tokio", diff --git a/src/common/grpc/Cargo.toml b/src/common/grpc/Cargo.toml index 7665f3b721..b1b5a25b6e 100644 --- a/src/common/grpc/Cargo.toml +++ b/src/common/grpc/Cargo.toml @@ -14,6 +14,7 @@ common-recordbatch = { path = "../recordbatch" } common-runtime = { path = "../runtime" } dashmap = "5.4" datafusion = "14.0.0" +datatypes = { path = "../../datatypes" } snafu = { version = "0.7", features = ["backtraces"] } tokio = { version = "1.0", features = ["full"] } tonic = "0.8" diff --git a/src/common/grpc/src/select.rs b/src/common/grpc/src/select.rs index 0801370dbd..f352d5697d 100644 --- a/src/common/grpc/src/select.rs +++ b/src/common/grpc/src/select.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use api::helper::ColumnDataTypeWrapper; use api::result::{build_err_result, ObjectResultBuilder}; use api::v1::codec::SelectResult; @@ -24,9 +22,14 @@ use common_error::prelude::ErrorExt; use common_error::status_code::StatusCode; use common_query::Output; use common_recordbatch::{util, RecordBatches, SendableRecordBatchStream}; -use datatypes::arrow::array::{Array, BooleanArray, PrimitiveArray}; -use datatypes::arrow_array::{BinaryArray, StringArray}; use datatypes::schema::SchemaRef; +use datatypes::types::{TimestampType, WrapperType}; +use datatypes::vectors::{ + BinaryVector, BooleanVector, DateTimeVector, DateVector, Float32Vector, Float64Vector, + Int16Vector, Int32Vector, Int64Vector, Int8Vector, StringVector, TimestampMicrosecondVector, + TimestampMillisecondVector, TimestampNanosecondVector, TimestampSecondVector, UInt16Vector, + UInt32Vector, UInt64Vector, UInt8Vector, VectorRef, +}; use snafu::{OptionExt, ResultExt}; use crate::error::{self, ConversionSnafu, Result}; @@ -46,6 +49,7 @@ pub async fn to_object_result(output: std::result::Result Err(e) => build_err_result(&e), } } + async fn collect(stream: SendableRecordBatchStream) -> Result { let schema = stream.schema(); @@ -82,10 +86,7 @@ fn try_convert(record_batches: RecordBatches) -> Result { let schema = record_batches.schema(); let record_batches = record_batches.take(); - let row_count: usize = record_batches - .iter() - .map(|r| r.df_recordbatch.num_rows()) - .sum(); + let row_count: usize = record_batches.iter().map(|r| r.num_rows()).sum(); let schemas = schema.column_schemas(); let mut columns = Vec::with_capacity(schemas.len()); @@ -93,9 +94,9 @@ fn try_convert(record_batches: RecordBatches) -> Result { for (idx, column_schema) in schemas.iter().enumerate() { let column_name = column_schema.name.clone(); - let arrays: Vec> = record_batches + let arrays: Vec<_> = record_batches .iter() - .map(|r| r.df_recordbatch.columns()[idx].clone()) + .map(|r| r.column(idx).clone()) .collect(); let column = Column { @@ -116,7 +117,7 @@ fn try_convert(record_batches: RecordBatches) -> Result { }) } -pub fn null_mask(arrays: &Vec>, row_count: usize) -> Vec { +pub fn null_mask(arrays: &[VectorRef], row_count: usize) -> Vec { let null_count: usize = arrays.iter().map(|a| a.null_count()).sum(); if null_count == 0 { @@ -126,10 +127,12 @@ pub fn null_mask(arrays: &Vec>, row_count: usize) -> Vec { let mut null_mask = BitVec::with_capacity(row_count); for array in arrays { let validity = array.validity(); - if let Some(v) = validity { - v.iter().for_each(|x| null_mask.push(!x)); - } else { + if validity.is_all_valid() { null_mask.extend_from_bitslice(&BitVec::repeat(false, array.len())); + } else { + for i in 0..array.len() { + null_mask.push(!validity.is_set(i)); + } } } null_mask.into_vec() @@ -137,7 +140,9 @@ pub fn null_mask(arrays: &Vec>, row_count: usize) -> Vec { macro_rules! convert_arrow_array_to_grpc_vals { ($data_type: expr, $arrays: ident, $(($Type: pat, $CastType: ty, $field: ident, $MapFunction: expr)), +) => {{ - use datatypes::arrow::datatypes::{DataType, TimeUnit}; + use datatypes::data_type::{ConcreteDataType}; + use datatypes::prelude::ScalarVector; + match $data_type { $( $Type => { @@ -147,52 +152,114 @@ macro_rules! convert_arrow_array_to_grpc_vals { from: format!("{:?}", $data_type), })?; vals.$field.extend(array - .iter() + .iter_data() .filter_map(|i| i.map($MapFunction)) .collect::>()); } return Ok(vals); }, )+ - _ => unimplemented!(), + ConcreteDataType::Null(_) | ConcreteDataType::List(_) => unreachable!("Should not send {:?} in gRPC", $data_type), } }}; } -pub fn values(arrays: &[Arc]) -> Result { +pub fn values(arrays: &[VectorRef]) -> Result { if arrays.is_empty() { return Ok(Values::default()); } let data_type = arrays[0].data_type(); convert_arrow_array_to_grpc_vals!( - data_type, arrays, - - (DataType::Boolean, BooleanArray, bool_values, |x| {x}), - - (DataType::Int8, PrimitiveArray, i8_values, |x| {*x as i32}), - (DataType::Int16, PrimitiveArray, i16_values, |x| {*x as i32}), - (DataType::Int32, PrimitiveArray, i32_values, |x| {*x}), - (DataType::Int64, PrimitiveArray, i64_values, |x| {*x}), - - (DataType::UInt8, PrimitiveArray, u8_values, |x| {*x as u32}), - (DataType::UInt16, PrimitiveArray, u16_values, |x| {*x as u32}), - (DataType::UInt32, PrimitiveArray, u32_values, |x| {*x}), - (DataType::UInt64, PrimitiveArray, u64_values, |x| {*x}), - - (DataType::Float32, PrimitiveArray, f32_values, |x| {*x}), - (DataType::Float64, PrimitiveArray, f64_values, |x| {*x}), - - (DataType::Binary, BinaryArray, binary_values, |x| {x.into()}), - (DataType::LargeBinary, BinaryArray, binary_values, |x| {x.into()}), - - (DataType::Utf8, StringArray, string_values, |x| {x.into()}), - (DataType::LargeUtf8, StringArray, string_values, |x| {x.into()}), - - (DataType::Date32, PrimitiveArray, date_values, |x| {*x as i32}), - (DataType::Date64, PrimitiveArray, datetime_values,|x| {*x as i64}), - - (DataType::Timestamp(TimeUnit::Millisecond, _), PrimitiveArray, ts_millis_values, |x| {*x}) + data_type, + arrays, + ( + ConcreteDataType::Boolean(_), + BooleanVector, + bool_values, + |x| { x } + ), + (ConcreteDataType::Int8(_), Int8Vector, i8_values, |x| { + i32::from(x) + }), + (ConcreteDataType::Int16(_), Int16Vector, i16_values, |x| { + i32::from(x) + }), + (ConcreteDataType::Int32(_), Int32Vector, i32_values, |x| { + x + }), + (ConcreteDataType::Int64(_), Int64Vector, i64_values, |x| { + x + }), + (ConcreteDataType::UInt8(_), UInt8Vector, u8_values, |x| { + u32::from(x) + }), + (ConcreteDataType::UInt16(_), UInt16Vector, u16_values, |x| { + u32::from(x) + }), + (ConcreteDataType::UInt32(_), UInt32Vector, u32_values, |x| { + x + }), + (ConcreteDataType::UInt64(_), UInt64Vector, u64_values, |x| { + x + }), + ( + ConcreteDataType::Float32(_), + Float32Vector, + f32_values, + |x| { x } + ), + ( + ConcreteDataType::Float64(_), + Float64Vector, + f64_values, + |x| { x } + ), + ( + ConcreteDataType::Binary(_), + BinaryVector, + binary_values, + |x| { x.into() } + ), + ( + ConcreteDataType::String(_), + StringVector, + string_values, + |x| { x.into() } + ), + (ConcreteDataType::Date(_), DateVector, date_values, |x| { + x.val() + }), + ( + ConcreteDataType::DateTime(_), + DateTimeVector, + datetime_values, + |x| { x.val() } + ), + ( + ConcreteDataType::Timestamp(TimestampType::Second(_)), + TimestampSecondVector, + ts_second_values, + |x| { x.into_native() } + ), + ( + ConcreteDataType::Timestamp(TimestampType::Millisecond(_)), + TimestampMillisecondVector, + ts_millisecond_values, + |x| { x.into_native() } + ), + ( + ConcreteDataType::Timestamp(TimestampType::Microsecond(_)), + TimestampMicrosecondVector, + ts_microsecond_values, + |x| { x.into_native() } + ), + ( + ConcreteDataType::Timestamp(TimestampType::Nanosecond(_)), + TimestampNanosecondVector, + ts_nanosecond_values, + |x| { x.into_native() } + ) ) } @@ -201,14 +268,10 @@ mod tests { use std::sync::Arc; use common_recordbatch::{RecordBatch, RecordBatches}; - use datafusion::field_util::SchemaExt; - use datatypes::arrow::array::{Array, BooleanArray, PrimitiveArray}; - use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; - use datatypes::arrow_array::StringArray; - use datatypes::schema::Schema; - use datatypes::vectors::{UInt32Vector, VectorRef}; + use datatypes::data_type::ConcreteDataType; + use datatypes::schema::{ColumnSchema, Schema}; - use crate::select::{null_mask, try_convert, values}; + use super::*; #[test] fn test_convert_record_batches_to_select_result() { @@ -234,9 +297,8 @@ mod tests { #[test] fn test_convert_arrow_arrays_i32() { - let array: PrimitiveArray = - PrimitiveArray::from(vec![Some(1), Some(2), None, Some(3)]); - let array: Arc = Arc::new(array); + let array = Int32Vector::from(vec![Some(1), Some(2), None, Some(3)]); + let array: VectorRef = Arc::new(array); let values = values(&[array]).unwrap(); @@ -245,14 +307,14 @@ mod tests { #[test] fn test_convert_arrow_arrays_string() { - let array = StringArray::from(vec![ + let array = StringVector::from(vec![ Some("1".to_string()), Some("2".to_string()), None, Some("3".to_string()), None, ]); - let array: Arc = Arc::new(array); + let array: VectorRef = Arc::new(array); let values = values(&[array]).unwrap(); @@ -261,8 +323,8 @@ mod tests { #[test] fn test_convert_arrow_arrays_bool() { - let array = BooleanArray::from(vec![Some(true), Some(false), None, Some(false), None]); - let array: Arc = Arc::new(array); + let array = BooleanVector::from(vec![Some(true), Some(false), None, Some(false), None]); + let array: VectorRef = Arc::new(array); let values = values(&[array]).unwrap(); @@ -271,43 +333,42 @@ mod tests { #[test] fn test_convert_arrow_arrays_empty() { - let array = BooleanArray::from(vec![None, None, None, None, None]); - let array: Arc = Arc::new(array); + let array = BooleanVector::from(vec![None, None, None, None, None]); + let array: VectorRef = Arc::new(array); let values = values(&[array]).unwrap(); - assert_eq!(Vec::::default(), values.bool_values); + assert!(values.bool_values.is_empty()); } #[test] fn test_null_mask() { - let a1: Arc = Arc::new(PrimitiveArray::from(vec![None, Some(2), None])); - let a2: Arc = - Arc::new(PrimitiveArray::from(vec![Some(1), Some(2), None, Some(4)])); + let a1: VectorRef = Arc::new(Int32Vector::from(vec![None, Some(2), None])); + let a2: VectorRef = Arc::new(Int32Vector::from(vec![Some(1), Some(2), None, Some(4)])); let mask = null_mask(&vec![a1, a2], 3 + 4); assert_eq!(vec![0b0010_0101], mask); - let empty: Arc = Arc::new(PrimitiveArray::::from(vec![None, None, None])); + let empty: VectorRef = Arc::new(Int32Vector::from(vec![None, None, None])); let mask = null_mask(&vec![empty.clone(), empty.clone(), empty], 9); assert_eq!(vec![0b1111_1111, 0b0000_0001], mask); - let a1: Arc = Arc::new(PrimitiveArray::from(vec![Some(1), Some(2), Some(3)])); - let a2: Arc = Arc::new(PrimitiveArray::from(vec![Some(4), Some(5), Some(6)])); + let a1: VectorRef = Arc::new(Int32Vector::from(vec![Some(1), Some(2), Some(3)])); + let a2: VectorRef = Arc::new(Int32Vector::from(vec![Some(4), Some(5), Some(6)])); let mask = null_mask(&vec![a1, a2], 3 + 3); assert_eq!(Vec::::default(), mask); - let a1: Arc = Arc::new(PrimitiveArray::from(vec![Some(1), Some(2), Some(3)])); - let a2: Arc = Arc::new(PrimitiveArray::from(vec![Some(4), Some(5), None])); + let a1: VectorRef = Arc::new(Int32Vector::from(vec![Some(1), Some(2), Some(3)])); + let a2: VectorRef = Arc::new(Int32Vector::from(vec![Some(4), Some(5), None])); let mask = null_mask(&vec![a1, a2], 3 + 3); assert_eq!(vec![0b0010_0000], mask); } fn mock_record_batch() -> RecordBatch { - let arrow_schema = Arc::new(ArrowSchema::new(vec![ - Field::new("c1", DataType::UInt32, false), - Field::new("c2", DataType::UInt32, false), - ])); - let schema = Arc::new(Schema::try_from(arrow_schema).unwrap()); + let column_schemas = vec![ + ColumnSchema::new("c1", ConcreteDataType::uint32_datatype(), true), + ColumnSchema::new("c2", ConcreteDataType::uint32_datatype(), true), + ]; + let schema = Arc::new(Schema::try_new(column_schemas).unwrap()); let v1 = Arc::new(UInt32Vector::from(vec![Some(1), Some(2), None])); let v2 = Arc::new(UInt32Vector::from(vec![Some(1), None, None])); diff --git a/src/common/grpc/src/writer.rs b/src/common/grpc/src/writer.rs index 2cd28f45af..d05a2908e1 100644 --- a/src/common/grpc/src/writer.rs +++ b/src/common/grpc/src/writer.rs @@ -45,11 +45,11 @@ impl LinesWriter { pub fn write_ts(&mut self, column_name: &str, value: (i64, Precision)) -> Result<()> { let (idx, column) = self.mut_column( column_name, - ColumnDataType::Timestamp, + ColumnDataType::TimestampMillisecond, SemanticType::Timestamp, ); ensure!( - column.datatype == ColumnDataType::Timestamp as i32, + column.datatype == ColumnDataType::TimestampMillisecond as i32, TypeMismatchSnafu { column_name, expected: "timestamp", @@ -58,7 +58,9 @@ impl LinesWriter { ); // It is safe to use unwrap here, because values has been initialized in mut_column() let values = column.values.as_mut().unwrap(); - values.ts_millis_values.push(to_ms_ts(value.1, value.0)); + values + .ts_millisecond_values + .push(to_ms_ts(value.1, value.0)); self.null_masks[idx].push(false); Ok(()) } @@ -224,23 +226,23 @@ impl LinesWriter { pub fn to_ms_ts(p: Precision, ts: i64) -> i64 { match p { - Precision::NANOSECOND => ts / 1_000_000, - Precision::MICROSECOND => ts / 1000, - Precision::MILLISECOND => ts, - Precision::SECOND => ts * 1000, - Precision::MINUTE => ts * 1000 * 60, - Precision::HOUR => ts * 1000 * 60 * 60, + Precision::Nanosecond => ts / 1_000_000, + Precision::Microsecond => ts / 1000, + Precision::Millisecond => ts, + Precision::Second => ts * 1000, + Precision::Minute => ts * 1000 * 60, + Precision::Hour => ts * 1000 * 60 * 60, } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Precision { - NANOSECOND, - MICROSECOND, - MILLISECOND, - SECOND, - MINUTE, - HOUR, + Nanosecond, + Microsecond, + Millisecond, + Second, + Minute, + Hour, } #[cfg(test)] @@ -261,13 +263,13 @@ mod tests { writer.write_f64("memory", 0.4).unwrap(); writer.write_string("name", "name1").unwrap(); writer - .write_ts("ts", (101011000, Precision::MILLISECOND)) + .write_ts("ts", (101011000, Precision::Millisecond)) .unwrap(); writer.commit(); writer.write_tag("host", "host2").unwrap(); writer - .write_ts("ts", (102011001, Precision::MILLISECOND)) + .write_ts("ts", (102011001, Precision::Millisecond)) .unwrap(); writer.write_bool("enable_reboot", true).unwrap(); writer.write_u64("year_of_service", 2).unwrap(); @@ -278,7 +280,7 @@ mod tests { writer.write_f64("cpu", 0.4).unwrap(); writer.write_u64("cpu_core_num", 16).unwrap(); writer - .write_ts("ts", (103011002, Precision::MILLISECOND)) + .write_ts("ts", (103011002, Precision::Millisecond)) .unwrap(); writer.commit(); @@ -321,11 +323,11 @@ mod tests { let column = &columns[4]; assert_eq!("ts", column.column_name); - assert_eq!(ColumnDataType::Timestamp as i32, column.datatype); + assert_eq!(ColumnDataType::TimestampMillisecond as i32, column.datatype); assert_eq!(SemanticType::Timestamp as i32, column.semantic_type); assert_eq!( vec![101011000, 102011001, 103011002], - column.values.as_ref().unwrap().ts_millis_values + column.values.as_ref().unwrap().ts_millisecond_values ); verify_null_mask(&column.null_mask, vec![false, false, false]); @@ -367,16 +369,16 @@ mod tests { #[test] fn test_to_ms() { - assert_eq!(100, to_ms_ts(Precision::NANOSECOND, 100110000)); - assert_eq!(100110, to_ms_ts(Precision::MICROSECOND, 100110000)); - assert_eq!(100110000, to_ms_ts(Precision::MILLISECOND, 100110000)); + assert_eq!(100, to_ms_ts(Precision::Nanosecond, 100110000)); + assert_eq!(100110, to_ms_ts(Precision::Microsecond, 100110000)); + assert_eq!(100110000, to_ms_ts(Precision::Millisecond, 100110000)); assert_eq!( 100110000 * 1000 * 60, - to_ms_ts(Precision::MINUTE, 100110000) + to_ms_ts(Precision::Minute, 100110000) ); assert_eq!( 100110000 * 1000 * 60 * 60, - to_ms_ts(Precision::HOUR, 100110000) + to_ms_ts(Precision::Hour, 100110000) ); } } diff --git a/src/common/recordbatch/src/adapter.rs b/src/common/recordbatch/src/adapter.rs index f2f53861ce..2b8436ec4e 100644 --- a/src/common/recordbatch/src/adapter.rs +++ b/src/common/recordbatch/src/adapter.rs @@ -63,7 +63,7 @@ impl Stream for DfRecordBatchStreamAdapter { match Pin::new(&mut self.stream).poll_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(Some(recordbatch)) => match recordbatch { - Ok(recordbatch) => Poll::Ready(Some(Ok(recordbatch.df_recordbatch))), + Ok(recordbatch) => Poll::Ready(Some(Ok(recordbatch.into_df_record_batch()))), Err(e) => Poll::Ready(Some(Err(ArrowError::ExternalError(Box::new(e))))), }, Poll::Ready(None) => Poll::Ready(None), @@ -102,10 +102,13 @@ impl Stream for RecordBatchStreamAdapter { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match Pin::new(&mut self.stream).poll_next(cx) { Poll::Pending => Poll::Pending, - Poll::Ready(Some(df_recordbatch)) => Poll::Ready(Some(Ok(RecordBatch { - schema: self.schema(), - df_recordbatch: df_recordbatch.context(error::PollStreamSnafu)?, - }))), + Poll::Ready(Some(df_record_batch)) => { + let df_record_batch = df_record_batch.context(error::PollStreamSnafu)?; + Poll::Ready(Some(RecordBatch::try_from_df_record_batch( + self.schema(), + df_record_batch, + ))) + } Poll::Ready(None) => Poll::Ready(None), } } @@ -157,10 +160,8 @@ impl Stream for AsyncRecordBatchStreamAdapter { AsyncRecordBatchStreamAdapterState::Inited(stream) => match stream { Ok(stream) => { return Poll::Ready(ready!(Pin::new(stream).poll_next(cx)).map(|df| { - Ok(RecordBatch { - schema: self.schema(), - df_recordbatch: df.context(error::PollStreamSnafu)?, - }) + let df_record_batch = df.context(error::PollStreamSnafu)?; + RecordBatch::try_from_df_record_batch(self.schema(), df_record_batch) })); } Err(e) => { diff --git a/src/common/recordbatch/src/lib.rs b/src/common/recordbatch/src/lib.rs index 75f463404e..23aa04a9bf 100644 --- a/src/common/recordbatch/src/lib.rs +++ b/src/common/recordbatch/src/lib.rs @@ -96,7 +96,7 @@ impl RecordBatches { pub fn pretty_print(&self) -> Result { let df_batches = &self .iter() - .map(|x| x.df_recordbatch.clone()) + .map(|x| x.df_record_batch().clone()) .collect::>(); let result = pretty::pretty_format_batches(df_batches).context(error::FormatSnafu)?; @@ -140,7 +140,7 @@ impl RecordBatches { let df_record_batches = self .batches .into_iter() - .map(|batch| batch.df_recordbatch) + .map(|batch| batch.into_df_record_batch()) .collect(); // unwrap safety: `MemoryStream::try_new` won't fail Box::pin( diff --git a/src/common/recordbatch/src/recordbatch.rs b/src/common/recordbatch/src/recordbatch.rs index 76c1ee5ef7..47c86831dc 100644 --- a/src/common/recordbatch/src/recordbatch.rs +++ b/src/common/recordbatch/src/recordbatch.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use datatypes::arrow_array::arrow_array_get; use datatypes::schema::SchemaRef; use datatypes::value::Value; use datatypes::vectors::{Helper, VectorRef}; @@ -23,32 +22,76 @@ use snafu::ResultExt; use crate::error::{self, Result}; use crate::DfRecordBatch; -// TODO(yingwen): We should hold vectors in the RecordBatch. /// A two-dimensional batch of column-oriented data with a defined schema. #[derive(Clone, Debug, PartialEq)] pub struct RecordBatch { pub schema: SchemaRef, - pub df_recordbatch: DfRecordBatch, + columns: Vec, + df_record_batch: DfRecordBatch, } impl RecordBatch { + /// Create a new [`RecordBatch`] from `schema` and `columns`. pub fn new>( schema: SchemaRef, columns: I, ) -> Result { - let arrow_arrays = columns.into_iter().map(|v| v.to_arrow_array()).collect(); + let columns: Vec<_> = columns.into_iter().collect(); + let arrow_arrays = columns.iter().map(|v| v.to_arrow_array()).collect(); - let df_recordbatch = DfRecordBatch::try_new(schema.arrow_schema().clone(), arrow_arrays) + let df_record_batch = DfRecordBatch::try_new(schema.arrow_schema().clone(), arrow_arrays) .context(error::NewDfRecordBatchSnafu)?; Ok(RecordBatch { schema, - df_recordbatch, + columns, + df_record_batch, }) } + /// Create a new [`RecordBatch`] from `schema` and `df_record_batch`. + /// + /// This method doesn't check the schema. + pub fn try_from_df_record_batch( + schema: SchemaRef, + df_record_batch: DfRecordBatch, + ) -> Result { + let columns = df_record_batch + .columns() + .iter() + .map(|c| Helper::try_into_vector(c.clone()).context(error::DataTypesSnafu)) + .collect::>>()?; + + Ok(RecordBatch { + schema, + columns, + df_record_batch, + }) + } + + #[inline] + pub fn df_record_batch(&self) -> &DfRecordBatch { + &self.df_record_batch + } + + #[inline] + pub fn into_df_record_batch(self) -> DfRecordBatch { + self.df_record_batch + } + + #[inline] + pub fn column(&self, idx: usize) -> &VectorRef { + &self.columns[idx] + } + + #[inline] + pub fn num_columns(&self) -> usize { + self.columns.len() + } + + #[inline] pub fn num_rows(&self) -> usize { - self.df_recordbatch.num_rows() + self.df_record_batch.num_rows() } /// Create an iterator to traverse the data by row @@ -62,14 +105,15 @@ impl Serialize for RecordBatch { where S: Serializer, { + // TODO(yingwen): arrow and arrow2's schemas have different fields, so + // it might be better to use our `RawSchema` as serialized field. let mut s = serializer.serialize_struct("record", 2)?; s.serialize_field("schema", &**self.schema.arrow_schema())?; - let df_columns = self.df_recordbatch.columns(); - - let vec = df_columns + let vec = self + .columns .iter() - .map(|c| Helper::try_into_vector(c.clone())?.serialize_to_json()) + .map(|c| c.serialize_to_json()) .collect::, _>>() .map_err(S::Error::custom)?; @@ -89,8 +133,8 @@ impl<'a> RecordBatchRowIterator<'a> { fn new(record_batch: &'a RecordBatch) -> RecordBatchRowIterator { RecordBatchRowIterator { record_batch, - rows: record_batch.df_recordbatch.num_rows(), - columns: record_batch.df_recordbatch.num_columns(), + rows: record_batch.df_record_batch.num_rows(), + columns: record_batch.df_record_batch.num_columns(), row_cursor: 0, } } @@ -105,15 +149,9 @@ impl<'a> Iterator for RecordBatchRowIterator<'a> { } else { let mut row = Vec::with_capacity(self.columns); - // TODO(yingwen): Get from the vector if RecordBatch also holds vectors. for col in 0..self.columns { - let column_array = self.record_batch.df_recordbatch.column(col); - match arrow_array_get(column_array.as_ref(), self.row_cursor) - .context(error::DataTypesSnafu) - { - Ok(field) => row.push(field), - Err(e) => return Some(Err(e)), - } + let column = self.record_batch.column(col); + row.push(column.get(self.row_cursor)); } self.row_cursor += 1; @@ -126,17 +164,15 @@ impl<'a> Iterator for RecordBatchRowIterator<'a> { mod tests { use std::sync::Arc; - use datatypes::arrow::array::UInt32Array; use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; use datatypes::data_type::ConcreteDataType; use datatypes::schema::{ColumnSchema, Schema}; - use datatypes::vectors::{StringVector, UInt32Vector, Vector}; + use datatypes::vectors::{StringVector, UInt32Vector}; use super::*; - use crate::DfRecordBatch; #[test] - fn test_new_record_batch() { + fn test_record_batch() { let arrow_schema = Arc::new(ArrowSchema::new(vec![ Field::new("c1", DataType::UInt32, false), Field::new("c2", DataType::UInt32, false), @@ -147,39 +183,36 @@ mod tests { let columns: Vec = vec![v.clone(), v.clone()]; let batch = RecordBatch::new(schema.clone(), columns).unwrap(); - let expect = v.to_arrow_array(); - for column in batch.df_recordbatch.columns() { - let array = column.as_any().downcast_ref::().unwrap(); - assert_eq!( - expect.as_any().downcast_ref::().unwrap(), - array - ); + assert_eq!(3, batch.num_rows()); + for i in 0..batch.num_columns() { + let column = batch.column(i); + let actual = column.as_any().downcast_ref::().unwrap(); + assert_eq!(&*v, actual); } assert_eq!(schema, batch.schema); + + let converted = + RecordBatch::try_from_df_record_batch(schema, batch.df_record_batch().clone()).unwrap(); + assert_eq!(batch, converted); + assert_eq!(*batch.df_record_batch(), converted.into_df_record_batch()); } #[test] pub fn test_serialize_recordbatch() { - let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new( + let column_schemas = vec![ColumnSchema::new( "number", - DataType::UInt32, + ConcreteDataType::uint32_datatype(), false, - )])); - let schema = Arc::new(Schema::try_from(arrow_schema.clone()).unwrap()); + )]; + let schema = Arc::new(Schema::try_new(column_schemas).unwrap()); let numbers: Vec = (0..10).collect(); - let df_batch = - DfRecordBatch::try_new(arrow_schema, vec![Arc::new(UInt32Array::from(numbers))]) - .unwrap(); - - let batch = RecordBatch { - schema, - df_recordbatch: df_batch, - }; + let columns = vec![Arc::new(UInt32Vector::from_slice(&numbers)) as VectorRef]; + let batch = RecordBatch::new(schema, columns).unwrap(); let output = serde_json::to_string(&batch).unwrap(); assert_eq!( - r#"{"schema":{"fields":[{"name":"number","data_type":"UInt32","is_nullable":false,"metadata":{}}],"metadata":{}},"columns":[[0,1,2,3,4,5,6,7,8,9]]}"#, + r#"{"schema":{"fields":[{"name":"number","data_type":"UInt32","nullable":false,"dict_id":0,"dict_is_ordered":false}],"metadata":{"greptime:version":"0"}},"columns":[[0,1,2,3,4,5,6,7,8,9]]}"#, output ); } diff --git a/src/datatypes/src/arrow_array.rs b/src/datatypes/src/arrow_array.rs index 7405c8a665..72de422142 100644 --- a/src/datatypes/src/arrow_array.rs +++ b/src/datatypes/src/arrow_array.rs @@ -12,231 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::{ - Array, BooleanArray, Date32Array, Date64Array, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array, UInt64Array, - UInt8Array, -}; -use arrow::datatypes::DataType; -use common_time::timestamp::TimeUnit; -use common_time::Timestamp; -use snafu::OptionExt; - -use crate::data_type::ConcreteDataType; -use crate::error::{ConversionSnafu, Result}; -use crate::value::{ListValue, Value}; - pub type BinaryArray = arrow::array::LargeBinaryArray; pub type MutableBinaryArray = arrow::array::LargeBinaryBuilder; pub type StringArray = arrow::array::StringArray; pub type MutableStringArray = arrow::array::StringBuilder; - -macro_rules! cast_array { - ($arr: ident, $CastType: ty) => { - $arr.as_any() - .downcast_ref::<$CastType>() - .with_context(|| ConversionSnafu { - from: format!("{:?}", $arr.data_type()), - })? - }; -} - -// TODO(yingwen): Remove this function. -pub fn arrow_array_get(array: &dyn Array, idx: usize) -> Result { - if array.is_null(idx) { - return Ok(Value::Null); - } - - let result = match array.data_type() { - DataType::Null => Value::Null, - DataType::Boolean => Value::Boolean(cast_array!(array, BooleanArray).value(idx)), - DataType::Binary => Value::Binary(cast_array!(array, BinaryArray).value(idx).into()), - DataType::Int8 => Value::Int8(cast_array!(array, Int8Array).value(idx)), - DataType::Int16 => Value::Int16(cast_array!(array, Int16Array).value(idx)), - DataType::Int32 => Value::Int32(cast_array!(array, Int32Array).value(idx)), - DataType::Int64 => Value::Int64(cast_array!(array, Int64Array).value(idx)), - DataType::UInt8 => Value::UInt8(cast_array!(array, UInt8Array).value(idx)), - DataType::UInt16 => Value::UInt16(cast_array!(array, UInt16Array).value(idx)), - DataType::UInt32 => Value::UInt32(cast_array!(array, UInt32Array).value(idx)), - DataType::UInt64 => Value::UInt64(cast_array!(array, UInt64Array).value(idx)), - DataType::Float32 => Value::Float32(cast_array!(array, Float32Array).value(idx).into()), - DataType::Float64 => Value::Float64(cast_array!(array, Float64Array).value(idx).into()), - DataType::Utf8 => Value::String(cast_array!(array, StringArray).value(idx).into()), - DataType::Date32 => Value::Date(cast_array!(array, Date32Array).value(idx).into()), - DataType::Date64 => Value::DateTime(cast_array!(array, Date64Array).value(idx).into()), - DataType::Timestamp(t, _) => match t { - arrow::datatypes::TimeUnit::Second => Value::Timestamp(Timestamp::new( - cast_array!(array, arrow::array::TimestampSecondArray).value(idx), - TimeUnit::Second, - )), - arrow::datatypes::TimeUnit::Millisecond => Value::Timestamp(Timestamp::new( - cast_array!(array, arrow::array::TimestampMillisecondArray).value(idx), - TimeUnit::Millisecond, - )), - arrow::datatypes::TimeUnit::Microsecond => Value::Timestamp(Timestamp::new( - cast_array!(array, arrow::array::TimestampMicrosecondArray).value(idx), - TimeUnit::Microsecond, - )), - arrow::datatypes::TimeUnit::Nanosecond => Value::Timestamp(Timestamp::new( - cast_array!(array, arrow::array::TimestampNanosecondArray).value(idx), - TimeUnit::Nanosecond, - )), - }, - DataType::List(_) => { - let array = cast_array!(array, ListArray).value(idx); - let item_type = ConcreteDataType::try_from(array.data_type())?; - let values = (0..array.len()) - .map(|i| arrow_array_get(&*array, i)) - .collect::>>()?; - Value::List(ListValue::new(Some(Box::new(values)), item_type)) - } - _ => unimplemented!("Arrow array datatype: {:?}", array.data_type()), - }; - - Ok(result) -} - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use arrow::array::{ - BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - LargeBinaryArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, - UInt8Array, - }; - use arrow::datatypes::Int32Type; - use common_time::timestamp::{TimeUnit, Timestamp}; - use paste::paste; - - use super::*; - use crate::data_type::ConcreteDataType; - use crate::types::TimestampType; - - macro_rules! test_arrow_array_get_for_timestamps { - ( $($unit: ident), *) => { - $( - paste! { - let mut builder = arrow::array::[]::builder(3); - builder.append_value(1); - builder.append_value(0); - builder.append_value(-1); - let ts_array = Arc::new(builder.finish()) as Arc; - let v = arrow_array_get(&ts_array, 1).unwrap(); - assert_eq!( - ConcreteDataType::Timestamp(TimestampType::$unit( - $crate::types::[]::default(), - )), - v.data_type() - ); - } - )* - }; - } - - #[test] - fn test_timestamp_array() { - test_arrow_array_get_for_timestamps![Second, Millisecond, Microsecond, Nanosecond]; - } - - #[test] - fn test_arrow_array_access() { - let array1 = BooleanArray::from(vec![true, true, false, false]); - assert_eq!(Value::Boolean(true), arrow_array_get(&array1, 1).unwrap()); - let array1 = Int8Array::from(vec![1, 2, 3, 4]); - assert_eq!(Value::Int8(2), arrow_array_get(&array1, 1).unwrap()); - let array1 = UInt8Array::from(vec![1, 2, 3, 4]); - assert_eq!(Value::UInt8(2), arrow_array_get(&array1, 1).unwrap()); - let array1 = Int16Array::from(vec![1, 2, 3, 4]); - assert_eq!(Value::Int16(2), arrow_array_get(&array1, 1).unwrap()); - let array1 = UInt16Array::from(vec![1, 2, 3, 4]); - assert_eq!(Value::UInt16(2), arrow_array_get(&array1, 1).unwrap()); - let array1 = Int32Array::from(vec![1, 2, 3, 4]); - assert_eq!(Value::Int32(2), arrow_array_get(&array1, 1).unwrap()); - let array1 = UInt32Array::from(vec![1, 2, 3, 4]); - assert_eq!(Value::UInt32(2), arrow_array_get(&array1, 1).unwrap()); - let array = Int64Array::from(vec![1, 2, 3, 4]); - assert_eq!(Value::Int64(2), arrow_array_get(&array, 1).unwrap()); - let array1 = UInt64Array::from(vec![1, 2, 3, 4]); - assert_eq!(Value::UInt64(2), arrow_array_get(&array1, 1).unwrap()); - let array1 = Float32Array::from(vec![1f32, 2f32, 3f32, 4f32]); - assert_eq!( - Value::Float32(2f32.into()), - arrow_array_get(&array1, 1).unwrap() - ); - let array1 = Float64Array::from(vec![1f64, 2f64, 3f64, 4f64]); - assert_eq!( - Value::Float64(2f64.into()), - arrow_array_get(&array1, 1).unwrap() - ); - - let array2 = StringArray::from(vec![Some("hello"), None, Some("world")]); - assert_eq!( - Value::String("hello".into()), - arrow_array_get(&array2, 0).unwrap() - ); - assert_eq!(Value::Null, arrow_array_get(&array2, 1).unwrap()); - - let array3 = LargeBinaryArray::from(vec![ - Some("hello".as_bytes()), - None, - Some("world".as_bytes()), - ]); - assert_eq!(Value::Null, arrow_array_get(&array3, 1).unwrap()); - - let array = TimestampSecondArray::from(vec![1, 2, 3]); - let value = arrow_array_get(&array, 1).unwrap(); - assert_eq!(value, Value::Timestamp(Timestamp::new(2, TimeUnit::Second))); - let array = TimestampMillisecondArray::from(vec![1, 2, 3]); - let value = arrow_array_get(&array, 1).unwrap(); - assert_eq!( - value, - Value::Timestamp(Timestamp::new(2, TimeUnit::Millisecond)) - ); - let array = TimestampMicrosecondArray::from(vec![1, 2, 3]); - let value = arrow_array_get(&array, 1).unwrap(); - assert_eq!( - value, - Value::Timestamp(Timestamp::new(2, TimeUnit::Microsecond)) - ); - let array = TimestampNanosecondArray::from(vec![1, 2, 3]); - let value = arrow_array_get(&array, 1).unwrap(); - assert_eq!( - value, - Value::Timestamp(Timestamp::new(2, TimeUnit::Nanosecond)) - ); - - // test list array - let data = vec![ - Some(vec![Some(1), Some(2), Some(3)]), - None, - Some(vec![Some(4), None, Some(6)]), - ]; - let arrow_array = ListArray::from_iter_primitive::(data); - - let v0 = arrow_array_get(&arrow_array, 0).unwrap(); - match v0 { - Value::List(list) => { - assert!(matches!(list.datatype(), ConcreteDataType::Int32(_))); - let items = list.items().as_ref().unwrap(); - assert_eq!( - **items, - vec![Value::Int32(1), Value::Int32(2), Value::Int32(3)] - ); - } - _ => unreachable!(), - } - - assert_eq!(Value::Null, arrow_array_get(&arrow_array, 1).unwrap()); - let v2 = arrow_array_get(&arrow_array, 2).unwrap(); - match v2 { - Value::List(list) => { - assert!(matches!(list.datatype(), ConcreteDataType::Int32(_))); - let items = list.items().as_ref().unwrap(); - assert_eq!(**items, vec![Value::Int32(4), Value::Null, Value::Int32(6)]); - } - _ => unreachable!(), - } - } -}