From 0bff41e038f3cf8e39ecb0c0318ef5cb6cbd4fec Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Mon, 2 Sep 2024 11:43:33 -0700 Subject: [PATCH] refactor: address review comments --- .../function/src/scalars/geo/geohash.rs | 63 ++++++++++++------- src/common/function/src/scalars/geo/h3.rs | 61 ++++++++++++------ src/datatypes/src/value.rs | 2 +- 3 files changed, 84 insertions(+), 42 deletions(-) diff --git a/src/common/function/src/scalars/geo/geohash.rs b/src/common/function/src/scalars/geo/geohash.rs index fb39b46634..2daa8223cc 100644 --- a/src/common/function/src/scalars/geo/geohash.rs +++ b/src/common/function/src/scalars/geo/geohash.rs @@ -44,21 +44,32 @@ impl Function for GeohashFunction { } fn signature(&self) -> Signature { - Signature::one_of( - vec![ - TypeSignature::Exact(vec![ - ConcreteDataType::float32_datatype(), - ConcreteDataType::float32_datatype(), - ConcreteDataType::int64_datatype(), - ]), - TypeSignature::Exact(vec![ - ConcreteDataType::float64_datatype(), - ConcreteDataType::float64_datatype(), - ConcreteDataType::int64_datatype(), - ]), - ], - Volatility::Stable, - ) + let mut signatures = Vec::new(); + for coord_type in &[ + ConcreteDataType::float32_datatype(), + ConcreteDataType::float64_datatype(), + ] { + for resolution_type in &[ + ConcreteDataType::int8_datatype(), + ConcreteDataType::int16_datatype(), + ConcreteDataType::int32_datatype(), + ConcreteDataType::int64_datatype(), + ConcreteDataType::uint8_datatype(), + ConcreteDataType::uint16_datatype(), + ConcreteDataType::uint32_datatype(), + ConcreteDataType::uint64_datatype(), + ] { + signatures.push(TypeSignature::Exact(vec![ + // latitude + coord_type.clone(), + // longitude + coord_type.clone(), + // resolution + resolution_type.clone(), + ])); + } + } + Signature::one_of(signatures, Volatility::Stable) } fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { @@ -80,14 +91,24 @@ impl Function for GeohashFunction { let mut results = StringVectorBuilder::with_capacity(size); for i in 0..size { - let lat = lat_vec.get(i).as_f64(); - let lon = lon_vec.get(i).as_f64(); - let resolution = resolution_vec.get(i); + let lat = lat_vec.get(i).as_f64_lossy(); + let lon = lon_vec.get(i).as_f64_lossy(); + let r = match resolution_vec.get(i) { + Value::Int8(v) => v as usize, + Value::Int16(v) => v as usize, + Value::Int32(v) => v as usize, + Value::Int64(v) => v as usize, + Value::UInt8(v) => v as usize, + Value::UInt16(v) => v as usize, + Value::UInt32(v) => v as usize, + Value::UInt64(v) => v as usize, + _ => unreachable!(), + }; - let result = match (lat, lon, resolution) { - (Some(lat), Some(lon), Value::Int64(r)) => { + let result = match (lat, lon) { + (Some(lat), Some(lon)) => { let coord = Coord { x: lon, y: lat }; - let encoded = geohash::encode(coord, r as usize) + let encoded = geohash::encode(coord, r) .map_err(|e| { BoxedError::new(PlainError::new( format!("Geohash error: {}", e), diff --git a/src/common/function/src/scalars/geo/h3.rs b/src/common/function/src/scalars/geo/h3.rs index 7b58150079..7497549ca9 100644 --- a/src/common/function/src/scalars/geo/h3.rs +++ b/src/common/function/src/scalars/geo/h3.rs @@ -44,21 +44,32 @@ impl Function for H3Function { } fn signature(&self) -> Signature { - Signature::one_of( - vec![ - TypeSignature::Exact(vec![ - ConcreteDataType::float32_datatype(), - ConcreteDataType::float32_datatype(), - ConcreteDataType::int64_datatype(), - ]), - TypeSignature::Exact(vec![ - ConcreteDataType::float64_datatype(), - ConcreteDataType::float64_datatype(), - ConcreteDataType::int64_datatype(), - ]), - ], - Volatility::Stable, - ) + let mut signatures = Vec::new(); + for coord_type in &[ + ConcreteDataType::float32_datatype(), + ConcreteDataType::float64_datatype(), + ] { + for resolution_type in &[ + ConcreteDataType::int8_datatype(), + ConcreteDataType::int16_datatype(), + ConcreteDataType::int32_datatype(), + ConcreteDataType::int64_datatype(), + ConcreteDataType::uint8_datatype(), + ConcreteDataType::uint16_datatype(), + ConcreteDataType::uint32_datatype(), + ConcreteDataType::uint64_datatype(), + ] { + signatures.push(TypeSignature::Exact(vec![ + // latitude + coord_type.clone(), + // longitude + coord_type.clone(), + // resolution + resolution_type.clone(), + ])); + } + } + Signature::one_of(signatures, Volatility::Stable) } fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { @@ -80,12 +91,22 @@ impl Function for H3Function { let mut results = StringVectorBuilder::with_capacity(size); for i in 0..size { - let lat = lat_vec.get(i).as_f64(); - let lon = lon_vec.get(i).as_f64(); - let resolution = resolution_vec.get(i); + let lat = lat_vec.get(i).as_f64_lossy(); + let lon = lon_vec.get(i).as_f64_lossy(); + let r = match resolution_vec.get(i) { + Value::Int8(v) => v as u8, + Value::Int16(v) => v as u8, + Value::Int32(v) => v as u8, + Value::Int64(v) => v as u8, + Value::UInt8(v) => v, + Value::UInt16(v) => v as u8, + Value::UInt32(v) => v as u8, + Value::UInt64(v) => v as u8, + _ => unreachable!(), + }; - let result = match (lat, lon, resolution) { - (Some(lat), Some(lon), Value::Int64(r)) => { + let result = match (lat, lon) { + (Some(lat), Some(lon)) => { let coord = LatLng::new(lat, lon) .map_err(|e| { BoxedError::new(PlainError::new( diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index 136de0363e..256514c7d2 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -269,7 +269,7 @@ impl Value { } /// Cast Value to f32. Return None if it's not castable; - pub fn as_f64(&self) -> Option { + pub fn as_f64_lossy(&self) -> Option { match self { Value::Float32(v) => Some(v.0 as _), Value::Float64(v) => Some(v.0),