feat: use server inferenced types on statement describe (#2032)

* feat: use server inferenced types on statement describe

* feat: add support for server inferenced type

* feat: allow parameter type inferencing

* chore: update comments

* fix: lint issue

* style: comfort rustfmt

* Update src/servers/src/postgres/types.rs

Co-authored-by: Yingwen <realevenyag@gmail.com>

---------

Co-authored-by: Yingwen <realevenyag@gmail.com>
This commit is contained in:
Ning Sun
2023-08-09 10:57:56 +08:00
committed by GitHub
parent aa6452c86c
commit d18eb18b32
6 changed files with 120 additions and 41 deletions

9
Cargo.lock generated
View File

@@ -6544,9 +6544,9 @@ dependencies = [
[[package]]
name = "pgwire"
version = "0.15.0"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2de42ee35f9694def25c37c15f564555411d9904b48e33680618ee7359080dc"
checksum = "593c5af58c6394873b84c6fabf31f97e49ab29a56809e7fd240c1bcc4e5d272f"
dependencies = [
"async-trait",
"base64 0.21.2",
@@ -9882,6 +9882,7 @@ dependencies = [
"table",
"tempfile",
"tokio",
"tokio-postgres",
"tonic 0.9.2",
"tower",
"uuid",
@@ -11487,9 +11488,9 @@ dependencies = [
[[package]]
name = "x509-certificate"
version = "0.20.0"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2133ce6c08c050a5b368730a67c53a603ffd4a4a6c577c5218675a19f7782c05"
checksum = "5e5d27c90840e84503cf44364de338794d5d5680bdd1da6272d13f80b0769ee0"
dependencies = [
"bcder",
"bytes",

View File

@@ -58,7 +58,7 @@ openmetrics-parser = "0.4"
opensrv-mysql = "0.4"
opentelemetry-proto.workspace = true
parking_lot = "0.12"
pgwire = "0.15"
pgwire = "0.16"
pin-project = "1.0"
postgres-types = { version = "0.2", features = ["with-chrono-0_4"] }
promql-parser = "0.1.1"

View File

@@ -257,10 +257,20 @@ impl ExtendedQueryHandler for PostgresServerHandler {
{
let (param_types, sql_plan, format) = match target {
StatementOrPortal::Statement(stmt) => {
let param_types = Some(stmt.parameter_types().clone());
// TODO(sunng87): return server inferenced param_types if client
// not specified
(param_types, stmt.statement(), &Format::UnifiedBinary)
let sql_plan = stmt.statement();
if let Some(plan) = &sql_plan.plan {
let param_types = plan
.get_param_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let types = param_types_to_pg_types(&param_types)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
(Some(types), sql_plan, &Format::UnifiedBinary)
} else {
let param_types = Some(stmt.parameter_types().clone());
(param_types, sql_plan, &Format::UnifiedBinary)
}
}
StatementOrPortal::Portal(portal) => (
None,

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::ops::Deref;
use chrono::{NaiveDate, NaiveDateTime};
@@ -161,34 +162,37 @@ pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWir
match param_type {
&Type::VARCHAR | &Type::TEXT => Ok(format!(
"'{}'",
portal.parameter::<String>(idx)?.as_deref().unwrap_or("")
portal
.parameter::<String>(idx, param_type)?
.as_deref()
.unwrap_or("")
)),
&Type::BOOL => Ok(portal
.parameter::<bool>(idx)?
.parameter::<bool>(idx, param_type)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::INT4 => Ok(portal
.parameter::<i32>(idx)?
.parameter::<i32>(idx, param_type)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::INT8 => Ok(portal
.parameter::<i64>(idx)?
.parameter::<i64>(idx, param_type)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::FLOAT4 => Ok(portal
.parameter::<f32>(idx)?
.parameter::<f32>(idx, param_type)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::FLOAT8 => Ok(portal
.parameter::<f64>(idx)?
.parameter::<f64>(idx, param_type)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::DATE => Ok(portal
.parameter::<NaiveDate>(idx)?
.parameter::<NaiveDate>(idx, param_type)?
.map(|v| v.format("%Y-%m-%d").to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::TIMESTAMP => Ok(portal
.parameter::<NaiveDateTime>(idx)?
.parameter::<NaiveDateTime>(idx, param_type)?
.map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
.unwrap_or_else(|| "".to_owned())),
_ => Err(invalid_parameter_error(
@@ -245,24 +249,30 @@ pub(super) fn parameters_to_scalar_values(
)),
));
}
if client_param_types.len() != param_count {
return Err(invalid_parameter_error(
"invalid_parameter_count",
Some(&format!(
"Expected: {}, found: {}",
client_param_types.len(),
param_count
)),
));
}
for (idx, client_type) in client_param_types.iter().enumerate() {
let Some(Some(server_type)) = param_types.get(&format!("${}", idx + 1)) else {
continue;
for idx in 0..param_count {
let server_type =
if let Some(Some(server_infer_type)) = param_types.get(&format!("${}", idx + 1)) {
server_infer_type
} else {
// at the moment we require type information inferenced by
// server so here we return error if the type is unknown from
// server-side.
//
// It might be possible to parse the parameter just using client
// specified type, we will implement that if there is a case.
return Err(invalid_parameter_error("unknown_parameter_type", None));
};
let client_type = if let Some(client_given_type) = client_param_types.get(idx) {
client_given_type.clone()
} else {
type_gt_to_pg(server_type).map_err(|e| PgWireError::ApiError(Box::new(e)))?
};
let value = match client_type {
let value = match &client_type {
&Type::VARCHAR | &Type::TEXT => {
let data = portal.parameter::<String>(idx)?;
let data = portal.parameter::<String>(idx, &client_type)?;
match server_type {
ConcreteDataType::String(_) => ScalarValue::Utf8(data),
_ => {
@@ -277,7 +287,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::BOOL => {
let data = portal.parameter::<bool>(idx)?;
let data = portal.parameter::<bool>(idx, &client_type)?;
match server_type {
ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data),
_ => {
@@ -292,7 +302,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::INT2 => {
let data = portal.parameter::<i16>(idx)?;
let data = portal.parameter::<i16>(idx, &client_type)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data),
@@ -318,7 +328,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::INT4 => {
let data = portal.parameter::<i32>(idx)?;
let data = portal.parameter::<i32>(idx, &client_type)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
@@ -344,7 +354,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::INT8 => {
let data = portal.parameter::<i64>(idx)?;
let data = portal.parameter::<i64>(idx, &client_type)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
@@ -370,7 +380,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::FLOAT4 => {
let data = portal.parameter::<f32>(idx)?;
let data = portal.parameter::<f32>(idx, &client_type)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
@@ -394,7 +404,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::FLOAT8 => {
let data = portal.parameter::<f64>(idx)?;
let data = portal.parameter::<f64>(idx, &client_type)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
@@ -418,7 +428,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::TIMESTAMP => {
let data = portal.parameter::<NaiveDateTime>(idx)?;
let data = portal.parameter::<NaiveDateTime>(idx, &client_type)?;
match server_type {
ConcreteDataType::Timestamp(unit) => match *unit {
TimestampType::Second(_) => {
@@ -452,7 +462,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::DATE => {
let data = portal.parameter::<NaiveDate>(idx)?;
let data = portal.parameter::<NaiveDate>(idx, &client_type)?;
match server_type {
ConcreteDataType::Date(_) => ScalarValue::Date32(data.map(|d| {
(d - NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()).num_days() as i32
@@ -469,7 +479,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::BYTEA => {
let data = portal.parameter::<Vec<u8>>(idx)?;
let data = portal.parameter::<Vec<u8>>(idx, &client_type)?;
match server_type {
ConcreteDataType::String(_) => {
ScalarValue::Utf8(data.map(|d| String::from_utf8_lossy(&d).to_string()))
@@ -491,12 +501,29 @@ pub(super) fn parameters_to_scalar_values(
Some(&format!("Found type: {}", client_type)),
))?,
};
results.push(value);
}
Ok(results)
}
pub(super) fn param_types_to_pg_types(
param_types: &HashMap<String, Option<ConcreteDataType>>,
) -> Result<Vec<Type>> {
let param_count = param_types.len();
let mut types = Vec::with_capacity(param_count);
for i in 0..param_count {
if let Some(Some(param_type)) = param_types.get(&format!("${}", i + 1)) {
let pg_type = type_gt_to_pg(param_type)?;
types.push(pg_type);
} else {
types.push(Type::UNKNOWN);
}
}
Ok(types)
}
#[cfg(test)]
mod test {
use std::sync::Arc;

View File

@@ -71,3 +71,4 @@ prost.workspace = true
script = { workspace = true }
session = { workspace = true, features = ["testing"] }
store-api = { workspace = true }
tokio-postgres = "0.7"

View File

@@ -22,6 +22,7 @@ use tests_integration::test_util::{
setup_mysql_server, setup_mysql_server_with_user_provider, setup_pg_server,
setup_pg_server_with_user_provider, StorageType,
};
use tokio_postgres::NoTls;
#[macro_export]
macro_rules! sql_test {
@@ -57,6 +58,7 @@ macro_rules! sql_tests {
test_mysql_crud,
test_postgres_auth,
test_postgres_crud,
test_postgres_parameter_inference,
);
)*
};
@@ -332,3 +334,41 @@ pub async fn test_postgres_crud(store_type: StorageType) {
let _ = fe_pg_server.shutdown().await;
guard.remove_all().await;
}
pub async fn test_postgres_parameter_inference(store_type: StorageType) {
let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await;
let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls)
.await
.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
});
// Create demo table
let _ = client
.simple_query("create table demo(i bigint, ts timestamp time index, d date, dt datetime)")
.await
.unwrap();
let d = NaiveDate::from_yo_opt(2015, 100).unwrap();
let dt = d.and_hms_opt(0, 0, 0).unwrap();
let _ = client
.execute(
"INSERT INTO demo VALUES($1, $2, $3, $4)",
&[&0i64, &dt, &d, &dt],
)
.await
.unwrap();
let rows = client
.query("SELECT * FROM demo WHERE i = $1", &[&0i64])
.await
.unwrap();
assert_eq!(1, rows.len());
let _ = fe_pg_server.shutdown().await;
guard.remove_all().await;
}