mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-13 08:22:55 +00:00
Merged changes.
This commit is contained in:
@@ -269,11 +269,15 @@ impl FeStartupPacket {
|
||||
.next()
|
||||
.context("expected even number of params in StartupMessage")?;
|
||||
if name == "options" {
|
||||
// deprecated way of passing params as cmd line args
|
||||
for cmdopt in value.split(' ') {
|
||||
let nameval: Vec<&str> = cmdopt.split('=').collect();
|
||||
//parsing options arguments "..&options=<var>:<val>,.."
|
||||
//extended example and set of options:
|
||||
//https://github.com/neondatabase/neon/blob/main/docs/rfcs/016-connection-routing.md#connection-url
|
||||
for cmdopt in value.split(',') {
|
||||
let nameval: Vec<&str> = cmdopt.split(':').collect();
|
||||
if nameval.len() == 2 {
|
||||
params.insert(nameval[0].to_string(), nameval[1].to_string());
|
||||
} else {
|
||||
//todo: inform user / throw error message if options format is wrong.
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -26,6 +26,10 @@ pub struct ClientCredentials {
|
||||
// New console API requires SNI info to determine the cluster name.
|
||||
// Other Auth backends don't need it.
|
||||
pub sni_data: Option<String>,
|
||||
|
||||
// cluster_option is passed as argument from options from url.
|
||||
// To be used to determine cluster name in case sni_data is missing.
|
||||
pub project_option: Option<String>,
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
@@ -37,10 +41,10 @@ impl ClientCredentials {
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ProjectNameError {
|
||||
#[error("SNI is missing, please upgrade the postgres client library")]
|
||||
#[error("SNI info is missing. EITHER please upgrade the postgres client library OR pass the project name as a parameter: '..&options=project:<project name>..'.")]
|
||||
Missing,
|
||||
|
||||
#[error("SNI is malformed")]
|
||||
#[error("SNI is malformed.")]
|
||||
Bad,
|
||||
}
|
||||
|
||||
@@ -49,10 +53,22 @@ impl UserFacingError for ProjectNameError {}
|
||||
impl ClientCredentials {
|
||||
/// Determine project name from SNI.
|
||||
pub fn project_name(&self) -> Result<&str, ProjectNameError> {
|
||||
// Currently project name is passed as a top level domain
|
||||
let sni = self.sni_data.as_ref().ok_or(ProjectNameError::Missing)?;
|
||||
let (first, _) = sni.split_once('.').ok_or(ProjectNameError::Bad)?;
|
||||
Ok(first)
|
||||
let ret = match &self.sni_data {
|
||||
//if sni_data exists, use it to determine project name
|
||||
Some(sni_data) => {
|
||||
sni_data
|
||||
.split_once('.')
|
||||
.ok_or(ProjectNameError::Bad)?
|
||||
.0
|
||||
}
|
||||
//otherwise use project_option if it was manually set thought ..&options=project:<name> parameter
|
||||
None => self
|
||||
.project_option
|
||||
.as_ref()
|
||||
.ok_or(ProjectNameError::Missing)?
|
||||
.as_str(),
|
||||
};
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,11 +84,17 @@ impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
||||
|
||||
let user = get_param("user")?;
|
||||
let dbname = get_param("database")?;
|
||||
let project = get_param("project");
|
||||
let project_option = match project {
|
||||
Ok(project) => Some(project),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
user,
|
||||
dbname,
|
||||
sni_data: None,
|
||||
project_option,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
251
proxy/src/auth_backend/console.rs
Normal file
251
proxy/src/auth_backend/console.rs
Normal file
@@ -0,0 +1,251 @@
|
||||
//! Declaration of Cloud API V2.
|
||||
|
||||
use crate::{
|
||||
auth::{self, AuthFlow},
|
||||
compute, scram,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::auth::ClientCredentials;
|
||||
use crate::stream::PqStream;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConsoleAuthError {
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Bad authentication secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error("Bad client credentials: {0:?}")]
|
||||
BadCredentials(crate::auth::ClientCredentials),
|
||||
|
||||
#[error("SNI info is missing. EITHER please upgrade the postgres client library OR pass ..&options=cluster:<project name>.. parameter")]
|
||||
SniMissingAndProjectNameMissing,
|
||||
|
||||
#[error("Unexpected SNI content")]
|
||||
SniWrong,
|
||||
|
||||
#[error(transparent)]
|
||||
BadUrl(#[from] url::ParseError),
|
||||
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// HTTP status (other than 200) returned by the console.
|
||||
#[error("Console responded with an HTTP status: {0}")]
|
||||
HttpStatus(reqwest::StatusCode),
|
||||
|
||||
#[error(transparent)]
|
||||
Transport(#[from] reqwest::Error),
|
||||
|
||||
#[error("Console responded with a malformed JSON: '{0}'")]
|
||||
MalformedResponse(#[from] serde_json::Error),
|
||||
|
||||
#[error("Console responded with a malformed compute address: '{0}'")]
|
||||
MalformedComputeAddress(String),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct GetRoleSecretResponse {
|
||||
role_secret: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct GetWakeComputeResponse {
|
||||
address: String,
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
pub enum AuthInfo {
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
|
||||
/// Compute node connection params provided by the cloud.
|
||||
/// Note how it implements serde traits, since we receive it over the wire.
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub dbname: String,
|
||||
pub user: String,
|
||||
|
||||
/// [Cloud API V1](super::legacy) returns cleartext password,
|
||||
/// but [Cloud API V2](super::api) implements [SCRAM](crate::scram)
|
||||
/// authentication, so we can leverage this method and cope without password.
|
||||
pub password: Option<String>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit personal and sensitive info.
|
||||
impl std::fmt::Debug for DatabaseInfo {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fmt.debug_struct("DatabaseInfo")
|
||||
.field("host", &self.host)
|
||||
.field("port", &self.port)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
fn from(db_info: DatabaseInfo) -> Self {
|
||||
let mut config = tokio_postgres::Config::new();
|
||||
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password);
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_auth_info(
|
||||
auth_endpoint: &str,
|
||||
user: &str,
|
||||
cluster: &str,
|
||||
) -> Result<AuthInfo, ConsoleAuthError> {
|
||||
let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_get_role_secret"))?;
|
||||
|
||||
url.query_pairs_mut()
|
||||
.append_pair("project", cluster)
|
||||
.append_pair("role", user);
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {}", url);
|
||||
|
||||
let resp = reqwest::get(url).await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(ConsoleAuthError::HttpStatus(resp.status()));
|
||||
}
|
||||
|
||||
let response: GetRoleSecretResponse = serde_json::from_str(resp.text().await?.as_str())?;
|
||||
|
||||
scram::ServerSecret::parse(response.role_secret.as_str())
|
||||
.map(AuthInfo::Scram)
|
||||
.ok_or(ConsoleAuthError::BadSecret)
|
||||
}
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
auth_endpoint: &str,
|
||||
cluster: &str,
|
||||
) -> Result<(String, u16), ConsoleAuthError> {
|
||||
let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_wake_compute"))?;
|
||||
url.query_pairs_mut().append_pair("project", cluster);
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {}", url);
|
||||
|
||||
let resp = reqwest::get(url).await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(ConsoleAuthError::HttpStatus(resp.status()));
|
||||
}
|
||||
|
||||
let response: GetWakeComputeResponse = serde_json::from_str(resp.text().await?.as_str())?;
|
||||
let (host, port) = response
|
||||
.address
|
||||
.split_once(':')
|
||||
.ok_or_else(|| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?;
|
||||
let port: u16 = port
|
||||
.parse()
|
||||
.map_err(|_| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?;
|
||||
|
||||
Ok((host.to_string(), port))
|
||||
}
|
||||
|
||||
pub async fn handle_user(
|
||||
auth_endpoint: &str,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<compute::NodeInfo, crate::auth::AuthError> {
|
||||
// Determine cluster name from SNI (creds.sni_data) or from creds.cluster_option.
|
||||
let cluster = match &creds.sni_data {
|
||||
//if sni_data exists, use it
|
||||
Some(sni_data) => {
|
||||
sni_data
|
||||
.split_once('.')
|
||||
.ok_or(ConsoleAuthError::SniWrong)?
|
||||
.0
|
||||
}
|
||||
//otherwise use cluster_option if it was manually set thought ..&options=cluster:<name> parameter
|
||||
None => creds
|
||||
.cluster_option
|
||||
.as_ref()
|
||||
.ok_or(ConsoleAuthError::SniMissingAndProjectNameMissing)?
|
||||
.as_str(),
|
||||
};
|
||||
|
||||
let user = creds.user.as_str();
|
||||
|
||||
// Step 1: get the auth secret
|
||||
let auth_info = get_auth_info(auth_endpoint, user, cluster).await?;
|
||||
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match auth_info {
|
||||
AuthInfo::Md5(_) => {
|
||||
// TODO: decide if we should support MD5 in api v2
|
||||
return Err(crate::auth::AuthErrorImpl::auth_failed("MD5 is not supported").into());
|
||||
}
|
||||
AuthInfo::Scram(secret) => {
|
||||
let scram = auth::Scram(&secret);
|
||||
Some(compute::ScramKeys {
|
||||
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
// Step 2: wake compute
|
||||
let (host, port) = wake_compute(auth_endpoint, cluster).await?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info: DatabaseInfo {
|
||||
host,
|
||||
port,
|
||||
dbname: creds.dbname.clone(),
|
||||
user: creds.user.clone(),
|
||||
password: None,
|
||||
},
|
||||
scram_keys,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_db_info() -> anyhow::Result<()> {
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"password": "password",
|
||||
}))?;
|
||||
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
}))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user