diff --git a/Cargo.lock b/Cargo.lock index 5b80ec5e93..38158b7aec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1031,9 +1031,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.5.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" dependencies = [ "serde", ] diff --git a/Cargo.toml b/Cargo.toml index 91fa6a2607..a35823e0c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,7 +74,7 @@ bindgen = "0.70" bit_field = "0.10.2" bstr = "1.0" byteorder = "1.4" -bytes = "1.0" +bytes = "1.9" camino = "1.1.6" cfg-if = "1.0.0" chrono = { version = "0.4", default-features = false, features = ["clock"] } diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index 43dfbc22a4..94714359a3 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -100,7 +100,7 @@ impl StartupMessageParamsBuilder { #[derive(Debug, Clone, Default)] pub struct StartupMessageParams { - params: Bytes, + pub params: Bytes, } impl StartupMessageParams { diff --git a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs index 19aa3c1e9a..f2200a40ce 100644 --- a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs +++ b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs @@ -117,7 +117,7 @@ enum Credentials { /// A regular password as a vector of bytes. Password(Vec), /// A precomputed pair of keys. - Keys(Box>), + Keys(ScramKeys), } enum State { @@ -176,7 +176,7 @@ impl ScramSha256 { /// Constructs a new instance which will use the provided key pair for authentication. pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 { - let password = Credentials::Keys(keys.into()); + let password = Credentials::Keys(keys); ScramSha256::new_inner(password, channel_binding, nonce()) } diff --git a/libs/proxy/postgres-protocol2/src/message/frontend.rs b/libs/proxy/postgres-protocol2/src/message/frontend.rs index 5d0a8ff8c8..bc6168f337 100644 --- a/libs/proxy/postgres-protocol2/src/message/frontend.rs +++ b/libs/proxy/postgres-protocol2/src/message/frontend.rs @@ -255,22 +255,34 @@ pub fn ssl_request(buf: &mut BytesMut) { } #[inline] -pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()> -where - I: IntoIterator, -{ +pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> { write_body(buf, |buf| { // postgres protocol version 3.0(196608) in bigger-endian buf.put_i32(0x00_03_00_00); - for (key, value) in parameters { - write_cstr(key.as_bytes(), buf)?; - write_cstr(value.as_bytes(), buf)?; - } + buf.put_slice(¶meters.params); buf.put_u8(0); Ok(()) }) } +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct StartupMessageParams { + pub params: BytesMut, +} + +impl StartupMessageParams { + /// Set parameter's value by its name. + pub fn insert(&mut self, name: &str, value: &str) { + if name.contains('\0') || value.contains('\0') { + panic!("startup parameter name or value contained a null") + } + self.params.put_slice(name.as_bytes()); + self.params.put_u8(0); + self.params.put_slice(value.as_bytes()); + self.params.put_u8(0); + } +} + #[inline] pub fn sync(buf: &mut BytesMut) { buf.put_u8(b'S'); diff --git a/libs/proxy/tokio-postgres2/src/codec.rs b/libs/proxy/tokio-postgres2/src/codec.rs index 7412db785b..0ec46198ce 100644 --- a/libs/proxy/tokio-postgres2/src/codec.rs +++ b/libs/proxy/tokio-postgres2/src/codec.rs @@ -35,9 +35,7 @@ impl FallibleIterator for BackendMessages { } } -pub struct PostgresCodec { - pub max_message_size: Option, -} +pub struct PostgresCodec; impl Encoder for PostgresCodec { type Error = io::Error; @@ -66,15 +64,6 @@ impl Decoder for PostgresCodec { break; } - if let Some(max) = self.max_message_size { - if len > max { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "message too large", - )); - } - } - match header.tag() { backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index fd10ef6f20..11a361a81b 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -6,6 +6,7 @@ use crate::connect_raw::RawConnection; use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; use crate::{Client, Connection, Error}; +use postgres_protocol2::message::frontend::StartupMessageParams; use std::fmt; use std::str; use std::time::Duration; @@ -14,16 +15,6 @@ use tokio::io::{AsyncRead, AsyncWrite}; pub use postgres_protocol2::authentication::sasl::ScramKeys; use tokio::net::TcpStream; -/// Properties required of a session. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -#[non_exhaustive] -pub enum TargetSessionAttrs { - /// No special properties are required. - Any, - /// The session must allow writes. - ReadWrite, -} - /// TLS configuration. #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[non_exhaustive] @@ -73,94 +64,20 @@ pub enum AuthKeys { } /// Connection configuration. -/// -/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats: -/// -/// # Key-Value -/// -/// This format consists of space-separated key-value pairs. Values which are either the empty string or contain -/// whitespace should be wrapped in `'`. `'` and `\` characters should be backslash-escaped. -/// -/// ## Keys -/// -/// * `user` - The username to authenticate with. Required. -/// * `password` - The password to authenticate with. -/// * `dbname` - The name of the database to connect to. Defaults to the username. -/// * `options` - Command line options used to configure the server. -/// * `application_name` - Sets the `application_name` parameter on the server. -/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used -/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`. -/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the -/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts -/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting -/// with the `connect` method. -/// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be -/// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if -/// omitted or the empty string. -/// * `connect_timeout` - The time limit in seconds applied to each socket-level connection attempt. Note that hostnames -/// can resolve to multiple IP addresses, and this limit is applied to each address. Defaults to no timeout. -/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that -/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server -/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`. -/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel -/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. -/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. -/// -/// ## Examples -/// -/// ```not_rust -/// host=localhost user=postgres connect_timeout=10 keepalives=0 -/// ``` -/// -/// ```not_rust -/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces' -/// ``` -/// -/// ```not_rust -/// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write -/// ``` -/// -/// # Url -/// -/// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional, -/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple -/// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded, -/// as the path component of the URL specifies the database name. -/// -/// ## Examples -/// -/// ```not_rust -/// postgresql://user@localhost -/// ``` -/// -/// ```not_rust -/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10 -/// ``` -/// -/// ```not_rust -/// postgresql://user@host1:1234,host2,host3:5678?target_session_attrs=read-write -/// ``` -/// -/// ```not_rust -/// postgresql:///mydb?user=user&host=/var/lib/postgresql -/// ``` #[derive(Clone, PartialEq, Eq)] pub struct Config { pub(crate) host: Host, pub(crate) port: u16, - pub(crate) user: Option, pub(crate) password: Option>, pub(crate) auth_keys: Option>, - pub(crate) dbname: Option, - pub(crate) options: Option, - pub(crate) application_name: Option, pub(crate) ssl_mode: SslMode, pub(crate) connect_timeout: Option, - pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, - pub(crate) replication_mode: Option, - pub(crate) max_backend_message_size: Option, + pub(crate) server_params: StartupMessageParams, + + database: bool, + username: bool, } impl Config { @@ -169,18 +86,15 @@ impl Config { Config { host: Host::Tcp(host), port, - user: None, password: None, auth_keys: None, - dbname: None, - options: None, - application_name: None, ssl_mode: SslMode::Prefer, connect_timeout: None, - target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, - replication_mode: None, - max_backend_message_size: None, + server_params: StartupMessageParams::default(), + + database: false, + username: false, } } @@ -188,14 +102,13 @@ impl Config { /// /// Required. pub fn user(&mut self, user: &str) -> &mut Config { - self.user = Some(user.to_string()); - self + self.set_param("user", user) } /// Gets the user to authenticate with, if one has been configured with /// the `user` method. - pub fn get_user(&self) -> Option<&str> { - self.user.as_deref() + pub fn user_is_set(&self) -> bool { + self.username } /// Sets the password to authenticate with. @@ -231,40 +144,26 @@ impl Config { /// /// Defaults to the user. pub fn dbname(&mut self, dbname: &str) -> &mut Config { - self.dbname = Some(dbname.to_string()); - self + self.set_param("database", dbname) } /// Gets the name of the database to connect to, if one has been configured /// with the `dbname` method. - pub fn get_dbname(&self) -> Option<&str> { - self.dbname.as_deref() + pub fn db_is_set(&self) -> bool { + self.database } - /// Sets command line options used to configure the server. - pub fn options(&mut self, options: &str) -> &mut Config { - self.options = Some(options.to_string()); + pub fn set_param(&mut self, name: &str, value: &str) -> &mut Config { + if name == "database" { + self.database = true; + } else if name == "user" { + self.username = true; + } + + self.server_params.insert(name, value); self } - /// Gets the command line options used to configure the server, if the - /// options have been set with the `options` method. - pub fn get_options(&self) -> Option<&str> { - self.options.as_deref() - } - - /// Sets the value of the `application_name` runtime parameter. - pub fn application_name(&mut self, application_name: &str) -> &mut Config { - self.application_name = Some(application_name.to_string()); - self - } - - /// Gets the value of the `application_name` runtime parameter, if it has - /// been set with the `application_name` method. - pub fn get_application_name(&self) -> Option<&str> { - self.application_name.as_deref() - } - /// Sets the SSL configuration. /// /// Defaults to `prefer`. @@ -303,23 +202,6 @@ impl Config { self.connect_timeout.as_ref() } - /// Sets the requirements of the session. - /// - /// This can be used to connect to the primary server in a clustered database rather than one of the read-only - /// secondary servers. Defaults to `Any`. - pub fn target_session_attrs( - &mut self, - target_session_attrs: TargetSessionAttrs, - ) -> &mut Config { - self.target_session_attrs = target_session_attrs; - self - } - - /// Gets the requirements of the session. - pub fn get_target_session_attrs(&self) -> TargetSessionAttrs { - self.target_session_attrs - } - /// Sets the channel binding behavior. /// /// Defaults to `prefer`. @@ -333,28 +215,6 @@ impl Config { self.channel_binding } - /// Set replication mode. - pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { - self.replication_mode = Some(replication_mode); - self - } - - /// Get replication mode. - pub fn get_replication_mode(&self) -> Option { - self.replication_mode - } - - /// Set limit for backend messages size. - pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { - self.max_backend_message_size = Some(max_backend_message_size); - self - } - - /// Get limit for backend messages size. - pub fn get_max_backend_message_size(&self) -> Option { - self.max_backend_message_size - } - /// Opens a connection to a PostgreSQL database. /// /// Requires the `runtime` Cargo feature (enabled by default). @@ -392,18 +252,13 @@ impl fmt::Debug for Config { } f.debug_struct("Config") - .field("user", &self.user) .field("password", &self.password.as_ref().map(|_| Redaction {})) - .field("dbname", &self.dbname) - .field("options", &self.options) - .field("application_name", &self.application_name) .field("ssl_mode", &self.ssl_mode) .field("host", &self.host) .field("port", &self.port) .field("connect_timeout", &self.connect_timeout) - .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) - .field("replication", &self.replication_mode) + .field("server_params", &self.server_params) .finish() } } diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 75a58e6eac..e0cb69748d 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -1,14 +1,11 @@ use crate::client::SocketConfig; use crate::codec::BackendMessage; -use crate::config::{Host, TargetSessionAttrs}; +use crate::config::Host; use crate::connect_raw::connect_raw; use crate::connect_socket::connect_socket; use crate::tls::{MakeTlsConnect, TlsConnect}; -use crate::{Client, Config, Connection, Error, RawConnection, SimpleQueryMessage}; -use futures_util::{future, pin_mut, Future, FutureExt, Stream}; +use crate::{Client, Config, Connection, Error, RawConnection}; use postgres_protocol2::message::backend::Message; -use std::io; -use std::task::Poll; use tokio::net::TcpStream; use tokio::sync::mpsc; @@ -72,47 +69,7 @@ where .map(|m| BackendMessage::Async(Message::NoticeResponse(m))) .collect(); - let mut connection = Connection::new(stream, delayed, parameters, receiver); - - if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { - let rows = client.simple_query_raw("SHOW transaction_read_only"); - pin_mut!(rows); - - let rows = future::poll_fn(|cx| { - if connection.poll_unpin(cx)?.is_ready() { - return Poll::Ready(Err(Error::closed())); - } - - rows.as_mut().poll(cx) - }) - .await?; - pin_mut!(rows); - - loop { - let next = future::poll_fn(|cx| { - if connection.poll_unpin(cx)?.is_ready() { - return Poll::Ready(Some(Err(Error::closed()))); - } - - rows.as_mut().poll_next(cx) - }); - - match next.await.transpose()? { - Some(SimpleQueryMessage::Row(row)) => { - if row.try_get(0)? == Some("on") { - return Err(Error::connect(io::Error::new( - io::ErrorKind::PermissionDenied, - "database does not allow writes", - ))); - } else { - break; - } - } - Some(_) => {} - None => return Err(Error::unexpected_message()), - } - } - } + let connection = Connection::new(stream, delayed, parameters, receiver); Ok((client, connection)) } diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs index 390f133002..66db85e07d 100644 --- a/libs/proxy/tokio-postgres2/src/connect_raw.rs +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::{self, AuthKeys, Config, ReplicationMode}; +use crate::config::{self, AuthKeys, Config}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; @@ -96,12 +96,7 @@ where let stream = connect_tls(stream, config.ssl_mode, tls).await?; let mut stream = StartupStream { - inner: Framed::new( - stream, - PostgresCodec { - max_message_size: config.max_backend_message_size, - }, - ), + inner: Framed::new(stream, PostgresCodec), buf: BackendMessages::empty(), delayed_notice: Vec::new(), }; @@ -124,28 +119,8 @@ where S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin, { - let mut params = vec![("client_encoding", "UTF8")]; - if let Some(user) = &config.user { - params.push(("user", &**user)); - } - if let Some(dbname) = &config.dbname { - params.push(("database", &**dbname)); - } - if let Some(options) = &config.options { - params.push(("options", &**options)); - } - if let Some(application_name) = &config.application_name { - params.push(("application_name", &**application_name)); - } - if let Some(replication_mode) = &config.replication_mode { - match replication_mode { - ReplicationMode::Physical => params.push(("replication", "true")), - ReplicationMode::Logical => params.push(("replication", "database")), - } - } - let mut buf = BytesMut::new(); - frontend::startup_message(params, &mut buf).map_err(Error::encode)?; + frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?; stream .send(FrontendMessage::Raw(buf.freeze())) diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index bcb0ef40bd..7bc5587a25 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -70,11 +70,12 @@ impl ReportableError for CancelError { impl CancellationHandler

