From 079ee8615f41f30efb38ca4acf7aef22e48edfef Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Mon, 26 Jan 2026 17:53:45 +0800 Subject: [PATCH] feat: UDF json_get with user specified return type (#7554) * feat: add return_field_from_args * feat: add JsonGetWithType * port json_get_float and json_get_bool to new implementation, add json_get with third argument accepting a scalar value for type. * fix: lint fix * chore: add sqlness tests * chore: update tests --- src/common/function/src/function.rs | 21 +- src/common/function/src/scalars/json.rs | 2 + .../function/src/scalars/json/json_get.rs | 802 +++++++++++++----- src/common/function/src/scalars/udf.rs | 7 + .../common/function/json/json_get.result | 72 ++ .../common/function/json/json_get.sql | 12 + 6 files changed, 690 insertions(+), 226 deletions(-) diff --git a/src/common/function/src/function.rs b/src/common/function/src/function.rs index cf93b35adf..b4b44d52c7 100644 --- a/src/common/function/src/function.rs +++ b/src/common/function/src/function.rs @@ -17,8 +17,8 @@ use std::fmt; use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use datafusion::arrow::datatypes::DataType; -use datafusion::logical_expr::ColumnarValue; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::logical_expr::{ColumnarValue, ReturnFieldArgs}; use datafusion_common::DataFusionError; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::config::{ConfigEntry, ConfigExtension, ExtensionOptions}; @@ -116,6 +116,23 @@ pub trait Function: fmt::Display + Sync + Send { fn aliases(&self) -> &[String] { &[] } + + /// Returns the return field for this function given the input fields. + /// + /// Default implementation extracts data types from input fields and calls + /// [`Function::return_type`], creating a generic field with the returned type. + fn return_field_from_args( + &self, + args: ReturnFieldArgs<'_>, + ) -> datafusion_common::Result> { + let input_types = args + .arg_fields + .iter() + .map(|f| f.data_type().clone()) + .collect::>(); + let return_type = self.return_type(&input_types)?; + Ok(Arc::new(Field::new(self.name(), return_type, true))) + } } pub type FunctionRef = Arc; diff --git a/src/common/function/src/scalars/json.rs b/src/common/function/src/scalars/json.rs index f84937fa0f..801744a0fa 100644 --- a/src/common/function/src/scalars/json.rs +++ b/src/common/function/src/scalars/json.rs @@ -27,6 +27,7 @@ use json_to_string::JsonToStringFunction; use parse_json::ParseJsonFunction; use crate::function_registry::FunctionRegistry; +use crate::scalars::json::json_get::JsonGetWithType; pub(crate) struct JsonFunction; @@ -40,6 +41,7 @@ impl JsonFunction { registry.register_scalar(JsonGetString::default()); registry.register_scalar(JsonGetBool::default()); registry.register_scalar(JsonGetObject::default()); + registry.register_scalar(JsonGetWithType::default()); registry.register_scalar(JsonIsNull::default()); registry.register_scalar(JsonIsInt::default()); diff --git a/src/common/function/src/scalars/json/json_get.rs b/src/common/function/src/scalars/json/json_get.rs index ce2ee08f91..ff7b1b8fbe 100644 --- a/src/common/function/src/scalars/json/json_get.rs +++ b/src/common/function/src/scalars/json/json_get.rs @@ -12,22 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::fmt::{self, Display}; use std::str::FromStr; use std::sync::Arc; use arrow::array::{ArrayRef, BinaryViewArray, StringViewArray, StructArray}; use arrow::compute; use arrow::datatypes::{Float64Type, Int64Type, UInt64Type}; +use arrow_schema::Field; use datafusion_common::arrow::array::{ Array, AsArray, BinaryViewBuilder, BooleanBuilder, Float64Builder, Int64Builder, StringViewBuilder, }; use datafusion_common::arrow::datatypes::DataType; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; use datatypes::arrow_array::{int_array_value_at_index, string_array_value_at_index}; use datatypes::json::JsonStructureSettings; +use derive_more::Display; use jsonpath_rust::JsonPath; use serde_json::Value; @@ -49,101 +50,6 @@ fn get_json_by_path(json: &[u8], path: &str) -> Option> { } } -/// Get the value from the JSONB by the given path and return it as specified type. -/// If the path does not exist or the value is not the type specified, return `NULL`. -macro_rules! json_get { - // e.g. name = JsonGetInt, type = Int64, rust_type = i64, doc = "Get the value from the JSONB by the given path and return it as an integer." - ($name:ident, $type:ident, $rust_type:ident, $doc:expr) => { - paste::paste! { - #[doc = $doc] - #[derive(Clone, Debug)] - pub struct $name { - signature: Signature, - } - - impl $name { - pub const NAME: &'static str = stringify!([<$name:snake>]); - } - - impl Default for $name { - fn default() -> Self { - Self { - // TODO(LFC): Use a more clear type here instead of "Binary" for Json input, once we have a "Json" type. - signature: helper::one_of_sigs2( - vec![DataType::Binary, DataType::BinaryView], - vec![DataType::Utf8, DataType::Utf8View], - ), - } - } - } - - impl Function for $name { - fn name(&self) -> &str { - Self::NAME - } - - fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { - Ok(DataType::[<$type>]) - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn invoke_with_args( - &self, - args: ScalarFunctionArgs, - ) -> datafusion_common::Result { - let [arg0, arg1] = extract_args(self.name(), &args)?; - let arg0 = compute::cast(&arg0, &DataType::BinaryView)?; - let jsons = arg0.as_binary_view(); - let arg1 = compute::cast(&arg1, &DataType::Utf8View)?; - let paths = arg1.as_string_view(); - - let size = jsons.len(); - let mut builder = [<$type Builder>]::with_capacity(size); - - for i in 0..size { - let json = jsons.is_valid(i).then(|| jsons.value(i)); - let path = paths.is_valid(i).then(|| paths.value(i)); - let result = match (json, path) { - (Some(json), Some(path)) => { - get_json_by_path(json, path) - .and_then(|json| { jsonb::[](&json).ok() }) - } - _ => None, - }; - - builder.append_option(result); - } - - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) - } - } - - impl Display for $name { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", Self::NAME.to_ascii_uppercase()) - } - } - } - }; -} - -json_get!( - JsonGetFloat, - Float64, - f64, - "Get the value from the JSONB by the given path and return it as a float." -); - -json_get!( - JsonGetBool, - Boolean, - bool, - "Get the value from the JSONB by the given path and return it as a boolean." -); - enum JsonResultValue<'a> { Jsonb(Vec), JsonStructByColumn(&'a ArrayRef, usize), @@ -164,6 +70,7 @@ trait JsonGetResultBuilder { /// based on a path expression. Different JSON get functions reuse this /// implementation by supplying their own `JsonGetResultBuilder` to control /// how the resulting values are materialized into an Arrow array. +#[derive(Debug)] struct JsonGet { signature: Signature, } @@ -210,7 +117,42 @@ impl Default for JsonGet { } } -#[derive(Default)] +struct StringResultBuilder(StringViewBuilder); + +impl JsonGetResultBuilder for StringResultBuilder { + fn append_value(&mut self, value: JsonResultValue<'_>) -> Result<()> { + match value { + JsonResultValue::Jsonb(value) => self.0.append_option(jsonb::to_str(&value).ok()), + JsonResultValue::JsonStructByColumn(column, i) => { + if let Some(v) = string_array_value_at_index(column, i) { + self.0.append_value(v); + } else { + self.0 + .append_value(arrow_cast::display::array_value_to_string(column, i)?); + } + } + JsonResultValue::JsonStructByValue(value) => { + if let Some(s) = value.as_str() { + self.0.append_value(s) + } else { + self.0.append_value(value.to_string()) + } + } + } + Ok(()) + } + + fn append_null(&mut self) { + self.0.append_null(); + } + + fn build(&mut self) -> ArrayRef { + Arc::new(self.0.finish()) + } +} + +#[derive(Default, Display, Debug)] +#[display("{}", Self::NAME.to_ascii_uppercase())] pub struct JsonGetString(JsonGet); impl JsonGetString { @@ -231,51 +173,37 @@ impl Function for JsonGetString { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - struct StringResultBuilder(StringViewBuilder); - - impl JsonGetResultBuilder for StringResultBuilder { - fn append_value(&mut self, value: JsonResultValue<'_>) -> Result<()> { - match value { - JsonResultValue::Jsonb(value) => { - self.0.append_option(jsonb::to_str(&value).ok()) - } - JsonResultValue::JsonStructByColumn(column, i) => { - if let Some(v) = string_array_value_at_index(column, i) { - self.0.append_value(v); - } else { - self.0 - .append_value(arrow_cast::display::array_value_to_string( - column, i, - )?); - } - } - JsonResultValue::JsonStructByValue(value) => { - if let Some(s) = value.as_str() { - self.0.append_value(s) - } else { - self.0.append_value(value.to_string()) - } - } - } - Ok(()) - } - - fn append_null(&mut self) { - self.0.append_null(); - } - - fn build(&mut self) -> ArrayRef { - Arc::new(self.0.finish()) - } - } - self.0.invoke(args, |len: usize| { StringResultBuilder(StringViewBuilder::with_capacity(len)) }) } } -#[derive(Default)] +struct IntResultBuilder(Int64Builder); + +impl JsonGetResultBuilder for IntResultBuilder { + fn append_value(&mut self, value: JsonResultValue<'_>) -> Result<()> { + match value { + JsonResultValue::Jsonb(value) => self.0.append_option(jsonb::to_i64(&value).ok()), + JsonResultValue::JsonStructByColumn(column, i) => { + self.0.append_option(int_array_value_at_index(column, i)) + } + JsonResultValue::JsonStructByValue(value) => self.0.append_option(value.as_i64()), + } + Ok(()) + } + + fn append_null(&mut self) { + self.0.append_null(); + } + + fn build(&mut self) -> ArrayRef { + Arc::new(self.0.finish()) + } +} + +#[derive(Default, Display, Debug)] +#[display("{}", Self::NAME.to_ascii_uppercase())] pub struct JsonGetInt(JsonGet); impl JsonGetInt { @@ -296,49 +224,134 @@ impl Function for JsonGetInt { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - struct IntResultBuilder(Int64Builder); - - impl JsonGetResultBuilder for IntResultBuilder { - fn append_value(&mut self, value: JsonResultValue<'_>) -> Result<()> { - match value { - JsonResultValue::Jsonb(value) => { - self.0.append_option(jsonb::to_i64(&value).ok()) - } - JsonResultValue::JsonStructByColumn(column, i) => { - self.0.append_option(int_array_value_at_index(column, i)) - } - JsonResultValue::JsonStructByValue(value) => { - self.0.append_option(value.as_i64()) - } - } - Ok(()) - } - - fn append_null(&mut self) { - self.0.append_null(); - } - - fn build(&mut self) -> ArrayRef { - Arc::new(self.0.finish()) - } - } - self.0.invoke(args, |len: usize| { IntResultBuilder(Int64Builder::with_capacity(len)) }) } } -impl Display for JsonGetInt { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", Self::NAME.to_ascii_uppercase()) +struct FloatResultBuilder(Float64Builder); + +impl JsonGetResultBuilder for FloatResultBuilder { + fn append_value(&mut self, value: JsonResultValue<'_>) -> Result<()> { + match value { + JsonResultValue::Jsonb(value) => self.0.append_option(jsonb::to_f64(&value).ok()), + JsonResultValue::JsonStructByColumn(column, i) => { + let result = if column.data_type() == &DataType::Float64 { + column + .as_primitive::() + .is_valid(i) + .then(|| column.as_primitive::().value(i)) + } else { + None + }; + self.0.append_option(result); + } + JsonResultValue::JsonStructByValue(value) => self.0.append_option(value.as_f64()), + } + Ok(()) + } + + fn append_null(&mut self) { + self.0.append_null(); + } + + fn build(&mut self) -> ArrayRef { + Arc::new(self.0.finish()) + } +} + +#[derive(Default, Display, Debug)] +#[display("{}", Self::NAME.to_ascii_uppercase())] +pub struct JsonGetFloat(JsonGet); + +impl JsonGetFloat { + pub const NAME: &'static str = "json_get_float"; +} + +impl Function for JsonGetFloat { + fn name(&self) -> &str { + Self::NAME + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn signature(&self) -> &Signature { + &self.0.signature + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + self.0.invoke(args, |len: usize| { + FloatResultBuilder(Float64Builder::with_capacity(len)) + }) + } +} + +struct BoolResultBuilder(BooleanBuilder); + +impl JsonGetResultBuilder for BoolResultBuilder { + fn append_value(&mut self, value: JsonResultValue<'_>) -> Result<()> { + match value { + JsonResultValue::Jsonb(value) => self.0.append_option(jsonb::to_bool(&value).ok()), + JsonResultValue::JsonStructByColumn(column, i) => { + let result = if column.data_type() == &DataType::Boolean { + column + .as_boolean() + .is_valid(i) + .then(|| column.as_boolean().value(i)) + } else { + None + }; + self.0.append_option(result); + } + JsonResultValue::JsonStructByValue(value) => self.0.append_option(value.as_bool()), + } + Ok(()) + } + + fn append_null(&mut self) { + self.0.append_null(); + } + + fn build(&mut self) -> ArrayRef { + Arc::new(self.0.finish()) + } +} + +#[derive(Default, Display, Debug)] +#[display("{}", Self::NAME.to_ascii_uppercase())] +pub struct JsonGetBool(JsonGet); + +impl JsonGetBool { + pub const NAME: &'static str = "json_get_bool"; +} + +impl Function for JsonGetBool { + fn name(&self) -> &str { + Self::NAME + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn signature(&self) -> &Signature { + &self.0.signature + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + self.0.invoke(args, |len: usize| { + BoolResultBuilder(BooleanBuilder::with_capacity(len)) + }) } } fn jsonb_get( jsons: &BinaryViewArray, paths: &StringViewArray, - builder: &mut impl JsonGetResultBuilder, + builder: &mut dyn JsonGetResultBuilder, ) -> Result<()> { let size = jsons.len(); for i in 0..size { @@ -360,7 +373,7 @@ fn jsonb_get( fn json_struct_get( jsons: &StructArray, paths: &StringViewArray, - builder: &mut impl JsonGetResultBuilder, + builder: &mut dyn JsonGetResultBuilder, ) -> Result<()> { let size = jsons.len(); for i in 0..size { @@ -494,13 +507,121 @@ fn json_struct_to_value(raw: &str, jsons: &StructArray, i: usize) -> Result fmt::Result { - write!(f, "{}", Self::NAME.to_ascii_uppercase()) +#[derive(Debug, Display)] +#[display("{}", Self::NAME.to_ascii_uppercase())] +pub(super) struct JsonGetWithType { + signature: Signature, +} + +impl JsonGetWithType { + const NAME: &'static str = "json_get"; +} + +impl Default for JsonGetWithType { + fn default() -> Self { + Self { + signature: Signature::any(3, Volatility::Immutable), + } + } +} + +impl Function for JsonGetWithType { + fn name(&self) -> &str { + Self::NAME + } + + fn return_type(&self, _input_types: &[DataType]) -> datafusion_common::Result { + Err(DataFusionError::Internal( + "This method isn't meant to be called".to_string(), + )) + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs<'_>, + ) -> datafusion_common::Result> { + match args.scalar_arguments[2] { + Some(ScalarValue::Utf8(Some(type_str))) + | Some(ScalarValue::Utf8View(Some(type_str))) + | Some(ScalarValue::LargeUtf8(Some(type_str))) => { + let type_str = type_str.to_ascii_lowercase(); + match type_str.as_str() { + "bool" | "boolean" => { + Ok(Arc::new(Field::new(self.name(), DataType::Boolean, true))) + } + "int" | "integer" => { + Ok(Arc::new(Field::new(self.name(), DataType::Int64, true))) + } + "float" | "double" => { + Ok(Arc::new(Field::new(self.name(), DataType::Float64, true))) + } + "string" => Ok(Arc::new(Field::new(self.name(), DataType::Utf8View, true))), + _ => Err(DataFusionError::Internal(format!( + "Unsupported type: {}", + type_str + ))), + } + } + _ => Err(DataFusionError::Internal( + "Invalid argument provided for type".to_string(), + )), + } + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0, arg1, _] = extract_args("JSON_GET", &args)?; + let len = arg0.len(); + + let arg1 = compute::cast(&arg1, &DataType::Utf8View)?; + let paths = arg1.as_string_view(); + + // mapping datatypes returned from return_field_from_args + let mut builder: Box = match args.return_field.data_type() { + DataType::Utf8View => { + Box::new(StringResultBuilder(StringViewBuilder::with_capacity(len))) + } + DataType::Int64 => Box::new(IntResultBuilder(Int64Builder::with_capacity(len))), + DataType::Float64 => Box::new(FloatResultBuilder(Float64Builder::with_capacity(len))), + DataType::Boolean => Box::new(BoolResultBuilder(BooleanBuilder::with_capacity(len))), + _ => { + return Err(DataFusionError::Internal( + "Unsupported return type".to_string(), + )); + } + }; + + match arg0.data_type() { + DataType::Binary | DataType::LargeBinary | DataType::BinaryView => { + let arg0 = compute::cast(&arg0, &DataType::BinaryView)?; + let jsons = arg0.as_binary_view(); + jsonb_get(jsons, paths, builder.as_mut())?; + } + DataType::Struct(_) => { + let jsons = arg0.as_struct(); + json_struct_get(jsons, paths, builder.as_mut())?; + } + _ => { + return Err(DataFusionError::Execution(format!( + "JSON_GET not supported argument type {}", + arg0.data_type(), + ))); + } + }; + + Ok(ColumnarValue::Array(builder.build())) } } /// Get the object from JSON value by path. +#[derive(Display, Debug)] +#[display("{}", Self::NAME.to_ascii_uppercase())] pub(super) struct JsonGetObject { signature: Signature, } @@ -571,12 +692,6 @@ impl Function for JsonGetObject { } } -impl Display for JsonGetObject { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", Self::NAME.to_ascii_uppercase()) - } -} - #[cfg(test)] mod tests { use std::sync::Arc; @@ -721,37 +836,55 @@ mod tests { r#"{"a": 4.4, "b": {"c": 6.6}, "c": 6.6}"#, r#"{"a": 7.7, "b": 8.8, "c": {"a": 7.7}}"#, ]; - let paths = vec!["$.a.b", "$.a", "$.c"]; - let results = [Some(2.1), Some(4.4), None]; + let json_struct = test_json_struct(); - let jsonbs = json_strings + let path_expects = vec![ + ("$.a.b", Some(2.1)), + ("$.a", Some(4.4)), + ("$.c", None), + ("$.kind", None), + ("$.payload.code", None), + ("$.payload.success", None), + ("$.payload.result.time_cost", Some(1.234)), + ("$.payload.not-exists", None), + ("$.not-exists", None), + ("$", None), + ]; + + let mut jsons = json_strings .iter() .map(|s| { let value = jsonb::parse_value(s.as_bytes()).unwrap(); - value.to_vec() + Arc::new(BinaryArray::from_iter_values([value.to_vec()])) as ArrayRef }) .collect::>(); + let json_struct_arrays = + std::iter::repeat_n(json_struct, path_expects.len() - jsons.len()).collect::>(); + jsons.extend(json_struct_arrays); - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new(BinaryArray::from_iter_values(jsonbs))), - ColumnarValue::Array(Arc::new(StringArray::from_iter_values(paths))), - ], - arg_fields: vec![], - number_rows: 3, - return_field: Arc::new(Field::new("x", DataType::Float64, false)), - config_options: Arc::new(Default::default()), - }; - let result = json_get_float - .invoke_with_args(args) - .and_then(|x| x.to_array(3)) - .unwrap(); - let vector = result.as_primitive::(); + for i in 0..jsons.len() { + let json = &jsons[i]; + let (path, expect) = path_expects[i]; - assert_eq!(3, vector.len()); - for (i, gt) in results.iter().enumerate() { - let result = vector.is_valid(i).then(|| vector.value(i)); - assert_eq!(*gt, result); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(json.clone()), + ColumnarValue::Scalar(path.into()), + ], + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(Field::new("x", DataType::Float64, false)), + config_options: Arc::new(Default::default()), + }; + let result = json_get_float + .invoke_with_args(args) + .and_then(|x| x.to_array(1)) + .unwrap(); + + let result = result.as_primitive::(); + assert_eq!(1, result.len()); + let actual = result.is_valid(0).then(|| result.value(0)); + assert_eq!(actual, expect); } } @@ -772,37 +905,55 @@ mod tests { r#"{"a": false, "b": {"c": true}, "c": false}"#, r#"{"a": true, "b": false, "c": {"a": true}}"#, ]; - let paths = vec!["$.a.b", "$.a", "$.c"]; - let results = [Some(true), Some(false), None]; + let json_struct = test_json_struct(); - let jsonbs = json_strings + let path_expects = vec![ + ("$.a.b", Some(true)), + ("$.a", Some(false)), + ("$.c", None), + ("$.kind", None), + ("$.payload.code", None), + ("$.payload.success", Some(false)), + ("$.payload.result.time_cost", None), + ("$.payload.not-exists", None), + ("$.not-exists", None), + ("$", None), + ]; + + let mut jsons = json_strings .iter() .map(|s| { let value = jsonb::parse_value(s.as_bytes()).unwrap(); - value.to_vec() + Arc::new(BinaryArray::from_iter_values([value.to_vec()])) as ArrayRef }) .collect::>(); + let json_struct_arrays = + std::iter::repeat_n(json_struct, path_expects.len() - jsons.len()).collect::>(); + jsons.extend(json_struct_arrays); - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new(BinaryArray::from_iter_values(jsonbs))), - ColumnarValue::Array(Arc::new(StringArray::from_iter_values(paths))), - ], - arg_fields: vec![], - number_rows: 3, - return_field: Arc::new(Field::new("x", DataType::Boolean, false)), - config_options: Arc::new(Default::default()), - }; - let result = json_get_bool - .invoke_with_args(args) - .and_then(|x| x.to_array(3)) - .unwrap(); - let vector = result.as_boolean(); + for i in 0..jsons.len() { + let json = &jsons[i]; + let (path, expect) = path_expects[i]; - assert_eq!(3, vector.len()); - for (i, gt) in results.iter().enumerate() { - let result = vector.is_valid(i).then(|| vector.value(i)); - assert_eq!(*gt, result); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(json.clone()), + ColumnarValue::Scalar(path.into()), + ], + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(Field::new("x", DataType::Boolean, false)), + config_options: Arc::new(Default::default()), + }; + let result = json_get_bool + .invoke_with_args(args) + .and_then(|x| x.to_array(1)) + .unwrap(); + + let result = result.as_boolean(); + assert_eq!(1, result.len()); + let actual = result.is_valid(0).then(|| result.value(0)); + assert_eq!(actual, expect); } } @@ -944,4 +1095,207 @@ mod tests { assert_eq!(result, expected); Ok(()) } + + #[test] + fn test_json_get_with_type() { + let json_get_with_type = JsonGetWithType::default(); + + assert_eq!("json_get", json_get_with_type.name()); + + let json_strings = [ + r#"{"a": {"b": "a"}, "b": "b", "c": "c"}"#, + r#"{"a": "d", "b": {"c": "e"}, "c": "f"}"#, + r#"{"a": "g", "b": "h", "c": {"a": "g"}}"#, + ]; + let json_struct = test_json_struct(); + + let paths = vec![ + "$.a.b", + "$.a", + "", + "$.kind", + "$.payload.code", + "$.payload.result.time_cost", + "$.payload", + "$.payload.success", + "$.payload.result", + "$.payload.result.error", + "$.payload.result.not-exists", + "$.payload.not-exists", + "$.not-exists", + "$", + ]; + let expects = [ + Some("a"), + Some("d"), + None, + Some("foo"), + Some("404"), + Some("1.234"), + Some( + r#"{"code":404,"result":{"error":"not found","time_cost":1.234},"success":false}"#, + ), + Some("false"), + Some(r#"{"error":"not found","time_cost":1.234}"#), + Some("not found"), + None, + None, + None, + Some( + r#"{"kind":"foo","payload":{"code":404,"result":{"error":"not found","time_cost":1.234},"success":false}}"#, + ), + ]; + + let mut jsons = json_strings + .iter() + .map(|s| { + let value = jsonb::parse_value(s.as_bytes()).unwrap(); + Arc::new(BinaryArray::from_iter_values([value.to_vec()])) as ArrayRef + }) + .collect::>(); + let json_struct_arrays = + std::iter::repeat_n(json_struct, expects.len() - jsons.len()).collect::>(); + jsons.extend(json_struct_arrays); + + for i in 0..jsons.len() { + let json = &jsons[i]; + let path = paths[i]; + let expect = expects[i]; + + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(json.clone()), + ColumnarValue::Scalar(path.into()), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("string".to_string()))), + ], + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let result = json_get_with_type + .invoke_with_args(args) + .and_then(|x| x.to_array(1)) + .unwrap(); + + let result = result.as_string_view(); + assert_eq!(1, result.len()); + let actual = result.is_valid(0).then(|| result.value(0)); + assert_eq!(actual, expect); + } + + let json_strings = [ + r#"{"a": {"b": 2}, "b": 2, "c": 3}"#, + r#"{"a": 4, "b": {"c": 6}, "c": 6}"#, + r#"{"a": 7, "b": 8, "c": {"a": 7}}"#, + ]; + let paths = ["$.a.b", "$.a", "$.c", "$.payload.code"]; + let expects = [Some(2), Some(4), None, Some(404)]; + + for (i, (path, expect)) in paths.iter().zip(expects.iter()).enumerate() { + let json = if i < json_strings.len() { + let value = jsonb::parse_value(json_strings[i].as_bytes()).unwrap(); + Arc::new(BinaryArray::from_iter_values([value.to_vec()])) as ArrayRef + } else { + test_json_struct() + }; + + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(json), + ColumnarValue::Scalar((*path).into()), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("int".to_string()))), + ], + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(Field::new("x", DataType::Int64, false)), + config_options: Arc::new(Default::default()), + }; + let result = json_get_with_type + .invoke_with_args(args) + .and_then(|x| x.to_array(1)) + .unwrap(); + + let result = result.as_primitive::(); + assert_eq!(1, result.len()); + let actual = result.is_valid(0).then(|| result.value(0)); + assert_eq!(actual, *expect); + } + + let json_strings = [ + r#"{"a": {"b": 2.1}, "b": 2.2, "c": 3.3}"#, + r#"{"a": 4.4, "b": {"c": 6.6}, "c": 6.6}"#, + r#"{"a": 7.7, "b": 8.8, "c": {"a": 7.7}}"#, + ]; + let paths = ["$.a.b", "$.a", "$.c", "$.payload.result.time_cost"]; + let expects = [Some(2.1), Some(4.4), None, Some(1.234)]; + + for (i, (path, expect)) in paths.iter().zip(expects.iter()).enumerate() { + let json = if i < json_strings.len() { + let value = jsonb::parse_value(json_strings[i].as_bytes()).unwrap(); + Arc::new(BinaryArray::from_iter_values([value.to_vec()])) as ArrayRef + } else { + test_json_struct() + }; + + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(json), + ColumnarValue::Scalar((*path).into()), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("float".to_string()))), + ], + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(Field::new("x", DataType::Float64, false)), + config_options: Arc::new(Default::default()), + }; + let result = json_get_with_type + .invoke_with_args(args) + .and_then(|x| x.to_array(1)) + .unwrap(); + + let result = result.as_primitive::(); + assert_eq!(1, result.len()); + let actual = result.is_valid(0).then(|| result.value(0)); + assert_eq!(actual, *expect); + } + + let json_strings = [ + r#"{"a": {"b": true}, "b": false, "c": true}"#, + r#"{"a": false, "b": {"c": true}, "c": false}"#, + r#"{"a": true, "b": false, "c": {"a": true}}"#, + ]; + let paths = ["$.a.b", "$.a", "$.c", "$.payload.success"]; + let expects = [Some(true), Some(false), None, Some(false)]; + + for (i, (path, expect)) in paths.iter().zip(expects.iter()).enumerate() { + let json = if i < json_strings.len() { + let value = jsonb::parse_value(json_strings[i].as_bytes()).unwrap(); + Arc::new(BinaryArray::from_iter_values([value.to_vec()])) as ArrayRef + } else { + test_json_struct() + }; + + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(json), + ColumnarValue::Scalar((*path).into()), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("bool".to_string()))), + ], + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(Field::new("x", DataType::Boolean, false)), + config_options: Arc::new(Default::default()), + }; + let result = json_get_with_type + .invoke_with_args(args) + .and_then(|x| x.to_array(1)) + .unwrap(); + + let result = result.as_boolean(); + assert_eq!(1, result.len()); + let actual = result.is_valid(0).then(|| result.value(0)); + assert_eq!(actual, *expect); + } + } } diff --git a/src/common/function/src/scalars/udf.rs b/src/common/function/src/scalars/udf.rs index eee3ede801..638f7c38af 100644 --- a/src/common/function/src/scalars/udf.rs +++ b/src/common/function/src/scalars/udf.rs @@ -69,6 +69,13 @@ impl ScalarUDFImpl for ScalarUdf { self.function.return_type(arg_types) } + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> datafusion_common::Result { + self.function.return_field_from_args(args) + } + fn invoke_with_args( &self, args: ScalarFunctionArgs, diff --git a/tests/cases/standalone/common/function/json/json_get.result b/tests/cases/standalone/common/function/json/json_get.result index 5f17415d0c..8f605b82bd 100644 --- a/tests/cases/standalone/common/function/json/json_get.result +++ b/tests/cases/standalone/common/function/json/json_get.result @@ -293,6 +293,78 @@ SELECT json_get_int(json_get_object(j, '[9]'), 'a.i') FROM jsons; | | +----------------------------------------------------------------+ +SELECT json_get(j, '[0]', 'int') FROM jsons; + ++-------------------------------------------+ +| json_get(jsons.j,Utf8("[0]"),Utf8("int")) | ++-------------------------------------------+ +| | +| 1 | +| 1 | +| | +| | ++-------------------------------------------+ + +SELECT json_get(j, '[0]', 'float') FROM jsons; + ++---------------------------------------------+ +| json_get(jsons.j,Utf8("[0]"),Utf8("float")) | ++---------------------------------------------+ +| | +| 1.0 | +| 1.0 | +| 1.2 | +| | ++---------------------------------------------+ + +SELECT json_get(j, '[1]', 'int') FROM jsons; + ++-------------------------------------------+ +| json_get(jsons.j,Utf8("[1]"),Utf8("int")) | ++-------------------------------------------+ +| | +| 0 | +| 0 | +| | +| | ++-------------------------------------------+ + +SELECT json_get(j, '[1]', 'float') FROM jsons; + ++---------------------------------------------+ +| json_get(jsons.j,Utf8("[1]"),Utf8("float")) | ++---------------------------------------------+ +| | +| 0.0 | +| 0.0 | +| 3.141592653589793 | +| | ++---------------------------------------------+ + +SELECT json_get(j, '[2]', 'bool') FROM jsons; + ++--------------------------------------------+ +| json_get(jsons.j,Utf8("[2]"),Utf8("bool")) | ++--------------------------------------------+ +| | +| false | +| | +| | +| | ++--------------------------------------------+ + +SELECT json_get(j, '[3]', 'string') FROM jsons; + ++--------------------------------------------------------+ +| json_get(jsons.j,Utf8("[3]"),Utf8("string")) | ++--------------------------------------------------------+ +| Long time ago, there is a little pig flying in the sky | +| false | +| 2147483648 | +| 1e100 | +| | ++--------------------------------------------------------+ + DROP TABLE jsons; Affected Rows: 0 diff --git a/tests/cases/standalone/common/function/json/json_get.sql b/tests/cases/standalone/common/function/json/json_get.sql index 010a3bd7a7..061815b922 100644 --- a/tests/cases/standalone/common/function/json/json_get.sql +++ b/tests/cases/standalone/common/function/json/json_get.sql @@ -73,6 +73,18 @@ SELECT json_get_int(json_get_object(j, '[0]'), 'a.i') FROM jsons; SELECT json_get_int(json_get_object(j, '[9]'), 'a.i') FROM jsons; +SELECT json_get(j, '[0]', 'int') FROM jsons; + +SELECT json_get(j, '[0]', 'float') FROM jsons; + +SELECT json_get(j, '[1]', 'int') FROM jsons; + +SELECT json_get(j, '[1]', 'float') FROM jsons; + +SELECT json_get(j, '[2]', 'bool') FROM jsons; + +SELECT json_get(j, '[3]', 'string') FROM jsons; + DROP TABLE jsons; -- test functions in WHERE clause --