mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 01:12:56 +00:00
fix
This commit is contained in:
@@ -4,6 +4,7 @@ pub mod jwt;
|
||||
mod link;
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -23,6 +24,7 @@ use crate::context::RequestMonitoring;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
use crate::proxy::handshake::KtlsAsyncReadReady;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
|
||||
use crate::stream::Stream;
|
||||
@@ -274,7 +276,9 @@ async fn auth_quirks(
|
||||
ctx: &RequestMonitoring,
|
||||
api: &impl console::Api,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqStream<
|
||||
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
|
||||
>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -358,7 +362,9 @@ async fn authenticate_with_secret(
|
||||
ctx: &RequestMonitoring,
|
||||
secret: AuthSecret,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqStream<
|
||||
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
|
||||
>,
|
||||
unauthenticated_password: Option<Vec<u8>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
@@ -417,7 +423,9 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
ctx: &RequestMonitoring,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqStream<
|
||||
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
|
||||
>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -542,7 +550,7 @@ mod tests {
|
||||
CachedNodeInfo,
|
||||
},
|
||||
context::RequestMonitoring,
|
||||
proxy::NeonOptions,
|
||||
proxy::{tests::DummyClient, NeonOptions},
|
||||
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
|
||||
scram::{threadpool::ThreadPool, ServerSecret},
|
||||
stream::{PqStream, Stream},
|
||||
@@ -650,7 +658,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_scram() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new(Stream::from_raw(DummyClient(server)));
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let api = Auth {
|
||||
@@ -727,7 +735,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_cleartext() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new(Stream::from_raw(DummyClient(server)));
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let api = Auth {
|
||||
@@ -779,7 +787,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_password_hack() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new(Stream::from_raw(DummyClient(server)));
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let api = Auth {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::os::fd::AsRawFd;
|
||||
|
||||
use super::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::{
|
||||
auth::{self, backend::ComputeCredentialKeys, AuthFlow},
|
||||
@@ -5,6 +7,7 @@ use crate::{
|
||||
config::AuthenticationConfig,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
proxy::handshake::KtlsAsyncReadReady,
|
||||
sasl,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
@@ -14,7 +17,9 @@ use tracing::{info, warn};
|
||||
pub(super) async fn authenticate(
|
||||
ctx: &RequestMonitoring,
|
||||
creds: ComputeUserInfo,
|
||||
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut PqStream<
|
||||
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
|
||||
>,
|
||||
config: &'static AuthenticationConfig,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::os::fd::AsRawFd;
|
||||
|
||||
use super::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
|
||||
};
|
||||
@@ -7,6 +9,7 @@ use crate::{
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
intern::EndpointIdInt,
|
||||
proxy::handshake::KtlsAsyncReadReady,
|
||||
sasl,
|
||||
stream::{self, Stream},
|
||||
};
|
||||
@@ -20,7 +23,9 @@ use tracing::{info, warn};
|
||||
pub async fn authenticate_cleartext(
|
||||
ctx: &RequestMonitoring,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqStream<
|
||||
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
|
||||
>,
|
||||
secret: AuthSecret,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
@@ -62,7 +67,9 @@ pub async fn authenticate_cleartext(
|
||||
pub async fn password_hack_no_authentication(
|
||||
ctx: &RequestMonitoring,
|
||||
info: ComputeUserInfoNoEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqStream<
|
||||
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
|
||||
>,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
warn!("project not specified, resorting to the password hack auth flow");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
@@ -6,13 +6,14 @@ use crate::{
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
intern::EndpointIdInt,
|
||||
proxy::handshake::KtlsAsyncReadReady,
|
||||
sasl,
|
||||
scram::{self, threadpool::ThreadPool},
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
use std::{io, sync::Arc};
|
||||
use std::{io, os::fd::AsRawFd, sync::Arc};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
@@ -70,7 +71,7 @@ impl AuthMethod for CleartextPassword {
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub struct AuthFlow<'a, S, State> {
|
||||
pub struct AuthFlow<'a, S: AsRawFd, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut PqStream<Stream<S>>,
|
||||
/// State might contain ancillary data (see [`Self::begin`]).
|
||||
@@ -79,7 +80,7 @@ pub struct AuthFlow<'a, S, State> {
|
||||
}
|
||||
|
||||
/// Initial state of the stream wrapper.
|
||||
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
impl<'a, S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AuthFlow<'a, S, Begin> {
|
||||
/// Create a new wrapper for client authentication.
|
||||
pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
|
||||
let tls_server_end_point = stream.get_ref().tls_server_end_point();
|
||||
@@ -105,7 +106,9 @@ impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>
|
||||
AuthFlow<'_, S, PasswordHack>
|
||||
{
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
@@ -124,7 +127,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>
|
||||
AuthFlow<'_, S, CleartextPassword>
|
||||
{
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
@@ -149,7 +154,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
}
|
||||
|
||||
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AuthFlow<'_, S, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
|
||||
let Scram(secret, ctx) = self.state;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::os::fd::AsRawFd;
|
||||
/// A stand-alone program that routes connections, e.g. from
|
||||
/// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
|
||||
///
|
||||
@@ -9,6 +10,7 @@ use futures::future::Either;
|
||||
use itertools::Itertools;
|
||||
use proxy::context::RequestMonitoring;
|
||||
use proxy::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use proxy::proxy::handshake::KtlsAsyncReadReady;
|
||||
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
|
||||
use rustls::pki_types::PrivateKeyDer;
|
||||
use tokio::net::TcpListener;
|
||||
@@ -197,7 +199,7 @@ async fn task_main(
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>(
|
||||
ctx: &RequestMonitoring,
|
||||
raw_stream: S,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
@@ -248,7 +250,7 @@ async fn handle_client(
|
||||
ctx: RequestMonitoring,
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
|
||||
|
||||
|
||||
@@ -285,7 +285,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
|
||||
let args = ProxyCliArgs::parse();
|
||||
let config = build_config(&args)?;
|
||||
let config = build_config(&args).await?;
|
||||
|
||||
info!("Authentication backend: {}", config.auth_backend);
|
||||
info!("Using region: {}", args.aws_region);
|
||||
@@ -529,16 +529,14 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
/// ProxyConfig is created at proxy startup, and lives forever.
|
||||
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
async fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
|
||||
Metrics::install(thread_pool.metrics.clone());
|
||||
|
||||
let tls_config = match (&args.tls_key, &args.tls_cert) {
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
|
||||
key_path,
|
||||
cert_path,
|
||||
args.certs_dir.as_ref(),
|
||||
)?),
|
||||
(Some(key_path), Some(cert_path)) => {
|
||||
Some(config::configure_tls(key_path, cert_path, args.certs_dir.as_ref()).await?)
|
||||
}
|
||||
(None, None) => None,
|
||||
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
|
||||
};
|
||||
|
||||
@@ -76,7 +76,7 @@ impl TlsConfig {
|
||||
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
|
||||
|
||||
/// Configure TLS for the main endpoint.
|
||||
pub fn configure_tls(
|
||||
pub async fn configure_tls(
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
certs_dir: Option<&String>,
|
||||
@@ -114,8 +114,9 @@ pub fn configure_tls(
|
||||
#[cfg(target_os = "linux")]
|
||||
let provider = {
|
||||
let mut provider = provider;
|
||||
let compat = ktls::CompatibleCiphers::new()?;
|
||||
provider.cipher_suites.retain(|s| compat.is_supported(s));
|
||||
let compat = ktls::CompatibleCiphers::new().await?;
|
||||
provider.cipher_suites.retain(|s| compat.is_compatible(*s));
|
||||
provider
|
||||
};
|
||||
|
||||
// allow TLS 1.2 to be compatible with older client libraries
|
||||
|
||||
@@ -28,7 +28,7 @@ impl<S: AsRawFd> AsRawFd for ChainRW<S> {
|
||||
}
|
||||
|
||||
#[cfg(all(target_os = "linux", not(test)))]
|
||||
impl<S: ktls::AsyncReadReady> AsRawFd for ChainRW<S> {
|
||||
impl<S: ktls::AsyncReadReady> ktls::AsyncReadReady for ChainRW<S> {
|
||||
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
if self.buf.is_empty() {
|
||||
self.inner.poll_read_ready(cx)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
pub mod tests;
|
||||
|
||||
pub mod connect_compute;
|
||||
mod copy_bidirectional;
|
||||
|
||||
@@ -50,6 +50,8 @@ impl ReportableError for HandshakeError {
|
||||
match self {
|
||||
HandshakeError::EarlyData => crate::error::ErrorKind::User,
|
||||
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
|
||||
#[cfg(all(target_os = "linux", not(test)))]
|
||||
HandshakeError::KtlsUpgradeError(_) => crate::error::ErrorKind::Service,
|
||||
// This error should not happen, but will if we have no default certificate and
|
||||
// the client sends no SNI extension.
|
||||
// If they provide SNI then we can be sure there is a certificate that matches.
|
||||
@@ -64,7 +66,7 @@ impl ReportableError for HandshakeError {
|
||||
}
|
||||
}
|
||||
|
||||
pub enum HandshakeData<S> {
|
||||
pub enum HandshakeData<S: AsRawFd> {
|
||||
Startup(
|
||||
PqStream<Stream<S>>,
|
||||
Option<EndpointId>,
|
||||
@@ -201,7 +203,7 @@ where
|
||||
#[cfg(any(not(target_os = "linux"), test))]
|
||||
tls: Box::pin(tls_stream),
|
||||
#[cfg(all(target_os = "linux", not(test)))]
|
||||
tls: Box::pin(ktls::config_ktls_server(tls_stream)?),
|
||||
tls: ktls::config_ktls_server(tls_stream).await?,
|
||||
tls_server_end_point,
|
||||
},
|
||||
read_buf,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::os::fd::AsRawFd;
|
||||
|
||||
use crate::{
|
||||
cancellation,
|
||||
compute::PostgresConnection,
|
||||
@@ -10,7 +12,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
use super::copy_bidirectional::ErrorSource;
|
||||
use super::{copy_bidirectional::ErrorSource, handshake::KtlsAsyncReadReady};
|
||||
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
@@ -57,7 +59,7 @@ pub async fn proxy_pass(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct ProxyPassthrough<P, S> {
|
||||
pub struct ProxyPassthrough<P, S: AsRawFd> {
|
||||
pub client: Stream<S>,
|
||||
pub compute: PostgresConnection,
|
||||
pub aux: MetricsAuxInfo,
|
||||
@@ -67,7 +69,7 @@ pub struct ProxyPassthrough<P, S> {
|
||||
pub cancel: cancellation::Session<P>,
|
||||
}
|
||||
|
||||
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
|
||||
impl<P, S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> ProxyPassthrough<P, S> {
|
||||
pub async fn proxy_pass(self) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
|
||||
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
|
||||
|
||||
@@ -64,7 +64,7 @@ fn generate_certs(
|
||||
))
|
||||
}
|
||||
|
||||
struct DummyClient(DuplexStream);
|
||||
pub struct DummyClient(pub DuplexStream);
|
||||
|
||||
impl AsRawFd for DummyClient {
|
||||
fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
|
||||
@@ -170,7 +170,9 @@ fn generate_tls_config<'a>(
|
||||
|
||||
#[async_trait]
|
||||
trait TestAuth: Sized {
|
||||
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
async fn authenticate<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + AsRawFd + KtlsAsyncReadReady,
|
||||
>(
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -199,7 +201,9 @@ impl Scram {
|
||||
|
||||
#[async_trait]
|
||||
impl TestAuth for Scram {
|
||||
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
async fn authenticate<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + AsRawFd + KtlsAsyncReadReady,
|
||||
>(
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
|
||||
@@ -197,6 +197,8 @@ impl MaybeTlsAcceptor for rustls::ServerConfig {
|
||||
|
||||
#[cfg(all(target_os = "linux", not(test)))]
|
||||
return ktls::config_ktls_server(tls)
|
||||
.await
|
||||
.map(|s| Box::pin(s) as _)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
|
||||
|
||||
#[cfg(any(not(target_os = "linux"), test))]
|
||||
|
||||
@@ -52,8 +52,8 @@ impl<S> AsRawFd for WebSocketRw<S> {
|
||||
}
|
||||
}
|
||||
#[cfg(all(target_os = "linux", not(test)))]
|
||||
impl<S: ktls::AsyncReadReady> AsRawFd for ChainRW<S> {
|
||||
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
impl<S> ktls::AsyncReadReady for WebSocketRw<S> {
|
||||
fn poll_read_ready(&self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
unreachable!("ktls should not need to be used for websocket rw")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
use crate::config::TlsServerEndPoint;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::handshake::KtlsAsyncReadReady;
|
||||
use bytes::BytesMut;
|
||||
|
||||
use pq_proto::framed::{ConnectionError, Framed};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
use rustls::ServerConfig;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::{io, task};
|
||||
@@ -172,7 +174,7 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
}
|
||||
|
||||
/// Wrapper for upgrading raw streams into secure streams.
|
||||
pub enum Stream<S> {
|
||||
pub enum Stream<S: AsRawFd> {
|
||||
/// We always begin with a raw stream,
|
||||
/// which may then be upgraded into a secure stream.
|
||||
Raw { raw: S },
|
||||
@@ -182,16 +184,16 @@ pub enum Stream<S> {
|
||||
tls: Pin<Box<TlsStream<S>>>,
|
||||
|
||||
#[cfg(all(target_os = "linux", not(test)))]
|
||||
tls: Pin<Box<ktls::KtlsStream<S>>>,
|
||||
tls: ktls::KtlsStream<S>,
|
||||
|
||||
/// Channel binding parameter
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
},
|
||||
}
|
||||
|
||||
impl<S: Unpin> Unpin for Stream<S> {}
|
||||
impl<S: Unpin + AsRawFd> Unpin for Stream<S> {}
|
||||
|
||||
impl<S> Stream<S> {
|
||||
impl<S: AsRawFd> Stream<S> {
|
||||
/// Construct a new instance from a raw stream.
|
||||
pub fn from_raw(raw: S) -> Self {
|
||||
Self::Raw { raw }
|
||||
@@ -218,7 +220,7 @@ pub enum StreamUpgradeError {
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd> Stream<S> {
|
||||
/// If possible, upgrade raw stream into a secure TLS-based stream.
|
||||
pub async fn upgrade(
|
||||
self,
|
||||
@@ -239,7 +241,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AsyncRead for Stream<S> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
@@ -247,12 +249,12 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
|
||||
Self::Tls { tls, .. } => tls.as_mut().poll_read(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AsyncWrite for Stream<S> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
@@ -260,7 +262,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
|
||||
Self::Tls { tls, .. } => tls.as_mut().poll_write(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,7 +272,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_flush(context),
|
||||
Self::Tls { tls, .. } => tls.as_mut().poll_flush(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,7 +282,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
|
||||
Self::Tls { tls, .. } => tls.as_mut().poll_shutdown(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user