From d18eb18b3233e004131fd4ffe9852ae4b02a5536 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 9 Aug 2023 10:57:56 +0800 Subject: [PATCH] feat: use server inferenced types on statement describe (#2032) * feat: use server inferenced types on statement describe * feat: add support for server inferenced type * feat: allow parameter type inferencing * chore: update comments * fix: lint issue * style: comfort rustfmt * Update src/servers/src/postgres/types.rs Co-authored-by: Yingwen --------- Co-authored-by: Yingwen --- Cargo.lock | 9 +-- src/servers/Cargo.toml | 2 +- src/servers/src/postgres/handler.rs | 18 ++++-- src/servers/src/postgres/types.rs | 91 +++++++++++++++++++---------- tests-integration/Cargo.toml | 1 + tests-integration/tests/sql.rs | 40 +++++++++++++ 6 files changed, 120 insertions(+), 41 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c9e9888625..1d0ad5ca78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6544,9 +6544,9 @@ dependencies = [ [[package]] name = "pgwire" -version = "0.15.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2de42ee35f9694def25c37c15f564555411d9904b48e33680618ee7359080dc" +checksum = "593c5af58c6394873b84c6fabf31f97e49ab29a56809e7fd240c1bcc4e5d272f" dependencies = [ "async-trait", "base64 0.21.2", @@ -9882,6 +9882,7 @@ dependencies = [ "table", "tempfile", "tokio", + "tokio-postgres", "tonic 0.9.2", "tower", "uuid", @@ -11487,9 +11488,9 @@ dependencies = [ [[package]] name = "x509-certificate" -version = "0.20.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2133ce6c08c050a5b368730a67c53a603ffd4a4a6c577c5218675a19f7782c05" +checksum = "5e5d27c90840e84503cf44364de338794d5d5680bdd1da6272d13f80b0769ee0" dependencies = [ "bcder", "bytes", diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index a5e5aa5f45..331e526a95 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -58,7 +58,7 @@ openmetrics-parser = "0.4" opensrv-mysql = "0.4" opentelemetry-proto.workspace = true parking_lot = "0.12" -pgwire = "0.15" +pgwire = "0.16" pin-project = "1.0" postgres-types = { version = "0.2", features = ["with-chrono-0_4"] } promql-parser = "0.1.1" diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index da8355603f..536ca60417 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -257,10 +257,20 @@ impl ExtendedQueryHandler for PostgresServerHandler { { let (param_types, sql_plan, format) = match target { StatementOrPortal::Statement(stmt) => { - let param_types = Some(stmt.parameter_types().clone()); - // TODO(sunng87): return server inferenced param_types if client - // not specified - (param_types, stmt.statement(), &Format::UnifiedBinary) + let sql_plan = stmt.statement(); + if let Some(plan) = &sql_plan.plan { + let param_types = plan + .get_param_types() + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + let types = param_types_to_pg_types(¶m_types) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + (Some(types), sql_plan, &Format::UnifiedBinary) + } else { + let param_types = Some(stmt.parameter_types().clone()); + (param_types, sql_plan, &Format::UnifiedBinary) + } } StatementOrPortal::Portal(portal) => ( None, diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs index e511cebe4c..5447023301 100644 --- a/src/servers/src/postgres/types.rs +++ b/src/servers/src/postgres/types.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::ops::Deref; use chrono::{NaiveDate, NaiveDateTime}; @@ -161,34 +162,37 @@ pub(super) fn parameter_to_string(portal: &Portal, idx: usize) -> PgWir match param_type { &Type::VARCHAR | &Type::TEXT => Ok(format!( "'{}'", - portal.parameter::(idx)?.as_deref().unwrap_or("") + portal + .parameter::(idx, param_type)? + .as_deref() + .unwrap_or("") )), &Type::BOOL => Ok(portal - .parameter::(idx)? + .parameter::(idx, param_type)? .map(|v| v.to_string()) .unwrap_or_else(|| "".to_owned())), &Type::INT4 => Ok(portal - .parameter::(idx)? + .parameter::(idx, param_type)? .map(|v| v.to_string()) .unwrap_or_else(|| "".to_owned())), &Type::INT8 => Ok(portal - .parameter::(idx)? + .parameter::(idx, param_type)? .map(|v| v.to_string()) .unwrap_or_else(|| "".to_owned())), &Type::FLOAT4 => Ok(portal - .parameter::(idx)? + .parameter::(idx, param_type)? .map(|v| v.to_string()) .unwrap_or_else(|| "".to_owned())), &Type::FLOAT8 => Ok(portal - .parameter::(idx)? + .parameter::(idx, param_type)? .map(|v| v.to_string()) .unwrap_or_else(|| "".to_owned())), &Type::DATE => Ok(portal - .parameter::(idx)? + .parameter::(idx, param_type)? .map(|v| v.format("%Y-%m-%d").to_string()) .unwrap_or_else(|| "".to_owned())), &Type::TIMESTAMP => Ok(portal - .parameter::(idx)? + .parameter::(idx, param_type)? .map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string()) .unwrap_or_else(|| "".to_owned())), _ => Err(invalid_parameter_error( @@ -245,24 +249,30 @@ pub(super) fn parameters_to_scalar_values( )), )); } - if client_param_types.len() != param_count { - return Err(invalid_parameter_error( - "invalid_parameter_count", - Some(&format!( - "Expected: {}, found: {}", - client_param_types.len(), - param_count - )), - )); - } - for (idx, client_type) in client_param_types.iter().enumerate() { - let Some(Some(server_type)) = param_types.get(&format!("${}", idx + 1)) else { - continue; + for idx in 0..param_count { + let server_type = + if let Some(Some(server_infer_type)) = param_types.get(&format!("${}", idx + 1)) { + server_infer_type + } else { + // at the moment we require type information inferenced by + // server so here we return error if the type is unknown from + // server-side. + // + // It might be possible to parse the parameter just using client + // specified type, we will implement that if there is a case. + return Err(invalid_parameter_error("unknown_parameter_type", None)); + }; + + let client_type = if let Some(client_given_type) = client_param_types.get(idx) { + client_given_type.clone() + } else { + type_gt_to_pg(server_type).map_err(|e| PgWireError::ApiError(Box::new(e)))? }; - let value = match client_type { + + let value = match &client_type { &Type::VARCHAR | &Type::TEXT => { - let data = portal.parameter::(idx)?; + let data = portal.parameter::(idx, &client_type)?; match server_type { ConcreteDataType::String(_) => ScalarValue::Utf8(data), _ => { @@ -277,7 +287,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::BOOL => { - let data = portal.parameter::(idx)?; + let data = portal.parameter::(idx, &client_type)?; match server_type { ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data), _ => { @@ -292,7 +302,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::INT2 => { - let data = portal.parameter::(idx)?; + let data = portal.parameter::(idx, &client_type)?; match server_type { ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), ConcreteDataType::Int16(_) => ScalarValue::Int16(data), @@ -318,7 +328,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::INT4 => { - let data = portal.parameter::(idx)?; + let data = portal.parameter::(idx, &client_type)?; match server_type { ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), @@ -344,7 +354,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::INT8 => { - let data = portal.parameter::(idx)?; + let data = portal.parameter::(idx, &client_type)?; match server_type { ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), @@ -370,7 +380,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::FLOAT4 => { - let data = portal.parameter::(idx)?; + let data = portal.parameter::(idx, &client_type)?; match server_type { ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), @@ -394,7 +404,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::FLOAT8 => { - let data = portal.parameter::(idx)?; + let data = portal.parameter::(idx, &client_type)?; match server_type { ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), @@ -418,7 +428,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::TIMESTAMP => { - let data = portal.parameter::(idx)?; + let data = portal.parameter::(idx, &client_type)?; match server_type { ConcreteDataType::Timestamp(unit) => match *unit { TimestampType::Second(_) => { @@ -452,7 +462,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::DATE => { - let data = portal.parameter::(idx)?; + let data = portal.parameter::(idx, &client_type)?; match server_type { ConcreteDataType::Date(_) => ScalarValue::Date32(data.map(|d| { (d - NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()).num_days() as i32 @@ -469,7 +479,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::BYTEA => { - let data = portal.parameter::>(idx)?; + let data = portal.parameter::>(idx, &client_type)?; match server_type { ConcreteDataType::String(_) => { ScalarValue::Utf8(data.map(|d| String::from_utf8_lossy(&d).to_string())) @@ -491,12 +501,29 @@ pub(super) fn parameters_to_scalar_values( Some(&format!("Found type: {}", client_type)), ))?, }; + results.push(value); } Ok(results) } +pub(super) fn param_types_to_pg_types( + param_types: &HashMap>, +) -> Result> { + let param_count = param_types.len(); + let mut types = Vec::with_capacity(param_count); + for i in 0..param_count { + if let Some(Some(param_type)) = param_types.get(&format!("${}", i + 1)) { + let pg_type = type_gt_to_pg(param_type)?; + types.push(pg_type); + } else { + types.push(Type::UNKNOWN); + } + } + Ok(types) +} + #[cfg(test)] mod test { use std::sync::Arc; diff --git a/tests-integration/Cargo.toml b/tests-integration/Cargo.toml index 87af890a68..a92911cb32 100644 --- a/tests-integration/Cargo.toml +++ b/tests-integration/Cargo.toml @@ -71,3 +71,4 @@ prost.workspace = true script = { workspace = true } session = { workspace = true, features = ["testing"] } store-api = { workspace = true } +tokio-postgres = "0.7" diff --git a/tests-integration/tests/sql.rs b/tests-integration/tests/sql.rs index e4652d9665..44c4c4d30a 100644 --- a/tests-integration/tests/sql.rs +++ b/tests-integration/tests/sql.rs @@ -22,6 +22,7 @@ use tests_integration::test_util::{ setup_mysql_server, setup_mysql_server_with_user_provider, setup_pg_server, setup_pg_server_with_user_provider, StorageType, }; +use tokio_postgres::NoTls; #[macro_export] macro_rules! sql_test { @@ -57,6 +58,7 @@ macro_rules! sql_tests { test_mysql_crud, test_postgres_auth, test_postgres_crud, + test_postgres_parameter_inference, ); )* }; @@ -332,3 +334,41 @@ pub async fn test_postgres_crud(store_type: StorageType) { let _ = fe_pg_server.shutdown().await; guard.remove_all().await; } + +pub async fn test_postgres_parameter_inference(store_type: StorageType) { + let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await; + + let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls) + .await + .unwrap(); + + tokio::spawn(async move { + connection.await.unwrap(); + }); + + // Create demo table + let _ = client + .simple_query("create table demo(i bigint, ts timestamp time index, d date, dt datetime)") + .await + .unwrap(); + + let d = NaiveDate::from_yo_opt(2015, 100).unwrap(); + let dt = d.and_hms_opt(0, 0, 0).unwrap(); + let _ = client + .execute( + "INSERT INTO demo VALUES($1, $2, $3, $4)", + &[&0i64, &dt, &d, &dt], + ) + .await + .unwrap(); + + let rows = client + .query("SELECT * FROM demo WHERE i = $1", &[&0i64]) + .await + .unwrap(); + + assert_eq!(1, rows.len()); + + let _ = fe_pg_server.shutdown().await; + guard.remove_all().await; +}