Compare commits

...

3 Commits

Author SHA1 Message Date
Bojan Serafimov
1c40c26313 Parse search_path option 2022-03-07 18:50:52 -05:00
Bojan Serafimov
a6ace609a7 Fix typo 2022-03-07 17:56:12 -05:00
Bojan Serafimov
29d72e8955 Add proxy test 2022-03-07 14:32:24 -05:00
7 changed files with 99 additions and 28 deletions

View File

@@ -7,11 +7,13 @@ use std::collections::HashMap;
use tokio::io::{AsyncRead, AsyncWrite};
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe};
// TODO rename the struct to ClientParams or something
/// Various client credentials which we use for authentication.
#[derive(Debug, PartialEq, Eq)]
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
pub options: Option<String>,
}
impl TryFrom<HashMap<String, String>> for ClientCredentials {
@@ -25,9 +27,22 @@ impl TryFrom<HashMap<String, String>> for ClientCredentials {
};
let user = get_param("user")?;
let db = get_param("database")?;
let dbname = get_param("database")?;
Ok(Self { user, dbname: db })
// TODO see what other options should be recognized, possibly all.
let options = match get_param("search_path") {
Ok(path) => Some(format!("-c search_path={}", path)),
Err(_) => None,
};
// TODO investigate why "" is always a key
// TODO warn on unrecognized options?
Ok(Self {
user,
dbname,
options,
})
}
}
@@ -85,6 +100,7 @@ async fn handle_static(
dbname: creds.dbname.clone(),
user: creds.user.clone(),
password: Some(cleartext_password.into()),
options: creds.options,
};
client
@@ -117,15 +133,22 @@ async fn handle_existing_user(
.ok_or_else(|| anyhow!("unexpected password message"))?;
let cplane = CPlaneApi::new(&config.auth_endpoint);
let db_info = cplane
.authenticate_proxy_request(creds, md5_response, &md5_salt, &psql_session_id)
let db_info_response = cplane
.authenticate_proxy_request(&creds, md5_response, &md5_salt, &psql_session_id)
.await?;
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info)
Ok(DatabaseInfo {
host: db_info_response.host,
port: db_info_response.port,
dbname: db_info_response.dbname,
user: db_info_response.user,
password: db_info_response.password,
options: creds.options,
})
}
async fn handle_new_user(
@@ -135,7 +158,7 @@ async fn handle_new_user(
let psql_session_id = new_psql_session_id();
let greeting = hello_message(&config.redirect_uri, &psql_session_id);
let db_info = cplane_api::with_waiter(psql_session_id, |waiter| async {
let db_info_response = cplane_api::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database
client
.write_message_noflush(&Be::AuthenticationOk)?
@@ -150,7 +173,14 @@ async fn handle_new_user(
client.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
Ok(db_info)
Ok(DatabaseInfo {
host: db_info_response.host,
port: db_info_response.port,
dbname: db_info_response.dbname,
user: db_info_response.user,
password: db_info_response.password,
options: None,
})
}
fn hello_message(redirect_uri: &str, session_id: &str) -> String {

View File

@@ -10,6 +10,7 @@ pub struct DatabaseInfo {
pub dbname: String,
pub user: String,
pub password: Option<String>,
pub options: Option<String>,
}
impl DatabaseInfo {
@@ -33,6 +34,10 @@ impl From<DatabaseInfo> for tokio_postgres::Config {
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(options) = db_info.options {
config.options(&options);
}
if let Some(password) = db_info.password {
config.password(password);
}

View File

@@ -1,25 +1,37 @@
use crate::auth::ClientCredentials;
use crate::compute::DatabaseInfo;
use crate::waiters::{Waiter, Waiters};
use anyhow::{anyhow, bail};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
/// Part of the legacy cplane responses
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct DatabaseInfoResponse {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
pub password: Option<String>,
}
lazy_static! {
static ref CPLANE_WAITERS: Waiters<Result<DatabaseInfo, String>> = Default::default();
static ref CPLANE_WAITERS: Waiters<Result<DatabaseInfoResponse, String>> = Default::default();
}
/// Give caller an opportunity to wait for cplane's reply.
pub async fn with_waiter<F, R, T>(psql_session_id: impl Into<String>, f: F) -> anyhow::Result<T>
where
F: FnOnce(Waiter<'static, Result<DatabaseInfo, String>>) -> R,
F: FnOnce(Waiter<'static, Result<DatabaseInfoResponse, String>>) -> R,
R: std::future::Future<Output = anyhow::Result<T>>,
{
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
f(waiter).await
}
pub fn notify(psql_session_id: &str, msg: Result<DatabaseInfo, String>) -> anyhow::Result<()> {
pub fn notify(
psql_session_id: &str,
msg: Result<DatabaseInfoResponse, String>,
) -> anyhow::Result<()> {
CPLANE_WAITERS.notify(psql_session_id, msg)
}
@@ -37,11 +49,11 @@ impl<'a> CPlaneApi<'a> {
impl CPlaneApi<'_> {
pub async fn authenticate_proxy_request(
&self,
creds: ClientCredentials,
creds: &ClientCredentials,
md5_response: &[u8],
salt: &[u8; 4],
psql_session_id: &str,
) -> anyhow::Result<DatabaseInfo> {
) -> anyhow::Result<DatabaseInfoResponse> {
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
url.query_pairs_mut()
.append_pair("login", &creds.user)
@@ -77,7 +89,7 @@ impl CPlaneApi<'_> {
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
enum ProxyAuthResponse {
Ready { conn_info: DatabaseInfo },
Ready { conn_info: DatabaseInfoResponse },
Error { error: String },
NotReady { ready: bool }, // TODO: get rid of `ready`
}
@@ -92,13 +104,13 @@ mod tests {
// Ready
let auth: ProxyAuthResponse = serde_json::from_value(json!({
"ready": true,
"conn_info": DatabaseInfo::default(),
"conn_info": DatabaseInfoResponse::default(),
}))
.unwrap();
assert!(matches!(
auth,
ProxyAuthResponse::Ready {
conn_info: DatabaseInfo { .. }
conn_info: DatabaseInfoResponse { .. }
}
));

View File

@@ -1,4 +1,4 @@
use crate::{compute::DatabaseInfo, cplane_api};
use crate::cplane_api;
use anyhow::Context;
use serde::Deserialize;
use std::{
@@ -75,7 +75,7 @@ struct PsqlSessionResponse {
#[derive(Deserialize)]
enum PsqlSessionResult {
Success(DatabaseInfo),
Success(cplane_api::DatabaseInfoResponse),
Failure(String),
}

View File

@@ -1,4 +1,4 @@
use crate::auth;
use crate::auth::{self, ClientCredentials};
use crate::cancellation::{self, CancelClosure, CancelMap};
use crate::compute::DatabaseInfo;
use crate::config::{ProxyConfig, TlsConfig};
@@ -138,7 +138,6 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream.write_message(&Be::ErrorResponse(msg)).await?;
bail!(msg);
}
break Ok(Some((stream, params.try_into()?)));
}
CancelRequest(cancel_key_data) => {

View File

@@ -1,2 +1,14 @@
import pytest
def test_proxy_select_1(static_proxy):
static_proxy.safe_psql("select 1;")
def test_proxy_options(static_proxy):
schema_name = "tmp_schema_1"
with static_proxy.connect(schema=schema_name) as conn:
with conn.cursor() as cur:
cur.execute("SHOW search_path;")
search_path = cur.fetchall()[0][0]
assert schema_name == search_path

View File

@@ -242,15 +242,20 @@ class PgProtocol:
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None):
password: Optional[str] = None,
dbname: Optional[str] = None,
schema: Optional[str] = None):
self.host = host
self.port = port
self.username = username
self.password = password
self.dbname = dbname
self.schema = schema
def connstr(self,
*,
dbname: str = 'postgres',
dbname: Optional[str] = None,
schema: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None) -> str:
"""
@@ -259,6 +264,8 @@ class PgProtocol:
username = username or self.username
password = password or self.password
dbname = dbname or self.dbname or "postgres"
schema = schema or self.schema
res = f'host={self.host} port={self.port} dbname={dbname}'
if username:
@@ -267,13 +274,17 @@ class PgProtocol:
if password:
res = f'{res} password={password}'
if schema:
res = f"{res} options='-c search_path={schema}'"
return res
# autocommit=True here by default because that's what we need most of the time
def connect(self,
*,
autocommit=True,
dbname: str = 'postgres',
dbname: Optional[str] = None,
schema: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None) -> PgConnection:
"""
@@ -282,11 +293,13 @@ class PgProtocol:
This method passes all extra params to connstr.
"""
conn = psycopg2.connect(self.connstr(
dbname=dbname,
username=username,
password=password,
))
conn = psycopg2.connect(
self.connstr(
dbname=dbname,
schema=schema,
username=username,
password=password,
))
# WARNING: this setting affects *all* tests!
conn.autocommit = autocommit
return conn