From 1c40c263139b2d576bf7a40cd6d25a6fb647c0aa Mon Sep 17 00:00:00 2001 From: Bojan Serafimov Date: Mon, 7 Mar 2022 18:50:52 -0500 Subject: [PATCH] Parse search_path option --- proxy/src/auth.rs | 44 ++++++++++++++++++++++---- proxy/src/compute.rs | 5 +++ proxy/src/cplane_api.rs | 30 ++++++++++++------ proxy/src/mgmt.rs | 4 +-- proxy/src/proxy.rs | 3 +- test_runner/batch_others/test_proxy.py | 1 - 6 files changed, 66 insertions(+), 21 deletions(-) diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index a5bdaeaeca..f56222bfcf 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -7,11 +7,13 @@ use std::collections::HashMap; use tokio::io::{AsyncRead, AsyncWrite}; use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe}; +// TODO rename the struct to ClientParams or something /// Various client credentials which we use for authentication. #[derive(Debug, PartialEq, Eq)] pub struct ClientCredentials { pub user: String, pub dbname: String, + pub options: Option, } impl TryFrom> for ClientCredentials { @@ -25,9 +27,22 @@ impl TryFrom> for ClientCredentials { }; let user = get_param("user")?; - let db = get_param("database")?; + let dbname = get_param("database")?; - Ok(Self { user, dbname: db }) + // TODO see what other options should be recognized, possibly all. + let options = match get_param("search_path") { + Ok(path) => Some(format!("-c search_path={}", path)), + Err(_) => None, + }; + + // TODO investigate why "" is always a key + // TODO warn on unrecognized options? + + Ok(Self { + user, + dbname, + options, + }) } } @@ -85,6 +100,7 @@ async fn handle_static( dbname: creds.dbname.clone(), user: creds.user.clone(), password: Some(cleartext_password.into()), + options: creds.options, }; client @@ -117,15 +133,22 @@ async fn handle_existing_user( .ok_or_else(|| anyhow!("unexpected password message"))?; let cplane = CPlaneApi::new(&config.auth_endpoint); - let db_info = cplane - .authenticate_proxy_request(creds, md5_response, &md5_salt, &psql_session_id) + let db_info_response = cplane + .authenticate_proxy_request(&creds, md5_response, &md5_salt, &psql_session_id) .await?; client .write_message_noflush(&Be::AuthenticationOk)? .write_message_noflush(&BeParameterStatusMessage::encoding())?; - Ok(db_info) + Ok(DatabaseInfo { + host: db_info_response.host, + port: db_info_response.port, + dbname: db_info_response.dbname, + user: db_info_response.user, + password: db_info_response.password, + options: creds.options, + }) } async fn handle_new_user( @@ -135,7 +158,7 @@ async fn handle_new_user( let psql_session_id = new_psql_session_id(); let greeting = hello_message(&config.redirect_uri, &psql_session_id); - let db_info = cplane_api::with_waiter(psql_session_id, |waiter| async { + let db_info_response = cplane_api::with_waiter(psql_session_id, |waiter| async { // Give user a URL to spawn a new database client .write_message_noflush(&Be::AuthenticationOk)? @@ -150,7 +173,14 @@ async fn handle_new_user( client.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?; - Ok(db_info) + Ok(DatabaseInfo { + host: db_info_response.host, + port: db_info_response.port, + dbname: db_info_response.dbname, + user: db_info_response.user, + password: db_info_response.password, + options: None, + }) } fn hello_message(redirect_uri: &str, session_id: &str) -> String { diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 7c294bd488..bd7a58ad58 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -10,6 +10,7 @@ pub struct DatabaseInfo { pub dbname: String, pub user: String, pub password: Option, + pub options: Option, } impl DatabaseInfo { @@ -33,6 +34,10 @@ impl From for tokio_postgres::Config { .dbname(&db_info.dbname) .user(&db_info.user); + if let Some(options) = db_info.options { + config.options(&options); + } + if let Some(password) = db_info.password { config.password(password); } diff --git a/proxy/src/cplane_api.rs b/proxy/src/cplane_api.rs index 187809717f..719848515a 100644 --- a/proxy/src/cplane_api.rs +++ b/proxy/src/cplane_api.rs @@ -1,25 +1,37 @@ use crate::auth::ClientCredentials; -use crate::compute::DatabaseInfo; use crate::waiters::{Waiter, Waiters}; use anyhow::{anyhow, bail}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; +/// Part of the legacy cplane responses +#[derive(Serialize, Deserialize, Debug, Default)] +pub struct DatabaseInfoResponse { + pub host: String, + pub port: u16, + pub dbname: String, + pub user: String, + pub password: Option, +} + lazy_static! { - static ref CPLANE_WAITERS: Waiters> = Default::default(); + static ref CPLANE_WAITERS: Waiters> = Default::default(); } /// Give caller an opportunity to wait for cplane's reply. pub async fn with_waiter(psql_session_id: impl Into, f: F) -> anyhow::Result where - F: FnOnce(Waiter<'static, Result>) -> R, + F: FnOnce(Waiter<'static, Result>) -> R, R: std::future::Future>, { let waiter = CPLANE_WAITERS.register(psql_session_id.into())?; f(waiter).await } -pub fn notify(psql_session_id: &str, msg: Result) -> anyhow::Result<()> { +pub fn notify( + psql_session_id: &str, + msg: Result, +) -> anyhow::Result<()> { CPLANE_WAITERS.notify(psql_session_id, msg) } @@ -37,11 +49,11 @@ impl<'a> CPlaneApi<'a> { impl CPlaneApi<'_> { pub async fn authenticate_proxy_request( &self, - creds: ClientCredentials, + creds: &ClientCredentials, md5_response: &[u8], salt: &[u8; 4], psql_session_id: &str, - ) -> anyhow::Result { + ) -> anyhow::Result { let mut url = reqwest::Url::parse(self.auth_endpoint)?; url.query_pairs_mut() .append_pair("login", &creds.user) @@ -77,7 +89,7 @@ impl CPlaneApi<'_> { #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] enum ProxyAuthResponse { - Ready { conn_info: DatabaseInfo }, + Ready { conn_info: DatabaseInfoResponse }, Error { error: String }, NotReady { ready: bool }, // TODO: get rid of `ready` } @@ -92,13 +104,13 @@ mod tests { // Ready let auth: ProxyAuthResponse = serde_json::from_value(json!({ "ready": true, - "conn_info": DatabaseInfo::default(), + "conn_info": DatabaseInfoResponse::default(), })) .unwrap(); assert!(matches!( auth, ProxyAuthResponse::Ready { - conn_info: DatabaseInfo { .. } + conn_info: DatabaseInfoResponse { .. } } )); diff --git a/proxy/src/mgmt.rs b/proxy/src/mgmt.rs index 55b49b441f..d39abc4188 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/mgmt.rs @@ -1,4 +1,4 @@ -use crate::{compute::DatabaseInfo, cplane_api}; +use crate::cplane_api; use anyhow::Context; use serde::Deserialize; use std::{ @@ -75,7 +75,7 @@ struct PsqlSessionResponse { #[derive(Deserialize)] enum PsqlSessionResult { - Success(DatabaseInfo), + Success(cplane_api::DatabaseInfoResponse), Failure(String), } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 1dc301b792..5be6a15b4e 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,4 +1,4 @@ -use crate::auth; +use crate::auth::{self, ClientCredentials}; use crate::cancellation::{self, CancelClosure, CancelMap}; use crate::compute::DatabaseInfo; use crate::config::{ProxyConfig, TlsConfig}; @@ -138,7 +138,6 @@ async fn handshake( stream.write_message(&Be::ErrorResponse(msg)).await?; bail!(msg); } - break Ok(Some((stream, params.try_into()?))); } CancelRequest(cancel_key_data) => { diff --git a/test_runner/batch_others/test_proxy.py b/test_runner/batch_others/test_proxy.py index d2039f9758..39802b25a7 100644 --- a/test_runner/batch_others/test_proxy.py +++ b/test_runner/batch_others/test_proxy.py @@ -5,7 +5,6 @@ def test_proxy_select_1(static_proxy): static_proxy.safe_psql("select 1;") -@pytest.mark.xfail # Proxy eats the extra connection options def test_proxy_options(static_proxy): schema_name = "tmp_schema_1" with static_proxy.connect(schema=schema_name) as conn: