diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 4b937f017a..4adf0ed940 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -8,7 +8,9 @@ pub use console::{GetAuthInfoError, WakeComputeError}; use crate::{ auth::{self, AuthFlow, ClientCredentials}, - compute, http, mgmt, stream, url, + compute, + console::messages::MetricsAuxInfo, + http, mgmt, stream, url, waiters::{self, Waiter, Waiters}, }; use once_cell::sync::Lazy; @@ -126,25 +128,13 @@ pub struct AuthSuccess { pub value: T, } -impl AuthSuccess { - /// Very similar to [`std::option::Option::map`]. - /// Maps [`AuthSuccess`] to [`AuthSuccess`] by applying - /// a function to a contained value. - pub fn map(self, f: impl FnOnce(T) -> R) -> AuthSuccess { - AuthSuccess { - reported_auth_ok: self.reported_auth_ok, - value: f(self.value), - } - } -} - /// Info for establishing a connection to a compute node. /// This is what we get after auth succeeded, but not before! pub struct NodeInfo { - /// Project from [`auth::ClientCredentials`]. - pub project: String, /// Compute node connection params. pub config: compute::ConnCfg, + /// Labels for proxy's metrics. + pub aux: MetricsAuxInfo, } impl BackendType<'_, ClientCredentials<'_>> { @@ -172,37 +162,34 @@ impl BackendType<'_, ClientCredentials<'_>> { }; // TODO: find a proper way to merge those very similar blocks. - let (mut config, payload) = match self { + let (mut node, payload) = match self { Console(endpoint, creds) if creds.project.is_none() => { let payload = fetch_magic_payload.await?; let mut creds = creds.as_ref(); creds.project = Some(payload.project.as_str().into()); - let config = console::Api::new(endpoint, extra, &creds) + let node = console::Api::new(endpoint, extra, &creds) .wake_compute() .await?; - (config, payload) + (node, payload) } Postgres(endpoint, creds) if creds.project.is_none() => { let payload = fetch_magic_payload.await?; let mut creds = creds.as_ref(); creds.project = Some(payload.project.as_str().into()); - let config = postgres::Api::new(endpoint, &creds).wake_compute().await?; + let node = postgres::Api::new(endpoint, &creds).wake_compute().await?; - (config, payload) + (node, payload) } _ => return Ok(None), }; - config.password(payload.password); + node.config.password(payload.password); Ok(Some(AuthSuccess { reported_auth_ok: false, - value: NodeInfo { - project: payload.project, - config, - }, + value: node, })) } @@ -233,10 +220,6 @@ impl BackendType<'_, ClientCredentials<'_>> { console::Api::new(&endpoint, extra, &creds) .handle_user(client) .await? - .map(|config| NodeInfo { - project: creds.project.unwrap().into_owned(), - config, - }) } Postgres(endpoint, creds) => { info!("performing mock authentication using a local postgres instance"); @@ -245,10 +228,6 @@ impl BackendType<'_, ClientCredentials<'_>> { postgres::Api::new(&endpoint, &creds) .handle_user(client) .await? - .map(|config| NodeInfo { - project: creds.project.unwrap().into_owned(), - config, - }) } // NOTE: this auth backend doesn't use client credentials. Link(url) => { diff --git a/proxy/src/auth/backend/console.rs b/proxy/src/auth/backend/console.rs index 040870fc8e..b3e3fd0c10 100644 --- a/proxy/src/auth/backend/console.rs +++ b/proxy/src/auth/backend/console.rs @@ -1,16 +1,16 @@ //! Cloud API V2. -use super::{AuthSuccess, ConsoleReqExtra}; +use super::{AuthSuccess, ConsoleReqExtra, NodeInfo}; use crate::{ auth::{self, AuthFlow, ClientCredentials}, compute, + console::messages::{ConsoleError, GetRoleSecret, WakeCompute}, error::{io_error, UserFacingError}, http, sasl, scram, stream::PqStream, }; use futures::TryFutureExt; use reqwest::StatusCode as HttpStatusCode; -use serde::Deserialize; use std::future::Future; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; @@ -136,24 +136,6 @@ impl UserFacingError for WakeComputeError { } } -/// Console's response which holds client's auth secret. -#[derive(Deserialize, Debug)] -struct GetRoleSecret { - role_secret: Box, -} - -/// Console's response which holds compute node's `host:port` pair. -#[derive(Deserialize, Debug)] -struct WakeCompute { - address: Box, -} - -/// Console's error response with human-readable description. -#[derive(Deserialize, Debug)] -struct ConsoleError { - error: Box, -} - /// Auth secret which is managed by the cloud. pub enum AuthInfo { /// Md5 hash of user's password. @@ -194,7 +176,7 @@ impl<'a> Api<'a> { pub(super) async fn handle_user( &'a self, client: &mut PqStream, - ) -> auth::Result> { + ) -> auth::Result> { handle_user(client, self, Self::get_auth_info, Self::wake_compute).await } } @@ -238,7 +220,7 @@ impl Api<'_> { } /// Wake up the compute node and return the corresponding connection info. - pub async fn wake_compute(&self) -> Result { + pub async fn wake_compute(&self) -> Result { let request_id = uuid::Uuid::new_v4().to_string(); async { let request = self @@ -269,7 +251,10 @@ impl Api<'_> { .dbname(self.creds.dbname) .user(self.creds.user); - Ok(config) + Ok(NodeInfo { + config, + aux: body.aux, + }) } .map_err(crate::error::log_error) .instrument(info_span!("wake_compute", id = request_id)) @@ -284,11 +269,11 @@ pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>( endpoint: &'a Endpoint, get_auth_info: impl FnOnce(&'a Endpoint) -> GetAuthInfo, wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute, -) -> auth::Result> +) -> auth::Result> where Endpoint: AsRef>, GetAuthInfo: Future, GetAuthInfoError>>, - WakeCompute: Future>, + WakeCompute: Future>, { let creds = endpoint.as_ref(); @@ -325,19 +310,20 @@ where } }; - let mut config = wake_compute(endpoint).await?; + let mut node = wake_compute(endpoint).await?; if let Some(keys) = scram_keys { - config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(keys)); + use tokio_postgres::config::AuthKeys; + node.config.auth_keys(AuthKeys::ScramSha256(keys)); } Ok(AuthSuccess { reported_auth_ok: false, - value: config, + value: node, }) } /// Parse http response body, taking status code into account. -async fn parse_body Deserialize<'a>>( +async fn parse_body serde::Deserialize<'a>>( response: reqwest::Response, ) -> Result { let status = response.status(); diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index 641519ac50..e16bbc70e4 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -86,8 +86,8 @@ pub async fn handle_user( Ok(AuthSuccess { reported_auth_ok: true, value: NodeInfo { - project: db_info.project, config, + aux: db_info.aux, }, }) } diff --git a/proxy/src/auth/backend/postgres.rs b/proxy/src/auth/backend/postgres.rs index 8f16dc9fa8..260342f103 100644 --- a/proxy/src/auth/backend/postgres.rs +++ b/proxy/src/auth/backend/postgres.rs @@ -2,7 +2,7 @@ use super::{ console::{self, AuthInfo, GetAuthInfoError, WakeComputeError}, - AuthSuccess, + AuthSuccess, NodeInfo, }; use crate::{ auth::{self, ClientCredentials}, @@ -57,7 +57,7 @@ impl<'a> Api<'a> { pub(super) async fn handle_user( &'a self, client: &mut PqStream, - ) -> auth::Result> { + ) -> auth::Result> { // We reuse user handling logic from a production module. console::handle_user(client, self, Self::get_auth_info, Self::wake_compute).await } @@ -103,7 +103,7 @@ impl Api<'_> { } /// We don't need to wake anything locally, so we just return the connection info. - pub async fn wake_compute(&self) -> Result { + pub async fn wake_compute(&self) -> Result { let mut config = compute::ConnCfg::new(); config .host(self.endpoint.host_str().unwrap_or("localhost")) @@ -111,7 +111,10 @@ impl Api<'_> { .dbname(self.creds.dbname) .user(self.creds.user); - Ok(config) + Ok(NodeInfo { + config, + aux: Default::default(), + }) } } diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 71421a4a65..094db73061 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -43,7 +43,7 @@ pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; /// Eventually, `tokio_postgres` will be replaced with something better. /// Newtype allows us to implement methods on top of it. #[repr(transparent)] -pub struct ConnCfg(pub tokio_postgres::Config); +pub struct ConnCfg(Box); impl ConnCfg { /// Construct a new connection config. diff --git a/proxy/src/console.rs b/proxy/src/console.rs new file mode 100644 index 0000000000..78f09ac9e1 --- /dev/null +++ b/proxy/src/console.rs @@ -0,0 +1,5 @@ +///! Various stuff for dealing with the Neon Console. +///! Later we might move some API wrappers here. + +/// Payloads used in the console's APIs. +pub mod messages; diff --git a/proxy/src/console/messages.rs b/proxy/src/console/messages.rs new file mode 100644 index 0000000000..63a97069b8 --- /dev/null +++ b/proxy/src/console/messages.rs @@ -0,0 +1,190 @@ +use serde::Deserialize; +use std::fmt; + +/// Generic error response with human-readable description. +/// Note that we can't always present it to user as is. +#[derive(Debug, Deserialize)] +pub struct ConsoleError { + pub error: Box, +} + +/// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`]. +/// Returned by the `/proxy_get_role_secret` API method. +#[derive(Deserialize)] +pub struct GetRoleSecret { + pub role_secret: Box, +} + +// Manually implement debug to omit sensitive info. +impl fmt::Debug for GetRoleSecret { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GetRoleSecret").finish_non_exhaustive() + } +} + +/// Response which holds compute node's `host:port` pair. +/// Returned by the `/proxy_wake_compute` API method. +#[derive(Debug, Deserialize)] +pub struct WakeCompute { + pub address: Box, + pub aux: MetricsAuxInfo, +} + +/// Async response which concludes the link auth flow. +/// Also known as `kickResponse` in the console. +#[derive(Debug, Deserialize)] +pub struct KickSession<'a> { + /// Session ID is assigned by the proxy. + pub session_id: &'a str, + + /// Compute node connection params. + #[serde(deserialize_with = "KickSession::parse_db_info")] + pub result: DatabaseInfo, +} + +impl KickSession<'_> { + fn parse_db_info<'de, D>(des: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + enum Wrapper { + // Currently, console only reports `Success`. + // `Failure(String)` used to be here... RIP. + Success(DatabaseInfo), + } + + Wrapper::deserialize(des).map(|x| match x { + Wrapper::Success(info) => info, + }) + } +} + +/// Compute node connection params. +#[derive(Deserialize)] +pub struct DatabaseInfo { + pub host: String, + pub port: u16, + pub dbname: String, + pub user: String, + /// Console always provides a password, but it might + /// be inconvenient for debug with local PG instance. + pub password: Option, + pub aux: MetricsAuxInfo, +} + +// Manually implement debug to omit sensitive info. +impl fmt::Debug for DatabaseInfo { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("DatabaseInfo") + .field("host", &self.host) + .field("port", &self.port) + .field("dbname", &self.dbname) + .field("user", &self.user) + .finish_non_exhaustive() + } +} + +/// Various labels for prometheus metrics. +/// Also known as `ProxyMetricsAuxInfo` in the console. +#[derive(Debug, Deserialize, Default)] +pub struct MetricsAuxInfo { + pub endpoint_id: Box, + pub project_id: Box, + pub branch_id: Box, +} + +impl MetricsAuxInfo { + /// Definitions of labels for traffic metric. + pub const TRAFFIC_LABELS: &'static [&'static str] = &[ + // Received (rx) / sent (tx). + "direction", + // ID of a project. + "project_id", + // ID of an endpoint within a project. + "endpoint_id", + // ID of a branch within a project (snapshot). + "branch_id", + ]; + + /// Values of labels for traffic metric. + // TODO: add more type safety (validate arity & positions). + pub fn traffic_labels(&self, direction: &'static str) -> [&str; 4] { + [ + direction, + &self.project_id, + &self.endpoint_id, + &self.branch_id, + ] + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn dummy_aux() -> serde_json::Value { + json!({ + "endpoint_id": "endpoint", + "project_id": "project", + "branch_id": "branch", + }) + } + + #[test] + fn parse_kick_session() -> anyhow::Result<()> { + // This is what the console's kickResponse looks like. + let json = json!({ + "session_id": "deadbeef", + "result": { + "Success": { + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + "password": "password", + "aux": dummy_aux(), + } + } + }); + let _: KickSession = serde_json::from_str(&json.to_string())?; + + Ok(()) + } + + #[test] + fn parse_db_info() -> anyhow::Result<()> { + // with password + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + "password": "password", + "aux": dummy_aux(), + }))?; + + // without password + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + "aux": dummy_aux(), + }))?; + + // new field (forward compatibility) + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + "project": "hello_world", + "N.E.W": "forward compatibility check", + "aux": dummy_aux(), + }))?; + + Ok(()) + } +} diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 2855d1f900..89ea9142a9 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -8,6 +8,7 @@ mod auth; mod cancellation; mod compute; mod config; +mod console; mod error; mod http; mod mgmt; diff --git a/proxy/src/mgmt.rs b/proxy/src/mgmt.rs index 23e10b5a9b..2e0a502e7f 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/mgmt.rs @@ -1,7 +1,9 @@ -use crate::auth; +use crate::{ + auth, + console::messages::{DatabaseInfo, KickSession}, +}; use anyhow::Context; use pq_proto::{BeMessage, SINGLE_COL_ROWDESC}; -use serde::Deserialize; use std::{ net::{TcpListener, TcpStream}, thread, @@ -50,59 +52,9 @@ fn handle_connection(socket: TcpStream) -> anyhow::Result<()> { pgbackend.run(&mut MgmtHandler) } -/// Known as `kickResponse` in the console. -#[derive(Debug, Deserialize)] -struct PsqlSessionResponse { - session_id: String, - result: PsqlSessionResult, -} - -#[derive(Debug, Deserialize)] -enum PsqlSessionResult { - Success(DatabaseInfo), - Failure(String), -} - /// A message received by `mgmt` when a compute node is ready. pub type ComputeReady = Result; -impl PsqlSessionResult { - fn into_compute_ready(self) -> ComputeReady { - match self { - Self::Success(db_info) => Ok(db_info), - Self::Failure(message) => Err(message), - } - } -} - -/// Compute node connection params provided by the console. -/// This struct and its parents are mgmt API implementation -/// detail and thus should remain in this module. -// TODO: restore deserialization tests from git history. -#[derive(Deserialize)] -pub struct DatabaseInfo { - pub host: String, - pub port: u16, - pub dbname: String, - pub user: String, - /// Console always provides a password, but it might - /// be inconvenient for debug with local PG instance. - pub password: Option, - pub project: String, -} - -// Manually implement debug to omit sensitive info. -impl std::fmt::Debug for DatabaseInfo { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - fmt.debug_struct("DatabaseInfo") - .field("host", &self.host) - .field("port", &self.port) - .field("dbname", &self.dbname) - .field("user", &self.user) - .finish_non_exhaustive() - } -} - // TODO: replace with an http-based protocol. struct MgmtHandler; impl postgres_backend::Handler for MgmtHandler { @@ -115,13 +67,13 @@ impl postgres_backend::Handler for MgmtHandler { } fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<()> { - let resp: PsqlSessionResponse = serde_json::from_str(query)?; + let resp: KickSession = serde_json::from_str(query)?; let span = info_span!("event", session_id = resp.session_id); let _enter = span.enter(); info!("got response: {:?}", resp.result); - match auth::backend::notify(&resp.session_id, resp.result.into_compute_ready()) { + match auth::backend::notify(resp.session_id, Ok(resp.result)) { Ok(()) => { pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))? @@ -135,43 +87,3 @@ fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<( Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn parse_db_info() -> anyhow::Result<()> { - // with password - let _: DatabaseInfo = serde_json::from_value(json!({ - "host": "localhost", - "port": 5432, - "dbname": "postgres", - "user": "john_doe", - "password": "password", - "project": "hello_world", - }))?; - - // without password - let _: DatabaseInfo = serde_json::from_value(json!({ - "host": "localhost", - "port": 5432, - "dbname": "postgres", - "user": "john_doe", - "project": "hello_world", - }))?; - - // new field (forward compatibility) - let _: DatabaseInfo = serde_json::from_value(json!({ - "host": "localhost", - "port": 5432, - "dbname": "postgres", - "user": "john_doe", - "project": "hello_world", - "N.E.W": "forward compatibility check", - }))?; - - Ok(()) - } -} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 713388c625..382f7cd918 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -11,7 +11,7 @@ use anyhow::{bail, Context}; use futures::TryFutureExt; use metrics::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec}; use once_cell::sync::Lazy; -use pq_proto::{BeMessage as Be, *}; +use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{error, info, info_span, Instrument}; @@ -39,12 +39,7 @@ static NUM_BYTES_PROXIED_COUNTER: Lazy = Lazy::new(|| { register_int_counter_vec!( "proxy_io_bytes_per_client", "Number of bytes sent/received between client and backend.", - &[ - // Received (rx) / sent (tx). - "direction", - // Proxy can keep calling it `project` internally. - "endpoint_id" - ] + crate::console::messages::MetricsAuxInfo::TRAFFIC_LABELS, ) .unwrap() }); @@ -271,19 +266,16 @@ impl Client<'_, S> { stream .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? - .write_message(&BeMessage::ReadyForQuery) + .write_message(&Be::ReadyForQuery) .await?; - // TODO: add more identifiers. - let metric_id = node.project; - - let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx", &metric_id]); + let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("tx")); let mut client = MeasuredStream::new(stream.into_inner(), |cnt| { // Number of bytes we sent to the client (outbound). m_sent.inc_by(cnt as u64); }); - let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx", &metric_id]); + let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("rx")); let mut db = MeasuredStream::new(db.stream, |cnt| { // Number of bytes the client sent to the compute node (inbound). m_recv.inc_by(cnt as u64); diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 2f023844d0..ed429df421 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -140,7 +140,7 @@ async fn dummy_proxy( stream .write_message_noflush(&Be::AuthenticationOk)? .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&BeMessage::ReadyForQuery) + .write_message(&Be::ReadyForQuery) .await?; Ok(()) diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index bcea4d970c..e13ba51f4b 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -63,7 +63,11 @@ async def test_psql_session_id(vanilla_pg: VanillaPostgres, link_proxy: NeonProx "port": local_vanilla_pg.default_options["port"], "dbname": local_vanilla_pg.default_options["dbname"], "user": pg_user, - "project": "irrelevant", + "aux": { + "project_id": "project", + "endpoint_id": "endpoint", + "branch_id": "branch", + }, } }, }