From 9470bc9fe0f79f889aaf616b8148bc75988f6271 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Mon, 14 Nov 2022 18:32:31 +0300 Subject: [PATCH] [proxy] Implement per-tenant traffic metrics --- proxy/src/auth.rs | 2 +- proxy/src/auth/backend.rs | 193 +++++++++++++++----------- proxy/src/auth/backend/console.rs | 18 +-- proxy/src/auth/backend/link.rs | 21 ++- proxy/src/auth/backend/postgres.rs | 18 +-- proxy/src/auth/credentials.rs | 12 ++ proxy/src/compute.rs | 55 +++++--- proxy/src/mgmt.rs | 156 ++++++++++++++------- proxy/src/proxy.rs | 59 +++++--- test_runner/fixtures/neon_fixtures.py | 3 +- test_runner/regress/test_proxy.py | 126 ++++++----------- 11 files changed, 378 insertions(+), 285 deletions(-) diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 2df4f9d920..f272f9adc1 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -1,7 +1,7 @@ //! Client authentication mechanisms. pub mod backend; -pub use backend::{BackendType, ConsoleReqExtra, DatabaseInfo}; +pub use backend::{BackendType, ConsoleReqExtra}; mod credentials; pub use credentials::ClientCredentials; diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index bb919770c1..4b937f017a 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -12,7 +12,6 @@ use crate::{ waiters::{self, Waiter, Waiters}, }; use once_cell::sync::Lazy; -use serde::{Deserialize, Serialize}; use std::borrow::Cow; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; @@ -36,45 +35,6 @@ pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), wait CPLANE_WAITERS.notify(psql_session_id, msg) } -/// Compute node connection params provided by the cloud. -/// Note how it implements serde traits, since we receive it over the wire. -#[derive(Serialize, Deserialize, Default)] -pub struct DatabaseInfo { - pub host: String, - pub port: u16, - pub dbname: String, - pub user: String, - pub password: Option, -} - -// Manually implement debug to omit personal and sensitive info. -impl std::fmt::Debug for DatabaseInfo { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - fmt.debug_struct("DatabaseInfo") - .field("host", &self.host) - .field("port", &self.port) - .finish_non_exhaustive() - } -} - -impl From for tokio_postgres::Config { - fn from(db_info: DatabaseInfo) -> Self { - let mut config = tokio_postgres::Config::new(); - - config - .host(&db_info.host) - .port(db_info.port) - .dbname(&db_info.dbname) - .user(&db_info.user); - - if let Some(password) = db_info.password { - config.password(password); - } - - config - } -} - /// Extra query params we'd like to pass to the console. pub struct ConsoleReqExtra<'a> { /// A unique identifier for a connection. @@ -158,54 +118,107 @@ impl<'a, T, E> BackendType<'a, Result> { } } +/// A product of successful authentication. +pub struct AuthSuccess { + /// Did we send [`pq_proto::BeMessage::AuthenticationOk`] to client? + pub reported_auth_ok: bool, + /// Something to be considered a positive result. + pub value: T, +} + +impl AuthSuccess { + /// Very similar to [`std::option::Option::map`]. + /// Maps [`AuthSuccess`] to [`AuthSuccess`] by applying + /// a function to a contained value. + pub fn map(self, f: impl FnOnce(T) -> R) -> AuthSuccess { + AuthSuccess { + reported_auth_ok: self.reported_auth_ok, + value: f(self.value), + } + } +} + +/// Info for establishing a connection to a compute node. +/// This is what we get after auth succeeded, but not before! +pub struct NodeInfo { + /// Project from [`auth::ClientCredentials`]. + pub project: String, + /// Compute node connection params. + pub config: compute::ConnCfg, +} + impl BackendType<'_, ClientCredentials<'_>> { + /// Do something special if user didn't provide the `project` parameter. + async fn try_password_hack( + &mut self, + extra: &ConsoleReqExtra<'_>, + client: &mut stream::PqStream, + ) -> auth::Result>> { + use BackendType::*; + + // If there's no project so far, that entails that client doesn't + // support SNI or other means of passing the project name. + // We now expect to see a very specific payload in the place of password. + let fetch_magic_payload = async { + warn!("project name not specified, resorting to the password hack auth flow"); + let payload = AuthFlow::new(client) + .begin(auth::PasswordHack) + .await? + .authenticate() + .await?; + + info!(project = &payload.project, "received missing parameter"); + auth::Result::Ok(payload) + }; + + // TODO: find a proper way to merge those very similar blocks. + let (mut config, payload) = match self { + Console(endpoint, creds) if creds.project.is_none() => { + let payload = fetch_magic_payload.await?; + + let mut creds = creds.as_ref(); + creds.project = Some(payload.project.as_str().into()); + let config = console::Api::new(endpoint, extra, &creds) + .wake_compute() + .await?; + + (config, payload) + } + Postgres(endpoint, creds) if creds.project.is_none() => { + let payload = fetch_magic_payload.await?; + + let mut creds = creds.as_ref(); + creds.project = Some(payload.project.as_str().into()); + let config = postgres::Api::new(endpoint, &creds).wake_compute().await?; + + (config, payload) + } + _ => return Ok(None), + }; + + config.password(payload.password); + Ok(Some(AuthSuccess { + reported_auth_ok: false, + value: NodeInfo { + project: payload.project, + config, + }, + })) + } + /// Authenticate the client via the requested backend, possibly using credentials. pub async fn authenticate( mut self, extra: &ConsoleReqExtra<'_>, client: &mut stream::PqStream, - ) -> super::Result { + ) -> auth::Result> { use BackendType::*; - if let Console(_, creds) | Postgres(_, creds) = &mut self { - // If there's no project so far, that entails that client doesn't - // support SNI or other means of passing the project name. - // We now expect to see a very specific payload in the place of password. - if creds.project().is_none() { - warn!("project name not specified, resorting to the password hack auth flow"); - - let payload = AuthFlow::new(client) - .begin(auth::PasswordHack) - .await? - .authenticate() - .await?; - - // Finally we may finish the initialization of `creds`. - // TODO: add missing type safety to ClientCredentials. - info!(project = &payload.project, "received missing parameter"); - creds.project = Some(payload.project.into()); - - let mut config = match &self { - Console(endpoint, creds) => { - console::Api::new(endpoint, extra, creds) - .wake_compute() - .await? - } - Postgres(endpoint, creds) => { - postgres::Api::new(endpoint, creds).wake_compute().await? - } - _ => unreachable!("see the patterns above"), - }; - - // We should use a password from payload as well. - config.password(payload.password); - - info!("user successfully authenticated (using the password hack)"); - return Ok(compute::NodeInfo { - reported_auth_ok: false, - config, - }); - } + // Handle cases when `project` is missing in `creds`. + // TODO: type safety: return `creds` with irrefutable `project`. + if let Some(res) = self.try_password_hack(extra, client).await? { + info!("user successfully authenticated (using the password hack)"); + return Ok(res); } let res = match self { @@ -215,22 +228,34 @@ impl BackendType<'_, ClientCredentials<'_>> { project = creds.project(), "performing authentication using the console" ); + + assert!(creds.project.is_some()); console::Api::new(&endpoint, extra, &creds) .handle_user(client) - .await + .await? + .map(|config| NodeInfo { + project: creds.project.unwrap().into_owned(), + config, + }) } Postgres(endpoint, creds) => { info!("performing mock authentication using a local postgres instance"); + + assert!(creds.project.is_some()); postgres::Api::new(&endpoint, &creds) .handle_user(client) - .await + .await? + .map(|config| NodeInfo { + project: creds.project.unwrap().into_owned(), + config, + }) } // NOTE: this auth backend doesn't use client credentials. Link(url) => { info!("performing link authentication"); - link::handle_user(&url, client).await + link::handle_user(&url, client).await? } - }?; + }; info!("user successfully authenticated"); Ok(res) diff --git a/proxy/src/auth/backend/console.rs b/proxy/src/auth/backend/console.rs index cf99aa08ef..929dfb33f7 100644 --- a/proxy/src/auth/backend/console.rs +++ b/proxy/src/auth/backend/console.rs @@ -1,9 +1,9 @@ //! Cloud API V2. -use super::ConsoleReqExtra; +use super::{AuthSuccess, ConsoleReqExtra}; use crate::{ auth::{self, AuthFlow, ClientCredentials}, - compute::{self, ComputeConnCfg}, + compute, error::{io_error, UserFacingError}, http, scram, stream::PqStream, @@ -128,7 +128,7 @@ impl<'a> Api<'a> { pub(super) async fn handle_user( self, client: &mut PqStream, - ) -> auth::Result { + ) -> auth::Result> { handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await } @@ -164,7 +164,7 @@ impl<'a> Api<'a> { } /// Wake up the compute node and return the corresponding connection info. - pub(super) async fn wake_compute(&self) -> Result { + pub(super) async fn wake_compute(&self) -> Result { let request_id = uuid::Uuid::new_v4().to_string(); let req = self .endpoint @@ -195,7 +195,7 @@ impl<'a> Api<'a> { Some(x) => x, }; - let mut config = ComputeConnCfg::new(); + let mut config = compute::ConnCfg::new(); config .host(host) .port(port) @@ -213,10 +213,10 @@ pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>( endpoint: &'a Endpoint, get_auth_info: impl FnOnce(&'a Endpoint) -> GetAuthInfo, wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute, -) -> auth::Result +) -> auth::Result> where GetAuthInfo: Future>, - WakeCompute: Future>, + WakeCompute: Future>, { info!("fetching user's authentication info"); let auth_info = get_auth_info(endpoint).await?; @@ -243,9 +243,9 @@ where config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(keys)); } - Ok(compute::NodeInfo { + Ok(AuthSuccess { reported_auth_ok: false, - config, + value: config, }) } diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index 96c6f0ba18..440a55f194 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -1,3 +1,4 @@ +use super::{AuthSuccess, NodeInfo}; use crate::{auth, compute, error::UserFacingError, stream::PqStream, waiters}; use pq_proto::{BeMessage as Be, BeParameterStatusMessage}; use thiserror::Error; @@ -49,7 +50,7 @@ pub fn new_psql_session_id() -> String { pub async fn handle_user( link_uri: &reqwest::Url, client: &mut PqStream, -) -> auth::Result { +) -> auth::Result> { let psql_session_id = new_psql_session_id(); let span = info_span!("link", psql_session_id = &psql_session_id); let greeting = hello_message(link_uri, &psql_session_id); @@ -71,8 +72,22 @@ pub async fn handle_user( client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; - Ok(compute::NodeInfo { + let mut config = compute::ConnCfg::new(); + config + .host(&db_info.host) + .port(db_info.port) + .dbname(&db_info.dbname) + .user(&db_info.user); + + if let Some(password) = db_info.password { + config.password(password); + } + + Ok(AuthSuccess { reported_auth_ok: true, - config: db_info.into(), + value: NodeInfo { + project: db_info.project, + config, + }, }) } diff --git a/proxy/src/auth/backend/postgres.rs b/proxy/src/auth/backend/postgres.rs index 2055ee14c8..e56b62622a 100644 --- a/proxy/src/auth/backend/postgres.rs +++ b/proxy/src/auth/backend/postgres.rs @@ -1,12 +1,12 @@ //! Local mock of Cloud API V2. +use super::{ + console::{self, AuthInfo, GetAuthInfoError, TransportError, WakeComputeError}, + AuthSuccess, +}; use crate::{ - auth::{ - self, - backend::console::{self, AuthInfo, GetAuthInfoError, TransportError, WakeComputeError}, - ClientCredentials, - }, - compute::{self, ComputeConnCfg}, + auth::{self, ClientCredentials}, + compute, error::io_error, scram, stream::PqStream, @@ -37,7 +37,7 @@ impl<'a> Api<'a> { pub(super) async fn handle_user( self, client: &mut PqStream, - ) -> auth::Result { + ) -> auth::Result> { // We reuse user handling logic from a production module. console::handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await } @@ -82,8 +82,8 @@ impl<'a> Api<'a> { } /// We don't need to wake anything locally, so we just return the connection info. - pub(super) async fn wake_compute(&self) -> Result { - let mut config = ComputeConnCfg::new(); + pub(super) async fn wake_compute(&self) -> Result { + let mut config = compute::ConnCfg::new(); config .host(self.endpoint.host_str().unwrap_or("localhost")) .port(self.endpoint.port().unwrap_or(5432)) diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 907f99b8e0..4f3238e4ff 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -36,11 +36,23 @@ pub struct ClientCredentials<'a> { } impl ClientCredentials<'_> { + #[inline] pub fn project(&self) -> Option<&str> { self.project.as_deref() } } +impl<'a> ClientCredentials<'a> { + #[inline] + pub fn as_ref(&'a self) -> ClientCredentials<'a> { + Self { + user: self.user, + dbname: self.dbname, + project: self.project().map(Cow::Borrowed), + } + } +} + impl<'a> ClientCredentials<'a> { pub fn parse( params: &'a StartupMessageParams, diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 4771c774a1..4c5edb9673 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -40,17 +40,36 @@ impl UserFacingError for ConnectionError { /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; -pub type ComputeConnCfg = tokio_postgres::Config; +/// A config for establishing a connection to compute node. +/// Eventually, `tokio_postgres` will be replaced with something better. +/// Newtype allows us to implement methods on top of it. +#[repr(transparent)] +pub struct ConnCfg(pub tokio_postgres::Config); -/// Various compute node info for establishing connection etc. -pub struct NodeInfo { - /// Did we send [`pq_proto::BeMessage::AuthenticationOk`]? - pub reported_auth_ok: bool, - /// Compute node connection params. - pub config: tokio_postgres::Config, +impl ConnCfg { + /// Construct a new connection config. + pub fn new() -> Self { + Self(tokio_postgres::Config::new()) + } } -impl NodeInfo { +impl std::ops::Deref for ConnCfg { + type Target = tokio_postgres::Config; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +/// For now, let's make it easier to setup the config. +impl std::ops::DerefMut for ConnCfg { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl ConnCfg { + /// Establish a raw TCP connection to the compute node. async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> { use tokio_postgres::config::Host; @@ -68,8 +87,8 @@ impl NodeInfo { // because it has no means for extracting the underlying socket which we // require for our business. let mut connection_error = None; - let ports = self.config.get_ports(); - let hosts = self.config.get_hosts(); + let ports = self.0.get_ports(); + let hosts = self.0.get_hosts(); // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array if ports.len() > 1 && ports.len() != hosts.len() { return Err(io::Error::new( @@ -77,7 +96,7 @@ impl NodeInfo { format!( "couldn't connect: bad compute config, \ ports and hosts entries' count does not match: {:?}", - self.config + self.0 ), )); } @@ -103,7 +122,7 @@ impl NodeInfo { Err(connection_error.unwrap_or_else(|| { io::Error::new( io::ErrorKind::Other, - format!("couldn't connect: bad compute config: {:?}", self.config), + format!("couldn't connect: bad compute config: {:?}", self.0), ) })) } @@ -116,7 +135,7 @@ pub struct PostgresConnection { pub version: String, } -impl NodeInfo { +impl ConnCfg { /// Connect to a corresponding compute node. pub async fn connect( mut self, @@ -130,21 +149,21 @@ impl NodeInfo { .intersperse(" ") // TODO: use impl from std once it's stabilized .collect(); - self.config.options(&options); + self.0.options(&options); } if let Some(app_name) = params.get("application_name") { - self.config.application_name(app_name); + self.0.application_name(app_name); } if let Some(replication) = params.get("replication") { use tokio_postgres::config::ReplicationMode; match replication { "true" | "on" | "yes" | "1" => { - self.config.replication_mode(ReplicationMode::Physical); + self.0.replication_mode(ReplicationMode::Physical); } "database" => { - self.config.replication_mode(ReplicationMode::Logical); + self.0.replication_mode(ReplicationMode::Logical); } _other => {} } @@ -160,7 +179,7 @@ impl NodeInfo { .map_err(|_| ConnectionError::FailedToConnectToCompute)?; // TODO: establish a secure connection to the DB - let (client, conn) = self.config.connect_raw(&mut stream, NoTls).await?; + let (client, conn) = self.0.connect_raw(&mut stream, NoTls).await?; let version = conn .parameter("server_version") .ok_or(ConnectionError::FailedToFetchPgVersion)? diff --git a/proxy/src/mgmt.rs b/proxy/src/mgmt.rs index 06d1a4f106..23e10b5a9b 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/mgmt.rs @@ -6,16 +6,11 @@ use std::{ net::{TcpListener, TcpStream}, thread, }; -use tracing::{error, info}; +use tracing::{error, info, info_span}; use utils::postgres_backend::{self, AuthType, PostgresBackend}; -/// TODO: move all of that to auth-backend/link.rs when we ditch legacy-console backend - -/// -/// Main proxy listener loop. -/// -/// Listens for connections, and launches a new handler thread for each. -/// +/// Console management API listener thread. +/// It spawns console response handlers needed for the link auth. pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> { scopeguard::defer! { info!("mgmt has shut down"); @@ -24,6 +19,7 @@ pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> { listener .set_nonblocking(false) .context("failed to set listener to blocking")?; + loop { let (socket, peer_addr) = listener.accept().context("failed to accept a new client")?; info!("accepted connection from {peer_addr}"); @@ -31,9 +27,19 @@ pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> { .set_nodelay(true) .context("failed to set client socket option")?; + // TODO: replace with async tasks. thread::spawn(move || { - if let Err(err) = handle_connection(socket) { - error!("{err}"); + let tid = std::thread::current().id(); + let span = info_span!("mgmt", thread = format_args!("{tid:?}")); + let _enter = span.enter(); + + info!("started a new console management API thread"); + scopeguard::defer! { + info!("console management API thread is about to finish"); + } + + if let Err(e) = handle_connection(socket) { + error!("thread failed with an error: {e}"); } }); } @@ -44,44 +50,21 @@ fn handle_connection(socket: TcpStream) -> anyhow::Result<()> { pgbackend.run(&mut MgmtHandler) } -struct MgmtHandler; - -/// Serialized examples: -// { -// "session_id": "71d6d03e6d93d99a", -// "result": { -// "Success": { -// "host": "127.0.0.1", -// "port": 5432, -// "dbname": "stas", -// "user": "stas", -// "password": "mypass" -// } -// } -// } -// { -// "session_id": "71d6d03e6d93d99a", -// "result": { -// "Failure": "oops" -// } -// } -// -// // to test manually by sending a query to mgmt interface: -// psql -h 127.0.0.1 -p 9999 -c '{"session_id":"4f10dde522e14739","result":{"Success":{"host":"127.0.0.1","port":5432,"dbname":"stas","user":"stas","password":"stas"}}}' -#[derive(Deserialize)] +/// Known as `kickResponse` in the console. +#[derive(Debug, Deserialize)] struct PsqlSessionResponse { session_id: String, result: PsqlSessionResult, } -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] enum PsqlSessionResult { - Success(auth::DatabaseInfo), + Success(DatabaseInfo), Failure(String), } /// A message received by `mgmt` when a compute node is ready. -pub type ComputeReady = Result; +pub type ComputeReady = Result; impl PsqlSessionResult { fn into_compute_ready(self) -> ComputeReady { @@ -92,25 +75,51 @@ impl PsqlSessionResult { } } -impl postgres_backend::Handler for MgmtHandler { - fn process_query( - &mut self, - pgb: &mut PostgresBackend, - query_string: &str, - ) -> anyhow::Result<()> { - let res = try_process_query(pgb, query_string); - // intercept and log error message - if res.is_err() { - error!("mgmt query failed: {res:?}"); - } - res +/// Compute node connection params provided by the console. +/// This struct and its parents are mgmt API implementation +/// detail and thus should remain in this module. +// TODO: restore deserialization tests from git history. +#[derive(Deserialize)] +pub struct DatabaseInfo { + pub host: String, + pub port: u16, + pub dbname: String, + pub user: String, + /// Console always provides a password, but it might + /// be inconvenient for debug with local PG instance. + pub password: Option, + pub project: String, +} + +// Manually implement debug to omit sensitive info. +impl std::fmt::Debug for DatabaseInfo { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.debug_struct("DatabaseInfo") + .field("host", &self.host) + .field("port", &self.port) + .field("dbname", &self.dbname) + .field("user", &self.user) + .finish_non_exhaustive() } } -fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::Result<()> { - info!("got mgmt query [redacted]"); // Content contains password, don't print it +// TODO: replace with an http-based protocol. +struct MgmtHandler; +impl postgres_backend::Handler for MgmtHandler { + fn process_query(&mut self, pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<()> { + try_process_query(pgb, query).map_err(|e| { + error!("failed to process response: {e:?}"); + e + }) + } +} - let resp: PsqlSessionResponse = serde_json::from_str(query_string)?; +fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<()> { + let resp: PsqlSessionResponse = serde_json::from_str(query)?; + + let span = info_span!("event", session_id = resp.session_id); + let _enter = span.enter(); + info!("got response: {:?}", resp.result); match auth::backend::notify(&resp.session_id, resp.result.into_compute_ready()) { Ok(()) => { @@ -119,9 +128,50 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R .write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; } Err(e) => { + error!("failed to deliver response to per-client task"); pgb.write_message(&BeMessage::ErrorResponse(&e.to_string()))?; } } Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn parse_db_info() -> anyhow::Result<()> { + // with password + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + "password": "password", + "project": "hello_world", + }))?; + + // without password + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + "project": "hello_world", + }))?; + + // new field (forward compatibility) + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + "project": "hello_world", + "N.E.W": "forward compatibility check", + }))?; + + Ok(()) + } +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 9257fcd650..5988faccec 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -4,7 +4,7 @@ use crate::config::{ProxyConfig, TlsConfig}; use crate::stream::{MeasuredStream, PqStream, Stream}; use anyhow::{bail, Context}; use futures::TryFutureExt; -use metrics::{register_int_counter, IntCounter}; +use metrics::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec}; use once_cell::sync::Lazy; use pq_proto::{BeMessage as Be, *}; use std::sync::Arc; @@ -30,10 +30,16 @@ static NUM_CONNECTIONS_CLOSED_COUNTER: Lazy = Lazy::new(|| { .unwrap() }); -static NUM_BYTES_PROXIED_COUNTER: Lazy = Lazy::new(|| { - register_int_counter!( - "proxy_io_bytes_total", - "Number of bytes sent/received between any client and backend." +static NUM_BYTES_PROXIED_COUNTER: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "proxy_io_bytes_per_client", + "Number of bytes sent/received between client and backend.", + &[ + // Received (rx) / sent (tx). + "direction", + // Proxy can keep calling it `project` internally. + "endpoint_id" + ] ) .unwrap() }); @@ -230,16 +236,17 @@ impl Client<'_, S> { application_name: params.get("application_name"), }; - // Authenticate and connect to a compute node. - let auth = creds - .authenticate(&extra, &mut stream) - .instrument(info_span!("auth")) - .await; - - let node = async { auth }.or_else(|e| stream.throw_error(e)).await?; - let reported_auth_ok = node.reported_auth_ok; + let auth_result = async { + // `&mut stream` doesn't let us merge those 2 lines. + let res = creds.authenticate(&extra, &mut stream).await; + async { res }.or_else(|e| stream.throw_error(e)).await + } + .instrument(info_span!("auth")) + .await?; + let node = auth_result.value; let (db, cancel_closure) = node + .config .connect(params) .or_else(|e| stream.throw_error(e)) .await?; @@ -247,7 +254,9 @@ impl Client<'_, S> { let cancel_key_data = session.enable_query_cancellation(cancel_closure); // Report authentication success if we haven't done this already. - if !reported_auth_ok { + // Note that we do this only (for the most part) after we've connected + // to a compute (see above) which performs its own authentication. + if !auth_result.reported_auth_ok { stream .write_message_noflush(&Be::AuthenticationOk)? .write_message_noflush(&BeParameterStatusMessage::encoding())?; @@ -261,17 +270,23 @@ impl Client<'_, S> { .write_message(&BeMessage::ReadyForQuery) .await?; - /// This function will be called for writes to either direction. - fn inc_proxied(cnt: usize) { - // Consider inventing something more sophisticated - // if this ever becomes a bottleneck (cacheline bouncing). - NUM_BYTES_PROXIED_COUNTER.inc_by(cnt as u64); - } + // TODO: add more identifiers. + let metric_id = node.project; + + let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx", &metric_id]); + let mut client = MeasuredStream::new(stream.into_inner(), |cnt| { + // Number of bytes we sent to the client (outbound). + m_sent.inc_by(cnt as u64); + }); + + let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx", &metric_id]); + let mut db = MeasuredStream::new(db.stream, |cnt| { + // Number of bytes the client sent to the compute node (inbound). + m_recv.inc_by(cnt as u64); + }); // Starting from here we only proxy the client's traffic. info!("performing the proxy pass..."); - let mut db = MeasuredStream::new(db.stream, inc_proxied); - let mut client = MeasuredStream::new(stream.into_inner(), inc_proxied); let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?; Ok(()) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index ffb1df6701..4d17992046 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2096,7 +2096,8 @@ class NeonProxy(PgProtocol): def start(self): """ - Starts a proxy with option '--auth-backend postgres' and a postgres instance already provided though '--auth-endpoint '." + Starts a proxy with option '--auth-backend postgres' and a postgres instance + already provided though '--auth-endpoint '." """ assert self._popen is None assert self.auth_endpoint is not None diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index b8cfb21a5b..e868d6b616 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -1,5 +1,4 @@ import json -import subprocess from urllib.parse import urlparse import psycopg2 @@ -29,108 +28,65 @@ def test_password_hack(static_proxy: NeonProxy): static_proxy.safe_psql("select 1", sslsni=0, user=user, password=magic) -def get_session_id_from_uri_line(uri_prefix, uri_line): +def get_session_id(uri_prefix, uri_line): assert uri_prefix in uri_line url_parts = urlparse(uri_line) psql_session_id = url_parts.path[1:] - assert psql_session_id.isalnum(), "session_id should only contain alphanumeric chars." - link_auth_uri_prefix = uri_line[: -len(url_parts.path)] - # invariant: the prefix must match the uri_prefix. - assert ( - link_auth_uri_prefix == uri_prefix - ), f"Line='{uri_line}' should contain a http auth link of form '{uri_prefix}/'." - # invariant: the entire link_auth_uri should be on its own line, module spaces. - assert " ".join(uri_line.split(" ")) == f"{uri_prefix}/{psql_session_id}" + assert psql_session_id.isalnum(), "session_id should only contain alphanumeric chars" return psql_session_id -def create_and_send_db_info(local_vanilla_pg, psql_session_id, mgmt_port): - pg_user = "proxy" - pg_password = "password" - - local_vanilla_pg.start() - query = f"create user {pg_user} with login superuser password '{pg_password}'" - local_vanilla_pg.safe_psql(query) - - port = local_vanilla_pg.default_options["port"] - host = local_vanilla_pg.default_options["host"] - dbname = local_vanilla_pg.default_options["dbname"] - - db_info_dict = { - "session_id": psql_session_id, - "result": { - "Success": { - "host": host, - "port": port, - "dbname": dbname, - "user": pg_user, - "password": pg_password, - } - }, - } - db_info_str = json.dumps(db_info_dict) - cmd_args = [ - "psql", - "-h", - "127.0.0.1", # localhost - "-p", - f"{mgmt_port}", - "-c", - db_info_str, - ] - - log.info(f"Sending to proxy the user and db info: {' '.join(cmd_args)}") - p = subprocess.Popen(cmd_args, stdout=subprocess.PIPE) - out, err = p.communicate() - assert "ok" in str(out) - - -async def get_uri_line_from_process_welcome_notice(link_auth_uri_prefix, proc): - """ - Returns the line from the welcome notice from proc containing link_auth_uri_prefix. - :param link_auth_uri_prefix: the uri prefix used to indicate the line of interest - :param proc: the process to read the welcome message from. - :return: a line containing the full link authentication uri. - """ - max_num_lines_of_welcome_message = 15 - for attempt in range(max_num_lines_of_welcome_message): - raw_line = await proc.stderr.readline() - line = raw_line.decode("utf-8").strip() +async def find_auth_link(link_auth_uri_prefix, proc): + for _ in range(100): + line = (await proc.stderr.readline()).decode("utf-8").strip() + log.info(f"psql line: {line}") if link_auth_uri_prefix in line: + log.info(f"SUCCESS, found auth url: {line}") return line - assert False, f"did not find line containing '{link_auth_uri_prefix}'" + + +async def activate_link_auth(local_vanilla_pg, link_proxy, psql_session_id): + pg_user = "proxy" + + log.info("creating a new user for link auth test") + local_vanilla_pg.start() + local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser") + + db_info = json.dumps( + { + "session_id": psql_session_id, + "result": { + "Success": { + "host": local_vanilla_pg.default_options["host"], + "port": local_vanilla_pg.default_options["port"], + "dbname": local_vanilla_pg.default_options["dbname"], + "user": pg_user, + "project": "irrelevant", + } + }, + } + ) + + log.info("sending session activation message") + psql = await PSQL(host=link_proxy.host, port=link_proxy.mgmt_port).run(db_info) + out = (await psql.stdout.read()).decode("utf-8").strip() + assert out == "ok" @pytest.mark.asyncio async def test_psql_session_id(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy): - """ - Test copied and modified from: test_project_psql_link_auth test from cloud/tests_e2e/tests/test_project.py - Step 1. establish connection to the proxy - Step 2. retrieve session_id: - Step 2.1: read welcome message - Step 2.2: parse session_id - Step 3. create a vanilla_pg and send user and db info via command line (using Popen) a psql query via mgmt port to proxy. - Step 4. assert that select 1 has been executed correctly. - """ - - psql = PSQL( - host=link_proxy.host, - port=link_proxy.proxy_port, - ) - proc = await psql.run("select 42") + psql = await PSQL(host=link_proxy.host, port=link_proxy.proxy_port).run("select 42") uri_prefix = link_proxy.link_auth_uri_prefix - line_str = await get_uri_line_from_process_welcome_notice(uri_prefix, proc) + link = await find_auth_link(uri_prefix, psql) - psql_session_id = get_session_id_from_uri_line(uri_prefix, line_str) - log.info(f"Parsed psql_session_id='{psql_session_id}' from Neon welcome message.") + psql_session_id = get_session_id(uri_prefix, link) + await activate_link_auth(vanilla_pg, link_proxy, psql_session_id) - create_and_send_db_info(vanilla_pg, psql_session_id, link_proxy.mgmt_port) - - assert proc.stdout is not None - out = (await proc.stdout.read()).decode("utf-8").strip() + assert psql.stdout is not None + out = (await psql.stdout.read()).decode("utf-8").strip() assert out == "42"