call wake compute

This commit is contained in:
Conrad Ludgate
2024-09-12 17:42:21 +01:00
parent ce200a53e8
commit f95ddef4e0
3 changed files with 62 additions and 53 deletions

View File

@@ -1,11 +1,9 @@
use std::{sync::Arc, time::Duration};
use clap::Parser;
use futures::TryStreamExt;
use pq_proto::FeStartupPacket;
use proxy::{
auth::{self, backend::AuthRateLimiter},
auth_proxy::{self, backend::MaybeOwned, AuthProxyStream, Backend},
auth::backend::AuthRateLimiter,
auth_proxy::{backend::MaybeOwned, Backend},
config::{self, AuthenticationConfig, CacheOptions, ProjectInfoCacheOptions},
console::{
caches::ApiCaches,
@@ -14,23 +12,18 @@ use proxy::{
},
http,
metrics::Metrics,
proxy::NeonOptions,
proxy::{handle_stream, AuthProxyConfig},
rate_limiter::{RateBucketInfo, WakeComputeRateLimiter},
scram::threadpool::ThreadPool,
stream::AuthProxyStreamExt,
PglbCodec, PglbControlMessage, PglbMessage,
};
use quinn::{
crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, RecvStream, SendStream,
VarInt,
};
use quinn::{crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, VarInt};
use tokio::{
io::{join, AsyncWriteExt},
io::AsyncWriteExt,
select,
signal::unix::{signal, SignalKind},
time::interval,
};
use tokio_util::{codec::Framed, task::TaskTracker};
use tokio_util::task::TaskTracker;
/// Neon proxy/router
#[derive(Parser)]
@@ -177,7 +170,7 @@ async fn main() {
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
};
let config = Box::leak(Box::new(Config { backend, auth }));
let config = Box::leak(Box::new(AuthProxyConfig { backend, auth }));
loop {
select! {
@@ -208,11 +201,6 @@ async fn main() {
conn.close(VarInt::from_u32(1), b"graceful shutdown");
}
struct Config {
backend: Backend<'static, ()>,
auth: AuthenticationConfig,
}
#[derive(Copy, Clone, Debug)]
struct NoVerify;
@@ -250,36 +238,3 @@ impl danger::ServerCertVerifier for NoVerify {
vec![quinn::rustls::SignatureScheme::ECDSA_NISTP256_SHA256]
}
}
async fn handle_stream(config: &'static Config, send: SendStream, recv: RecvStream) {
let mut stream: AuthProxyStream = Framed::new(join(recv, send), PglbCodec);
let first_msg = stream.try_next().await.unwrap();
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(first_msg))) = first_msg
else {
panic!("invalid first msg")
};
let startup = stream.read_startup_packet().await.unwrap();
let FeStartupPacket::StartupMessage { version, params } = startup else {
panic!("invalid startup message")
};
// Extract credentials which we're going to use for auth.
let user_info = auth::ComputeUserInfoMaybeEndpoint {
user: params.get("user").unwrap().into(),
endpoint_id: first_msg
.server_name
.as_deref()
.map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()),
options: NeonOptions::parse_params(&params),
};
let user_info = config.backend.as_ref().map(|()| user_info);
let user_info = match user_info.authenticate(&mut stream, &config.auth).await {
Ok(auth_result) => auth_result,
Err(e) => {
return stream.throw_error(e).await.unwrap();
}
};
}

View File

@@ -7,9 +7,20 @@ pub(crate) mod handshake;
pub(crate) mod passthrough;
pub(crate) mod retry;
pub(crate) mod wake_compute;
use connect_compute::ComputeConnectBackend;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use futures::TryStreamExt;
use pq_proto::FeStartupPacket;
use quinn::RecvStream;
use quinn::SendStream;
use tokio::io::join;
use tokio_util::codec::Framed;
use crate::auth_proxy::AuthProxyStream;
use crate::stream::AuthProxyStreamExt;
use crate::PglbControlMessage;
use crate::PglbMessage;
use crate::{
auth,
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
@@ -431,3 +442,46 @@ pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> {
let (_, [k, v]) = cap.extract();
Some((k, v))
}
pub struct AuthProxyConfig {
pub backend: crate::auth_proxy::Backend<'static, ()>,
pub auth: crate::config::AuthenticationConfig,
}
pub async fn handle_stream(config: &'static AuthProxyConfig, send: SendStream, recv: RecvStream) {
let mut stream: AuthProxyStream = Framed::new(join(recv, send), crate::PglbCodec);
let first_msg = stream.try_next().await.unwrap();
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(first_msg))) = first_msg
else {
panic!("invalid first msg")
};
let startup = stream.read_startup_packet().await.unwrap();
let FeStartupPacket::StartupMessage { version: _, params } = startup else {
panic!("invalid startup message")
};
// Extract credentials which we're going to use for auth.
let user_info = auth::ComputeUserInfoMaybeEndpoint {
user: params.get("user").unwrap().into(),
endpoint_id: first_msg
.server_name
.as_deref()
.map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()),
options: NeonOptions::parse_params(&params),
};
let user_info = config.backend.as_ref().map(|()| user_info);
let user_info = match user_info.authenticate(&mut stream, &config.auth).await {
Ok(auth_result) => auth_result,
Err(e) => {
return stream.throw_error(e).await.unwrap();
}
};
user_info
.wake_compute(&RequestMonitoring::test())
.await
.unwrap();
}

View File

@@ -299,7 +299,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
}
#[allow(async_fn_in_trait)]
pub trait AuthProxyStreamExt {
pub(crate) trait AuthProxyStreamExt {
/// Write the message into an internal buffer, but don't flush the underlying stream.
fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>;