From 0e2fd8e2bdadb7a4629e21a9601214b7adecf220 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Mon, 3 Mar 2025 21:10:12 -0800 Subject: [PATCH] feat: rewrite `json_encode_path` to `geo_path` using compound type (#5640) * function impl Signed-off-by: Ruihang Xia * tune type Signed-off-by: Ruihang Xia * fix clippy and suggestions Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- src/common/function/src/aggr.rs | 2 + src/common/function/src/aggr/geo_path.rs | 433 ++++++++++++++++++ src/query/src/datafusion/planner.rs | 11 +- .../src/query_engine/default_serializer.rs | 3 +- .../standalone/common/function/geo.result | 24 +- .../cases/standalone/common/function/geo.sql | 4 +- 6 files changed, 457 insertions(+), 20 deletions(-) create mode 100644 src/common/function/src/aggr/geo_path.rs diff --git a/src/common/function/src/aggr.rs b/src/common/function/src/aggr.rs index be271d4d20..24bcb86618 100644 --- a/src/common/function/src/aggr.rs +++ b/src/common/function/src/aggr.rs @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod geo_path; mod hll; mod uddsketch_state; +pub use geo_path::{GeoPathAccumulator, GEO_PATH_NAME}; pub(crate) use hll::HllStateType; pub use hll::{HllState, HLL_MERGE_NAME, HLL_NAME}; pub use uddsketch_state::{UddSketchState, UDDSKETCH_STATE_NAME}; diff --git a/src/common/function/src/aggr/geo_path.rs b/src/common/function/src/aggr/geo_path.rs new file mode 100644 index 0000000000..d5a2f71b57 --- /dev/null +++ b/src/common/function/src/aggr/geo_path.rs @@ -0,0 +1,433 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use datafusion::arrow::array::{Array, ArrayRef}; +use datafusion::common::cast::as_primitive_array; +use datafusion::error::{DataFusionError, Result as DfResult}; +use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, Volatility}; +use datafusion::prelude::create_udaf; +use datafusion_common::cast::{as_list_array, as_struct_array}; +use datafusion_common::utils::SingleRowListArrayBuilder; +use datafusion_common::ScalarValue; +use datatypes::arrow::array::{Float64Array, Int64Array, ListArray, StructArray}; +use datatypes::arrow::datatypes::{ + DataType, Field, Float64Type, Int64Type, TimeUnit, TimestampNanosecondType, +}; +use datatypes::compute::{self, sort_to_indices}; + +pub const GEO_PATH_NAME: &str = "geo_path"; + +const LATITUDE_FIELD: &str = "lat"; +const LONGITUDE_FIELD: &str = "lng"; +const TIMESTAMP_FIELD: &str = "timestamp"; +const DEFAULT_LIST_FIELD_NAME: &str = "item"; + +#[derive(Debug, Default)] +pub struct GeoPathAccumulator { + lat: Vec>, + lng: Vec>, + timestamp: Vec>, +} + +impl GeoPathAccumulator { + pub fn new() -> Self { + Self::default() + } + + pub fn udf_impl() -> AggregateUDF { + create_udaf( + GEO_PATH_NAME, + // Input types: lat, lng, timestamp + vec![ + DataType::Float64, + DataType::Float64, + DataType::Timestamp(TimeUnit::Nanosecond, None), + ], + // Output type: list of points {[lat], [lng]} + Arc::new(DataType::Struct( + vec![ + Field::new( + LATITUDE_FIELD, + DataType::List(Arc::new(Field::new( + DEFAULT_LIST_FIELD_NAME, + DataType::Float64, + true, + ))), + false, + ), + Field::new( + LONGITUDE_FIELD, + DataType::List(Arc::new(Field::new( + DEFAULT_LIST_FIELD_NAME, + DataType::Float64, + true, + ))), + false, + ), + ] + .into(), + )), + Volatility::Immutable, + // Create the accumulator + Arc::new(|_| Ok(Box::new(GeoPathAccumulator::new()))), + // Intermediate state types + Arc::new(vec![DataType::Struct( + vec![ + Field::new( + LATITUDE_FIELD, + DataType::List(Arc::new(Field::new( + DEFAULT_LIST_FIELD_NAME, + DataType::Float64, + true, + ))), + false, + ), + Field::new( + LONGITUDE_FIELD, + DataType::List(Arc::new(Field::new( + DEFAULT_LIST_FIELD_NAME, + DataType::Float64, + true, + ))), + false, + ), + Field::new( + TIMESTAMP_FIELD, + DataType::List(Arc::new(Field::new( + DEFAULT_LIST_FIELD_NAME, + DataType::Int64, + true, + ))), + false, + ), + ] + .into(), + )]), + ) + } +} + +impl DfAccumulator for GeoPathAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion::error::Result<()> { + if values.len() != 3 { + return Err(DataFusionError::Internal(format!( + "Expected 3 columns for geo_path, got {}", + values.len() + ))); + } + + let lat_array = as_primitive_array::(&values[0])?; + let lng_array = as_primitive_array::(&values[1])?; + let ts_array = as_primitive_array::(&values[2])?; + + let size = lat_array.len(); + self.lat.reserve(size); + self.lng.reserve(size); + + for idx in 0..size { + self.lat.push(if lat_array.is_null(idx) { + None + } else { + Some(lat_array.value(idx)) + }); + + self.lng.push(if lng_array.is_null(idx) { + None + } else { + Some(lng_array.value(idx)) + }); + + self.timestamp.push(if ts_array.is_null(idx) { + None + } else { + Some(ts_array.value(idx)) + }); + } + + Ok(()) + } + + fn evaluate(&mut self) -> DfResult { + let unordered_lng_array = Float64Array::from(self.lng.clone()); + let unordered_lat_array = Float64Array::from(self.lat.clone()); + let ts_array = Int64Array::from(self.timestamp.clone()); + + let ordered_indices = sort_to_indices(&ts_array, None, None)?; + let lat_array = compute::take(&unordered_lat_array, &ordered_indices, None)?; + let lng_array = compute::take(&unordered_lng_array, &ordered_indices, None)?; + + let lat_list = Arc::new(SingleRowListArrayBuilder::new(lat_array).build_list_array()); + let lng_list = Arc::new(SingleRowListArrayBuilder::new(lng_array).build_list_array()); + + let result = ScalarValue::Struct(Arc::new(StructArray::new( + vec![ + Field::new( + LATITUDE_FIELD, + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + false, + ), + Field::new( + LONGITUDE_FIELD, + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + false, + ), + ] + .into(), + vec![lat_list, lng_list], + None, + ))); + + Ok(result) + } + + fn size(&self) -> usize { + // Base size of GeoPathAccumulator struct fields + let mut total_size = std::mem::size_of::(); + + // Size of vectors (approximation) + total_size += self.lat.capacity() * std::mem::size_of::>(); + total_size += self.lng.capacity() * std::mem::size_of::>(); + total_size += self.timestamp.capacity() * std::mem::size_of::>(); + + total_size + } + + fn state(&mut self) -> datafusion::error::Result> { + let lat_array = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(self.lat.clone()), + ])); + let lng_array = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(self.lng.clone()), + ])); + let ts_array = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(self.timestamp.clone()), + ])); + + let state_struct = StructArray::new( + vec![ + Field::new( + LATITUDE_FIELD, + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + false, + ), + Field::new( + LONGITUDE_FIELD, + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + false, + ), + Field::new( + TIMESTAMP_FIELD, + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + ] + .into(), + vec![lat_array, lng_array, ts_array], + None, + ); + + Ok(vec![ScalarValue::Struct(Arc::new(state_struct))]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion::error::Result<()> { + if states.len() != 1 { + return Err(DataFusionError::Internal(format!( + "Expected 1 states for geo_path, got {}", + states.len() + ))); + } + + for state in states { + let state = as_struct_array(state)?; + let lat_list = as_list_array(state.column(0))?.value(0); + let lat_array = as_primitive_array::(&lat_list)?; + let lng_list = as_list_array(state.column(1))?.value(0); + let lng_array = as_primitive_array::(&lng_list)?; + let ts_list = as_list_array(state.column(2))?.value(0); + let ts_array = as_primitive_array::(&ts_list)?; + + self.lat.extend(lat_array); + self.lng.extend(lng_array); + self.timestamp.extend(ts_array); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use datafusion::arrow::array::{Float64Array, TimestampNanosecondArray}; + use datafusion::scalar::ScalarValue; + + use super::*; + + #[test] + fn test_geo_path_basic() { + let mut accumulator = GeoPathAccumulator::new(); + + // Create test data + let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])); + let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0])); + let ts_array = Arc::new(TimestampNanosecondArray::from(vec![100, 200, 300])); + + // Update batch + accumulator + .update_batch(&[lat_array, lng_array, ts_array]) + .unwrap(); + + // Evaluate + let result = accumulator.evaluate().unwrap(); + if let ScalarValue::Struct(struct_array) = result { + // Verify structure + let fields = struct_array.fields().clone(); + assert_eq!(fields.len(), 2); + assert_eq!(fields[0].name(), LATITUDE_FIELD); + assert_eq!(fields[1].name(), LONGITUDE_FIELD); + + // Verify data + let columns = struct_array.columns(); + assert_eq!(columns.len(), 2); + + // Check latitude values + let lat_list = as_list_array(&columns[0]).unwrap().value(0); + let lat_array = as_primitive_array::(&lat_list).unwrap(); + assert_eq!(lat_array.len(), 3); + assert_eq!(lat_array.value(0), 1.0); + assert_eq!(lat_array.value(1), 2.0); + assert_eq!(lat_array.value(2), 3.0); + + // Check longitude values + let lng_list = as_list_array(&columns[1]).unwrap().value(0); + let lng_array = as_primitive_array::(&lng_list).unwrap(); + assert_eq!(lng_array.len(), 3); + assert_eq!(lng_array.value(0), 4.0); + assert_eq!(lng_array.value(1), 5.0); + assert_eq!(lng_array.value(2), 6.0); + } else { + panic!("Expected Struct scalar value"); + } + } + + #[test] + fn test_geo_path_sort_by_timestamp() { + let mut accumulator = GeoPathAccumulator::new(); + + // Create test data with unordered timestamps + let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])); + let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0])); + let ts_array = Arc::new(TimestampNanosecondArray::from(vec![300, 100, 200])); + + // Update batch + accumulator + .update_batch(&[lat_array, lng_array, ts_array]) + .unwrap(); + + // Evaluate + let result = accumulator.evaluate().unwrap(); + if let ScalarValue::Struct(struct_array) = result { + // Extract arrays + let columns = struct_array.columns(); + + // Check latitude values + let lat_list = as_list_array(&columns[0]).unwrap().value(0); + let lat_array = as_primitive_array::(&lat_list).unwrap(); + assert_eq!(lat_array.len(), 3); + assert_eq!(lat_array.value(0), 2.0); // timestamp 100 + assert_eq!(lat_array.value(1), 3.0); // timestamp 200 + assert_eq!(lat_array.value(2), 1.0); // timestamp 300 + + // Check longitude values (should be sorted by timestamp) + let lng_list = as_list_array(&columns[1]).unwrap().value(0); + let lng_array = as_primitive_array::(&lng_list).unwrap(); + assert_eq!(lng_array.len(), 3); + assert_eq!(lng_array.value(0), 5.0); // timestamp 100 + assert_eq!(lng_array.value(1), 6.0); // timestamp 200 + assert_eq!(lng_array.value(2), 4.0); // timestamp 300 + } else { + panic!("Expected Struct scalar value"); + } + } + + #[test] + fn test_geo_path_merge() { + let mut accumulator1 = GeoPathAccumulator::new(); + let mut accumulator2 = GeoPathAccumulator::new(); + + // Create test data for first accumulator + let lat_array1 = Arc::new(Float64Array::from(vec![1.0])); + let lng_array1 = Arc::new(Float64Array::from(vec![4.0])); + let ts_array1 = Arc::new(TimestampNanosecondArray::from(vec![100])); + + // Create test data for second accumulator + let lat_array2 = Arc::new(Float64Array::from(vec![2.0])); + let lng_array2 = Arc::new(Float64Array::from(vec![5.0])); + let ts_array2 = Arc::new(TimestampNanosecondArray::from(vec![200])); + + // Update batches + accumulator1 + .update_batch(&[lat_array1, lng_array1, ts_array1]) + .unwrap(); + accumulator2 + .update_batch(&[lat_array2, lng_array2, ts_array2]) + .unwrap(); + + // Get states + let state1 = accumulator1.state().unwrap(); + let state2 = accumulator2.state().unwrap(); + + // Create a merged accumulator + let mut merged = GeoPathAccumulator::new(); + + // Extract the struct arrays from the states + let state_array1 = match &state1[0] { + ScalarValue::Struct(array) => array.clone(), + _ => panic!("Expected Struct scalar value"), + }; + + let state_array2 = match &state2[0] { + ScalarValue::Struct(array) => array.clone(), + _ => panic!("Expected Struct scalar value"), + }; + + // Merge state arrays + merged.merge_batch(&[state_array1]).unwrap(); + merged.merge_batch(&[state_array2]).unwrap(); + + // Evaluate merged result + let result = merged.evaluate().unwrap(); + if let ScalarValue::Struct(struct_array) = result { + // Extract arrays + let columns = struct_array.columns(); + + // Check latitude values + let lat_list = as_list_array(&columns[0]).unwrap().value(0); + let lat_array = as_primitive_array::(&lat_list).unwrap(); + assert_eq!(lat_array.len(), 2); + assert_eq!(lat_array.value(0), 1.0); // timestamp 100 + assert_eq!(lat_array.value(1), 2.0); // timestamp 200 + + // Check longitude values (should be sorted by timestamp) + let lng_list = as_list_array(&columns[1]).unwrap().value(0); + let lng_array = as_primitive_array::(&lng_list).unwrap(); + assert_eq!(lng_array.len(), 2); + assert_eq!(lng_array.value(0), 4.0); // timestamp 100 + assert_eq!(lng_array.value(1), 5.0); // timestamp 200 + } else { + panic!("Expected Struct scalar value"); + } + } +} diff --git a/src/query/src/datafusion/planner.rs b/src/query/src/datafusion/planner.rs index 13e95ee560..912393690d 100644 --- a/src/query/src/datafusion/planner.rs +++ b/src/query/src/datafusion/planner.rs @@ -19,7 +19,8 @@ use std::sync::Arc; use arrow_schema::DataType; use catalog::table_source::DfTableSourceProvider; use common_function::aggr::{ - HllState, UddSketchState, HLL_MERGE_NAME, HLL_NAME, UDDSKETCH_STATE_NAME, + GeoPathAccumulator, HllState, UddSketchState, GEO_PATH_NAME, HLL_MERGE_NAME, HLL_NAME, + UDDSKETCH_STATE_NAME, }; use common_function::scalars::udf::create_udf; use common_query::logical_plan::create_aggregate_function; @@ -167,12 +168,12 @@ impl ContextProvider for DfContextProviderAdapter { fn get_aggregate_meta(&self, name: &str) -> Option> { if name == UDDSKETCH_STATE_NAME { return Some(Arc::new(UddSketchState::udf_impl())); - } - if name == HLL_NAME { + } else if name == HLL_NAME { return Some(Arc::new(HllState::state_udf_impl())); - } - if name == HLL_MERGE_NAME { + } else if name == HLL_MERGE_NAME { return Some(Arc::new(HllState::merge_udf_impl())); + } else if name == GEO_PATH_NAME { + return Some(Arc::new(GeoPathAccumulator::udf_impl())); } self.engine_state.aggregate_function(name).map_or_else( diff --git a/src/query/src/query_engine/default_serializer.rs b/src/query/src/query_engine/default_serializer.rs index 63ae3ab4fa..23d6789866 100644 --- a/src/query/src/query_engine/default_serializer.rs +++ b/src/query/src/query_engine/default_serializer.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use common_error::ext::BoxedError; -use common_function::aggr::{HllState, UddSketchState}; +use common_function::aggr::{GeoPathAccumulator, HllState, UddSketchState}; use common_function::function_registry::FUNCTION_REGISTRY; use common_function::scalars::udf::create_udf; use common_query::error::RegisterUdfSnafu; @@ -131,6 +131,7 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder { let _ = session_state.register_udaf(Arc::new(UddSketchState::udf_impl())); let _ = session_state.register_udaf(Arc::new(HllState::state_udf_impl())); let _ = session_state.register_udaf(Arc::new(HllState::merge_udf_impl())); + let _ = session_state.register_udaf(Arc::new(GeoPathAccumulator::udf_impl())); } let logical_plan = DFLogicalSubstraitConvertor .decode(message, session_state) diff --git a/tests/cases/standalone/common/function/geo.result b/tests/cases/standalone/common/function/geo.result index b9ae2ba580..7b051a35ed 100644 --- a/tests/cases/standalone/common/function/geo.result +++ b/tests/cases/standalone/common/function/geo.result @@ -333,15 +333,15 @@ FROM cell_cte; | 9263763445276221387 | 808f7fc59ef01fcb | 30 | 9277415232383221760 | +---------------------+---------------------------------+------------------------------+----------------------------------------+ -SELECT json_encode_path(37.76938, -122.3889, 1728083375::TimestampSecond); +SELECT UNNEST(geo_path(37.76938, -122.3889, 1728083375::TimestampSecond)); -+----------------------------------------------------------------------------------------------------------------------+ -| json_encode_path(Float64(37.76938),Float64(-122.3889),arrow_cast(Int64(1728083375),Utf8("Timestamp(Second, None)"))) | -+----------------------------------------------------------------------------------------------------------------------+ -| [[-122.3889,37.76938]] | -+----------------------------------------------------------------------------------------------------------------------+ ++--------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------+ +| unnest_placeholder(geo_path(Float64(37.76938),Float64(-122.3889),arrow_cast(Int64(1728083375),Utf8("Timestamp(Second, None)")))).lat | unnest_placeholder(geo_path(Float64(37.76938),Float64(-122.3889),arrow_cast(Int64(1728083375),Utf8("Timestamp(Second, None)")))).lng | ++--------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------+ +| [37.76938] | [-122.3889] | ++--------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------+ -SELECT json_encode_path(lat, lon, ts) +SELECT UNNEST(geo_path(lat, lon, ts)) FROM( SELECT 37.76938 AS lat, -122.3889 AS lon, 1728083375::TimestampSecond AS ts UNION ALL @@ -352,11 +352,11 @@ FROM( SELECT 37.77001 AS lat, -122.3888 AS lon, 1728083372::TimestampSecond AS ts ); -+-------------------------------------------------------------------------------------+ -| json_encode_path(lat,lon,ts) | -+-------------------------------------------------------------------------------------+ -| [[-122.3888,37.77001],[-122.3839,37.76928],[-122.3889,37.76938],[-122.382,37.7693]] | -+-------------------------------------------------------------------------------------+ ++----------------------------------------------+----------------------------------------------+ +| unnest_placeholder(geo_path(lat,lon,ts)).lat | unnest_placeholder(geo_path(lat,lon,ts)).lng | ++----------------------------------------------+----------------------------------------------+ +| [37.77001, 37.76928, 37.76938, 37.7693] | [-122.3888, -122.3839, -122.3889, -122.382] | ++----------------------------------------------+----------------------------------------------+ SELECT wkt_point_from_latlng(37.76938, -122.3889) AS point; diff --git a/tests/cases/standalone/common/function/geo.sql b/tests/cases/standalone/common/function/geo.sql index fe424eb228..89bd1e6a44 100644 --- a/tests/cases/standalone/common/function/geo.sql +++ b/tests/cases/standalone/common/function/geo.sql @@ -119,9 +119,9 @@ SELECT cell, s2_cell_parent(cell, 3) FROM cell_cte; -SELECT json_encode_path(37.76938, -122.3889, 1728083375::TimestampSecond); +SELECT UNNEST(geo_path(37.76938, -122.3889, 1728083375::TimestampSecond)); -SELECT json_encode_path(lat, lon, ts) +SELECT UNNEST(geo_path(lat, lon, ts)) FROM( SELECT 37.76938 AS lat, -122.3889 AS lon, 1728083375::TimestampSecond AS ts UNION ALL