Compare commits

...

14 Commits

Author SHA1 Message Date
Conrad Ludgate
f9df93655f fix weird error change 2025-07-28 12:16:06 +01:00
Conrad Ludgate
3778081b7b eliminate needless reconnect 2025-07-28 12:04:00 +01:00
Conrad Ludgate
2ddbd3cc80 allow longer JWTs 2025-07-28 12:04:00 +01:00
Conrad Ludgate
5a293242f8 hack around the fact that the TLS cert is not a wildcard 2025-07-28 12:04:00 +01:00
Conrad Ludgate
ccea44becd disable channel binding 2025-07-28 12:04:00 +01:00
Conrad Ludgate
2c915e2f3d support mtls 2025-07-28 12:04:00 +01:00
Conrad Ludgate
2b22e0b069 use SNI for cancellation routing 2025-07-28 12:04:00 +01:00
Conrad Ludgate
739ecc6f6d send compute IP address to regional ingress 2025-07-28 12:04:00 +01:00
Conrad Ludgate
725aed694b do not replace cancelkeydata 2025-07-28 12:04:00 +01:00
Conrad Ludgate
d0e579c026 delay authenticationok until after connect_to_compute 2025-07-28 12:04:00 +01:00
Conrad Ludgate
56cc55d24a do not validate passwords, just forward them onto postgres
dont get access controls

add a new cleartext authsecret instead
2025-07-28 12:03:58 +01:00
Conrad Ludgate
da6419a45a expose lakebase-v1 as a flag 2025-07-28 11:59:27 +01:00
Conrad Ludgate
e5f5c79eb1 add lakebase_v1 cplane impl 2025-07-28 11:59:27 +01:00
Conrad Ludgate
314babb0cb add ignored lints for the sake of the diff 2025-07-28 11:59:27 +01:00
27 changed files with 359 additions and 89 deletions

View File

