diff --git a/src/common/recordbatch/src/error.rs b/src/common/recordbatch/src/error.rs index e0839924ae..383b848ba5 100644 --- a/src/common/recordbatch/src/error.rs +++ b/src/common/recordbatch/src/error.rs @@ -15,12 +15,19 @@ pub enum InnerError { source: datatypes::arrow::error::ArrowError, backtrace: Backtrace, }, + + #[snafu(display("Data types error, source: {}", source))] + DataTypes { + #[snafu(backtrace)] + source: datatypes::error::Error, + }, } impl ErrorExt for InnerError { fn status_code(&self) -> StatusCode { match self { InnerError::NewDfRecordBatch { .. } => StatusCode::InvalidArguments, + InnerError::DataTypes { .. } => StatusCode::Internal, } } diff --git a/src/common/recordbatch/src/recordbatch.rs b/src/common/recordbatch/src/recordbatch.rs index d62caaf761..bb996ea463 100644 --- a/src/common/recordbatch/src/recordbatch.rs +++ b/src/common/recordbatch/src/recordbatch.rs @@ -1,5 +1,7 @@ use datafusion_common::record_batch::RecordBatch as DfRecordBatch; +use datatypes::arrow_array::arrow_array_get; use datatypes::schema::SchemaRef; +use datatypes::value::Value; use datatypes::vectors::{Helper, VectorRef}; use serde::ser::{Error, SerializeStruct}; use serde::{Serialize, Serializer}; @@ -28,6 +30,11 @@ impl RecordBatch { df_recordbatch, }) } + + /// Create an iterator to traverse the data by row + pub fn rows(&self) -> RecordBatchRowIterator<'_> { + RecordBatchRowIterator::new(self) + } } impl Serialize for RecordBatch { @@ -51,6 +58,49 @@ impl Serialize for RecordBatch { } } +pub struct RecordBatchRowIterator<'a> { + record_batch: &'a RecordBatch, + rows: usize, + columns: usize, + row_cursor: usize, +} + +impl<'a> RecordBatchRowIterator<'a> { + fn new(record_batch: &'a RecordBatch) -> RecordBatchRowIterator { + RecordBatchRowIterator { + record_batch, + rows: record_batch.df_recordbatch.num_rows(), + columns: record_batch.df_recordbatch.num_columns(), + row_cursor: 0, + } + } +} + +impl<'a> Iterator for RecordBatchRowIterator<'a> { + type Item = Result>; + + fn next(&mut self) -> Option { + if self.row_cursor == self.rows { + None + } else { + let mut row = Vec::with_capacity(self.columns); + + for col in 0..self.columns { + let column_array = self.record_batch.df_recordbatch.column(col); + match arrow_array_get(column_array.as_ref(), self.row_cursor) + .context(error::DataTypesSnafu) + { + Ok(field) => row.push(field), + Err(e) => return Some(Err(e.into())), + } + } + + self.row_cursor += 1; + Some(Ok(row)) + } + } +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -59,8 +109,9 @@ mod tests { use datafusion_common::record_batch::RecordBatch as DfRecordBatch; use datatypes::arrow::array::UInt32Array; use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; - use datatypes::schema::Schema; - use datatypes::vectors::{UInt32Vector, Vector}; + use datatypes::prelude::*; + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::vectors::{StringVector, UInt32Vector, Vector}; use super::*; @@ -114,4 +165,66 @@ mod tests { output ); } + + #[test] + fn test_record_batch_visitor() { + let column_schemas = vec![ + ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false), + ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true), + ]; + let schema = Arc::new(Schema::new(column_schemas)); + let columns: Vec = vec![ + Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])), + Arc::new(StringVector::from(vec![ + None, + Some("hello"), + Some("greptime"), + None, + ])), + ]; + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + + let mut record_batch_iter = recordbatch.rows(); + assert_eq!( + vec![Value::UInt32(1), Value::Null], + record_batch_iter + .next() + .unwrap() + .unwrap() + .into_iter() + .collect::>() + ); + + assert_eq!( + vec![Value::UInt32(2), Value::String("hello".into())], + record_batch_iter + .next() + .unwrap() + .unwrap() + .into_iter() + .collect::>() + ); + + assert_eq!( + vec![Value::UInt32(3), Value::String("greptime".into())], + record_batch_iter + .next() + .unwrap() + .unwrap() + .into_iter() + .collect::>() + ); + + assert_eq!( + vec![Value::UInt32(4), Value::Null], + record_batch_iter + .next() + .unwrap() + .unwrap() + .into_iter() + .collect::>() + ); + + assert!(record_batch_iter.next().is_none()); + } } diff --git a/src/datatypes/src/arrow_array.rs b/src/datatypes/src/arrow_array.rs index 27da7d29a5..bc4711d724 100644 --- a/src/datatypes/src/arrow_array.rs +++ b/src/datatypes/src/arrow_array.rs @@ -1,9 +1,124 @@ use arrow::array::{ - BinaryArray as ArrowBinaryArray, MutableBinaryArray as ArrowMutableBinaryArray, - MutableUtf8Array, Utf8Array, + self, Array, BinaryArray as ArrowBinaryArray, MutableBinaryArray as ArrowMutableBinaryArray, + MutableUtf8Array, PrimitiveArray, Utf8Array, }; +use arrow::datatypes::DataType as ArrowDataType; +use snafu::OptionExt; + +use crate::error::{ConversionSnafu, Result}; +use crate::value::Value; pub type BinaryArray = ArrowBinaryArray; pub type MutableBinaryArray = ArrowMutableBinaryArray; pub type MutableStringArray = MutableUtf8Array; pub type StringArray = Utf8Array; + +macro_rules! cast_array { + ($arr: ident, $CastType: ty) => { + $arr.as_any() + .downcast_ref::<$CastType>() + .with_context(|| ConversionSnafu { + from: format!("{:?}", $arr.data_type()), + })? + }; +} + +pub fn arrow_array_get(array: &dyn Array, idx: usize) -> Result { + if array.is_null(idx) { + return Ok(Value::Null); + } + + let result = match array.data_type() { + ArrowDataType::Null => Value::Null, + ArrowDataType::Boolean => { + Value::Boolean(cast_array!(array, array::BooleanArray).value(idx)) + } + ArrowDataType::Binary | ArrowDataType::LargeBinary => { + Value::Binary(cast_array!(array, BinaryArray).value(idx).into()) + } + ArrowDataType::Int8 => Value::Int8(cast_array!(array, PrimitiveArray::).value(idx)), + ArrowDataType::Int16 => Value::Int16(cast_array!(array, PrimitiveArray::).value(idx)), + ArrowDataType::Int32 => Value::Int32(cast_array!(array, PrimitiveArray::).value(idx)), + ArrowDataType::Int64 => Value::Int64(cast_array!(array, PrimitiveArray::).value(idx)), + ArrowDataType::UInt8 => Value::UInt8(cast_array!(array, PrimitiveArray::).value(idx)), + ArrowDataType::UInt16 => { + Value::UInt16(cast_array!(array, PrimitiveArray::).value(idx)) + } + ArrowDataType::UInt32 => { + Value::UInt32(cast_array!(array, PrimitiveArray::).value(idx)) + } + ArrowDataType::UInt64 => { + Value::UInt64(cast_array!(array, PrimitiveArray::).value(idx)) + } + ArrowDataType::Float32 => { + Value::Float32(cast_array!(array, PrimitiveArray::).value(idx).into()) + } + ArrowDataType::Float64 => { + Value::Float64(cast_array!(array, PrimitiveArray::).value(idx).into()) + } + ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => { + Value::String(cast_array!(array, StringArray).value(idx).into()) + } + // TODO(sunng87): List + _ => unimplemented!("Arrow array datatype: {:?}", array.data_type()), + }; + + Ok(result) +} + +#[cfg(test)] +mod test { + use arrow::array::*; + + use super::*; + + #[test] + fn test_arrow_array_access() { + let array1 = BooleanArray::from_slice(vec![true, true, false, false]); + assert_eq!(Value::Boolean(true), arrow_array_get(&array1, 1).unwrap()); + let array1 = Int8Array::from_vec(vec![1, 2, 3, 4]); + assert_eq!(Value::Int8(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = UInt8Array::from_vec(vec![1, 2, 3, 4]); + assert_eq!(Value::UInt8(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = Int16Array::from_vec(vec![1, 2, 3, 4]); + assert_eq!(Value::Int16(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = UInt16Array::from_vec(vec![1, 2, 3, 4]); + assert_eq!(Value::UInt16(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = Int32Array::from_vec(vec![1, 2, 3, 4]); + assert_eq!(Value::Int32(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = UInt32Array::from_vec(vec![1, 2, 3, 4]); + assert_eq!(Value::UInt32(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = Int64Array::from_vec(vec![1, 2, 3, 4]); + assert_eq!(Value::Int64(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = UInt64Array::from_vec(vec![1, 2, 3, 4]); + assert_eq!(Value::UInt64(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = Float32Array::from_vec(vec![1f32, 2f32, 3f32, 4f32]); + assert_eq!( + Value::Float32(2f32.into()), + arrow_array_get(&array1, 1).unwrap() + ); + let array1 = Float64Array::from_vec(vec![1f64, 2f64, 3f64, 4f64]); + assert_eq!( + Value::Float64(2f64.into()), + arrow_array_get(&array1, 1).unwrap() + ); + + let array2 = StringArray::from(vec![Some("hello"), None, Some("world")]); + assert_eq!( + Value::String("hello".into()), + arrow_array_get(&array2, 0).unwrap() + ); + assert_eq!(Value::Null, arrow_array_get(&array2, 1).unwrap()); + + let array3 = super::BinaryArray::from(vec![ + Some("hello".as_bytes()), + None, + Some("world".as_bytes()), + ]); + assert_eq!( + Value::Binary("hello".as_bytes().into()), + arrow_array_get(&array3, 0).unwrap() + ); + assert_eq!(Value::Null, arrow_array_get(&array3, 1).unwrap()); + } +} diff --git a/src/servers/src/mysql/writer.rs b/src/servers/src/mysql/writer.rs index 42b7086d14..0b21561da4 100644 --- a/src/servers/src/mysql/writer.rs +++ b/src/servers/src/mysql/writer.rs @@ -1,7 +1,8 @@ use std::io; +use std::ops::Deref; use common_recordbatch::{util, RecordBatch}; -use datatypes::prelude::{ConcreteDataType, Value, VectorHelper}; +use datatypes::prelude::{ConcreteDataType, Value}; use datatypes::schema::{ColumnSchema, SchemaRef}; use opensrv_mysql::{ Column, ColumnFlags, ColumnType, ErrorKind, OkResponse, QueryResultWriter, RowWriter, @@ -82,12 +83,12 @@ impl<'a, W: io::Write> MysqlResultWriter<'a, W> { } fn write_recordbatch(row_writer: &mut RowWriter, recordbatch: &RecordBatch) -> Result<()> { - let matrix = transpose(recordbatch)?; - for row in matrix.iter() { - for v in row.iter() { - match v { + for row in recordbatch.rows() { + let row = row.context(error::CollectRecordbatchSnafu)?; + for value in row.into_iter() { + match value { Value::Null => row_writer.write_col(None::)?, - Value::Boolean(v) => row_writer.write_col(*v as i8)?, + Value::Boolean(v) => row_writer.write_col(v as i8)?, Value::UInt8(v) => row_writer.write_col(v)?, Value::UInt16(v) => row_writer.write_col(v)?, Value::UInt32(v) => row_writer.write_col(v)?, @@ -99,14 +100,14 @@ impl<'a, W: io::Write> MysqlResultWriter<'a, W> { Value::Float32(v) => row_writer.write_col(v.0)?, Value::Float64(v) => row_writer.write_col(v.0)?, Value::String(v) => row_writer.write_col(v.as_utf8())?, - Value::Binary(v) => row_writer.write_col(v.to_vec())?, + Value::Binary(v) => row_writer.write_col(v.deref())?, Value::Date(v) => row_writer.write_col(v)?, Value::DateTime(v) => row_writer.write_col(v)?, - _ => { + Value::List(_) => { return Err(Error::Internal { err_msg: format!( "cannot write value {:?} in mysql protocol: unimplemented", - v + &value ), }) } @@ -169,69 +170,3 @@ pub fn create_mysql_column_def(schema: &SchemaRef) -> Result> { .map(create_mysql_column) .collect() } - -/// RecordBatch organizes its values in columns while MySQL needs to write row by row. -/// This function creates a view of [Value]s organized in rows from RecordBatch (just like matrix -/// transpose, hence the function name), helping us write RecordBatch to MySQL. -fn transpose(recordbatch: &RecordBatch) -> Result>> { - let recordbatch = &recordbatch.df_recordbatch; - let rows = recordbatch.num_rows(); - let columns = recordbatch.num_columns(); - let mut matrix = vec![vec![Value::Null; columns]; rows]; - for column in 0..columns { - let array = recordbatch.column(column); - let vector = VectorHelper::try_into_vector(array).context(error::VectorConversionSnafu)?; - // Clippy suggests us to use "matrix.iter_mut().enumerate().take(rows)", which is not wanted. - #[allow(clippy::needless_range_loop)] - for row in 0..rows { - matrix[row][column] = vector.get(row); - } - } - Ok(matrix) -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use common_base::bytes::StringBytes; - use datatypes::prelude::*; - use datatypes::schema::Schema; - use datatypes::vectors::{StringVector, UInt32Vector}; - - use super::*; - - #[test] - fn test_transpose() { - let column_schemas = vec![ - ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false), - ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true), - ]; - let schema = Arc::new(Schema::new(column_schemas)); - let columns: Vec = vec![ - Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])), - Arc::new(StringVector::from(vec![ - None, - Some("hello"), - Some("greptime"), - None, - ])), - ]; - let recordbatch = RecordBatch::new(schema, columns).unwrap(); - let matrix = transpose(&recordbatch).unwrap(); - assert_eq!(4, matrix.len()); - assert_eq!(vec![Value::UInt32(1), Value::Null], matrix[0]); - assert_eq!( - vec![Value::UInt32(2), Value::String(StringBytes::from("hello"))], - matrix[1] - ); - assert_eq!( - vec![ - Value::UInt32(3), - Value::String(StringBytes::from("greptime")) - ], - matrix[2] - ); - assert_eq!(vec![Value::UInt32(4), Value::Null], matrix[3]); - } -}