mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-08 22:32:55 +00:00
refactor: rewrite some UDFs to DataFusion style (part 3) (#6990)
* refactor: rewrite some UDFs to DataFusion style (part 3) Signed-off-by: luofucong <luofc@foxmail.com> * resolve PR comments Signed-off-by: luofucong <luofc@foxmail.com> * resolve PR comments Signed-off-by: luofucong <luofc@foxmail.com> * resolve PR comments Signed-off-by: luofucong <luofc@foxmail.com> --------- Signed-off-by: luofucong <luofc@foxmail.com>
This commit is contained in:
@@ -17,7 +17,6 @@ use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use arrow::array::{ArrayRef, Float64Array, Int64Array, UInt64Array};
|
||||
use arrow::record_batch::RecordBatch;
|
||||
use arrow_schema::SchemaRef;
|
||||
use datafusion::catalog::{Session, TableProvider};
|
||||
@@ -32,10 +31,14 @@ use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
|
||||
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
|
||||
use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
|
||||
use datafusion::prelude::SessionContext;
|
||||
use datafusion_common::arrow::array::{ArrayRef, AsArray, Float64Array, Int64Array, UInt64Array};
|
||||
use datafusion_common::arrow::datatypes::{Float64Type, UInt64Type};
|
||||
use datafusion_common::{Column, TableReference};
|
||||
use datafusion_expr::expr::AggregateFunction;
|
||||
use datafusion_expr::sqlparser::ast::NullTreatment;
|
||||
use datafusion_expr::{Aggregate, Expr, LogicalPlan, SortExpr, TableScan, lit};
|
||||
use datafusion_expr::{
|
||||
Aggregate, ColumnarValue, Expr, LogicalPlan, ScalarFunctionArgs, SortExpr, TableScan, lit,
|
||||
};
|
||||
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
|
||||
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
|
||||
use datatypes::arrow_array::StringArray;
|
||||
@@ -649,14 +652,20 @@ async fn test_udaf_correct_eval_result() {
|
||||
expected_output: None,
|
||||
expected_fn: Some(|arr| {
|
||||
let percent = ScalarValue::Float64(Some(0.5)).to_array().unwrap();
|
||||
let percent = datatypes::vectors::Helper::try_into_vector(percent).unwrap();
|
||||
let state = datatypes::vectors::Helper::try_into_vector(arr).unwrap();
|
||||
let percent = ColumnarValue::Array(percent);
|
||||
let state = ColumnarValue::Array(arr);
|
||||
let udd_calc = UddSketchCalcFunction;
|
||||
let res = udd_calc
|
||||
.eval(&Default::default(), &[percent, state])
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![percent, state],
|
||||
arg_fields: vec![],
|
||||
number_rows: 1,
|
||||
return_field: Arc::new(Field::new("x", DataType::Float64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
let binding = res.to_arrow_array();
|
||||
let res_arr = binding.as_any().downcast_ref::<Float64Array>().unwrap();
|
||||
let binding = res.to_array(1).unwrap();
|
||||
let res_arr = binding.as_primitive::<Float64Type>();
|
||||
assert!(res_arr.len() == 1);
|
||||
assert!((res_arr.value(0) - 2.856578984907706f64).abs() <= f64::EPSILON);
|
||||
true
|
||||
@@ -683,11 +692,20 @@ async fn test_udaf_correct_eval_result() {
|
||||
]))],
|
||||
expected_output: None,
|
||||
expected_fn: Some(|arr| {
|
||||
let state = datatypes::vectors::Helper::try_into_vector(arr).unwrap();
|
||||
let number_rows = arr.len();
|
||||
let state = ColumnarValue::Array(arr);
|
||||
let hll_calc = HllCalcFunction;
|
||||
let res = hll_calc.eval(&Default::default(), &[state]).unwrap();
|
||||
let binding = res.to_arrow_array();
|
||||
let res_arr = binding.as_any().downcast_ref::<UInt64Array>().unwrap();
|
||||
let res = hll_calc
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![state],
|
||||
arg_fields: vec![],
|
||||
number_rows,
|
||||
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
let binding = res.to_array(1).unwrap();
|
||||
let res_arr = binding.as_primitive::<UInt64Type>();
|
||||
assert!(res_arr.len() == 1);
|
||||
assert_eq!(res_arr.value(0), 3);
|
||||
true
|
||||
|
||||
@@ -12,7 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::any::Any;
|
||||
use std::fmt;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_error::ext::{BoxedError, PlainError};
|
||||
@@ -20,6 +22,9 @@ use common_error::status_code::StatusCode;
|
||||
use common_query::error::{ExecuteSnafu, Result};
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::ArrayRef;
|
||||
use datafusion_common::config::{ConfigEntry, ConfigExtension, ExtensionOptions};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::vectors::VectorRef;
|
||||
use session::context::{QueryContextBuilder, QueryContextRef};
|
||||
@@ -60,6 +65,42 @@ impl Default for FunctionContext {
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtensionOptions for FunctionContext {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn cloned(&self) -> Box<dyn ExtensionOptions> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
|
||||
fn set(&mut self, _: &str, _: &str) -> datafusion_common::Result<()> {
|
||||
Err(DataFusionError::NotImplemented(
|
||||
"set options for `FunctionContext`".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn entries(&self) -> Vec<ConfigEntry> {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for FunctionContext {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("FunctionContext")
|
||||
.field("query_ctx", &self.query_ctx)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl ConfigExtension for FunctionContext {
|
||||
const PREFIX: &'static str = "FunctionContext";
|
||||
}
|
||||
|
||||
/// Scalar function trait, modified from databend to adapt datafusion
|
||||
/// TODO(dennis): optimize function by it's features such as monotonicity etc.
|
||||
pub trait Function: fmt::Display + Sync + Send {
|
||||
@@ -99,3 +140,26 @@ pub trait Function: fmt::Display + Sync + Send {
|
||||
}
|
||||
|
||||
pub type FunctionRef = Arc<dyn Function>;
|
||||
|
||||
/// Find the [FunctionContext] in the [ScalarFunctionArgs]. The [FunctionContext] was set
|
||||
/// previously in the DataFusion session context creation, and is passed all the way down to the
|
||||
/// args by DataFusion.
|
||||
pub(crate) fn find_function_context(
|
||||
args: &ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<&FunctionContext> {
|
||||
let Some(x) = args.config_options.extensions.get::<FunctionContext>() else {
|
||||
return Err(DataFusionError::Execution(
|
||||
"function context is not set".to_string(),
|
||||
));
|
||||
};
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
/// Extract UDF arguments (as Arrow's [ArrayRef]) from [ScalarFunctionArgs] directly.
|
||||
pub(crate) fn extract_args<const N: usize>(
|
||||
name: &str,
|
||||
args: &ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<[ArrayRef; N]> {
|
||||
ColumnarValue::values_to_arrays(&args.args)
|
||||
.and_then(|x| datafusion_common::utils::take_function_args(name, x))
|
||||
}
|
||||
|
||||
@@ -16,13 +16,12 @@ use std::fmt;
|
||||
|
||||
use common_query::error::{ArrowComputeSnafu, Result};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::utils;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::arrow::compute::kernels::numeric;
|
||||
use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function::{Function, extract_args};
|
||||
use crate::helper;
|
||||
|
||||
/// A function adds an interval value to Timestamp, Date, and return the result.
|
||||
@@ -63,8 +62,7 @@ impl Function for DateAddFunction {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [left, right] = utils::take_function_args(self.name(), args)?;
|
||||
let [left, right] = extract_args(self.name(), &args)?;
|
||||
|
||||
let result = numeric::add(&left, &right).context(ArrowComputeSnafu)?;
|
||||
Ok(ColumnarValue::Array(result))
|
||||
|
||||
@@ -16,13 +16,12 @@ use std::fmt;
|
||||
|
||||
use common_query::error::{ArrowComputeSnafu, Result};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::utils;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::arrow::compute::kernels::numeric;
|
||||
use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function::{Function, extract_args};
|
||||
use crate::helper;
|
||||
|
||||
/// A function subtracts an interval value to Timestamp, Date, and return the result.
|
||||
@@ -63,8 +62,7 @@ impl Function for DateSubFunction {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [left, right] = utils::take_function_args(self.name(), args)?;
|
||||
let [left, right] = extract_args(self.name(), &args)?;
|
||||
|
||||
let result = numeric::sub(&left, &right).context(ArrowComputeSnafu)?;
|
||||
Ok(ColumnarValue::Array(result))
|
||||
|
||||
@@ -16,17 +16,12 @@ use std::fmt;
|
||||
use std::fmt::Display;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error;
|
||||
use common_query::error::{ArrowComputeSnafu, InvalidFuncArgsSnafu, Result};
|
||||
use datafusion::arrow::array::ArrayRef;
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::compute::is_null;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::prelude::VectorRef;
|
||||
use datatypes::vectors::Helper;
|
||||
use snafu::{ResultExt, ensure};
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::{Function, extract_args};
|
||||
|
||||
const NAME: &str = "isnull";
|
||||
|
||||
@@ -53,25 +48,14 @@ impl Function for IsNullFunction {
|
||||
Signature::any(1, Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
columns: &[VectorRef],
|
||||
) -> common_query::error::Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly one, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
let values = &columns[0];
|
||||
let arrow_array = &values.to_arrow_array();
|
||||
let result = is_null(arrow_array).context(ArrowComputeSnafu)?;
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0] = extract_args(self.name(), &args)?;
|
||||
let result = is_null(&arg0)?;
|
||||
|
||||
Helper::try_into_vector(Arc::new(result) as ArrayRef).context(error::FromArrowArraySnafu)
|
||||
Ok(ColumnarValue::Array(Arc::new(result)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,9 +63,9 @@ impl Function for IsNullFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::Field;
|
||||
use datafusion_common::arrow::array::{AsArray, BooleanArray, Float32Array};
|
||||
use datafusion_expr::TypeSignature;
|
||||
use datatypes::scalars::ScalarVector;
|
||||
use datatypes::vectors::{BooleanVector, Float32Vector};
|
||||
|
||||
use super::*;
|
||||
#[test]
|
||||
@@ -98,9 +82,20 @@ mod tests {
|
||||
);
|
||||
let values = vec![None, Some(3.0), None];
|
||||
|
||||
let args: Vec<VectorRef> = vec![Arc::new(Float32Vector::from(values))];
|
||||
let vector = is_null.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let expect: VectorRef = Arc::new(BooleanVector::from_vec(vec![true, false, true]));
|
||||
let result = is_null
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(Arc::new(Float32Array::from(values)))],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("", DataType::Boolean, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
let ColumnarValue::Array(result) = result else {
|
||||
unreachable!()
|
||||
};
|
||||
let vector = result.as_boolean();
|
||||
let expect = &BooleanArray::from(vec![true, false, true]);
|
||||
assert_eq!(expect, vector);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,13 +21,13 @@ use common_query::error::{self, Result};
|
||||
use datafusion::arrow::array::{Array, AsArray, ListBuilder, StringViewBuilder};
|
||||
use datafusion::arrow::datatypes::{DataType, Field, Float64Type, UInt8Type};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, utils};
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_expr::type_coercion::aggregates::INTEGERS;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use geohash::Coord;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function::{Function, extract_args};
|
||||
use crate::scalars::geo::helpers;
|
||||
|
||||
fn ensure_resolution_usize(v: u8) -> datafusion_common::Result<usize> {
|
||||
@@ -77,8 +77,7 @@ impl Function for GeohashFunction {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?;
|
||||
let [lat_vec, lon_vec, resolutions] = extract_args(self.name(), &args)?;
|
||||
|
||||
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
|
||||
let lat_vec = lat_vec.as_primitive::<Float64Type>();
|
||||
@@ -169,8 +168,7 @@ impl Function for GeohashNeighboursFunction {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?;
|
||||
let [lat_vec, lon_vec, resolutions] = extract_args(self.name(), &args)?;
|
||||
|
||||
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
|
||||
let lat_vec = lat_vec.as_primitive::<Float64Type>();
|
||||
|
||||
@@ -25,7 +25,7 @@ use datafusion::arrow::array::{
|
||||
use datafusion::arrow::compute;
|
||||
use datafusion::arrow::datatypes::{Float64Type, Int64Type, UInt8Type, UInt64Type};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, ScalarValue, utils};
|
||||
use datafusion_common::{DataFusionError, ScalarValue};
|
||||
use datafusion_expr::type_coercion::aggregates::INTEGERS;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use datatypes::arrow::datatypes::{DataType, Field};
|
||||
@@ -33,7 +33,7 @@ use derive_more::Display;
|
||||
use h3o::{CellIndex, LatLng, Resolution};
|
||||
use snafu::prelude::*;
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function::{Function, extract_args};
|
||||
use crate::scalars::geo::helpers;
|
||||
|
||||
static CELL_TYPES: LazyLock<Vec<DataType>> =
|
||||
@@ -85,8 +85,7 @@ impl Function for H3LatLngToCell {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [lat_vec, lon_vec, resolution_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [lat_vec, lon_vec, resolution_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
|
||||
let lat_vec = lat_vec.as_primitive::<Float64Type>();
|
||||
@@ -167,8 +166,7 @@ impl Function for H3LatLngToCellString {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [lat_vec, lon_vec, resolution_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [lat_vec, lon_vec, resolution_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
|
||||
let lat_vec = lat_vec.as_primitive::<Float64Type>();
|
||||
@@ -233,8 +231,7 @@ impl Function for H3CellToString {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cell_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [cell_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let size = cell_vec.len();
|
||||
let mut builder = StringViewBuilder::with_capacity(size);
|
||||
@@ -272,8 +269,7 @@ impl Function for H3StringToCell {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [string_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [string_vec] = extract_args(self.name(), &args)?;
|
||||
let string_vec = compute::cast(string_vec.as_ref(), &DataType::Utf8View)?;
|
||||
let string_vec = datafusion_common::downcast_value!(string_vec, StringViewArray);
|
||||
|
||||
@@ -323,8 +319,7 @@ impl Function for H3CellCenterLatLng {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cell_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [cell_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let size = cell_vec.len();
|
||||
let mut builder = ListBuilder::new(Float64Builder::new());
|
||||
@@ -368,8 +363,7 @@ impl Function for H3CellResolution {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cell_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [cell_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let size = cell_vec.len();
|
||||
let mut builder = UInt8Builder::with_capacity(cell_vec.len());
|
||||
@@ -406,8 +400,7 @@ impl Function for H3CellBase {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cell_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [cell_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let size = cell_vec.len();
|
||||
let mut builder = UInt8Builder::with_capacity(size);
|
||||
@@ -445,8 +438,7 @@ impl Function for H3CellIsPentagon {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cell_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [cell_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let size = cell_vec.len();
|
||||
let mut builder = BooleanBuilder::with_capacity(size);
|
||||
@@ -544,8 +536,7 @@ impl Function for H3CellToChildren {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cell_vec, res_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [cell_vec, res_vec] = extract_args(self.name(), &args)?;
|
||||
let resolutions = helpers::cast::<UInt8Type>(&res_vec)?;
|
||||
let resolutions = resolutions.as_primitive::<UInt8Type>();
|
||||
|
||||
@@ -638,8 +629,7 @@ fn calculate_cell_child_property<F>(
|
||||
where
|
||||
F: Fn(CellIndex, Resolution) -> Option<u64>,
|
||||
{
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cells, resolutions] = utils::take_function_args(name, args)?;
|
||||
let [cells, resolutions] = extract_args(name, &args)?;
|
||||
let resolutions = helpers::cast::<UInt8Type>(&resolutions)?;
|
||||
let resolutions = resolutions.as_primitive::<UInt8Type>();
|
||||
|
||||
@@ -695,8 +685,7 @@ impl Function for H3ChildPosToCell {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [pos_vec, cell_vec, res_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [pos_vec, cell_vec, res_vec] = extract_args(self.name(), &args)?;
|
||||
let resolutions = helpers::cast::<UInt8Type>(&res_vec)?;
|
||||
let resolutions = resolutions.as_primitive::<UInt8Type>();
|
||||
|
||||
@@ -747,8 +736,7 @@ impl Function for H3GridDisk {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cell_vec, k_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [cell_vec, k_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let size = cell_vec.len();
|
||||
let mut builder = ListBuilder::new(UInt64Builder::new());
|
||||
@@ -797,8 +785,7 @@ impl Function for H3GridDiskDistances {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cell_vec, k_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [cell_vec, k_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let size = cell_vec.len();
|
||||
let mut builder = ListBuilder::new(UInt64Builder::new());
|
||||
@@ -842,8 +829,7 @@ impl Function for H3GridDistance {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cell_this_vec, cell_that_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [cell_this_vec, cell_that_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let size = cell_this_vec.len();
|
||||
let mut builder = Int32Builder::with_capacity(size);
|
||||
@@ -902,8 +888,7 @@ impl Function for H3GridPathCells {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cell_this_vec, cell_that_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [cell_this_vec, cell_that_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let size = cell_this_vec.len();
|
||||
let mut builder = ListBuilder::new(UInt64Builder::new());
|
||||
@@ -972,8 +957,7 @@ impl Function for H3CellContains {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cells_vec, cell_this_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [cells_vec, cell_this_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let size = cell_this_vec.len();
|
||||
let mut builder = BooleanBuilder::with_capacity(size);
|
||||
@@ -1027,8 +1011,7 @@ impl Function for H3CellDistanceSphereKm {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cell_this_vec, cell_that_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [cell_this_vec, cell_that_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let size = cell_this_vec.len();
|
||||
let mut builder = Float64Builder::with_capacity(size);
|
||||
@@ -1084,8 +1067,7 @@ impl Function for H3CellDistanceEuclideanDegree {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cell_this_vec, cell_that_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let [cell_this_vec, cell_that_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let size = cell_this_vec.len();
|
||||
let mut builder = Float64Builder::with_capacity(size);
|
||||
|
||||
@@ -15,63 +15,14 @@
|
||||
use datafusion::arrow::array::{ArrayRef, ArrowPrimitiveType};
|
||||
use datafusion::arrow::compute;
|
||||
|
||||
macro_rules! ensure_columns_len {
|
||||
($columns:ident) => {
|
||||
snafu::ensure!(
|
||||
$columns.windows(2).all(|c| c[0].len() == c[1].len()),
|
||||
common_query::error::InvalidFuncArgsSnafu {
|
||||
err_msg: "The length of input columns are in different size"
|
||||
}
|
||||
)
|
||||
};
|
||||
($column_a:ident, $column_b:ident, $($column_n:ident),*) => {
|
||||
snafu::ensure!(
|
||||
{
|
||||
let mut result = $column_a.len() == $column_b.len();
|
||||
$(
|
||||
result = result && ($column_a.len() == $column_n.len());
|
||||
)*
|
||||
result
|
||||
}
|
||||
common_query::error::InvalidFuncArgsSnafu {
|
||||
err_msg: "The length of input columns are in different size"
|
||||
}
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use ensure_columns_len;
|
||||
|
||||
macro_rules! ensure_columns_n {
|
||||
($columns:ident, $n:literal) => {
|
||||
snafu::ensure!(
|
||||
$columns.len() == $n,
|
||||
common_query::error::InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of arguments is not correct, expect {}, provided : {}",
|
||||
stringify!($n),
|
||||
$columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
if $n > 1 {
|
||||
ensure_columns_len!($columns);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use ensure_columns_n;
|
||||
|
||||
macro_rules! ensure_and_coerce {
|
||||
($compare:expr, $coerce:expr) => {{
|
||||
snafu::ensure!(
|
||||
$compare,
|
||||
common_query::error::InvalidFuncArgsSnafu {
|
||||
err_msg: "Argument was outside of acceptable range "
|
||||
}
|
||||
);
|
||||
Ok($coerce)
|
||||
if !$compare {
|
||||
return Err(datafusion_common::DataFusionError::Execution(
|
||||
"argument out of valid range".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(Some($coerce))
|
||||
}};
|
||||
}
|
||||
|
||||
|
||||
@@ -12,21 +12,22 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_error::ext::{BoxedError, PlainError};
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_query::error::{self, Result};
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{Float64VectorBuilder, MutableVector, VectorRef};
|
||||
use datafusion_common::arrow::array::{Array, AsArray, Float64Builder};
|
||||
use datafusion_common::arrow::compute;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
use derive_more::Display;
|
||||
use geo::algorithm::line_measures::metric_spaces::Euclidean;
|
||||
use geo::{Area, Distance, Haversine};
|
||||
use geo_types::Geometry;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::geo::helpers::{ensure_columns_len, ensure_columns_n};
|
||||
use crate::function::{Function, extract_args};
|
||||
use crate::scalars::geo::wkt::parse_wkt;
|
||||
|
||||
/// Return WGS84(SRID: 4326) euclidean distance between two geometry object, in degree
|
||||
@@ -47,33 +48,38 @@ impl Function for STDistance {
|
||||
Signature::string(2, Volatility::Stable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure_columns_n!(columns, 2);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0, arg1] = extract_args(self.name(), &args)?;
|
||||
|
||||
let wkt_this_vec = &columns[0];
|
||||
let wkt_that_vec = &columns[1];
|
||||
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
|
||||
let wkt_this_vec = arg0.as_string_view();
|
||||
let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
|
||||
let wkt_that_vec = arg1.as_string_view();
|
||||
|
||||
let size = wkt_this_vec.len();
|
||||
let mut results = Float64VectorBuilder::with_capacity(size);
|
||||
let mut builder = Float64Builder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let wkt_this = wkt_this_vec.get(i).as_string();
|
||||
let wkt_that = wkt_that_vec.get(i).as_string();
|
||||
let wkt_this = wkt_this_vec.is_valid(i).then(|| wkt_this_vec.value(i));
|
||||
let wkt_that = wkt_that_vec.is_valid(i).then(|| wkt_that_vec.value(i));
|
||||
|
||||
let result = match (wkt_this, wkt_that) {
|
||||
(Some(wkt_this), Some(wkt_that)) => {
|
||||
let geom_this = parse_wkt(&wkt_this)?;
|
||||
let geom_that = parse_wkt(&wkt_that)?;
|
||||
let geom_this = parse_wkt(wkt_this)?;
|
||||
let geom_that = parse_wkt(wkt_that)?;
|
||||
|
||||
Some(Euclidean::distance(&geom_this, &geom_that))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(result);
|
||||
builder.append_option(result);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,23 +101,28 @@ impl Function for STDistanceSphere {
|
||||
Signature::string(2, Volatility::Stable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure_columns_n!(columns, 2);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0, arg1] = extract_args(self.name(), &args)?;
|
||||
|
||||
let wkt_this_vec = &columns[0];
|
||||
let wkt_that_vec = &columns[1];
|
||||
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
|
||||
let wkt_this_vec = arg0.as_string_view();
|
||||
let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
|
||||
let wkt_that_vec = arg1.as_string_view();
|
||||
|
||||
let size = wkt_this_vec.len();
|
||||
let mut results = Float64VectorBuilder::with_capacity(size);
|
||||
let mut builder = Float64Builder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let wkt_this = wkt_this_vec.get(i).as_string();
|
||||
let wkt_that = wkt_that_vec.get(i).as_string();
|
||||
let wkt_this = wkt_this_vec.is_valid(i).then(|| wkt_this_vec.value(i));
|
||||
let wkt_that = wkt_that_vec.is_valid(i).then(|| wkt_that_vec.value(i));
|
||||
|
||||
let result = match (wkt_this, wkt_that) {
|
||||
(Some(wkt_this), Some(wkt_that)) => {
|
||||
let geom_this = parse_wkt(&wkt_this)?;
|
||||
let geom_that = parse_wkt(&wkt_that)?;
|
||||
let geom_this = parse_wkt(wkt_this)?;
|
||||
let geom_that = parse_wkt(wkt_that)?;
|
||||
|
||||
match (geom_this, geom_that) {
|
||||
(Geometry::Point(this), Geometry::Point(that)) => {
|
||||
@@ -128,10 +139,10 @@ impl Function for STDistanceSphere {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(result);
|
||||
builder.append_option(result);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,27 +164,31 @@ impl Function for STArea {
|
||||
Signature::string(1, Volatility::Stable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure_columns_n!(columns, 1);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0] = extract_args(self.name(), &args)?;
|
||||
|
||||
let wkt_vec = &columns[0];
|
||||
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
|
||||
let wkt_vec = arg0.as_string_view();
|
||||
|
||||
let size = wkt_vec.len();
|
||||
let mut results = Float64VectorBuilder::with_capacity(size);
|
||||
let mut builder = Float64Builder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let wkt = wkt_vec.get(i).as_string();
|
||||
let wkt = wkt_vec.is_valid(i).then(|| wkt_vec.value(i));
|
||||
|
||||
let result = if let Some(wkt) = wkt {
|
||||
let geom = parse_wkt(&wkt)?;
|
||||
let geom = parse_wkt(wkt)?;
|
||||
Some(geom.unsigned_area())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
results.push(result);
|
||||
builder.append_option(result);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,18 +12,20 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BooleanVectorBuilder, MutableVector, VectorRef};
|
||||
use datafusion_common::arrow::array::{Array, AsArray, BooleanBuilder};
|
||||
use datafusion_common::arrow::compute;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
use derive_more::Display;
|
||||
use geo::algorithm::contains::Contains;
|
||||
use geo::algorithm::intersects::Intersects;
|
||||
use geo::algorithm::within::Within;
|
||||
use geo_types::Geometry;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::geo::helpers::{ensure_columns_len, ensure_columns_n};
|
||||
use crate::function::{Function, extract_args};
|
||||
use crate::scalars::geo::wkt::parse_wkt;
|
||||
|
||||
/// Test if spatial relationship: contains
|
||||
@@ -31,46 +33,11 @@ use crate::scalars::geo::wkt::parse_wkt;
|
||||
#[display("{}", self.name())]
|
||||
pub struct STContains;
|
||||
|
||||
impl Function for STContains {
|
||||
fn name(&self) -> &str {
|
||||
"st_contains"
|
||||
}
|
||||
impl StFunction for STContains {
|
||||
const NAME: &'static str = "st_contains";
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Boolean)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::string(2, Volatility::Stable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure_columns_n!(columns, 2);
|
||||
|
||||
let wkt_this_vec = &columns[0];
|
||||
let wkt_that_vec = &columns[1];
|
||||
|
||||
let size = wkt_this_vec.len();
|
||||
let mut results = BooleanVectorBuilder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let wkt_this = wkt_this_vec.get(i).as_string();
|
||||
let wkt_that = wkt_that_vec.get(i).as_string();
|
||||
|
||||
let result = match (wkt_this, wkt_that) {
|
||||
(Some(wkt_this), Some(wkt_that)) => {
|
||||
let geom_this = parse_wkt(&wkt_this)?;
|
||||
let geom_that = parse_wkt(&wkt_that)?;
|
||||
|
||||
Some(geom_this.contains(&geom_that))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
fn invoke(g1: Geometry, g2: Geometry) -> bool {
|
||||
g1.contains(&g2)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,46 +46,11 @@ impl Function for STContains {
|
||||
#[display("{}", self.name())]
|
||||
pub struct STWithin;
|
||||
|
||||
impl Function for STWithin {
|
||||
fn name(&self) -> &str {
|
||||
"st_within"
|
||||
}
|
||||
impl StFunction for STWithin {
|
||||
const NAME: &'static str = "st_within";
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Boolean)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::string(2, Volatility::Stable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure_columns_n!(columns, 2);
|
||||
|
||||
let wkt_this_vec = &columns[0];
|
||||
let wkt_that_vec = &columns[1];
|
||||
|
||||
let size = wkt_this_vec.len();
|
||||
let mut results = BooleanVectorBuilder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let wkt_this = wkt_this_vec.get(i).as_string();
|
||||
let wkt_that = wkt_that_vec.get(i).as_string();
|
||||
|
||||
let result = match (wkt_this, wkt_that) {
|
||||
(Some(wkt_this), Some(wkt_that)) => {
|
||||
let geom_this = parse_wkt(&wkt_this)?;
|
||||
let geom_that = parse_wkt(&wkt_that)?;
|
||||
|
||||
Some(geom_this.is_within(&geom_that))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
fn invoke(g1: Geometry, g2: Geometry) -> bool {
|
||||
g1.is_within(&g2)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,9 +59,23 @@ impl Function for STWithin {
|
||||
#[display("{}", self.name())]
|
||||
pub struct STIntersects;
|
||||
|
||||
impl Function for STIntersects {
|
||||
impl StFunction for STIntersects {
|
||||
const NAME: &'static str = "st_intersects";
|
||||
|
||||
fn invoke(g1: Geometry, g2: Geometry) -> bool {
|
||||
g1.intersects(&g2)
|
||||
}
|
||||
}
|
||||
|
||||
trait StFunction {
|
||||
const NAME: &'static str;
|
||||
|
||||
fn invoke(g1: Geometry, g2: Geometry) -> bool;
|
||||
}
|
||||
|
||||
impl<T: StFunction + Display + Send + Sync> Function for T {
|
||||
fn name(&self) -> &str {
|
||||
"st_intersects"
|
||||
T::NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
@@ -140,32 +86,34 @@ impl Function for STIntersects {
|
||||
Signature::string(2, Volatility::Stable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure_columns_n!(columns, 2);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0, arg1] = extract_args(self.name(), &args)?;
|
||||
|
||||
let wkt_this_vec = &columns[0];
|
||||
let wkt_that_vec = &columns[1];
|
||||
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
|
||||
let wkt_this_vec = arg0.as_string_view();
|
||||
let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
|
||||
let wkt_that_vec = arg1.as_string_view();
|
||||
|
||||
let size = wkt_this_vec.len();
|
||||
let mut results = BooleanVectorBuilder::with_capacity(size);
|
||||
let mut builder = BooleanBuilder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let wkt_this = wkt_this_vec.get(i).as_string();
|
||||
let wkt_that = wkt_that_vec.get(i).as_string();
|
||||
let wkt_this = wkt_this_vec.is_valid(i).then(|| wkt_this_vec.value(i));
|
||||
let wkt_that = wkt_that_vec.is_valid(i).then(|| wkt_that_vec.value(i));
|
||||
|
||||
let result = match (wkt_this, wkt_that) {
|
||||
(Some(wkt_this), Some(wkt_that)) => {
|
||||
let geom_this = parse_wkt(&wkt_this)?;
|
||||
let geom_that = parse_wkt(&wkt_that)?;
|
||||
|
||||
Some(geom_this.intersects(&geom_that))
|
||||
Some(T::invoke(parse_wkt(wkt_this)?, parse_wkt(wkt_that)?))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(result);
|
||||
builder.append_option(result);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,21 +12,21 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::value::Value;
|
||||
use datatypes::vectors::{MutableVector, StringVectorBuilder, UInt64VectorBuilder, VectorRef};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder, UInt64Builder};
|
||||
use datafusion_common::arrow::datatypes::{DataType, Float64Type};
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use derive_more::Display;
|
||||
use s2::cellid::{CellID, MAX_LEVEL};
|
||||
use s2::latlng::LatLng;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::geo::helpers::{ensure_and_coerce, ensure_columns_len, ensure_columns_n};
|
||||
use crate::function::{Function, extract_args};
|
||||
use crate::scalars::geo::helpers;
|
||||
use crate::scalars::geo::helpers::ensure_and_coerce;
|
||||
|
||||
static CELL_TYPES: LazyLock<Vec<DataType>> =
|
||||
LazyLock::new(|| vec![DataType::Int64, DataType::UInt64]);
|
||||
@@ -65,18 +65,23 @@ impl Function for S2LatLngToCell {
|
||||
Signature::one_of(signatures, Volatility::Stable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure_columns_n!(columns, 2);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0, arg1] = extract_args(self.name(), &args)?;
|
||||
|
||||
let lat_vec = &columns[0];
|
||||
let lon_vec = &columns[1];
|
||||
let arg0 = helpers::cast::<Float64Type>(&arg0)?;
|
||||
let lat_vec = arg0.as_primitive::<Float64Type>();
|
||||
let arg1 = helpers::cast::<Float64Type>(&arg1)?;
|
||||
let lon_vec = arg1.as_primitive::<Float64Type>();
|
||||
|
||||
let size = lat_vec.len();
|
||||
let mut results = UInt64VectorBuilder::with_capacity(size);
|
||||
let mut builder = UInt64Builder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let lat = lat_vec.get(i).as_f64_lossy();
|
||||
let lon = lon_vec.get(i).as_f64_lossy();
|
||||
let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i));
|
||||
let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i));
|
||||
|
||||
let result = match (lat, lon) {
|
||||
(Some(lat), Some(lon)) => {
|
||||
@@ -94,10 +99,10 @@ impl Function for S2LatLngToCell {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(result);
|
||||
builder.append_option(result);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,21 +124,23 @@ impl Function for S2CellLevel {
|
||||
signature_of_cell()
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure_columns_n!(columns, 1);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [cell_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let cell_vec = &columns[0];
|
||||
let size = cell_vec.len();
|
||||
let mut results = UInt64VectorBuilder::with_capacity(size);
|
||||
let mut builder = UInt64Builder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let cell = cell_from_value(cell_vec.get(i));
|
||||
let res = cell.map(|cell| cell.level());
|
||||
let v = ScalarValue::try_from_array(&cell_vec, i)?;
|
||||
let v = cell_from_value(v).map(|x| x.level());
|
||||
|
||||
results.push(res);
|
||||
builder.append_option(v);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,28 +155,30 @@ impl Function for S2CellToToken {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
signature_of_cell()
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure_columns_n!(columns, 1);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [cell_vec] = extract_args(self.name(), &args)?;
|
||||
|
||||
let cell_vec = &columns[0];
|
||||
let size = cell_vec.len();
|
||||
let mut results = StringVectorBuilder::with_capacity(size);
|
||||
let mut builder = StringViewBuilder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let cell = cell_from_value(cell_vec.get(i));
|
||||
let res = cell.map(|cell| cell.to_token());
|
||||
let v = ScalarValue::try_from_array(&cell_vec, i)?;
|
||||
let v = cell_from_value(v).map(|x| x.to_token());
|
||||
|
||||
results.push(res.as_deref());
|
||||
builder.append_option(v.as_deref());
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,23 +200,28 @@ impl Function for S2CellParent {
|
||||
signature_of_cell_and_level()
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure_columns_n!(columns, 2);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [cell_vec, levels] = extract_args(self.name(), &args)?;
|
||||
|
||||
let cell_vec = &columns[0];
|
||||
let level_vec = &columns[1];
|
||||
let size = cell_vec.len();
|
||||
let mut results = UInt64VectorBuilder::with_capacity(size);
|
||||
let mut builder = UInt64Builder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let cell = cell_from_value(cell_vec.get(i));
|
||||
let level = value_to_level(level_vec.get(i))?;
|
||||
let result = cell.map(|cell| cell.parent(level).0);
|
||||
let cell = ScalarValue::try_from_array(&cell_vec, i).map(cell_from_value)?;
|
||||
let level = ScalarValue::try_from_array(&levels, i).and_then(value_to_level)?;
|
||||
let result = if let (Some(cell), Some(level)) = (cell, level) {
|
||||
Some(cell.parent(level).0)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
results.push(result);
|
||||
builder.append_option(result);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,24 +247,30 @@ fn signature_of_cell_and_level() -> Signature {
|
||||
Signature::one_of(signatures, Volatility::Stable)
|
||||
}
|
||||
|
||||
fn cell_from_value(v: Value) -> Option<CellID> {
|
||||
fn cell_from_value(v: ScalarValue) -> Option<CellID> {
|
||||
match v {
|
||||
Value::Int64(v) => Some(CellID(v as u64)),
|
||||
Value::UInt64(v) => Some(CellID(v)),
|
||||
ScalarValue::Int64(v) => v.map(|x| CellID(x as u64)),
|
||||
ScalarValue::UInt64(v) => v.map(CellID),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn value_to_level(v: Value) -> Result<u64> {
|
||||
fn value_to_level(v: ScalarValue) -> datafusion_common::Result<Option<u64>> {
|
||||
match v {
|
||||
Value::Int8(v) => ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i8, v as u64),
|
||||
Value::Int16(v) => ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i16, v as u64),
|
||||
Value::Int32(v) => ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i32, v as u64),
|
||||
Value::Int64(v) => ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i64, v as u64),
|
||||
Value::UInt8(v) => ensure_and_coerce!(v <= MAX_LEVEL as u8, v as u64),
|
||||
Value::UInt16(v) => ensure_and_coerce!(v <= MAX_LEVEL as u16, v as u64),
|
||||
Value::UInt32(v) => ensure_and_coerce!(v <= MAX_LEVEL as u32, v as u64),
|
||||
Value::UInt64(v) => ensure_and_coerce!(v <= MAX_LEVEL, v),
|
||||
_ => unreachable!(),
|
||||
ScalarValue::Int8(Some(v)) => ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i8, v as u64),
|
||||
ScalarValue::Int16(Some(v)) => {
|
||||
ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i16, v as u64)
|
||||
}
|
||||
ScalarValue::Int32(Some(v)) => {
|
||||
ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i32, v as u64)
|
||||
}
|
||||
ScalarValue::Int64(Some(v)) => {
|
||||
ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i64, v as u64)
|
||||
}
|
||||
ScalarValue::UInt8(Some(v)) => ensure_and_coerce!(v <= MAX_LEVEL as u8, v as u64),
|
||||
ScalarValue::UInt16(Some(v)) => ensure_and_coerce!(v <= MAX_LEVEL as u16, v as u64),
|
||||
ScalarValue::UInt32(Some(v)) => ensure_and_coerce!(v <= MAX_LEVEL as u32, v as u64),
|
||||
ScalarValue::UInt64(Some(v)) => ensure_and_coerce!(v <= MAX_LEVEL, v),
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,22 +12,21 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
|
||||
use common_error::ext::{BoxedError, PlainError};
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_query::error::{self, Result};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{MutableVector, StringVectorBuilder, VectorRef};
|
||||
use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder};
|
||||
use datafusion_common::arrow::datatypes::{DataType, Float64Type};
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use derive_more::Display;
|
||||
use geo_types::{Geometry, Point};
|
||||
use snafu::ResultExt;
|
||||
use wkt::{ToWkt, TryFromWkt};
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::geo::helpers::{ensure_columns_len, ensure_columns_n};
|
||||
use crate::function::{Function, extract_args};
|
||||
use crate::scalars::geo::helpers;
|
||||
|
||||
static COORDINATE_TYPES: LazyLock<Vec<DataType>> =
|
||||
LazyLock::new(|| vec![DataType::Float32, DataType::Float64]);
|
||||
@@ -43,7 +42,7 @@ impl Function for LatLngToPointWkt {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -59,28 +58,33 @@ impl Function for LatLngToPointWkt {
|
||||
Signature::one_of(signatures, Volatility::Stable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure_columns_n!(columns, 2);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0, arg1] = extract_args(self.name(), &args)?;
|
||||
|
||||
let lat_vec = &columns[0];
|
||||
let lng_vec = &columns[1];
|
||||
let arg0 = helpers::cast::<Float64Type>(&arg0)?;
|
||||
let lat_vec = arg0.as_primitive::<Float64Type>();
|
||||
let arg1 = helpers::cast::<Float64Type>(&arg1)?;
|
||||
let lng_vec = arg1.as_primitive::<Float64Type>();
|
||||
|
||||
let size = lat_vec.len();
|
||||
let mut results = StringVectorBuilder::with_capacity(size);
|
||||
let mut builder = StringViewBuilder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let lat = lat_vec.get(i).as_f64_lossy();
|
||||
let lng = lng_vec.get(i).as_f64_lossy();
|
||||
let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i));
|
||||
let lng = lng_vec.is_valid(i).then(|| lng_vec.value(i));
|
||||
|
||||
let result = match (lat, lng) {
|
||||
(Some(lat), Some(lng)) => Some(Point::new(lng, lat).wkt_string()),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(result.as_deref());
|
||||
builder.append_option(result.as_deref());
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,18 +16,17 @@
|
||||
|
||||
use std::fmt;
|
||||
use std::fmt::Display;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{DowncastVectorSnafu, InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use common_query::error::Result;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, AsArray, UInt64Builder};
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::prelude::Vector;
|
||||
use datatypes::scalars::{ScalarVector, ScalarVectorBuilder};
|
||||
use datatypes::vectors::{BinaryVector, MutableVector, UInt64VectorBuilder, VectorRef};
|
||||
use hyperloglogplus::HyperLogLog;
|
||||
use snafu::OptionExt;
|
||||
|
||||
use crate::aggrs::approximate::hll::HllStateType;
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::{Function, extract_args};
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "hll_count";
|
||||
@@ -67,28 +66,27 @@ impl Function for HllCalcFunction {
|
||||
Signature::exact(vec![DataType::Binary], Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
if columns.len() != 1 {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: format!("hll_count expects 1 argument, got {}", columns.len()),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0] = extract_args(self.name(), &args)?;
|
||||
|
||||
let hll_vec = columns[0]
|
||||
.as_any()
|
||||
.downcast_ref::<BinaryVector>()
|
||||
.with_context(|| DowncastVectorSnafu {
|
||||
err_msg: format!("expect BinaryVector, got {}", columns[0].vector_type_name()),
|
||||
})?;
|
||||
let Some(hll_vec) = arg0.as_binary_opt::<i32>() else {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"'{}' expects argument to be Binary datatype, got {}",
|
||||
self.name(),
|
||||
arg0.data_type()
|
||||
)));
|
||||
};
|
||||
let len = hll_vec.len();
|
||||
let mut builder = UInt64VectorBuilder::with_capacity(len);
|
||||
let mut builder = UInt64Builder::with_capacity(len);
|
||||
|
||||
for i in 0..len {
|
||||
let hll_opt = hll_vec.get_data(i);
|
||||
let hll_opt = hll_vec.is_valid(i).then(|| hll_vec.value(i));
|
||||
|
||||
if hll_opt.is_none() {
|
||||
builder.push_null();
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -99,15 +97,15 @@ impl Function for HllCalcFunction {
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
common_telemetry::trace!("Failed to deserialize HyperLogLogPlus: {}", e);
|
||||
builder.push_null();
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
builder.push(Some(hll.count().round() as u64));
|
||||
builder.append_value(hll.count().round() as u64);
|
||||
}
|
||||
|
||||
Ok(builder.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +113,9 @@ impl Function for HllCalcFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::BinaryVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion_common::arrow::array::BinaryArray;
|
||||
use datafusion_common::arrow::datatypes::UInt64Type;
|
||||
|
||||
use super::*;
|
||||
use crate::utils::FixedRandomState;
|
||||
@@ -136,17 +136,27 @@ mod tests {
|
||||
}
|
||||
|
||||
let serialized_bytes = bincode::serialize(&hll).unwrap();
|
||||
let args: Vec<VectorRef> = vec![Arc::new(BinaryVector::from(vec![Some(serialized_bytes)]))];
|
||||
let args = vec![ColumnarValue::Array(Arc::new(BinaryArray::from_iter(
|
||||
vec![Some(serialized_bytes)],
|
||||
)))];
|
||||
|
||||
let result = function.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let result = function
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows: 1,
|
||||
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
let ColumnarValue::Array(result) = result else {
|
||||
unreachable!()
|
||||
};
|
||||
let result = result.as_primitive::<UInt64Type>();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
// Test cardinality estimate
|
||||
if let datatypes::value::Value::UInt64(v) = result.get(0) {
|
||||
assert_eq!(v, 10);
|
||||
} else {
|
||||
panic!("Expected uint64 value");
|
||||
}
|
||||
assert_eq!(result.value(0), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -154,20 +164,38 @@ mod tests {
|
||||
let function = HllCalcFunction;
|
||||
|
||||
// Test with invalid number of arguments
|
||||
let args: Vec<VectorRef> = vec![];
|
||||
let result = function.eval(&FunctionContext::default(), &args);
|
||||
let result = function.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![],
|
||||
arg_fields: vec![],
|
||||
number_rows: 0,
|
||||
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
});
|
||||
assert!(result.is_err());
|
||||
assert!(
|
||||
result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("hll_count expects 1 argument")
|
||||
.contains("Execution error: hll_count function requires 1 argument, got 0")
|
||||
);
|
||||
|
||||
// Test with invalid binary data
|
||||
let args: Vec<VectorRef> = vec![Arc::new(BinaryVector::from(vec![Some(vec![1, 2, 3])]))]; // Invalid binary data
|
||||
let result = function.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let result = function
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(Arc::new(BinaryArray::from_iter(
|
||||
vec![Some(vec![1, 2, 3])],
|
||||
)))],
|
||||
arg_fields: vec![],
|
||||
number_rows: 0,
|
||||
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
let ColumnarValue::Array(result) = result else {
|
||||
unreachable!()
|
||||
};
|
||||
let result = result.as_primitive::<UInt64Type>();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(matches!(result.get(0), datatypes::value::Value::Null));
|
||||
assert!(result.is_null(0));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,13 +23,13 @@ use datafusion::common::{DFSchema, Result as DfResult};
|
||||
use datafusion::execution::SessionStateBuilder;
|
||||
use datafusion::logical_expr::{self, ColumnarValue, Expr, Volatility};
|
||||
use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
|
||||
use datafusion_common::{DataFusionError, utils};
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::arrow::array::RecordBatch;
|
||||
use datatypes::arrow::datatypes::{DataType, Field};
|
||||
use snafu::{OptionExt, ensure};
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function::{Function, extract_args};
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
/// `matches` for full text search.
|
||||
@@ -65,8 +65,7 @@ impl Function for MatchesFunction {
|
||||
|
||||
// TODO: read case-sensitive config
|
||||
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DfResult<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [data_column, patterns] = utils::take_function_args(self.name(), args)?;
|
||||
let [data_column, patterns] = extract_args(self.name(), &args)?;
|
||||
|
||||
if data_column.is_empty() {
|
||||
return Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(
|
||||
|
||||
@@ -16,17 +16,16 @@
|
||||
|
||||
use std::fmt;
|
||||
use std::fmt::Display;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{DowncastVectorSnafu, InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::prelude::Vector;
|
||||
use datatypes::scalars::{ScalarVector, ScalarVectorBuilder};
|
||||
use datatypes::vectors::{BinaryVector, Float64VectorBuilder, MutableVector, VectorRef};
|
||||
use snafu::OptionExt;
|
||||
use common_query::error::Result;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, AsArray, Float64Builder};
|
||||
use datafusion_common::arrow::datatypes::{DataType, Float64Type};
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
use uddsketch::UDDSketch;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::{Function, extract_args};
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "uddsketch_calc";
|
||||
@@ -71,30 +70,35 @@ impl Function for UddSketchCalcFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
if columns.len() != 2 {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: format!("uddsketch_calc expects 2 arguments, got {}", columns.len()),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0, arg1] = extract_args(self.name(), &args)?;
|
||||
|
||||
let perc_vec = &columns[0];
|
||||
let sketch_vec = columns[1]
|
||||
.as_any()
|
||||
.downcast_ref::<BinaryVector>()
|
||||
.with_context(|| DowncastVectorSnafu {
|
||||
err_msg: format!("expect BinaryVector, got {}", columns[1].vector_type_name()),
|
||||
})?;
|
||||
let Some(percentages) = arg0.as_primitive_opt::<Float64Type>() else {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"'{}' expects 1st argument to be Float64 datatype, got {}",
|
||||
self.name(),
|
||||
arg0.data_type()
|
||||
)));
|
||||
};
|
||||
let Some(sketch_vec) = arg1.as_binary_opt::<i32>() else {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"'{}' expects 2nd argument to be Binary datatype, got {}",
|
||||
self.name(),
|
||||
arg1.data_type()
|
||||
)));
|
||||
};
|
||||
let len = sketch_vec.len();
|
||||
let mut builder = Float64VectorBuilder::with_capacity(len);
|
||||
let mut builder = Float64Builder::with_capacity(len);
|
||||
|
||||
for i in 0..len {
|
||||
let perc_opt = perc_vec.get(i).as_f64_lossy();
|
||||
let sketch_opt = sketch_vec.get_data(i);
|
||||
let perc_opt = percentages.is_valid(i).then(|| percentages.value(i));
|
||||
let sketch_opt = sketch_vec.is_valid(i).then(|| sketch_vec.value(i));
|
||||
|
||||
if sketch_opt.is_none() || perc_opt.is_none() {
|
||||
builder.push_null();
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -106,7 +110,7 @@ impl Function for UddSketchCalcFunction {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
common_telemetry::trace!("Failed to deserialize UDDSketch: {}", e);
|
||||
builder.push_null();
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -115,15 +119,15 @@ impl Function for UddSketchCalcFunction {
|
||||
// This is important to avoid panics when calling estimate_quantile on an empty sketch
|
||||
// In practice, this will happen if input is all null
|
||||
if sketch.bucket_iter().count() == 0 {
|
||||
builder.push_null();
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
// Compute the estimated quantile from the sketch
|
||||
let result = sketch.estimate_quantile(perc);
|
||||
builder.push(Some(result));
|
||||
builder.append_value(result);
|
||||
}
|
||||
|
||||
Ok(builder.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,7 +135,8 @@ impl Function for UddSketchCalcFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::{BinaryVector, Float64Vector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion_common::arrow::array::{BinaryArray, Float64Array};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -165,26 +170,32 @@ mod tests {
|
||||
let serialized = bincode::serialize(&sketch).unwrap();
|
||||
let percentiles = vec![0.5, 0.9, 0.95];
|
||||
|
||||
let args: Vec<VectorRef> = vec![
|
||||
Arc::new(Float64Vector::from_vec(percentiles.clone())),
|
||||
Arc::new(BinaryVector::from(vec![Some(serialized.clone()); 3])),
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(percentiles.clone()))),
|
||||
ColumnarValue::Array(Arc::new(BinaryArray::from_iter_values(vec![serialized; 3]))),
|
||||
];
|
||||
|
||||
let result = function.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let result = function
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::Float64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
let ColumnarValue::Array(result) = result else {
|
||||
unreachable!()
|
||||
};
|
||||
let result = result.as_primitive::<Float64Type>();
|
||||
assert_eq!(result.len(), 3);
|
||||
|
||||
// Test median (p50)
|
||||
assert!(
|
||||
matches!(result.get(0), datatypes::value::Value::Float64(v) if (v - expected_p50).abs() < 1e-10)
|
||||
);
|
||||
assert!((result.value(0) - expected_p50).abs() < 1e-10);
|
||||
// Test p90
|
||||
assert!(
|
||||
matches!(result.get(1), datatypes::value::Value::Float64(v) if (v - expected_p90).abs() < 1e-10)
|
||||
);
|
||||
assert!((result.value(1) - expected_p90).abs() < 1e-10);
|
||||
// Test p95
|
||||
assert!(
|
||||
matches!(result.get(2), datatypes::value::Value::Float64(v) if (v - expected_p95).abs() < 1e-10)
|
||||
);
|
||||
assert!((result.value(2) - expected_p95).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -192,23 +203,42 @@ mod tests {
|
||||
let function = UddSketchCalcFunction;
|
||||
|
||||
// Test with invalid number of arguments
|
||||
let args: Vec<VectorRef> = vec![Arc::new(Float64Vector::from_vec(vec![0.95]))];
|
||||
let result = function.eval(&FunctionContext::default(), &args);
|
||||
let result = function.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(Arc::new(Float64Array::from(vec![
|
||||
0.95,
|
||||
])))],
|
||||
arg_fields: vec![],
|
||||
number_rows: 0,
|
||||
return_field: Arc::new(Field::new("x", DataType::Float64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
});
|
||||
assert!(result.is_err());
|
||||
assert!(
|
||||
result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("uddsketch_calc expects 2 arguments")
|
||||
.contains("Execution error: uddsketch_calc function requires 2 arguments, got 1")
|
||||
);
|
||||
|
||||
// Test with invalid binary data
|
||||
let args: Vec<VectorRef> = vec![
|
||||
Arc::new(Float64Vector::from_vec(vec![0.95])),
|
||||
Arc::new(BinaryVector::from(vec![Some(vec![1, 2, 3])])), // Invalid binary data
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(vec![0.95]))),
|
||||
ColumnarValue::Array(Arc::new(BinaryArray::from_iter(vec![Some(vec![1, 2, 3])]))),
|
||||
];
|
||||
let result = function.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let result = function
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows: 0,
|
||||
return_field: Arc::new(Field::new("x", DataType::Float64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
let ColumnarValue::Array(result) = result else {
|
||||
unreachable!()
|
||||
};
|
||||
let result = result.as_primitive::<Float64Type>();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(matches!(result.get(0), datatypes::value::Value::Null));
|
||||
assert!(result.is_null(0));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,12 +19,12 @@ use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion::arrow::array::{Array, AsArray, BinaryViewBuilder};
|
||||
use datafusion::arrow::datatypes::Int64Type;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{ScalarValue, utils};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function::{Function, extract_args};
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
|
||||
|
||||
const NAME: &str = "vec_subvector";
|
||||
@@ -71,8 +71,7 @@ impl Function for VectorSubvectorFunction {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [arg0, arg1, arg2] = utils::take_function_args(self.name(), args)?;
|
||||
let [arg0, arg1, arg2] = extract_args(self.name(), &args)?;
|
||||
let arg1 = arg1.as_primitive::<Int64Type>();
|
||||
let arg2 = arg2.as_primitive::<Int64Type>();
|
||||
|
||||
|
||||
@@ -13,16 +13,14 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt::{self};
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::prelude::ScalarVector;
|
||||
use datatypes::vectors::{StringVector, UInt32Vector, VectorRef};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
use derive_more::Display;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::{Function, find_function_context};
|
||||
|
||||
/// A function to return current schema name.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
@@ -55,17 +53,21 @@ impl Function for DatabaseFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::nullary(Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let func_ctx = find_function_context(&args)?;
|
||||
let db = func_ctx.query_ctx.current_schema();
|
||||
|
||||
Ok(Arc::new(StringVector::from_slice(&[&db])) as _)
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(db))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,17 +79,21 @@ impl Function for CurrentSchemaFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::nullary(Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let func_ctx = find_function_context(&args)?;
|
||||
let db = func_ctx.query_ctx.current_schema();
|
||||
|
||||
Ok(Arc::new(StringVector::from_slice(&[&db])) as _)
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(db))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,17 +103,23 @@ impl Function for SessionUserFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::nullary(Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let func_ctx = find_function_context(&args)?;
|
||||
let user = func_ctx.query_ctx.current_user();
|
||||
|
||||
Ok(Arc::new(StringVector::from_slice(&[user.username()])) as _)
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
|
||||
user.username().to_string(),
|
||||
))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,17 +129,23 @@ impl Function for ReadPreferenceFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::nullary(Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let func_ctx = find_function_context(&args)?;
|
||||
let read_preference = func_ctx.query_ctx.read_preference();
|
||||
|
||||
Ok(Arc::new(StringVector::from_slice(&[read_preference.as_ref()])) as _)
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
|
||||
read_preference.to_string(),
|
||||
))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,10 +162,14 @@ impl Function for PgBackendPidFunction {
|
||||
Signature::nullary(Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let func_ctx = find_function_context(&args)?;
|
||||
let pid = func_ctx.query_ctx.process_id();
|
||||
|
||||
Ok(Arc::new(UInt32Vector::from_slice([pid])) as _)
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(pid as u64))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,17 +179,21 @@ impl Function for ConnectionIdFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::UInt64)
|
||||
Ok(DataType::UInt32)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::nullary(Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let func_ctx = find_function_context(&args)?;
|
||||
let pid = func_ctx.query_ctx.process_id();
|
||||
|
||||
Ok(Arc::new(UInt32Vector::from_slice([pid])) as _)
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::UInt32(Some(pid))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,14 +225,17 @@ impl fmt::Display for ReadPreferenceFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::Field;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
use session::context::QueryContextBuilder;
|
||||
|
||||
use super::*;
|
||||
use crate::function::FunctionContext;
|
||||
#[test]
|
||||
fn test_build_function() {
|
||||
let build = DatabaseFunction;
|
||||
assert_eq!("database", build.name());
|
||||
assert_eq!(DataType::Utf8, build.return_type(&[]).unwrap());
|
||||
assert_eq!(DataType::Utf8View, build.return_type(&[]).unwrap());
|
||||
assert_eq!(build.signature(), Signature::nullary(Volatility::Immutable));
|
||||
|
||||
let query_ctx = QueryContextBuilder::default()
|
||||
@@ -214,12 +243,24 @@ mod tests {
|
||||
.build()
|
||||
.into();
|
||||
|
||||
let func_ctx = FunctionContext {
|
||||
let mut config_options = ConfigOptions::default();
|
||||
config_options.extensions.insert(FunctionContext {
|
||||
query_ctx,
|
||||
..Default::default()
|
||||
});
|
||||
let config_options = Arc::new(config_options);
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![],
|
||||
arg_fields: vec![],
|
||||
number_rows: 0,
|
||||
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
|
||||
config_options,
|
||||
};
|
||||
let vector = build.eval(&func_ctx, &[]).unwrap();
|
||||
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_db"]));
|
||||
assert_eq!(expect, vector);
|
||||
let result = build.invoke_with_args(args).unwrap();
|
||||
let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) = result else {
|
||||
unreachable!()
|
||||
};
|
||||
assert_eq!(s, "test_db");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ use async_trait::async_trait;
|
||||
use common_base::Plugins;
|
||||
use common_catalog::consts::is_readonly_schema;
|
||||
use common_error::ext::BoxedError;
|
||||
use common_function::function::FunctionContext;
|
||||
use common_function::function_factory::ScalarFunctionFactory;
|
||||
use common_query::{Output, OutputData, OutputMeta};
|
||||
use common_recordbatch::adapter::RecordBatchStreamAdapter;
|
||||
@@ -568,6 +569,16 @@ impl QueryEngine for DatafusionQueryEngine {
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
state
|
||||
.config_mut()
|
||||
.options_mut()
|
||||
.extensions
|
||||
.insert(FunctionContext {
|
||||
query_ctx: query_ctx.clone(),
|
||||
state: self.engine_state().function_state(),
|
||||
});
|
||||
|
||||
QueryEngineContext::new(state, query_ctx)
|
||||
}
|
||||
|
||||
|
||||
@@ -15,11 +15,13 @@
|
||||
use std::collections::{BTreeMap, BTreeSet, HashSet};
|
||||
use std::sync::Arc;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use common_telemetry::debug;
|
||||
use datafusion::config::{ConfigExtension, ExtensionOptions};
|
||||
use datafusion::datasource::DefaultTableSource;
|
||||
use datafusion::error::Result as DfResult;
|
||||
use datafusion_common::Column;
|
||||
use datafusion_common::alias::AliasGenerator;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
|
||||
use datafusion_expr::expr::{Exists, InSubquery};
|
||||
@@ -27,7 +29,7 @@ use datafusion_expr::utils::expr_to_columns;
|
||||
use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Subquery, col as col_fn};
|
||||
use datafusion_optimizer::analyzer::AnalyzerRule;
|
||||
use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
|
||||
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
|
||||
use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
|
||||
use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
|
||||
use table::metadata::TableType;
|
||||
use table::table::adapter::DfTableProviderAdapter;
|
||||
@@ -100,8 +102,42 @@ impl AnalyzerRule for DistPlannerAnalyzer {
|
||||
plan: LogicalPlan,
|
||||
config: &ConfigOptions,
|
||||
) -> datafusion_common::Result<LogicalPlan> {
|
||||
// preprocess the input plan
|
||||
let optimizer_context = OptimizerContext::new();
|
||||
let mut config = config.clone();
|
||||
// Aligned with the behavior in `datafusion_optimizer::OptimizerContext::new()`.
|
||||
config.optimizer.filter_null_join_keys = true;
|
||||
let config = Arc::new(config);
|
||||
|
||||
// The `ConstEvaluator` in `SimplifyExpressions` might evaluate some UDFs early in the
|
||||
// planning stage, by executing them directly. For example, the `database()` function.
|
||||
// So the `ConfigOptions` here (which is set from the session context) should be present
|
||||
// in the UDF's `ScalarFunctionArgs`. However, the default implementation in DataFusion
|
||||
// seems to lost track on it: the `ConfigOptions` is recreated with its default values again.
|
||||
// So we create a custom `OptimizerConfig` with the desired `ConfigOptions`
|
||||
// to walk around the issue.
|
||||
struct OptimizerContext {
|
||||
inner: datafusion_optimizer::OptimizerContext,
|
||||
config: Arc<ConfigOptions>,
|
||||
}
|
||||
|
||||
impl OptimizerConfig for OptimizerContext {
|
||||
fn query_execution_start_time(&self) -> DateTime<Utc> {
|
||||
self.inner.query_execution_start_time()
|
||||
}
|
||||
|
||||
fn alias_generator(&self) -> &Arc<AliasGenerator> {
|
||||
self.inner.alias_generator()
|
||||
}
|
||||
|
||||
fn options(&self) -> Arc<ConfigOptions> {
|
||||
self.config.clone()
|
||||
}
|
||||
}
|
||||
|
||||
let optimizer_context = OptimizerContext {
|
||||
inner: datafusion_optimizer::OptimizerContext::new(),
|
||||
config: config.clone(),
|
||||
};
|
||||
|
||||
let plan = SimplifyExpressions::new()
|
||||
.rewrite(plan, &optimizer_context)?
|
||||
.data;
|
||||
|
||||
Reference in New Issue
Block a user