Take port from SNI, formatting, make clippy happy

This commit is contained in:
Alexey Kondratov
2023-04-27 12:45:54 +02:00
committed by Stas Kelvich
parent 556fb1642a
commit 81c75586ab
4 changed files with 44 additions and 18 deletions

View File

@@ -1,10 +1,15 @@
use std::{net::SocketAddr, sync::Arc};
use tokio::{net::TcpListener, io::AsyncWriteExt};
use tokio::{io::AsyncWriteExt, net::TcpListener};
use anyhow::{bail, ensure, Context};
use clap::{self, Arg};
use futures::TryFutureExt;
use proxy::{cancellation::CancelMap, auth::{AuthFlow, self}, compute::ConnCfg, console::messages::MetricsAuxInfo};
use proxy::{
auth::{self, AuthFlow},
cancellation::CancelMap,
compute::ConnCfg,
console::messages::MetricsAuxInfo,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use utils::{project_git_version, sentry_init::init_sentry};
@@ -199,28 +204,36 @@ async fn handle_client(
conn_cfg.set_startup_params(&params);
conn_cfg.password(password);
// cut off first part of the sni domain
// Cut off first part of the SNI domain
// We receive required destination details in the format of
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
let sni = stream.get_ref().sni_hostname().unwrap();
let dest: Vec<&str> = sni
.split_once('.').context("invalid sni")?.0
.splitn(3, "--").collect();
.split_once('.')
.context("invalid SNI")?
.0
.splitn(3, "--")
.collect();
let destination = format!("{}.{}.{}", dest[0], dest[1], dest_suffix);
let port = dest[2].parse::<u16>().context("invalid port")?;
info!("destination: {:?}", destination);
info!("destination: {:?}:{}", destination, port);
conn_cfg.host(destination.as_str());
conn_cfg.port(6432); // TODO: it's a pooler and should be passed externally
conn_cfg.port(port);
let mut conn = conn_cfg.connect()
let mut conn = conn_cfg
.connect()
.or_else(|e| stream.throw_error(e))
.await?;
cancel_map.with_session(|session| async {
proxy::proxy::prepare_client_connection(&conn, false, session, &mut stream).await?;
let (stream, read_buf) = stream.into_inner();
conn.stream.write_all(&read_buf).await?;
let metrics_aux: MetricsAuxInfo = Default::default();
proxy::proxy::proxy_pass(stream, conn.stream, &metrics_aux).await
})
.await
cancel_map
.with_session(|session| async {
proxy::proxy::prepare_client_connection(&conn, false, session, &mut stream).await?;
let (stream, read_buf) = stream.into_inner();
conn.stream.write_all(&read_buf).await?;
let metrics_aux: MetricsAuxInfo = Default::default();
proxy::proxy::proxy_pass(stream, conn.stream, &metrics_aux).await
})
.await
}

View File

@@ -7,7 +7,7 @@ use anyhow::bail;
use clap::{self, Arg};
use proxy::config::{self, ProxyConfig};
use std::{borrow::Cow, net::SocketAddr};
use tokio::{net::TcpListener};
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::info;
use utils::{project_git_version, sentry_init::init_sentry};

View File

@@ -125,6 +125,12 @@ impl std::ops::DerefMut for ConnCfg {
}
}
impl Default for ConnCfg {
fn default() -> Self {
Self::new()
}
}
impl ConnCfg {
/// Establish a raw TCP connection to the compute node.
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> {

View File

@@ -1,6 +1,9 @@
///! A group of high-level tests for connection establishing logic and auth.
use super::*;
use crate::config::TlsConfig;
use crate::{auth, sasl, scram};
use async_trait::async_trait;
use rstest::rstest;
use tokio_postgres::config::SslMode;
@@ -133,7 +136,11 @@ async fn dummy_proxy(
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
let (mut stream, _params) = handshake(client, tls.as_ref(), &cancel_map)
let server_config = match tls {
Some(tls) => Some(tls.config),
None => None,
};
let (mut stream, _params) = handshake(client, server_config, &cancel_map)
.await?
.context("handshake failed")?;