chore(proxy): enforce single host+port (#9995)

proxy doesn't ever provide multiple hosts/ports, so this code adds a lot
of complexity of error handling for no good reason.

(stacked on #9990)
This commit is contained in:
Conrad Ludgate
2024-12-03 20:00:14 +00:00
committed by Ivan Efremov
parent 5519e42612
commit fbc8c36983
11 changed files with 59 additions and 142 deletions

View File

@@ -161,12 +161,8 @@ async fn authenticate(
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port);
config.dbname(&db_info.dbname).user(&db_info.user);
ctx.set_dbname(db_info.dbname.into());
ctx.set_user(db_info.user.into());

View File

@@ -29,12 +29,7 @@ impl LocalBackend {
api: http::Endpoint::new(compute_ctl, http::new_client()),
},
node_info: NodeInfo {
config: {
let mut cfg = ConnCfg::new();
cfg.host(&postgres_addr.ip().to_string());
cfg.port(postgres_addr.port());
cfg
},
config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()),
// TODO(conrad): make this better reflect compute info rather than endpoint info.
aux: MetricsAuxInfo {
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),

View File

@@ -104,13 +104,13 @@ pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
/// A config for establishing a connection to compute node.
/// Eventually, `postgres_client` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
#[derive(Clone, Default)]
#[derive(Clone)]
pub(crate) struct ConnCfg(Box<postgres_client::Config>);
/// Creation and initialization routines.
impl ConnCfg {
pub(crate) fn new() -> Self {
Self::default()
pub(crate) fn new(host: String, port: u16) -> Self {
Self(Box::new(postgres_client::Config::new(host, port)))
}
/// Reuse password or auth keys from the other config.
@@ -124,13 +124,9 @@ impl ConnCfg {
}
}
pub(crate) fn get_host(&self) -> Result<Host, WakeComputeError> {
match self.0.get_hosts() {
[postgres_client::config::Host::Tcp(s)] => Ok(s.into()),
// we should not have multiple address or unix addresses.
_ => Err(WakeComputeError::BadComputeAddress(
"invalid compute address".into(),
)),
pub(crate) fn get_host(&self) -> Host {
match self.0.get_host() {
postgres_client::config::Host::Tcp(s) => s.into(),
}
}
@@ -227,43 +223,20 @@ impl ConnCfg {
// We can't reuse connection establishing logic from `postgres_client` here,
// because it has no means for extracting the underlying socket which we
// require for our business.
let mut connection_error = None;
let ports = self.0.get_ports();
let hosts = self.0.get_hosts();
// the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
if ports.len() > 1 && ports.len() != hosts.len() {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"bad compute config, \
ports and hosts entries' count does not match: {:?}",
self.0
),
));
}
let port = self.0.get_port();
let host = self.0.get_host();
for (i, host) in hosts.iter().enumerate() {
let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
let host = match host {
Host::Tcp(host) => host.as_str(),
};
let host = match host {
Host::Tcp(host) => host.as_str(),
};
match connect_once(host, *port).await {
Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
Err(err) => {
// We can't throw an error here, as there might be more hosts to try.
warn!("couldn't connect to compute node at {host}:{port}: {err}");
connection_error = Some(err);
}
match connect_once(host, port).await {
Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
Err(err) => {
warn!("couldn't connect to compute node at {host}:{port}: {err}");
Err(err)
}
}
Err(connection_error.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
format!("bad compute config: {:?}", self.0),
)
}))
}
}

View File

@@ -160,11 +160,11 @@ impl MockControlPlane {
}
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new();
config
.host(self.endpoint.host_str().unwrap_or("localhost"))
.port(self.endpoint.port().unwrap_or(5432))
.ssl_mode(postgres_client::config::SslMode::Disable);
let mut config = compute::ConnCfg::new(
self.endpoint.host_str().unwrap_or("localhost").to_owned(),
self.endpoint.port().unwrap_or(5432),
);
config.ssl_mode(postgres_client::config::SslMode::Disable);
let node = NodeInfo {
config,

View File

@@ -241,8 +241,8 @@ impl NeonControlPlaneClient {
// Don't set anything but host and port! This config will be cached.
// We'll set username and such later using the startup message.
// TODO: add more type safety (in progress).
let mut config = compute::ConnCfg::new();
config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
let mut config = compute::ConnCfg::new(host.to_owned(), port);
config.ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
let node = NodeInfo {
config,

View File

@@ -86,7 +86,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
node_info: &control_plane::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> {
let host = node_info.config.get_host()?;
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
permit.release_result(node_info.connect(ctx, timeout).await)
}

View File

@@ -158,7 +158,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
Scram::new("password").await?,
));
let _client_err = postgres_client::Config::new()
let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
.channel_binding(postgres_client::config::ChannelBinding::Disable)
.user("user")
.dbname("db")
@@ -241,7 +241,7 @@ async fn connect_failure(
Scram::new("password").await?,
));
let _client_err = postgres_client::Config::new()
let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
.channel_binding(channel_binding)
.user("user")
.dbname("db")

View File

@@ -204,7 +204,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let client_err = postgres_client::Config::new()
let client_err = postgres_client::Config::new("test".to_owned(), 5432)
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Disable)
@@ -233,7 +233,7 @@ async fn handshake_tls() -> anyhow::Result<()> {
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let _conn = postgres_client::Config::new()
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Require)
@@ -249,7 +249,7 @@ async fn handshake_raw() -> anyhow::Result<()> {
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
let _conn = postgres_client::Config::new()
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.user("john_doe")
.dbname("earth")
.options("project=generic-project-name")
@@ -296,7 +296,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
Scram::new(password).await?,
));
let _conn = postgres_client::Config::new()
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.channel_binding(postgres_client::config::ChannelBinding::Require)
.user("user")
.dbname("db")
@@ -320,7 +320,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
Scram::new("password").await?,
));
let _conn = postgres_client::Config::new()
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.channel_binding(postgres_client::config::ChannelBinding::Disable)
.user("user")
.dbname("db")
@@ -348,7 +348,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
.map(char::from)
.collect();
let _client_err = postgres_client::Config::new()
let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
.user("user")
.dbname("db")
.password(&password) // no password will match the mocked secret
@@ -546,7 +546,7 @@ impl TestControlPlaneClient for TestConnectMechanism {
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = NodeInfo {
config: compute::ConnCfg::new(),
config: compute::ConnCfg::new("test".to_owned(), 5432),
aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),

View File

@@ -499,7 +499,7 @@ impl ConnectMechanism for TokioMechanism {
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host()?;
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
let mut config = (*node_info.config).clone();
@@ -549,16 +549,12 @@ impl ConnectMechanism for HyperMechanism {
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host()?;
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let port = *node_info.config.get_ports().first().ok_or_else(|| {
HttpConnError::WakeCompute(WakeComputeError::BadComputeAddress(
"local-proxy port missing on compute address".into(),
))
})?;
let port = node_info.config.get_port();
let res = connect_http2(&host, port, timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;