mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 17:02:56 +00:00
Refactor Client into EstablishedConnection.
The name "Client" was a bit ambiguous. Instead of encapsulating all the data needed to establish the connection, change it so that it encapsulates the streams, after the connection has been established. With that, "EstablishedConnection" is a fitting name for it.
This commit is contained in:
@@ -115,14 +115,16 @@ pub async fn handle_ws_client(
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
|
||||
let client = Client::new(
|
||||
let conn = EstablishedConnection::connect_to_db(
|
||||
stream,
|
||||
creds,
|
||||
¶ms,
|
||||
session_id,
|
||||
cancel_map.new_session()?,
|
||||
);
|
||||
client.handle_connection(true).await
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
conn.handle_connection().await
|
||||
}
|
||||
|
||||
/// Handle an incoming client connection, handshake and authentication.
|
||||
@@ -160,14 +162,16 @@ async fn handle_client(
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
|
||||
let client = Client::new(
|
||||
let conn = EstablishedConnection::connect_to_db(
|
||||
stream,
|
||||
creds,
|
||||
¶ms,
|
||||
session_id,
|
||||
cancel_map.new_session()?,
|
||||
);
|
||||
client.handle_connection(false).await
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
conn.handle_connection().await
|
||||
}
|
||||
|
||||
/// Establish a (most probably, secure) connection with the client.
|
||||
@@ -237,62 +241,35 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
|
||||
/// Thin connection context.
|
||||
struct Client<'a, S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
stream: PqStream<S>,
|
||||
/// Client credentials that we care about.
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
params: &'a StartupMessageParams,
|
||||
/// Unique connection ID.
|
||||
session_id: uuid::Uuid,
|
||||
struct EstablishedConnection<'a, S> {
|
||||
client_stream: MeasuredStream<S>,
|
||||
db_stream: MeasuredStream<tokio::net::TcpStream>,
|
||||
|
||||
session: cancellation::Session<'a>,
|
||||
/// Hold on to the Session for as long as the connection is alive, so that
|
||||
/// it can be cancelled.
|
||||
_session: cancellation::Session<'a>,
|
||||
}
|
||||
|
||||
impl<'a, S> Client<'a, S> {
|
||||
/// Construct a new connection context.
|
||||
fn new(
|
||||
stream: PqStream<S>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
params: &'a StartupMessageParams,
|
||||
session_id: uuid::Uuid,
|
||||
session: cancellation::Session<'a>,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
creds,
|
||||
params,
|
||||
session_id,
|
||||
session,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
async fn handle_connection(self, use_cleartext_password_flow: bool) -> anyhow::Result<()> {
|
||||
let (mut client, mut db) = self.connect_to_db(use_cleartext_password_flow).await?;
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send> EstablishedConnection<'_, S> {
|
||||
async fn handle_connection(mut self) -> anyhow::Result<()> {
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
|
||||
let _ = tokio::io::copy_bidirectional(&mut self.client_stream, &mut self.db_stream).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Let the client authenticate and connect to the designated compute node.
|
||||
/// On return, the connection is fully established in both ways, and we can start
|
||||
/// forwarding the bytes.
|
||||
#[instrument(skip_all)]
|
||||
async fn connect_to_db(
|
||||
self,
|
||||
async fn connect_to_db<'a>(
|
||||
mut stream: PqStream<S>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
params: &'a StartupMessageParams,
|
||||
session_id: uuid::Uuid,
|
||||
session: cancellation::Session<'a>,
|
||||
use_cleartext_password_flow: bool,
|
||||
) -> anyhow::Result<(MeasuredStream<S>, MeasuredStream<tokio::net::TcpStream>)> {
|
||||
let Self {
|
||||
mut stream,
|
||||
creds,
|
||||
params,
|
||||
session_id,
|
||||
session,
|
||||
} = self;
|
||||
|
||||
) -> anyhow::Result<EstablishedConnection<S>> {
|
||||
let extra = auth::ConsoleReqExtra {
|
||||
session_id, // aka this connection's id
|
||||
application_name: params.get("application_name"),
|
||||
@@ -340,11 +317,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
.await?;
|
||||
|
||||
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("tx"));
|
||||
let client = MeasuredStream::new(stream.into_inner(), m_sent);
|
||||
let client_stream = MeasuredStream::new(stream.into_inner(), m_sent);
|
||||
|
||||
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("rx"));
|
||||
let db = MeasuredStream::new(db.stream, m_recv);
|
||||
let db_stream = MeasuredStream::new(db.stream, m_recv);
|
||||
|
||||
Ok((client, db))
|
||||
Ok(EstablishedConnection {
|
||||
client_stream,
|
||||
db_stream,
|
||||
_session: session,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user