feat: add built-in functions h3 and geohash

This commit is contained in:
Ning Sun
2024-08-31 10:14:27 -07:00
parent 841e66c810
commit 420446f19f
9 changed files with 348 additions and 3 deletions

52
Cargo.lock generated
View File

@@ -1950,6 +1950,8 @@ dependencies = [
"common-version",
"datafusion",
"datatypes",
"geohash",
"h3o",
"num",
"num-traits",
"once_cell",
@@ -3813,6 +3815,12 @@ dependencies = [
"num-traits",
]
[[package]]
name = "float_eq"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
[[package]]
name = "flow"
version = "0.9.2"
@@ -4211,6 +4219,27 @@ dependencies = [
"version_check",
]
[[package]]
name = "geo-types"
version = "0.7.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff16065e5720f376fbced200a5ae0f47ace85fd70b7e54269790281353b6d61"
dependencies = [
"approx",
"num-traits",
"serde",
]
[[package]]
name = "geohash"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fb94b1a65401d6cbf22958a9040aa364812c26674f841bee538b12c135db1e6"
dependencies = [
"geo-types",
"libm",
]
[[package]]
name = "gethostname"
version = "0.2.3"
@@ -4301,6 +4330,25 @@ dependencies = [
"tracing",
]
[[package]]
name = "h3o"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0de3592e1f699692aa0525c42ff7879ec3ee7e36329af20967bc910a1cdc39c7"
dependencies = [
"ahash 0.8.11",
"either",
"float_eq",
"h3o-bit",
"libm",
]
[[package]]
name = "h3o-bit"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fb45e8060378c0353781abf67e1917b545a6b710d0342d85b70c125af7ef320"
[[package]]
name = "half"
version = "1.8.3"
@@ -4717,7 +4765,7 @@ dependencies = [
"httpdate",
"itoa",
"pin-project-lite",
"socket2 0.4.10",
"socket2 0.5.7",
"tokio",
"tower-service",
"tracing",
@@ -8512,7 +8560,7 @@ dependencies = [
"indoc",
"libc",
"memoffset 0.9.1",
"parking_lot 0.11.2",
"parking_lot 0.12.3",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",

View File

@@ -7,6 +7,10 @@ license.workspace = true
[lints]
workspace = true
[features]
default = ["geo"]
geo = ["geohash", "h3o"]
[dependencies]
api.workspace = true
arc-swap = "1.0"
@@ -35,6 +39,8 @@ sql.workspace = true
statrs = "0.16"
store-api.workspace = true
table.workspace = true
geohash = { version = "0.13", optional = true }
h3o = { version = "0.6", optional = true }
[dev-dependencies]
ron = "0.7"

View File

@@ -116,6 +116,10 @@ pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
SystemFunction::register(&function_registry);
TableFunction::register(&function_registry);
// Geo functions
#[cfg(feature = "geo")]
crate::scalars::geo::GeoFunctions::register(&function_registry);
Arc::new(function_registry)
});

View File

@@ -15,6 +15,8 @@
pub mod aggregate;
pub(crate) mod date;
pub mod expression;
#[cfg(feature = "geo")]
pub mod geo;
pub mod matches;
pub mod math;
pub mod numpy;

View File

@@ -0,0 +1,31 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
mod geohash;
mod h3;
use geohash::GeohashFunction;
use h3::H3Function;
use crate::function_registry::FunctionRegistry;
pub(crate) struct GeoFunctions;
impl GeoFunctions {
pub fn register(registry: &FunctionRegistry) {
registry.register(Arc::new(GeohashFunction));
registry.register(Arc::new(H3Function));
}
}

View File

@@ -0,0 +1,114 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt;
use common_error::ext::{BoxedError, PlainError};
use common_error::status_code::StatusCode;
use common_query::error::{self, InvalidFuncArgsSnafu, Result};
use common_query::prelude::{Signature, TypeSignature};
use datafusion::logical_expr::Volatility;
use datatypes::prelude::ConcreteDataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::value::Value;
use datatypes::vectors::{MutableVector, StringVectorBuilder, VectorRef};
use geohash::Coord;
use snafu::{ensure, ResultExt};
use crate::function::{Function, FunctionContext};
/// Function that return geohash string for a given geospatial coordinate.
#[derive(Clone, Debug, Default)]
pub struct GeohashFunction;
const NAME: &str = "geohash";
impl Function for GeohashFunction {
fn name(&self) -> &str {
NAME
}
fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::string_datatype())
}
fn signature(&self) -> Signature {
Signature::one_of(
vec![
TypeSignature::Exact(vec![
ConcreteDataType::float32_datatype(),
ConcreteDataType::float32_datatype(),
ConcreteDataType::int64_datatype(),
]),
TypeSignature::Exact(vec![
ConcreteDataType::float64_datatype(),
ConcreteDataType::float64_datatype(),
ConcreteDataType::int64_datatype(),
]),
],
Volatility::Stable,
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 3,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect 3, provided : {}",
columns.len()
),
}
);
let lat_vec = &columns[0];
let lon_vec = &columns[1];
let resolution_vec = &columns[2];
let size = lat_vec.len();
let mut results = StringVectorBuilder::with_capacity(size);
for i in 0..size {
let lat = lat_vec.get(i).as_f64();
let lon = lon_vec.get(i).as_f64();
let resolution = resolution_vec.get(i);
let result = match (lat, lon, resolution) {
(Some(lat), Some(lon), Value::Int64(r)) => {
let coord = Coord { x: lon, y: lat };
let encoded = geohash::encode(coord, r as usize)
.map_err(|e| {
BoxedError::new(PlainError::new(
format!("Geohash error: {}", e.to_string()),
StatusCode::EngineExecuteQuery,
))
})
.context(error::ExecuteSnafu)?;
Some(encoded)
}
_ => None,
};
results.push(result.as_deref());
}
Ok(results.to_vector())
}
}
impl fmt::Display for GeohashFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", NAME)
}
}

