Skip to main content

servers/postgres/
handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::pin::Pin;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use common_query::{Output, OutputData};
21use common_recordbatch::RecordBatch;
22use common_recordbatch::error::Result as RecordBatchResult;
23use common_telemetry::{debug, info, tracing};
24use datafusion::sql::sqlparser::ast::{CopyOption, CopyTarget, Statement as SqlParserStatement};
25use datafusion_common::ParamValues;
26use datafusion_expr::LogicalPlan;
27use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
28use datatypes::prelude::ConcreteDataType;
29use datatypes::schema::{Schema, SchemaRef};
30use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
31use pgwire::api::portal::{Format, Portal};
32use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
33use pgwire::api::results::{
34    CopyCsvOptions, CopyEncoder, CopyResponse, CopyTextOptions, DataRowEncoder,
35    DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag,
36};
37use pgwire::api::stmt::{QueryParser, StoredStatement};
38use pgwire::api::{ClientInfo, ErrorHandler, Type};
39use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
40use pgwire::messages::PgWireBackendMessage;
41use pgwire::messages::copy::CopyData;
42use pgwire::messages::data::DataRow;
43use query::planner::DfLogicalPlanner;
44use query::query_engine::DescribeResult;
45use session::Session;
46use session::context::QueryContextRef;
47use snafu::ResultExt;
48use sql::dialect::PostgreSqlDialect;
49use sql::parser::{ParseOptions, ParserContext};
50use sql::statements::statement::Statement;
51
52use crate::SqlPlan;
53use crate::error::{DataFusionSnafu, InferParameterTypesSnafu, Result};
54use crate::postgres::types::*;
55use crate::postgres::utils::convert_err;
56use crate::postgres::{PostgresServerHandlerInner, fixtures};
57use crate::query_handler::sql::ServerSqlQueryHandlerRef;
58
59#[async_trait]
60impl SimpleQueryHandler for PostgresServerHandlerInner {
61    #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
62    async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
63    where
64        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
65        C::Error: Debug,
66        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
67    {
68        let query_ctx = self.session.new_query_context();
69        let db = query_ctx.get_db_string();
70        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
71            .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
72            .start_timer();
73
74        if query.is_empty() {
75            // early return if query is empty
76            return Ok(vec![Response::EmptyQuery]);
77        }
78
79        let parsed_query = self.query_parser.compatibility_parser.parse(query);
80
81        let query = if let Ok(statements) = &parsed_query {
82            statements
83                .iter()
84                .map(|s| s.to_string())
85                .collect::<Vec<_>>()
86                .join(";")
87        } else {
88            query.to_string()
89        };
90
91        if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
92            send_warning_opt(client, query_ctx).await?;
93            Ok(resps)
94        } else {
95            let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
96
97            let mut results = Vec::with_capacity(outputs.len());
98
99            let statements = parsed_query.ok();
100            for (idx, output) in outputs.into_iter().enumerate() {
101                let copy_format = statements
102                    .as_ref()
103                    .and_then(|stmts| stmts.get(idx))
104                    .and_then(check_copy_to_stdout);
105                let resp = if let Some(format) = &copy_format {
106                    output_to_copy_response(query_ctx.clone(), output, format)?
107                } else {
108                    output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?
109                };
110                results.push(resp);
111            }
112
113            send_warning_opt(client, query_ctx).await?;
114            Ok(results)
115        }
116    }
117}
118
119async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
120where
121    C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
122    C::Error: Debug,
123    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
124{
125    if let Some(warning) = query_context.warning() {
126        client
127            .feed(PgWireBackendMessage::NoticeResponse(
128                ErrorInfo::new(
129                    PgErrorSeverity::Warning.to_string(),
130                    PgErrorCode::Ec01000.code(),
131                    warning.clone(),
132                )
133                .into(),
134            ))
135            .await?;
136    }
137
138    Ok(())
139}
140
141pub(crate) fn output_to_query_response(
142    query_ctx: QueryContextRef,
143    output: Result<Output>,
144    field_format: &Format,
145) -> PgWireResult<Response> {
146    match output {
147        Ok(o) => match o.data {
148            OutputData::AffectedRows(rows) => {
149                Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
150            }
151            OutputData::Stream(record_stream) => {
152                let schema = record_stream.schema();
153                recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
154            }
155            OutputData::RecordBatches(recordbatches) => {
156                let schema = recordbatches.schema();
157                recordbatches_to_query_response(
158                    query_ctx,
159                    recordbatches.as_stream(),
160                    schema,
161                    field_format,
162                )
163            }
164        },
165        Err(e) => Err(convert_err(e)),
166    }
167}
168
169type RowStream<T> = Pin<Box<dyn Stream<Item = PgWireResult<T>> + Send + Unpin>>;
170
171fn recordbatches_to_query_response<S>(
172    query_ctx: QueryContextRef,
173    recordbatches_stream: S,
174    schema: SchemaRef,
175    field_format: &Format,
176) -> PgWireResult<Response>
177where
178    S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
179{
180    let format_options = format_options_from_query_ctx(&query_ctx);
181    let pg_schema = Arc::new(
182        schema_to_pg(schema.as_ref(), field_format, Some(format_options)).map_err(convert_err)?,
183    );
184
185    let encoder = DataRowEncoder::new(pg_schema.clone());
186    let row_stream = RecordBatchRowStream::new(
187        query_ctx.clone(),
188        pg_schema.clone(),
189        schema.clone(),
190        recordbatches_stream,
191        encoder,
192    );
193
194    let data_row_stream: RowStream<DataRow> = Box::pin(
195        row_stream
196            .map(move |result| match result {
197                Ok(rows) => Box::pin(stream::iter(rows.into_iter().map(Ok))) as RowStream<DataRow>,
198                Err(e) => Box::pin(stream::once(future::ready(Err(e)))) as RowStream<DataRow>,
199            })
200            .flatten(),
201    );
202
203    Ok(Response::Query(QueryResponse::new(
204        pg_schema,
205        data_row_stream,
206    )))
207}
208
209pub(crate) fn output_to_copy_response(
210    query_ctx: QueryContextRef,
211    output: Result<Output>,
212    format: &str,
213) -> PgWireResult<Response> {
214    match output {
215        Ok(o) => match o.data {
216            OutputData::AffectedRows(_) => Err(PgWireError::UserError(Box::new(ErrorInfo::new(
217                "ERROR".to_string(),
218                "42601".to_string(),
219                "COPY cannot be used with non-query statements".to_string(),
220            )))),
221            OutputData::Stream(record_stream) => {
222                let schema = record_stream.schema();
223                recordbatches_to_copy_response(query_ctx, record_stream, schema, format)
224            }
225            OutputData::RecordBatches(recordbatches) => {
226                let schema = recordbatches.schema();
227                recordbatches_to_copy_response(query_ctx, recordbatches.as_stream(), schema, format)
228            }
229        },
230        Err(e) => Err(convert_err(e)),
231    }
232}
233
234fn recordbatches_to_copy_response<S>(
235    query_ctx: QueryContextRef,
236    recordbatches_stream: S,
237    schema: SchemaRef,
238    format: &str,
239) -> PgWireResult<Response>
240where
241    S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
242{
243    let format_options = format_options_from_query_ctx(&query_ctx);
244    let pg_fields = schema_to_pg(schema.as_ref(), &Format::UnifiedText, Some(format_options))
245        .map_err(convert_err)?;
246
247    let copy_format = match format.to_lowercase().as_str() {
248        "binary" => 1,
249        _ => 0,
250    };
251
252    let pg_schema = Arc::new(pg_fields);
253    let num_columns = pg_schema.len();
254
255    let copy_encoder = match format.to_lowercase().as_str() {
256        "csv" => CopyEncoder::new_csv(pg_schema.clone(), CopyCsvOptions::default()),
257        "binary" => CopyEncoder::new_binary(pg_schema.clone()),
258        _ => CopyEncoder::new_text(pg_schema.clone(), CopyTextOptions::default()),
259    };
260
261    let row_stream = RecordBatchRowStream::new(
262        query_ctx.clone(),
263        pg_schema.clone(),
264        schema.clone(),
265        recordbatches_stream,
266        copy_encoder,
267    );
268
269    let copy_stream: RowStream<CopyData> = Box::pin(
270        row_stream
271            .map(move |result| match result {
272                Ok(rows) => Box::pin(stream::iter(rows.into_iter().map(Ok))) as RowStream<CopyData>,
273                Err(e) => Box::pin(stream::once(future::ready(Err(e)))) as RowStream<CopyData>,
274            })
275            .flatten(),
276    );
277
278    Ok(Response::CopyOut(CopyResponse::new(
279        copy_format,
280        num_columns,
281        copy_stream,
282    )))
283}
284
285pub struct DefaultQueryParser {
286    query_handler: ServerSqlQueryHandlerRef,
287    session: Arc<Session>,
288    compatibility_parser: PostgresCompatibilityParser,
289}
290
291impl DefaultQueryParser {
292    pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
293        DefaultQueryParser {
294            query_handler,
295            session,
296            compatibility_parser: PostgresCompatibilityParser::new(),
297        }
298    }
299}
300
301/// A container type of parse result types
302#[derive(Clone, Debug)]
303pub struct PgSqlPlan {
304    pub(crate) plan: SqlPlan,
305    pub(crate) copy_to_stdout_format: Option<String>,
306}
307
308#[async_trait]
309impl QueryParser for DefaultQueryParser {
310    type Statement = PgSqlPlan;
311
312    async fn parse_sql<C>(
313        &self,
314        _client: &C,
315        sql: &str,
316        _types: &[Option<Type>],
317    ) -> PgWireResult<Self::Statement> {
318        crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
319        let query_ctx = self.session.new_query_context();
320
321        // do not parse if query is empty or matches rules
322        if sql.is_empty() {
323            return Ok(PgSqlPlan {
324                plan: SqlPlan::Empty,
325                copy_to_stdout_format: None,
326            });
327        }
328
329        if fixtures::matches(sql) {
330            return Ok(PgSqlPlan {
331                plan: SqlPlan::Shortcut(sql.to_string()),
332                copy_to_stdout_format: None,
333            });
334        }
335
336        let parsed_statements = self.compatibility_parser.parse(sql);
337        let (sql, copy_to_stdout_format) = if let Ok(mut statements) = parsed_statements {
338            let first_stmt = statements.remove(0);
339            let format = check_copy_to_stdout(&first_stmt);
340            (first_stmt.to_string(), format)
341        } else {
342            // bypass the error: it can run into error because of different
343            // versions of sqlparser
344            (sql.to_string(), None)
345        };
346
347        let mut stmts = ParserContext::create_with_dialect(
348            &sql,
349            &PostgreSqlDialect {},
350            ParseOptions::default(),
351        )
352        .map_err(convert_err)?;
353        if stmts.len() != 1 {
354            Err(PgWireError::UserError(Box::new(ErrorInfo::from(
355                PgErrorCode::Ec42P14,
356            ))))
357        } else {
358            let stmt = stmts.remove(0);
359
360            if let Some(logical_plan) = self
361                .query_handler
362                .do_describe(stmt.clone(), query_ctx)
363                .await
364                .map_err(convert_err)?
365                .map(|DescribeResult { logical_plan }| logical_plan)
366            {
367                Ok(PgSqlPlan {
368                    plan: SqlPlan::Plan(logical_plan, stmt),
369                    copy_to_stdout_format,
370                })
371            } else {
372                Ok(PgSqlPlan {
373                    plan: SqlPlan::Statement(stmt, sql),
374                    copy_to_stdout_format,
375                })
376            }
377        }
378    }
379
380    fn get_parameter_types(&self, _stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
381        // we have our own implementation of describes in ExtendedQueryHandler
382        // so we don't use these methods
383        Err(PgWireError::ApiError(
384            "get_parameter_types is not expected to be called".into(),
385        ))
386    }
387
388    fn get_result_schema(
389        &self,
390        _stmt: &Self::Statement,
391        _column_format: Option<&Format>,
392    ) -> PgWireResult<Vec<FieldInfo>> {
393        // we have our own implementation of describes in ExtendedQueryHandler
394        // so we don't use these methods
395        Err(PgWireError::ApiError(
396            "get_result_schema is not expected to be called".into(),
397        ))
398    }
399}
400
401#[async_trait]
402impl ExtendedQueryHandler for PostgresServerHandlerInner {
403    type Statement = PgSqlPlan;
404    type QueryParser = DefaultQueryParser;
405
406    fn query_parser(&self) -> Arc<Self::QueryParser> {
407        self.query_parser.clone()
408    }
409
410    async fn do_query<C>(
411        &self,
412        client: &mut C,
413        portal: &Portal<Self::Statement>,
414        _max_rows: usize,
415    ) -> PgWireResult<Response>
416    where
417        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
418        C::Error: Debug,
419        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
420    {
421        let query_ctx = self.session.new_query_context();
422        let db = query_ctx.get_db_string();
423        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
424            .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
425            .start_timer();
426
427        let pg_sql_plan = &portal.statement.statement;
428        let sql_plan = &pg_sql_plan.plan;
429
430        let output = match sql_plan {
431            SqlPlan::Empty => {
432                // early return if query is empty
433                return Ok(Response::EmptyQuery);
434            }
435            SqlPlan::Shortcut(query) => {
436                if let Some(mut resps) = fixtures::process(query, query_ctx.clone()) {
437                    send_warning_opt(client, query_ctx).await?;
438                    // if the statement matches our predefined rules, return it early
439                    return Ok(resps.remove(0));
440                } else {
441                    // unreachable logic
442                    return Ok(Response::EmptyQuery);
443                }
444            }
445            SqlPlan::Plan(plan, stmt) => {
446                let values = parameters_to_scalar_values(plan, portal)?;
447                let plan = plan
448                    .clone()
449                    .replace_params_with_values(&ParamValues::List(
450                        values.into_iter().map(Into::into).collect(),
451                    ))
452                    .context(DataFusionSnafu)
453                    .map_err(convert_err)?;
454                self.query_handler
455                    .do_exec_plan(plan, Some(stmt.clone()), query_ctx.clone())
456                    .await
457            }
458            SqlPlan::Statement(_stmt, query) => {
459                // We won't replace params from statement manually any more.
460                // Newer version of datafusion can generate plan for SELECT/INSERT/UPDATE/DELETE.
461                // Only CREATE TABLE and others minor statements cannot generate sql plan,
462                // in this case, we assume these statements will not carry parameters
463                // and execute them directly.
464                self.query_handler
465                    .do_query(query, query_ctx.clone())
466                    .await
467                    .remove(0)
468            }
469        };
470
471        send_warning_opt(client, query_ctx.clone()).await?;
472
473        if let Some(format) = &pg_sql_plan.copy_to_stdout_format {
474            output_to_copy_response(query_ctx, output, format)
475        } else {
476            output_to_query_response(query_ctx, output, &portal.result_column_format)
477        }
478    }
479
480    async fn do_describe_statement<C>(
481        &self,
482        _client: &mut C,
483        stmt: &StoredStatement<Self::Statement>,
484    ) -> PgWireResult<DescribeStatementResponse>
485    where
486        C: ClientInfo + Unpin + Send + Sync,
487    {
488        let sql_plan = &stmt.statement.plan;
489        // client provided parameter types, can be empty if client doesn't try to parse statement
490        let provided_param_types = &stmt.parameter_types;
491        let server_inferenced_types = if let SqlPlan::Plan(plan, _) = &sql_plan {
492            let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
493                .context(InferParameterTypesSnafu)
494                .map_err(convert_err)?
495                .into_iter()
496                .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
497                .collect();
498
499            let types = param_types_to_pg_types(&param_types).map_err(convert_err)?;
500
501            Some(types)
502        } else {
503            None
504        };
505
506        let param_count = if provided_param_types.is_empty() {
507            server_inferenced_types
508                .as_ref()
509                .map(|types| types.len())
510                .unwrap_or(0)
511        } else {
512            provided_param_types.len()
513        };
514
515        let param_types = (0..param_count)
516            .map(|i| {
517                let client_type = provided_param_types.get(i);
518                // use server type when client provided type is None (oid: 0 or other invalid values)
519                match client_type {
520                    Some(Some(client_type)) => client_type.clone(),
521                    _ => server_inferenced_types
522                        .as_ref()
523                        .and_then(|types| types.get(i).cloned())
524                        .unwrap_or(Type::UNKNOWN),
525                }
526            })
527            .collect::<Vec<_>>();
528
529        let fields = describe_fields(sql_plan, &Format::UnifiedText, &self.session)?;
530
531        Ok(DescribeStatementResponse::new(param_types, fields))
532    }
533
534    async fn do_describe_portal<C>(
535        &self,
536        _client: &mut C,
537        portal: &Portal<Self::Statement>,
538    ) -> PgWireResult<DescribePortalResponse>
539    where
540        C: ClientInfo + Unpin + Send + Sync,
541    {
542        let sql_plan = &portal.statement.statement.plan;
543        let format = &portal.result_column_format;
544
545        let fields = describe_fields(sql_plan, format, &self.session)?;
546
547        Ok(DescribePortalResponse::new(fields))
548    }
549}
550
551fn describe_fields(
552    sql_plan: &SqlPlan,
553    format: &Format,
554    session: &Arc<Session>,
555) -> PgWireResult<Vec<FieldInfo>> {
556    match sql_plan {
557        // query
558        SqlPlan::Plan(plan, _) if !matches!(plan, LogicalPlan::Dml(_) | LogicalPlan::Ddl(_)) => {
559            let schema: Schema = plan.schema().clone().try_into().map_err(convert_err)?;
560            schema_to_pg(&schema, format, None).map_err(convert_err)
561        }
562        // We can cover only part of show statements
563        // these show create statements will return 2 columns
564        SqlPlan::Statement(
565            Statement::ShowCreateDatabase(_)
566            | Statement::ShowCreateTable(_)
567            | Statement::ShowCreateFlow(_)
568            | Statement::ShowCreateView(_),
569            _,
570        ) => Ok(vec![
571            FieldInfo::new(
572                "name".to_string(),
573                None,
574                None,
575                Type::TEXT,
576                format.format_for(0),
577            ),
578            FieldInfo::new(
579                "create_statement".to_string(),
580                None,
581                None,
582                Type::TEXT,
583                format.format_for(1),
584            ),
585        ]),
586        #[cfg(feature = "enterprise")]
587        SqlPlan::Statement(Statement::ShowCreateTrigger(_), _) => Ok(vec![
588            FieldInfo::new(
589                "name".to_string(),
590                None,
591                None,
592                Type::TEXT,
593                format.format_for(0),
594            ),
595            FieldInfo::new(
596                "create_statement".to_string(),
597                None,
598                None,
599                Type::TEXT,
600                format.format_for(1),
601            ),
602        ]),
603        // single column show statements
604        SqlPlan::Statement(
605            Statement::ShowTables(_) | Statement::ShowFlows(_) | Statement::ShowViews(_),
606            _,
607        ) => Ok(vec![FieldInfo::new(
608            "name".to_string(),
609            None,
610            None,
611            Type::TEXT,
612            format.format_for(0),
613        )]),
614        #[cfg(feature = "enterprise")]
615        SqlPlan::Statement(Statement::ShowTriggers(_), _) => Ok(vec![FieldInfo::new(
616            "name".to_string(),
617            None,
618            None,
619            Type::TEXT,
620            format.format_for(0),
621        )]),
622        // we will not support other show statements for extended query protocol at least for now.
623        // because the return columns is not predictable at this stage
624        SqlPlan::Shortcut(query) => {
625            // test if query caught by fixture
626            if let Some(mut resp) = fixtures::process(query, session.new_query_context())
627                && let Response::Query(query_response) = resp.remove(0)
628            {
629                Ok((*query_response.row_schema()).clone())
630            } else {
631                // fallback to NoData
632                Ok(vec![])
633            }
634        }
635        _ => {
636            // NoData
637            Ok(vec![])
638        }
639    }
640}
641
642impl ErrorHandler for PostgresServerHandlerInner {
643    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
644    where
645        C: ClientInfo,
646    {
647        match error {
648            PgWireError::IoError(e) => debug!("Postgres client disconnected: {}", e),
649            _ => info!("Postgres interface error: {}", error),
650        }
651    }
652}
653
654fn check_copy_to_stdout(statement: &SqlParserStatement) -> Option<String> {
655    if let SqlParserStatement::Copy {
656        target, options, ..
657    } = statement
658        && matches!(target, CopyTarget::Stdout)
659    {
660        for opt in options {
661            if let CopyOption::Format(format_ident) = opt {
662                return Some(format_ident.value.to_lowercase());
663            }
664        }
665        return Some("txt".to_string());
666    }
667
668    None
669}
670
671#[cfg(test)]
672mod tests {
673    use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
674
675    use super::*;
676
677    fn parse_copy_statement(sql: &str) -> SqlParserStatement {
678        let parser = PostgresCompatibilityParser::new();
679        let statements = parser.parse(sql).unwrap();
680        statements.into_iter().next().unwrap()
681    }
682
683    #[test]
684    fn test_check_copy_out_with_csv_format() {
685        let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT CSV)");
686        assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
687    }
688
689    #[test]
690    fn test_check_copy_out_with_txt_format() {
691        let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT TXT)");
692        assert_eq!(check_copy_to_stdout(&statement), Some("txt".to_string()));
693    }
694
695    #[test]
696    fn test_check_copy_out_with_binary_format() {
697        let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT BINARY)");
698        assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
699    }
700
701    #[test]
702    fn test_check_copy_out_without_format() {
703        let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT");
704        assert_eq!(check_copy_to_stdout(&statement), Some("txt".to_string()));
705    }
706
707    #[test]
708    fn test_check_copy_out_to_file() {
709        let statement =
710            parse_copy_statement("COPY (SELECT 1) TO '/path/to/file.csv' WITH (FORMAT CSV)");
711        assert_eq!(check_copy_to_stdout(&statement), None);
712    }
713
714    #[test]
715    fn test_check_copy_out_case_insensitive() {
716        let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT csv)");
717        assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
718
719        let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT binary)");
720        assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
721    }
722
723    #[test]
724    fn test_check_copy_out_with_multiple_options() {
725        let statement = parse_copy_statement(
726            "COPY (SELECT 1) TO STDOUT WITH (FORMAT csv, DELIMITER ',', HEADER)",
727        );
728        assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
729
730        let statement = parse_copy_statement(
731            "COPY (SELECT 1) TO STDOUT WITH (DELIMITER ',', HEADER, FORMAT binary)",
732        );
733        assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
734    }
735}