@@ -45,6 +45,10 @@ pub(super) async fn authenticate(
server_key: secret.server_key.as_bytes(),
}
}
AuthSecret::Cleartext => {
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
return super::hacks::authenticate_cleartext(ctx, creds, client, secret, config).await;
}
};
Ok(ComputeCredentials {

View File

@@ -192,7 +192,8 @@ async fn authenticate(
};
let conn_info = compute::ConnectInfo {
host: db_info.host.into(),
host: db_info.host.as_ref().into(),
server_name: db_info.host.into_string(),
port: db_info.port,
ssl_mode,
host_addr: None,

View File

@@ -33,6 +33,7 @@ impl LocalBackend {
conn_info: ConnectInfo {
host_addr: Some(postgres_addr.ip()),
host: postgres_addr.ip().to_string().into(),
server_name: postgres_addr.ip().to_string(),
port: postgres_addr.port(),
ssl_mode: SslMode::Disable,
},

View File

@@ -74,6 +74,11 @@ impl std::fmt::Display for Backend<'_, ()> {
.debug_tuple("ControlPlane::ProxyV1")
.field(&endpoint.url())
.finish(),
ControlPlaneClient::LakebaseV1(lb) => fmt
.debug_tuple("ControlPlane::LakebaseV1")
.field(&lb.namespace)
.field(&lb.port)
.finish(),
#[cfg(any(test, feature = "testing"))]
ControlPlaneClient::PostgresMock(endpoint) => {
let url = endpoint.url();
@@ -169,6 +174,8 @@ impl ComputeUserInfo {
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ComputeCredentialKeys {
/// We don't convert passwords into auth keys, we just pass passwords onto postgres.
Password(Vec<u8>),
AuthKeys(AuthKeys),
JwtPayload(Vec<u8>),
}
@@ -239,11 +246,13 @@ async fn auth_quirks(
let secret = if let Some(secret) = role_access.secret {
secret
} else {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
// // If we don't have an authentication secret, we mock one to
// // prevent malicious probing (possible due to missing protocol steps).
// // This mocked secret will never lead to successful authentication.
// info!("authentication info not found, mocking it");
// AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
AuthSecret::Cleartext
};
match authenticate_with_secret(

View File

@@ -62,6 +62,9 @@ impl ComputeUserInfoMaybeEndpoint {
pub(crate) fn endpoint_sni(sni: &str, common_names: &HashSet<String>) -> Option<EndpointId> {
let (subdomain, common_name) = sni.split_once('.')?;
if subdomain.starts_with("instance-") {
return Some(EndpointId::from(subdomain));
}
if !common_names.contains(common_name) {
return None;
}

View File

@@ -116,7 +116,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
)
.await?;
if let sasl::Outcome::Success(_) = &outcome {
if let sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(_)) = &outcome {
self.stream.write_message(BeMessage::AuthenticationOk);
}
@@ -187,5 +187,8 @@ pub(crate) async fn validate_password_and_exchange(
postgres_client::config::AuthKeys::ScramSha256(keys),
)))
}
AuthSecret::Cleartext => Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
password.to_vec(),
))),
}
}

View File

@@ -5,9 +5,7 @@ use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
#[cfg(any(test, feature = "testing"))]
use anyhow::Context;
use anyhow::{bail, ensure};
use anyhow::{Context, bail, ensure};
use arc_swap::ArcSwapOption;
#[cfg(any(test, feature = "testing"))]
use camino::Utf8PathBuf;
@@ -39,6 +37,7 @@ use crate::config::{
ProxyConfig, ProxyProtocolV2, remote_storage_from_toml,
};
use crate::context::parquet::ParquetUploadArgs;
use crate::control_plane::client::lakebase_v1::LakebaseClient;
use crate::http::health_server::AppMetrics;
use crate::metrics::{Metrics, ServiceInfo};
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
@@ -66,6 +65,9 @@ enum AuthBackendType {
#[clap(alias("cplane-v1"))]
ControlPlane,
#[clap(alias("lakebase-v1"))]
Lakebase,
#[clap(alias("link"))]
ConsoleRedirect,
@@ -132,6 +134,9 @@ struct ProxyCliArgs {
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
#[clap(short = 'c', long, alias = "ssl-cert")]
tls_cert: Option<PathBuf>,
/// path to mTLS certs for client postgres connections
#[clap(long)]
mtls_certs: Option<PathBuf>,
/// Allow writing TLS session keys to the given file pointed to by the environment variable `SSLKEYLOGFILE`.
#[clap(long, alias = "allow-ssl-keylogfile")]
allow_tls_keylogfile: bool,
@@ -623,6 +628,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
key_path,
cert_path,
args.mtls_certs.as_deref(),
args.certs_dir.as_deref(),
args.allow_tls_keylogfile,
)?),
@@ -734,6 +740,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
match &args.auth_backend {
AuthBackendType::ControlPlane => {}
AuthBackendType::Lakebase => {}
#[cfg(any(test, feature = "testing"))]
AuthBackendType::Postgres => {}
#[cfg(any(test, feature = "testing"))]
@@ -828,6 +835,19 @@ fn build_auth_backend(
Ok(Either::Left(config))
}
AuthBackendType::Lakebase => {
let url: url::Url = args.auth_endpoint.parse()?;
let namespace = url.host_str().context("missing hostname as namespace")?;
let port = url.port().unwrap_or(5432);
let api = LakebaseClient::new(namespace.to_owned(), port);
let api = control_plane::client::ControlPlaneClient::LakebaseV1(api);
let auth_backend = auth::Backend::ControlPlane(MaybeOwned::Owned(api), ());
let config = Box::leak(Box::new(auth_backend));
Ok(Either::Left(config))
}
#[cfg(any(test, feature = "testing"))]
AuthBackendType::Postgres => {
let mut url: ApiUrl = args.auth_endpoint.parse()?;

View File

@@ -137,7 +137,7 @@ pub(crate) struct AuthInfo {
/// None for local-proxy, as we use trust-based localhost auth.
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
/// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
auth: Option<Auth>,
pub(crate) auth: Option<Auth>,
server_params: StartupMessageParams,
channel_binding: ChannelBinding,
@@ -151,6 +151,7 @@ pub(crate) struct AuthInfo {
pub struct ConnectInfo {
pub host_addr: Option<IpAddr>,
pub host: Host,
pub server_name: String,
pub port: u16,
pub ssl_mode: SslMode,
}
@@ -176,6 +177,7 @@ impl AuthInfo {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(auth_keys)))
}
ComputeCredentialKeys::Password(pw) => Some(Auth::Password(pw)),
ComputeCredentialKeys::JwtPayload(_) => None,
},
server_params: StartupMessageParams::default(),
@@ -303,6 +305,7 @@ impl ConnectInfo {
// require for our business.
let port = self.port;
let host = &*self.host;
let server_name = &*self.server_name;
let addrs = match self.host_addr {
Some(addr) => vec![SocketAddr::new(addr, port)],
@@ -312,7 +315,7 @@ impl ConnectInfo {
match connect_once(&*addrs).await {
Ok((sockaddr, stream)) => Ok((
sockaddr,
tls::connect_tls(stream, self.ssl_mode, config, host, tls).await?,
tls::connect_tls(stream, self.ssl_mode, config, server_name, tls).await?,
)),
Err(err) => {
warn!("couldn't connect to compute node at {host}:{port}: {err}");

View File

@@ -493,6 +493,7 @@ pub(crate) async fn refresh_config_inner(
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
None,
None,
false,
)
})

View File

@@ -15,6 +15,7 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pglb::ClientRequestError;
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::pqproto::CancelKeyData;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::{
ErrorSource, connect_compute, forward_compute_params_to_client, send_client_greeting,
@@ -179,7 +180,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
HandshakeData::Cancel(_, cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
@@ -207,7 +208,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx.set_db_options(params.clone());
let (node_info, mut auth_info, user_info) = match backend
let (node_info, mut auth_info, _user_info) = match backend
.authenticate(ctx, &config.authentication_config, &mut stream)
.await
{
@@ -229,37 +230,36 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.authenticate(ctx, &mut node)
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
.await?;
send_client_greeting(ctx, &config.greetings, &mut stream);
send_client_greeting(ctx, &config.greetings, &mut stream, node.socket_addr);
let session = cancellation_handler.get_key();
// let session = cancellation_handler.get_key();
let (process_id, secret_key) =
forward_compute_params_to_client(ctx, *session.key(), &mut stream, &mut node.stream)
.await?;
let (_process_id, _secret_key) =
forward_compute_params_to_client(ctx, None, &mut stream, &mut node.stream).await?;
let stream = stream.flush_and_into_inner().await?;
let hostname = node.hostname.to_string();
// let hostname = node.hostname.to_string();
let session_id = ctx.session_id();
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
session
.maintain_cancel_key(
session_id,
cancel,
&CancelClosure {
socket_addr: node.socket_addr,
cancel_token: RawCancelToken {
ssl_mode: node.ssl_mode,
process_id,
secret_key,
},
hostname,
user_info,
},
&config.connect_to_compute,
)
.await;
});
// let session_id = ctx.session_id();
let (cancel_on_shutdown, _cancel) = tokio::sync::oneshot::channel();
// tokio::spawn(async move {
// session
// .maintain_cancel_key(
// session_id,
// cancel,
// &CancelClosure {
// socket_addr: node.socket_addr,
// cancel_token: RawCancelToken {
// ssl_mode: node.ssl_mode,
// process_id,
// secret_key,
// },
// hostname,
// user_info,
// },
// &config.connect_to_compute,
// )
// .await;
// });
Ok(Some(ProxyPassthrough {
client: stream,

View File

@@ -31,7 +31,7 @@ use crate::control_plane::{
use crate::metrics::Metrics;
use crate::proxy::retry::CouldRetry;
use crate::rate_limiter::WakeComputeRateLimiter;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::types::{EndpointCacheKey, EndpointId, Host, RoleName};
use crate::{compute, http, scram};
pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
@@ -313,7 +313,7 @@ impl NeonControlPlaneClient {
Some(_) => SslMode::Require,
None => SslMode::Disable,
};
let host = match body.server_name {
let host: Host = match body.server_name {
Some(host) => host.into(),
None => host.into(),
};
@@ -321,6 +321,7 @@ impl NeonControlPlaneClient {
let node = NodeInfo {
conn_info: compute::ConnectInfo {
host_addr,
server_name: host.to_string(),
host,
port,
ssl_mode,

View File

@@ -0,0 +1,104 @@
//! Production console backend.
use std::sync::Arc;
use ::http::HeaderName;
use postgres_client::config::SslMode;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::compute::ConnectInfo;
use crate::context::RequestContext;
use crate::control_plane::errors::{
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
};
use crate::control_plane::messages::{ColdStartInfo, EndpointRateLimitConfig, MetricsAuxInfo};
use crate::control_plane::{
AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, NodeInfo, RoleAccessControl,
};
use crate::intern::{BranchIdInt, ProjectIdInt};
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
#[derive(Clone)]
pub struct LakebaseClient {
pub namespace: String,
pub port: u16,
}
impl LakebaseClient {
/// Construct an API object containing the auth parameters.
pub fn new(namespace: String, port: u16) -> Self {
Self { namespace, port }
}
}
impl super::ControlPlaneApi for LakebaseClient {
#[tracing::instrument(skip_all)]
async fn get_role_access_control(
&self,
_ctx: &RequestContext,
_endpoint: &EndpointId,
_role: &RoleName,
) -> Result<RoleAccessControl, crate::control_plane::errors::GetAuthInfoError> {
Ok(RoleAccessControl { secret: None })
}
#[tracing::instrument(skip_all)]
async fn get_endpoint_access_control(
&self,
_ctx: &RequestContext,
_endpoint: &EndpointId,
_role: &RoleName,
) -> Result<EndpointAccessControl, GetAuthInfoError> {
Ok(EndpointAccessControl {
allowed_ips: Arc::new(vec![]),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
})
}
#[tracing::instrument(skip_all)]
async fn get_endpoint_jwks(
&self,
_ctx: &RequestContext,
_endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
Err(GetEndpointJwksError::ControlPlane(
ControlPlaneError::Transport(std::io::Error::other("unsupported")),
))
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,
_ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
let instance_id = user_info.endpoint.as_str();
let namespace = self.namespace.as_str();
let host = format!("{instance_id}.{namespace}.svc.cluster.local").into();
let server_name = format!("{instance_id}.online-tables.dev.databricks.com");
let port = self.port;
Ok(CachedNodeInfo::new_uncached(NodeInfo {
conn_info: ConnectInfo {
host_addr: None,
host,
server_name,
port,
ssl_mode: SslMode::Require,
},
aux: MetricsAuxInfo {
endpoint_id: user_info.endpoint.normalize_intern(),
project_id: ProjectIdInt::from(&ProjectId::from("unknown")),
branch_id: BranchIdInt::from(&BranchId::from("unknown")),
compute_id: user_info.endpoint.as_str().into(),
cold_start_info: ColdStartInfo::WarmCached,
},
}))
}
}

View File

@@ -175,12 +175,14 @@ impl MockControlPlane {
let conn_info = match self.endpoint.host_str() {
None => ConnectInfo {
host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
server_name: "localhost".into(),
host: "localhost".into(),
port,
ssl_mode: SslMode::Disable,
},
Some(host) => ConnectInfo {
host_addr: IpAddr::from_str(host).ok(),
server_name: host.into(),
host: host.into(),
port,
ssl_mode: SslMode::Disable,

View File

@@ -1,4 +1,5 @@
pub mod cplane_proxy_v1;
pub mod lakebase_v1;
#[cfg(any(test, feature = "testing"))]
pub mod mock;
@@ -28,6 +29,8 @@ use crate::types::EndpointId;
pub enum ControlPlaneClient {
/// Proxy V1 control plane API
ProxyV1(cplane_proxy_v1::NeonControlPlaneClient),
/// Lakebase V1 mocked API.
LakebaseV1(lakebase_v1::LakebaseClient),
/// Local mock control plane.
#[cfg(any(test, feature = "testing"))]
PostgresMock(mock::MockControlPlane),
@@ -46,6 +49,7 @@ impl ControlPlaneApi for ControlPlaneClient {
) -> Result<RoleAccessControl, errors::GetAuthInfoError> {
match self {
Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await,
Self::LakebaseV1(api) => api.get_role_access_control(ctx, endpoint, role).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await,
#[cfg(test)]
@@ -63,6 +67,7 @@ impl ControlPlaneApi for ControlPlaneClient {
) -> Result<EndpointAccessControl, errors::GetAuthInfoError> {
match self {
Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
Self::LakebaseV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
#[cfg(test)]
@@ -77,6 +82,7 @@ impl ControlPlaneApi for ControlPlaneClient {
) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
match self {
Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
Self::LakebaseV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(test)]
@@ -91,6 +97,7 @@ impl ControlPlaneApi for ControlPlaneClient {
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
match self {
Self::ProxyV1(api) => api.wake_compute(ctx, user_info).await,
Self::LakebaseV1(api) => api.wake_compute(ctx, user_info).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
#[cfg(test)]

View File

@@ -43,6 +43,8 @@ pub mod mgmt;
pub(crate) enum AuthSecret {
/// [SCRAM](crate::scram) authentication info.
Scram(scram::ServerSecret),
/// Do not authenticate, just take the cleartext password and give it to postgres.
Cleartext,
}
#[derive(Default)]

View File

@@ -75,6 +75,14 @@
)]
// List of temporarily allowed lints to unblock beta/nightly.
#![allow(unknown_lints)]
#![expect(
unused_imports,
dead_code,
reason = "
We are making minimal changes to proxy for lakebase-v2 integration.
I don't want to delete code that will eventually be merged back in.
"
)]
pub mod binary;

View File

@@ -50,7 +50,7 @@ impl ReportableError for HandshakeError {
pub(crate) enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
Cancel(CancelKeyData),
Cancel(Option<String>, CancelKeyData),
}
/// Establish a (most probably, secure) connection with the client.
@@ -234,8 +234,17 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
return Err(HandshakeError::ProtocolViolation);
}
FeStartupPacket::CancelRequest(cancel_key_data) => {
info!(session_type = "cancellation", "successful handshake");
break Ok(HandshakeData::Cancel(cancel_key_data));
let server_name = match stream.get_ref() {
Stream::Raw { .. } => None,
Stream::Tls { tls, .. } => tls.get_ref().1.server_name().map(String::from),
};
info!(
session_type = "cancellation",
server_name, "successful handshake"
);
break Ok(HandshakeData::Cancel(server_name, cancel_key_data));
}
}
}

View File

@@ -3,19 +3,23 @@ pub mod handshake;
pub mod inprocess;
pub mod passthrough;
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::Arc;
use futures::FutureExt;
use smol_str::ToSmolStr;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, warn};
use crate::auth;
use crate::auth::{self, Backend};
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
pub use crate::pglb::copy_bidirectional::ErrorSource;
@@ -266,7 +270,28 @@ pub(crate) async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send>(
.await??
{
HandshakeData::Startup(client, params) => (client, params),
HandshakeData::Cancel(cancel_key_data) => {
HandshakeData::Cancel(server_name, cancel_key_data) => {
if let Backend::ControlPlane(api, ()) = auth_backend
&& let ControlPlaneClient::LakebaseV1(lakebase) = &**api
{
let pod_suffix = format!(".{}.pod.cluster.local", lakebase.namespace);
let pod_ip = server_name
.as_deref()
.and_then(|server_name| server_name.strip_suffix(&pod_suffix))
.and_then(|pod_ip| IpAddr::from_str(&pod_ip.replace('-', ".")).ok());
if let Some(pod_ip) = pod_ip {
cancellations.spawn(async move {
let stream = TcpStream::connect((pod_ip, lakebase.port)).await?;
crate::pqproto::cancel(stream, cancel_key_data).await?;
anyhow::Ok(())
});
}
return Ok(None);
}
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);

View File

@@ -9,6 +9,7 @@ use bytes::{Buf, BufMut};
use itertools::Itertools;
use rand::distr::{Distribution, StandardUniform};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
pub type ErrorCode = [u8; 5];
@@ -53,6 +54,18 @@ impl fmt::Debug for ProtocolVersion {
}
}
pub async fn cancel(mut s: TcpStream, key: CancelKeyData) -> io::Result<()> {
s.write_all(
StartupHeader {
len: 16_u32.into(),
version: CANCEL_REQUEST_CODE,
}
.as_bytes(),
)
.await?;
s.write_all(key.as_bytes()).await
}
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;

View File

@@ -64,7 +64,7 @@ pub(crate) async fn connect_to_compute_and_auth(
match res {
Ok(()) => return Ok(node),
Err(e) => {
if attempt < 2
if attempt < 1
&& let Backend::ControlPlane(cplane, user_info) = user_info
&& let ControlPlaneClient::ProxyV1(cplane_proxy_v1) = &**cplane
&& e.should_retry_wake_compute()

View File

@@ -8,6 +8,7 @@ pub(crate) mod wake_compute;
use std::collections::HashSet;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use futures::TryStreamExt;
@@ -24,6 +25,7 @@ use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tracing::Instrument;
use crate::auth::backend::ComputeCredentialKeys;
use crate::cancellation::{CancelClosure, CancellationHandler};
use crate::compute::{ComputeConnection, PostgresError, RustlsStream};
use crate::config::ProxyConfig;
@@ -41,7 +43,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
_cancellation_handler: Arc<CancellationHandler>,
client: &mut PqStream<Stream<S>>,
mode: &ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -88,6 +90,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
auth::Backend::ControlPlane(cplane, creds) => (cplane, creds),
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let unauthenticated = matches!(creds.keys, ComputeCredentialKeys::Password(_));
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
auth_info.set_startup_params(params, params_compat);
@@ -109,39 +113,45 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
};
send_client_greeting(ctx, &config.greetings, client);
send_client_greeting(ctx, &config.greetings, client, node.socket_addr);
let auth::Backend::ControlPlane(_, user_info) = backend else {
let auth::Backend::ControlPlane(_, _user_info) = backend else {
unreachable!("ensured above");
};
let session = cancellation_handler.get_key();
// If we have a password, that means we didn't validate the password and convert
// them into scram keys. Therefore we can only announce authentication ok now.
if unauthenticated {
client.write_message(BeMessage::AuthenticationOk);
}
let (process_id, secret_key) =
forward_compute_params_to_client(ctx, *session.key(), client, &mut node.stream).await?;
let hostname = node.hostname.to_string();
// let session = cancellation_handler.get_key();
let session_id = ctx.session_id();
let (cancel_on_shutdown, cancel) = oneshot::channel();
tokio::spawn(async move {
session
.maintain_cancel_key(
session_id,
cancel,
&CancelClosure {
socket_addr: node.socket_addr,
cancel_token: RawCancelToken {
ssl_mode: node.ssl_mode,
process_id,
secret_key,
},
hostname,
user_info,
},
&config.connect_to_compute,
)
.await;
});
let (_process_id, _secret_key) =
forward_compute_params_to_client(ctx, None, client, &mut node.stream).await?;
// let hostname = node.hostname.to_string();
// let session_id = ctx.session_id();
let (cancel_on_shutdown, _cancel) = oneshot::channel();
// tokio::spawn(async move {
// session
// .maintain_cancel_key(
// session_id,
// cancel,
// &CancelClosure {
// socket_addr: node.socket_addr,
// cancel_token: RawCancelToken {
// ssl_mode: node.ssl_mode,
// process_id,
// secret_key,
// },
// hostname,
// user_info,
// },
// &config.connect_to_compute,
// )
// .await;
// });
Ok((node, cancel_on_shutdown))
}
@@ -151,6 +161,7 @@ pub(crate) fn send_client_greeting(
ctx: &RequestContext,
greetings: &String,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
socket_addr: SocketAddr,
) {
// Expose session_id to clients if we have a greeting message.
if !greetings.is_empty() {
@@ -158,6 +169,12 @@ pub(crate) fn send_client_greeting(
client.write_message(BeMessage::NoticeResponse(session_msg.as_str()));
}
// needed for RI to know what IP to send cancellation to.
client.write_message(BeMessage::ParameterStatus {
name: "upstream_ip".as_bytes(),
value: socket_addr.ip().to_string().as_bytes(),
});
// Forward recorded latencies for probing requests
if let Some(testodrome_id) = ctx.get_testodrome_id() {
client.write_message(BeMessage::ParameterStatus {
@@ -191,7 +208,7 @@ pub(crate) fn send_client_greeting(
pub(crate) async fn forward_compute_params_to_client(
ctx: &RequestContext,
cancel_key_data: CancelKeyData,
cancel_key_data: Option<CancelKeyData>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
compute: &mut StartupStream<TcpStream, RustlsStream>,
) -> Result<(i32, i32), ClientRequestError> {
@@ -210,9 +227,16 @@ pub(crate) async fn forward_compute_params_to_client(
match msg {
// Send our cancellation key data instead.
Some(Message::BackendKeyData(body)) => {
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
process_id = body.process_id();
secret_key = body.secret_key();
let cancel_key_data = cancel_key_data.unwrap_or_else(|| {
let pid = process_id as u32;
let key = secret_key as u32;
CancelKeyData(((pid as u64) << 32 | (key as u64)).into())
});
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
}
// Forward all postgres connection params to the client.
Some(Message::ParameterStatus(body)) => {

View File

@@ -48,7 +48,7 @@ async fn proxy_mitm(
.unwrap()
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
HandshakeData::Cancel(_, _) => panic!("cancellation not supported"),
};
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);

View File

@@ -1,7 +1,8 @@
//! A group of high-level tests for connection establishing logic and auth.
#![allow(clippy::unimplemented)]
mod mitm;
// disabled as we removed support for channel binding.
// mod mitm;
use std::sync::Arc;
use std::time::Duration;
@@ -181,7 +182,7 @@ async fn dummy_proxy(
) -> anyhow::Result<()> {
let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
HandshakeData::Cancel(_, _) => bail!("cancellation not supported"),
};
auth.authenticate(&mut stream).await?;
@@ -296,7 +297,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
));
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.channel_binding(postgres_client::config::ChannelBinding::Require)
.channel_binding(postgres_client::config::ChannelBinding::Disable)
.user("user")
.dbname("db")
.password(password)
@@ -531,6 +532,7 @@ fn helper_create_uncached_node_info() -> NodeInfo {
NodeInfo {
conn_info: compute::ConnectInfo {
host: "test".into(),
server_name: "test".into(),
port: 5432,
ssl_mode: SslMode::Disable,
host_addr: None,

View File

@@ -800,7 +800,7 @@ async fn handle_rest_inner(
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
Some(payload)
}
ComputeCredentialKeys::AuthKeys(_) => None,
ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::Password(_) => None,
};
// read the role from the jwt claims (and set it to the "anon" role if not present)

View File

@@ -86,10 +86,10 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
/// Read a postgres password message, which will respect the max length requested.
/// This is not cancel safe.
pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
// passwords are usually pretty short
// passwords are usually pretty short, but JWTs are quite long.
// and SASL SCRAM messages are no longer than 256 bytes in my testing
// (a few hashes and random bytes, encoded into base64).
const MAX_PASSWORD_LENGTH: u32 = 512;
const MAX_PASSWORD_LENGTH: u32 = 2048;
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
.await
}

View File

@@ -4,8 +4,10 @@ use std::sync::Arc;
use anyhow::{Context, bail};
use itertools::Itertools;
use rustls::RootCertStore;
use rustls::crypto::ring::{self, sign};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::server::WebPkiClientVerifier;
use rustls::sign::CertifiedKey;
use x509_cert::der::{Reader, SliceReader};
@@ -24,12 +26,37 @@ pub struct TlsConfig {
pub fn configure_tls(
key_path: &Path,
cert_path: &Path,
mtls_certs_dir: Option<&Path>,
certs_dir: Option<&Path>,
allow_tls_keylogfile: bool,
) -> anyhow::Result<TlsConfig> {
// add default certificate
let mut cert_resolver = CertResolver::parse_new(key_path, cert_path)?;
let verifier = match mtls_certs_dir {
Some(dir) => {
let mut roots = RootCertStore::empty();
for entry in std::fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.is_file() {
let cert_chain_bytes = std::fs::read(&path).context(format!(
"Failed to read TLS cert file at '{}.'",
path.display()
))?;
for cert in rustls_pemfile::certs(&mut &cert_chain_bytes[..]) {
roots.add(cert?)?;
}
}
}
WebPkiClientVerifier::builder(Arc::new(roots)).build()?
}
None => WebPkiClientVerifier::no_client_auth(),
};
// add extra certificates
if let Some(certs_dir) = certs_dir {
for entry in std::fs::read_dir(certs_dir)? {
@@ -55,7 +82,7 @@ pub fn configure_tls(
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.context("ring should support TLS1.2 and TLS1.3")?
.with_no_client_auth()
.with_client_cert_verifier(verifier)
.with_cert_resolver(cert_resolver.clone());
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
@@ -161,7 +188,8 @@ fn process_key_cert(
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
// let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_server_end_point = TlsServerEndPoint::Undefined;
let certificate = SliceReader::new(first_cert)
.context("Failed to parse cerficiate")?

View File

@@ -156,7 +156,7 @@ def test_auth_errors(static_proxy: NeonProxy):
with pytest.raises(psycopg2.Error) as exprinfo:
static_proxy.connect(user="pinocchio")
text = str(exprinfo.value).strip()
assert text.find("password authentication failed for user 'pinocchio'") != -1
assert text.find("password authentication failed for user \"pinocchio\"") != -1
static_proxy.safe_psql(
"create role pinocchio with login password 'magic'",