From 420446f19f49594bf047b7c0e2fd12fa92c09e8e Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sat, 31 Aug 2024 10:14:27 -0700 Subject: [PATCH] feat: add built-in functions h3 and geohash --- Cargo.lock | 52 +++++++- src/common/function/Cargo.toml | 6 + src/common/function/src/function_registry.rs | 4 + src/common/function/src/scalars.rs | 2 + src/common/function/src/scalars/geo.rs | 31 +++++ .../function/src/scalars/geo/geohash.rs | 114 ++++++++++++++++ src/common/function/src/scalars/geo/h3.rs | 122 ++++++++++++++++++ src/common/recordbatch/src/error.rs | 3 +- src/datatypes/src/value.rs | 17 +++ 9 files changed, 348 insertions(+), 3 deletions(-) create mode 100644 src/common/function/src/scalars/geo.rs create mode 100644 src/common/function/src/scalars/geo/geohash.rs create mode 100644 src/common/function/src/scalars/geo/h3.rs diff --git a/Cargo.lock b/Cargo.lock index 18d16b2682..7ca8f4303d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1950,6 +1950,8 @@ dependencies = [ "common-version", "datafusion", "datatypes", + "geohash", + "h3o", "num", "num-traits", "once_cell", @@ -3813,6 +3815,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "float_eq" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853" + [[package]] name = "flow" version = "0.9.2" @@ -4211,6 +4219,27 @@ dependencies = [ "version_check", ] +[[package]] +name = "geo-types" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ff16065e5720f376fbced200a5ae0f47ace85fd70b7e54269790281353b6d61" +dependencies = [ + "approx", + "num-traits", + "serde", +] + +[[package]] +name = "geohash" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fb94b1a65401d6cbf22958a9040aa364812c26674f841bee538b12c135db1e6" +dependencies = [ + "geo-types", + "libm", +] + [[package]] name = "gethostname" version = "0.2.3" @@ -4301,6 +4330,25 @@ dependencies = [ "tracing", ] +[[package]] +name = "h3o" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de3592e1f699692aa0525c42ff7879ec3ee7e36329af20967bc910a1cdc39c7" +dependencies = [ + "ahash 0.8.11", + "either", + "float_eq", + "h3o-bit", + "libm", +] + +[[package]] +name = "h3o-bit" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fb45e8060378c0353781abf67e1917b545a6b710d0342d85b70c125af7ef320" + [[package]] name = "half" version = "1.8.3" @@ -4717,7 +4765,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.4.10", + "socket2 0.5.7", "tokio", "tower-service", "tracing", @@ -8512,7 +8560,7 @@ dependencies = [ "indoc", "libc", "memoffset 0.9.1", - "parking_lot 0.11.2", + "parking_lot 0.12.3", "portable-atomic", "pyo3-build-config", "pyo3-ffi", diff --git a/src/common/function/Cargo.toml b/src/common/function/Cargo.toml index e7d6ee870f..eec49bbb00 100644 --- a/src/common/function/Cargo.toml +++ b/src/common/function/Cargo.toml @@ -7,6 +7,10 @@ license.workspace = true [lints] workspace = true +[features] +default = ["geo"] +geo = ["geohash", "h3o"] + [dependencies] api.workspace = true arc-swap = "1.0" @@ -35,6 +39,8 @@ sql.workspace = true statrs = "0.16" store-api.workspace = true table.workspace = true +geohash = { version = "0.13", optional = true } +h3o = { version = "0.6", optional = true } [dev-dependencies] ron = "0.7" diff --git a/src/common/function/src/function_registry.rs b/src/common/function/src/function_registry.rs index c2a315d51d..ed863c16aa 100644 --- a/src/common/function/src/function_registry.rs +++ b/src/common/function/src/function_registry.rs @@ -116,6 +116,10 @@ pub static FUNCTION_REGISTRY: Lazy> = Lazy::new(|| { SystemFunction::register(&function_registry); TableFunction::register(&function_registry); + // Geo functions + #[cfg(feature = "geo")] + crate::scalars::geo::GeoFunctions::register(&function_registry); + Arc::new(function_registry) }); diff --git a/src/common/function/src/scalars.rs b/src/common/function/src/scalars.rs index 2b3f463e94..f8dc570d12 100644 --- a/src/common/function/src/scalars.rs +++ b/src/common/function/src/scalars.rs @@ -15,6 +15,8 @@ pub mod aggregate; pub(crate) mod date; pub mod expression; +#[cfg(feature = "geo")] +pub mod geo; pub mod matches; pub mod math; pub mod numpy; diff --git a/src/common/function/src/scalars/geo.rs b/src/common/function/src/scalars/geo.rs new file mode 100644 index 0000000000..4b126f20f0 --- /dev/null +++ b/src/common/function/src/scalars/geo.rs @@ -0,0 +1,31 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +mod geohash; +mod h3; + +use geohash::GeohashFunction; +use h3::H3Function; + +use crate::function_registry::FunctionRegistry; + +pub(crate) struct GeoFunctions; + +impl GeoFunctions { + pub fn register(registry: &FunctionRegistry) { + registry.register(Arc::new(GeohashFunction)); + registry.register(Arc::new(H3Function)); + } +} diff --git a/src/common/function/src/scalars/geo/geohash.rs b/src/common/function/src/scalars/geo/geohash.rs new file mode 100644 index 0000000000..b61ed536c2 --- /dev/null +++ b/src/common/function/src/scalars/geo/geohash.rs @@ -0,0 +1,114 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt; + +use common_error::ext::{BoxedError, PlainError}; +use common_error::status_code::StatusCode; +use common_query::error::{self, InvalidFuncArgsSnafu, Result}; +use common_query::prelude::{Signature, TypeSignature}; +use datafusion::logical_expr::Volatility; +use datatypes::prelude::ConcreteDataType; +use datatypes::scalars::ScalarVectorBuilder; +use datatypes::value::Value; +use datatypes::vectors::{MutableVector, StringVectorBuilder, VectorRef}; +use geohash::Coord; +use snafu::{ensure, ResultExt}; + +use crate::function::{Function, FunctionContext}; + +/// Function that return geohash string for a given geospatial coordinate. +#[derive(Clone, Debug, Default)] +pub struct GeohashFunction; + +const NAME: &str = "geohash"; + +impl Function for GeohashFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { + Ok(ConcreteDataType::string_datatype()) + } + + fn signature(&self) -> Signature { + Signature::one_of( + vec![ + TypeSignature::Exact(vec![ + ConcreteDataType::float32_datatype(), + ConcreteDataType::float32_datatype(), + ConcreteDataType::int64_datatype(), + ]), + TypeSignature::Exact(vec![ + ConcreteDataType::float64_datatype(), + ConcreteDataType::float64_datatype(), + ConcreteDataType::int64_datatype(), + ]), + ], + Volatility::Stable, + ) + } + + fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + ensure!( + columns.len() == 3, + InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the args is not correct, expect 3, provided : {}", + columns.len() + ), + } + ); + + let lat_vec = &columns[0]; + let lon_vec = &columns[1]; + let resolution_vec = &columns[2]; + + let size = lat_vec.len(); + let mut results = StringVectorBuilder::with_capacity(size); + + for i in 0..size { + let lat = lat_vec.get(i).as_f64(); + let lon = lon_vec.get(i).as_f64(); + let resolution = resolution_vec.get(i); + + let result = match (lat, lon, resolution) { + (Some(lat), Some(lon), Value::Int64(r)) => { + let coord = Coord { x: lon, y: lat }; + let encoded = geohash::encode(coord, r as usize) + .map_err(|e| { + BoxedError::new(PlainError::new( + format!("Geohash error: {}", e.to_string()), + StatusCode::EngineExecuteQuery, + )) + }) + .context(error::ExecuteSnafu)?; + Some(encoded) + } + _ => None, + }; + + results.push(result.as_deref()); + } + + Ok(results.to_vector()) + } +} + +impl fmt::Display for GeohashFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", NAME) + } +} diff --git a/src/common/function/src/scalars/geo/h3.rs b/src/common/function/src/scalars/geo/h3.rs new file mode 100644 index 0000000000..7248633e10 --- /dev/null +++ b/src/common/function/src/scalars/geo/h3.rs @@ -0,0 +1,122 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt; + +use common_error::ext::{BoxedError, PlainError}; +use common_error::status_code::StatusCode; +use common_query::error::{self, InvalidFuncArgsSnafu, Result}; +use common_query::prelude::{Signature, TypeSignature}; +use datafusion::logical_expr::Volatility; +use datatypes::prelude::ConcreteDataType; +use datatypes::scalars::ScalarVectorBuilder; +use datatypes::value::Value; +use datatypes::vectors::{MutableVector, StringVectorBuilder, VectorRef}; +use h3o::{LatLng, Resolution}; +use snafu::{ensure, ResultExt}; + +use crate::function::{Function, FunctionContext}; + +/// Function that return geohash string for a given geospatial coordinate. +#[derive(Clone, Debug, Default)] +pub struct H3Function; + +const NAME: &str = "h3"; + +impl Function for H3Function { + fn name(&self) -> &str { + NAME + } + + fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { + Ok(ConcreteDataType::string_datatype()) + } + + fn signature(&self) -> Signature { + Signature::one_of( + vec![ + TypeSignature::Exact(vec![ + ConcreteDataType::float32_datatype(), + ConcreteDataType::float32_datatype(), + ConcreteDataType::int64_datatype(), + ]), + TypeSignature::Exact(vec![ + ConcreteDataType::float64_datatype(), + ConcreteDataType::float64_datatype(), + ConcreteDataType::int64_datatype(), + ]), + ], + Volatility::Stable, + ) + } + + fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + ensure!( + columns.len() == 3, + InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the args is not correct, expect 3, provided : {}", + columns.len() + ), + } + ); + + let lat_vec = &columns[0]; + let lon_vec = &columns[1]; + let resolution_vec = &columns[2]; + + let size = lat_vec.len(); + let mut results = StringVectorBuilder::with_capacity(size); + + for i in 0..size { + let lat = lat_vec.get(i).as_f64(); + let lon = lon_vec.get(i).as_f64(); + let resolution = resolution_vec.get(i); + + let result = match (lat, lon, resolution) { + (Some(lat), Some(lon), Value::Int64(r)) => { + let coord = LatLng::new(lat, lon) + .map_err(|e| { + BoxedError::new(PlainError::new( + format!("H3 error: {}", e.to_string()), + StatusCode::EngineExecuteQuery, + )) + }) + .context(error::ExecuteSnafu)?; + let r = Resolution::try_from(r as u8) + .map_err(|e| { + BoxedError::new(PlainError::new( + format!("H3 error: {}", e.to_string()), + StatusCode::EngineExecuteQuery, + )) + }) + .context(error::ExecuteSnafu)?; + let encoded = coord.to_cell(r).to_string(); + Some(encoded) + } + _ => None, + }; + + results.push(result.as_deref()); + } + + Ok(results.to_vector()) + } +} + +impl fmt::Display for H3Function { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", NAME) + } +} diff --git a/src/common/recordbatch/src/error.rs b/src/common/recordbatch/src/error.rs index f2114f645f..3eb90b05e7 100644 --- a/src/common/recordbatch/src/error.rs +++ b/src/common/recordbatch/src/error.rs @@ -172,12 +172,13 @@ impl ErrorExt for Error { Error::DataTypes { .. } | Error::CreateRecordBatches { .. } - | Error::PollStream { .. } | Error::Format { .. } | Error::ToArrowScalar { .. } | Error::ProjectArrowRecordBatch { .. } | Error::PhysicalExpr { .. } => StatusCode::Internal, + Error::PollStream { .. } => StatusCode::EngineExecuteQuery, + Error::ArrowCompute { .. } => StatusCode::IllegalState, Error::ColumnNotExists { .. } => StatusCode::TableColumnNotFound, diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index fdb6b38bb6..136de0363e 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -268,6 +268,23 @@ impl Value { } } + /// Cast Value to f32. Return None if it's not castable; + pub fn as_f64(&self) -> Option { + match self { + Value::Float32(v) => Some(v.0 as _), + Value::Float64(v) => Some(v.0), + Value::Int8(v) => Some(*v as _), + Value::Int16(v) => Some(*v as _), + Value::Int32(v) => Some(*v as _), + Value::Int64(v) => Some(*v as _), + Value::UInt8(v) => Some(*v as _), + Value::UInt16(v) => Some(*v as _), + Value::UInt32(v) => Some(*v as _), + Value::UInt64(v) => Some(*v as _), + _ => None, + } + } + /// Returns the logical type of the value. pub fn logical_type_id(&self) -> LogicalTypeId { match self {