mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-05 11:40:37 +00:00
Compare commits
3 Commits
split-prox
...
bojan/prox
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1c40c26313 | ||
|
|
a6ace609a7 | ||
|
|
29d72e8955 |
@@ -7,11 +7,13 @@ use std::collections::HashMap;
|
|||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe};
|
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe};
|
||||||
|
|
||||||
|
// TODO rename the struct to ClientParams or something
|
||||||
/// Various client credentials which we use for authentication.
|
/// Various client credentials which we use for authentication.
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub struct ClientCredentials {
|
pub struct ClientCredentials {
|
||||||
pub user: String,
|
pub user: String,
|
||||||
pub dbname: String,
|
pub dbname: String,
|
||||||
|
pub options: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
||||||
@@ -25,9 +27,22 @@ impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let user = get_param("user")?;
|
let user = get_param("user")?;
|
||||||
let db = get_param("database")?;
|
let dbname = get_param("database")?;
|
||||||
|
|
||||||
Ok(Self { user, dbname: db })
|
// TODO see what other options should be recognized, possibly all.
|
||||||
|
let options = match get_param("search_path") {
|
||||||
|
Ok(path) => Some(format!("-c search_path={}", path)),
|
||||||
|
Err(_) => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO investigate why "" is always a key
|
||||||
|
// TODO warn on unrecognized options?
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
user,
|
||||||
|
dbname,
|
||||||
|
options,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,6 +100,7 @@ async fn handle_static(
|
|||||||
dbname: creds.dbname.clone(),
|
dbname: creds.dbname.clone(),
|
||||||
user: creds.user.clone(),
|
user: creds.user.clone(),
|
||||||
password: Some(cleartext_password.into()),
|
password: Some(cleartext_password.into()),
|
||||||
|
options: creds.options,
|
||||||
};
|
};
|
||||||
|
|
||||||
client
|
client
|
||||||
@@ -117,15 +133,22 @@ async fn handle_existing_user(
|
|||||||
.ok_or_else(|| anyhow!("unexpected password message"))?;
|
.ok_or_else(|| anyhow!("unexpected password message"))?;
|
||||||
|
|
||||||
let cplane = CPlaneApi::new(&config.auth_endpoint);
|
let cplane = CPlaneApi::new(&config.auth_endpoint);
|
||||||
let db_info = cplane
|
let db_info_response = cplane
|
||||||
.authenticate_proxy_request(creds, md5_response, &md5_salt, &psql_session_id)
|
.authenticate_proxy_request(&creds, md5_response, &md5_salt, &psql_session_id)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
client
|
client
|
||||||
.write_message_noflush(&Be::AuthenticationOk)?
|
.write_message_noflush(&Be::AuthenticationOk)?
|
||||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||||
|
|
||||||
Ok(db_info)
|
Ok(DatabaseInfo {
|
||||||
|
host: db_info_response.host,
|
||||||
|
port: db_info_response.port,
|
||||||
|
dbname: db_info_response.dbname,
|
||||||
|
user: db_info_response.user,
|
||||||
|
password: db_info_response.password,
|
||||||
|
options: creds.options,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_new_user(
|
async fn handle_new_user(
|
||||||
@@ -135,7 +158,7 @@ async fn handle_new_user(
|
|||||||
let psql_session_id = new_psql_session_id();
|
let psql_session_id = new_psql_session_id();
|
||||||
let greeting = hello_message(&config.redirect_uri, &psql_session_id);
|
let greeting = hello_message(&config.redirect_uri, &psql_session_id);
|
||||||
|
|
||||||
let db_info = cplane_api::with_waiter(psql_session_id, |waiter| async {
|
let db_info_response = cplane_api::with_waiter(psql_session_id, |waiter| async {
|
||||||
// Give user a URL to spawn a new database
|
// Give user a URL to spawn a new database
|
||||||
client
|
client
|
||||||
.write_message_noflush(&Be::AuthenticationOk)?
|
.write_message_noflush(&Be::AuthenticationOk)?
|
||||||
@@ -150,7 +173,14 @@ async fn handle_new_user(
|
|||||||
|
|
||||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
|
client.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
|
||||||
|
|
||||||
Ok(db_info)
|
Ok(DatabaseInfo {
|
||||||
|
host: db_info_response.host,
|
||||||
|
port: db_info_response.port,
|
||||||
|
dbname: db_info_response.dbname,
|
||||||
|
user: db_info_response.user,
|
||||||
|
password: db_info_response.password,
|
||||||
|
options: None,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
|
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ pub struct DatabaseInfo {
|
|||||||
pub dbname: String,
|
pub dbname: String,
|
||||||
pub user: String,
|
pub user: String,
|
||||||
pub password: Option<String>,
|
pub password: Option<String>,
|
||||||
|
pub options: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DatabaseInfo {
|
impl DatabaseInfo {
|
||||||
@@ -33,6 +34,10 @@ impl From<DatabaseInfo> for tokio_postgres::Config {
|
|||||||
.dbname(&db_info.dbname)
|
.dbname(&db_info.dbname)
|
||||||
.user(&db_info.user);
|
.user(&db_info.user);
|
||||||
|
|
||||||
|
if let Some(options) = db_info.options {
|
||||||
|
config.options(&options);
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(password) = db_info.password {
|
if let Some(password) = db_info.password {
|
||||||
config.password(password);
|
config.password(password);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,25 +1,37 @@
|
|||||||
use crate::auth::ClientCredentials;
|
use crate::auth::ClientCredentials;
|
||||||
use crate::compute::DatabaseInfo;
|
|
||||||
use crate::waiters::{Waiter, Waiters};
|
use crate::waiters::{Waiter, Waiters};
|
||||||
use anyhow::{anyhow, bail};
|
use anyhow::{anyhow, bail};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Part of the legacy cplane responses
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||||
|
pub struct DatabaseInfoResponse {
|
||||||
|
pub host: String,
|
||||||
|
pub port: u16,
|
||||||
|
pub dbname: String,
|
||||||
|
pub user: String,
|
||||||
|
pub password: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref CPLANE_WAITERS: Waiters<Result<DatabaseInfo, String>> = Default::default();
|
static ref CPLANE_WAITERS: Waiters<Result<DatabaseInfoResponse, String>> = Default::default();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Give caller an opportunity to wait for cplane's reply.
|
/// Give caller an opportunity to wait for cplane's reply.
|
||||||
pub async fn with_waiter<F, R, T>(psql_session_id: impl Into<String>, f: F) -> anyhow::Result<T>
|
pub async fn with_waiter<F, R, T>(psql_session_id: impl Into<String>, f: F) -> anyhow::Result<T>
|
||||||
where
|
where
|
||||||
F: FnOnce(Waiter<'static, Result<DatabaseInfo, String>>) -> R,
|
F: FnOnce(Waiter<'static, Result<DatabaseInfoResponse, String>>) -> R,
|
||||||
R: std::future::Future<Output = anyhow::Result<T>>,
|
R: std::future::Future<Output = anyhow::Result<T>>,
|
||||||
{
|
{
|
||||||
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
|
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
|
||||||
f(waiter).await
|
f(waiter).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn notify(psql_session_id: &str, msg: Result<DatabaseInfo, String>) -> anyhow::Result<()> {
|
pub fn notify(
|
||||||
|
psql_session_id: &str,
|
||||||
|
msg: Result<DatabaseInfoResponse, String>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,11 +49,11 @@ impl<'a> CPlaneApi<'a> {
|
|||||||
impl CPlaneApi<'_> {
|
impl CPlaneApi<'_> {
|
||||||
pub async fn authenticate_proxy_request(
|
pub async fn authenticate_proxy_request(
|
||||||
&self,
|
&self,
|
||||||
creds: ClientCredentials,
|
creds: &ClientCredentials,
|
||||||
md5_response: &[u8],
|
md5_response: &[u8],
|
||||||
salt: &[u8; 4],
|
salt: &[u8; 4],
|
||||||
psql_session_id: &str,
|
psql_session_id: &str,
|
||||||
) -> anyhow::Result<DatabaseInfo> {
|
) -> anyhow::Result<DatabaseInfoResponse> {
|
||||||
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
|
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
|
||||||
url.query_pairs_mut()
|
url.query_pairs_mut()
|
||||||
.append_pair("login", &creds.user)
|
.append_pair("login", &creds.user)
|
||||||
@@ -77,7 +89,7 @@ impl CPlaneApi<'_> {
|
|||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
enum ProxyAuthResponse {
|
enum ProxyAuthResponse {
|
||||||
Ready { conn_info: DatabaseInfo },
|
Ready { conn_info: DatabaseInfoResponse },
|
||||||
Error { error: String },
|
Error { error: String },
|
||||||
NotReady { ready: bool }, // TODO: get rid of `ready`
|
NotReady { ready: bool }, // TODO: get rid of `ready`
|
||||||
}
|
}
|
||||||
@@ -92,13 +104,13 @@ mod tests {
|
|||||||
// Ready
|
// Ready
|
||||||
let auth: ProxyAuthResponse = serde_json::from_value(json!({
|
let auth: ProxyAuthResponse = serde_json::from_value(json!({
|
||||||
"ready": true,
|
"ready": true,
|
||||||
"conn_info": DatabaseInfo::default(),
|
"conn_info": DatabaseInfoResponse::default(),
|
||||||
}))
|
}))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
auth,
|
auth,
|
||||||
ProxyAuthResponse::Ready {
|
ProxyAuthResponse::Ready {
|
||||||
conn_info: DatabaseInfo { .. }
|
conn_info: DatabaseInfoResponse { .. }
|
||||||
}
|
}
|
||||||
));
|
));
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use crate::{compute::DatabaseInfo, cplane_api};
|
use crate::cplane_api;
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::{
|
use std::{
|
||||||
@@ -75,7 +75,7 @@ struct PsqlSessionResponse {
|
|||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
enum PsqlSessionResult {
|
enum PsqlSessionResult {
|
||||||
Success(DatabaseInfo),
|
Success(cplane_api::DatabaseInfoResponse),
|
||||||
Failure(String),
|
Failure(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use crate::auth;
|
use crate::auth::{self, ClientCredentials};
|
||||||
use crate::cancellation::{self, CancelClosure, CancelMap};
|
use crate::cancellation::{self, CancelClosure, CancelMap};
|
||||||
use crate::compute::DatabaseInfo;
|
use crate::compute::DatabaseInfo;
|
||||||
use crate::config::{ProxyConfig, TlsConfig};
|
use crate::config::{ProxyConfig, TlsConfig};
|
||||||
@@ -138,7 +138,6 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
|||||||
stream.write_message(&Be::ErrorResponse(msg)).await?;
|
stream.write_message(&Be::ErrorResponse(msg)).await?;
|
||||||
bail!(msg);
|
bail!(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
break Ok(Some((stream, params.try_into()?)));
|
break Ok(Some((stream, params.try_into()?)));
|
||||||
}
|
}
|
||||||
CancelRequest(cancel_key_data) => {
|
CancelRequest(cancel_key_data) => {
|
||||||
|
|||||||
@@ -1,2 +1,14 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def test_proxy_select_1(static_proxy):
|
def test_proxy_select_1(static_proxy):
|
||||||
static_proxy.safe_psql("select 1;")
|
static_proxy.safe_psql("select 1;")
|
||||||
|
|
||||||
|
|
||||||
|
def test_proxy_options(static_proxy):
|
||||||
|
schema_name = "tmp_schema_1"
|
||||||
|
with static_proxy.connect(schema=schema_name) as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute("SHOW search_path;")
|
||||||
|
search_path = cur.fetchall()[0][0]
|
||||||
|
assert schema_name == search_path
|
||||||
|
|||||||
@@ -242,15 +242,20 @@ class PgProtocol:
|
|||||||
host: str,
|
host: str,
|
||||||
port: int,
|
port: int,
|
||||||
username: Optional[str] = None,
|
username: Optional[str] = None,
|
||||||
password: Optional[str] = None):
|
password: Optional[str] = None,
|
||||||
|
dbname: Optional[str] = None,
|
||||||
|
schema: Optional[str] = None):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.username = username
|
self.username = username
|
||||||
self.password = password
|
self.password = password
|
||||||
|
self.dbname = dbname
|
||||||
|
self.schema = schema
|
||||||
|
|
||||||
def connstr(self,
|
def connstr(self,
|
||||||
*,
|
*,
|
||||||
dbname: str = 'postgres',
|
dbname: Optional[str] = None,
|
||||||
|
schema: Optional[str] = None,
|
||||||
username: Optional[str] = None,
|
username: Optional[str] = None,
|
||||||
password: Optional[str] = None) -> str:
|
password: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -259,6 +264,8 @@ class PgProtocol:
|
|||||||
|
|
||||||
username = username or self.username
|
username = username or self.username
|
||||||
password = password or self.password
|
password = password or self.password
|
||||||
|
dbname = dbname or self.dbname or "postgres"
|
||||||
|
schema = schema or self.schema
|
||||||
res = f'host={self.host} port={self.port} dbname={dbname}'
|
res = f'host={self.host} port={self.port} dbname={dbname}'
|
||||||
|
|
||||||
if username:
|
if username:
|
||||||
@@ -267,13 +274,17 @@ class PgProtocol:
|
|||||||
if password:
|
if password:
|
||||||
res = f'{res} password={password}'
|
res = f'{res} password={password}'
|
||||||
|
|
||||||
|
if schema:
|
||||||
|
res = f"{res} options='-c search_path={schema}'"
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
# autocommit=True here by default because that's what we need most of the time
|
# autocommit=True here by default because that's what we need most of the time
|
||||||
def connect(self,
|
def connect(self,
|
||||||
*,
|
*,
|
||||||
autocommit=True,
|
autocommit=True,
|
||||||
dbname: str = 'postgres',
|
dbname: Optional[str] = None,
|
||||||
|
schema: Optional[str] = None,
|
||||||
username: Optional[str] = None,
|
username: Optional[str] = None,
|
||||||
password: Optional[str] = None) -> PgConnection:
|
password: Optional[str] = None) -> PgConnection:
|
||||||
"""
|
"""
|
||||||
@@ -282,11 +293,13 @@ class PgProtocol:
|
|||||||
This method passes all extra params to connstr.
|
This method passes all extra params to connstr.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
conn = psycopg2.connect(self.connstr(
|
conn = psycopg2.connect(
|
||||||
dbname=dbname,
|
self.connstr(
|
||||||
username=username,
|
dbname=dbname,
|
||||||
password=password,
|
schema=schema,
|
||||||
))
|
username=username,
|
||||||
|
password=password,
|
||||||
|
))
|
||||||
# WARNING: this setting affects *all* tests!
|
# WARNING: this setting affects *all* tests!
|
||||||
conn.autocommit = autocommit
|
conn.autocommit = autocommit
|
||||||
return conn
|
return conn
|
||||||
|
|||||||
Reference in New Issue
Block a user