diff --git a/Cargo.lock b/Cargo.lock index 83ed5f9263..9f55fb71ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -601,6 +601,7 @@ dependencies = [ "once_cell", "pageserver_api", "postgres", + "postgres_connection", "regex", "reqwest", "safekeeper_api", @@ -2405,6 +2406,18 @@ dependencies = [ "postgres-protocol", ] +[[package]] +name = "postgres_connection" +version = "0.1.0" +dependencies = [ + "anyhow", + "once_cell", + "postgres", + "tokio-postgres", + "url", + "workspace_hack", +] + [[package]] name = "postgres_ffi" version = "0.1.0" diff --git a/control_plane/Cargo.toml b/control_plane/Cargo.toml index a9d30b4a86..2ab48fa76c 100644 --- a/control_plane/Cargo.toml +++ b/control_plane/Cargo.toml @@ -23,6 +23,7 @@ url = "2.2.2" # Note: Do not directly depend on pageserver or safekeeper; use pageserver_api or safekeeper_api # instead, so that recompile times are better. pageserver_api = { path = "../libs/pageserver_api" } +postgres_connection = { path = "../libs/postgres_connection" } safekeeper_api = { path = "../libs/safekeeper_api" } utils = { path = "../libs/utils" } workspace_hack = { version = "0.1", path = "../workspace_hack" } diff --git a/control_plane/src/connection.rs b/control_plane/src/connection.rs deleted file mode 100644 index cca837de6e..0000000000 --- a/control_plane/src/connection.rs +++ /dev/null @@ -1,57 +0,0 @@ -use url::Url; - -#[derive(Debug)] -pub struct PgConnectionConfig { - url: Url, -} - -impl PgConnectionConfig { - pub fn host(&self) -> &str { - self.url.host_str().expect("BUG: no host") - } - - pub fn port(&self) -> u16 { - self.url.port().expect("BUG: no port") - } - - /// Return a `:` string. - pub fn raw_address(&self) -> String { - format!("{}:{}", self.host(), self.port()) - } - - /// Connect using postgres protocol with TLS disabled. - pub fn connect_no_tls(&self) -> Result { - postgres::Client::connect(self.url.as_str(), postgres::NoTls) - } -} - -impl std::str::FromStr for PgConnectionConfig { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - let mut url: Url = s.parse()?; - - match url.scheme() { - "postgres" | "postgresql" => {} - other => anyhow::bail!("invalid scheme: {other}"), - } - - // It's not a valid connection url if host is unavailable. - if url.host().is_none() { - anyhow::bail!(url::ParseError::EmptyHost); - } - - // E.g. `postgres:bar`. - if url.cannot_be_a_base() { - anyhow::bail!("URL cannot be a base"); - } - - // Set the default PG port if it's missing. - if url.port().is_none() { - url.set_port(Some(5432)) - .expect("BUG: couldn't set the default port"); - } - - Ok(Self { url }) - } -} diff --git a/control_plane/src/lib.rs b/control_plane/src/lib.rs index c3b47fe81b..7c1007b133 100644 --- a/control_plane/src/lib.rs +++ b/control_plane/src/lib.rs @@ -9,7 +9,6 @@ mod background_process; pub mod compute; -pub mod connection; pub mod etcd; pub mod local_env; pub mod pageserver; diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index aec6f5bc2c..1736b1e9fe 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -6,11 +6,11 @@ use std::path::{Path, PathBuf}; use std::process::Child; use std::{io, result}; -use crate::connection::PgConnectionConfig; use anyhow::{bail, Context}; use pageserver_api::models::{ TenantConfigRequest, TenantCreateRequest, TenantInfo, TimelineCreateRequest, TimelineInfo, }; +use postgres_connection::{parse_host_port, PgConnectionConfig}; use reqwest::blocking::{Client, RequestBuilder, Response}; use reqwest::{IntoUrl, Method}; use thiserror::Error; @@ -77,30 +77,24 @@ pub struct PageServerNode { impl PageServerNode { pub fn from_env(env: &LocalEnv) -> PageServerNode { + let (host, port) = parse_host_port(&env.pageserver.listen_pg_addr) + .expect("Unable to parse listen_pg_addr"); + let port = port.unwrap_or(5432); let password = if env.pageserver.auth_type == AuthType::NeonJWT { - &env.pageserver.auth_token + Some(env.pageserver.auth_token.clone()) } else { - "" + None }; Self { - pg_connection_config: Self::pageserver_connection_config( - password, - &env.pageserver.listen_pg_addr, - ), + pg_connection_config: PgConnectionConfig::new_host_port(host, port) + .set_password(password), env: env.clone(), http_client: Client::new(), http_base_url: format!("http://{}/v1", env.pageserver.listen_http_addr), } } - /// Construct libpq connection string for connecting to the pageserver. - fn pageserver_connection_config(password: &str, listen_addr: &str) -> PgConnectionConfig { - format!("postgresql://no_user:{password}@{listen_addr}/no_db") - .parse() - .unwrap() - } - pub fn initialize( &self, create_tenant: Option, diff --git a/control_plane/src/safekeeper.rs b/control_plane/src/safekeeper.rs index 0bc35b3680..57aee9cfeb 100644 --- a/control_plane/src/safekeeper.rs +++ b/control_plane/src/safekeeper.rs @@ -5,12 +5,12 @@ use std::sync::Arc; use std::{io, result}; use anyhow::Context; +use postgres_connection::PgConnectionConfig; use reqwest::blocking::{Client, RequestBuilder, Response}; use reqwest::{IntoUrl, Method}; use thiserror::Error; use utils::{http::error::HttpErrorBody, id::NodeId}; -use crate::connection::PgConnectionConfig; use crate::pageserver::PageServerNode; use crate::{ background_process, @@ -86,10 +86,7 @@ impl SafekeeperNode { /// Construct libpq connection string for connecting to this safekeeper. fn safekeeper_connection_config(port: u16) -> PgConnectionConfig { - // TODO safekeeper authentication not implemented yet - format!("postgresql://no_user@127.0.0.1:{port}/no_db") - .parse() - .unwrap() + PgConnectionConfig::new_host_port(url::Host::parse("127.0.0.1").unwrap(), port) } pub fn datadir_path_by_id(env: &LocalEnv, sk_id: NodeId) -> PathBuf { diff --git a/libs/postgres_connection/Cargo.toml b/libs/postgres_connection/Cargo.toml new file mode 100644 index 0000000000..e7b5b49077 --- /dev/null +++ b/libs/postgres_connection/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "postgres_connection" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "1.0" +postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" } +tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" } +url = "2.2.2" +workspace_hack = { version = "0.1", path = "../../workspace_hack" } + +[dev-dependencies] +once_cell = "1.13.0" diff --git a/libs/postgres_connection/src/lib.rs b/libs/postgres_connection/src/lib.rs new file mode 100644 index 0000000000..7edd5b7be6 --- /dev/null +++ b/libs/postgres_connection/src/lib.rs @@ -0,0 +1,196 @@ +use anyhow::{bail, Context}; +use std::fmt; +use url::Host; + +/// Parses a string of format either `host:port` or `host` into a corresponding pair. +/// The `host` part should be a correct `url::Host`, while `port` (if present) should be +/// a valid decimal u16 of digits only. +pub fn parse_host_port>(host_port: S) -> Result<(Host, Option), anyhow::Error> { + let (host, port) = match host_port.as_ref().rsplit_once(':') { + Some((host, port)) => ( + host, + // +80 is a valid u16, but not a valid port + if port.chars().all(|c| c.is_ascii_digit()) { + Some(port.parse::().context("Unable to parse port")?) + } else { + bail!("Port contains a non-ascii-digit") + }, + ), + None => (host_port.as_ref(), None), // No colons, no port specified + }; + let host = Host::parse(host).context("Unable to parse host")?; + Ok((host, port)) +} + +#[cfg(test)] +mod tests_parse_host_port { + use crate::parse_host_port; + use url::Host; + + #[test] + fn test_normal() { + let (host, port) = parse_host_port("hello:123").unwrap(); + assert_eq!(host, Host::Domain("hello".to_owned())); + assert_eq!(port, Some(123)); + } + + #[test] + fn test_no_port() { + let (host, port) = parse_host_port("hello").unwrap(); + assert_eq!(host, Host::Domain("hello".to_owned())); + assert_eq!(port, None); + } + + #[test] + fn test_ipv6() { + let (host, port) = parse_host_port("[::1]:123").unwrap(); + assert_eq!(host, Host::::Ipv6(std::net::Ipv6Addr::LOCALHOST)); + assert_eq!(port, Some(123)); + } + + #[test] + fn test_invalid_host() { + assert!(parse_host_port("hello world").is_err()); + } + + #[test] + fn test_invalid_port() { + assert!(parse_host_port("hello:+80").is_err()); + } +} + +pub struct PgConnectionConfig { + host: Host, + port: u16, + password: Option, +} + +/// A simplified PostgreSQL connection configuration. Supports only a subset of possible +/// settings for simplicity. A password getter or `to_connection_string` methods are not +/// added by design to avoid accidentally leaking password through logging, command line +/// arguments to a child process, or likewise. +impl PgConnectionConfig { + pub fn new_host_port(host: Host, port: u16) -> Self { + PgConnectionConfig { + host, + port, + password: None, + } + } + + pub fn host(&self) -> &Host { + &self.host + } + + pub fn port(&self) -> u16 { + self.port + } + + pub fn set_host(mut self, h: Host) -> Self { + self.host = h; + self + } + + pub fn set_port(mut self, p: u16) -> Self { + self.port = p; + self + } + + pub fn set_password(mut self, s: Option) -> Self { + self.password = s; + self + } + + /// Return a `:` string. + pub fn raw_address(&self) -> String { + format!("{}:{}", self.host(), self.port()) + } + + /// Build a client library-specific connection configuration. + /// Used for testing and when we need to add some obscure configuration + /// elements at the last moment. + pub fn to_tokio_postgres_config(&self) -> tokio_postgres::Config { + // Use `tokio_postgres::Config` instead of `postgres::Config` because + // the former supports more options to fiddle with later. + let mut config = tokio_postgres::Config::new(); + config.host(&self.host().to_string()).port(self.port); + if let Some(password) = &self.password { + config.password(password); + } + config + } + + /// Connect using postgres protocol with TLS disabled. + pub fn connect_no_tls(&self) -> Result { + postgres::Config::from(self.to_tokio_postgres_config()).connect(postgres::NoTls) + } +} + +impl fmt::Debug for PgConnectionConfig { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // We want `password: Some(REDACTED-STRING)`, not `password: Some("REDACTED-STRING")` + // so even if the password is `REDACTED-STRING` (quite unlikely) there is no confusion. + // Hence `format_args!()`, it returns a "safe" string which is not escaped by `Debug`. + f.debug_struct("PgConnectionConfig") + .field("host", &self.host) + .field("port", &self.port) + .field( + "password", + &self + .password + .as_ref() + .map(|_| format_args!("REDACTED-STRING")), + ) + .finish() + } +} + +#[cfg(test)] +mod tests_pg_connection_config { + use crate::PgConnectionConfig; + use once_cell::sync::Lazy; + use url::Host; + + static STUB_HOST: Lazy = Lazy::new(|| Host::Domain("stub.host.example".to_owned())); + + #[test] + fn test_no_password() { + let cfg = PgConnectionConfig::new_host_port(STUB_HOST.clone(), 123); + assert_eq!(cfg.host(), &*STUB_HOST); + assert_eq!(cfg.port(), 123); + assert_eq!(cfg.raw_address(), "stub.host.example:123"); + assert_eq!( + format!("{:?}", cfg), + "PgConnectionConfig { host: Domain(\"stub.host.example\"), port: 123, password: None }" + ); + } + + #[test] + fn test_ipv6() { + // May be a special case because hostname contains a colon. + let cfg = PgConnectionConfig::new_host_port(Host::parse("[::1]").unwrap(), 123); + assert_eq!( + cfg.host(), + &Host::::Ipv6(std::net::Ipv6Addr::LOCALHOST) + ); + assert_eq!(cfg.port(), 123); + assert_eq!(cfg.raw_address(), "[::1]:123"); + assert_eq!( + format!("{:?}", cfg), + "PgConnectionConfig { host: Ipv6(::1), port: 123, password: None }" + ); + } + + #[test] + fn test_with_password() { + let cfg = PgConnectionConfig::new_host_port(STUB_HOST.clone(), 123) + .set_password(Some("password".to_owned())); + assert_eq!(cfg.host(), &*STUB_HOST); + assert_eq!(cfg.port(), 123); + assert_eq!(cfg.raw_address(), "stub.host.example:123"); + assert_eq!( + format!("{:?}", cfg), + "PgConnectionConfig { host: Domain(\"stub.host.example\"), port: 123, password: Some(REDACTED-STRING) }" + ); + } +}