mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
chore(proxy): vendor a subset of rust-postgres (#9930)
Our rust-postgres fork is getting messy. Mostly because proxy wants more
control over the raw protocol than tokio-postgres provides. As such,
it's diverging more and more. Storage and compute also make use of
rust-postgres, but in more normal usage, thus they don't need our crazy
changes.
Idea:
* proxy maintains their subset
* other teams use a minimal patch set against upstream rust-postgres
Reviewing this code will be difficult. To implement it, I
1. Copied tokio-postgres, postgres-protocol and postgres-types from
00940fcdb5
2. Updated their package names with the `2` suffix to make them compile
in the workspace.
3. Updated proxy to use those packages
4. Copied in the code from tokio-postgres-rustls 0.13 (with some patches
applied https://github.com/jbg/tokio-postgres-rustls/pull/32
https://github.com/jbg/tokio-postgres-rustls/pull/33)
5. Removed as much dead code as I could find in the vendored libraries
6. Updated the tokio-postgres-rustls code to use our existing channel
binding implementation
This commit is contained in:
@@ -13,7 +13,6 @@ use rustls::pki_types::InvalidDnsNameError;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::parse_endpoint_param;
|
||||
@@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumDbConnectionsGuard};
|
||||
use crate::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::proxy::neon_option;
|
||||
use crate::types::Host;
|
||||
|
||||
@@ -244,7 +244,6 @@ impl ConnCfg {
|
||||
let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
|
||||
let host = match host {
|
||||
Host::Tcp(host) => host.as_str(),
|
||||
Host::Unix(_) => continue, // unix sockets are not welcome here
|
||||
};
|
||||
|
||||
match connect_once(host, *port).await {
|
||||
@@ -315,7 +314,7 @@ impl ConnCfg {
|
||||
};
|
||||
let client_config = client_config.with_no_client_auth();
|
||||
|
||||
let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
|
||||
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
host,
|
||||
|
||||
@@ -414,6 +414,7 @@ impl RequestContextInner {
|
||||
outcome,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(tx) = self.sender.take() {
|
||||
// If type changes, this error handling needs to be updated.
|
||||
let tx: mpsc::UnboundedSender<RequestData> = tx;
|
||||
|
||||
@@ -88,6 +88,7 @@ pub mod jemalloc;
|
||||
pub mod logging;
|
||||
pub mod metrics;
|
||||
pub mod parse;
|
||||
pub mod postgres_rustls;
|
||||
pub mod protocol2;
|
||||
pub mod proxy;
|
||||
pub mod rate_limiter;
|
||||
|
||||
158
proxy/src/postgres_rustls/mod.rs
Normal file
158
proxy/src/postgres_rustls/mod.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use std::convert::TryFrom;
|
||||
use std::sync::Arc;
|
||||
|
||||
use rustls::pki_types::ServerName;
|
||||
use rustls::ClientConfig;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
|
||||
mod private {
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use rustls::pki_types::ServerName;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
|
||||
use tokio_rustls::client::TlsStream;
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
use crate::config::TlsServerEndPoint;
|
||||
|
||||
pub struct TlsConnectFuture<S> {
|
||||
inner: tokio_rustls::Connect<S>,
|
||||
}
|
||||
|
||||
impl<S> Future for TlsConnectFuture<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = io::Result<RustlsStream<S>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RustlsConnect(pub RustlsConnectData);
|
||||
|
||||
pub struct RustlsConnectData {
|
||||
pub hostname: ServerName<'static>,
|
||||
pub connector: TlsConnector,
|
||||
}
|
||||
|
||||
impl<S> TlsConnect<S> for RustlsConnect
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
type Stream = RustlsStream<S>;
|
||||
type Error = io::Error;
|
||||
type Future = TlsConnectFuture<S>;
|
||||
|
||||
fn connect(self, stream: S) -> Self::Future {
|
||||
TlsConnectFuture {
|
||||
inner: self.0.connector.connect(self.0.hostname, stream),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RustlsStream<S>(TlsStream<S>);
|
||||
|
||||
impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
fn channel_binding(&self) -> ChannelBinding {
|
||||
let (_, session) = self.0.get_ref();
|
||||
match session.peer_certificates() {
|
||||
Some([cert, ..]) => TlsServerEndPoint::new(cert)
|
||||
.ok()
|
||||
.and_then(|cb| match cb {
|
||||
TlsServerEndPoint::Sha256(hash) => Some(hash),
|
||||
TlsServerEndPoint::Undefined => None,
|
||||
})
|
||||
.map_or_else(ChannelBinding::none, |hash| {
|
||||
ChannelBinding::tls_server_end_point(hash.to_vec())
|
||||
}),
|
||||
_ => ChannelBinding::none(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for RustlsStream<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<tokio::io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncWrite for RustlsStream<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<tokio::io::Result<usize>> {
|
||||
Pin::new(&mut self.0).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<tokio::io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<tokio::io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A `MakeTlsConnect` implementation using `rustls`.
|
||||
///
|
||||
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
|
||||
#[derive(Clone)]
|
||||
pub struct MakeRustlsConnect {
|
||||
config: Arc<ClientConfig>,
|
||||
}
|
||||
|
||||
impl MakeRustlsConnect {
|
||||
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
|
||||
#[must_use]
|
||||
pub fn new(config: ClientConfig) -> Self {
|
||||
Self {
|
||||
config: Arc::new(config),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
type Stream = private::RustlsStream<S>;
|
||||
type TlsConnect = private::RustlsConnect;
|
||||
type Error = rustls::pki_types::InvalidDnsNameError;
|
||||
|
||||
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
|
||||
ServerName::try_from(hostname).map(|dns_name| {
|
||||
private::RustlsConnect(private::RustlsConnectData {
|
||||
hostname: dns_name.to_owned(),
|
||||
connector: Arc::clone(&self.config).into(),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,6 @@ use rustls::pki_types;
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
use super::connect_compute::ConnectMechanism;
|
||||
use super::retry::CouldRetry;
|
||||
@@ -29,6 +28,7 @@ use crate::control_plane::{
|
||||
self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache,
|
||||
};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId};
|
||||
use crate::{sasl, scram};
|
||||
|
||||
|
||||
@@ -333,7 +333,7 @@ impl PoolingBackend {
|
||||
debug!("setting up backend session state");
|
||||
|
||||
// initiates the auth session
|
||||
if let Err(e) = client.query("select auth.init()", &[]).await {
|
||||
if let Err(e) = client.execute("select auth.init()", &[]).await {
|
||||
discard.discard();
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
@@ -6,9 +6,10 @@ use std::task::{ready, Poll};
|
||||
use futures::future::poll_fn;
|
||||
use futures::Future;
|
||||
use smallvec::SmallVec;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::{AsyncMessage, Socket};
|
||||
use tokio_postgres::AsyncMessage;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
#[cfg(test)]
|
||||
@@ -57,7 +58,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
ctx: &RequestContext,
|
||||
conn_info: ConnInfo,
|
||||
client: C,
|
||||
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
|
||||
mut connection: tokio_postgres::Connection<TcpStream, NoTlsStream>,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client<C> {
|
||||
|
||||
@@ -24,10 +24,11 @@ use p256::ecdsa::{Signature, SigningKey};
|
||||
use parking_lot::RwLock;
|
||||
use serde_json::value::RawValue;
|
||||
use signature::Signer;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::types::ToSql;
|
||||
use tokio_postgres::{AsyncMessage, Socket};
|
||||
use tokio_postgres::AsyncMessage;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||
|
||||
@@ -163,7 +164,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
ctx: &RequestContext,
|
||||
conn_info: ConnInfo,
|
||||
client: C,
|
||||
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
|
||||
mut connection: tokio_postgres::Connection<TcpStream, NoTlsStream>,
|
||||
key: SigningKey,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
@@ -286,11 +287,11 @@ impl ClientInnerCommon<tokio_postgres::Client> {
|
||||
let token = resign_jwt(&local_data.key, payload, local_data.jti)?;
|
||||
|
||||
// initiates the auth session
|
||||
self.inner.simple_query("discard all").await?;
|
||||
self.inner.batch_execute("discard all").await?;
|
||||
self.inner
|
||||
.query(
|
||||
.execute(
|
||||
"select auth.jwt_session_init($1)",
|
||||
&[&token as &(dyn ToSql + Sync)],
|
||||
&[&&*token as &(dyn ToSql + Sync)],
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user