From 6506fd14c45bf4fd685e8ba25cbd609502537155 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 2 Feb 2024 16:07:35 +0000 Subject: [PATCH 1/2] proxy: more refactors (#6526) ## Problem not really any problem, just some drive-by changes ## Summary of changes 1. move wake compute 2. move json processing 3. move handle_try_wake 4. move test backend to api provider 5. reduce wake-compute concerns 6. remove duplicate wake-compute loop --- proxy/src/auth/backend.rs | 113 +++---- proxy/src/auth/backend/classic.rs | 2 +- proxy/src/auth/flow.rs | 2 +- proxy/src/bin/proxy.rs | 26 +- proxy/src/console/provider.rs | 23 +- proxy/src/console/provider/neon.rs | 1 - proxy/src/proxy.rs | 1 + proxy/src/proxy/connect_compute.rs | 118 ++----- proxy/src/proxy/tests.rs | 16 +- proxy/src/proxy/wake_compute.rs | 95 ++++++ proxy/src/serverless.rs | 1 + proxy/src/serverless/json.rs | 448 ++++++++++++++++++++++++++ proxy/src/serverless/sql_over_http.rs | 447 +------------------------ 13 files changed, 649 insertions(+), 644 deletions(-) create mode 100644 proxy/src/proxy/wake_compute.rs create mode 100644 proxy/src/serverless/json.rs diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 144c9dcff5..236567163e 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -12,8 +12,7 @@ use crate::console::errors::GetAuthInfoError; use crate::console::provider::{CachedRoleSecret, ConsoleBackend}; use crate::console::AuthSecret; use crate::context::RequestMonitoring; -use crate::proxy::connect_compute::handle_try_wake; -use crate::proxy::retry::retry_after; +use crate::proxy::wake_compute::wake_compute; use crate::proxy::NeonOptions; use crate::stream::Stream; use crate::{ @@ -28,11 +27,26 @@ use crate::{ }; use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; use futures::TryFutureExt; -use std::borrow::Cow; -use std::ops::ControlFlow; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{error, info, warn}; +use tracing::info; + +/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality +pub enum MaybeOwned<'a, T> { + Owned(T), + Borrowed(&'a T), +} + +impl std::ops::Deref for MaybeOwned<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + match self { + MaybeOwned::Owned(t) => t, + MaybeOwned::Borrowed(t) => t, + } + } +} /// This type serves two purposes: /// @@ -44,12 +58,9 @@ use tracing::{error, info, warn}; /// backends which require them for the authentication process. pub enum BackendType<'a, T> { /// Cloud API (V2). - Console(Cow<'a, ConsoleBackend>, T), + Console(MaybeOwned<'a, ConsoleBackend>, T), /// Authentication via a web browser. - Link(Cow<'a, url::ApiUrl>), - #[cfg(test)] - /// Test backend. - Test(&'a dyn TestBackend), + Link(MaybeOwned<'a, url::ApiUrl>), } pub trait TestBackend: Send + Sync + 'static { @@ -67,14 +78,14 @@ impl std::fmt::Display for BackendType<'_, ()> { ConsoleBackend::Console(endpoint) => { fmt.debug_tuple("Console").field(&endpoint.url()).finish() } - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] ConsoleBackend::Postgres(endpoint) => { fmt.debug_tuple("Postgres").field(&endpoint.url()).finish() } + #[cfg(test)] + ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(), }, Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), - #[cfg(test)] - Test(_) => fmt.debug_tuple("Test").finish(), } } } @@ -85,10 +96,8 @@ impl BackendType<'_, T> { pub fn as_ref(&self) -> BackendType<'_, &T> { use BackendType::*; match self { - Console(c, x) => Console(Cow::Borrowed(c), x), - Link(c) => Link(Cow::Borrowed(c)), - #[cfg(test)] - Test(x) => Test(*x), + Console(c, x) => Console(MaybeOwned::Borrowed(c), x), + Link(c) => Link(MaybeOwned::Borrowed(c)), } } } @@ -102,8 +111,6 @@ impl<'a, T> BackendType<'a, T> { match self { Console(c, x) => Console(c, f(x)), Link(c) => Link(c), - #[cfg(test)] - Test(x) => Test(x), } } } @@ -116,8 +123,6 @@ impl<'a, T, E> BackendType<'a, Result> { match self { Console(c, x) => x.map(|x| Console(c, x)), Link(c) => Ok(Link(c)), - #[cfg(test)] - Test(x) => Ok(Test(x)), } } } @@ -147,7 +152,7 @@ impl ComputeUserInfo { } pub enum ComputeCredentialKeys { - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] Password(Vec), AuthKeys(AuthKeys), } @@ -277,42 +282,6 @@ async fn authenticate_with_secret( classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await } -/// wake a compute (or retrieve an existing compute session from cache) -async fn wake_compute( - ctx: &mut RequestMonitoring, - api: &impl console::Api, - compute_credentials: ComputeCredentials, -) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> { - let mut num_retries = 0; - let mut node = loop { - let wake_res = api.wake_compute(ctx, &compute_credentials.info).await; - match handle_try_wake(wake_res, num_retries) { - Err(e) => { - error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); - return Err(e.into()); - } - Ok(ControlFlow::Continue(e)) => { - warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node"); - } - Ok(ControlFlow::Break(n)) => break n, - } - - let wait_duration = retry_after(num_retries); - num_retries += 1; - tokio::time::sleep(wait_duration).await; - }; - - ctx.set_project(node.aux.clone()); - - match compute_credentials.keys { - #[cfg(feature = "testing")] - ComputeCredentialKeys::Password(password) => node.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys), - }; - - Ok((node, compute_credentials.info)) -} - impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { /// Get compute endpoint name from the credentials. pub fn get_endpoint(&self) -> Option { @@ -321,8 +290,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { match self { Console(_, user_info) => user_info.endpoint_id.clone(), Link(_) => Some("link".into()), - #[cfg(test)] - Test(_) => Some("test".into()), } } @@ -333,8 +300,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { match self { Console(_, user_info) => &user_info.user, Link(_) => "link", - #[cfg(test)] - Test(_) => "test", } } @@ -359,8 +324,20 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { let compute_credentials = auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?; - let (cache_info, user_info) = wake_compute(ctx, &*api, compute_credentials).await?; - (cache_info, BackendType::Console(api, user_info)) + + let mut num_retries = 0; + let mut node = + wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?; + + ctx.set_project(node.aux.clone()); + + match compute_credentials.keys { + #[cfg(any(test, feature = "testing"))] + ComputeCredentialKeys::Password(password) => node.config.password(password), + ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys), + }; + + (node, BackendType::Console(api, compute_credentials.info)) } // NOTE: this auth backend doesn't use client credentials. Link(url) => { @@ -373,10 +350,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { BackendType::Link(url), ) } - #[cfg(test)] - Test(_) => { - unreachable!("this function should never be called in the test backend") - } }; info!("user successfully authenticated"); @@ -393,8 +366,6 @@ impl BackendType<'_, ComputeUserInfo> { match self { Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await, Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), - #[cfg(test)] - Test(x) => x.get_allowed_ips_and_secret(), } } @@ -409,8 +380,6 @@ impl BackendType<'_, ComputeUserInfo> { match self { Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await, Link(_) => Ok(None), - #[cfg(test)] - Test(x) => x.wake_compute().map(Some), } } } diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 358b335b88..384063ceae 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -20,7 +20,7 @@ pub(super) async fn authenticate( ) -> auth::Result> { let flow = AuthFlow::new(client); let scram_keys = match secret { - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] AuthSecret::Md5(_) => { info!("auth endpoint chooses MD5"); return Err(auth::AuthError::bad_auth_method("MD5")); diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 3151a77263..077178d107 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -172,7 +172,7 @@ pub(super) fn validate_password_and_exchange( secret: AuthSecret, ) -> super::Result> { match secret { - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] AuthSecret::Md5(_) => { // test only Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password( diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 3960b080be..3bbb87808d 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,5 +1,6 @@ use futures::future::Either; use proxy::auth; +use proxy::auth::backend::MaybeOwned; use proxy::config::AuthenticationConfig; use proxy::config::CacheOptions; use proxy::config::HttpConfig; @@ -17,9 +18,9 @@ use proxy::usage_metrics; use anyhow::bail; use proxy::config::{self, ProxyConfig}; use proxy::serverless; +use std::net::SocketAddr; use std::pin::pin; use std::sync::Arc; -use std::{borrow::Cow, net::SocketAddr}; use tokio::net::TcpListener; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -259,18 +260,13 @@ async fn main() -> anyhow::Result<()> { } if let auth::BackendType::Console(api, _) = &config.auth_backend { - match &**api { - proxy::console::provider::ConsoleBackend::Console(api) => { - let cache = api.caches.project_info.clone(); - if let Some(url) = args.redis_notifications { - info!("Starting redis notifications listener ({url})"); - maintenance_tasks - .spawn(notifications::task_main(url.to_owned(), cache.clone())); - } - maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); + if let proxy::console::provider::ConsoleBackend::Console(api) = &**api { + let cache = api.caches.project_info.clone(); + if let Some(url) = args.redis_notifications { + info!("Starting redis notifications listener ({url})"); + maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone())); } - #[cfg(feature = "testing")] - proxy::console::provider::ConsoleBackend::Postgres(_) => {} + maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } } @@ -369,18 +365,18 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let api = console::provider::neon::Api::new(endpoint, caches, locks); let api = console::provider::ConsoleBackend::Console(api); - auth::BackendType::Console(Cow::Owned(api), ()) + auth::BackendType::Console(MaybeOwned::Owned(api), ()) } #[cfg(feature = "testing")] AuthBackend::Postgres => { let url = args.auth_endpoint.parse()?; let api = console::provider::mock::Api::new(url); let api = console::provider::ConsoleBackend::Postgres(api); - auth::BackendType::Console(Cow::Owned(api), ()) + auth::BackendType::Console(MaybeOwned::Owned(api), ()) } AuthBackend::Link => { let url = args.uri.parse()?; - auth::BackendType::Link(Cow::Owned(url)) + auth::BackendType::Link(MaybeOwned::Owned(url)) } }; let http_config = HttpConfig { diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index ff84db7738..c53d929470 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -1,4 +1,4 @@ -#[cfg(feature = "testing")] +#[cfg(any(test, feature = "testing"))] pub mod mock; pub mod neon; @@ -199,7 +199,7 @@ pub mod errors { /// Auth secret which is managed by the cloud. #[derive(Clone, Eq, PartialEq, Debug)] pub enum AuthSecret { - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] /// Md5 hash of user's password. Md5([u8; 16]), @@ -264,13 +264,16 @@ pub trait Api { ) -> Result; } -#[derive(Clone)] +#[non_exhaustive] pub enum ConsoleBackend { /// Current Cloud API (V2). Console(neon::Api), /// Local mock of Cloud API (V2). - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] Postgres(mock::Api), + /// Internal testing + #[cfg(test)] + Test(Box), } #[async_trait] @@ -283,8 +286,10 @@ impl Api for ConsoleBackend { use ConsoleBackend::*; match self { Console(api) => api.get_role_secret(ctx, user_info).await, - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] Postgres(api) => api.get_role_secret(ctx, user_info).await, + #[cfg(test)] + Test(_) => unreachable!("this function should never be called in the test backend"), } } @@ -296,8 +301,10 @@ impl Api for ConsoleBackend { use ConsoleBackend::*; match self { Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, + #[cfg(test)] + Test(api) => api.get_allowed_ips_and_secret(), } } @@ -310,8 +317,10 @@ impl Api for ConsoleBackend { match self { Console(api) => api.wake_compute(ctx, user_info).await, - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] Postgres(api) => api.wake_compute(ctx, user_info).await, + #[cfg(test)] + Test(api) => api.wake_compute(), } } } diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index f22c6d2322..0785419790 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -19,7 +19,6 @@ use tokio::time::Instant; use tokio_postgres::config::SslMode; use tracing::{error, info, info_span, warn, Instrument}; -#[derive(Clone)] pub struct Api { endpoint: http::Endpoint, pub caches: &'static ApiCaches, diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 4aa1f3590d..b68fb26e42 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -5,6 +5,7 @@ pub mod connect_compute; pub mod handshake; pub mod passthrough; pub mod retry; +pub mod wake_compute; use crate::{ auth, diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 8bbe88aa51..58c59dba36 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,15 +1,16 @@ use crate::{ auth, compute::{self, PostgresConnection}, - console::{self, errors::WakeComputeError, Api}, + console::{self, errors::WakeComputeError}, context::RequestMonitoring, - metrics::{bool_to_str, NUM_CONNECTION_FAILURES, NUM_WAKEUP_FAILURES}, - proxy::retry::{retry_after, ShouldRetry}, + metrics::NUM_CONNECTION_FAILURES, + proxy::{ + retry::{retry_after, ShouldRetry}, + wake_compute::wake_compute, + }, }; use async_trait::async_trait; -use hyper::StatusCode; use pq_proto::StartupMessageParams; -use std::ops::ControlFlow; use tokio::time; use tracing::{error, info, warn}; @@ -88,39 +89,6 @@ impl ConnectMechanism for TcpMechanism<'_> { } } -fn report_error(e: &WakeComputeError, retry: bool) { - use crate::console::errors::ApiError; - let retry = bool_to_str(retry); - let kind = match e { - WakeComputeError::BadComputeAddress(_) => "bad_compute_address", - WakeComputeError::ApiError(ApiError::Transport(_)) => "api_transport_error", - WakeComputeError::ApiError(ApiError::Console { - status: StatusCode::LOCKED, - ref text, - }) if text.contains("written data quota exceeded") - || text.contains("the limit for current plan reached") => - { - "quota_exceeded" - } - WakeComputeError::ApiError(ApiError::Console { - status: StatusCode::LOCKED, - .. - }) => "api_console_locked", - WakeComputeError::ApiError(ApiError::Console { - status: StatusCode::BAD_REQUEST, - .. - }) => "api_console_bad_request", - WakeComputeError::ApiError(ApiError::Console { status, .. }) - if status.is_server_error() => - { - "api_console_other_server_error" - } - WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error", - WakeComputeError::TimeoutError => "timeout_error", - }; - NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc(); -} - /// Try to connect to the compute node, retrying if necessary. /// This function might update `node_info`, so we take it by `&mut`. #[tracing::instrument(skip_all)] @@ -137,7 +105,7 @@ where mechanism.update_connect_config(&mut node_info.config); // try once - let (config, err) = match mechanism + let err = match mechanism .connect_once(ctx, &node_info, CONNECT_TIMEOUT) .await { @@ -145,51 +113,27 @@ where ctx.latency_timer.success(); return Ok(res); } - Err(e) => { - error!(error = ?e, "could not connect to compute node"); - (invalidate_cache(node_info), e) - } + Err(e) => e, }; - ctx.latency_timer.cache_miss(); + error!(error = ?err, "could not connect to compute node"); let mut num_retries = 1; - // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node - info!("compute node's state has likely changed; requesting a wake-up"); - let node_info = loop { - let wake_res = match user_info { - auth::BackendType::Console(api, user_info) => api.wake_compute(ctx, user_info).await, - // nothing to do? - auth::BackendType::Link(_) => return Err(err.into()), - // test backend - #[cfg(test)] - auth::BackendType::Test(x) => x.wake_compute(), - }; + match user_info { + auth::BackendType::Console(api, info) => { + // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node + info!("compute node's state has likely changed; requesting a wake-up"); - match handle_try_wake(wake_res, num_retries) { - Err(e) => { - error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); - report_error(&e, false); - return Err(e.into()); - } - // failed to wake up but we can continue to retry - Ok(ControlFlow::Continue(e)) => { - report_error(&e, true); - warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node"); - } - // successfully woke up a compute node and can break the wakeup loop - Ok(ControlFlow::Break(mut node_info)) => { - node_info.config.reuse_password(&config); - mechanism.update_connect_config(&mut node_info.config); - break node_info; - } + ctx.latency_timer.cache_miss(); + let config = invalidate_cache(node_info); + node_info = wake_compute(&mut num_retries, ctx, api, info).await?; + + node_info.config.reuse_password(&config); + mechanism.update_connect_config(&mut node_info.config); } - - let wait_duration = retry_after(num_retries); - num_retries += 1; - - time::sleep(wait_duration).await; + // nothing to do? + auth::BackendType::Link(_) => {} }; // now that we have a new node, try connect to it repeatedly. @@ -221,23 +165,3 @@ where time::sleep(wait_duration).await; } } - -/// Attempts to wake up the compute node. -/// * Returns Ok(Continue(e)) if there was an error waking but retries are acceptable -/// * Returns Ok(Break(node)) if the wakeup succeeded -/// * Returns Err(e) if there was an error -pub fn handle_try_wake( - result: Result, - num_retries: u32, -) -> Result, WakeComputeError> { - match result { - Err(err) => match &err { - WakeComputeError::ApiError(api) if api.should_retry(num_retries) => { - Ok(ControlFlow::Continue(err)) - } - _ => Err(err), - }, - // Ready to try again. - Ok(new) => Ok(ControlFlow::Break(new)), - } -} diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 1f57d343c4..2000774224 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -5,9 +5,9 @@ mod mitm; use super::connect_compute::ConnectMechanism; use super::retry::ShouldRetry; use super::*; -use crate::auth::backend::{ComputeUserInfo, TestBackend}; +use crate::auth::backend::{ComputeUserInfo, MaybeOwned, TestBackend}; use crate::config::CertResolver; -use crate::console::provider::{CachedAllowedIps, CachedRoleSecret}; +use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; use crate::console::{self, CachedNodeInfo, NodeInfo}; use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT}; use crate::{auth, http, sasl, scram}; @@ -371,6 +371,7 @@ enum ConnectAction { Fail, } +#[derive(Clone)] struct TestConnectMechanism { counter: Arc>, sequence: Vec, @@ -490,9 +491,16 @@ fn helper_create_cached_node_info() -> CachedNodeInfo { fn helper_create_connect_info( mechanism: &TestConnectMechanism, -) -> (CachedNodeInfo, auth::BackendType<'_, ComputeUserInfo>) { +) -> (CachedNodeInfo, auth::BackendType<'static, ComputeUserInfo>) { let cache = helper_create_cached_node_info(); - let user_info = auth::BackendType::Test(mechanism); + let user_info = auth::BackendType::Console( + MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))), + ComputeUserInfo { + endpoint: "endpoint".into(), + user: "user".into(), + options: NeonOptions::parse_options_raw(""), + }, + ); (cache, user_info) } diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs new file mode 100644 index 0000000000..925727bdab --- /dev/null +++ b/proxy/src/proxy/wake_compute.rs @@ -0,0 +1,95 @@ +use crate::auth::backend::ComputeUserInfo; +use crate::console::{ + errors::WakeComputeError, + provider::{CachedNodeInfo, ConsoleBackend}, + Api, +}; +use crate::context::RequestMonitoring; +use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES}; +use crate::proxy::retry::retry_after; +use hyper::StatusCode; +use std::ops::ControlFlow; +use tracing::{error, warn}; + +use super::retry::ShouldRetry; + +/// wake a compute (or retrieve an existing compute session from cache) +pub async fn wake_compute( + num_retries: &mut u32, + ctx: &mut RequestMonitoring, + api: &ConsoleBackend, + info: &ComputeUserInfo, +) -> Result { + loop { + let wake_res = api.wake_compute(ctx, info).await; + match handle_try_wake(wake_res, *num_retries) { + Err(e) => { + error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); + report_error(&e, false); + return Err(e); + } + Ok(ControlFlow::Continue(e)) => { + warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node"); + report_error(&e, true); + } + Ok(ControlFlow::Break(n)) => return Ok(n), + } + + let wait_duration = retry_after(*num_retries); + *num_retries += 1; + tokio::time::sleep(wait_duration).await; + } +} + +/// Attempts to wake up the compute node. +/// * Returns Ok(Continue(e)) if there was an error waking but retries are acceptable +/// * Returns Ok(Break(node)) if the wakeup succeeded +/// * Returns Err(e) if there was an error +pub fn handle_try_wake( + result: Result, + num_retries: u32, +) -> Result, WakeComputeError> { + match result { + Err(err) => match &err { + WakeComputeError::ApiError(api) if api.should_retry(num_retries) => { + Ok(ControlFlow::Continue(err)) + } + _ => Err(err), + }, + // Ready to try again. + Ok(new) => Ok(ControlFlow::Break(new)), + } +} + +fn report_error(e: &WakeComputeError, retry: bool) { + use crate::console::errors::ApiError; + let retry = bool_to_str(retry); + let kind = match e { + WakeComputeError::BadComputeAddress(_) => "bad_compute_address", + WakeComputeError::ApiError(ApiError::Transport(_)) => "api_transport_error", + WakeComputeError::ApiError(ApiError::Console { + status: StatusCode::LOCKED, + ref text, + }) if text.contains("written data quota exceeded") + || text.contains("the limit for current plan reached") => + { + "quota_exceeded" + } + WakeComputeError::ApiError(ApiError::Console { + status: StatusCode::LOCKED, + .. + }) => "api_console_locked", + WakeComputeError::ApiError(ApiError::Console { + status: StatusCode::BAD_REQUEST, + .. + }) => "api_console_bad_request", + WakeComputeError::ApiError(ApiError::Console { status, .. }) + if status.is_server_error() => + { + "api_console_other_server_error" + } + WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error", + WakeComputeError::TimeoutError => "timeout_error", + }; + NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc(); +} diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index a2eb7e62cc..7ff93b23b8 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -3,6 +3,7 @@ //! Handles both SQL over HTTP and SQL over Websockets. mod conn_pool; +mod json; mod sql_over_http; mod websocket; diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs new file mode 100644 index 0000000000..05835b23ce --- /dev/null +++ b/proxy/src/serverless/json.rs @@ -0,0 +1,448 @@ +use serde_json::Map; +use serde_json::Value; +use tokio_postgres::types::Kind; +use tokio_postgres::types::Type; +use tokio_postgres::Row; + +// +// Convert json non-string types to strings, so that they can be passed to Postgres +// as parameters. +// +pub fn json_to_pg_text(json: Vec) -> Vec> { + json.iter() + .map(|value| { + match value { + // special care for nulls + Value::Null => None, + + // convert to text with escaping + v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()), + + // avoid escaping here, as we pass this as a parameter + Value::String(s) => Some(s.to_string()), + + // special care for arrays + Value::Array(_) => json_array_to_pg_array(value), + } + }) + .collect() +} + +// +// Serialize a JSON array to a Postgres array. Contrary to the strings in the params +// in the array we need to escape the strings. Postgres is okay with arrays of form +// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving +// it for Postgres to check. +// +// Example of the same escaping in node-postgres: packages/pg/lib/utils.js +// +fn json_array_to_pg_array(value: &Value) -> Option { + match value { + // special care for nulls + Value::Null => None, + + // convert to text with escaping + // here string needs to be escaped, as it is part of the array + v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()), + v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())), + + // recurse into array + Value::Array(arr) => { + let vals = arr + .iter() + .map(json_array_to_pg_array) + .map(|v| v.unwrap_or_else(|| "NULL".to_string())) + .collect::>() + .join(","); + + Some(format!("{{{}}}", vals)) + } + } +} + +// +// Convert postgres row with text-encoded values to JSON object +// +pub fn pg_text_row_to_json( + row: &Row, + columns: &[Type], + raw_output: bool, + array_mode: bool, +) -> Result { + let iter = row + .columns() + .iter() + .zip(columns) + .enumerate() + .map(|(i, (column, typ))| { + let name = column.name(); + let pg_value = row.as_text(i)?; + let json_value = if raw_output { + match pg_value { + Some(v) => Value::String(v.to_string()), + None => Value::Null, + } + } else { + pg_text_to_json(pg_value, typ)? + }; + Ok((name.to_string(), json_value)) + }); + + if array_mode { + // drop keys and aggregate into array + let arr = iter + .map(|r| r.map(|(_key, val)| val)) + .collect::, anyhow::Error>>()?; + Ok(Value::Array(arr)) + } else { + let obj = iter.collect::, anyhow::Error>>()?; + Ok(Value::Object(obj)) + } +} + +// +// Convert postgres text-encoded value to JSON value +// +fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result { + if let Some(val) = pg_value { + if let Kind::Array(elem_type) = pg_type.kind() { + return pg_array_parse(val, elem_type); + } + + match *pg_type { + Type::BOOL => Ok(Value::Bool(val == "t")), + Type::INT2 | Type::INT4 => { + let val = val.parse::()?; + Ok(Value::Number(serde_json::Number::from(val))) + } + Type::FLOAT4 | Type::FLOAT8 => { + let fval = val.parse::()?; + let num = serde_json::Number::from_f64(fval); + if let Some(num) = num { + Ok(Value::Number(num)) + } else { + // Pass Nan, Inf, -Inf as strings + // JS JSON.stringify() does converts them to null, but we + // want to preserve them, so we pass them as strings + Ok(Value::String(val.to_string())) + } + } + Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?), + _ => Ok(Value::String(val.to_string())), + } + } else { + Ok(Value::Null) + } +} + +// +// Parse postgres array into JSON array. +// +// This is a bit involved because we need to handle nested arrays and quoted +// values. Unlike postgres we don't check that all nested arrays have the same +// dimensions, we just return them as is. +// +fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result { + _pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v) +} + +fn _pg_array_parse( + pg_array: &str, + elem_type: &Type, + nested: bool, +) -> Result<(Value, usize), anyhow::Error> { + let mut pg_array_chr = pg_array.char_indices(); + let mut level = 0; + let mut quote = false; + let mut entries: Vec = Vec::new(); + let mut entry = String::new(); + + // skip bounds decoration + if let Some('[') = pg_array.chars().next() { + for (_, c) in pg_array_chr.by_ref() { + if c == '=' { + break; + } + } + } + + fn push_checked( + entry: &mut String, + entries: &mut Vec, + elem_type: &Type, + ) -> Result<(), anyhow::Error> { + if !entry.is_empty() { + // While in usual postgres response we get nulls as None and everything else + // as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while + // string with value 'NULL' will be represented by '"NULL"'). So catch NULLs + // here while we have quotation info and convert them to None. + if entry == "NULL" { + entries.push(pg_text_to_json(None, elem_type)?); + } else { + entries.push(pg_text_to_json(Some(entry), elem_type)?); + } + entry.clear(); + } + + Ok(()) + } + + while let Some((mut i, mut c)) = pg_array_chr.next() { + let mut escaped = false; + + if c == '\\' { + escaped = true; + (i, c) = pg_array_chr.next().unwrap(); + } + + match c { + '{' if !quote => { + level += 1; + if level > 1 { + let (res, off) = _pg_array_parse(&pg_array[i..], elem_type, true)?; + entries.push(res); + for _ in 0..off - 1 { + pg_array_chr.next(); + } + } + } + '}' if !quote => { + level -= 1; + if level == 0 { + push_checked(&mut entry, &mut entries, elem_type)?; + if nested { + return Ok((Value::Array(entries), i)); + } + } + } + '"' if !escaped => { + if quote { + // end of quoted string, so push it manually without any checks + // for emptiness or nulls + entries.push(pg_text_to_json(Some(&entry), elem_type)?); + entry.clear(); + } + quote = !quote; + } + ',' if !quote => { + push_checked(&mut entry, &mut entries, elem_type)?; + } + _ => { + entry.push(c); + } + } + } + + if level != 0 { + return Err(anyhow::anyhow!("unbalanced array")); + } + + Ok((Value::Array(entries), 0)) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_atomic_types_to_pg_params() { + let json = vec![Value::Bool(true), Value::Bool(false)]; + let pg_params = json_to_pg_text(json); + assert_eq!( + pg_params, + vec![Some("true".to_owned()), Some("false".to_owned())] + ); + + let json = vec![Value::Number(serde_json::Number::from(42))]; + let pg_params = json_to_pg_text(json); + assert_eq!(pg_params, vec![Some("42".to_owned())]); + + let json = vec![Value::String("foo\"".to_string())]; + let pg_params = json_to_pg_text(json); + assert_eq!(pg_params, vec![Some("foo\"".to_owned())]); + + let json = vec![Value::Null]; + let pg_params = json_to_pg_text(json); + assert_eq!(pg_params, vec![None]); + } + + #[test] + fn test_json_array_to_pg_array() { + // atoms and escaping + let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]"; + let json: Value = serde_json::from_str(json).unwrap(); + let pg_params = json_to_pg_text(vec![json]); + assert_eq!( + pg_params, + vec![Some( + "{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned() + )] + ); + + // nested arrays + let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]"; + let json: Value = serde_json::from_str(json).unwrap(); + let pg_params = json_to_pg_text(vec![json]); + assert_eq!( + pg_params, + vec![Some( + "{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned() + )] + ); + // array of objects + let json = r#"[{"foo": 1},{"bar": 2}]"#; + let json: Value = serde_json::from_str(json).unwrap(); + let pg_params = json_to_pg_text(vec![json]); + assert_eq!( + pg_params, + vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())] + ); + } + + #[test] + fn test_atomic_types_parse() { + assert_eq!( + pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(), + json!("foo") + ); + assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null)); + assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42)); + assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42)); + assert_eq!( + pg_text_to_json(Some("42"), &Type::INT8).unwrap(), + json!("42") + ); + assert_eq!( + pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(), + json!(42.42) + ); + assert_eq!( + pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(), + json!(42.42) + ); + assert_eq!( + pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(), + json!("NaN") + ); + assert_eq!( + pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(), + json!("Infinity") + ); + assert_eq!( + pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(), + json!("-Infinity") + ); + + let json: Value = + serde_json::from_str("{\"s\":\"str\",\"n\":42,\"f\":4.2,\"a\":[null,3,\"a\"]}") + .unwrap(); + assert_eq!( + pg_text_to_json( + Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#), + &Type::JSONB + ) + .unwrap(), + json + ); + } + + #[test] + fn test_pg_array_parse_text() { + fn pt(pg_arr: &str) -> Value { + pg_array_parse(pg_arr, &Type::TEXT).unwrap() + } + assert_eq!( + pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#), + json!(["aa\"\\,a", "cha", "bbbb"]) + ); + assert_eq!( + pt(r#"{{"foo","bar"},{"bee","bop"}}"#), + json!([["foo", "bar"], ["bee", "bop"]]) + ); + assert_eq!( + pt(r#"{{{{"foo",NULL,"bop",bup}}}}"#), + json!([[[["foo", null, "bop", "bup"]]]]) + ); + assert_eq!( + pt(r#"{{"1",2,3},{4,NULL,6},{NULL,NULL,NULL}}"#), + json!([["1", "2", "3"], ["4", null, "6"], [null, null, null]]) + ); + } + + #[test] + fn test_pg_array_parse_bool() { + fn pb(pg_arr: &str) -> Value { + pg_array_parse(pg_arr, &Type::BOOL).unwrap() + } + assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true])); + assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]])); + assert_eq!( + pb(r#"{{t,f},{f,t}}"#), + json!([[true, false], [false, true]]) + ); + assert_eq!( + pb(r#"{{t,NULL},{NULL,f}}"#), + json!([[true, null], [null, false]]) + ); + } + + #[test] + fn test_pg_array_parse_numbers() { + fn pn(pg_arr: &str, ty: &Type) -> Value { + pg_array_parse(pg_arr, ty).unwrap() + } + assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3])); + assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3])); + assert_eq!(pn(r#"{1,2,3}"#, &Type::INT8), json!(["1", "2", "3"])); + assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT4), json!([1.0, 2.0, 3.0])); + assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT8), json!([1.0, 2.0, 3.0])); + assert_eq!( + pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT4), + json!([1.1, 2.2, 3.3]) + ); + assert_eq!( + pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT8), + json!([1.1, 2.2, 3.3]) + ); + assert_eq!( + pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT4), + json!(["NaN", "Infinity", "-Infinity"]) + ); + assert_eq!( + pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT8), + json!(["NaN", "Infinity", "-Infinity"]) + ); + } + + #[test] + fn test_pg_array_with_decoration() { + fn p(pg_arr: &str) -> Value { + pg_array_parse(pg_arr, &Type::INT2).unwrap() + } + assert_eq!( + p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#), + json!([[[1, 2, 3], [4, 5, 6]]]) + ); + } + + #[test] + fn test_pg_array_parse_json() { + fn pt(pg_arr: &str) -> Value { + pg_array_parse(pg_arr, &Type::JSONB).unwrap() + } + assert_eq!(pt(r#"{"{}"}"#), json!([{}])); + assert_eq!( + pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#), + json!([{"foo": 1, "bar": 2}]) + ); + assert_eq!( + pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#), + json!([{"foo": 1}, {"bar": 2}]) + ); + assert_eq!( + pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#), + json!([[{"foo": 1}, {"bar": 2}]]) + ); + } +} diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 27c2134221..96bf39c915 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -12,16 +12,12 @@ use hyper::Response; use hyper::StatusCode; use hyper::{Body, HeaderMap, Request}; use serde_json::json; -use serde_json::Map; use serde_json::Value; use tokio_postgres::error::DbError; use tokio_postgres::error::ErrorPosition; -use tokio_postgres::types::Kind; -use tokio_postgres::types::Type; use tokio_postgres::GenericClient; use tokio_postgres::IsolationLevel; use tokio_postgres::ReadyForQueryStatus; -use tokio_postgres::Row; use tokio_postgres::Transaction; use tracing::error; use tracing::instrument; @@ -40,6 +36,7 @@ use crate::RoleName; use super::conn_pool::ConnInfo; use super::conn_pool::GlobalConnPool; +use super::json::{json_to_pg_text, pg_text_row_to_json}; use super::SERVERLESS_DRIVER_SNI; #[derive(serde::Deserialize)] @@ -72,62 +69,6 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); -// -// Convert json non-string types to strings, so that they can be passed to Postgres -// as parameters. -// -fn json_to_pg_text(json: Vec) -> Vec> { - json.iter() - .map(|value| { - match value { - // special care for nulls - Value::Null => None, - - // convert to text with escaping - v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()), - - // avoid escaping here, as we pass this as a parameter - Value::String(s) => Some(s.to_string()), - - // special care for arrays - Value::Array(_) => json_array_to_pg_array(value), - } - }) - .collect() -} - -// -// Serialize a JSON array to a Postgres array. Contrary to the strings in the params -// in the array we need to escape the strings. Postgres is okay with arrays of form -// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving -// it for Postgres to check. -// -// Example of the same escaping in node-postgres: packages/pg/lib/utils.js -// -fn json_array_to_pg_array(value: &Value) -> Option { - match value { - // special care for nulls - Value::Null => None, - - // convert to text with escaping - // here string needs to be escaped, as it is part of the array - v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()), - v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())), - - // recurse into array - Value::Array(arr) => { - let vals = arr - .iter() - .map(json_array_to_pg_array) - .map(|v| v.unwrap_or_else(|| "NULL".to_string())) - .collect::>() - .join(","); - - Some(format!("{{{}}}", vals)) - } - } -} - fn get_conn_info( ctx: &mut RequestMonitoring, headers: &HeaderMap, @@ -611,389 +552,3 @@ async fn query_to_json( }), )) } - -// -// Convert postgres row with text-encoded values to JSON object -// -pub fn pg_text_row_to_json( - row: &Row, - columns: &[Type], - raw_output: bool, - array_mode: bool, -) -> Result { - let iter = row - .columns() - .iter() - .zip(columns) - .enumerate() - .map(|(i, (column, typ))| { - let name = column.name(); - let pg_value = row.as_text(i)?; - let json_value = if raw_output { - match pg_value { - Some(v) => Value::String(v.to_string()), - None => Value::Null, - } - } else { - pg_text_to_json(pg_value, typ)? - }; - Ok((name.to_string(), json_value)) - }); - - if array_mode { - // drop keys and aggregate into array - let arr = iter - .map(|r| r.map(|(_key, val)| val)) - .collect::, anyhow::Error>>()?; - Ok(Value::Array(arr)) - } else { - let obj = iter.collect::, anyhow::Error>>()?; - Ok(Value::Object(obj)) - } -} - -// -// Convert postgres text-encoded value to JSON value -// -pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result { - if let Some(val) = pg_value { - if let Kind::Array(elem_type) = pg_type.kind() { - return pg_array_parse(val, elem_type); - } - - match *pg_type { - Type::BOOL => Ok(Value::Bool(val == "t")), - Type::INT2 | Type::INT4 => { - let val = val.parse::()?; - Ok(Value::Number(serde_json::Number::from(val))) - } - Type::FLOAT4 | Type::FLOAT8 => { - let fval = val.parse::()?; - let num = serde_json::Number::from_f64(fval); - if let Some(num) = num { - Ok(Value::Number(num)) - } else { - // Pass Nan, Inf, -Inf as strings - // JS JSON.stringify() does converts them to null, but we - // want to preserve them, so we pass them as strings - Ok(Value::String(val.to_string())) - } - } - Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?), - _ => Ok(Value::String(val.to_string())), - } - } else { - Ok(Value::Null) - } -} - -// -// Parse postgres array into JSON array. -// -// This is a bit involved because we need to handle nested arrays and quoted -// values. Unlike postgres we don't check that all nested arrays have the same -// dimensions, we just return them as is. -// -fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result { - _pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v) -} - -fn _pg_array_parse( - pg_array: &str, - elem_type: &Type, - nested: bool, -) -> Result<(Value, usize), anyhow::Error> { - let mut pg_array_chr = pg_array.char_indices(); - let mut level = 0; - let mut quote = false; - let mut entries: Vec = Vec::new(); - let mut entry = String::new(); - - // skip bounds decoration - if let Some('[') = pg_array.chars().next() { - for (_, c) in pg_array_chr.by_ref() { - if c == '=' { - break; - } - } - } - - fn push_checked( - entry: &mut String, - entries: &mut Vec, - elem_type: &Type, - ) -> Result<(), anyhow::Error> { - if !entry.is_empty() { - // While in usual postgres response we get nulls as None and everything else - // as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while - // string with value 'NULL' will be represented by '"NULL"'). So catch NULLs - // here while we have quotation info and convert them to None. - if entry == "NULL" { - entries.push(pg_text_to_json(None, elem_type)?); - } else { - entries.push(pg_text_to_json(Some(entry), elem_type)?); - } - entry.clear(); - } - - Ok(()) - } - - while let Some((mut i, mut c)) = pg_array_chr.next() { - let mut escaped = false; - - if c == '\\' { - escaped = true; - (i, c) = pg_array_chr.next().unwrap(); - } - - match c { - '{' if !quote => { - level += 1; - if level > 1 { - let (res, off) = _pg_array_parse(&pg_array[i..], elem_type, true)?; - entries.push(res); - for _ in 0..off - 1 { - pg_array_chr.next(); - } - } - } - '}' if !quote => { - level -= 1; - if level == 0 { - push_checked(&mut entry, &mut entries, elem_type)?; - if nested { - return Ok((Value::Array(entries), i)); - } - } - } - '"' if !escaped => { - if quote { - // end of quoted string, so push it manually without any checks - // for emptiness or nulls - entries.push(pg_text_to_json(Some(&entry), elem_type)?); - entry.clear(); - } - quote = !quote; - } - ',' if !quote => { - push_checked(&mut entry, &mut entries, elem_type)?; - } - _ => { - entry.push(c); - } - } - } - - if level != 0 { - return Err(anyhow::anyhow!("unbalanced array")); - } - - Ok((Value::Array(entries), 0)) -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_atomic_types_to_pg_params() { - let json = vec![Value::Bool(true), Value::Bool(false)]; - let pg_params = json_to_pg_text(json); - assert_eq!( - pg_params, - vec![Some("true".to_owned()), Some("false".to_owned())] - ); - - let json = vec![Value::Number(serde_json::Number::from(42))]; - let pg_params = json_to_pg_text(json); - assert_eq!(pg_params, vec![Some("42".to_owned())]); - - let json = vec![Value::String("foo\"".to_string())]; - let pg_params = json_to_pg_text(json); - assert_eq!(pg_params, vec![Some("foo\"".to_owned())]); - - let json = vec![Value::Null]; - let pg_params = json_to_pg_text(json); - assert_eq!(pg_params, vec![None]); - } - - #[test] - fn test_json_array_to_pg_array() { - // atoms and escaping - let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]"; - let json: Value = serde_json::from_str(json).unwrap(); - let pg_params = json_to_pg_text(vec![json]); - assert_eq!( - pg_params, - vec![Some( - "{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned() - )] - ); - - // nested arrays - let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]"; - let json: Value = serde_json::from_str(json).unwrap(); - let pg_params = json_to_pg_text(vec![json]); - assert_eq!( - pg_params, - vec![Some( - "{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned() - )] - ); - // array of objects - let json = r#"[{"foo": 1},{"bar": 2}]"#; - let json: Value = serde_json::from_str(json).unwrap(); - let pg_params = json_to_pg_text(vec![json]); - assert_eq!( - pg_params, - vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())] - ); - } - - #[test] - fn test_atomic_types_parse() { - assert_eq!( - pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(), - json!("foo") - ); - assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null)); - assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42)); - assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42)); - assert_eq!( - pg_text_to_json(Some("42"), &Type::INT8).unwrap(), - json!("42") - ); - assert_eq!( - pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(), - json!(42.42) - ); - assert_eq!( - pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(), - json!(42.42) - ); - assert_eq!( - pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(), - json!("NaN") - ); - assert_eq!( - pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(), - json!("Infinity") - ); - assert_eq!( - pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(), - json!("-Infinity") - ); - - let json: Value = - serde_json::from_str("{\"s\":\"str\",\"n\":42,\"f\":4.2,\"a\":[null,3,\"a\"]}") - .unwrap(); - assert_eq!( - pg_text_to_json( - Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#), - &Type::JSONB - ) - .unwrap(), - json - ); - } - - #[test] - fn test_pg_array_parse_text() { - fn pt(pg_arr: &str) -> Value { - pg_array_parse(pg_arr, &Type::TEXT).unwrap() - } - assert_eq!( - pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#), - json!(["aa\"\\,a", "cha", "bbbb"]) - ); - assert_eq!( - pt(r#"{{"foo","bar"},{"bee","bop"}}"#), - json!([["foo", "bar"], ["bee", "bop"]]) - ); - assert_eq!( - pt(r#"{{{{"foo",NULL,"bop",bup}}}}"#), - json!([[[["foo", null, "bop", "bup"]]]]) - ); - assert_eq!( - pt(r#"{{"1",2,3},{4,NULL,6},{NULL,NULL,NULL}}"#), - json!([["1", "2", "3"], ["4", null, "6"], [null, null, null]]) - ); - } - - #[test] - fn test_pg_array_parse_bool() { - fn pb(pg_arr: &str) -> Value { - pg_array_parse(pg_arr, &Type::BOOL).unwrap() - } - assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true])); - assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]])); - assert_eq!( - pb(r#"{{t,f},{f,t}}"#), - json!([[true, false], [false, true]]) - ); - assert_eq!( - pb(r#"{{t,NULL},{NULL,f}}"#), - json!([[true, null], [null, false]]) - ); - } - - #[test] - fn test_pg_array_parse_numbers() { - fn pn(pg_arr: &str, ty: &Type) -> Value { - pg_array_parse(pg_arr, ty).unwrap() - } - assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3])); - assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3])); - assert_eq!(pn(r#"{1,2,3}"#, &Type::INT8), json!(["1", "2", "3"])); - assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT4), json!([1.0, 2.0, 3.0])); - assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT8), json!([1.0, 2.0, 3.0])); - assert_eq!( - pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT4), - json!([1.1, 2.2, 3.3]) - ); - assert_eq!( - pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT8), - json!([1.1, 2.2, 3.3]) - ); - assert_eq!( - pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT4), - json!(["NaN", "Infinity", "-Infinity"]) - ); - assert_eq!( - pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT8), - json!(["NaN", "Infinity", "-Infinity"]) - ); - } - - #[test] - fn test_pg_array_with_decoration() { - fn p(pg_arr: &str) -> Value { - pg_array_parse(pg_arr, &Type::INT2).unwrap() - } - assert_eq!( - p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#), - json!([[[1, 2, 3], [4, 5, 6]]]) - ); - } - #[test] - fn test_pg_array_parse_json() { - fn pt(pg_arr: &str) -> Value { - pg_array_parse(pg_arr, &Type::JSONB).unwrap() - } - assert_eq!(pt(r#"{"{}"}"#), json!([{}])); - assert_eq!( - pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#), - json!([{"foo": 1, "bar": 2}]) - ); - assert_eq!( - pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#), - json!([{"foo": 1}, {"bar": 2}]) - ); - assert_eq!( - pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#), - json!([[{"foo": 1}, {"bar": 2}]]) - ); - } -} From 7e2436695decac52fd0fc5eec11441d0a7e8d407 Mon Sep 17 00:00:00 2001 From: John Spray Date: Fri, 2 Feb 2024 16:57:11 +0000 Subject: [PATCH 2/2] storage controller: use AWS Secrets Manager for database URL, etc (#6585) ## Problem Passing secrets in via CLI/environment is awkward when using helm for deployment, and not ideal for security (secrets may show up in ps, /proc). We can bypass these issues by simply connecting directly to the AWS Secrets Manager service at runtime. ## Summary of changes - Add dependency on aws-sdk-secretsmanager - Update other aws dependencies to latest, to match transitive dependency versions - Add `Secrets` type in attachment service, using AWS SDK to load if secrets are not provided on the command line. --- Cargo.lock | 242 ++++++++++--------- Cargo.toml | 11 +- control_plane/attachment_service/Cargo.toml | 2 + control_plane/attachment_service/src/main.rs | 110 ++++++++- libs/utils/src/auth.rs | 4 + workspace_hack/Cargo.toml | 2 +- 6 files changed, 249 insertions(+), 122 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ea5a29a142..90991ab0a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -275,6 +275,8 @@ name = "attachment_service" version = "0.1.0" dependencies = [ "anyhow", + "aws-config", + "aws-sdk-secretsmanager", "camino", "clap", "control_plane", @@ -304,12 +306,11 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "aws-config" -version = "1.0.1" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80c950a809d39bc9480207cb1cfc879ace88ea7e3a4392a8e9999e45d6e5692e" +checksum = "8b30c39ebe61f75d1b3785362b1586b41991873c9ab3e317a9181c246fb71d82" dependencies = [ "aws-credential-types", - "aws-http", "aws-runtime", "aws-sdk-sso", "aws-sdk-ssooidc", @@ -324,7 +325,7 @@ dependencies = [ "bytes", "fastrand 2.0.0", "hex", - "http", + "http 0.2.9", "hyper", "ring 0.17.6", "time", @@ -335,9 +336,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.0.1" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c1317e1a3514b103cf7d5828bbab3b4d30f56bd22d684f8568bc51b6cfbbb1c" +checksum = "33cc49dcdd31c8b6e79850a179af4c367669150c7ac0135f176c61bec81a70f7" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -345,30 +346,13 @@ dependencies = [ "zeroize", ] -[[package]] -name = "aws-http" -version = "0.60.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "361c4310fdce94328cc2d1ca0c8a48c13f43009c61d3367585685a50ca8c66b6" -dependencies = [ - "aws-smithy-runtime-api", - "aws-smithy-types", - "aws-types", - "bytes", - "http", - "http-body", - "pin-project-lite", - "tracing", -] - [[package]] name = "aws-runtime" -version = "1.0.1" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ed7ef604a15fd0d4d9e43701295161ea6b504b63c44990ead352afea2bc15e9" +checksum = "eb031bff99877c26c28895766f7bb8484a05e24547e370768d6cc9db514662aa" dependencies = [ "aws-credential-types", - "aws-http", "aws-sigv4", "aws-smithy-async", "aws-smithy-eventstream", @@ -376,21 +360,23 @@ dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", + "bytes", "fastrand 2.0.0", - "http", + "http 0.2.9", + "http-body", "percent-encoding", + "pin-project-lite", "tracing", "uuid", ] [[package]] name = "aws-sdk-s3" -version = "1.4.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dcafc2fe52cc30b2d56685e2fa6a879ba50d79704594852112337a472ddbd24" +checksum = "951f7730f51a2155c711c85c79f337fbc02a577fa99d2a0a8059acfce5392113" dependencies = [ "aws-credential-types", - "aws-http", "aws-runtime", "aws-sigv4", "aws-smithy-async", @@ -404,23 +390,22 @@ dependencies = [ "aws-smithy-xml", "aws-types", "bytes", - "http", + "http 0.2.9", "http-body", "once_cell", "percent-encoding", - "regex", + "regex-lite", "tracing", "url", ] [[package]] -name = "aws-sdk-sso" -version = "1.3.0" +name = "aws-sdk-secretsmanager" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0619ab97a5ca8982e7de073cdc66f93e5f6a1b05afc09e696bec1cb3607cd4df" +checksum = "0a0b64e61e7d632d9df90a2e0f32630c68c24960cab1d27d848718180af883d3" dependencies = [ "aws-credential-types", - "aws-http", "aws-runtime", "aws-smithy-async", "aws-smithy-http", @@ -430,19 +415,42 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "http", - "regex", + "fastrand 2.0.0", + "http 0.2.9", + "once_cell", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-sso" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f486420a66caad72635bc2ce0ff6581646e0d32df02aa39dc983bfe794955a5b" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "http 0.2.9", + "once_cell", + "regex-lite", "tracing", ] [[package]] name = "aws-sdk-ssooidc" -version = "1.3.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04b9f5474cc0f35d829510b2ec8c21e352309b46bf9633c5a81fb9321e9b1c7" +checksum = "39ddccf01d82fce9b4a15c8ae8608211ee7db8ed13a70b514bbfe41df3d24841" dependencies = [ "aws-credential-types", - "aws-http", "aws-runtime", "aws-smithy-async", "aws-smithy-http", @@ -452,19 +460,19 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "http", - "regex", + "http 0.2.9", + "once_cell", + "regex-lite", "tracing", ] [[package]] name = "aws-sdk-sts" -version = "1.3.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5700da387716ccfc30b27f44b008f457e1baca5b0f05b6b95455778005e3432a" +checksum = "1a591f8c7e6a621a501b2b5d2e88e1697fcb6274264523a6ad4d5959889a41ce" dependencies = [ "aws-credential-types", - "aws-http", "aws-runtime", "aws-smithy-async", "aws-smithy-http", @@ -475,16 +483,17 @@ dependencies = [ "aws-smithy-types", "aws-smithy-xml", "aws-types", - "http", - "regex", + "http 0.2.9", + "once_cell", + "regex-lite", "tracing", ] [[package]] name = "aws-sigv4" -version = "1.0.1" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "380adcc8134ad8bbdfeb2ace7626a869914ee266322965276cbc54066186d236" +checksum = "c371c6b0ac54d4605eb6f016624fb5c7c2925d315fdf600ac1bf21b19d5f1742" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", @@ -496,11 +505,11 @@ dependencies = [ "form_urlencoded", "hex", "hmac", - "http", + "http 0.2.9", + "http 1.0.0", "once_cell", "p256", "percent-encoding", - "regex", "ring 0.17.6", "sha2", "subtle", @@ -511,9 +520,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.0.2" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e37ca17d25fe1e210b6d4bdf59b81caebfe99f986201a1228cb5061233b4b13" +checksum = "72ee2d09cce0ef3ae526679b522835d63e75fb427aca5413cd371e490d52dcc6" dependencies = [ "futures-util", "pin-project-lite", @@ -522,9 +531,9 @@ dependencies = [ [[package]] name = "aws-smithy-checksums" -version = "0.60.0" +version = "0.60.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5a373ec01aede3dd066ec018c1bc4e8f5dd11b2c11c59c8eef1a5c68101f397" +checksum = "be2acd1b9c6ae5859999250ed5a62423aedc5cf69045b844432de15fa2f31f2b" dependencies = [ "aws-smithy-http", "aws-smithy-types", @@ -532,7 +541,7 @@ dependencies = [ "crc32c", "crc32fast", "hex", - "http", + "http 0.2.9", "http-body", "md-5", "pin-project-lite", @@ -543,9 +552,9 @@ dependencies = [ [[package]] name = "aws-smithy-eventstream" -version = "0.60.0" +version = "0.60.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c669e1e5fc0d79561bf7a122b118bd50c898758354fe2c53eb8f2d31507cbc3" +checksum = "e6363078f927f612b970edf9d1903ef5cef9a64d1e8423525ebb1f0a1633c858" dependencies = [ "aws-smithy-types", "bytes", @@ -554,9 +563,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.60.0" +version = "0.60.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b1de8aee22f67de467b2e3d0dd0fb30859dc53f579a63bd5381766b987db644" +checksum = "dab56aea3cd9e1101a0a999447fb346afb680ab1406cebc44b32346e25b4117d" dependencies = [ "aws-smithy-eventstream", "aws-smithy-runtime-api", @@ -564,7 +573,7 @@ dependencies = [ "bytes", "bytes-utils", "futures-core", - "http", + "http 0.2.9", "http-body", "once_cell", "percent-encoding", @@ -575,18 +584,18 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.60.0" +version = "0.60.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a46dd338dc9576d6a6a5b5a19bd678dcad018ececee11cf28ecd7588bd1a55c" +checksum = "fd3898ca6518f9215f62678870064398f00031912390efd03f1f6ef56d83aa8e" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-query" -version = "0.60.0" +version = "0.60.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feb5b8c7a86d4b6399169670723b7e6f21a39fc833a30f5c5a2f997608178129" +checksum = "bda4b1dfc9810e35fba8a620e900522cd1bd4f9578c446e82f49d1ce41d2e9f9" dependencies = [ "aws-smithy-types", "urlencoding", @@ -594,9 +603,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.0.2" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "273479291efc55e7b0bce985b139d86b6031adb8e50f65c1f712f20ba38f6388" +checksum = "fafdab38f40ad7816e7da5dec279400dd505160780083759f01441af1bbb10ea" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -605,7 +614,7 @@ dependencies = [ "bytes", "fastrand 2.0.0", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-rustls", @@ -619,14 +628,14 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.0.2" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6cebff0d977b6b6feed2fd07db52aac58ba3ccaf26cdd49f1af4add5061bef9" +checksum = "c18276dd28852f34b3bf501f4f3719781f4999a51c7bff1a5c6dc8c4529adc29" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", - "http", + "http 0.2.9", "pin-project-lite", "tokio", "tracing", @@ -635,15 +644,15 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.0.2" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7f48b3f27ddb40ab19892a5abda331f403e3cb877965e4e51171447807104af" +checksum = "bb3e134004170d3303718baa2a4eb4ca64ee0a1c0a7041dca31b38be0fb414f3" dependencies = [ "base64-simd", "bytes", "bytes-utils", "futures-core", - "http", + "http 0.2.9", "http-body", "itoa", "num-integer", @@ -658,24 +667,24 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.0" +version = "0.60.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ec40d74a67fd395bc3f6b4ccbdf1543672622d905ef3f979689aea5b730cb95" +checksum = "8604a11b25e9ecaf32f9aa56b9fe253c5e2f606a3477f0071e96d3155a5ed218" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.0.1" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8403fc56b1f3761e8efe45771ddc1165e47ec3417c68e68a4519b5cb030159ca" +checksum = "789bbe008e65636fe1b6dbbb374c40c8960d1232b96af5ff4aec349f9c4accf4" dependencies = [ "aws-credential-types", "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", - "http", + "http 0.2.9", "rustc_version", "tracing", ] @@ -692,7 +701,7 @@ dependencies = [ "bitflags 1.3.2", "bytes", "futures-util", - "http", + "http 0.2.9", "http-body", "hyper", "itoa", @@ -724,7 +733,7 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http", + "http 0.2.9", "http-body", "mime", "rustversion", @@ -2003,9 +2012,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", "futures-sink", @@ -2013,9 +2022,9 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" [[package]] name = "futures-executor" @@ -2030,9 +2039,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-lite" @@ -2051,9 +2060,9 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", @@ -2062,15 +2071,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" [[package]] name = "futures-timer" @@ -2080,9 +2089,9 @@ checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-channel", "futures-core", @@ -2186,7 +2195,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.9", "indexmap 2.0.1", "slab", "tokio", @@ -2337,6 +2346,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.5" @@ -2344,7 +2364,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" dependencies = [ "bytes", - "http", + "http 0.2.9", "pin-project-lite", ] @@ -2407,7 +2427,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "httparse", "httpdate", @@ -2426,7 +2446,7 @@ version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0646026eb1b3eea4cd9ba47912ea5ce9cc07713d105b1a14698f4e6433d348b7" dependencies = [ - "http", + "http 0.2.9", "hyper", "log", "rustls", @@ -3108,7 +3128,7 @@ dependencies = [ "base64 0.13.1", "chrono", "getrandom 0.2.11", - "http", + "http 0.2.9", "rand 0.8.5", "serde", "serde_json", @@ -3210,7 +3230,7 @@ checksum = "c7594ec0e11d8e33faf03530a4c49af7064ebba81c1480e01be67d90b356508b" dependencies = [ "async-trait", "bytes", - "http", + "http 0.2.9", "opentelemetry_api", "reqwest", ] @@ -3223,7 +3243,7 @@ checksum = "7e5e5a5c4135864099f3faafbe939eb4d7f9b80ebf68a8448da961b32a7c1275" dependencies = [ "async-trait", "futures-core", - "http", + "http 0.2.9", "opentelemetry-http", "opentelemetry-proto", "opentelemetry-semantic-conventions", @@ -4323,6 +4343,12 @@ dependencies = [ "regex-syntax 0.8.2", ] +[[package]] +name = "regex-lite" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b661b2f27137bdbc16f00eda72866a92bb28af1753ffbd56744fb6e2e9cd8e" + [[package]] name = "regex-syntax" version = "0.6.29" @@ -4392,7 +4418,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-rustls", @@ -4433,7 +4459,7 @@ checksum = "4531c89d50effe1fac90d095c8b133c20c5c714204feee0bfc3fd158e784209d" dependencies = [ "anyhow", "async-trait", - "http", + "http 0.2.9", "reqwest", "serde", "task-local-extensions", @@ -4451,7 +4477,7 @@ dependencies = [ "chrono", "futures", "getrandom 0.2.11", - "http", + "http 0.2.9", "hyper", "parking_lot 0.11.2", "reqwest", @@ -4538,7 +4564,7 @@ version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "496c1d3718081c45ba9c31fbfc07417900aa96f4070ff90dc29961836b7a9945" dependencies = [ - "http", + "http 0.2.9", "hyper", "lazy_static", "percent-encoding", @@ -5868,7 +5894,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-timeout", @@ -6083,7 +6109,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http", + "http 0.2.9", "httparse", "log", "rand 0.8.5", diff --git a/Cargo.toml b/Cargo.toml index d3006985ab..0cfe522ff9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,11 +48,12 @@ azure_storage_blobs = "0.18" flate2 = "1.0.26" async-stream = "0.3" async-trait = "0.1" -aws-config = { version = "1.0", default-features = false, features=["rustls"] } -aws-sdk-s3 = "1.0" -aws-smithy-async = { version = "1.0", default-features = false, features=["rt-tokio"] } -aws-smithy-types = "1.0" -aws-credential-types = "1.0" +aws-config = { version = "1.1.4", default-features = false, features=["rustls"] } +aws-sdk-s3 = "1.14" +aws-sdk-secretsmanager = { version = "1.14.0" } +aws-smithy-async = { version = "1.1.4", default-features = false, features=["rt-tokio"] } +aws-smithy-types = "1.1.4" +aws-credential-types = "1.1.4" axum = { version = "0.6.20", features = ["ws"] } base64 = "0.13.0" bincode = "1.3" diff --git a/control_plane/attachment_service/Cargo.toml b/control_plane/attachment_service/Cargo.toml index 210a898747..1d3831eea0 100644 --- a/control_plane/attachment_service/Cargo.toml +++ b/control_plane/attachment_service/Cargo.toml @@ -6,6 +6,8 @@ license.workspace = true [dependencies] anyhow.workspace = true +aws-config.workspace = true +aws-sdk-secretsmanager.workspace = true camino.workspace = true clap.workspace = true futures.workspace = true diff --git a/control_plane/attachment_service/src/main.rs b/control_plane/attachment_service/src/main.rs index 7c716a9f53..ed65437ba2 100644 --- a/control_plane/attachment_service/src/main.rs +++ b/control_plane/attachment_service/src/main.rs @@ -8,6 +8,7 @@ use anyhow::anyhow; use attachment_service::http::make_router; use attachment_service::persistence::Persistence; use attachment_service::service::{Config, Service}; +use aws_config::{self, BehaviorVersion, Region}; use camino::Utf8PathBuf; use clap::Parser; use metrics::launch_timestamp::LaunchTimestamp; @@ -46,6 +47,100 @@ struct Cli { database_url: String, } +/// Secrets may either be provided on the command line (for testing), or loaded from AWS SecretManager: this +/// type encapsulates the logic to decide which and do the loading. +struct Secrets { + database_url: String, + public_key: Option, + jwt_token: Option, +} + +impl Secrets { + const DATABASE_URL_SECRET: &'static str = "rds-neon-storage-controller-url"; + const JWT_TOKEN_SECRET: &'static str = "neon-storage-controller-pageserver-jwt-token"; + const PUBLIC_KEY_SECRET: &'static str = "neon-storage-controller-public-key"; + + async fn load(args: &Cli) -> anyhow::Result { + if args.database_url.is_empty() { + Self::load_aws_sm().await + } else { + Self::load_cli(args) + } + } + + async fn load_aws_sm() -> anyhow::Result { + let Ok(region) = std::env::var("AWS_REGION") else { + anyhow::bail!("AWS_REGION is not set, cannot load secrets automatically: either set this, or use CLI args to supply secrets"); + }; + let config = aws_config::defaults(BehaviorVersion::v2023_11_09()) + .region(Region::new(region.clone())) + .load() + .await; + + let asm = aws_sdk_secretsmanager::Client::new(&config); + + let Some(database_url) = asm + .get_secret_value() + .secret_id(Self::DATABASE_URL_SECRET) + .send() + .await? + .secret_string() + .map(str::to_string) + else { + anyhow::bail!( + "Database URL secret not found at {region}/{}", + Self::DATABASE_URL_SECRET + ) + }; + + let jwt_token = asm + .get_secret_value() + .secret_id(Self::JWT_TOKEN_SECRET) + .send() + .await? + .secret_string() + .map(str::to_string); + if jwt_token.is_none() { + tracing::warn!("No pageserver JWT token set: this will only work if authentication is disabled on the pageserver"); + } + + let public_key = asm + .get_secret_value() + .secret_id(Self::PUBLIC_KEY_SECRET) + .send() + .await? + .secret_string() + .map(str::to_string); + let public_key = match public_key { + Some(key) => Some(JwtAuth::from_key(key)?), + None => { + tracing::warn!( + "No public key set: inccoming HTTP requests will not be authenticated" + ); + None + } + }; + + Ok(Self { + database_url, + public_key, + jwt_token, + }) + } + + fn load_cli(args: &Cli) -> anyhow::Result { + let public_key = match &args.public_key { + None => None, + Some(key_path) => Some(JwtAuth::from_key_path(key_path)?), + }; + Ok(Self { + database_url: args.database_url.clone(), + public_key, + jwt_token: args.jwt_token.clone(), + }) + } +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let launch_ts = Box::leak(Box::new(LaunchTimestamp::generate())); @@ -66,23 +161,22 @@ async fn main() -> anyhow::Result<()> { args.listen ); + let secrets = Secrets::load(&args).await?; + let config = Config { - jwt_token: args.jwt_token, + jwt_token: secrets.jwt_token, }; let json_path = args.path; - let persistence = Arc::new(Persistence::new(args.database_url, json_path.clone())); + let persistence = Arc::new(Persistence::new(secrets.database_url, json_path.clone())); let service = Service::spawn(config, persistence.clone()).await?; let http_listener = tcp_listener::bind(args.listen)?; - let auth = if let Some(public_key_path) = &args.public_key { - let jwt_auth = JwtAuth::from_key_path(public_key_path)?; - Some(Arc::new(SwappableJwtAuth::new(jwt_auth))) - } else { - None - }; + let auth = secrets + .public_key + .map(|jwt_auth| Arc::new(SwappableJwtAuth::new(jwt_auth))); let router = make_router(service, auth) .build() .map_err(|err| anyhow!(err))?; diff --git a/libs/utils/src/auth.rs b/libs/utils/src/auth.rs index 66b1f6e866..15c3f2af1b 100644 --- a/libs/utils/src/auth.rs +++ b/libs/utils/src/auth.rs @@ -127,6 +127,10 @@ impl JwtAuth { Ok(Self::new(decoding_keys)) } + pub fn from_key(key: String) -> Result { + Ok(Self::new(vec![DecodingKey::from_ed_pem(key.as_bytes())?])) + } + /// Attempt to decode the token with the internal decoding keys. /// /// The function tries the stored decoding keys in succession, diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 8fd49956cc..f58b912a77 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -15,7 +15,7 @@ publish = false [dependencies] anyhow = { version = "1", features = ["backtrace"] } aws-config = { version = "1", default-features = false, features = ["rustls", "sso"] } -aws-runtime = { version = "1", default-features = false, features = ["event-stream", "sigv4a"] } +aws-runtime = { version = "1", default-features = false, features = ["event-stream", "http-02x", "sigv4a"] } aws-sigv4 = { version = "1", features = ["http0-compat", "sign-eventstream", "sigv4a"] } aws-smithy-async = { version = "1", default-features = false, features = ["rt-tokio"] } aws-smithy-http = { version = "0.60", default-features = false, features = ["event-stream"] }