{ /// Run async action within an ephemeral session identified by [`CancelKeyData`]. pub(crate) fn get_session(self: Arc) -> Session

{ - // HACK: We'd rather get the real backend_pid but postgres_client doesn't - // expose it and we don't want to do another roundtrip to query - // for it. The client will be able to notice that this is not the - // actual backend_pid, but backend_pid is not used for anything - // so it doesn't matter. + // we intentionally generate a random "backend pid" and "secret key" here. + // we use the corresponding u64 as an identifier for the + // actual endpoint+pid+secret for postgres/pgbouncer. + // + // if we forwarded the backend_pid from postgres to the client, there would be a lot + // of overlap between our computes as most pids are small (~100). let key = loop { let key = rand::random(); diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index ab0ff4b795..4113b5bb80 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -131,49 +131,37 @@ impl ConnCfg { } /// Apply startup message params to the connection config. - pub(crate) fn set_startup_params(&mut self, params: &StartupMessageParams) { - // Only set `user` if it's not present in the config. - // Console redirect auth flow takes username from the console's response. - if let (None, Some(user)) = (self.get_user(), params.get("user")) { - self.user(user); + pub(crate) fn set_startup_params( + &mut self, + params: &StartupMessageParams, + arbitrary_params: bool, + ) { + if !arbitrary_params { + self.set_param("client_encoding", "UTF8"); } - - // Only set `dbname` if it's not present in the config. - // Console redirect auth flow takes dbname from the console's response. - if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) { - self.dbname(dbname); - } - - // Don't add `options` if they were only used for specifying a project. - // Connection pools don't support `options`, because they affect backend startup. - if let Some(options) = filtered_options(params) { - self.options(&options); - } - - if let Some(app_name) = params.get("application_name") { - self.application_name(app_name); - } - - // TODO: This is especially ugly... - if let Some(replication) = params.get("replication") { - use postgres_client::config::ReplicationMode; - match replication { - "true" | "on" | "yes" | "1" => { - self.replication_mode(ReplicationMode::Physical); + for (k, v) in params.iter() { + match k { + // Only set `user` if it's not present in the config. + // Console redirect auth flow takes username from the console's response. + "user" if self.user_is_set() => continue, + "database" if self.db_is_set() => continue, + "options" => { + if let Some(options) = filtered_options(v) { + self.set_param(k, &options); + } } - "database" => { - self.replication_mode(ReplicationMode::Logical); + "user" | "database" | "application_name" | "replication" => { + self.set_param(k, v); } - _other => {} + + // if we allow arbitrary params, then we forward them through. + // this is a flag for a period of backwards compatibility + k if arbitrary_params => { + self.set_param(k, v); + } + _ => {} } } - - // TODO: extend the list of the forwarded startup parameters. - // Currently, tokio-postgres doesn't allow us to pass - // arbitrary parameters, but the ones above are a good start. - // - // This and the reverse params problem can be better addressed - // in a bespoke connection machinery (a new library for that sake). } } @@ -347,10 +335,9 @@ impl ConnCfg { } /// Retrieve `options` from a startup message, dropping all proxy-secific flags. -fn filtered_options(params: &StartupMessageParams) -> Option { +fn filtered_options(options: &str) -> Option { #[allow(unstable_name_collisions)] - let options: String = params - .options_raw()? + let options: String = StartupMessageParams::parse_options_raw(options) .filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none()) .intersperse(" ") // TODO: use impl from std once it's stabilized .collect(); @@ -427,27 +414,24 @@ mod tests { #[test] fn test_filtered_options() { // Empty options is unlikely to be useful anyway. - let params = StartupMessageParams::new([("options", "")]); - assert_eq!(filtered_options(¶ms), None); + let params = ""; + assert_eq!(filtered_options(params), None); // It's likely that clients will only use options to specify endpoint/project. - let params = StartupMessageParams::new([("options", "project=foo")]); - assert_eq!(filtered_options(¶ms), None); + let params = "project=foo"; + assert_eq!(filtered_options(params), None); // Same, because unescaped whitespaces are no-op. - let params = StartupMessageParams::new([("options", " project=foo ")]); - assert_eq!(filtered_options(¶ms).as_deref(), None); + let params = " project=foo "; + assert_eq!(filtered_options(params).as_deref(), None); - let params = StartupMessageParams::new([("options", r"\ project=foo \ ")]); - assert_eq!(filtered_options(¶ms).as_deref(), Some(r"\ \ ")); + let params = r"\ project=foo \ "; + assert_eq!(filtered_options(params).as_deref(), Some(r"\ \ ")); - let params = StartupMessageParams::new([("options", "project = foo")]); - assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo")); + let params = "project = foo"; + assert_eq!(filtered_options(params).as_deref(), Some("project = foo")); - let params = StartupMessageParams::new([( - "options", - "project = foo neon_endpoint_type:read_write neon_lsn:0/2", - )]); - assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo")); + let params = "project = foo neon_endpoint_type:read_write neon_lsn:0/2 neon_proxy_params_compat:true"; + assert_eq!(filtered_options(params).as_deref(), Some("project = foo")); } } diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 8f78df1964..7db1179eea 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -206,6 +206,7 @@ pub(crate) async fn handle_client( let mut node = connect_to_compute( ctx, &TcpMechanism { + params_compat: true, params: ¶ms, locks: &config.connect_compute_locks, }, diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 585dce7bae..a3027abd7c 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -66,6 +66,8 @@ pub(crate) trait ComputeConnectBackend { } pub(crate) struct TcpMechanism<'a> { + pub(crate) params_compat: bool, + /// KV-dictionary with PostgreSQL connection params. pub(crate) params: &'a StartupMessageParams, @@ -92,7 +94,7 @@ impl ConnectMechanism for TcpMechanism<'_> { } fn update_connect_config(&self, config: &mut compute::ConnCfg) { - config.set_startup_params(self.params); + config.set_startup_params(self.params, self.params_compat); } } diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index af97fb3d71..f74eb5940f 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -338,9 +338,17 @@ pub(crate) async fn handle_client( } }; + let params_compat = match &user_info { + auth::Backend::ControlPlane(_, info) => { + info.info.options.get(NeonOptions::PARAMS_COMPAT).is_some() + } + auth::Backend::Local(_) => false, + }; + let mut node = connect_to_compute( ctx, &TcpMechanism { + params_compat, params: ¶ms, locks: &config.connect_compute_locks, }, @@ -409,19 +417,47 @@ pub(crate) async fn prepare_client_connection

( pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>); impl NeonOptions { + // proxy options: + + /// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute. + const PARAMS_COMPAT: &str = "proxy_params_compat"; + + // cplane options: + + /// `LSN` allows provisioning an ephemeral compute with time-travel to the provided LSN. + const LSN: &str = "lsn"; + + /// `ENDPOINT_TYPE` allows configuring an ephemeral compute to be read_only or read_write. + const ENDPOINT_TYPE: &str = "endpoint_type"; + pub(crate) fn parse_params(params: &StartupMessageParams) -> Self { params .options_raw() .map(Self::parse_from_iter) .unwrap_or_default() } + pub(crate) fn parse_options_raw(options: &str) -> Self { Self::parse_from_iter(StartupMessageParams::parse_options_raw(options)) } + pub(crate) fn get(&self, key: &str) -> Option { + self.0 + .iter() + .find_map(|(k, v)| (k == key).then_some(v)) + .cloned() + } + pub(crate) fn is_ephemeral(&self) -> bool { - // Currently, neon endpoint options are all reserved for ephemeral endpoints. - !self.0.is_empty() + self.0.iter().any(|(k, _)| match &**k { + // This is not a cplane option, we know it does not create ephemeral computes. + Self::PARAMS_COMPAT => false, + Self::LSN => true, + Self::ENDPOINT_TYPE => true, + // err on the side of caution. any cplane options we don't know about + // might lead to ephemeral computes. + _ => true, + }) } fn parse_from_iter<'a>(options: impl Iterator) -> Self { diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index d72331c7bf..59c9ac27b8 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -55,7 +55,13 @@ async fn proxy_mitm( // give the end_server the startup parameters let mut buf = BytesMut::new(); - frontend::startup_message(startup.iter(), &mut buf).unwrap(); + frontend::startup_message( + &postgres_protocol::message::frontend::StartupMessageParams { + params: startup.params.into(), + }, + &mut buf, + ) + .unwrap(); end_server.send(buf.freeze()).await.unwrap(); // proxy messages between end_client and end_server diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 53345431e3..911b349416 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -252,7 +252,7 @@ async fn handshake_raw() -> anyhow::Result<()> { let _conn = postgres_client::Config::new("test".to_owned(), 5432) .user("john_doe") .dbname("earth") - .options("project=generic-project-name") + .set_param("options", "project=generic-project-name") .ssl_mode(SslMode::Prefer) .connect_raw(server, NoTls) .await?; diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 55d2e47fd3..251aa47084 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -309,10 +309,13 @@ impl PoolingBackend { .config .user(&conn_info.user_info.user) .dbname(&conn_info.dbname) - .options(&format!( - "-c pg_session_jwt.jwk={}", - serde_json::to_string(&jwk).expect("serializing jwk to json should not fail") - )); + .set_param( + "options", + &format!( + "-c pg_session_jwt.jwk={}", + serde_json::to_string(&jwk).expect("serializing jwk to json should not fail") + ), + ); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let (client, connection) = config.connect(postgres_client::NoTls).await?; diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 9c579373e8..60c4a23936 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -269,7 +269,7 @@ class PgProtocol: for match in re.finditer(r"-c(\w*)=(\w*)", options): key = match.group(1) val = match.group(2) - if "server_options" in conn_options: + if "server_settings" in conn_options: conn_options["server_settings"].update({key: val}) else: conn_options["server_settings"] = {key: val} diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 5a01d90d85..d8df2efc78 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -5,6 +5,7 @@ import json import subprocess import time import urllib.parse +from contextlib import closing from typing import TYPE_CHECKING import psycopg2 @@ -131,6 +132,24 @@ def test_proxy_options(static_proxy: NeonProxy, option_name: str): assert out[0][0] == " str" +@pytest.mark.asyncio +async def test_proxy_arbitrary_params(static_proxy: NeonProxy): + with closing( + await static_proxy.connect_async(server_settings={"IntervalStyle": "iso_8601"}) + ) as conn: + out = await conn.fetchval("select to_json('0 seconds'::interval)") + assert out == '"00:00:00"' + + options = "neon_proxy_params_compat:true" + with closing( + await static_proxy.connect_async( + server_settings={"IntervalStyle": "iso_8601", "options": options} + ) + ) as conn: + out = await conn.fetchval("select to_json('0 seconds'::interval)") + assert out == '"PT0S"' + + def test_auth_errors(static_proxy: NeonProxy): """ Check that we throw very specific errors in some unsuccessful auth scenarios.