1use std::fmt::Debug;
16use std::pin::Pin;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use common_query::{Output, OutputData};
21use common_recordbatch::RecordBatch;
22use common_recordbatch::error::Result as RecordBatchResult;
23use common_telemetry::{debug, info, tracing};
24use datafusion::sql::sqlparser::ast::{CopyOption, CopyTarget, Statement as SqlParserStatement};
25use datafusion_common::ParamValues;
26use datafusion_expr::LogicalPlan;
27use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
28use datatypes::prelude::ConcreteDataType;
29use datatypes::schema::{Schema, SchemaRef};
30use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
31use pgwire::api::portal::{Format, Portal};
32use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
33use pgwire::api::results::{
34 CopyCsvOptions, CopyEncoder, CopyResponse, CopyTextOptions, DataRowEncoder,
35 DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag,
36};
37use pgwire::api::stmt::{QueryParser, StoredStatement};
38use pgwire::api::{ClientInfo, ErrorHandler, Type};
39use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
40use pgwire::messages::PgWireBackendMessage;
41use pgwire::messages::copy::CopyData;
42use pgwire::messages::data::DataRow;
43use query::planner::DfLogicalPlanner;
44use query::query_engine::DescribeResult;
45use session::Session;
46use session::context::QueryContextRef;
47use snafu::ResultExt;
48use sql::dialect::PostgreSqlDialect;
49use sql::parser::{ParseOptions, ParserContext};
50use sql::statements::statement::Statement;
51
52use crate::SqlPlan;
53use crate::error::{DataFusionSnafu, InferParameterTypesSnafu, Result};
54use crate::postgres::types::*;
55use crate::postgres::utils::convert_err;
56use crate::postgres::{PostgresServerHandlerInner, fixtures};
57use crate::query_handler::sql::ServerSqlQueryHandlerRef;
58
59#[async_trait]
60impl SimpleQueryHandler for PostgresServerHandlerInner {
61 #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
62 async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
63 where
64 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
65 C::Error: Debug,
66 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
67 {
68 let query_ctx = self.session.new_query_context();
69 let db = query_ctx.get_db_string();
70 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
71 .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
72 .start_timer();
73
74 if query.is_empty() {
75 return Ok(vec![Response::EmptyQuery]);
77 }
78
79 let parsed_query = self.query_parser.compatibility_parser.parse(query);
80
81 let query = if let Ok(statements) = &parsed_query {
82 statements
83 .iter()
84 .map(|s| s.to_string())
85 .collect::<Vec<_>>()
86 .join(";")
87 } else {
88 query.to_string()
89 };
90
91 if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
92 send_warning_opt(client, query_ctx).await?;
93 Ok(resps)
94 } else {
95 let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
96
97 let mut results = Vec::with_capacity(outputs.len());
98
99 let statements = parsed_query.ok();
100 for (idx, output) in outputs.into_iter().enumerate() {
101 let copy_format = statements
102 .as_ref()
103 .and_then(|stmts| stmts.get(idx))
104 .and_then(check_copy_to_stdout);
105 let resp = if let Some(format) = ©_format {
106 output_to_copy_response(query_ctx.clone(), output, format)?
107 } else {
108 output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?
109 };
110 results.push(resp);
111 }
112
113 send_warning_opt(client, query_ctx).await?;
114 Ok(results)
115 }
116 }
117}
118
119async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
120where
121 C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
122 C::Error: Debug,
123 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
124{
125 if let Some(warning) = query_context.warning() {
126 client
127 .feed(PgWireBackendMessage::NoticeResponse(
128 ErrorInfo::new(
129 PgErrorSeverity::Warning.to_string(),
130 PgErrorCode::Ec01000.code(),
131 warning.clone(),
132 )
133 .into(),
134 ))
135 .await?;
136 }
137
138 Ok(())
139}
140
141pub(crate) fn output_to_query_response(
142 query_ctx: QueryContextRef,
143 output: Result<Output>,
144 field_format: &Format,
145) -> PgWireResult<Response> {
146 match output {
147 Ok(o) => match o.data {
148 OutputData::AffectedRows(rows) => {
149 Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
150 }
151 OutputData::Stream(record_stream) => {
152 let schema = record_stream.schema();
153 recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
154 }
155 OutputData::RecordBatches(recordbatches) => {
156 let schema = recordbatches.schema();
157 recordbatches_to_query_response(
158 query_ctx,
159 recordbatches.as_stream(),
160 schema,
161 field_format,
162 )
163 }
164 },
165 Err(e) => Err(convert_err(e)),
166 }
167}
168
169type RowStream<T> = Pin<Box<dyn Stream<Item = PgWireResult<T>> + Send + Unpin>>;
170
171fn recordbatches_to_query_response<S>(
172 query_ctx: QueryContextRef,
173 recordbatches_stream: S,
174 schema: SchemaRef,
175 field_format: &Format,
176) -> PgWireResult<Response>
177where
178 S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
179{
180 let format_options = format_options_from_query_ctx(&query_ctx);
181 let pg_schema = Arc::new(
182 schema_to_pg(schema.as_ref(), field_format, Some(format_options)).map_err(convert_err)?,
183 );
184
185 let encoder = DataRowEncoder::new(pg_schema.clone());
186 let row_stream = RecordBatchRowStream::new(
187 query_ctx.clone(),
188 pg_schema.clone(),
189 schema.clone(),
190 recordbatches_stream,
191 encoder,
192 );
193
194 let data_row_stream: RowStream<DataRow> = Box::pin(
195 row_stream
196 .map(move |result| match result {
197 Ok(rows) => Box::pin(stream::iter(rows.into_iter().map(Ok))) as RowStream<DataRow>,
198 Err(e) => Box::pin(stream::once(future::ready(Err(e)))) as RowStream<DataRow>,
199 })
200 .flatten(),
201 );
202
203 Ok(Response::Query(QueryResponse::new(
204 pg_schema,
205 data_row_stream,
206 )))
207}
208
209pub(crate) fn output_to_copy_response(
210 query_ctx: QueryContextRef,
211 output: Result<Output>,
212 format: &str,
213) -> PgWireResult<Response> {
214 match output {
215 Ok(o) => match o.data {
216 OutputData::AffectedRows(_) => Err(PgWireError::UserError(Box::new(ErrorInfo::new(
217 "ERROR".to_string(),
218 "42601".to_string(),
219 "COPY cannot be used with non-query statements".to_string(),
220 )))),
221 OutputData::Stream(record_stream) => {
222 let schema = record_stream.schema();
223 recordbatches_to_copy_response(query_ctx, record_stream, schema, format)
224 }
225 OutputData::RecordBatches(recordbatches) => {
226 let schema = recordbatches.schema();
227 recordbatches_to_copy_response(query_ctx, recordbatches.as_stream(), schema, format)
228 }
229 },
230 Err(e) => Err(convert_err(e)),
231 }
232}
233
234fn recordbatches_to_copy_response<S>(
235 query_ctx: QueryContextRef,
236 recordbatches_stream: S,
237 schema: SchemaRef,
238 format: &str,
239) -> PgWireResult<Response>
240where
241 S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
242{
243 let format_options = format_options_from_query_ctx(&query_ctx);
244 let pg_fields = schema_to_pg(schema.as_ref(), &Format::UnifiedText, Some(format_options))
245 .map_err(convert_err)?;
246
247 let copy_format = match format.to_lowercase().as_str() {
248 "binary" => 1,
249 _ => 0,
250 };
251
252 let pg_schema = Arc::new(pg_fields);
253 let num_columns = pg_schema.len();
254
255 let copy_encoder = match format.to_lowercase().as_str() {
256 "csv" => CopyEncoder::new_csv(pg_schema.clone(), CopyCsvOptions::default()),
257 "binary" => CopyEncoder::new_binary(pg_schema.clone()),
258 _ => CopyEncoder::new_text(pg_schema.clone(), CopyTextOptions::default()),
259 };
260
261 let row_stream = RecordBatchRowStream::new(
262 query_ctx.clone(),
263 pg_schema.clone(),
264 schema.clone(),
265 recordbatches_stream,
266 copy_encoder,
267 );
268
269 let copy_stream: RowStream<CopyData> = Box::pin(
270 row_stream
271 .map(move |result| match result {
272 Ok(rows) => Box::pin(stream::iter(rows.into_iter().map(Ok))) as RowStream<CopyData>,
273 Err(e) => Box::pin(stream::once(future::ready(Err(e)))) as RowStream<CopyData>,
274 })
275 .flatten(),
276 );
277
278 Ok(Response::CopyOut(CopyResponse::new(
279 copy_format,
280 num_columns,
281 copy_stream,
282 )))
283}
284
285pub struct DefaultQueryParser {
286 query_handler: ServerSqlQueryHandlerRef,
287 session: Arc<Session>,
288 compatibility_parser: PostgresCompatibilityParser,
289}
290
291impl DefaultQueryParser {
292 pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
293 DefaultQueryParser {
294 query_handler,
295 session,
296 compatibility_parser: PostgresCompatibilityParser::new(),
297 }
298 }
299}
300
301#[derive(Clone, Debug)]
303pub struct PgSqlPlan {
304 pub(crate) plan: SqlPlan,
305 pub(crate) copy_to_stdout_format: Option<String>,
306}
307
308#[async_trait]
309impl QueryParser for DefaultQueryParser {
310 type Statement = PgSqlPlan;
311
312 async fn parse_sql<C>(
313 &self,
314 _client: &C,
315 sql: &str,
316 _types: &[Option<Type>],
317 ) -> PgWireResult<Self::Statement> {
318 crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
319 let query_ctx = self.session.new_query_context();
320
321 if sql.is_empty() {
323 return Ok(PgSqlPlan {
324 plan: SqlPlan::Empty,
325 copy_to_stdout_format: None,
326 });
327 }
328
329 if fixtures::matches(sql) {
330 return Ok(PgSqlPlan {
331 plan: SqlPlan::Shortcut(sql.to_string()),
332 copy_to_stdout_format: None,
333 });
334 }
335
336 let parsed_statements = self.compatibility_parser.parse(sql);
337 let (sql, copy_to_stdout_format) = if let Ok(mut statements) = parsed_statements {
338 let first_stmt = statements.remove(0);
339 let format = check_copy_to_stdout(&first_stmt);
340 (first_stmt.to_string(), format)
341 } else {
342 (sql.to_string(), None)
345 };
346
347 let mut stmts = ParserContext::create_with_dialect(
348 &sql,
349 &PostgreSqlDialect {},
350 ParseOptions::default(),
351 )
352 .map_err(convert_err)?;
353 if stmts.len() != 1 {
354 Err(PgWireError::UserError(Box::new(ErrorInfo::from(
355 PgErrorCode::Ec42P14,
356 ))))
357 } else {
358 let stmt = stmts.remove(0);
359
360 if let Some(logical_plan) = self
361 .query_handler
362 .do_describe(stmt.clone(), query_ctx)
363 .await
364 .map_err(convert_err)?
365 .map(|DescribeResult { logical_plan }| logical_plan)
366 {
367 Ok(PgSqlPlan {
368 plan: SqlPlan::Plan(logical_plan, stmt),
369 copy_to_stdout_format,
370 })
371 } else {
372 Ok(PgSqlPlan {
373 plan: SqlPlan::Statement(stmt, sql),
374 copy_to_stdout_format,
375 })
376 }
377 }
378 }
379
380 fn get_parameter_types(&self, _stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
381 Err(PgWireError::ApiError(
384 "get_parameter_types is not expected to be called".into(),
385 ))
386 }
387
388 fn get_result_schema(
389 &self,
390 _stmt: &Self::Statement,
391 _column_format: Option<&Format>,
392 ) -> PgWireResult<Vec<FieldInfo>> {
393 Err(PgWireError::ApiError(
396 "get_result_schema is not expected to be called".into(),
397 ))
398 }
399}
400
401#[async_trait]
402impl ExtendedQueryHandler for PostgresServerHandlerInner {
403 type Statement = PgSqlPlan;
404 type QueryParser = DefaultQueryParser;
405
406 fn query_parser(&self) -> Arc<Self::QueryParser> {
407 self.query_parser.clone()
408 }
409
410 async fn do_query<C>(
411 &self,
412 client: &mut C,
413 portal: &Portal<Self::Statement>,
414 _max_rows: usize,
415 ) -> PgWireResult<Response>
416 where
417 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
418 C::Error: Debug,
419 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
420 {
421 let query_ctx = self.session.new_query_context();
422 let db = query_ctx.get_db_string();
423 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
424 .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
425 .start_timer();
426
427 let pg_sql_plan = &portal.statement.statement;
428 let sql_plan = &pg_sql_plan.plan;
429
430 let output = match sql_plan {
431 SqlPlan::Empty => {
432 return Ok(Response::EmptyQuery);
434 }
435 SqlPlan::Shortcut(query) => {
436 if let Some(mut resps) = fixtures::process(query, query_ctx.clone()) {
437 send_warning_opt(client, query_ctx).await?;
438 return Ok(resps.remove(0));
440 } else {
441 return Ok(Response::EmptyQuery);
443 }
444 }
445 SqlPlan::Plan(plan, stmt) => {
446 let values = parameters_to_scalar_values(plan, portal)?;
447 let plan = plan
448 .clone()
449 .replace_params_with_values(&ParamValues::List(
450 values.into_iter().map(Into::into).collect(),
451 ))
452 .context(DataFusionSnafu)
453 .map_err(convert_err)?;
454 self.query_handler
455 .do_exec_plan(plan, Some(stmt.clone()), query_ctx.clone())
456 .await
457 }
458 SqlPlan::Statement(_stmt, query) => {
459 self.query_handler
465 .do_query(query, query_ctx.clone())
466 .await
467 .remove(0)
468 }
469 };
470
471 send_warning_opt(client, query_ctx.clone()).await?;
472
473 if let Some(format) = &pg_sql_plan.copy_to_stdout_format {
474 output_to_copy_response(query_ctx, output, format)
475 } else {
476 output_to_query_response(query_ctx, output, &portal.result_column_format)
477 }
478 }
479
480 async fn do_describe_statement<C>(
481 &self,
482 _client: &mut C,
483 stmt: &StoredStatement<Self::Statement>,
484 ) -> PgWireResult<DescribeStatementResponse>
485 where
486 C: ClientInfo + Unpin + Send + Sync,
487 {
488 let sql_plan = &stmt.statement.plan;
489 let provided_param_types = &stmt.parameter_types;
491 let server_inferenced_types = if let SqlPlan::Plan(plan, _) = &sql_plan {
492 let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
493 .context(InferParameterTypesSnafu)
494 .map_err(convert_err)?
495 .into_iter()
496 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
497 .collect();
498
499 let types = param_types_to_pg_types(¶m_types).map_err(convert_err)?;
500
501 Some(types)
502 } else {
503 None
504 };
505
506 let param_count = if provided_param_types.is_empty() {
507 server_inferenced_types
508 .as_ref()
509 .map(|types| types.len())
510 .unwrap_or(0)
511 } else {
512 provided_param_types.len()
513 };
514
515 let param_types = (0..param_count)
516 .map(|i| {
517 let client_type = provided_param_types.get(i);
518 match client_type {
520 Some(Some(client_type)) => client_type.clone(),
521 _ => server_inferenced_types
522 .as_ref()
523 .and_then(|types| types.get(i).cloned())
524 .unwrap_or(Type::UNKNOWN),
525 }
526 })
527 .collect::<Vec<_>>();
528
529 let fields = describe_fields(sql_plan, &Format::UnifiedText, &self.session)?;
530
531 Ok(DescribeStatementResponse::new(param_types, fields))
532 }
533
534 async fn do_describe_portal<C>(
535 &self,
536 _client: &mut C,
537 portal: &Portal<Self::Statement>,
538 ) -> PgWireResult<DescribePortalResponse>
539 where
540 C: ClientInfo + Unpin + Send + Sync,
541 {
542 let sql_plan = &portal.statement.statement.plan;
543 let format = &portal.result_column_format;
544
545 let fields = describe_fields(sql_plan, format, &self.session)?;
546
547 Ok(DescribePortalResponse::new(fields))
548 }
549}
550
551fn describe_fields(
552 sql_plan: &SqlPlan,
553 format: &Format,
554 session: &Arc<Session>,
555) -> PgWireResult<Vec<FieldInfo>> {
556 match sql_plan {
557 SqlPlan::Plan(plan, _) if !matches!(plan, LogicalPlan::Dml(_) | LogicalPlan::Ddl(_)) => {
559 let schema: Schema = plan.schema().clone().try_into().map_err(convert_err)?;
560 schema_to_pg(&schema, format, None).map_err(convert_err)
561 }
562 SqlPlan::Statement(
565 Statement::ShowCreateDatabase(_)
566 | Statement::ShowCreateTable(_)
567 | Statement::ShowCreateFlow(_)
568 | Statement::ShowCreateView(_),
569 _,
570 ) => Ok(vec![
571 FieldInfo::new(
572 "name".to_string(),
573 None,
574 None,
575 Type::TEXT,
576 format.format_for(0),
577 ),
578 FieldInfo::new(
579 "create_statement".to_string(),
580 None,
581 None,
582 Type::TEXT,
583 format.format_for(1),
584 ),
585 ]),
586 #[cfg(feature = "enterprise")]
587 SqlPlan::Statement(Statement::ShowCreateTrigger(_), _) => Ok(vec![
588 FieldInfo::new(
589 "name".to_string(),
590 None,
591 None,
592 Type::TEXT,
593 format.format_for(0),
594 ),
595 FieldInfo::new(
596 "create_statement".to_string(),
597 None,
598 None,
599 Type::TEXT,
600 format.format_for(1),
601 ),
602 ]),
603 SqlPlan::Statement(
605 Statement::ShowTables(_) | Statement::ShowFlows(_) | Statement::ShowViews(_),
606 _,
607 ) => Ok(vec![FieldInfo::new(
608 "name".to_string(),
609 None,
610 None,
611 Type::TEXT,
612 format.format_for(0),
613 )]),
614 #[cfg(feature = "enterprise")]
615 SqlPlan::Statement(Statement::ShowTriggers(_), _) => Ok(vec![FieldInfo::new(
616 "name".to_string(),
617 None,
618 None,
619 Type::TEXT,
620 format.format_for(0),
621 )]),
622 SqlPlan::Shortcut(query) => {
625 if let Some(mut resp) = fixtures::process(query, session.new_query_context())
627 && let Response::Query(query_response) = resp.remove(0)
628 {
629 Ok((*query_response.row_schema()).clone())
630 } else {
631 Ok(vec![])
633 }
634 }
635 _ => {
636 Ok(vec![])
638 }
639 }
640}
641
642impl ErrorHandler for PostgresServerHandlerInner {
643 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
644 where
645 C: ClientInfo,
646 {
647 match error {
648 PgWireError::IoError(e) => debug!("Postgres client disconnected: {}", e),
649 _ => info!("Postgres interface error: {}", error),
650 }
651 }
652}
653
654fn check_copy_to_stdout(statement: &SqlParserStatement) -> Option<String> {
655 if let SqlParserStatement::Copy {
656 target, options, ..
657 } = statement
658 && matches!(target, CopyTarget::Stdout)
659 {
660 for opt in options {
661 if let CopyOption::Format(format_ident) = opt {
662 return Some(format_ident.value.to_lowercase());
663 }
664 }
665 return Some("txt".to_string());
666 }
667
668 None
669}
670
671#[cfg(test)]
672mod tests {
673 use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
674
675 use super::*;
676
677 fn parse_copy_statement(sql: &str) -> SqlParserStatement {
678 let parser = PostgresCompatibilityParser::new();
679 let statements = parser.parse(sql).unwrap();
680 statements.into_iter().next().unwrap()
681 }
682
683 #[test]
684 fn test_check_copy_out_with_csv_format() {
685 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT CSV)");
686 assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
687 }
688
689 #[test]
690 fn test_check_copy_out_with_txt_format() {
691 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT TXT)");
692 assert_eq!(check_copy_to_stdout(&statement), Some("txt".to_string()));
693 }
694
695 #[test]
696 fn test_check_copy_out_with_binary_format() {
697 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT BINARY)");
698 assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
699 }
700
701 #[test]
702 fn test_check_copy_out_without_format() {
703 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT");
704 assert_eq!(check_copy_to_stdout(&statement), Some("txt".to_string()));
705 }
706
707 #[test]
708 fn test_check_copy_out_to_file() {
709 let statement =
710 parse_copy_statement("COPY (SELECT 1) TO '/path/to/file.csv' WITH (FORMAT CSV)");
711 assert_eq!(check_copy_to_stdout(&statement), None);
712 }
713
714 #[test]
715 fn test_check_copy_out_case_insensitive() {
716 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT csv)");
717 assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
718
719 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT binary)");
720 assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
721 }
722
723 #[test]
724 fn test_check_copy_out_with_multiple_options() {
725 let statement = parse_copy_statement(
726 "COPY (SELECT 1) TO STDOUT WITH (FORMAT csv, DELIMITER ',', HEADER)",
727 );
728 assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
729
730 let statement = parse_copy_statement(
731 "COPY (SELECT 1) TO STDOUT WITH (DELIMITER ',', HEADER, FORMAT binary)",
732 );
733 assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
734 }
735}