Compare commits

...

3 Commits

Author SHA1 Message Date
Bojan Serafimov
1c40c26313 Parse search_path option 2022-03-07 18:50:52 -05:00
Bojan Serafimov
a6ace609a7 Fix typo 2022-03-07 17:56:12 -05:00
Bojan Serafimov
29d72e8955 Add proxy test 2022-03-07 14:32:24 -05:00
7 changed files with 99 additions and 28 deletions

View File

@@ -7,11 +7,13 @@ use std::collections::HashMap;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe}; use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe};
// TODO rename the struct to ClientParams or something
/// Various client credentials which we use for authentication. /// Various client credentials which we use for authentication.
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct ClientCredentials { pub struct ClientCredentials {
pub user: String, pub user: String,
pub dbname: String, pub dbname: String,
pub options: Option<String>,
} }
impl TryFrom<HashMap<String, String>> for ClientCredentials { impl TryFrom<HashMap<String, String>> for ClientCredentials {
@@ -25,9 +27,22 @@ impl TryFrom<HashMap<String, String>> for ClientCredentials {
}; };
let user = get_param("user")?; let user = get_param("user")?;
let db = get_param("database")?; let dbname = get_param("database")?;
Ok(Self { user, dbname: db }) // TODO see what other options should be recognized, possibly all.
let options = match get_param("search_path") {
Ok(path) => Some(format!("-c search_path={}", path)),
Err(_) => None,
};
// TODO investigate why "" is always a key
// TODO warn on unrecognized options?
Ok(Self {
user,
dbname,
options,
})
} }
} }
@@ -85,6 +100,7 @@ async fn handle_static(
dbname: creds.dbname.clone(), dbname: creds.dbname.clone(),
user: creds.user.clone(), user: creds.user.clone(),
password: Some(cleartext_password.into()), password: Some(cleartext_password.into()),
options: creds.options,
}; };
client client
@@ -117,15 +133,22 @@ async fn handle_existing_user(
.ok_or_else(|| anyhow!("unexpected password message"))?; .ok_or_else(|| anyhow!("unexpected password message"))?;
let cplane = CPlaneApi::new(&config.auth_endpoint); let cplane = CPlaneApi::new(&config.auth_endpoint);
let db_info = cplane let db_info_response = cplane
.authenticate_proxy_request(creds, md5_response, &md5_salt, &psql_session_id) .authenticate_proxy_request(&creds, md5_response, &md5_salt, &psql_session_id)
.await?; .await?;
client client
.write_message_noflush(&Be::AuthenticationOk)? .write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?; .write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info) Ok(DatabaseInfo {
host: db_info_response.host,
port: db_info_response.port,
dbname: db_info_response.dbname,
user: db_info_response.user,
password: db_info_response.password,
options: creds.options,
})
} }
async fn handle_new_user( async fn handle_new_user(
@@ -135,7 +158,7 @@ async fn handle_new_user(
let psql_session_id = new_psql_session_id(); let psql_session_id = new_psql_session_id();
let greeting = hello_message(&config.redirect_uri, &psql_session_id); let greeting = hello_message(&config.redirect_uri, &psql_session_id);
let db_info = cplane_api::with_waiter(psql_session_id, |waiter| async { let db_info_response = cplane_api::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database // Give user a URL to spawn a new database
client client
.write_message_noflush(&Be::AuthenticationOk)? .write_message_noflush(&Be::AuthenticationOk)?
@@ -150,7 +173,14 @@ async fn handle_new_user(
client.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?; client.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
Ok(db_info) Ok(DatabaseInfo {
host: db_info_response.host,
port: db_info_response.port,
dbname: db_info_response.dbname,
user: db_info_response.user,
password: db_info_response.password,
options: None,
})
} }
fn hello_message(redirect_uri: &str, session_id: &str) -> String { fn hello_message(redirect_uri: &str, session_id: &str) -> String {

View File

@@ -10,6 +10,7 @@ pub struct DatabaseInfo {
pub dbname: String, pub dbname: String,
pub user: String, pub user: String,
pub password: Option<String>, pub password: Option<String>,
pub options: Option<String>,
} }
impl DatabaseInfo { impl DatabaseInfo {
@@ -33,6 +34,10 @@ impl From<DatabaseInfo> for tokio_postgres::Config {
.dbname(&db_info.dbname) .dbname(&db_info.dbname)
.user(&db_info.user); .user(&db_info.user);
if let Some(options) = db_info.options {
config.options(&options);
}
if let Some(password) = db_info.password { if let Some(password) = db_info.password {
config.password(password); config.password(password);
} }

View File

@@ -1,25 +1,37 @@
use crate::auth::ClientCredentials; use crate::auth::ClientCredentials;
use crate::compute::DatabaseInfo;
use crate::waiters::{Waiter, Waiters}; use crate::waiters::{Waiter, Waiters};
use anyhow::{anyhow, bail}; use anyhow::{anyhow, bail};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Part of the legacy cplane responses
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct DatabaseInfoResponse {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
pub password: Option<String>,
}
lazy_static! { lazy_static! {
static ref CPLANE_WAITERS: Waiters<Result<DatabaseInfo, String>> = Default::default(); static ref CPLANE_WAITERS: Waiters<Result<DatabaseInfoResponse, String>> = Default::default();
} }
/// Give caller an opportunity to wait for cplane's reply. /// Give caller an opportunity to wait for cplane's reply.
pub async fn with_waiter<F, R, T>(psql_session_id: impl Into<String>, f: F) -> anyhow::Result<T> pub async fn with_waiter<F, R, T>(psql_session_id: impl Into<String>, f: F) -> anyhow::Result<T>
where where
F: FnOnce(Waiter<'static, Result<DatabaseInfo, String>>) -> R, F: FnOnce(Waiter<'static, Result<DatabaseInfoResponse, String>>) -> R,
R: std::future::Future<Output = anyhow::Result<T>>, R: std::future::Future<Output = anyhow::Result<T>>,
{ {
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?; let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
f(waiter).await f(waiter).await
} }
pub fn notify(psql_session_id: &str, msg: Result<DatabaseInfo, String>) -> anyhow::Result<()> { pub fn notify(
psql_session_id: &str,
msg: Result<DatabaseInfoResponse, String>,
) -> anyhow::Result<()> {
CPLANE_WAITERS.notify(psql_session_id, msg) CPLANE_WAITERS.notify(psql_session_id, msg)
} }
@@ -37,11 +49,11 @@ impl<'a> CPlaneApi<'a> {
impl CPlaneApi<'_> { impl CPlaneApi<'_> {
pub async fn authenticate_proxy_request( pub async fn authenticate_proxy_request(
&self, &self,
creds: ClientCredentials, creds: &ClientCredentials,
md5_response: &[u8], md5_response: &[u8],
salt: &[u8; 4], salt: &[u8; 4],
psql_session_id: &str, psql_session_id: &str,
) -> anyhow::Result<DatabaseInfo> { ) -> anyhow::Result<DatabaseInfoResponse> {
let mut url = reqwest::Url::parse(self.auth_endpoint)?; let mut url = reqwest::Url::parse(self.auth_endpoint)?;
url.query_pairs_mut() url.query_pairs_mut()
.append_pair("login", &creds.user) .append_pair("login", &creds.user)
@@ -77,7 +89,7 @@ impl CPlaneApi<'_> {
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)] #[serde(untagged)]
enum ProxyAuthResponse { enum ProxyAuthResponse {
Ready { conn_info: DatabaseInfo }, Ready { conn_info: DatabaseInfoResponse },
Error { error: String }, Error { error: String },
NotReady { ready: bool }, // TODO: get rid of `ready` NotReady { ready: bool }, // TODO: get rid of `ready`
} }
@@ -92,13 +104,13 @@ mod tests {
// Ready // Ready
let auth: ProxyAuthResponse = serde_json::from_value(json!({ let auth: ProxyAuthResponse = serde_json::from_value(json!({
"ready": true, "ready": true,
"conn_info": DatabaseInfo::default(), "conn_info": DatabaseInfoResponse::default(),
})) }))
.unwrap(); .unwrap();
assert!(matches!( assert!(matches!(
auth, auth,
ProxyAuthResponse::Ready { ProxyAuthResponse::Ready {
conn_info: DatabaseInfo { .. } conn_info: DatabaseInfoResponse { .. }
} }
)); ));

View File

@@ -1,4 +1,4 @@
use crate::{compute::DatabaseInfo, cplane_api}; use crate::cplane_api;
use anyhow::Context; use anyhow::Context;
use serde::Deserialize; use serde::Deserialize;
use std::{ use std::{
@@ -75,7 +75,7 @@ struct PsqlSessionResponse {
#[derive(Deserialize)] #[derive(Deserialize)]
enum PsqlSessionResult { enum PsqlSessionResult {
Success(DatabaseInfo), Success(cplane_api::DatabaseInfoResponse),
Failure(String), Failure(String),
} }

View File

@@ -1,4 +1,4 @@
use crate::auth; use crate::auth::{self, ClientCredentials};
use crate::cancellation::{self, CancelClosure, CancelMap}; use crate::cancellation::{self, CancelClosure, CancelMap};
use crate::compute::DatabaseInfo; use crate::compute::DatabaseInfo;
use crate::config::{ProxyConfig, TlsConfig}; use crate::config::{ProxyConfig, TlsConfig};
@@ -138,7 +138,6 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream.write_message(&Be::ErrorResponse(msg)).await?; stream.write_message(&Be::ErrorResponse(msg)).await?;
bail!(msg); bail!(msg);
} }
break Ok(Some((stream, params.try_into()?))); break Ok(Some((stream, params.try_into()?)));
} }
CancelRequest(cancel_key_data) => { CancelRequest(cancel_key_data) => {

View File

@@ -1,2 +1,14 @@
import pytest
def test_proxy_select_1(static_proxy): def test_proxy_select_1(static_proxy):
static_proxy.safe_psql("select 1;") static_proxy.safe_psql("select 1;")
def test_proxy_options(static_proxy):
schema_name = "tmp_schema_1"
with static_proxy.connect(schema=schema_name) as conn:
with conn.cursor() as cur:
cur.execute("SHOW search_path;")
search_path = cur.fetchall()[0][0]
assert schema_name == search_path

View File

@@ -242,15 +242,20 @@ class PgProtocol:
host: str, host: str,
port: int, port: int,
username: Optional[str] = None, username: Optional[str] = None,
password: Optional[str] = None): password: Optional[str] = None,
dbname: Optional[str] = None,
schema: Optional[str] = None):
self.host = host self.host = host
self.port = port self.port = port
self.username = username self.username = username
self.password = password self.password = password
self.dbname = dbname
self.schema = schema
def connstr(self, def connstr(self,
*, *,
dbname: str = 'postgres', dbname: Optional[str] = None,
schema: Optional[str] = None,
username: Optional[str] = None, username: Optional[str] = None,
password: Optional[str] = None) -> str: password: Optional[str] = None) -> str:
""" """
@@ -259,6 +264,8 @@ class PgProtocol:
username = username or self.username username = username or self.username
password = password or self.password password = password or self.password
dbname = dbname or self.dbname or "postgres"
schema = schema or self.schema
res = f'host={self.host} port={self.port} dbname={dbname}' res = f'host={self.host} port={self.port} dbname={dbname}'
if username: if username:
@@ -267,13 +274,17 @@ class PgProtocol:
if password: if password:
res = f'{res} password={password}' res = f'{res} password={password}'
if schema:
res = f"{res} options='-c search_path={schema}'"
return res return res
# autocommit=True here by default because that's what we need most of the time # autocommit=True here by default because that's what we need most of the time
def connect(self, def connect(self,
*, *,
autocommit=True, autocommit=True,
dbname: str = 'postgres', dbname: Optional[str] = None,
schema: Optional[str] = None,
username: Optional[str] = None, username: Optional[str] = None,
password: Optional[str] = None) -> PgConnection: password: Optional[str] = None) -> PgConnection:
""" """
@@ -282,11 +293,13 @@ class PgProtocol:
This method passes all extra params to connstr. This method passes all extra params to connstr.
""" """
conn = psycopg2.connect(self.connstr( conn = psycopg2.connect(
dbname=dbname, self.connstr(
username=username, dbname=dbname,
password=password, schema=schema,
)) username=username,
password=password,
))
# WARNING: this setting affects *all* tests! # WARNING: this setting affects *all* tests!
conn.autocommit = autocommit conn.autocommit = autocommit
return conn return conn