Add tests for link auth to compute connection

This commit is contained in:
Stas Kelvich
2023-04-28 13:06:54 +03:00
parent 040f736909
commit 9486d76b2a
11 changed files with 88 additions and 17 deletions

View File

@@ -98,6 +98,7 @@ pub(super) async fn authenticate(
value: NodeInfo {
config,
aux: db_info.aux.into(),
allow_self_signed_compute: false, // caller may override
},
})
}

View File

@@ -10,6 +10,7 @@ use std::{borrow::Cow, net::SocketAddr};
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::info;
use tracing::warn;
use utils::{project_git_version, sentry_init::init_sentry};
project_git_version!(GIT_VERSION);
@@ -96,6 +97,14 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
};
let allow_self_signed_compute: bool = args
.get_one::<String>("allow-self-signed-compute")
.unwrap()
.parse()?;
if allow_self_signed_compute {
warn!("allowing self-signed compute certificates");
}
let metric_collection = match (
args.get_one::<String>("metric-collection-endpoint"),
args.get_one::<String>("metric-collection-interval"),
@@ -145,6 +154,7 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
tls_config,
auth_backend,
metric_collection,
allow_self_signed_compute,
}));
Ok(config)
@@ -235,6 +245,12 @@ fn cli() -> clap::Command {
.help("cache for `wake_compute` api method (use `size=0` to disable)")
.default_value(config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO),
)
.arg(
Arg::new("allow-self-signed-compute")
.long("allow-self-signed-compute")
.help("Allow self-signed certificates for compute nodes (for testing)")
.default_value("false"),
)
}
#[cfg(test)]

View File

