mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-26 17:40:37 +00:00
Compare commits
14 Commits
amasterov/
...
conrad/lak
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f9df93655f | ||
|
|
3778081b7b | ||
|
|
2ddbd3cc80 | ||
|
|
5a293242f8 | ||
|
|
ccea44becd | ||
|
|
2c915e2f3d | ||
|
|
2b22e0b069 | ||
|
|
739ecc6f6d | ||
|
|
725aed694b | ||
|
|
d0e579c026 | ||
|
|
56cc55d24a | ||
|
|
da6419a45a | ||
|
|
e5f5c79eb1 | ||
|
|
314babb0cb |
@@ -45,6 +45,10 @@ pub(super) async fn authenticate(
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
}
|
||||
}
|
||||
AuthSecret::Cleartext => {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
return super::hacks::authenticate_cleartext(ctx, creds, client, secret, config).await;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ComputeCredentials {
|
||||
|
||||
@@ -192,7 +192,8 @@ async fn authenticate(
|
||||
};
|
||||
|
||||
let conn_info = compute::ConnectInfo {
|
||||
host: db_info.host.into(),
|
||||
host: db_info.host.as_ref().into(),
|
||||
server_name: db_info.host.into_string(),
|
||||
port: db_info.port,
|
||||
ssl_mode,
|
||||
host_addr: None,
|
||||
|
||||
@@ -33,6 +33,7 @@ impl LocalBackend {
|
||||
conn_info: ConnectInfo {
|
||||
host_addr: Some(postgres_addr.ip()),
|
||||
host: postgres_addr.ip().to_string().into(),
|
||||
server_name: postgres_addr.ip().to_string(),
|
||||
port: postgres_addr.port(),
|
||||
ssl_mode: SslMode::Disable,
|
||||
},
|
||||
|
||||
@@ -74,6 +74,11 @@ impl std::fmt::Display for Backend<'_, ()> {
|
||||
.debug_tuple("ControlPlane::ProxyV1")
|
||||
.field(&endpoint.url())
|
||||
.finish(),
|
||||
ControlPlaneClient::LakebaseV1(lb) => fmt
|
||||
.debug_tuple("ControlPlane::LakebaseV1")
|
||||
.field(&lb.namespace)
|
||||
.field(&lb.port)
|
||||
.finish(),
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ControlPlaneClient::PostgresMock(endpoint) => {
|
||||
let url = endpoint.url();
|
||||
@@ -169,6 +174,8 @@ impl ComputeUserInfo {
|
||||
|
||||
#[cfg_attr(test, derive(Debug))]
|
||||
pub(crate) enum ComputeCredentialKeys {
|
||||
/// We don't convert passwords into auth keys, we just pass passwords onto postgres.
|
||||
Password(Vec<u8>),
|
||||
AuthKeys(AuthKeys),
|
||||
JwtPayload(Vec<u8>),
|
||||
}
|
||||
@@ -239,11 +246,13 @@ async fn auth_quirks(
|
||||
let secret = if let Some(secret) = role_access.secret {
|
||||
secret
|
||||
} else {
|
||||
// If we don't have an authentication secret, we mock one to
|
||||
// prevent malicious probing (possible due to missing protocol steps).
|
||||
// This mocked secret will never lead to successful authentication.
|
||||
info!("authentication info not found, mocking it");
|
||||
AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
|
||||
// // If we don't have an authentication secret, we mock one to
|
||||
// // prevent malicious probing (possible due to missing protocol steps).
|
||||
// // This mocked secret will never lead to successful authentication.
|
||||
// info!("authentication info not found, mocking it");
|
||||
// AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
|
||||
|
||||
AuthSecret::Cleartext
|
||||
};
|
||||
|
||||
match authenticate_with_secret(
|
||||
|
||||
@@ -62,6 +62,9 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
|
||||
pub(crate) fn endpoint_sni(sni: &str, common_names: &HashSet<String>) -> Option<EndpointId> {
|
||||
let (subdomain, common_name) = sni.split_once('.')?;
|
||||
if subdomain.starts_with("instance-") {
|
||||
return Some(EndpointId::from(subdomain));
|
||||
}
|
||||
if !common_names.contains(common_name) {
|
||||
return None;
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let sasl::Outcome::Success(_) = &outcome {
|
||||
if let sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(_)) = &outcome {
|
||||
self.stream.write_message(BeMessage::AuthenticationOk);
|
||||
}
|
||||
|
||||
@@ -187,5 +187,8 @@ pub(crate) async fn validate_password_and_exchange(
|
||||
postgres_client::config::AuthKeys::ScramSha256(keys),
|
||||
)))
|
||||
}
|
||||
AuthSecret::Cleartext => Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
|
||||
password.to_vec(),
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,9 +5,7 @@ use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
use anyhow::Context;
|
||||
use anyhow::{bail, ensure};
|
||||
use anyhow::{Context, bail, ensure};
|
||||
use arc_swap::ArcSwapOption;
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
use camino::Utf8PathBuf;
|
||||
@@ -39,6 +37,7 @@ use crate::config::{
|
||||
ProxyConfig, ProxyProtocolV2, remote_storage_from_toml,
|
||||
};
|
||||
use crate::context::parquet::ParquetUploadArgs;
|
||||
use crate::control_plane::client::lakebase_v1::LakebaseClient;
|
||||
use crate::http::health_server::AppMetrics;
|
||||
use crate::metrics::{Metrics, ServiceInfo};
|
||||
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
|
||||
@@ -66,6 +65,9 @@ enum AuthBackendType {
|
||||
#[clap(alias("cplane-v1"))]
|
||||
ControlPlane,
|
||||
|
||||
#[clap(alias("lakebase-v1"))]
|
||||
Lakebase,
|
||||
|
||||
#[clap(alias("link"))]
|
||||
ConsoleRedirect,
|
||||
|
||||
@@ -132,6 +134,9 @@ struct ProxyCliArgs {
|
||||
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
|
||||
#[clap(short = 'c', long, alias = "ssl-cert")]
|
||||
tls_cert: Option<PathBuf>,
|
||||
/// path to mTLS certs for client postgres connections
|
||||
#[clap(long)]
|
||||
mtls_certs: Option<PathBuf>,
|
||||
/// Allow writing TLS session keys to the given file pointed to by the environment variable `SSLKEYLOGFILE`.
|
||||
#[clap(long, alias = "allow-ssl-keylogfile")]
|
||||
allow_tls_keylogfile: bool,
|
||||
@@ -623,6 +628,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
|
||||
key_path,
|
||||
cert_path,
|
||||
args.mtls_certs.as_deref(),
|
||||
args.certs_dir.as_deref(),
|
||||
args.allow_tls_keylogfile,
|
||||
)?),
|
||||
@@ -734,6 +740,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
|
||||
match &args.auth_backend {
|
||||
AuthBackendType::ControlPlane => {}
|
||||
AuthBackendType::Lakebase => {}
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthBackendType::Postgres => {}
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
@@ -828,6 +835,19 @@ fn build_auth_backend(
|
||||
Ok(Either::Left(config))
|
||||
}
|
||||
|
||||
AuthBackendType::Lakebase => {
|
||||
let url: url::Url = args.auth_endpoint.parse()?;
|
||||
let namespace = url.host_str().context("missing hostname as namespace")?;
|
||||
let port = url.port().unwrap_or(5432);
|
||||
|
||||
let api = LakebaseClient::new(namespace.to_owned(), port);
|
||||
let api = control_plane::client::ControlPlaneClient::LakebaseV1(api);
|
||||
let auth_backend = auth::Backend::ControlPlane(MaybeOwned::Owned(api), ());
|
||||
let config = Box::leak(Box::new(auth_backend));
|
||||
|
||||
Ok(Either::Left(config))
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthBackendType::Postgres => {
|
||||
let mut url: ApiUrl = args.auth_endpoint.parse()?;
|
||||
|
||||
@@ -137,7 +137,7 @@ pub(crate) struct AuthInfo {
|
||||
/// None for local-proxy, as we use trust-based localhost auth.
|
||||
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
|
||||
/// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
|
||||
auth: Option<Auth>,
|
||||
pub(crate) auth: Option<Auth>,
|
||||
server_params: StartupMessageParams,
|
||||
|
||||
channel_binding: ChannelBinding,
|
||||
@@ -151,6 +151,7 @@ pub(crate) struct AuthInfo {
|
||||
pub struct ConnectInfo {
|
||||
pub host_addr: Option<IpAddr>,
|
||||
pub host: Host,
|
||||
pub server_name: String,
|
||||
pub port: u16,
|
||||
pub ssl_mode: SslMode,
|
||||
}
|
||||
@@ -176,6 +177,7 @@ impl AuthInfo {
|
||||
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
|
||||
Some(Auth::Scram(Box::new(auth_keys)))
|
||||
}
|
||||
ComputeCredentialKeys::Password(pw) => Some(Auth::Password(pw)),
|
||||
ComputeCredentialKeys::JwtPayload(_) => None,
|
||||
},
|
||||
server_params: StartupMessageParams::default(),
|
||||
@@ -303,6 +305,7 @@ impl ConnectInfo {
|
||||
// require for our business.
|
||||
let port = self.port;
|
||||
let host = &*self.host;
|
||||
let server_name = &*self.server_name;
|
||||
|
||||
let addrs = match self.host_addr {
|
||||
Some(addr) => vec![SocketAddr::new(addr, port)],
|
||||
@@ -312,7 +315,7 @@ impl ConnectInfo {
|
||||
match connect_once(&*addrs).await {
|
||||
Ok((sockaddr, stream)) => Ok((
|
||||
sockaddr,
|
||||
tls::connect_tls(stream, self.ssl_mode, config, host, tls).await?,
|
||||
tls::connect_tls(stream, self.ssl_mode, config, server_name, tls).await?,
|
||||
)),
|
||||
Err(err) => {
|
||||
warn!("couldn't connect to compute node at {host}:{port}: {err}");
|
||||
|
||||
@@ -493,6 +493,7 @@ pub(crate) async fn refresh_config_inner(
|
||||
tls_config.key_path.as_ref(),
|
||||
tls_config.cert_path.as_ref(),
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
})
|
||||
|
||||
@@ -15,6 +15,7 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::pglb::ClientRequestError;
|
||||
use crate::pglb::handshake::{HandshakeData, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::pqproto::CancelKeyData;
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::{
|
||||
ErrorSource, connect_compute, forward_compute_params_to_client, send_client_greeting,
|
||||
@@ -179,7 +180,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
HandshakeData::Cancel(_, cancel_key_data) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
@@ -207,7 +208,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let (node_info, mut auth_info, user_info) = match backend
|
||||
let (node_info, mut auth_info, _user_info) = match backend
|
||||
.authenticate(ctx, &config.authentication_config, &mut stream)
|
||||
.await
|
||||
{
|
||||
@@ -229,37 +230,36 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
.authenticate(ctx, &mut node)
|
||||
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
|
||||
.await?;
|
||||
send_client_greeting(ctx, &config.greetings, &mut stream);
|
||||
send_client_greeting(ctx, &config.greetings, &mut stream, node.socket_addr);
|
||||
|
||||
let session = cancellation_handler.get_key();
|
||||
// let session = cancellation_handler.get_key();
|
||||
|
||||
let (process_id, secret_key) =
|
||||
forward_compute_params_to_client(ctx, *session.key(), &mut stream, &mut node.stream)
|
||||
.await?;
|
||||
let (_process_id, _secret_key) =
|
||||
forward_compute_params_to_client(ctx, None, &mut stream, &mut node.stream).await?;
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
let hostname = node.hostname.to_string();
|
||||
// let hostname = node.hostname.to_string();
|
||||
|
||||
let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
session
|
||||
.maintain_cancel_key(
|
||||
session_id,
|
||||
cancel,
|
||||
&CancelClosure {
|
||||
socket_addr: node.socket_addr,
|
||||
cancel_token: RawCancelToken {
|
||||
ssl_mode: node.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
hostname,
|
||||
user_info,
|
||||
},
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
// let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, _cancel) = tokio::sync::oneshot::channel();
|
||||
// tokio::spawn(async move {
|
||||
// session
|
||||
// .maintain_cancel_key(
|
||||
// session_id,
|
||||
// cancel,
|
||||
// &CancelClosure {
|
||||
// socket_addr: node.socket_addr,
|
||||
// cancel_token: RawCancelToken {
|
||||
// ssl_mode: node.ssl_mode,
|
||||
// process_id,
|
||||
// secret_key,
|
||||
// },
|
||||
// hostname,
|
||||
// user_info,
|
||||
// },
|
||||
// &config.connect_to_compute,
|
||||
// )
|
||||
// .await;
|
||||
// });
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
|
||||
@@ -31,7 +31,7 @@ use crate::control_plane::{
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::retry::CouldRetry;
|
||||
use crate::rate_limiter::WakeComputeRateLimiter;
|
||||
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
|
||||
use crate::types::{EndpointCacheKey, EndpointId, Host, RoleName};
|
||||
use crate::{compute, http, scram};
|
||||
|
||||
pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
|
||||
@@ -313,7 +313,7 @@ impl NeonControlPlaneClient {
|
||||
Some(_) => SslMode::Require,
|
||||
None => SslMode::Disable,
|
||||
};
|
||||
let host = match body.server_name {
|
||||
let host: Host = match body.server_name {
|
||||
Some(host) => host.into(),
|
||||
None => host.into(),
|
||||
};
|
||||
@@ -321,6 +321,7 @@ impl NeonControlPlaneClient {
|
||||
let node = NodeInfo {
|
||||
conn_info: compute::ConnectInfo {
|
||||
host_addr,
|
||||
server_name: host.to_string(),
|
||||
host,
|
||||
port,
|
||||
ssl_mode,
|
||||
|
||||
104
proxy/src/control_plane/client/lakebase_v1.rs
Normal file
104
proxy/src/control_plane/client/lakebase_v1.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
//! Production console backend.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use ::http::HeaderName;
|
||||
use postgres_client::config::SslMode;
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::compute::ConnectInfo;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::errors::{
|
||||
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
|
||||
};
|
||||
use crate::control_plane::messages::{ColdStartInfo, EndpointRateLimitConfig, MetricsAuxInfo};
|
||||
use crate::control_plane::{
|
||||
AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, NodeInfo, RoleAccessControl,
|
||||
};
|
||||
use crate::intern::{BranchIdInt, ProjectIdInt};
|
||||
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
|
||||
|
||||
pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LakebaseClient {
|
||||
pub namespace: String,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
impl LakebaseClient {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub fn new(namespace: String, port: u16) -> Self {
|
||||
Self { namespace, port }
|
||||
}
|
||||
}
|
||||
|
||||
impl super::ControlPlaneApi for LakebaseClient {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_access_control(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
_endpoint: &EndpointId,
|
||||
_role: &RoleName,
|
||||
) -> Result<RoleAccessControl, crate::control_plane::errors::GetAuthInfoError> {
|
||||
Ok(RoleAccessControl { secret: None })
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_endpoint_access_control(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
_endpoint: &EndpointId,
|
||||
_role: &RoleName,
|
||||
) -> Result<EndpointAccessControl, GetAuthInfoError> {
|
||||
Ok(EndpointAccessControl {
|
||||
allowed_ips: Arc::new(vec![]),
|
||||
allowed_vpce: Arc::new(vec![]),
|
||||
flags: AccessBlockerFlags::default(),
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
_endpoint: &EndpointId,
|
||||
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
|
||||
Err(GetEndpointJwksError::ControlPlane(
|
||||
ControlPlaneError::Transport(std::io::Error::other("unsupported")),
|
||||
))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let instance_id = user_info.endpoint.as_str();
|
||||
let namespace = self.namespace.as_str();
|
||||
|
||||
let host = format!("{instance_id}.{namespace}.svc.cluster.local").into();
|
||||
let server_name = format!("{instance_id}.online-tables.dev.databricks.com");
|
||||
let port = self.port;
|
||||
|
||||
Ok(CachedNodeInfo::new_uncached(NodeInfo {
|
||||
conn_info: ConnectInfo {
|
||||
host_addr: None,
|
||||
host,
|
||||
server_name,
|
||||
port,
|
||||
ssl_mode: SslMode::Require,
|
||||
},
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: user_info.endpoint.normalize_intern(),
|
||||
project_id: ProjectIdInt::from(&ProjectId::from("unknown")),
|
||||
branch_id: BranchIdInt::from(&BranchId::from("unknown")),
|
||||
compute_id: user_info.endpoint.as_str().into(),
|
||||
cold_start_info: ColdStartInfo::WarmCached,
|
||||
},
|
||||
}))
|
||||
}
|
||||
}
|
||||
@@ -175,12 +175,14 @@ impl MockControlPlane {
|
||||
let conn_info = match self.endpoint.host_str() {
|
||||
None => ConnectInfo {
|
||||
host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
|
||||
server_name: "localhost".into(),
|
||||
host: "localhost".into(),
|
||||
port,
|
||||
ssl_mode: SslMode::Disable,
|
||||
},
|
||||
Some(host) => ConnectInfo {
|
||||
host_addr: IpAddr::from_str(host).ok(),
|
||||
server_name: host.into(),
|
||||
host: host.into(),
|
||||
port,
|
||||
ssl_mode: SslMode::Disable,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod cplane_proxy_v1;
|
||||
pub mod lakebase_v1;
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
pub mod mock;
|
||||
|
||||
@@ -28,6 +29,8 @@ use crate::types::EndpointId;
|
||||
pub enum ControlPlaneClient {
|
||||
/// Proxy V1 control plane API
|
||||
ProxyV1(cplane_proxy_v1::NeonControlPlaneClient),
|
||||
/// Lakebase V1 mocked API.
|
||||
LakebaseV1(lakebase_v1::LakebaseClient),
|
||||
/// Local mock control plane.
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
PostgresMock(mock::MockControlPlane),
|
||||
@@ -46,6 +49,7 @@ impl ControlPlaneApi for ControlPlaneClient {
|
||||
) -> Result<RoleAccessControl, errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await,
|
||||
Self::LakebaseV1(api) => api.get_role_access_control(ctx, endpoint, role).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await,
|
||||
#[cfg(test)]
|
||||
@@ -63,6 +67,7 @@ impl ControlPlaneApi for ControlPlaneClient {
|
||||
) -> Result<EndpointAccessControl, errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
|
||||
Self::LakebaseV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
|
||||
#[cfg(test)]
|
||||
@@ -77,6 +82,7 @@ impl ControlPlaneApi for ControlPlaneClient {
|
||||
) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
|
||||
match self {
|
||||
Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
Self::LakebaseV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(test)]
|
||||
@@ -91,6 +97,7 @@ impl ControlPlaneApi for ControlPlaneClient {
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
|
||||
match self {
|
||||
Self::ProxyV1(api) => api.wake_compute(ctx, user_info).await,
|
||||
Self::LakebaseV1(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -43,6 +43,8 @@ pub mod mgmt;
|
||||
pub(crate) enum AuthSecret {
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
/// Do not authenticate, just take the cleartext password and give it to postgres.
|
||||
Cleartext,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
|
||||
@@ -75,6 +75,14 @@
|
||||
)]
|
||||
// List of temporarily allowed lints to unblock beta/nightly.
|
||||
#![allow(unknown_lints)]
|
||||
#![expect(
|
||||
unused_imports,
|
||||
dead_code,
|
||||
reason = "
|
||||
We are making minimal changes to proxy for lakebase-v2 integration.
|
||||
I don't want to delete code that will eventually be merged back in.
|
||||
"
|
||||
)]
|
||||
|
||||
pub mod binary;
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ impl ReportableError for HandshakeError {
|
||||
|
||||
pub(crate) enum HandshakeData<S> {
|
||||
Startup(PqStream<Stream<S>>, StartupMessageParams),
|
||||
Cancel(CancelKeyData),
|
||||
Cancel(Option<String>, CancelKeyData),
|
||||
}
|
||||
|
||||
/// Establish a (most probably, secure) connection with the client.
|
||||
@@ -234,8 +234,17 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
return Err(HandshakeError::ProtocolViolation);
|
||||
}
|
||||
FeStartupPacket::CancelRequest(cancel_key_data) => {
|
||||
info!(session_type = "cancellation", "successful handshake");
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data));
|
||||
let server_name = match stream.get_ref() {
|
||||
Stream::Raw { .. } => None,
|
||||
Stream::Tls { tls, .. } => tls.get_ref().1.server_name().map(String::from),
|
||||
};
|
||||
|
||||
info!(
|
||||
session_type = "cancellation",
|
||||
server_name, "successful handshake"
|
||||
);
|
||||
|
||||
break Ok(HandshakeData::Cancel(server_name, cancel_key_data));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,19 +3,23 @@ pub mod handshake;
|
||||
pub mod inprocess;
|
||||
pub mod passthrough;
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::FutureExt;
|
||||
use smol_str::ToSmolStr;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
use crate::auth;
|
||||
use crate::auth::{self, Backend};
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ControlPlaneClient;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
pub use crate::pglb::copy_bidirectional::ErrorSource;
|
||||
@@ -266,7 +270,28 @@ pub(crate) async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(client, params) => (client, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
HandshakeData::Cancel(server_name, cancel_key_data) => {
|
||||
if let Backend::ControlPlane(api, ()) = auth_backend
|
||||
&& let ControlPlaneClient::LakebaseV1(lakebase) = &**api
|
||||
{
|
||||
let pod_suffix = format!(".{}.pod.cluster.local", lakebase.namespace);
|
||||
|
||||
let pod_ip = server_name
|
||||
.as_deref()
|
||||
.and_then(|server_name| server_name.strip_suffix(&pod_suffix))
|
||||
.and_then(|pod_ip| IpAddr::from_str(&pod_ip.replace('-', ".")).ok());
|
||||
|
||||
if let Some(pod_ip) = pod_ip {
|
||||
cancellations.spawn(async move {
|
||||
let stream = TcpStream::connect((pod_ip, lakebase.port)).await?;
|
||||
crate::pqproto::cancel(stream, cancel_key_data).await?;
|
||||
anyhow::Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
|
||||
@@ -9,6 +9,7 @@ use bytes::{Buf, BufMut};
|
||||
use itertools::Itertools;
|
||||
use rand::distr::{Distribution, StandardUniform};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
|
||||
|
||||
pub type ErrorCode = [u8; 5];
|
||||
@@ -53,6 +54,18 @@ impl fmt::Debug for ProtocolVersion {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cancel(mut s: TcpStream, key: CancelKeyData) -> io::Result<()> {
|
||||
s.write_all(
|
||||
StartupHeader {
|
||||
len: 16_u32.into(),
|
||||
version: CANCEL_REQUEST_CODE,
|
||||
}
|
||||
.as_bytes(),
|
||||
)
|
||||
.await?;
|
||||
s.write_all(key.as_bytes()).await
|
||||
}
|
||||
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
|
||||
|
||||
@@ -64,7 +64,7 @@ pub(crate) async fn connect_to_compute_and_auth(
|
||||
match res {
|
||||
Ok(()) => return Ok(node),
|
||||
Err(e) => {
|
||||
if attempt < 2
|
||||
if attempt < 1
|
||||
&& let Backend::ControlPlane(cplane, user_info) = user_info
|
||||
&& let ControlPlaneClient::ProxyV1(cplane_proxy_v1) = &**cplane
|
||||
&& e.should_retry_wake_compute()
|
||||
|
||||
@@ -8,6 +8,7 @@ pub(crate) mod wake_compute;
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::TryStreamExt;
|
||||
@@ -24,6 +25,7 @@ use tokio::net::TcpStream;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::auth::backend::ComputeCredentialKeys;
|
||||
use crate::cancellation::{CancelClosure, CancellationHandler};
|
||||
use crate::compute::{ComputeConnection, PostgresError, RustlsStream};
|
||||
use crate::config::ProxyConfig;
|
||||
@@ -41,7 +43,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
_cancellation_handler: Arc<CancellationHandler>,
|
||||
client: &mut PqStream<Stream<S>>,
|
||||
mode: &ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -88,6 +90,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
auth::Backend::ControlPlane(cplane, creds) => (cplane, creds),
|
||||
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
|
||||
};
|
||||
let unauthenticated = matches!(creds.keys, ComputeCredentialKeys::Password(_));
|
||||
|
||||
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
|
||||
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
|
||||
auth_info.set_startup_params(params, params_compat);
|
||||
@@ -109,39 +113,45 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
send_client_greeting(ctx, &config.greetings, client);
|
||||
send_client_greeting(ctx, &config.greetings, client, node.socket_addr);
|
||||
|
||||
let auth::Backend::ControlPlane(_, user_info) = backend else {
|
||||
let auth::Backend::ControlPlane(_, _user_info) = backend else {
|
||||
unreachable!("ensured above");
|
||||
};
|
||||
|
||||
let session = cancellation_handler.get_key();
|
||||
// If we have a password, that means we didn't validate the password and convert
|
||||
// them into scram keys. Therefore we can only announce authentication ok now.
|
||||
if unauthenticated {
|
||||
client.write_message(BeMessage::AuthenticationOk);
|
||||
}
|
||||
|
||||
let (process_id, secret_key) =
|
||||
forward_compute_params_to_client(ctx, *session.key(), client, &mut node.stream).await?;
|
||||
let hostname = node.hostname.to_string();
|
||||
// let session = cancellation_handler.get_key();
|
||||
|
||||
let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, cancel) = oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
session
|
||||
.maintain_cancel_key(
|
||||
session_id,
|
||||
cancel,
|
||||
&CancelClosure {
|
||||
socket_addr: node.socket_addr,
|
||||
cancel_token: RawCancelToken {
|
||||
ssl_mode: node.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
hostname,
|
||||
user_info,
|
||||
},
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
let (_process_id, _secret_key) =
|
||||
forward_compute_params_to_client(ctx, None, client, &mut node.stream).await?;
|
||||
// let hostname = node.hostname.to_string();
|
||||
|
||||
// let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, _cancel) = oneshot::channel();
|
||||
// tokio::spawn(async move {
|
||||
// session
|
||||
// .maintain_cancel_key(
|
||||
// session_id,
|
||||
// cancel,
|
||||
// &CancelClosure {
|
||||
// socket_addr: node.socket_addr,
|
||||
// cancel_token: RawCancelToken {
|
||||
// ssl_mode: node.ssl_mode,
|
||||
// process_id,
|
||||
// secret_key,
|
||||
// },
|
||||
// hostname,
|
||||
// user_info,
|
||||
// },
|
||||
// &config.connect_to_compute,
|
||||
// )
|
||||
// .await;
|
||||
// });
|
||||
|
||||
Ok((node, cancel_on_shutdown))
|
||||
}
|
||||
@@ -151,6 +161,7 @@ pub(crate) fn send_client_greeting(
|
||||
ctx: &RequestContext,
|
||||
greetings: &String,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
socket_addr: SocketAddr,
|
||||
) {
|
||||
// Expose session_id to clients if we have a greeting message.
|
||||
if !greetings.is_empty() {
|
||||
@@ -158,6 +169,12 @@ pub(crate) fn send_client_greeting(
|
||||
client.write_message(BeMessage::NoticeResponse(session_msg.as_str()));
|
||||
}
|
||||
|
||||
// needed for RI to know what IP to send cancellation to.
|
||||
client.write_message(BeMessage::ParameterStatus {
|
||||
name: "upstream_ip".as_bytes(),
|
||||
value: socket_addr.ip().to_string().as_bytes(),
|
||||
});
|
||||
|
||||
// Forward recorded latencies for probing requests
|
||||
if let Some(testodrome_id) = ctx.get_testodrome_id() {
|
||||
client.write_message(BeMessage::ParameterStatus {
|
||||
@@ -191,7 +208,7 @@ pub(crate) fn send_client_greeting(
|
||||
|
||||
pub(crate) async fn forward_compute_params_to_client(
|
||||
ctx: &RequestContext,
|
||||
cancel_key_data: CancelKeyData,
|
||||
cancel_key_data: Option<CancelKeyData>,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
compute: &mut StartupStream<TcpStream, RustlsStream>,
|
||||
) -> Result<(i32, i32), ClientRequestError> {
|
||||
@@ -210,9 +227,16 @@ pub(crate) async fn forward_compute_params_to_client(
|
||||
match msg {
|
||||
// Send our cancellation key data instead.
|
||||
Some(Message::BackendKeyData(body)) => {
|
||||
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
|
||||
process_id = body.process_id();
|
||||
secret_key = body.secret_key();
|
||||
|
||||
let cancel_key_data = cancel_key_data.unwrap_or_else(|| {
|
||||
let pid = process_id as u32;
|
||||
let key = secret_key as u32;
|
||||
CancelKeyData(((pid as u64) << 32 | (key as u64)).into())
|
||||
});
|
||||
|
||||
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
|
||||
}
|
||||
// Forward all postgres connection params to the client.
|
||||
Some(Message::ParameterStatus(body)) => {
|
||||
|
||||
@@ -48,7 +48,7 @@ async fn proxy_mitm(
|
||||
.unwrap()
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
|
||||
HandshakeData::Cancel(_, _) => panic!("cancellation not supported"),
|
||||
};
|
||||
|
||||
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
//! A group of high-level tests for connection establishing logic and auth.
|
||||
#![allow(clippy::unimplemented)]
|
||||
|
||||
mod mitm;
|
||||
// disabled as we removed support for channel binding.
|
||||
// mod mitm;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -181,7 +182,7 @@ async fn dummy_proxy(
|
||||
) -> anyhow::Result<()> {
|
||||
let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
|
||||
HandshakeData::Cancel(_, _) => bail!("cancellation not supported"),
|
||||
};
|
||||
|
||||
auth.authenticate(&mut stream).await?;
|
||||
@@ -296,7 +297,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
));
|
||||
|
||||
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
|
||||
.channel_binding(postgres_client::config::ChannelBinding::Require)
|
||||
.channel_binding(postgres_client::config::ChannelBinding::Disable)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(password)
|
||||
@@ -531,6 +532,7 @@ fn helper_create_uncached_node_info() -> NodeInfo {
|
||||
NodeInfo {
|
||||
conn_info: compute::ConnectInfo {
|
||||
host: "test".into(),
|
||||
server_name: "test".into(),
|
||||
port: 5432,
|
||||
ssl_mode: SslMode::Disable,
|
||||
host_addr: None,
|
||||
|
||||
@@ -800,7 +800,7 @@ async fn handle_rest_inner(
|
||||
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
|
||||
Some(payload)
|
||||
}
|
||||
ComputeCredentialKeys::AuthKeys(_) => None,
|
||||
ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::Password(_) => None,
|
||||
};
|
||||
|
||||
// read the role from the jwt claims (and set it to the "anon" role if not present)
|
||||
|
||||
@@ -86,10 +86,10 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
/// Read a postgres password message, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
|
||||
// passwords are usually pretty short
|
||||
// passwords are usually pretty short, but JWTs are quite long.
|
||||
// and SASL SCRAM messages are no longer than 256 bytes in my testing
|
||||
// (a few hashes and random bytes, encoded into base64).
|
||||
const MAX_PASSWORD_LENGTH: u32 = 512;
|
||||
const MAX_PASSWORD_LENGTH: u32 = 2048;
|
||||
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -4,8 +4,10 @@ use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, bail};
|
||||
use itertools::Itertools;
|
||||
use rustls::RootCertStore;
|
||||
use rustls::crypto::ring::{self, sign};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use rustls::server::WebPkiClientVerifier;
|
||||
use rustls::sign::CertifiedKey;
|
||||
use x509_cert::der::{Reader, SliceReader};
|
||||
|
||||
@@ -24,12 +26,37 @@ pub struct TlsConfig {
|
||||
pub fn configure_tls(
|
||||
key_path: &Path,
|
||||
cert_path: &Path,
|
||||
mtls_certs_dir: Option<&Path>,
|
||||
certs_dir: Option<&Path>,
|
||||
allow_tls_keylogfile: bool,
|
||||
) -> anyhow::Result<TlsConfig> {
|
||||
// add default certificate
|
||||
let mut cert_resolver = CertResolver::parse_new(key_path, cert_path)?;
|
||||
|
||||
let verifier = match mtls_certs_dir {
|
||||
Some(dir) => {
|
||||
let mut roots = RootCertStore::empty();
|
||||
|
||||
for entry in std::fs::read_dir(dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.is_file() {
|
||||
let cert_chain_bytes = std::fs::read(&path).context(format!(
|
||||
"Failed to read TLS cert file at '{}.'",
|
||||
path.display()
|
||||
))?;
|
||||
|
||||
for cert in rustls_pemfile::certs(&mut &cert_chain_bytes[..]) {
|
||||
roots.add(cert?)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
WebPkiClientVerifier::builder(Arc::new(roots)).build()?
|
||||
}
|
||||
None => WebPkiClientVerifier::no_client_auth(),
|
||||
};
|
||||
|
||||
// add extra certificates
|
||||
if let Some(certs_dir) = certs_dir {
|
||||
for entry in std::fs::read_dir(certs_dir)? {
|
||||
@@ -55,7 +82,7 @@ pub fn configure_tls(
|
||||
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
|
||||
.context("ring should support TLS1.2 and TLS1.3")?
|
||||
.with_no_client_auth()
|
||||
.with_client_cert_verifier(verifier)
|
||||
.with_cert_resolver(cert_resolver.clone());
|
||||
|
||||
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
|
||||
@@ -161,7 +188,8 @@ fn process_key_cert(
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let first_cert = &cert_chain[0];
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
// let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
let tls_server_end_point = TlsServerEndPoint::Undefined;
|
||||
|
||||
let certificate = SliceReader::new(first_cert)
|
||||
.context("Failed to parse cerficiate")?
|
||||
|
||||
@@ -156,7 +156,7 @@ def test_auth_errors(static_proxy: NeonProxy):
|
||||
with pytest.raises(psycopg2.Error) as exprinfo:
|
||||
static_proxy.connect(user="pinocchio")
|
||||
text = str(exprinfo.value).strip()
|
||||
assert text.find("password authentication failed for user 'pinocchio'") != -1
|
||||
assert text.find("password authentication failed for user \"pinocchio\"") != -1
|
||||
|
||||
static_proxy.safe_psql(
|
||||
"create role pinocchio with login password 'magic'",
|
||||
|
||||
Reference in New Issue
Block a user