View File

@@ -0,0 +1,122 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt;
use common_error::ext::{BoxedError, PlainError};
use common_error::status_code::StatusCode;
use common_query::error::{self, InvalidFuncArgsSnafu, Result};
use common_query::prelude::{Signature, TypeSignature};
use datafusion::logical_expr::Volatility;
use datatypes::prelude::ConcreteDataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::value::Value;
use datatypes::vectors::{MutableVector, StringVectorBuilder, VectorRef};
use h3o::{LatLng, Resolution};
use snafu::{ensure, ResultExt};
use crate::function::{Function, FunctionContext};
/// Function that return geohash string for a given geospatial coordinate.
#[derive(Clone, Debug, Default)]
pub struct H3Function;
const NAME: &str = "h3";
impl Function for H3Function {
fn name(&self) -> &str {
NAME
}
fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::string_datatype())
}
fn signature(&self) -> Signature {
Signature::one_of(
vec![
TypeSignature::Exact(vec![
ConcreteDataType::float32_datatype(),
ConcreteDataType::float32_datatype(),
ConcreteDataType::int64_datatype(),
]),
TypeSignature::Exact(vec![
ConcreteDataType::float64_datatype(),
ConcreteDataType::float64_datatype(),
ConcreteDataType::int64_datatype(),
]),
],
Volatility::Stable,
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 3,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect 3, provided : {}",
columns.len()
),
}
);
let lat_vec = &columns[0];
let lon_vec = &columns[1];
let resolution_vec = &columns[2];
let size = lat_vec.len();
let mut results = StringVectorBuilder::with_capacity(size);
for i in 0..size {
let lat = lat_vec.get(i).as_f64();
let lon = lon_vec.get(i).as_f64();
let resolution = resolution_vec.get(i);
let result = match (lat, lon, resolution) {
(Some(lat), Some(lon), Value::Int64(r)) => {
let coord = LatLng::new(lat, lon)
.map_err(|e| {
BoxedError::new(PlainError::new(
format!("H3 error: {}", e.to_string()),
StatusCode::EngineExecuteQuery,
))
})
.context(error::ExecuteSnafu)?;
let r = Resolution::try_from(r as u8)
.map_err(|e| {
BoxedError::new(PlainError::new(
format!("H3 error: {}", e.to_string()),
StatusCode::EngineExecuteQuery,
))
})
.context(error::ExecuteSnafu)?;
let encoded = coord.to_cell(r).to_string();
Some(encoded)
}
_ => None,
};
results.push(result.as_deref());
}
Ok(results.to_vector())
}
}
impl fmt::Display for H3Function {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", NAME)
}
}

View File

@@ -172,12 +172,13 @@ impl ErrorExt for Error {
Error::DataTypes { .. }
| Error::CreateRecordBatches { .. }
| Error::PollStream { .. }
| Error::Format { .. }
| Error::ToArrowScalar { .. }
| Error::ProjectArrowRecordBatch { .. }
| Error::PhysicalExpr { .. } => StatusCode::Internal,
Error::PollStream { .. } => StatusCode::EngineExecuteQuery,
Error::ArrowCompute { .. } => StatusCode::IllegalState,
Error::ColumnNotExists { .. } => StatusCode::TableColumnNotFound,

View File

@@ -268,6 +268,23 @@ impl Value {
}
}
/// Cast Value to f32. Return None if it's not castable;
pub fn as_f64(&self) -> Option<f64> {
match self {
Value::Float32(v) => Some(v.0 as _),
Value::Float64(v) => Some(v.0),
Value::Int8(v) => Some(*v as _),
Value::Int16(v) => Some(*v as _),
Value::Int32(v) => Some(*v as _),
Value::Int64(v) => Some(*v as _),
Value::UInt8(v) => Some(*v as _),
Value::UInt16(v) => Some(*v as _),
Value::UInt32(v) => Some(*v as _),
Value::UInt64(v) => Some(*v as _),
_ => None,
}
}
/// Returns the logical type of the value.
pub fn logical_type_id(&self) -> LogicalTypeId {
match self {