@@ -220,10 +220,14 @@ pub struct PostgresConnection {
}
impl ConnCfg {
async fn do_connect(&self) -> Result<PostgresConnection, ConnectionError> {
async fn do_connect(
&self,
allow_self_signed_compute: bool,
) -> Result<PostgresConnection, ConnectionError> {
let (socket_addr, stream, host) = self.connect_raw().await?;
let tls_connector = native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(allow_self_signed_compute)
.build()
.unwrap();
let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
@@ -257,8 +261,11 @@ impl ConnCfg {
}
/// Connect to a corresponding compute node.
pub async fn connect(&self) -> Result<PostgresConnection, ConnectionError> {
self.do_connect()
pub async fn connect(
&self,
allow_self_signed_compute: bool,
) -> Result<PostgresConnection, ConnectionError> {
self.do_connect(allow_self_signed_compute)
.inspect_err(|err| {
// Immediately log the error we have at our disposal.
error!("couldn't connect to compute node: {err}");

View File

@@ -12,6 +12,7 @@ pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub auth_backend: auth::BackendType<'static, ()>,
pub metric_collection: Option<MetricCollectionConfig>,
pub allow_self_signed_compute: bool,
}
#[derive(Debug)]

View File

@@ -170,6 +170,9 @@ pub struct NodeInfo {
/// Labels for proxy's metrics.
pub aux: Arc<MetricsAuxInfo>,
/// Whether we should accept self-signed certificates (for testing)
pub allow_self_signed_compute: bool,
}
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;

View File

@@ -93,6 +93,7 @@ impl Api {
let node = NodeInfo {
config,
aux: Default::default(),
allow_self_signed_compute: false,
};
Ok(node)

View File

@@ -106,6 +106,7 @@ impl Api {
let node = NodeInfo {
config,
aux: body.aux.into(),
allow_self_signed_compute: false,
};
Ok(node)

View File

@@ -155,7 +155,7 @@ pub async fn handle_ws_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds, &params, session_id);
let client = Client::new(stream, creds, &params, session_id, false);
cancel_map
.with_session(|session| client.connect_to_db(session, true))
.await
@@ -194,7 +194,15 @@ async fn handle_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds, &params, session_id);
let allow_self_signed_compute = config.allow_self_signed_compute;
let client = Client::new(
stream,
creds,
&params,
session_id,
allow_self_signed_compute,
);
cancel_map
.with_session(|session| client.connect_to_db(session, false))
.await
@@ -297,9 +305,11 @@ async fn connect_to_compute_once(
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
};
let allow_self_signed_compute = node_info.allow_self_signed_compute;
node_info
.config
.connect()
.connect(allow_self_signed_compute)
.inspect_err(invalidate_cache)
.await
}
@@ -420,6 +430,8 @@ struct Client<'a, S> {
params: &'a StartupMessageParams,
/// Unique connection ID.
session_id: uuid::Uuid,
/// Allow self-signed certificates (for testing).
allow_self_signed_compute: bool,
}
impl<'a, S> Client<'a, S> {
@@ -429,12 +441,14 @@ impl<'a, S> Client<'a, S> {
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
session_id: uuid::Uuid,
allow_self_signed_compute: bool,
) -> Self {
Self {
stream,
creds,
params,
session_id,
allow_self_signed_compute,
}
}
}
@@ -451,6 +465,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
mut creds,
params,
session_id,
allow_self_signed_compute,
} = self;
let extra = console::ConsoleReqExtra {
@@ -473,6 +488,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
value: mut node_info,
} = auth_result;
node_info.allow_self_signed_compute = allow_self_signed_compute;
let mut node = connect_to_compute(&mut node_info, params, &extra, &creds)
.or_else(|e| stream.throw_error(e))
.await?;

View File

@@ -1820,6 +1820,24 @@ class VanillaPostgres(PgProtocol):
self.pg_bin.run_capture(["initdb", "-D", str(pgdatadir)])
self.configure([f"port = {port}\n"])
def enable_tls(self):
assert not self.running
# generate self-signed certificate
subprocess.run(
["openssl", "req", "-new", "-x509", "-days", "365", "-nodes", "-text",
"-out", self.pgdatadir / "server.crt",
"-keyout", self.pgdatadir / "server.key",
"-subj", "/CN=localhost"]
)
# configure postgresql.conf
self.configure(
[
"ssl = on",
"ssl_cert_file = 'server.crt'",
"ssl_key_file = 'server.key'",
]
)
def configure(self, options: List[str]):
"""Append lines into postgresql.conf file."""
assert not self.running
@@ -1992,6 +2010,7 @@ class NeonProxy(PgProtocol):
# Link auth backend params
*["--auth-backend", "link"],
*["--uri", NeonProxy.link_auth_uri],
*["--allow-self-signed-compute", "true"],
]
@dataclass(frozen=True)
@@ -2012,6 +2031,7 @@ class NeonProxy(PgProtocol):
def __init__(
self,
neon_binpath: Path,
test_output_dir: Path,
proxy_port: int,
http_port: int,
mgmt_port: int,
@@ -2025,6 +2045,7 @@ class NeonProxy(PgProtocol):
self.host = host
self.http_port = http_port
self.neon_binpath = neon_binpath
self.test_output_dir = test_output_dir
self.proxy_port = proxy_port
self.mgmt_port = mgmt_port
self.auth_backend = auth_backend
@@ -2051,7 +2072,8 @@ class NeonProxy(PgProtocol):
*["--metric-collection-interval", self.metric_collection_interval],
]
self._popen = subprocess.Popen(args)
logfile = open(self.test_output_dir / "proxy.log", "w")
self._popen = subprocess.Popen(args, stdout=logfile, stderr=logfile)
self._wait_until_ready()
return self
@@ -2119,6 +2141,7 @@ class NeonProxy(PgProtocol):
if create_user:
log.info("creating a new user for link auth test")
local_vanilla_pg.enable_tls()
local_vanilla_pg.start()
local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser")
@@ -2152,7 +2175,7 @@ class NeonProxy(PgProtocol):
@pytest.fixture(scope="function")
def link_proxy(port_distributor: PortDistributor, neon_binpath: Path) -> Iterator[NeonProxy]:
def link_proxy(port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path) -> Iterator[NeonProxy]:
"""Neon proxy that routes through link auth."""
http_port = port_distributor.get_port()
@@ -2161,6 +2184,7 @@ def link_proxy(port_distributor: PortDistributor, neon_binpath: Path) -> Iterato
with NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
@@ -2172,7 +2196,8 @@ def link_proxy(port_distributor: PortDistributor, neon_binpath: Path) -> Iterato
@pytest.fixture(scope="function")
def static_proxy(
vanilla_pg: VanillaPostgres, port_distributor: PortDistributor, neon_binpath: Path
vanilla_pg: VanillaPostgres, port_distributor: PortDistributor, neon_binpath: Path,
test_output_dir: Path
) -> Iterator[NeonProxy]:
"""Neon proxy that routes directly to vanilla postgres."""
@@ -2191,6 +2216,7 @@ def static_proxy(
with NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,

View File

@@ -201,7 +201,8 @@ def proxy_metrics_handler(request: Request) -> Response:
@pytest.fixture(scope="session")
def proxy_with_metric_collector(
port_distributor: PortDistributor, neon_binpath: Path, httpserver_listen_address
port_distributor: PortDistributor, neon_binpath: Path, httpserver_listen_address,
test_output_dir: Path
) -> Iterator[NeonProxy]:
"""Neon proxy that routes through link auth and has metric collection enabled."""
@@ -215,6 +216,7 @@ def proxy_with_metric_collector(
with NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,

View File

@@ -37,7 +37,6 @@ class PgSniRouter(PgProtocol):
neon_binpath: Path,
port: int,
destination: str,
destination_port: int,
tls_cert: Path,
tls_key: Path,
):
@@ -49,7 +48,6 @@ class PgSniRouter(PgProtocol):
self.neon_binpath = neon_binpath
self.port = port
self.destination = destination
self.destination_port = destination_port
self.tls_cert = tls_cert
self.tls_key = tls_key
self._popen: Optional[subprocess.Popen[bytes]] = None
@@ -62,7 +60,6 @@ class PgSniRouter(PgProtocol):
*["--tls-cert", self.tls_cert],
*["--tls-key", self.tls_key],
*["--destination", self.destination],
*["--destination-port", str(self.destination_port)],
]
self._popen = subprocess.Popen(args)
@@ -110,7 +107,7 @@ def test_pg_sni_router(
):
generate_tls_cert(
"external.test", test_output_dir / "router.crt", test_output_dir / "router.key"
"endpoint.namespace.localtest.me", test_output_dir / "router.crt", test_output_dir / "router.key"
)
# Start a stand-alone Postgres to test with
@@ -122,8 +119,7 @@ def test_pg_sni_router(
with PgSniRouter(
neon_binpath=neon_binpath,
port=router_port,
destination="localhost",
destination_port=pg_port,
destination="localtest.me",
tls_cert=test_output_dir / "router.crt",
tls_key=test_output_dir / "router.key",
) as router:
@@ -133,7 +129,7 @@ def test_pg_sni_router(
"select 1",
dbname="postgres",
sslmode="require",
host="localhost.external.test",
host=f"endpoint--namespace--{pg_port}.localtest.me",
hostaddr="127.0.0.1",
)
assert out[0][0] == 1