diff --git a/Cargo.lock b/Cargo.lock index 603e034ed3..2e300e46f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2271,6 +2271,7 @@ dependencies = [ "hex", "hmac 0.12.1", "hyper", + "itertools", "md5", "metrics", "once_cell", diff --git a/libs/utils/src/pq_proto.rs b/libs/utils/src/pq_proto.rs index 2f8dcf31d3..dde76039d7 100644 --- a/libs/utils/src/pq_proto.rs +++ b/libs/utils/src/pq_proto.rs @@ -7,11 +7,14 @@ use anyhow::{bail, ensure, Context, Result}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use postgres_protocol::PG_EPOCH; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::future::Future; -use std::io::{self, Cursor}; -use std::str; -use std::time::{Duration, SystemTime}; +use std::{ + borrow::Cow, + collections::HashMap, + future::Future, + io::{self, Cursor}, + str, + time::{Duration, SystemTime}, +}; use tokio::io::AsyncReadExt; use tracing::{trace, warn}; @@ -53,7 +56,67 @@ pub enum FeStartupPacket { }, } -pub type StartupMessageParams = HashMap; +#[derive(Debug)] +pub struct StartupMessageParams { + params: HashMap, +} + +impl StartupMessageParams { + /// Get parameter's value by its name. + pub fn get(&self, name: &str) -> Option<&str> { + self.params.get(name).map(|s| s.as_str()) + } + + /// Split command-line options according to PostgreSQL's logic, + /// taking into account all escape sequences but leaving them as-is. + /// [`None`] means that there's no `options` in [`Self`]. + pub fn options_raw(&self) -> Option> { + // See `postgres: pg_split_opts`. + let mut last_was_escape = false; + let iter = self + .get("options")? + .split(move |c: char| { + // We split by non-escaped whitespace symbols. + let should_split = c.is_ascii_whitespace() && !last_was_escape; + last_was_escape = c == '\\' && !last_was_escape; + should_split + }) + .filter(|s| !s.is_empty()); + + Some(iter) + } + + /// Split command-line options according to PostgreSQL's logic, + /// applying all escape sequences (using owned strings as needed). + /// [`None`] means that there's no `options` in [`Self`]. + pub fn options_escaped(&self) -> Option>> { + // See `postgres: pg_split_opts`. + let iter = self.options_raw()?.map(|s| { + let mut preserve_next_escape = false; + let escape = |c| { + // We should remove '\\' unless it's preceded by '\\'. + let should_remove = c == '\\' && !preserve_next_escape; + preserve_next_escape = should_remove; + should_remove + }; + + match s.contains('\\') { + true => Cow::Owned(s.replace(escape, "")), + false => Cow::Borrowed(s), + } + }); + + Some(iter) + } + + // This function is mostly useful in tests. + #[doc(hidden)] + pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self { + Self { + params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(), + } + } +} #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] pub struct CancelKeyData { @@ -237,9 +300,9 @@ impl FeStartupPacket { stream.read_exact(params_bytes.as_mut()).await?; // Parse params depending on request code - let most_sig_16_bits = request_code >> 16; - let least_sig_16_bits = request_code & ((1 << 16) - 1); - let message = match (most_sig_16_bits, least_sig_16_bits) { + let req_hi = request_code >> 16; + let req_lo = request_code & ((1 << 16) - 1); + let message = match (req_hi, req_lo) { (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { ensure!(params_len == 8, "expected 8 bytes for CancelRequest params"); let mut cursor = Cursor::new(params_bytes); @@ -248,49 +311,44 @@ impl FeStartupPacket { cancel_key: cursor.read_i32().await?, }) } - (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => FeStartupPacket::SslRequest, + (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => { + // Requested upgrade to SSL (aka TLS) + FeStartupPacket::SslRequest + } (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => { + // Requested upgrade to GSSAPI FeStartupPacket::GssEncRequest } (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { bail!("Unrecognized request code {}", unrecognized_code) } + // TODO bail if protocol major_version is not 3? (major_version, minor_version) => { - // TODO bail if protocol major_version is not 3? - // Parse null-terminated (String) pairs of param name / param value - let params_str = str::from_utf8(¶ms_bytes).unwrap(); - let mut params_tokens = params_str.split('\0'); - let mut params: HashMap = HashMap::new(); - while let Some(name) = params_tokens.next() { - let value = params_tokens + // Parse pairs of null-terminated strings (key, value). + // See `postgres: ProcessStartupPacket, build_startup_packet`. + let mut tokens = str::from_utf8(¶ms_bytes) + .context("StartupMessage params: invalid utf-8")? + .strip_suffix('\0') // drop packet's own null terminator + .context("StartupMessage params: missing null terminator")? + .split_terminator('\0'); + + let mut params = HashMap::new(); + while let Some(name) = tokens.next() { + let value = tokens .next() - .context("expected even number of params in StartupMessage")?; - if name == "options" { - // parsing options arguments "...&options=%3D+=..." - // '%3D' is '=' and '+' is ' ' + .context("StartupMessage params: key without value")?; - // Note: we allow users that don't have SNI capabilities, - // to pass a special keyword argument 'project' - // to be used to determine the cluster name by the proxy. - - //TODO: write unit test for this and refactor in its own function. - for cmdopt in value.split(' ') { - let nameval: Vec<&str> = cmdopt.split('=').collect(); - if nameval.len() == 2 { - params.insert(nameval[0].to_string(), nameval[1].to_string()); - } - } - } else { - params.insert(name.to_string(), value.to_string()); - } + params.insert(name.to_owned(), value.to_owned()); } + FeStartupPacket::StartupMessage { major_version, minor_version, - params, + params: StartupMessageParams { params }, } } }; + Ok(Some(FeMessage::StartupPacket(message))) }) } @@ -967,6 +1025,33 @@ mod tests { assert_eq!(zf, zf_parsed); } + #[test] + fn test_startup_message_params_options_escaped() { + fn split_options(params: &StartupMessageParams) -> Vec> { + params + .options_escaped() + .expect("options are None") + .collect() + } + + let make_params = |options| StartupMessageParams::new([("options", options)]); + + let params = StartupMessageParams::new([]); + assert!(matches!(params.options_escaped(), None)); + + let params = make_params(""); + assert!(split_options(¶ms).is_empty()); + + let params = make_params("foo"); + assert_eq!(split_options(¶ms), ["foo"]); + + let params = make_params(" foo bar "); + assert_eq!(split_options(¶ms), ["foo", "bar"]); + + let params = make_params("foo\\ bar \\ \\\\ baz\\ lol"); + assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]); + } + // Make sure that `read` is sync/async callable async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) { let _ = FeMessage::read(&mut [].as_ref()); diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index d3f7ea5fdc..5a450793f1 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -15,6 +15,7 @@ hashbrown = "0.12" hex = "0.4.3" hmac = "0.12.1" hyper = "0.14" +itertools = "0.10.3" once_cell = "1.13.0" md5 = "0.7.0" parking_lot = "0.12" diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index bb7e7ef67b..9c43620ffb 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -127,7 +127,7 @@ impl BackendType> { } } -impl BackendType { +impl BackendType> { /// Authenticate the client via the requested backend, possibly using credentials. pub async fn authenticate( mut self, @@ -149,7 +149,7 @@ impl BackendType { // Finally we may finish the initialization of `creds`. // TODO: add missing type safety to ClientCredentials. - creds.project = Some(payload.project); + creds.project = Some(payload.project.into()); let mut config = match &self { Console(creds) => { diff --git a/proxy/src/auth/backend/console.rs b/proxy/src/auth/backend/console.rs index 87906679ea..e239320e9b 100644 --- a/proxy/src/auth/backend/console.rs +++ b/proxy/src/auth/backend/console.rs @@ -121,7 +121,7 @@ pub enum AuthInfo { #[must_use] pub(super) struct Api<'a> { endpoint: &'a ApiUrl, - creds: &'a ClientCredentials, + creds: &'a ClientCredentials<'a>, } impl<'a> Api<'a> { @@ -143,7 +143,7 @@ impl<'a> Api<'a> { url.path_segments_mut().push("proxy_get_role_secret"); url.query_pairs_mut() .append_pair("project", self.creds.project().expect("impossible")) - .append_pair("role", &self.creds.user); + .append_pair("role", self.creds.user); // TODO: use a proper logger println!("cplane request: {url}"); @@ -187,8 +187,8 @@ impl<'a> Api<'a> { config .host(host) .port(port) - .dbname(&self.creds.dbname) - .user(&self.creds.user); + .dbname(self.creds.dbname) + .user(self.creds.user); Ok(config) } diff --git a/proxy/src/auth/backend/legacy_console.rs b/proxy/src/auth/backend/legacy_console.rs index 17ba44e833..b99a004dcd 100644 --- a/proxy/src/auth/backend/legacy_console.rs +++ b/proxy/src/auth/backend/legacy_console.rs @@ -56,7 +56,7 @@ enum ProxyAuthResponse { NotReady { ready: bool }, // TODO: get rid of `ready` } -impl ClientCredentials { +impl ClientCredentials<'_> { fn is_existing_user(&self) -> bool { self.user.ends_with("@zenith") } @@ -64,15 +64,15 @@ impl ClientCredentials { async fn authenticate_proxy_client( auth_endpoint: &reqwest::Url, - creds: &ClientCredentials, + creds: &ClientCredentials<'_>, md5_response: &str, salt: &[u8; 4], psql_session_id: &str, ) -> Result { let mut url = auth_endpoint.clone(); url.query_pairs_mut() - .append_pair("login", &creds.user) - .append_pair("database", &creds.dbname) + .append_pair("login", creds.user) + .append_pair("database", creds.dbname) .append_pair("md5response", md5_response) .append_pair("salt", &hex::encode(salt)) .append_pair("psql_session_id", psql_session_id); @@ -103,7 +103,7 @@ async fn authenticate_proxy_client( async fn handle_existing_user( auth_endpoint: &reqwest::Url, client: &mut PqStream, - creds: &ClientCredentials, + creds: &ClientCredentials<'_>, ) -> auth::Result { let psql_session_id = super::link::new_psql_session_id(); let md5_salt = rand::random(); @@ -136,7 +136,7 @@ async fn handle_existing_user( pub async fn handle_user( auth_endpoint: &reqwest::Url, auth_link_uri: &reqwest::Url, - creds: &ClientCredentials, + creds: &ClientCredentials<'_>, client: &mut PqStream, ) -> auth::Result { if creds.is_existing_user() { diff --git a/proxy/src/auth/backend/postgres.rs b/proxy/src/auth/backend/postgres.rs index 183fa52ec1..2055ee14c8 100644 --- a/proxy/src/auth/backend/postgres.rs +++ b/proxy/src/auth/backend/postgres.rs @@ -17,7 +17,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; #[must_use] pub(super) struct Api<'a> { endpoint: &'a ApiUrl, - creds: &'a ClientCredentials, + creds: &'a ClientCredentials<'a>, } // Helps eliminate graceless `.map_err` calls without introducing another ctor. @@ -87,8 +87,8 @@ impl<'a> Api<'a> { config .host(self.endpoint.host_str().unwrap_or("localhost")) .port(self.endpoint.port().unwrap_or(5432)) - .dbname(&self.creds.dbname) - .user(&self.creds.user); + .dbname(self.creds.dbname) + .user(self.creds.user); Ok(config) } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 4c72da1c48..ea71eba010 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -1,6 +1,7 @@ //! User credentials used in authentication. use crate::error::UserFacingError; +use std::borrow::Cow; use thiserror::Error; use utils::pq_proto::StartupMessageParams; @@ -27,51 +28,59 @@ impl UserFacingError for ClientCredsParseError {} /// Various client credentials which we use for authentication. /// Note that we don't store any kind of client key or password here. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct ClientCredentials { - pub user: String, - pub dbname: String, - pub project: Option, +pub struct ClientCredentials<'a> { + pub user: &'a str, + pub dbname: &'a str, + pub project: Option>, } -impl ClientCredentials { +impl ClientCredentials<'_> { pub fn project(&self) -> Option<&str> { self.project.as_deref() } } -impl ClientCredentials { +impl<'a> ClientCredentials<'a> { pub fn parse( - mut options: StartupMessageParams, + params: &'a StartupMessageParams, sni: Option<&str>, common_name: Option<&str>, ) -> Result { use ClientCredsParseError::*; - // Some parameters are absolutely necessary, others not so much. - let mut get_param = |key| options.remove(key).ok_or(MissingKey(key)); - // Some parameters are stored in the startup message. + let get_param = |key| params.get(key).ok_or(MissingKey(key)); let user = get_param("user")?; let dbname = get_param("database")?; - let project_a = get_param("project").ok(); + + // Project name might be passed via PG's command-line options. + let project_a = params.options_raw().and_then(|options| { + for opt in options { + if let Some(value) = opt.strip_prefix("project=") { + return Some(Cow::Borrowed(value)); + } + } + None + }); // Alternative project name is in fact a subdomain from SNI. // NOTE: we do not consider SNI if `common_name` is missing. let project_b = sni .zip(common_name) .map(|(sni, cn)| { - // TODO: what if SNI is present but just a common name? subdomain_from_sni(sni, cn) - .ok_or_else(|| InconsistentSni(sni.to_owned(), cn.to_owned())) + .ok_or_else(|| InconsistentSni(sni.into(), cn.into())) + .map(Cow::<'static, str>::Owned) }) .transpose()?; let project = match (project_a, project_b) { // Invariant: if we have both project name variants, they should match. - (Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a, b))), - (a, b) => a.or(b).map(|name| { - // Invariant: project name may not contain certain characters. - check_project_name(name).map_err(MalformedProjectName) + (Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a.into(), b.into()))), + // Invariant: project name may not contain certain characters. + (a, b) => a.or(b).map(|name| match project_name_valid(&name) { + false => Err(MalformedProjectName(name.into())), + true => Ok(name), }), } .transpose()?; @@ -84,12 +93,8 @@ impl ClientCredentials { } } -fn check_project_name(name: String) -> Result { - if name.chars().all(|c| c.is_alphanumeric() || c == '-') { - Ok(name) - } else { - Err(name) - } +fn project_name_valid(name: &str) -> bool { + name.chars().all(|c| c.is_alphanumeric() || c == '-') } fn subdomain_from_sni(sni: &str, common_name: &str) -> Option { @@ -102,18 +107,14 @@ fn subdomain_from_sni(sni: &str, common_name: &str) -> Option { mod tests { use super::*; - fn make_options<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> StartupMessageParams { - StartupMessageParams::from(pairs.map(|(k, v)| (k.to_owned(), v.to_owned()))) - } - #[test] #[ignore = "TODO: fix how database is handled"] fn parse_bare_minimum() -> anyhow::Result<()> { // According to postgresql, only `user` should be required. - let options = make_options([("user", "john_doe")]); + let options = StartupMessageParams::new([("user", "john_doe")]); // TODO: check that `creds.dbname` is None. - let creds = ClientCredentials::parse(options, None, None)?; + let creds = ClientCredentials::parse(&options, None, None)?; assert_eq!(creds.user, "john_doe"); Ok(()) @@ -121,9 +122,9 @@ mod tests { #[test] fn parse_missing_project() -> anyhow::Result<()> { - let options = make_options([("user", "john_doe"), ("database", "world")]); + let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]); - let creds = ClientCredentials::parse(options, None, None)?; + let creds = ClientCredentials::parse(&options, None, None)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project, None); @@ -133,12 +134,12 @@ mod tests { #[test] fn parse_project_from_sni() -> anyhow::Result<()> { - let options = make_options([("user", "john_doe"), ("database", "world")]); + let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]); let sni = Some("foo.localhost"); let common_name = Some("localhost"); - let creds = ClientCredentials::parse(options, sni, common_name)?; + let creds = ClientCredentials::parse(&options, sni, common_name)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project.as_deref(), Some("foo")); @@ -148,13 +149,13 @@ mod tests { #[test] fn parse_project_from_options() -> anyhow::Result<()> { - let options = make_options([ + let options = StartupMessageParams::new([ ("user", "john_doe"), ("database", "world"), - ("project", "bar"), + ("options", "-ckey=1 project=bar -c geqo=off"), ]); - let creds = ClientCredentials::parse(options, None, None)?; + let creds = ClientCredentials::parse(&options, None, None)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project.as_deref(), Some("bar")); @@ -164,16 +165,16 @@ mod tests { #[test] fn parse_projects_identical() -> anyhow::Result<()> { - let options = make_options([ + let options = StartupMessageParams::new([ ("user", "john_doe"), ("database", "world"), - ("project", "baz"), + ("options", "project=baz"), ]); let sni = Some("baz.localhost"); let common_name = Some("localhost"); - let creds = ClientCredentials::parse(options, sni, common_name)?; + let creds = ClientCredentials::parse(&options, sni, common_name)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project.as_deref(), Some("baz")); @@ -183,17 +184,17 @@ mod tests { #[test] fn parse_projects_different() { - let options = make_options([ + let options = StartupMessageParams::new([ ("user", "john_doe"), ("database", "world"), - ("project", "first"), + ("options", "project=first"), ]); let sni = Some("second.localhost"); let common_name = Some("localhost"); assert!(matches!( - ClientCredentials::parse(options, sni, common_name).expect_err("should fail"), + ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"), ClientCredsParseError::InconsistentProjectNames(_, _) )); } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index a801313635..b7412b6f5b 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -95,7 +95,7 @@ impl<'a> Session<'a> { /// Store the cancel token for the given session. /// This enables query cancellation in [`crate::proxy::handshake`]. - pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData { + pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData { self.cancel_map .0 .lock() diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 3bad36661b..4ae44ded57 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,9 +1,11 @@ use crate::{cancellation::CancelClosure, error::UserFacingError}; use futures::TryFutureExt; +use itertools::Itertools; use std::{io, net::SocketAddr}; use thiserror::Error; use tokio::net::TcpStream; use tokio_postgres::NoTls; +use utils::pq_proto::StartupMessageParams; #[derive(Debug, Error)] pub enum ConnectionError { @@ -110,7 +112,42 @@ pub struct PostgresConnection { impl NodeInfo { /// Connect to a corresponding compute node. - pub async fn connect(&self) -> Result<(PostgresConnection, CancelClosure), ConnectionError> { + pub async fn connect( + mut self, + params: &StartupMessageParams, + ) -> Result<(PostgresConnection, CancelClosure), ConnectionError> { + if let Some(options) = params.options_raw() { + // We must drop all proxy-specific parameters. + #[allow(unstable_name_collisions)] + let options: String = options + .filter(|opt| !opt.starts_with("project=")) + .intersperse(" ") // TODO: use impl from std once it's stabilized + .collect(); + + self.config.options(&options); + } + + if let Some(app_name) = params.get("application_name") { + self.config.application_name(app_name); + } + + if let Some(replication) = params.get("replication") { + use tokio_postgres::config::ReplicationMode; + match replication { + "true" | "on" | "yes" | "1" => { + self.config.replication_mode(ReplicationMode::Physical); + } + "database" => { + self.config.replication_mode(ReplicationMode::Logical); + } + _other => {} + } + } + + // TODO: extend the list of the forwarded startup parameters. + // Currently, tokio-postgres doesn't allow us to pass + // arbitrary parameters, but the ones above are a good start. + let (socket_addr, mut stream) = self .connect_raw() .await diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 29be79c886..72cb822910 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,6 +1,6 @@ use crate::auth; use crate::cancellation::{self, CancelMap}; -use crate::config::{ProxyConfig, TlsConfig}; +use crate::config::{AuthUrls, ProxyConfig, TlsConfig}; use crate::stream::{MetricsStream, PqStream, Stream}; use anyhow::{bail, Context}; use futures::TryFutureExt; @@ -93,20 +93,21 @@ async fn handle_client( None => return Ok(()), // it's a cancellation request }; + // Extract credentials which we're going to use for auth. let creds = { let sni = stream.get_ref().sni_hostname(); let common_name = tls.and_then(|tls| tls.common_name.as_deref()); let result = config .auth_backend - .map(|_| auth::ClientCredentials::parse(params, sni, common_name)) + .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name)) .transpose(); async { result }.or_else(|e| stream.throw_error(e)).await? }; - let client = Client::new(stream, creds); + let client = Client::new(stream, creds, ¶ms); cancel_map - .with_session(|session| client.connect_to_db(config, session)) + .with_session(|session| client.connect_to_db(&config.auth_urls, session)) .await } @@ -174,38 +175,57 @@ async fn handshake( } /// Thin connection context. -struct Client { +struct Client<'a, S> { /// The underlying libpq protocol stream. stream: PqStream, /// Client credentials that we care about. - creds: auth::BackendType, + creds: auth::BackendType>, + /// KV-dictionary with PostgreSQL connection params. + params: &'a StartupMessageParams, } -impl Client { +impl<'a, S> Client<'a, S> { /// Construct a new connection context. - fn new(stream: PqStream, creds: auth::BackendType) -> Self { - Self { stream, creds } + fn new( + stream: PqStream, + creds: auth::BackendType>, + params: &'a StartupMessageParams, + ) -> Self { + Self { + stream, + creds, + params, + } } } -impl Client { +impl Client<'_, S> { /// Let the client authenticate and connect to the designated compute node. async fn connect_to_db( self, - config: &ProxyConfig, + urls: &AuthUrls, session: cancellation::Session<'_>, ) -> anyhow::Result<()> { - let Self { mut stream, creds } = self; + let Self { + mut stream, + creds, + params, + } = self; // Authenticate and connect to a compute node. - let auth = creds.authenticate(&config.auth_urls, &mut stream).await; + let auth = creds.authenticate(urls, &mut stream).await; let node = async { auth }.or_else(|e| stream.throw_error(e)).await?; + let reported_auth_ok = node.reported_auth_ok; - let (db, cancel_closure) = node.connect().or_else(|e| stream.throw_error(e)).await?; - let cancel_key_data = session.enable_cancellation(cancel_closure); + let (db, cancel_closure) = node + .connect(params) + .or_else(|e| stream.throw_error(e)) + .await?; + + let cancel_key_data = session.enable_query_cancellation(cancel_closure); // Report authentication success if we haven't done this already. - if !node.reported_auth_ok { + if !reported_auth_ok { stream .write_message_noflush(&Be::AuthenticationOk)? .write_message_noflush(&BeParameterStatusMessage::encoding())?; diff --git a/safekeeper/src/handler.rs b/safekeeper/src/handler.rs index c90c2a0446..3e301259ed 100644 --- a/safekeeper/src/handler.rs +++ b/safekeeper/src/handler.rs @@ -11,7 +11,6 @@ use anyhow::{bail, Context, Result}; use postgres_ffi::PG_TLI; use regex::Regex; -use std::str::FromStr; use std::sync::Arc; use tracing::info; use utils::{ @@ -67,18 +66,22 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { // ztenant id and ztimeline id are passed in connection string params fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupPacket) -> Result<()> { if let FeStartupPacket::StartupMessage { params, .. } = sm { - self.ztenantid = match params.get("ztenantid") { - Some(z) => Some(ZTenantId::from_str(z)?), // just curious, can I do that from .map? - _ => None, - }; - - self.ztimelineid = match params.get("ztimelineid") { - Some(z) => Some(ZTimelineId::from_str(z)?), - _ => None, - }; + if let Some(options) = params.options_raw() { + for opt in options { + match opt.split_once('=') { + Some(("ztenantid", value)) => { + self.ztenantid = Some(value.parse()?); + } + Some(("ztimelineid", value)) => { + self.ztimelineid = Some(value.parse()?); + } + _ => continue, + } + } + } if let Some(app_name) = params.get("application_name") { - self.appname = Some(app_name.clone()); + self.appname = Some(app_name.to_owned()); } Ok(()) diff --git a/test_runner/batch_others/test_proxy.py b/test_runner/batch_others/test_proxy.py index 1efb795140..bd02841dc0 100644 --- a/test_runner/batch_others/test_proxy.py +++ b/test_runner/batch_others/test_proxy.py @@ -134,12 +134,8 @@ async def test_psql_session_id(vanilla_pg: VanillaPostgres, link_proxy: NeonProx # Pass extra options to the server. -# -# Currently, proxy eats the extra connection options, so this fails. -# See https://github.com/neondatabase/neon/issues/1287 -@pytest.mark.xfail def test_proxy_options(static_proxy): - with static_proxy.connect(options="-cproxytest.option=value") as conn: + with static_proxy.connect(options="project=irrelevant -cproxytest.option=value") as conn: with conn.cursor() as cur: cur.execute("SHOW proxytest.option") value = cur.fetchall()[0][0]