diff --git a/src/common/function/src/aggr.rs b/src/common/function/src/aggr.rs
index be271d4d20..24bcb86618 100644
--- a/src/common/function/src/aggr.rs
+++ b/src/common/function/src/aggr.rs
@@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+mod geo_path;
mod hll;
mod uddsketch_state;
+pub use geo_path::{GeoPathAccumulator, GEO_PATH_NAME};
pub(crate) use hll::HllStateType;
pub use hll::{HllState, HLL_MERGE_NAME, HLL_NAME};
pub use uddsketch_state::{UddSketchState, UDDSKETCH_STATE_NAME};
diff --git a/src/common/function/src/aggr/geo_path.rs b/src/common/function/src/aggr/geo_path.rs
new file mode 100644
index 0000000000..d5a2f71b57
--- /dev/null
+++ b/src/common/function/src/aggr/geo_path.rs
@@ -0,0 +1,433 @@
+// Copyright 2023 Greptime Team
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::sync::Arc;
+
+use datafusion::arrow::array::{Array, ArrayRef};
+use datafusion::common::cast::as_primitive_array;
+use datafusion::error::{DataFusionError, Result as DfResult};
+use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, Volatility};
+use datafusion::prelude::create_udaf;
+use datafusion_common::cast::{as_list_array, as_struct_array};
+use datafusion_common::utils::SingleRowListArrayBuilder;
+use datafusion_common::ScalarValue;
+use datatypes::arrow::array::{Float64Array, Int64Array, ListArray, StructArray};
+use datatypes::arrow::datatypes::{
+ DataType, Field, Float64Type, Int64Type, TimeUnit, TimestampNanosecondType,
+};
+use datatypes::compute::{self, sort_to_indices};
+
+pub const GEO_PATH_NAME: &str = "geo_path";
+
+const LATITUDE_FIELD: &str = "lat";
+const LONGITUDE_FIELD: &str = "lng";
+const TIMESTAMP_FIELD: &str = "timestamp";
+const DEFAULT_LIST_FIELD_NAME: &str = "item";
+
+#[derive(Debug, Default)]
+pub struct GeoPathAccumulator {
+ lat: Vec>,
+ lng: Vec >,
+ timestamp: Vec >,
+}
+
+impl GeoPathAccumulator {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ pub fn udf_impl() -> AggregateUDF {
+ create_udaf(
+ GEO_PATH_NAME,
+ // Input types: lat, lng, timestamp
+ vec![
+ DataType::Float64,
+ DataType::Float64,
+ DataType::Timestamp(TimeUnit::Nanosecond, None),
+ ],
+ // Output type: list of points {[lat], [lng]}
+ Arc::new(DataType::Struct(
+ vec![
+ Field::new(
+ LATITUDE_FIELD,
+ DataType::List(Arc::new(Field::new(
+ DEFAULT_LIST_FIELD_NAME,
+ DataType::Float64,
+ true,
+ ))),
+ false,
+ ),
+ Field::new(
+ LONGITUDE_FIELD,
+ DataType::List(Arc::new(Field::new(
+ DEFAULT_LIST_FIELD_NAME,
+ DataType::Float64,
+ true,
+ ))),
+ false,
+ ),
+ ]
+ .into(),
+ )),
+ Volatility::Immutable,
+ // Create the accumulator
+ Arc::new(|_| Ok(Box::new(GeoPathAccumulator::new()))),
+ // Intermediate state types
+ Arc::new(vec![DataType::Struct(
+ vec![
+ Field::new(
+ LATITUDE_FIELD,
+ DataType::List(Arc::new(Field::new(
+ DEFAULT_LIST_FIELD_NAME,
+ DataType::Float64,
+ true,
+ ))),
+ false,
+ ),
+ Field::new(
+ LONGITUDE_FIELD,
+ DataType::List(Arc::new(Field::new(
+ DEFAULT_LIST_FIELD_NAME,
+ DataType::Float64,
+ true,
+ ))),
+ false,
+ ),
+ Field::new(
+ TIMESTAMP_FIELD,
+ DataType::List(Arc::new(Field::new(
+ DEFAULT_LIST_FIELD_NAME,
+ DataType::Int64,
+ true,
+ ))),
+ false,
+ ),
+ ]
+ .into(),
+ )]),
+ )
+ }
+}
+
+impl DfAccumulator for GeoPathAccumulator {
+ fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion::error::Result<()> {
+ if values.len() != 3 {
+ return Err(DataFusionError::Internal(format!(
+ "Expected 3 columns for geo_path, got {}",
+ values.len()
+ )));
+ }
+
+ let lat_array = as_primitive_array::(&values[0])?;
+ let lng_array = as_primitive_array::(&values[1])?;
+ let ts_array = as_primitive_array::(&values[2])?;
+
+ let size = lat_array.len();
+ self.lat.reserve(size);
+ self.lng.reserve(size);
+
+ for idx in 0..size {
+ self.lat.push(if lat_array.is_null(idx) {
+ None
+ } else {
+ Some(lat_array.value(idx))
+ });
+
+ self.lng.push(if lng_array.is_null(idx) {
+ None
+ } else {
+ Some(lng_array.value(idx))
+ });
+
+ self.timestamp.push(if ts_array.is_null(idx) {
+ None
+ } else {
+ Some(ts_array.value(idx))
+ });
+ }
+
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> DfResult {
+ let unordered_lng_array = Float64Array::from(self.lng.clone());
+ let unordered_lat_array = Float64Array::from(self.lat.clone());
+ let ts_array = Int64Array::from(self.timestamp.clone());
+
+ let ordered_indices = sort_to_indices(&ts_array, None, None)?;
+ let lat_array = compute::take(&unordered_lat_array, &ordered_indices, None)?;
+ let lng_array = compute::take(&unordered_lng_array, &ordered_indices, None)?;
+
+ let lat_list = Arc::new(SingleRowListArrayBuilder::new(lat_array).build_list_array());
+ let lng_list = Arc::new(SingleRowListArrayBuilder::new(lng_array).build_list_array());
+
+ let result = ScalarValue::Struct(Arc::new(StructArray::new(
+ vec![
+ Field::new(
+ LATITUDE_FIELD,
+ DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
+ false,
+ ),
+ Field::new(
+ LONGITUDE_FIELD,
+ DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
+ false,
+ ),
+ ]
+ .into(),
+ vec![lat_list, lng_list],
+ None,
+ )));
+
+ Ok(result)
+ }
+
+ fn size(&self) -> usize {
+ // Base size of GeoPathAccumulator struct fields
+ let mut total_size = std::mem::size_of::();
+
+ // Size of vectors (approximation)
+ total_size += self.lat.capacity() * std::mem::size_of::>();
+ total_size += self.lng.capacity() * std::mem::size_of:: >();
+ total_size += self.timestamp.capacity() * std::mem::size_of:: >();
+
+ total_size
+ }
+
+ fn state(&mut self) -> datafusion::error::Result> {
+ let lat_array = Arc::new(ListArray::from_iter_primitive::(vec![
+ Some(self.lat.clone()),
+ ]));
+ let lng_array = Arc::new(ListArray::from_iter_primitive::(vec![
+ Some(self.lng.clone()),
+ ]));
+ let ts_array = Arc::new(ListArray::from_iter_primitive::(vec![
+ Some(self.timestamp.clone()),
+ ]));
+
+ let state_struct = StructArray::new(
+ vec![
+ Field::new(
+ LATITUDE_FIELD,
+ DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
+ false,
+ ),
+ Field::new(
+ LONGITUDE_FIELD,
+ DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
+ false,
+ ),
+ Field::new(
+ TIMESTAMP_FIELD,
+ DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
+ false,
+ ),
+ ]
+ .into(),
+ vec![lat_array, lng_array, ts_array],
+ None,
+ );
+
+ Ok(vec![ScalarValue::Struct(Arc::new(state_struct))])
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion::error::Result<()> {
+ if states.len() != 1 {
+ return Err(DataFusionError::Internal(format!(
+ "Expected 1 states for geo_path, got {}",
+ states.len()
+ )));
+ }
+
+ for state in states {
+ let state = as_struct_array(state)?;
+ let lat_list = as_list_array(state.column(0))?.value(0);
+ let lat_array = as_primitive_array::(&lat_list)?;
+ let lng_list = as_list_array(state.column(1))?.value(0);
+ let lng_array = as_primitive_array::(&lng_list)?;
+ let ts_list = as_list_array(state.column(2))?.value(0);
+ let ts_array = as_primitive_array::(&ts_list)?;
+
+ self.lat.extend(lat_array);
+ self.lng.extend(lng_array);
+ self.timestamp.extend(ts_array);
+ }
+
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use datafusion::arrow::array::{Float64Array, TimestampNanosecondArray};
+ use datafusion::scalar::ScalarValue;
+
+ use super::*;
+
+ #[test]
+ fn test_geo_path_basic() {
+ let mut accumulator = GeoPathAccumulator::new();
+
+ // Create test data
+ let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
+ let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0]));
+ let ts_array = Arc::new(TimestampNanosecondArray::from(vec![100, 200, 300]));
+
+ // Update batch
+ accumulator
+ .update_batch(&[lat_array, lng_array, ts_array])
+ .unwrap();
+
+ // Evaluate
+ let result = accumulator.evaluate().unwrap();
+ if let ScalarValue::Struct(struct_array) = result {
+ // Verify structure
+ let fields = struct_array.fields().clone();
+ assert_eq!(fields.len(), 2);
+ assert_eq!(fields[0].name(), LATITUDE_FIELD);
+ assert_eq!(fields[1].name(), LONGITUDE_FIELD);
+
+ // Verify data
+ let columns = struct_array.columns();
+ assert_eq!(columns.len(), 2);
+
+ // Check latitude values
+ let lat_list = as_list_array(&columns[0]).unwrap().value(0);
+ let lat_array = as_primitive_array::(&lat_list).unwrap();
+ assert_eq!(lat_array.len(), 3);
+ assert_eq!(lat_array.value(0), 1.0);
+ assert_eq!(lat_array.value(1), 2.0);
+ assert_eq!(lat_array.value(2), 3.0);
+
+ // Check longitude values
+ let lng_list = as_list_array(&columns[1]).unwrap().value(0);
+ let lng_array = as_primitive_array::(&lng_list).unwrap();
+ assert_eq!(lng_array.len(), 3);
+ assert_eq!(lng_array.value(0), 4.0);
+ assert_eq!(lng_array.value(1), 5.0);
+ assert_eq!(lng_array.value(2), 6.0);
+ } else {
+ panic!("Expected Struct scalar value");
+ }
+ }
+
+ #[test]
+ fn test_geo_path_sort_by_timestamp() {
+ let mut accumulator = GeoPathAccumulator::new();
+
+ // Create test data with unordered timestamps
+ let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
+ let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0]));
+ let ts_array = Arc::new(TimestampNanosecondArray::from(vec![300, 100, 200]));
+
+ // Update batch
+ accumulator
+ .update_batch(&[lat_array, lng_array, ts_array])
+ .unwrap();
+
+ // Evaluate
+ let result = accumulator.evaluate().unwrap();
+ if let ScalarValue::Struct(struct_array) = result {
+ // Extract arrays
+ let columns = struct_array.columns();
+
+ // Check latitude values
+ let lat_list = as_list_array(&columns[0]).unwrap().value(0);
+ let lat_array = as_primitive_array::(&lat_list).unwrap();
+ assert_eq!(lat_array.len(), 3);
+ assert_eq!(lat_array.value(0), 2.0); // timestamp 100
+ assert_eq!(lat_array.value(1), 3.0); // timestamp 200
+ assert_eq!(lat_array.value(2), 1.0); // timestamp 300
+
+ // Check longitude values (should be sorted by timestamp)
+ let lng_list = as_list_array(&columns[1]).unwrap().value(0);
+ let lng_array = as_primitive_array::(&lng_list).unwrap();
+ assert_eq!(lng_array.len(), 3);
+ assert_eq!(lng_array.value(0), 5.0); // timestamp 100
+ assert_eq!(lng_array.value(1), 6.0); // timestamp 200
+ assert_eq!(lng_array.value(2), 4.0); // timestamp 300
+ } else {
+ panic!("Expected Struct scalar value");
+ }
+ }
+
+ #[test]
+ fn test_geo_path_merge() {
+ let mut accumulator1 = GeoPathAccumulator::new();
+ let mut accumulator2 = GeoPathAccumulator::new();
+
+ // Create test data for first accumulator
+ let lat_array1 = Arc::new(Float64Array::from(vec![1.0]));
+ let lng_array1 = Arc::new(Float64Array::from(vec![4.0]));
+ let ts_array1 = Arc::new(TimestampNanosecondArray::from(vec![100]));
+
+ // Create test data for second accumulator
+ let lat_array2 = Arc::new(Float64Array::from(vec![2.0]));
+ let lng_array2 = Arc::new(Float64Array::from(vec![5.0]));
+ let ts_array2 = Arc::new(TimestampNanosecondArray::from(vec![200]));
+
+ // Update batches
+ accumulator1
+ .update_batch(&[lat_array1, lng_array1, ts_array1])
+ .unwrap();
+ accumulator2
+ .update_batch(&[lat_array2, lng_array2, ts_array2])
+ .unwrap();
+
+ // Get states
+ let state1 = accumulator1.state().unwrap();
+ let state2 = accumulator2.state().unwrap();
+
+ // Create a merged accumulator
+ let mut merged = GeoPathAccumulator::new();
+
+ // Extract the struct arrays from the states
+ let state_array1 = match &state1[0] {
+ ScalarValue::Struct(array) => array.clone(),
+ _ => panic!("Expected Struct scalar value"),
+ };
+
+ let state_array2 = match &state2[0] {
+ ScalarValue::Struct(array) => array.clone(),
+ _ => panic!("Expected Struct scalar value"),
+ };
+
+ // Merge state arrays
+ merged.merge_batch(&[state_array1]).unwrap();
+ merged.merge_batch(&[state_array2]).unwrap();
+
+ // Evaluate merged result
+ let result = merged.evaluate().unwrap();
+ if let ScalarValue::Struct(struct_array) = result {
+ // Extract arrays
+ let columns = struct_array.columns();
+
+ // Check latitude values
+ let lat_list = as_list_array(&columns[0]).unwrap().value(0);
+ let lat_array = as_primitive_array::(&lat_list).unwrap();
+ assert_eq!(lat_array.len(), 2);
+ assert_eq!(lat_array.value(0), 1.0); // timestamp 100
+ assert_eq!(lat_array.value(1), 2.0); // timestamp 200
+
+ // Check longitude values (should be sorted by timestamp)
+ let lng_list = as_list_array(&columns[1]).unwrap().value(0);
+ let lng_array = as_primitive_array::(&lng_list).unwrap();
+ assert_eq!(lng_array.len(), 2);
+ assert_eq!(lng_array.value(0), 4.0); // timestamp 100
+ assert_eq!(lng_array.value(1), 5.0); // timestamp 200
+ } else {
+ panic!("Expected Struct scalar value");
+ }
+ }
+}
diff --git a/src/query/src/datafusion/planner.rs b/src/query/src/datafusion/planner.rs
index 13e95ee560..912393690d 100644
--- a/src/query/src/datafusion/planner.rs
+++ b/src/query/src/datafusion/planner.rs
@@ -19,7 +19,8 @@ use std::sync::Arc;
use arrow_schema::DataType;
use catalog::table_source::DfTableSourceProvider;
use common_function::aggr::{
- HllState, UddSketchState, HLL_MERGE_NAME, HLL_NAME, UDDSKETCH_STATE_NAME,
+ GeoPathAccumulator, HllState, UddSketchState, GEO_PATH_NAME, HLL_MERGE_NAME, HLL_NAME,
+ UDDSKETCH_STATE_NAME,
};
use common_function::scalars::udf::create_udf;
use common_query::logical_plan::create_aggregate_function;
@@ -167,12 +168,12 @@ impl ContextProvider for DfContextProviderAdapter {
fn get_aggregate_meta(&self, name: &str) -> Option> {
if name == UDDSKETCH_STATE_NAME {
return Some(Arc::new(UddSketchState::udf_impl()));
- }
- if name == HLL_NAME {
+ } else if name == HLL_NAME {
return Some(Arc::new(HllState::state_udf_impl()));
- }
- if name == HLL_MERGE_NAME {
+ } else if name == HLL_MERGE_NAME {
return Some(Arc::new(HllState::merge_udf_impl()));
+ } else if name == GEO_PATH_NAME {
+ return Some(Arc::new(GeoPathAccumulator::udf_impl()));
}
self.engine_state.aggregate_function(name).map_or_else(
diff --git a/src/query/src/query_engine/default_serializer.rs b/src/query/src/query_engine/default_serializer.rs
index 63ae3ab4fa..23d6789866 100644
--- a/src/query/src/query_engine/default_serializer.rs
+++ b/src/query/src/query_engine/default_serializer.rs
@@ -15,7 +15,7 @@
use std::sync::Arc;
use common_error::ext::BoxedError;
-use common_function::aggr::{HllState, UddSketchState};
+use common_function::aggr::{GeoPathAccumulator, HllState, UddSketchState};
use common_function::function_registry::FUNCTION_REGISTRY;
use common_function::scalars::udf::create_udf;
use common_query::error::RegisterUdfSnafu;
@@ -131,6 +131,7 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder {
let _ = session_state.register_udaf(Arc::new(UddSketchState::udf_impl()));
let _ = session_state.register_udaf(Arc::new(HllState::state_udf_impl()));
let _ = session_state.register_udaf(Arc::new(HllState::merge_udf_impl()));
+ let _ = session_state.register_udaf(Arc::new(GeoPathAccumulator::udf_impl()));
}
let logical_plan = DFLogicalSubstraitConvertor
.decode(message, session_state)
diff --git a/tests/cases/standalone/common/function/geo.result b/tests/cases/standalone/common/function/geo.result
index b9ae2ba580..7b051a35ed 100644
--- a/tests/cases/standalone/common/function/geo.result
+++ b/tests/cases/standalone/common/function/geo.result
@@ -333,15 +333,15 @@ FROM cell_cte;
| 9263763445276221387 | 808f7fc59ef01fcb | 30 | 9277415232383221760 |
+---------------------+---------------------------------+------------------------------+----------------------------------------+
-SELECT json_encode_path(37.76938, -122.3889, 1728083375::TimestampSecond);
+SELECT UNNEST(geo_path(37.76938, -122.3889, 1728083375::TimestampSecond));
-+----------------------------------------------------------------------------------------------------------------------+
-| json_encode_path(Float64(37.76938),Float64(-122.3889),arrow_cast(Int64(1728083375),Utf8("Timestamp(Second, None)"))) |
-+----------------------------------------------------------------------------------------------------------------------+
-| [[-122.3889,37.76938]] |
-+----------------------------------------------------------------------------------------------------------------------+
++--------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------+
+| unnest_placeholder(geo_path(Float64(37.76938),Float64(-122.3889),arrow_cast(Int64(1728083375),Utf8("Timestamp(Second, None)")))).lat | unnest_placeholder(geo_path(Float64(37.76938),Float64(-122.3889),arrow_cast(Int64(1728083375),Utf8("Timestamp(Second, None)")))).lng |
++--------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------+
+| [37.76938] | [-122.3889] |
++--------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------+
-SELECT json_encode_path(lat, lon, ts)
+SELECT UNNEST(geo_path(lat, lon, ts))
FROM(
SELECT 37.76938 AS lat, -122.3889 AS lon, 1728083375::TimestampSecond AS ts
UNION ALL
@@ -352,11 +352,11 @@ FROM(
SELECT 37.77001 AS lat, -122.3888 AS lon, 1728083372::TimestampSecond AS ts
);
-+-------------------------------------------------------------------------------------+
-| json_encode_path(lat,lon,ts) |
-+-------------------------------------------------------------------------------------+
-| [[-122.3888,37.77001],[-122.3839,37.76928],[-122.3889,37.76938],[-122.382,37.7693]] |
-+-------------------------------------------------------------------------------------+
++----------------------------------------------+----------------------------------------------+
+| unnest_placeholder(geo_path(lat,lon,ts)).lat | unnest_placeholder(geo_path(lat,lon,ts)).lng |
++----------------------------------------------+----------------------------------------------+
+| [37.77001, 37.76928, 37.76938, 37.7693] | [-122.3888, -122.3839, -122.3889, -122.382] |
++----------------------------------------------+----------------------------------------------+
SELECT wkt_point_from_latlng(37.76938, -122.3889) AS point;
diff --git a/tests/cases/standalone/common/function/geo.sql b/tests/cases/standalone/common/function/geo.sql
index fe424eb228..89bd1e6a44 100644
--- a/tests/cases/standalone/common/function/geo.sql
+++ b/tests/cases/standalone/common/function/geo.sql
@@ -119,9 +119,9 @@ SELECT cell,
s2_cell_parent(cell, 3)
FROM cell_cte;
-SELECT json_encode_path(37.76938, -122.3889, 1728083375::TimestampSecond);
+SELECT UNNEST(geo_path(37.76938, -122.3889, 1728083375::TimestampSecond));
-SELECT json_encode_path(lat, lon, ts)
+SELECT UNNEST(geo_path(lat, lon, ts))
FROM(
SELECT 37.76938 AS lat, -122.3889 AS lon, 1728083375::TimestampSecond AS ts
UNION ALL