mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-19 22:20:37 +00:00
Merge branch 'cloneable/pglb-workers' into pglb
This commit is contained in:
@@ -7,23 +7,32 @@ use std::{
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use bytes::BytesMut;
|
||||
use futures::sink::SinkExt;
|
||||
use indexmap::IndexMap;
|
||||
use proxy::config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL};
|
||||
use quinn::{Connection, Endpoint};
|
||||
use proxy::{
|
||||
config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL},
|
||||
ConnectionInitiatedPayload, PglbCodec,
|
||||
};
|
||||
use quinn::{Connection, Endpoint, RecvStream, SendStream};
|
||||
use rand::Rng;
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
use tokio::{
|
||||
io::AsyncReadExt,
|
||||
io::{join, AsyncReadExt, Join},
|
||||
net::{TcpListener, TcpStream},
|
||||
time::timeout,
|
||||
};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tokio_util::codec::Framed;
|
||||
use tracing::{error, warn};
|
||||
|
||||
type AuthConnId = usize;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AuthConnState {
|
||||
conns: Mutex<IndexMap<AuthConnId, AuthConn>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct AuthConn {
|
||||
conn: Connection,
|
||||
// latency info...
|
||||
@@ -46,7 +55,11 @@ async fn main() -> Result<()> {
|
||||
|
||||
let frontend_config = frontent_tls_config("pglb-fe", "pglb-fe")?;
|
||||
|
||||
let _frontend_handle = tokio::spawn(start_frontend("0.0.0.0:5432".parse()?, frontend_config));
|
||||
let _frontend_handle = tokio::spawn(start_frontend(
|
||||
"0.0.0.0:5432".parse()?,
|
||||
frontend_config,
|
||||
auth_connections.clone(),
|
||||
));
|
||||
|
||||
quinn_handle.await.unwrap();
|
||||
|
||||
@@ -151,17 +164,28 @@ fn frontent_tls_config(hostname: &str, common_name: &str) -> Result<TlsConfig> {
|
||||
})
|
||||
}
|
||||
|
||||
async fn start_frontend(addr: SocketAddr, tls: TlsConfig) -> Result<Infallible> {
|
||||
async fn start_frontend(
|
||||
addr: SocketAddr,
|
||||
tls: TlsConfig,
|
||||
state: Arc<AuthConnState>,
|
||||
) -> Result<Infallible> {
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let workers = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((socket, peer_addr)) => {
|
||||
let tls = tls.clone();
|
||||
connections.spawn_local(handle_frontend_connection(socket, peer_addr, tls));
|
||||
Ok((socket, client_addr)) => {
|
||||
let w = Worker {
|
||||
state: Some(WorkerState::Start {
|
||||
client_stream: socket,
|
||||
client_addr,
|
||||
}),
|
||||
tls: tls.clone(),
|
||||
auth_conns: Arc::clone(&state),
|
||||
};
|
||||
workers.spawn(w.start());
|
||||
}
|
||||
Err(e) => {
|
||||
error!("connection accept error: {e}");
|
||||
@@ -170,38 +194,207 @@ async fn start_frontend(addr: SocketAddr, tls: TlsConfig) -> Result<Infallible>
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_frontend_connection(mut stream: TcpStream, _peer_addr: SocketAddr, tls: TlsConfig) {
|
||||
match stream.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!("socket option error: {e}");
|
||||
return;
|
||||
#[derive(Debug)]
|
||||
struct Worker {
|
||||
state: Option<WorkerState>,
|
||||
tls: TlsConfig,
|
||||
auth_conns: Arc<AuthConnState>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum WorkerState {
|
||||
Start {
|
||||
client_stream: TcpStream,
|
||||
client_addr: SocketAddr,
|
||||
},
|
||||
AuthConnect {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
payload: ConnectionInitiatedPayload,
|
||||
},
|
||||
AuthPassthrough {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
},
|
||||
ComputeConnect {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
},
|
||||
ComputePassthrough {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
compute_conn: (),
|
||||
},
|
||||
End,
|
||||
}
|
||||
|
||||
impl Worker {
|
||||
async fn start(mut self) -> Result<()> {
|
||||
loop {
|
||||
match self.state.take().expect("state should be specified") {
|
||||
WorkerState::Start {
|
||||
client_stream,
|
||||
client_addr,
|
||||
} => {
|
||||
self.state = Some(self.handle_start(client_stream, client_addr).await?);
|
||||
}
|
||||
|
||||
WorkerState::AuthConnect {
|
||||
client_stream,
|
||||
payload,
|
||||
} => {
|
||||
self.state = Some(self.handle_auth_connect(client_stream, payload).await?);
|
||||
}
|
||||
|
||||
WorkerState::AuthPassthrough {
|
||||
client_stream,
|
||||
auth_stream,
|
||||
} => {
|
||||
self.state = Some(
|
||||
self.handle_auth_passthrough(client_stream, auth_stream)
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
|
||||
WorkerState::ComputeConnect {
|
||||
client_stream,
|
||||
auth_stream,
|
||||
} => {
|
||||
self.state = Some(
|
||||
self.handle_compute_connect(client_stream, auth_stream)
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
|
||||
WorkerState::ComputePassthrough {
|
||||
client_stream,
|
||||
compute_conn,
|
||||
} => {
|
||||
self.state = Some(
|
||||
self.handle_compute_passthrough(client_stream, compute_conn)
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
|
||||
WorkerState::End => {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: HAProxy protocol?
|
||||
|
||||
let tls_requested = match handle_ssl_request_message(&mut stream).await {
|
||||
Ok(tls_requested) => tls_requested,
|
||||
Err(e) => {
|
||||
error!("check_for_ssl_request: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if tls_requested {
|
||||
let (stream, ep, sn) = match tls_upgrade(stream, tls).await {
|
||||
Ok((stream, ep, sn)) => (stream, ep, sn),
|
||||
async fn handle_start(
|
||||
&self,
|
||||
mut client_stream: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> Result<WorkerState> {
|
||||
match client_stream.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!("tls_upgrade: {e}");
|
||||
return;
|
||||
bail!("socket option error: {e}");
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: send auth msg with tls ep and server name
|
||||
} else {
|
||||
// TODO: send auth msg without server name
|
||||
// TODO: HAProxy protocol
|
||||
|
||||
let tls_requested = match handle_ssl_request_message(&mut client_stream).await {
|
||||
Ok(tls_requested) => tls_requested,
|
||||
Err(e) => {
|
||||
bail!("check_for_ssl_request: {e}");
|
||||
}
|
||||
};
|
||||
|
||||
let (client_stream, payload) = if tls_requested {
|
||||
let (stream, tls_server_end_point, server_name) =
|
||||
match tls_upgrade(client_stream, self.tls.clone()).await {
|
||||
Ok((stream, ep, sn)) => (stream, ep, sn),
|
||||
Err(e) => {
|
||||
bail!("tls_upgrade: {e}");
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
stream,
|
||||
ConnectionInitiatedPayload {
|
||||
tls_server_end_point,
|
||||
server_name,
|
||||
ip_addr: peer_addr.ip(),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// TODO: support unsecured connections?
|
||||
bail!("closing non-TLS connection");
|
||||
};
|
||||
|
||||
Ok(WorkerState::AuthConnect {
|
||||
client_stream,
|
||||
payload,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_auth_connect(
|
||||
&self,
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
payload: ConnectionInitiatedPayload,
|
||||
) -> Result<WorkerState> {
|
||||
let auth_conn = {
|
||||
let conns = self.auth_conns.conns.lock().unwrap();
|
||||
if conns.is_empty() {
|
||||
bail!("no auth proxies avaiable");
|
||||
}
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
conns
|
||||
.get_index(rng.gen_range(0..conns.len()))
|
||||
.unwrap()
|
||||
.1
|
||||
.clone()
|
||||
|
||||
// TODO: check closed?
|
||||
};
|
||||
|
||||
let (send, recv) = auth_conn.conn.open_bi().await?;
|
||||
let mut auth_stream = Framed::new(join(recv, send), PglbCodec);
|
||||
|
||||
auth_stream
|
||||
.send(proxy::PglbMessage::Control(
|
||||
proxy::PglbControlMessage::ConnectionInitiated(payload),
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(WorkerState::AuthPassthrough {
|
||||
client_stream,
|
||||
auth_stream,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_auth_passthrough(
|
||||
&self,
|
||||
_client_stream: TlsStream<TcpStream>,
|
||||
_auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
) -> Result<WorkerState> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn handle_compute_connect(
|
||||
&self,
|
||||
_client_stream: TlsStream<TcpStream>,
|
||||
_auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
) -> Result<WorkerState> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn handle_compute_passthrough(
|
||||
&self,
|
||||
_client_stream: TlsStream<TcpStream>,
|
||||
_compute_conn: (),
|
||||
) -> Result<WorkerState> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TlsConfig {
|
||||
config: Arc<rustls::ServerConfig>,
|
||||
cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result<bool> {
|
||||
@@ -253,9 +446,3 @@ async fn tls_upgrade(
|
||||
|
||||
Ok((tls_stream, tls_server_end_point, server_name))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TlsConfig {
|
||||
config: Arc<rustls::ServerConfig>,
|
||||
cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
@@ -283,6 +283,7 @@ impl EndpointId {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PglbCodec;
|
||||
|
||||
impl tokio_util::codec::Encoder<PglbMessage> for PglbCodec {
|
||||
@@ -404,7 +405,7 @@ pub enum PglbControlMessage {
|
||||
ComputeEstablish,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ConnectionInitiatedPayload {
|
||||
pub tls_server_end_point: TlsServerEndPoint,
|
||||
pub server_name: Option<String>,
|
||||
|
||||
Reference in New Issue
Block a user