Parse search_path option

This commit is contained in:
Bojan Serafimov
2022-03-07 18:50:52 -05:00
parent a6ace609a7
commit 1c40c26313
6 changed files with 66 additions and 21 deletions

View File

@@ -7,11 +7,13 @@ use std::collections::HashMap;
use tokio::io::{AsyncRead, AsyncWrite};
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe};
// TODO rename the struct to ClientParams or something
/// Various client credentials which we use for authentication.
#[derive(Debug, PartialEq, Eq)]
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
pub options: Option<String>,
}
impl TryFrom<HashMap<String, String>> for ClientCredentials {
@@ -25,9 +27,22 @@ impl TryFrom<HashMap<String, String>> for ClientCredentials {
};
let user = get_param("user")?;
let db = get_param("database")?;
let dbname = get_param("database")?;
Ok(Self { user, dbname: db })
// TODO see what other options should be recognized, possibly all.
let options = match get_param("search_path") {
Ok(path) => Some(format!("-c search_path={}", path)),
Err(_) => None,
};
// TODO investigate why "" is always a key
// TODO warn on unrecognized options?
Ok(Self {
user,
dbname,
options,
})
}
}
@@ -85,6 +100,7 @@ async fn handle_static(
dbname: creds.dbname.clone(),
user: creds.user.clone(),
password: Some(cleartext_password.into()),
options: creds.options,
};
client
@@ -117,15 +133,22 @@ async fn handle_existing_user(
.ok_or_else(|| anyhow!("unexpected password message"))?;
let cplane = CPlaneApi::new(&config.auth_endpoint);
let db_info = cplane
.authenticate_proxy_request(creds, md5_response, &md5_salt, &psql_session_id)
let db_info_response = cplane
.authenticate_proxy_request(&creds, md5_response, &md5_salt, &psql_session_id)
.await?;
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info)
Ok(DatabaseInfo {
host: db_info_response.host,
port: db_info_response.port,
dbname: db_info_response.dbname,
user: db_info_response.user,
password: db_info_response.password,
options: creds.options,
})
}
async fn handle_new_user(
@@ -135,7 +158,7 @@ async fn handle_new_user(
let psql_session_id = new_psql_session_id();
let greeting = hello_message(&config.redirect_uri, &psql_session_id);
let db_info = cplane_api::with_waiter(psql_session_id, |waiter| async {
let db_info_response = cplane_api::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database
client
.write_message_noflush(&Be::AuthenticationOk)?
@@ -150,7 +173,14 @@ async fn handle_new_user(
client.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
Ok(db_info)
Ok(DatabaseInfo {
host: db_info_response.host,
port: db_info_response.port,
dbname: db_info_response.dbname,
user: db_info_response.user,
password: db_info_response.password,
options: None,
})
}
fn hello_message(redirect_uri: &str, session_id: &str) -> String {

View File

@@ -10,6 +10,7 @@ pub struct DatabaseInfo {
pub dbname: String,
pub user: String,
pub password: Option<String>,
pub options: Option<String>,
}
impl DatabaseInfo {
@@ -33,6 +34,10 @@ impl From<DatabaseInfo> for tokio_postgres::Config {
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(options) = db_info.options {
config.options(&options);
}
if let Some(password) = db_info.password {
config.password(password);
}

View File

@@ -1,25 +1,37 @@
use crate::auth::ClientCredentials;
use crate::compute::DatabaseInfo;
use crate::waiters::{Waiter, Waiters};
use anyhow::{anyhow, bail};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
/// Part of the legacy cplane responses
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct DatabaseInfoResponse {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
pub password: Option<String>,
}
lazy_static! {
static ref CPLANE_WAITERS: Waiters<Result<DatabaseInfo, String>> = Default::default();
static ref CPLANE_WAITERS: Waiters<Result<DatabaseInfoResponse, String>> = Default::default();
}
/// Give caller an opportunity to wait for cplane's reply.
pub async fn with_waiter<F, R, T>(psql_session_id: impl Into<String>, f: F) -> anyhow::Result<T>
where
F: FnOnce(Waiter<'static, Result<DatabaseInfo, String>>) -> R,
F: FnOnce(Waiter<'static, Result<DatabaseInfoResponse, String>>) -> R,
R: std::future::Future<Output = anyhow::Result<T>>,
{
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
f(waiter).await
}
pub fn notify(psql_session_id: &str, msg: Result<DatabaseInfo, String>) -> anyhow::Result<()> {
pub fn notify(
psql_session_id: &str,
msg: Result<DatabaseInfoResponse, String>,
) -> anyhow::Result<()> {
CPLANE_WAITERS.notify(psql_session_id, msg)
}
@@ -37,11 +49,11 @@ impl<'a> CPlaneApi<'a> {
impl CPlaneApi<'_> {
pub async fn authenticate_proxy_request(
&self,
creds: ClientCredentials,
creds: &ClientCredentials,
md5_response: &[u8],
salt: &[u8; 4],
psql_session_id: &str,
) -> anyhow::Result<DatabaseInfo> {
) -> anyhow::Result<DatabaseInfoResponse> {
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
url.query_pairs_mut()
.append_pair("login", &creds.user)
@@ -77,7 +89,7 @@ impl CPlaneApi<'_> {
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
enum ProxyAuthResponse {
Ready { conn_info: DatabaseInfo },
Ready { conn_info: DatabaseInfoResponse },
Error { error: String },
NotReady { ready: bool }, // TODO: get rid of `ready`
}
@@ -92,13 +104,13 @@ mod tests {
// Ready
let auth: ProxyAuthResponse = serde_json::from_value(json!({
"ready": true,
"conn_info": DatabaseInfo::default(),
"conn_info": DatabaseInfoResponse::default(),
}))
.unwrap();
assert!(matches!(
auth,
ProxyAuthResponse::Ready {
conn_info: DatabaseInfo { .. }
conn_info: DatabaseInfoResponse { .. }
}
));

View File

@@ -1,4 +1,4 @@
use crate::{compute::DatabaseInfo, cplane_api};
use crate::cplane_api;
use anyhow::Context;
use serde::Deserialize;
use std::{
@@ -75,7 +75,7 @@ struct PsqlSessionResponse {
#[derive(Deserialize)]
enum PsqlSessionResult {
Success(DatabaseInfo),
Success(cplane_api::DatabaseInfoResponse),
Failure(String),
}

View File

@@ -1,4 +1,4 @@
use crate::auth;
use crate::auth::{self, ClientCredentials};
use crate::cancellation::{self, CancelClosure, CancelMap};
use crate::compute::DatabaseInfo;
use crate::config::{ProxyConfig, TlsConfig};
@@ -138,7 +138,6 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream.write_message(&Be::ErrorResponse(msg)).await?;
bail!(msg);
}
break Ok(Some((stream, params.try_into()?)));
}
CancelRequest(cancel_key_data) => {

View File

@@ -5,7 +5,6 @@ def test_proxy_select_1(static_proxy):
static_proxy.safe_psql("select 1;")
@pytest.mark.xfail # Proxy eats the extra connection options
def test_proxy_options(static_proxy):
schema_name = "tmp_schema_1"
with static_proxy.connect(schema=schema_name) as conn: