From 05994656857273920182c78c079ae302dc6e4fb4 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Fri, 2 Dec 2022 11:49:39 +0800 Subject: [PATCH] feat: inject current database/schema into query context for postgres protocol (#685) * feat: inject current database/schema into query context * test: avoid duplicate server setup --- src/servers/src/postgres/handler.rs | 19 ++++++++-- src/servers/tests/postgres/mod.rs | 59 ++++++++++++++++++++--------- 2 files changed, 57 insertions(+), 21 deletions(-) diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 66d7099522..6cf82465a0 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -42,14 +42,27 @@ impl PostgresServerHandler { } } +const CLIENT_METADATA_DATABASE: &str = "database"; + +fn query_context_from_client_info(client: &C) -> Arc +where + C: ClientInfo, +{ + let query_context = QueryContext::new(); + if let Some(current_schema) = client.metadata().get(CLIENT_METADATA_DATABASE) { + query_context.set_current_schema(current_schema); + } + + Arc::new(query_context) +} + #[async_trait] impl SimpleQueryHandler for PostgresServerHandler { - async fn do_query(&self, _client: &C, query: &str) -> PgWireResult> + async fn do_query(&self, client: &C, query: &str) -> PgWireResult> where C: ClientInfo + Unpin + Send + Sync, { - // TODO(LFC): Sessions in pg server. - let query_ctx = Arc::new(QueryContext::new()); + let query_ctx = query_context_from_client_info(client); let output = self .query_handler .do_query(query, query_ctx) diff --git a/src/servers/tests/postgres/mod.rs b/src/servers/tests/postgres/mod.rs index 552891424f..8abc5ff760 100644 --- a/src/servers/tests/postgres/mod.rs +++ b/src/servers/tests/postgres/mod.rs @@ -16,6 +16,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, SystemTime}; +use common_catalog::consts::DEFAULT_SCHEMA_NAME; use common_runtime::Builder as RuntimeBuilder; use rand::rngs::StdRng; use rand::Rng; @@ -80,7 +81,6 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> { common_telemetry::init_default_ut_logging(); let table = MemTable::default_numbers_table(); - let postgres_server = create_postgres_server(table, with_pwd, Default::default())?; let result = postgres_server.shutdown().await; assert!(result @@ -136,14 +136,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn test_query_pg_concurrently() -> Result<()> { - common_telemetry::init_default_ut_logging(); - - let table = MemTable::default_numbers_table(); - - let pg_server = create_postgres_server(table, false, Default::default())?; - let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = pg_server.start(listening).await.unwrap(); - let server_port = server_addr.port(); + let server_port = start_test_server(Default::default()).await?; let threads = 4; let expect_executed_queries_per_worker = 300; @@ -211,13 +204,7 @@ async fn test_server_secure_require_client_plain() -> Result<()> { cert_path: "tests/ssl/server.crt".to_owned(), key_path: "tests/ssl/server.key".to_owned(), }); - - let table = MemTable::default_numbers_table(); - let pg_server = create_postgres_server(table, false, server_tls)?; - let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = pg_server.start(listening).await.unwrap(); - let server_port = server_addr.port(); - + let server_port = start_test_server(server_tls).await?; let r = create_plain_connection(server_port, false).await; assert!(r.is_err()); Ok(()) @@ -238,12 +225,35 @@ async fn test_server_secure_require_client_secure() -> Result<()> { Ok(()) } -async fn do_simple_query(server_tls: Arc, client_tls: bool) -> Result<()> { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_using_db() -> Result<()> { + let server_port = start_test_server(Arc::new(TlsOption::default())).await?; + + let client = create_connection_with_given_db(server_port, "testdb") + .await + .unwrap(); + let result = client.simple_query("SELECT uint32s FROM numbers").await; + assert!(result.is_err()); + + let client = create_connection_with_given_db(server_port, DEFAULT_SCHEMA_NAME) + .await + .unwrap(); + let result = client.simple_query("SELECT uint32s FROM numbers").await; + assert!(result.is_ok()); + Ok(()) +} + +async fn start_test_server(server_tls: Arc) -> Result { + common_telemetry::init_default_ut_logging(); let table = MemTable::default_numbers_table(); let pg_server = create_postgres_server(table, false, server_tls)?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = pg_server.start(listening).await.unwrap(); - let server_port = server_addr.port(); + Ok(server_addr.port()) +} + +async fn do_simple_query(server_tls: Arc, client_tls: bool) -> Result<()> { + let server_port = start_test_server(server_tls).await?; if !client_tls { let client = create_plain_connection(server_port, false).await.unwrap(); @@ -303,6 +313,19 @@ async fn create_plain_connection( Ok(client) } +async fn create_connection_with_given_db( + port: u16, + db: &str, +) -> std::result::Result { + let url = format!( + "host=127.0.0.1 port={} connect_timeout=2 dbname={}", + port, db + ); + let (client, conn) = tokio_postgres::connect(&url, NoTls).await?; + tokio::spawn(conn); + Ok(client) +} + fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> { match resp { &SimpleQueryMessage::Row(ref r) => r.get(col_index),