diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index cb8e692c18..a7ce944d36 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -6,6 +6,7 @@ pub use link::LinkAuthError; use serde::{Deserialize, Serialize}; use tokio_postgres::config::AuthKeys; +use crate::console::provider::neon::UserRowLevel; use crate::proxy::{handle_try_wake, retry_after, LatencyTimer}; use crate::{ auth::{self, ClientCredentials}, @@ -328,7 +329,7 @@ impl BackendType<'_, ClientCredentials<'_>> { dbname: String, username: String, policies: Vec, - ) -> anyhow::Result { + ) -> anyhow::Result { use BackendType::*; match self { diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 3988c50280..b67b17b3fe 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -1,6 +1,8 @@ pub mod mock; pub mod neon; +use self::neon::UserRowLevel; + use super::messages::MetricsAuxInfo; use crate::{ auth::{backend::Policy, ClientCredentials}, @@ -257,7 +259,7 @@ pub trait Api { dbname: String, username: String, policies: Vec, - ) -> anyhow::Result; + ) -> anyhow::Result; } /// Various caches for [`console`](super). diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 68bc909b00..694679c2e5 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -2,6 +2,7 @@ use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, + neon::UserRowLevel, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, }; use crate::{ @@ -136,7 +137,7 @@ impl super::Api for Api { _dbname: String, _username: String, _policies: Vec, - ) -> anyhow::Result { + ) -> anyhow::Result { Err(anyhow::anyhow!("unimplemented")) } } diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index e6beeb4903..f7a5040362 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -15,7 +15,7 @@ use serde::{Deserialize, Serialize}; use std::{net::SocketAddr, sync::Arc}; use tokio::time::Instant; use tokio_postgres::config::SslMode; -use tracing::{error, info, info_span, warn, Instrument}; +use tracing::{debug, error, info, info_span, warn, Instrument}; #[derive(Clone)] pub struct Api { @@ -151,7 +151,7 @@ impl Api { dbname: String, username: String, policies: Vec, - ) -> anyhow::Result { + ) -> anyhow::Result { let project = creds.project().expect("impossible"); let request_id = uuid::Uuid::new_v4().to_string(); async { @@ -189,7 +189,9 @@ impl Api { info!(duration = ?start.elapsed(), "received http response"); let body = parse_body::(response).await?; - Ok(body.password) + debug!(user = %body.username, pw=%body.password, "please don't merge this in production"); + + Ok(body) } .map_err(crate::error::log_error) .instrument(info_span!("http", id = request_id)) @@ -213,8 +215,9 @@ struct Target { } #[derive(Deserialize)] -struct UserRowLevel { - password: String, +pub struct UserRowLevel { + pub username: String, + pub password: String, } #[async_trait] @@ -274,7 +277,7 @@ impl super::Api for Api { dbname: String, username: String, policies: Vec, - ) -> anyhow::Result { + ) -> anyhow::Result { self.do_ensure_row_level(extra, creds, dbname, username, policies) .await } diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 6b48dc1c0e..3b2e313c81 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -22,7 +22,7 @@ use tokio_postgres::{AsyncMessage, ReadyForQueryStatus}; use crate::{ auth::{self, backend::Policy}, - console, + console::{self, provider::neon::UserRowLevel}, proxy::{ neon_options, LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER, NUM_DB_CONNECTIONS_OPENED_COUNTER, @@ -368,7 +368,7 @@ struct TokioMechanism<'a> { conn_info: &'a ConnInfo, session_id: uuid::Uuid, conn_id: uuid::Uuid, - password: Option, + row_level: Option, } #[async_trait] @@ -388,7 +388,7 @@ impl ConnectMechanism for TokioMechanism<'_> { timeout, self.conn_id, self.session_id, - self.password.as_deref(), + &self.row_level, ) .await } @@ -436,9 +436,9 @@ async fn connect_to_compute( .await? .context("missing cache entry from wake_compute")?; - let mut password = None; + let mut row_level = None; if let Some(policies) = &conn_info.policies { - password = Some( + row_level = Some( creds .ensure_row_level( &extra, @@ -455,7 +455,7 @@ async fn connect_to_compute( conn_id, conn_info, session_id, - password, + row_level, }, node_info, &extra, @@ -471,13 +471,24 @@ async fn connect_to_compute_once( timeout: time::Duration, conn_id: uuid::Uuid, mut session: uuid::Uuid, - password: Option<&str>, + row_level: &Option, ) -> Result { let mut config = (*node_info.config).clone(); + let username = row_level + .as_ref() + .map(|r| &r.username) + .unwrap_or(&conn_info.username); + info!(%username, dbname = %conn_info.dbname, "connecting"); + let (client, mut connection) = config - .user(&conn_info.username) - .password(password.unwrap_or(&conn_info.password)) + .user(username) + .password( + row_level + .as_ref() + .map(|r| &r.password) + .unwrap_or(&conn_info.password), + ) .dbname(&conn_info.dbname) .connect_timeout(timeout) .connect(tokio_postgres::NoTls)