mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 09:22:55 +00:00
abstract out auth
This commit is contained in:
@@ -18,7 +18,6 @@ use futures::TryStreamExt;
|
||||
use postgres_protocol::authentication::sasl;
|
||||
use postgres_protocol::authentication::sasl::ChannelBinding;
|
||||
use postgres_protocol::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol::message::backend;
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use pq_proto::FeStartupPacket;
|
||||
@@ -28,9 +27,12 @@ use tokio::io::join;
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
use tokio_util::codec::Framed;
|
||||
|
||||
use crate::auth::backend::ComputeCredentials;
|
||||
use crate::auth_proxy::AuthProxyStream;
|
||||
use crate::auth_proxy::TLS_SERVER_END_POINT;
|
||||
use crate::console::NodeInfo;
|
||||
use crate::stream::AuthProxyStreamExt;
|
||||
use crate::ConnectionInitiatedPayload;
|
||||
use crate::PglbControlMessage;
|
||||
use crate::PglbMessage;
|
||||
use crate::{
|
||||
@@ -470,7 +472,7 @@ pub async fn handle_stream(
|
||||
|
||||
// recv connection metadata
|
||||
let first_msg = stream.try_next().await?;
|
||||
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(first_msg))) = first_msg
|
||||
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(conn_info))) = first_msg
|
||||
else {
|
||||
panic!("invalid first msg")
|
||||
};
|
||||
@@ -481,31 +483,7 @@ pub async fn handle_stream(
|
||||
panic!("invalid startup message")
|
||||
};
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let user_info = auth::ComputeUserInfoMaybeEndpoint {
|
||||
user: params.get("user").context("missing user")?.into(),
|
||||
endpoint_id: first_msg
|
||||
.server_name
|
||||
.as_deref()
|
||||
.map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()),
|
||||
options: NeonOptions::parse_params(¶ms),
|
||||
};
|
||||
|
||||
// authenticate the user
|
||||
let user_info = config.backend.as_ref().map(|()| user_info);
|
||||
let res = TLS_SERVER_END_POINT
|
||||
.scope(
|
||||
first_msg.tls_server_end_point,
|
||||
user_info.authenticate(&mut stream, &config.auth),
|
||||
)
|
||||
.await;
|
||||
|
||||
let user_info = match res {
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
return stream.throw_error(e).await?;
|
||||
}
|
||||
};
|
||||
let user_info = auth_with_user(&mut stream, config, &conn_info, ¶ms).await?;
|
||||
|
||||
// wake the compute
|
||||
let node_info = user_info.wake_compute(&RequestMonitoring::test()).await?;
|
||||
@@ -526,41 +504,7 @@ pub async fn handle_stream(
|
||||
frontend::startup_message(params.iter(), &mut buf)?;
|
||||
stream.send(PglbMessage::Postgres(buf.split())).await?;
|
||||
|
||||
// start auth with compute
|
||||
|
||||
let AuthKeys::ScramSha256(scram_keys) = node_info
|
||||
.config
|
||||
.get_auth_keys()
|
||||
.context("missing auth keys")?;
|
||||
|
||||
let mut scram = ScramSha256::new_with_keys(scram_keys, ChannelBinding::unsupported());
|
||||
|
||||
// TODO: "negotiate" the auth mechanism
|
||||
|
||||
// send auth message
|
||||
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?;
|
||||
stream.send(PglbMessage::Postgres(buf.split())).await?;
|
||||
|
||||
let PglbMessage::Postgres(mut buf2) = stream.try_next().await?.context("missing")? else {
|
||||
bail!("invalid message");
|
||||
};
|
||||
|
||||
match backend::Message::parse(&mut buf2)? {
|
||||
Some(Message::AuthenticationSaslContinue(body)) => scram.update(body.data()).await?,
|
||||
_ => bail!("invalid message"),
|
||||
};
|
||||
|
||||
frontend::sasl_response(scram.message(), &mut buf)?;
|
||||
stream.send(PglbMessage::Postgres(buf.split())).await?;
|
||||
|
||||
let PglbMessage::Postgres(mut buf2) = stream.try_next().await?.context("missing")? else {
|
||||
bail!("invalid message");
|
||||
};
|
||||
|
||||
match backend::Message::parse(&mut buf2)? {
|
||||
Some(Message::AuthenticationSaslFinal(body)) => scram.finish(body.data())?,
|
||||
_ => bail!("invalid message"),
|
||||
};
|
||||
auth_with_compute(&mut stream, &node_info).await?;
|
||||
|
||||
stream
|
||||
.send(PglbMessage::Control(PglbControlMessage::ComputeEstablish))
|
||||
@@ -568,3 +512,92 @@ pub async fn handle_stream(
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn auth_with_user(
|
||||
stream: &mut AuthProxyStream,
|
||||
config: &'static AuthProxyConfig,
|
||||
conn_info: &ConnectionInitiatedPayload,
|
||||
params: &StartupMessageParams,
|
||||
) -> anyhow::Result<crate::auth_proxy::Backend<'static, ComputeCredentials>> {
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let user_info = auth::ComputeUserInfoMaybeEndpoint {
|
||||
user: params.get("user").context("missing user")?.into(),
|
||||
endpoint_id: conn_info
|
||||
.server_name
|
||||
.as_deref()
|
||||
.map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()),
|
||||
options: NeonOptions::parse_params(params),
|
||||
};
|
||||
|
||||
// authenticate the user
|
||||
let user_info = config.backend.as_ref().map(|()| user_info);
|
||||
let res = TLS_SERVER_END_POINT
|
||||
.scope(
|
||||
conn_info.tls_server_end_point,
|
||||
user_info.authenticate(stream, &config.auth),
|
||||
)
|
||||
.await;
|
||||
|
||||
let user_info = match res {
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
return stream.throw_error(e).await?;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(user_info)
|
||||
}
|
||||
|
||||
async fn auth_with_compute(
|
||||
stream: &mut AuthProxyStream,
|
||||
node_info: &NodeInfo,
|
||||
) -> anyhow::Result<()> {
|
||||
let AuthKeys::ScramSha256(scram_keys) = node_info
|
||||
.config
|
||||
.get_auth_keys()
|
||||
.context("missing auth keys")?;
|
||||
|
||||
// compute offers sasl
|
||||
stream
|
||||
.read_backend_message(|m| match m {
|
||||
Message::AuthenticationSasl(_body) => Ok(()),
|
||||
_ => bail!("invalid message"),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
// send auth message
|
||||
let mut scram = ScramSha256::new_with_keys(scram_keys, ChannelBinding::unsupported());
|
||||
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?;
|
||||
stream.send(PglbMessage::Postgres(buf.split())).await?;
|
||||
|
||||
let cont_body = stream
|
||||
.read_backend_message(|m| match m {
|
||||
Message::AuthenticationSaslContinue(body) => Ok(body),
|
||||
_ => bail!("invalid message"),
|
||||
})
|
||||
.await?;
|
||||
scram.update(cont_body.data()).await?;
|
||||
|
||||
frontend::sasl_response(scram.message(), &mut buf)?;
|
||||
stream.send(PglbMessage::Postgres(buf.split())).await?;
|
||||
|
||||
let final_body = stream
|
||||
.read_backend_message(|m| match m {
|
||||
Message::AuthenticationSaslFinal(body) => Ok(body),
|
||||
_ => bail!("invalid message"),
|
||||
})
|
||||
.await?;
|
||||
scram.finish(final_body.data())?;
|
||||
|
||||
// wait for ok.
|
||||
stream
|
||||
.read_backend_message(|m| match m {
|
||||
Message::AuthenticationOk => Ok(()),
|
||||
_ => bail!("invalid message"),
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -3,9 +3,11 @@ use crate::config::TlsServerEndPoint;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::PglbMessage;
|
||||
use anyhow::{bail, Context};
|
||||
use bytes::BytesMut;
|
||||
|
||||
use futures::{SinkExt, TryStreamExt};
|
||||
use postgres_protocol::message::backend;
|
||||
use pq_proto::framed::{ConnectionError, Framed};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
use rustls::ServerConfig;
|
||||
@@ -317,6 +319,11 @@ pub(crate) trait AuthProxyStreamExt {
|
||||
async fn read_message(&mut self) -> io::Result<FeMessage>;
|
||||
|
||||
async fn read_password_message(&mut self) -> io::Result<bytes::Bytes>;
|
||||
|
||||
async fn read_backend_message<T>(
|
||||
&mut self,
|
||||
f: impl FnOnce(backend::Message) -> anyhow::Result<T>,
|
||||
) -> anyhow::Result<T>;
|
||||
}
|
||||
|
||||
impl AuthProxyStreamExt for AuthProxyStream {
|
||||
@@ -417,4 +424,15 @@ impl AuthProxyStreamExt for AuthProxyStream {
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_backend_message<T>(
|
||||
&mut self,
|
||||
f: impl FnOnce(backend::Message) -> anyhow::Result<T>,
|
||||
) -> anyhow::Result<T> {
|
||||
let PglbMessage::Postgres(mut buf) = self.try_next().await?.context("missing")? else {
|
||||
bail!("invalid message");
|
||||
};
|
||||
let message = backend::Message::parse(&mut buf)?.context("missing")?;
|
||||
f(message)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user