mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-05 04:42:56 +00:00
* fix: postgres extended query paramater parsing and type check * test: update sqlness output * feat: implement FromSqlText for pg_interval * chore: toml format
442 lines
15 KiB
Rust
442 lines
15 KiB
Rust
// Copyright 2023 Greptime Team
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
use std::fmt::Debug;
|
|
use std::sync::Arc;
|
|
|
|
use async_trait::async_trait;
|
|
use common_query::{Output, OutputData};
|
|
use common_recordbatch::RecordBatch;
|
|
use common_recordbatch::error::Result as RecordBatchResult;
|
|
use common_telemetry::{debug, tracing};
|
|
use datafusion_common::ParamValues;
|
|
use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
|
|
use datatypes::prelude::ConcreteDataType;
|
|
use datatypes::schema::SchemaRef;
|
|
use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
|
|
use pgwire::api::portal::{Format, Portal};
|
|
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
|
|
use pgwire::api::results::{
|
|
DescribePortalResponse, DescribeStatementResponse, QueryResponse, Response, Tag,
|
|
};
|
|
use pgwire::api::stmt::{QueryParser, StoredStatement};
|
|
use pgwire::api::{ClientInfo, ErrorHandler, Type};
|
|
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
|
use pgwire::messages::PgWireBackendMessage;
|
|
use query::query_engine::DescribeResult;
|
|
use session::Session;
|
|
use session::context::QueryContextRef;
|
|
use snafu::ResultExt;
|
|
use sql::dialect::PostgreSqlDialect;
|
|
use sql::parser::{ParseOptions, ParserContext};
|
|
|
|
use crate::SqlPlan;
|
|
use crate::error::{DataFusionSnafu, Result};
|
|
use crate::postgres::types::*;
|
|
use crate::postgres::utils::convert_err;
|
|
use crate::postgres::{PostgresServerHandlerInner, fixtures};
|
|
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
|
|
|
#[async_trait]
|
|
impl SimpleQueryHandler for PostgresServerHandlerInner {
|
|
#[tracing::instrument(skip_all, fields(protocol = "postgres"))]
|
|
async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
|
|
where
|
|
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
|
|
C::Error: Debug,
|
|
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
|
|
{
|
|
let query_ctx = self.session.new_query_context();
|
|
let db = query_ctx.get_db_string();
|
|
let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
|
|
.with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
|
|
.start_timer();
|
|
|
|
if query.is_empty() {
|
|
// early return if query is empty
|
|
return Ok(vec![Response::EmptyQuery]);
|
|
}
|
|
|
|
let query = if let Ok(statements) = self.query_parser.compatibility_parser.parse(query) {
|
|
statements
|
|
.iter()
|
|
.map(|s| s.to_string())
|
|
.collect::<Vec<_>>()
|
|
.join(";")
|
|
} else {
|
|
query.to_string()
|
|
};
|
|
|
|
if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
|
|
send_warning_opt(client, query_ctx).await?;
|
|
Ok(resps)
|
|
} else {
|
|
let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
|
|
|
|
let mut results = Vec::with_capacity(outputs.len());
|
|
|
|
for output in outputs {
|
|
let resp =
|
|
output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
|
|
results.push(resp);
|
|
}
|
|
|
|
send_warning_opt(client, query_ctx).await?;
|
|
Ok(results)
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
|
|
where
|
|
C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
|
|
C::Error: Debug,
|
|
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
|
|
{
|
|
if let Some(warning) = query_context.warning() {
|
|
client
|
|
.feed(PgWireBackendMessage::NoticeResponse(
|
|
ErrorInfo::new(
|
|
PgErrorSeverity::Warning.to_string(),
|
|
PgErrorCode::Ec01000.code(),
|
|
warning.clone(),
|
|
)
|
|
.into(),
|
|
))
|
|
.await?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub(crate) fn output_to_query_response(
|
|
query_ctx: QueryContextRef,
|
|
output: Result<Output>,
|
|
field_format: &Format,
|
|
) -> PgWireResult<Response> {
|
|
match output {
|
|
Ok(o) => match o.data {
|
|
OutputData::AffectedRows(rows) => {
|
|
Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
|
|
}
|
|
OutputData::Stream(record_stream) => {
|
|
let schema = record_stream.schema();
|
|
recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
|
|
}
|
|
OutputData::RecordBatches(recordbatches) => {
|
|
let schema = recordbatches.schema();
|
|
recordbatches_to_query_response(
|
|
query_ctx,
|
|
recordbatches.as_stream(),
|
|
schema,
|
|
field_format,
|
|
)
|
|
}
|
|
},
|
|
Err(e) => Err(convert_err(e)),
|
|
}
|
|
}
|
|
|
|
fn recordbatches_to_query_response<S>(
|
|
query_ctx: QueryContextRef,
|
|
recordbatches_stream: S,
|
|
schema: SchemaRef,
|
|
field_format: &Format,
|
|
) -> PgWireResult<Response>
|
|
where
|
|
S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
|
|
{
|
|
let pg_schema = Arc::new(schema_to_pg(schema.as_ref(), field_format).map_err(convert_err)?);
|
|
let pg_schema_ref = pg_schema.clone();
|
|
let data_row_stream = recordbatches_stream
|
|
.map(move |result| match result {
|
|
Ok(record_batch) => stream::iter(RecordBatchRowIterator::new(
|
|
query_ctx.clone(),
|
|
pg_schema_ref.clone(),
|
|
record_batch,
|
|
))
|
|
.boxed(),
|
|
Err(e) => stream::once(future::err(convert_err(e))).boxed(),
|
|
})
|
|
.flatten();
|
|
|
|
Ok(Response::Query(QueryResponse::new(
|
|
pg_schema,
|
|
data_row_stream,
|
|
)))
|
|
}
|
|
|
|
pub struct DefaultQueryParser {
|
|
query_handler: ServerSqlQueryHandlerRef,
|
|
session: Arc<Session>,
|
|
compatibility_parser: PostgresCompatibilityParser,
|
|
}
|
|
|
|
impl DefaultQueryParser {
|
|
pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
|
|
DefaultQueryParser {
|
|
query_handler,
|
|
session,
|
|
compatibility_parser: PostgresCompatibilityParser::new(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl QueryParser for DefaultQueryParser {
|
|
type Statement = SqlPlan;
|
|
|
|
async fn parse_sql<C>(
|
|
&self,
|
|
_client: &C,
|
|
sql: &str,
|
|
_types: &[Option<Type>],
|
|
) -> PgWireResult<Self::Statement> {
|
|
crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
|
|
let query_ctx = self.session.new_query_context();
|
|
|
|
// do not parse if query is empty or matches rules
|
|
if sql.is_empty() || fixtures::matches(sql) {
|
|
return Ok(SqlPlan {
|
|
query: sql.to_owned(),
|
|
statement: None,
|
|
plan: None,
|
|
schema: None,
|
|
});
|
|
}
|
|
|
|
let sql = if let Ok(mut statements) = self.compatibility_parser.parse(sql) {
|
|
statements.remove(0).to_string()
|
|
} else {
|
|
// bypass the error: it can run into error because of different
|
|
// versions of sqlparser
|
|
sql.to_string()
|
|
};
|
|
|
|
let mut stmts = ParserContext::create_with_dialect(
|
|
&sql,
|
|
&PostgreSqlDialect {},
|
|
ParseOptions::default(),
|
|
)
|
|
.map_err(convert_err)?;
|
|
if stmts.len() != 1 {
|
|
Err(PgWireError::UserError(Box::new(ErrorInfo::from(
|
|
PgErrorCode::Ec42P14,
|
|
))))
|
|
} else {
|
|
let stmt = stmts.remove(0);
|
|
|
|
let describe_result = self
|
|
.query_handler
|
|
.do_describe(stmt.clone(), query_ctx)
|
|
.await
|
|
.map_err(convert_err)?;
|
|
|
|
let (plan, schema) = if let Some(DescribeResult {
|
|
logical_plan,
|
|
schema,
|
|
}) = describe_result
|
|
{
|
|
(Some(logical_plan), Some(schema))
|
|
} else {
|
|
(None, None)
|
|
};
|
|
|
|
Ok(SqlPlan {
|
|
query: sql.clone(),
|
|
statement: Some(stmt),
|
|
plan,
|
|
schema,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl ExtendedQueryHandler for PostgresServerHandlerInner {
|
|
type Statement = SqlPlan;
|
|
type QueryParser = DefaultQueryParser;
|
|
|
|
fn query_parser(&self) -> Arc<Self::QueryParser> {
|
|
self.query_parser.clone()
|
|
}
|
|
|
|
async fn do_query<C>(
|
|
&self,
|
|
client: &mut C,
|
|
portal: &Portal<Self::Statement>,
|
|
_max_rows: usize,
|
|
) -> PgWireResult<Response>
|
|
where
|
|
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
|
|
C::Error: Debug,
|
|
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
|
|
{
|
|
let query_ctx = self.session.new_query_context();
|
|
let db = query_ctx.get_db_string();
|
|
let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
|
|
.with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
|
|
.start_timer();
|
|
|
|
let sql_plan = &portal.statement.statement;
|
|
|
|
if sql_plan.query.is_empty() {
|
|
// early return if query is empty
|
|
return Ok(Response::EmptyQuery);
|
|
}
|
|
|
|
if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
|
|
send_warning_opt(client, query_ctx).await?;
|
|
// if the statement matches our predefined rules, return it early
|
|
return Ok(resps.remove(0));
|
|
}
|
|
|
|
let output = if let Some(plan) = &sql_plan.plan {
|
|
let plan = plan
|
|
.clone()
|
|
.replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
|
|
plan, portal,
|
|
)?))
|
|
.context(DataFusionSnafu)
|
|
.map_err(convert_err)?;
|
|
self.query_handler
|
|
.do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
|
|
.await
|
|
} else {
|
|
// manually replace variables in prepared statement when no
|
|
// logical_plan is generated. This happens when logical plan is not
|
|
// supported for certain statements.
|
|
let mut sql = sql_plan.query.clone();
|
|
for i in 0..portal.parameter_len() {
|
|
sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?);
|
|
}
|
|
|
|
self.query_handler
|
|
.do_query(&sql, query_ctx.clone())
|
|
.await
|
|
.remove(0)
|
|
};
|
|
|
|
send_warning_opt(client, query_ctx.clone()).await?;
|
|
output_to_query_response(query_ctx, output, &portal.result_column_format)
|
|
}
|
|
|
|
async fn do_describe_statement<C>(
|
|
&self,
|
|
_client: &mut C,
|
|
stmt: &StoredStatement<Self::Statement>,
|
|
) -> PgWireResult<DescribeStatementResponse>
|
|
where
|
|
C: ClientInfo + Unpin + Send + Sync,
|
|
{
|
|
let sql_plan = &stmt.statement;
|
|
// client provided parameter types, can be empty if client doesn't try to parse statement
|
|
let provided_param_types = &stmt.parameter_types;
|
|
let server_inferenced_types = if let Some(plan) = &sql_plan.plan {
|
|
let param_types = plan
|
|
.get_parameter_types()
|
|
.context(DataFusionSnafu)
|
|
.map_err(convert_err)?
|
|
.into_iter()
|
|
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
|
|
.collect();
|
|
|
|
let types = param_types_to_pg_types(¶m_types).map_err(convert_err)?;
|
|
|
|
Some(types)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let param_count = if provided_param_types.is_empty() {
|
|
server_inferenced_types
|
|
.as_ref()
|
|
.map(|types| types.len())
|
|
.unwrap_or(0)
|
|
} else {
|
|
provided_param_types.len()
|
|
};
|
|
|
|
let param_types = (0..param_count)
|
|
.map(|i| {
|
|
let client_type = provided_param_types.get(i);
|
|
// use server type when client provided type is None (oid: 0 or other invalid values)
|
|
match client_type {
|
|
Some(Some(client_type)) => client_type.clone(),
|
|
_ => server_inferenced_types
|
|
.as_ref()
|
|
.and_then(|types| types.get(i).cloned())
|
|
.unwrap_or(Type::UNKNOWN),
|
|
}
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
if let Some(schema) = &sql_plan.schema {
|
|
schema_to_pg(schema, &Format::UnifiedBinary)
|
|
.map(|fields| DescribeStatementResponse::new(param_types, fields))
|
|
.map_err(convert_err)
|
|
} else {
|
|
if let Some(mut resp) =
|
|
fixtures::process(&sql_plan.query, self.session.new_query_context())
|
|
&& let Response::Query(query_response) = resp.remove(0)
|
|
{
|
|
return Ok(DescribeStatementResponse::new(
|
|
param_types,
|
|
(*query_response.row_schema()).clone(),
|
|
));
|
|
}
|
|
|
|
Ok(DescribeStatementResponse::new(param_types, vec![]))
|
|
}
|
|
}
|
|
|
|
async fn do_describe_portal<C>(
|
|
&self,
|
|
_client: &mut C,
|
|
portal: &Portal<Self::Statement>,
|
|
) -> PgWireResult<DescribePortalResponse>
|
|
where
|
|
C: ClientInfo + Unpin + Send + Sync,
|
|
{
|
|
let sql_plan = &portal.statement.statement;
|
|
let format = &portal.result_column_format;
|
|
|
|
if let Some(schema) = &sql_plan.schema {
|
|
schema_to_pg(schema, format)
|
|
.map(DescribePortalResponse::new)
|
|
.map_err(convert_err)
|
|
} else {
|
|
if let Some(mut resp) =
|
|
fixtures::process(&sql_plan.query, self.session.new_query_context())
|
|
&& let Response::Query(query_response) = resp.remove(0)
|
|
{
|
|
return Ok(DescribePortalResponse::new(
|
|
(*query_response.row_schema()).clone(),
|
|
));
|
|
}
|
|
|
|
Ok(DescribePortalResponse::new(vec![]))
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ErrorHandler for PostgresServerHandlerInner {
|
|
fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
|
|
where
|
|
C: ClientInfo,
|
|
{
|
|
debug!("Postgres interface error {}", error)
|
|
}
|
|
}
|