diff --git a/src/common/substrait/Cargo.toml b/src/common/substrait/Cargo.toml index db563275b4..301b8289db 100644 --- a/src/common/substrait/Cargo.toml +++ b/src/common/substrait/Cargo.toml @@ -11,6 +11,7 @@ common-error = { path = "../error" } datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [ "simd", ] } +datatypes = { path = "../../datatypes" } futures = "0.3" prost = "0.9" snafu = { version = "0.7", features = ["backtraces"] } diff --git a/src/common/substrait/src/df_logical.rs b/src/common/substrait/src/df_logical.rs index a8f2426340..2e746cdc81 100644 --- a/src/common/substrait/src/df_logical.rs +++ b/src/common/substrait/src/df_logical.rs @@ -5,9 +5,12 @@ use catalog::CatalogManagerRef; use common_error::prelude::BoxedError; use datafusion::datasource::TableProvider; use datafusion::logical_plan::{LogicalPlan, TableScan, ToDFSchema}; +use datafusion::physical_plan::project_schema; use prost::Message; use snafu::ensure; use snafu::{OptionExt, ResultExt}; +use substrait_proto::protobuf::expression::mask_expression::{StructItem, StructSelect}; +use substrait_proto::protobuf::expression::MaskExpression; use substrait_proto::protobuf::plan_rel::RelType as PlanRelType; use substrait_proto::protobuf::read_rel::{NamedTable, ReadType}; use substrait_proto::protobuf::rel::RelType; @@ -19,9 +22,10 @@ use table::table::adapter::DfTableProviderAdapter; use crate::error::Error; use crate::error::{ DFInternalSnafu, DecodeRelSnafu, EmptyPlanSnafu, EncodeRelSnafu, InternalSnafu, - InvalidParametersSnafu, MissingFieldSnafu, TableNotFoundSnafu, UnknownPlanSnafu, - UnsupportedExprSnafu, UnsupportedPlanSnafu, + InvalidParametersSnafu, MissingFieldSnafu, SchemaNotMatchSnafu, TableNotFoundSnafu, + UnknownPlanSnafu, UnsupportedExprSnafu, UnsupportedPlanSnafu, }; +use crate::schema::{from_schema, to_schema}; use crate::SubstraitPlan; pub struct DFLogicalSubstraitConvertor { @@ -148,6 +152,11 @@ impl DFLogicalSubstraitConvertor { } }; + // Get projection indices + let projection = read_rel + .projection + .map(|mask_expr| self.convert_mask_expression(mask_expr)); + // Get table handle from catalog manager let table_ref = self .catalog_manager @@ -158,23 +167,45 @@ impl DFLogicalSubstraitConvertor { name: format!("{}.{}.{}", catalog_name, schema_name, table_name), })?; let adapter = Arc::new(DfTableProviderAdapter::new(table_ref)); - // Get schema direct from the table. - // TODO(ruihang): Maybe need to verify the schema with the one in Substrait? - let schema = adapter - .schema() + + // Get schema directly from the table, and compare it with the schema retrived from substrait proto. + let stored_schema = adapter.schema(); + let retrived_schema = to_schema(read_rel.base_schema.unwrap_or_default())?; + let retrived_arrow_schema = retrived_schema.arrow_schema(); + ensure!( + stored_schema.fields == retrived_arrow_schema.fields, + SchemaNotMatchSnafu { + substrait_schema: retrived_arrow_schema.clone(), + storage_schema: stored_schema + } + ); + + // Calculate the projected schema + let projected_schema = project_schema(&stored_schema, projection.as_ref()) + .context(DFInternalSnafu)? .to_dfschema_ref() .context(DFInternalSnafu)?; - // TODO(ruihang): Support projection, filters and limit + // TODO(ruihang): Support filters and limit Ok(LogicalPlan::TableScan(TableScan { table_name, source: adapter, - projection: None, - projected_schema: schema, + projection, + projected_schema, filters: vec![], limit: None, })) } + + fn convert_mask_expression(&self, mask_expression: MaskExpression) -> Vec { + mask_expression + .select + .unwrap_or_default() + .struct_items + .into_iter() + .map(|select| select.field as _) + .collect() + } } impl DFLogicalSubstraitConvertor { @@ -254,27 +285,51 @@ impl DFLogicalSubstraitConvertor { .context(UnknownPlanSnafu)?; let table_info = provider.table().table_info(); + // assemble NamedTable and ReadType let catalog_name = table_info.catalog_name.clone(); let schema_name = table_info.schema_name.clone(); let table_name = table_info.name.clone(); - let named_table = NamedTable { names: vec![catalog_name, schema_name, table_name], advanced_extension: None, }; let read_type = ReadType::NamedTable(named_table); + // assemble projection + let projection = table_scan + .projection + .map(|proj| self.convert_schema_projection(&proj)); + + // assemble base (unprojected) schema using Table's schema. + let base_schema = from_schema(&provider.table().schema())?; + let read_rel = ReadRel { common: None, - base_schema: None, + base_schema: Some(base_schema), filter: None, - projection: None, + projection, advanced_extension: None, read_type: Some(read_type), }; Ok(read_rel) } + + /// Convert a index-based schema projection to substrait's [MaskExpression]. + fn convert_schema_projection(&self, projections: &[usize]) -> MaskExpression { + let struct_items = projections + .iter() + .map(|index| StructItem { + field: *index as i32, + child: None, + }) + .collect(); + MaskExpression { + select: Some(StructSelect { struct_items }), + // TODO(ruihang): this field is unspecified + maintain_singular_struct: true, + } + } } #[cfg(test)] @@ -285,10 +340,12 @@ mod test { CatalogList, CatalogProvider, RegisterTableRequest, }; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; + use datafusion::logical_plan::DFSchema; use datatypes::schema::Schema; use table::{requests::CreateTableRequest, test_util::EmptyTable, test_util::MockTableEngine}; use super::*; + use crate::schema::test::supported_types; const DEFAULT_TABLE_NAME: &str = "SubstraitTable"; @@ -319,7 +376,7 @@ mod test { schema_name: DEFAULT_SCHEMA_NAME.to_string(), table_name: table_name.to_string(), desc: None, - schema: Arc::new(Schema::new(vec![])), + schema: Arc::new(Schema::new(supported_types())), primary_key_indices: vec![], create_if_not_exists: true, table_options: Default::default(), @@ -336,7 +393,7 @@ mod test { } #[tokio::test] - async fn test_bare_table_scan() { + async fn test_table_scan() { let catalog_manager = build_mock_catalog_manager().await; let table_ref = Arc::new(EmptyTable::new(build_create_table_request( DEFAULT_TABLE_NAME, @@ -352,13 +409,20 @@ mod test { .await .unwrap(); let adapter = Arc::new(DfTableProviderAdapter::new(table_ref)); - let schema = adapter.schema().to_dfschema_ref().unwrap(); + let projection = vec![1, 3, 5]; + let df_schema = adapter.schema().to_dfschema().unwrap(); + let projected_fields = projection + .iter() + .map(|index| df_schema.field(*index).clone()) + .collect(); + let projected_schema = + Arc::new(DFSchema::new_with_metadata(projected_fields, Default::default()).unwrap()); let table_scan_plan = LogicalPlan::TableScan(TableScan { table_name: DEFAULT_TABLE_NAME.to_string(), source: adapter, - projection: None, - projected_schema: schema, + projection: Some(projection), + projected_schema, filters: vec![], limit: None, }); diff --git a/src/common/substrait/src/error.rs b/src/common/substrait/src/error.rs index 6cebc19020..ad591b949c 100644 --- a/src/common/substrait/src/error.rs +++ b/src/common/substrait/src/error.rs @@ -2,6 +2,7 @@ use std::any::Any; use common_error::prelude::{BoxedError, ErrorExt, StatusCode}; use datafusion::error::DataFusionError; +use datatypes::prelude::ConcreteDataType; use prost::{DecodeError, EncodeError}; use snafu::{Backtrace, ErrorCompat, Snafu}; @@ -14,6 +15,15 @@ pub enum Error { #[snafu(display("Unsupported physical plan: {}", name))] UnsupportedExpr { name: String, backtrace: Backtrace }, + #[snafu(display("Unsupported concrete type: {:?}", ty))] + UnsupportedConcreteType { + ty: ConcreteDataType, + backtrace: Backtrace, + }, + + #[snafu(display("Unsupported substrait type: {}", ty))] + UnsupportedSubstraitType { ty: String, backtrace: Backtrace }, + #[snafu(display("Failed to decode substrait relation, source: {}", source))] DecodeRel { source: DecodeError, @@ -60,16 +70,32 @@ pub enum Error { #[snafu(display("Table quering not found: {}", name))] TableNotFound { name: String, backtrace: Backtrace }, - #[snafu(display("Cannot convert plan doesn't belong to GrepTimeDB"))] + #[snafu(display("Cannot convert plan doesn't belong to GreptimeDB"))] UnknownPlan { backtrace: Backtrace }, + + #[snafu(display( + "Schema from Substrait proto doesn't match with the schema in storage. + Substrait schema: {:?} + Storage schema: {:?}", + substrait_schema, + storage_schema + ))] + SchemaNotMatch { + substrait_schema: datafusion::arrow::datatypes::SchemaRef, + storage_schema: datafusion::arrow::datatypes::SchemaRef, + backtrace: Backtrace, + }, } +pub type Result = std::result::Result; + impl ErrorExt for Error { fn status_code(&self) -> StatusCode { match self { - Error::UnsupportedPlan { .. } | Error::UnsupportedExpr { .. } => { - StatusCode::Unsupported - } + Error::UnsupportedConcreteType { .. } + | Error::UnsupportedPlan { .. } + | Error::UnsupportedExpr { .. } + | Error::UnsupportedSubstraitType { .. } => StatusCode::Unsupported, Error::UnknownPlan { .. } | Error::EncodeRel { .. } | Error::DecodeRel { .. } @@ -77,7 +103,8 @@ impl ErrorExt for Error { | Error::EmptyExpr { .. } | Error::MissingField { .. } | Error::InvalidParameters { .. } - | Error::TableNotFound { .. } => StatusCode::InvalidArguments, + | Error::TableNotFound { .. } + | Error::SchemaNotMatch { .. } => StatusCode::InvalidArguments, Error::DFInternal { .. } | Error::Internal { .. } => StatusCode::Internal, } } diff --git a/src/common/substrait/src/lib.rs b/src/common/substrait/src/lib.rs index 3808ce9ec6..b5bf75e0bd 100644 --- a/src/common/substrait/src/lib.rs +++ b/src/common/substrait/src/lib.rs @@ -1,5 +1,7 @@ mod df_logical; pub mod error; +mod schema; +mod types; use bytes::{Buf, Bytes}; diff --git a/src/common/substrait/src/schema.rs b/src/common/substrait/src/schema.rs new file mode 100644 index 0000000000..42a5ea9ec9 --- /dev/null +++ b/src/common/substrait/src/schema.rs @@ -0,0 +1,97 @@ +use datatypes::schema::{ColumnSchema, Schema}; +use substrait_proto::protobuf::r#type::{Nullability, Struct as SubstraitStruct}; +use substrait_proto::protobuf::NamedStruct; + +use crate::error::Result; +use crate::types::{from_concrete_type, to_concrete_type}; + +pub fn to_schema(named_struct: NamedStruct) -> Result { + if named_struct.r#struct.is_none() { + return Ok(Schema::new(vec![])); + } + + let column_schemas = named_struct + .r#struct + .unwrap() + .types + .into_iter() + .zip(named_struct.names.into_iter()) + .map(|(ty, name)| { + let (concrete_type, is_nullable) = to_concrete_type(&ty)?; + let column_schema = ColumnSchema::new(name, concrete_type, is_nullable); + Ok(column_schema) + }) + .collect::>()?; + + Ok(Schema::new(column_schemas)) +} + +pub fn from_schema(schema: &Schema) -> Result { + let mut names = Vec::with_capacity(schema.num_columns()); + let mut types = Vec::with_capacity(schema.num_columns()); + + for column_schema in schema.column_schemas() { + names.push(column_schema.name.clone()); + let substrait_type = from_concrete_type( + column_schema.data_type.clone(), + Some(column_schema.is_nullable()), + )?; + types.push(substrait_type); + } + + // TODO(ruihang): `type_variation_reference` and `nullability` are unspecified. + let substrait_struct = SubstraitStruct { + types, + type_variation_reference: 0, + nullability: Nullability::Unspecified as _, + }; + + Ok(NamedStruct { + names, + r#struct: Some(substrait_struct), + }) +} + +#[cfg(test)] +pub(crate) mod test { + use datatypes::prelude::{ConcreteDataType, DataType}; + + use super::*; + + pub(crate) fn supported_types() -> Vec { + [ + ConcreteDataType::null_datatype(), + ConcreteDataType::boolean_datatype(), + ConcreteDataType::int8_datatype(), + ConcreteDataType::int16_datatype(), + ConcreteDataType::int32_datatype(), + ConcreteDataType::int64_datatype(), + ConcreteDataType::uint8_datatype(), + ConcreteDataType::uint16_datatype(), + ConcreteDataType::uint32_datatype(), + ConcreteDataType::uint64_datatype(), + ConcreteDataType::float32_datatype(), + ConcreteDataType::float64_datatype(), + ConcreteDataType::binary_datatype(), + ConcreteDataType::string_datatype(), + ConcreteDataType::date_datatype(), + ConcreteDataType::timestamp_datatype(Default::default()), + // TODO(ruihang): DateTime and List type are not supported now + ] + .into_iter() + .enumerate() + .map(|(ordinal, ty)| ColumnSchema::new(ty.name().to_string(), ty, ordinal % 2 == 0)) + .collect() + } + + #[test] + fn supported_types_round_trip() { + let column_schemas = supported_types(); + let schema = Schema::new(column_schemas); + + let named_struct = from_schema(&schema).unwrap(); + let converted_schema = to_schema(named_struct).unwrap(); + + assert_eq!(schema, converted_schema); + } +} diff --git a/src/common/substrait/src/types.rs b/src/common/substrait/src/types.rs new file mode 100644 index 0000000000..b15f7ed7a0 --- /dev/null +++ b/src/common/substrait/src/types.rs @@ -0,0 +1,123 @@ +//! Methods that perform convertion between Substrait's type ([Type](SType)) and GreptimeDB's type ([ConcreteDataType]). +//! +//! Substrait use [type variation](https://substrait.io/types/type_variations/) to express different "logical types". +//! Current we only have variations on integer types. Variation 0 (system prefered) are the same with base types, which +//! are signed integer (i.e. I8 -> [i8]), and Variation 1 stands for unsigned integer (i.e. I8 -> [u8]). + +use datatypes::prelude::ConcreteDataType; +use substrait_proto::protobuf::r#type::{self as s_type, Kind, Nullability}; +use substrait_proto::protobuf::Type as SType; + +use crate::error::Result; +use crate::error::{UnsupportedConcreteTypeSnafu, UnsupportedSubstraitTypeSnafu}; + +macro_rules! substrait_kind { + ($desc:ident, $concrete_ty:ident) => {{ + let nullable = $desc.nullability() == Nullability::Nullable; + let ty = ConcreteDataType::$concrete_ty(); + Ok((ty, nullable)) + }}; + + ($desc:ident, $concrete_ty:expr) => {{ + let nullable = $desc.nullability() == Nullability::Nullable; + Ok(($concrete_ty, nullable)) + }}; + + ($desc:ident, $concrete_ty_0:ident, $concrete_ty_1:ident) => {{ + let nullable = $desc.nullability() == Nullability::Nullable; + let ty = match $desc.type_variation_reference { + 0 => ConcreteDataType::$concrete_ty_0(), + 1 => ConcreteDataType::$concrete_ty_1(), + _ => UnsupportedSubstraitTypeSnafu { + ty: format!("{:?}", $desc), + } + .fail()?, + }; + Ok((ty, nullable)) + }}; +} + +/// Convert Substrait [Type](SType) to GreptimeDB's [ConcreteDataType]. The bool in return +/// tuple is the nullability identifier. +pub fn to_concrete_type(ty: &SType) -> Result<(ConcreteDataType, bool)> { + if ty.kind.is_none() { + return Ok((ConcreteDataType::null_datatype(), true)); + } + let kind = ty.kind.as_ref().unwrap(); + match kind { + Kind::Bool(desc) => substrait_kind!(desc, boolean_datatype), + Kind::I8(desc) => substrait_kind!(desc, int8_datatype, uint8_datatype), + Kind::I16(desc) => substrait_kind!(desc, int16_datatype, uint16_datatype), + Kind::I32(desc) => substrait_kind!(desc, int32_datatype, uint32_datatype), + Kind::I64(desc) => substrait_kind!(desc, int64_datatype, uint64_datatype), + Kind::Fp32(desc) => substrait_kind!(desc, float32_datatype), + Kind::Fp64(desc) => substrait_kind!(desc, float64_datatype), + Kind::String(desc) => substrait_kind!(desc, string_datatype), + Kind::Binary(desc) => substrait_kind!(desc, binary_datatype), + Kind::Timestamp(desc) => substrait_kind!( + desc, + ConcreteDataType::timestamp_datatype(Default::default()) + ), + Kind::Date(desc) => substrait_kind!(desc, date_datatype), + Kind::Time(_) + | Kind::IntervalYear(_) + | Kind::IntervalDay(_) + | Kind::TimestampTz(_) + | Kind::Uuid(_) + | Kind::FixedChar(_) + | Kind::Varchar(_) + | Kind::FixedBinary(_) + | Kind::Decimal(_) + | Kind::Struct(_) + | Kind::List(_) + | Kind::Map(_) + | Kind::UserDefinedTypeReference(_) => UnsupportedSubstraitTypeSnafu { + ty: format!("{:?}", kind), + } + .fail(), + } +} + +macro_rules! build_substrait_kind { + ($kind:ident,$s_type:ident,$nullable:ident,$variation:literal) => {{ + let nullability = match $nullable { + Some(true) => Nullability::Nullable, + Some(false) => Nullability::Required, + None => Nullability::Unspecified, + } as _; + Some(Kind::$kind(s_type::$s_type { + type_variation_reference: $variation, + nullability, + })) + }}; +} + +/// Convert GreptimeDB's [ConcreteDataType] to Substrait [Type](SType). +/// +/// Refer to [mod level documentation](super::types) for more information about type variation. +pub fn from_concrete_type(ty: ConcreteDataType, nullability: Option) -> Result { + let kind = match ty { + ConcreteDataType::Null(_) => None, + ConcreteDataType::Boolean(_) => build_substrait_kind!(Bool, Boolean, nullability, 0), + ConcreteDataType::Int8(_) => build_substrait_kind!(I8, I8, nullability, 0), + ConcreteDataType::Int16(_) => build_substrait_kind!(I16, I16, nullability, 0), + ConcreteDataType::Int32(_) => build_substrait_kind!(I32, I32, nullability, 0), + ConcreteDataType::Int64(_) => build_substrait_kind!(I64, I64, nullability, 0), + ConcreteDataType::UInt8(_) => build_substrait_kind!(I8, I8, nullability, 1), + ConcreteDataType::UInt16(_) => build_substrait_kind!(I16, I16, nullability, 1), + ConcreteDataType::UInt32(_) => build_substrait_kind!(I32, I32, nullability, 1), + ConcreteDataType::UInt64(_) => build_substrait_kind!(I64, I64, nullability, 1), + ConcreteDataType::Float32(_) => build_substrait_kind!(Fp32, Fp32, nullability, 0), + ConcreteDataType::Float64(_) => build_substrait_kind!(Fp64, Fp64, nullability, 0), + ConcreteDataType::Binary(_) => build_substrait_kind!(Binary, Binary, nullability, 0), + ConcreteDataType::String(_) => build_substrait_kind!(String, String, nullability, 0), + ConcreteDataType::Date(_) => build_substrait_kind!(Date, Date, nullability, 0), + ConcreteDataType::DateTime(_) => UnsupportedConcreteTypeSnafu { ty }.fail()?, + ConcreteDataType::Timestamp(_) => { + build_substrait_kind!(Timestamp, Timestamp, nullability, 0) + } + ConcreteDataType::List(_) => UnsupportedConcreteTypeSnafu { ty }.fail()?, + }; + + Ok(SType { kind }) +}