mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-05 21:02:58 +00:00
feat: use server inferenced types on statement describe (#2032)
* feat: use server inferenced types on statement describe * feat: add support for server inferenced type * feat: allow parameter type inferencing * chore: update comments * fix: lint issue * style: comfort rustfmt * Update src/servers/src/postgres/types.rs Co-authored-by: Yingwen <realevenyag@gmail.com> --------- Co-authored-by: Yingwen <realevenyag@gmail.com>
This commit is contained in:
9
Cargo.lock
generated
9
Cargo.lock
generated
@@ -6544,9 +6544,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pgwire"
|
||||
version = "0.15.0"
|
||||
version = "0.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2de42ee35f9694def25c37c15f564555411d9904b48e33680618ee7359080dc"
|
||||
checksum = "593c5af58c6394873b84c6fabf31f97e49ab29a56809e7fd240c1bcc4e5d272f"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.21.2",
|
||||
@@ -9882,6 +9882,7 @@ dependencies = [
|
||||
"table",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tonic 0.9.2",
|
||||
"tower",
|
||||
"uuid",
|
||||
@@ -11487,9 +11488,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "x509-certificate"
|
||||
version = "0.20.0"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2133ce6c08c050a5b368730a67c53a603ffd4a4a6c577c5218675a19f7782c05"
|
||||
checksum = "5e5d27c90840e84503cf44364de338794d5d5680bdd1da6272d13f80b0769ee0"
|
||||
dependencies = [
|
||||
"bcder",
|
||||
"bytes",
|
||||
|
||||
@@ -58,7 +58,7 @@ openmetrics-parser = "0.4"
|
||||
opensrv-mysql = "0.4"
|
||||
opentelemetry-proto.workspace = true
|
||||
parking_lot = "0.12"
|
||||
pgwire = "0.15"
|
||||
pgwire = "0.16"
|
||||
pin-project = "1.0"
|
||||
postgres-types = { version = "0.2", features = ["with-chrono-0_4"] }
|
||||
promql-parser = "0.1.1"
|
||||
|
||||
@@ -257,10 +257,20 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
{
|
||||
let (param_types, sql_plan, format) = match target {
|
||||
StatementOrPortal::Statement(stmt) => {
|
||||
let param_types = Some(stmt.parameter_types().clone());
|
||||
// TODO(sunng87): return server inferenced param_types if client
|
||||
// not specified
|
||||
(param_types, stmt.statement(), &Format::UnifiedBinary)
|
||||
let sql_plan = stmt.statement();
|
||||
if let Some(plan) = &sql_plan.plan {
|
||||
let param_types = plan
|
||||
.get_param_types()
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
|
||||
let types = param_types_to_pg_types(¶m_types)
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
|
||||
(Some(types), sql_plan, &Format::UnifiedBinary)
|
||||
} else {
|
||||
let param_types = Some(stmt.parameter_types().clone());
|
||||
(param_types, sql_plan, &Format::UnifiedBinary)
|
||||
}
|
||||
}
|
||||
StatementOrPortal::Portal(portal) => (
|
||||
None,
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Deref;
|
||||
|
||||
use chrono::{NaiveDate, NaiveDateTime};
|
||||
@@ -161,34 +162,37 @@ pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWir
|
||||
match param_type {
|
||||
&Type::VARCHAR | &Type::TEXT => Ok(format!(
|
||||
"'{}'",
|
||||
portal.parameter::<String>(idx)?.as_deref().unwrap_or("")
|
||||
portal
|
||||
.parameter::<String>(idx, param_type)?
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
)),
|
||||
&Type::BOOL => Ok(portal
|
||||
.parameter::<bool>(idx)?
|
||||
.parameter::<bool>(idx, param_type)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::INT4 => Ok(portal
|
||||
.parameter::<i32>(idx)?
|
||||
.parameter::<i32>(idx, param_type)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::INT8 => Ok(portal
|
||||
.parameter::<i64>(idx)?
|
||||
.parameter::<i64>(idx, param_type)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::FLOAT4 => Ok(portal
|
||||
.parameter::<f32>(idx)?
|
||||
.parameter::<f32>(idx, param_type)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::FLOAT8 => Ok(portal
|
||||
.parameter::<f64>(idx)?
|
||||
.parameter::<f64>(idx, param_type)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::DATE => Ok(portal
|
||||
.parameter::<NaiveDate>(idx)?
|
||||
.parameter::<NaiveDate>(idx, param_type)?
|
||||
.map(|v| v.format("%Y-%m-%d").to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::TIMESTAMP => Ok(portal
|
||||
.parameter::<NaiveDateTime>(idx)?
|
||||
.parameter::<NaiveDateTime>(idx, param_type)?
|
||||
.map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
_ => Err(invalid_parameter_error(
|
||||
@@ -245,24 +249,30 @@ pub(super) fn parameters_to_scalar_values(
|
||||
)),
|
||||
));
|
||||
}
|
||||
if client_param_types.len() != param_count {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_count",
|
||||
Some(&format!(
|
||||
"Expected: {}, found: {}",
|
||||
client_param_types.len(),
|
||||
param_count
|
||||
)),
|
||||
));
|
||||
}
|
||||
|
||||
for (idx, client_type) in client_param_types.iter().enumerate() {
|
||||
let Some(Some(server_type)) = param_types.get(&format!("${}", idx + 1)) else {
|
||||
continue;
|
||||
for idx in 0..param_count {
|
||||
let server_type =
|
||||
if let Some(Some(server_infer_type)) = param_types.get(&format!("${}", idx + 1)) {
|
||||
server_infer_type
|
||||
} else {
|
||||
// at the moment we require type information inferenced by
|
||||
// server so here we return error if the type is unknown from
|
||||
// server-side.
|
||||
//
|
||||
// It might be possible to parse the parameter just using client
|
||||
// specified type, we will implement that if there is a case.
|
||||
return Err(invalid_parameter_error("unknown_parameter_type", None));
|
||||
};
|
||||
|
||||
let client_type = if let Some(client_given_type) = client_param_types.get(idx) {
|
||||
client_given_type.clone()
|
||||
} else {
|
||||
type_gt_to_pg(server_type).map_err(|e| PgWireError::ApiError(Box::new(e)))?
|
||||
};
|
||||
let value = match client_type {
|
||||
|
||||
let value = match &client_type {
|
||||
&Type::VARCHAR | &Type::TEXT => {
|
||||
let data = portal.parameter::<String>(idx)?;
|
||||
let data = portal.parameter::<String>(idx, &client_type)?;
|
||||
match server_type {
|
||||
ConcreteDataType::String(_) => ScalarValue::Utf8(data),
|
||||
_ => {
|
||||
@@ -277,7 +287,7 @@ pub(super) fn parameters_to_scalar_values(
|
||||
}
|
||||
}
|
||||
&Type::BOOL => {
|
||||
let data = portal.parameter::<bool>(idx)?;
|
||||
let data = portal.parameter::<bool>(idx, &client_type)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data),
|
||||
_ => {
|
||||
@@ -292,7 +302,7 @@ pub(super) fn parameters_to_scalar_values(
|
||||
}
|
||||
}
|
||||
&Type::INT2 => {
|
||||
let data = portal.parameter::<i16>(idx)?;
|
||||
let data = portal.parameter::<i16>(idx, &client_type)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
|
||||
ConcreteDataType::Int16(_) => ScalarValue::Int16(data),
|
||||
@@ -318,7 +328,7 @@ pub(super) fn parameters_to_scalar_values(
|
||||
}
|
||||
}
|
||||
&Type::INT4 => {
|
||||
let data = portal.parameter::<i32>(idx)?;
|
||||
let data = portal.parameter::<i32>(idx, &client_type)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
|
||||
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
|
||||
@@ -344,7 +354,7 @@ pub(super) fn parameters_to_scalar_values(
|
||||
}
|
||||
}
|
||||
&Type::INT8 => {
|
||||
let data = portal.parameter::<i64>(idx)?;
|
||||
let data = portal.parameter::<i64>(idx, &client_type)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
|
||||
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
|
||||
@@ -370,7 +380,7 @@ pub(super) fn parameters_to_scalar_values(
|
||||
}
|
||||
}
|
||||
&Type::FLOAT4 => {
|
||||
let data = portal.parameter::<f32>(idx)?;
|
||||
let data = portal.parameter::<f32>(idx, &client_type)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
|
||||
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
|
||||
@@ -394,7 +404,7 @@ pub(super) fn parameters_to_scalar_values(
|
||||
}
|
||||
}
|
||||
&Type::FLOAT8 => {
|
||||
let data = portal.parameter::<f64>(idx)?;
|
||||
let data = portal.parameter::<f64>(idx, &client_type)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
|
||||
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
|
||||
@@ -418,7 +428,7 @@ pub(super) fn parameters_to_scalar_values(
|
||||
}
|
||||
}
|
||||
&Type::TIMESTAMP => {
|
||||
let data = portal.parameter::<NaiveDateTime>(idx)?;
|
||||
let data = portal.parameter::<NaiveDateTime>(idx, &client_type)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Timestamp(unit) => match *unit {
|
||||
TimestampType::Second(_) => {
|
||||
@@ -452,7 +462,7 @@ pub(super) fn parameters_to_scalar_values(
|
||||
}
|
||||
}
|
||||
&Type::DATE => {
|
||||
let data = portal.parameter::<NaiveDate>(idx)?;
|
||||
let data = portal.parameter::<NaiveDate>(idx, &client_type)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Date(_) => ScalarValue::Date32(data.map(|d| {
|
||||
(d - NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()).num_days() as i32
|
||||
@@ -469,7 +479,7 @@ pub(super) fn parameters_to_scalar_values(
|
||||
}
|
||||
}
|
||||
&Type::BYTEA => {
|
||||
let data = portal.parameter::<Vec<u8>>(idx)?;
|
||||
let data = portal.parameter::<Vec<u8>>(idx, &client_type)?;
|
||||
match server_type {
|
||||
ConcreteDataType::String(_) => {
|
||||
ScalarValue::Utf8(data.map(|d| String::from_utf8_lossy(&d).to_string()))
|
||||
@@ -491,12 +501,29 @@ pub(super) fn parameters_to_scalar_values(
|
||||
Some(&format!("Found type: {}", client_type)),
|
||||
))?,
|
||||
};
|
||||
|
||||
results.push(value);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
pub(super) fn param_types_to_pg_types(
|
||||
param_types: &HashMap<String, Option<ConcreteDataType>>,
|
||||
) -> Result<Vec<Type>> {
|
||||
let param_count = param_types.len();
|
||||
let mut types = Vec::with_capacity(param_count);
|
||||
for i in 0..param_count {
|
||||
if let Some(Some(param_type)) = param_types.get(&format!("${}", i + 1)) {
|
||||
let pg_type = type_gt_to_pg(param_type)?;
|
||||
types.push(pg_type);
|
||||
} else {
|
||||
types.push(Type::UNKNOWN);
|
||||
}
|
||||
}
|
||||
Ok(types)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -71,3 +71,4 @@ prost.workspace = true
|
||||
script = { workspace = true }
|
||||
session = { workspace = true, features = ["testing"] }
|
||||
store-api = { workspace = true }
|
||||
tokio-postgres = "0.7"
|
||||
|
||||
@@ -22,6 +22,7 @@ use tests_integration::test_util::{
|
||||
setup_mysql_server, setup_mysql_server_with_user_provider, setup_pg_server,
|
||||
setup_pg_server_with_user_provider, StorageType,
|
||||
};
|
||||
use tokio_postgres::NoTls;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! sql_test {
|
||||
@@ -57,6 +58,7 @@ macro_rules! sql_tests {
|
||||
test_mysql_crud,
|
||||
test_postgres_auth,
|
||||
test_postgres_crud,
|
||||
test_postgres_parameter_inference,
|
||||
);
|
||||
)*
|
||||
};
|
||||
@@ -332,3 +334,41 @@ pub async fn test_postgres_crud(store_type: StorageType) {
|
||||
let _ = fe_pg_server.shutdown().await;
|
||||
guard.remove_all().await;
|
||||
}
|
||||
|
||||
pub async fn test_postgres_parameter_inference(store_type: StorageType) {
|
||||
let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await;
|
||||
|
||||
let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
connection.await.unwrap();
|
||||
});
|
||||
|
||||
// Create demo table
|
||||
let _ = client
|
||||
.simple_query("create table demo(i bigint, ts timestamp time index, d date, dt datetime)")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let d = NaiveDate::from_yo_opt(2015, 100).unwrap();
|
||||
let dt = d.and_hms_opt(0, 0, 0).unwrap();
|
||||
let _ = client
|
||||
.execute(
|
||||
"INSERT INTO demo VALUES($1, $2, $3, $4)",
|
||||
&[&0i64, &dt, &d, &dt],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let rows = client
|
||||
.query("SELECT * FROM demo WHERE i = $1", &[&0i64])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(1, rows.len());
|
||||
|
||||
let _ = fe_pg_server.shutdown().await;
|
||||
guard.remove_all().await;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user