servers/postgres/
handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_query::{Output, OutputData};
20use common_recordbatch::RecordBatch;
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_telemetry::{debug, tracing};
23use datafusion_common::ParamValues;
24use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::SchemaRef;
27use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
28use pgwire::api::portal::{Format, Portal};
29use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
30use pgwire::api::results::{
31    DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag,
32};
33use pgwire::api::stmt::{QueryParser, StoredStatement};
34use pgwire::api::{ClientInfo, ErrorHandler, Type};
35use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
36use pgwire::messages::PgWireBackendMessage;
37use query::query_engine::DescribeResult;
38use session::Session;
39use session::context::QueryContextRef;
40use snafu::ResultExt;
41use sql::dialect::PostgreSqlDialect;
42use sql::parser::{ParseOptions, ParserContext};
43use sql::statements::statement::Statement;
44
45use crate::SqlPlan;
46use crate::error::{DataFusionSnafu, Result};
47use crate::postgres::types::*;
48use crate::postgres::utils::convert_err;
49use crate::postgres::{PostgresServerHandlerInner, fixtures};
50use crate::query_handler::sql::ServerSqlQueryHandlerRef;
51
52#[async_trait]
53impl SimpleQueryHandler for PostgresServerHandlerInner {
54    #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
55    async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
56    where
57        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
58        C::Error: Debug,
59        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
60    {
61        let query_ctx = self.session.new_query_context();
62        let db = query_ctx.get_db_string();
63        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
64            .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
65            .start_timer();
66
67        if query.is_empty() {
68            // early return if query is empty
69            return Ok(vec![Response::EmptyQuery]);
70        }
71
72        let query = if let Ok(statements) = self.query_parser.compatibility_parser.parse(query) {
73            statements
74                .iter()
75                .map(|s| s.to_string())
76                .collect::<Vec<_>>()
77                .join(";")
78        } else {
79            query.to_string()
80        };
81
82        if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
83            send_warning_opt(client, query_ctx).await?;
84            Ok(resps)
85        } else {
86            let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
87
88            let mut results = Vec::with_capacity(outputs.len());
89
90            for output in outputs {
91                let resp =
92                    output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
93                results.push(resp);
94            }
95
96            send_warning_opt(client, query_ctx).await?;
97            Ok(results)
98        }
99    }
100}
101
102async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
103where
104    C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
105    C::Error: Debug,
106    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
107{
108    if let Some(warning) = query_context.warning() {
109        client
110            .feed(PgWireBackendMessage::NoticeResponse(
111                ErrorInfo::new(
112                    PgErrorSeverity::Warning.to_string(),
113                    PgErrorCode::Ec01000.code(),
114                    warning.clone(),
115                )
116                .into(),
117            ))
118            .await?;
119    }
120
121    Ok(())
122}
123
124pub(crate) fn output_to_query_response(
125    query_ctx: QueryContextRef,
126    output: Result<Output>,
127    field_format: &Format,
128) -> PgWireResult<Response> {
129    match output {
130        Ok(o) => match o.data {
131            OutputData::AffectedRows(rows) => {
132                Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
133            }
134            OutputData::Stream(record_stream) => {
135                let schema = record_stream.schema();
136                recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
137            }
138            OutputData::RecordBatches(recordbatches) => {
139                let schema = recordbatches.schema();
140                recordbatches_to_query_response(
141                    query_ctx,
142                    recordbatches.as_stream(),
143                    schema,
144                    field_format,
145                )
146            }
147        },
148        Err(e) => Err(convert_err(e)),
149    }
150}
151
152fn recordbatches_to_query_response<S>(
153    query_ctx: QueryContextRef,
154    recordbatches_stream: S,
155    schema: SchemaRef,
156    field_format: &Format,
157) -> PgWireResult<Response>
158where
159    S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
160{
161    let pg_schema = Arc::new(schema_to_pg(schema.as_ref(), field_format).map_err(convert_err)?);
162    let pg_schema_ref = pg_schema.clone();
163    let data_row_stream = recordbatches_stream
164        .map(move |result| match result {
165            Ok(record_batch) => stream::iter(RecordBatchRowIterator::new(
166                query_ctx.clone(),
167                pg_schema_ref.clone(),
168                record_batch,
169            ))
170            .boxed(),
171            Err(e) => stream::once(future::err(convert_err(e))).boxed(),
172        })
173        .flatten();
174
175    Ok(Response::Query(QueryResponse::new(
176        pg_schema,
177        data_row_stream,
178    )))
179}
180
181pub struct DefaultQueryParser {
182    query_handler: ServerSqlQueryHandlerRef,
183    session: Arc<Session>,
184    compatibility_parser: PostgresCompatibilityParser,
185}
186
187impl DefaultQueryParser {
188    pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
189        DefaultQueryParser {
190            query_handler,
191            session,
192            compatibility_parser: PostgresCompatibilityParser::new(),
193        }
194    }
195}
196
197#[async_trait]
198impl QueryParser for DefaultQueryParser {
199    type Statement = SqlPlan;
200
201    async fn parse_sql<C>(
202        &self,
203        _client: &C,
204        sql: &str,
205        _types: &[Option<Type>],
206    ) -> PgWireResult<Self::Statement> {
207        crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
208        let query_ctx = self.session.new_query_context();
209
210        // do not parse if query is empty or matches rules
211        if sql.is_empty() || fixtures::matches(sql) {
212            return Ok(SqlPlan {
213                query: sql.to_owned(),
214                statement: None,
215                plan: None,
216                schema: None,
217            });
218        }
219
220        let sql = if let Ok(mut statements) = self.compatibility_parser.parse(sql) {
221            statements.remove(0).to_string()
222        } else {
223            // bypass the error: it can run into error because of different
224            // versions of sqlparser
225            sql.to_string()
226        };
227
228        let mut stmts = ParserContext::create_with_dialect(
229            &sql,
230            &PostgreSqlDialect {},
231            ParseOptions::default(),
232        )
233        .map_err(convert_err)?;
234        if stmts.len() != 1 {
235            Err(PgWireError::UserError(Box::new(ErrorInfo::from(
236                PgErrorCode::Ec42P14,
237            ))))
238        } else {
239            let stmt = stmts.remove(0);
240
241            let describe_result = self
242                .query_handler
243                .do_describe(stmt.clone(), query_ctx)
244                .await
245                .map_err(convert_err)?;
246
247            let (plan, schema) = if let Some(DescribeResult {
248                logical_plan,
249                schema,
250            }) = describe_result
251            {
252                (Some(logical_plan), Some(schema))
253            } else {
254                (None, None)
255            };
256
257            Ok(SqlPlan {
258                query: sql.clone(),
259                statement: Some(stmt),
260                plan,
261                schema,
262            })
263        }
264    }
265}
266
267#[async_trait]
268impl ExtendedQueryHandler for PostgresServerHandlerInner {
269    type Statement = SqlPlan;
270    type QueryParser = DefaultQueryParser;
271
272    fn query_parser(&self) -> Arc<Self::QueryParser> {
273        self.query_parser.clone()
274    }
275
276    async fn do_query<C>(
277        &self,
278        client: &mut C,
279        portal: &Portal<Self::Statement>,
280        _max_rows: usize,
281    ) -> PgWireResult<Response>
282    where
283        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
284        C::Error: Debug,
285        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
286    {
287        let query_ctx = self.session.new_query_context();
288        let db = query_ctx.get_db_string();
289        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
290            .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
291            .start_timer();
292
293        let sql_plan = &portal.statement.statement;
294
295        if sql_plan.query.is_empty() {
296            // early return if query is empty
297            return Ok(Response::EmptyQuery);
298        }
299
300        if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
301            send_warning_opt(client, query_ctx).await?;
302            // if the statement matches our predefined rules, return it early
303            return Ok(resps.remove(0));
304        }
305
306        let output = if let Some(plan) = &sql_plan.plan {
307            let plan = plan
308                .clone()
309                .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
310                    plan, portal,
311                )?))
312                .context(DataFusionSnafu)
313                .map_err(convert_err)?;
314            self.query_handler
315                .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
316                .await
317        } else {
318            // manually replace variables in prepared statement when no
319            // logical_plan is generated. This happens when logical plan is not
320            // supported for certain statements.
321            let mut sql = sql_plan.query.clone();
322            for i in 0..portal.parameter_len() {
323                sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
324            }
325
326            self.query_handler
327                .do_query(&sql, query_ctx.clone())
328                .await
329                .remove(0)
330        };
331
332        send_warning_opt(client, query_ctx.clone()).await?;
333        output_to_query_response(query_ctx, output, &portal.result_column_format)
334    }
335
336    async fn do_describe_statement<C>(
337        &self,
338        _client: &mut C,
339        stmt: &StoredStatement<Self::Statement>,
340    ) -> PgWireResult<DescribeStatementResponse>
341    where
342        C: ClientInfo + Unpin + Send + Sync,
343    {
344        let sql_plan = &stmt.statement;
345        // client provided parameter types, can be empty if client doesn't try to parse statement
346        let provided_param_types = &stmt.parameter_types;
347        let server_inferenced_types = if let Some(plan) = &sql_plan.plan {
348            let param_types = plan
349                .get_parameter_types()
350                .context(DataFusionSnafu)
351                .map_err(convert_err)?
352                .into_iter()
353                .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
354                .collect();
355
356            let types = param_types_to_pg_types(&param_types).map_err(convert_err)?;
357
358            Some(types)
359        } else {
360            None
361        };
362
363        let param_count = if provided_param_types.is_empty() {
364            server_inferenced_types
365                .as_ref()
366                .map(|types| types.len())
367                .unwrap_or(0)
368        } else {
369            provided_param_types.len()
370        };
371
372        let param_types = (0..param_count)
373            .map(|i| {
374                let client_type = provided_param_types.get(i);
375                // use server type when client provided type is None (oid: 0 or other invalid values)
376                match client_type {
377                    Some(Some(client_type)) => client_type.clone(),
378                    _ => server_inferenced_types
379                        .as_ref()
380                        .and_then(|types| types.get(i).cloned())
381                        .unwrap_or(Type::UNKNOWN),
382                }
383            })
384            .collect::<Vec<_>>();
385
386        if let Some(schema) = &sql_plan.schema {
387            schema_to_pg(schema, &Format::UnifiedBinary)
388                .map(|fields| DescribeStatementResponse::new(param_types, fields))
389                .map_err(convert_err)
390        } else {
391            if let Some(mut resp) =
392                fixtures::process(&sql_plan.query, self.session.new_query_context())
393                && let Response::Query(query_response) = resp.remove(0)
394            {
395                return Ok(DescribeStatementResponse::new(
396                    param_types,
397                    (*query_response.row_schema()).clone(),
398                ));
399            }
400
401            Ok(DescribeStatementResponse::new(param_types, vec![]))
402        }
403    }
404
405    async fn do_describe_portal<C>(
406        &self,
407        _client: &mut C,
408        portal: &Portal<Self::Statement>,
409    ) -> PgWireResult<DescribePortalResponse>
410    where
411        C: ClientInfo + Unpin + Send + Sync,
412    {
413        let sql_plan = &portal.statement.statement;
414        let format = &portal.result_column_format;
415
416        match sql_plan.statement.as_ref() {
417            Some(Statement::Query(_)) => {
418                // if the query has a schema, it is managed by datafusion, use the schema
419                if let Some(schema) = &sql_plan.schema {
420                    schema_to_pg(schema, format)
421                        .map(DescribePortalResponse::new)
422                        .map_err(convert_err)
423                } else {
424                    // fallback to NoData
425                    Ok(DescribePortalResponse::new(vec![]))
426                }
427            }
428            // We can cover only part of show statements
429            // these show create statements will return 2 columns
430            Some(Statement::ShowCreateDatabase(_))
431            | Some(Statement::ShowCreateTable(_))
432            | Some(Statement::ShowCreateFlow(_))
433            | Some(Statement::ShowCreateView(_)) => Ok(DescribePortalResponse::new(vec![
434                FieldInfo::new(
435                    "name".to_string(),
436                    None,
437                    None,
438                    Type::TEXT,
439                    format.format_for(0),
440                ),
441                FieldInfo::new(
442                    "create_statement".to_string(),
443                    None,
444                    None,
445                    Type::TEXT,
446                    format.format_for(1),
447                ),
448            ])),
449            // single column show statements
450            Some(Statement::ShowTables(_))
451            | Some(Statement::ShowFlows(_))
452            | Some(Statement::ShowViews(_)) => {
453                Ok(DescribePortalResponse::new(vec![FieldInfo::new(
454                    "name".to_string(),
455                    None,
456                    None,
457                    Type::TEXT,
458                    format.format_for(0),
459                )]))
460            }
461            // we will not support other show statements for extended query protocol at least for now.
462            // because the return columns is not predictable at this stage
463            _ => {
464                // test if query caught by fixture
465                if let Some(mut resp) =
466                    fixtures::process(&sql_plan.query, self.session.new_query_context())
467                    && let Response::Query(query_response) = resp.remove(0)
468                {
469                    Ok(DescribePortalResponse::new(
470                        (*query_response.row_schema()).clone(),
471                    ))
472                } else {
473                    // fallback to NoData
474                    Ok(DescribePortalResponse::new(vec![]))
475                }
476            }
477        }
478    }
479}
480
481impl ErrorHandler for PostgresServerHandlerInner {
482    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
483    where
484        C: ClientInfo,
485    {
486        debug!("Postgres interface error {}", error)
487    }
488}