From 6506fd14c45bf4fd685e8ba25cbd609502537155 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 2 Feb 2024 16:07:35 +0000 Subject: [PATCH] proxy: more refactors (#6526) ## Problem not really any problem, just some drive-by changes ## Summary of changes 1. move wake compute 2. move json processing 3. move handle_try_wake 4. move test backend to api provider 5. reduce wake-compute concerns 6. remove duplicate wake-compute loop --- proxy/src/auth/backend.rs | 113 +++---- proxy/src/auth/backend/classic.rs | 2 +- proxy/src/auth/flow.rs | 2 +- proxy/src/bin/proxy.rs | 26 +- proxy/src/console/provider.rs | 23 +- proxy/src/console/provider/neon.rs | 1 - proxy/src/proxy.rs | 1 + proxy/src/proxy/connect_compute.rs | 118 ++----- proxy/src/proxy/tests.rs | 16 +- proxy/src/proxy/wake_compute.rs | 95 ++++++ proxy/src/serverless.rs | 1 + proxy/src/serverless/json.rs | 448 ++++++++++++++++++++++++++ proxy/src/serverless/sql_over_http.rs | 447 +------------------------ 13 files changed, 649 insertions(+), 644 deletions(-) create mode 100644 proxy/src/proxy/wake_compute.rs create mode 100644 proxy/src/serverless/json.rs diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 144c9dcff5..236567163e 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -12,8 +12,7 @@ use crate::console::errors::GetAuthInfoError; use crate::console::provider::{CachedRoleSecret, ConsoleBackend}; use crate::console::AuthSecret; use crate::context::RequestMonitoring; -use crate::proxy::connect_compute::handle_try_wake; -use crate::proxy::retry::retry_after; +use crate::proxy::wake_compute::wake_compute; use crate::proxy::NeonOptions; use crate::stream::Stream; use crate::{ @@ -28,11 +27,26 @@ use crate::{ }; use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; use futures::TryFutureExt; -use std::borrow::Cow; -use std::ops::ControlFlow; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{error, info, warn}; +use tracing::info; + +/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality +pub enum MaybeOwned<'a, T> { + Owned(T), + Borrowed(&'a T), +} + +impl std::ops::Deref for MaybeOwned<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + match self { + MaybeOwned::Owned(t) => t, + MaybeOwned::Borrowed(t) => t, + } + } +} /// This type serves two purposes: /// @@ -44,12 +58,9 @@ use tracing::{error, info, warn}; /// backends which require them for the authentication process. pub enum BackendType<'a, T> { /// Cloud API (V2). - Console(Cow<'a, ConsoleBackend>, T), + Console(MaybeOwned<'a, ConsoleBackend>, T), /// Authentication via a web browser. - Link(Cow<'a, url::ApiUrl>), - #[cfg(test)] - /// Test backend. - Test(&'a dyn TestBackend), + Link(MaybeOwned<'a, url::ApiUrl>), } pub trait TestBackend: Send + Sync + 'static { @@ -67,14 +78,14 @@ impl std::fmt::Display for BackendType<'_, ()> { ConsoleBackend::Console(endpoint) => { fmt.debug_tuple("Console").field(&endpoint.url()).finish() } - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] ConsoleBackend::Postgres(endpoint) => { fmt.debug_tuple("Postgres").field(&endpoint.url()).finish() } + #[cfg(test)] + ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(), }, Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), - #[cfg(test)] - Test(_) => fmt.debug_tuple("Test").finish(), } } } @@ -85,10 +96,8 @@ impl BackendType<'_, T> { pub fn as_ref(&self) -> BackendType<'_, &T> { use BackendType::*; match self { - Console(c, x) => Console(Cow::Borrowed(c), x), - Link(c) => Link(Cow::Borrowed(c)), - #[cfg(test)] - Test(x) => Test(*x), + Console(c, x) => Console(MaybeOwned::Borrowed(c), x), + Link(c) => Link(MaybeOwned::Borrowed(c)), } } } @@ -102,8 +111,6 @@ impl<'a, T> BackendType<'a, T> { match self { Console(c, x) => Console(c, f(x)), Link(c) => Link(c), - #[cfg(test)] - Test(x) => Test(x), } } } @@ -116,8 +123,6 @@ impl<'a, T, E> BackendType<'a, Result> { match self { Console(c, x) => x.map(|x| Console(c, x)), Link(c) => Ok(Link(c)), - #[cfg(test)] - Test(x) => Ok(Test(x)), } } } @@ -147,7 +152,7 @@ impl ComputeUserInfo { } pub enum ComputeCredentialKeys { - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] Password(Vec), AuthKeys(AuthKeys), } @@ -277,42 +282,6 @@ async fn authenticate_with_secret( classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await } -/// wake a compute (or retrieve an existing compute session from cache) -async fn wake_compute( - ctx: &mut RequestMonitoring, - api: &impl console::Api, - compute_credentials: ComputeCredentials, -) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> { - let mut num_retries = 0; - let mut node = loop { - let wake_res = api.wake_compute(ctx, &compute_credentials.info).await; - match handle_try_wake(wake_res, num_retries) { - Err(e) => { - error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); - return Err(e.into()); - } - Ok(ControlFlow::Continue(e)) => { - warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node"); - } - Ok(ControlFlow::Break(n)) => break n, - } - - let wait_duration = retry_after(num_retries); - num_retries += 1; - tokio::time::sleep(wait_duration).await; - }; - - ctx.set_project(node.aux.clone()); - - match compute_credentials.keys { - #[cfg(feature = "testing")] - ComputeCredentialKeys::Password(password) => node.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys), - }; - - Ok((node, compute_credentials.info)) -} - impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { /// Get compute endpoint name from the credentials. pub fn get_endpoint(&self) -> Option { @@ -321,8 +290,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { match self { Console(_, user_info) => user_info.endpoint_id.clone(), Link(_) => Some("link".into()), - #[cfg(test)] - Test(_) => Some("test".into()), } } @@ -333,8 +300,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { match self { Console(_, user_info) => &user_info.user, Link(_) => "link", - #[cfg(test)] - Test(_) => "test", } } @@ -359,8 +324,20 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { let compute_credentials = auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?; - let (cache_info, user_info) = wake_compute(ctx, &*api, compute_credentials).await?; - (cache_info, BackendType::Console(api, user_info)) + + let mut num_retries = 0; + let mut node = + wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?; + + ctx.set_project(node.aux.clone()); + + match compute_credentials.keys { + #[cfg(any(test, feature = "testing"))] + ComputeCredentialKeys::Password(password) => node.config.password(password), + ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys), + }; + + (node, BackendType::Console(api, compute_credentials.info)) } // NOTE: this auth backend doesn't use client credentials. Link(url) => { @@ -373,10 +350,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { BackendType::Link(url), ) } - #[cfg(test)] - Test(_) => { - unreachable!("this function should never be called in the test backend") - } }; info!("user successfully authenticated"); @@ -393,8 +366,6 @@ impl BackendType<'_, ComputeUserInfo> { match self { Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await, Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), - #[cfg(test)] - Test(x) => x.get_allowed_ips_and_secret(), } } @@ -409,8 +380,6 @@ impl BackendType<'_, ComputeUserInfo> { match self { Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await, Link(_) => Ok(None), - #[cfg(test)] - Test(x) => x.wake_compute().map(Some), } } } diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 358b335b88..384063ceae 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -20,7 +20,7 @@ pub(super) async fn authenticate( ) -> auth::Result> { let flow = AuthFlow::new(client); let scram_keys = match secret { - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] AuthSecret::Md5(_) => { info!("auth endpoint chooses MD5"); return Err(auth::AuthError::bad_auth_method("MD5")); diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 3151a77263..077178d107 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -172,7 +172,7 @@ pub(super) fn validate_password_and_exchange( secret: AuthSecret, ) -> super::Result> { match secret { - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] AuthSecret::Md5(_) => { // test only Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password( diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 3960b080be..3bbb87808d 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,5 +1,6 @@ use futures::future::Either; use proxy::auth; +use proxy::auth::backend::MaybeOwned; use proxy::config::AuthenticationConfig; use proxy::config::CacheOptions; use proxy::config::HttpConfig; @@ -17,9 +18,9 @@ use proxy::usage_metrics; use anyhow::bail; use proxy::config::{self, ProxyConfig}; use proxy::serverless; +use std::net::SocketAddr; use std::pin::pin; use std::sync::Arc; -use std::{borrow::Cow, net::SocketAddr}; use tokio::net::TcpListener; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -259,18 +260,13 @@ async fn main() -> anyhow::Result<()> { } if let auth::BackendType::Console(api, _) = &config.auth_backend { - match &**api { - proxy::console::provider::ConsoleBackend::Console(api) => { - let cache = api.caches.project_info.clone(); - if let Some(url) = args.redis_notifications { - info!("Starting redis notifications listener ({url})"); - maintenance_tasks - .spawn(notifications::task_main(url.to_owned(), cache.clone())); - } - maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); + if let proxy::console::provider::ConsoleBackend::Console(api) = &**api { + let cache = api.caches.project_info.clone(); + if let Some(url) = args.redis_notifications { + info!("Starting redis notifications listener ({url})"); + maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone())); } - #[cfg(feature = "testing")] - proxy::console::provider::ConsoleBackend::Postgres(_) => {} + maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } } @@ -369,18 +365,18 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let api = console::provider::neon::Api::new(endpoint, caches, locks); let api = console::provider::ConsoleBackend::Console(api); - auth::BackendType::Console(Cow::Owned(api), ()) + auth::BackendType::Console(MaybeOwned::Owned(api), ()) } #[cfg(feature = "testing")] AuthBackend::Postgres => { let url = args.auth_endpoint.parse()?; let api = console::provider::mock::Api::new(url); let api = console::provider::ConsoleBackend::Postgres(api); - auth::BackendType::Console(Cow::Owned(api), ()) + auth::BackendType::Console(MaybeOwned::Owned(api), ()) } AuthBackend::Link => { let url = args.uri.parse()?; - auth::BackendType::Link(Cow::Owned(url)) + auth::BackendType::Link(MaybeOwned::Owned(url)) } }; let http_config = HttpConfig { diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index ff84db7738..c53d929470 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -1,4 +1,4 @@ -#[cfg(feature = "testing")] +#[cfg(any(test, feature = "testing"))] pub mod mock; pub mod neon; @@ -199,7 +199,7 @@ pub mod errors { /// Auth secret which is managed by the cloud. #[derive(Clone, Eq, PartialEq, Debug)] pub enum AuthSecret { - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] /// Md5 hash of user's password. Md5([u8; 16]), @@ -264,13 +264,16 @@ pub trait Api { ) -> Result; } -#[derive(Clone)] +#[non_exhaustive] pub enum ConsoleBackend { /// Current Cloud API (V2). Console(neon::Api), /// Local mock of Cloud API (V2). - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] Postgres(mock::Api), + /// Internal testing + #[cfg(test)] + Test(Box), } #[async_trait] @@ -283,8 +286,10 @@ impl Api for ConsoleBackend { use ConsoleBackend::*; match self { Console(api) => api.get_role_secret(ctx, user_info).await, - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] Postgres(api) => api.get_role_secret(ctx, user_info).await, + #[cfg(test)] + Test(_) => unreachable!("this function should never be called in the test backend"), } } @@ -296,8 +301,10 @@ impl Api for ConsoleBackend { use ConsoleBackend::*; match self { Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, + #[cfg(test)] + Test(api) => api.get_allowed_ips_and_secret(), } } @@ -310,8 +317,10 @@ impl Api for ConsoleBackend { match self { Console(api) => api.wake_compute(ctx, user_info).await, - #[cfg(feature = "testing")] + #[cfg(any(test, feature = "testing"))] Postgres(api) => api.wake_compute(ctx, user_info).await, + #[cfg(test)] + Test(api) => api.wake_compute(), } } } diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index f22c6d2322..0785419790 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -19,7 +19,6 @@ use tokio::time::Instant; use tokio_postgres::config::SslMode; use tracing::{error, info, info_span, warn, Instrument}; -#[derive(Clone)] pub struct Api { endpoint: http::Endpoint, pub caches: &'static ApiCaches, diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 4aa1f3590d..b68fb26e42 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -5,6 +5,7 @@ pub mod connect_compute; pub mod handshake; pub mod passthrough; pub mod retry; +pub mod wake_compute; use crate::{ auth, diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 8bbe88aa51..58c59dba36 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,15 +1,16 @@ use crate::{ auth, compute::{self, PostgresConnection}, - console::{self, errors::WakeComputeError, Api}, + console::{self, errors::WakeComputeError}, context::RequestMonitoring, - metrics::{bool_to_str, NUM_CONNECTION_FAILURES, NUM_WAKEUP_FAILURES}, - proxy::retry::{retry_after, ShouldRetry}, + metrics::NUM_CONNECTION_FAILURES, + proxy::{ + retry::{retry_after, ShouldRetry}, + wake_compute::wake_compute, + }, }; use async_trait::async_trait; -use hyper::StatusCode; use pq_proto::StartupMessageParams; -use std::ops::ControlFlow; use tokio::time; use tracing::{error, info, warn}; @@ -88,39 +89,6 @@ impl ConnectMechanism for TcpMechanism<'_> { } } -fn report_error(e: &WakeComputeError, retry: bool) { - use crate::console::errors::ApiError; - let retry = bool_to_str(retry); - let kind = match e { - WakeComputeError::BadComputeAddress(_) => "bad_compute_address", - WakeComputeError::ApiError(ApiError::Transport(_)) => "api_transport_error", - WakeComputeError::ApiError(ApiError::Console { - status: StatusCode::LOCKED, - ref text, - }) if text.contains("written data quota exceeded") - || text.contains("the limit for current plan reached") => - { - "quota_exceeded" - } - WakeComputeError::ApiError(ApiError::Console { - status: StatusCode::LOCKED, - .. - }) => "api_console_locked", - WakeComputeError::ApiError(ApiError::Console { - status: StatusCode::BAD_REQUEST, - .. - }) => "api_console_bad_request", - WakeComputeError::ApiError(ApiError::Console { status, .. }) - if status.is_server_error() => - { - "api_console_other_server_error" - } - WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error", - WakeComputeError::TimeoutError => "timeout_error", - }; - NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc(); -} - /// Try to connect to the compute node, retrying if necessary. /// This function might update `node_info`, so we take it by `&mut`. #[tracing::instrument(skip_all)] @@ -137,7 +105,7 @@ where mechanism.update_connect_config(&mut node_info.config); // try once - let (config, err) = match mechanism + let err = match mechanism .connect_once(ctx, &node_info, CONNECT_TIMEOUT) .await { @@ -145,51 +113,27 @@ where ctx.latency_timer.success(); return Ok(res); } - Err(e) => { - error!(error = ?e, "could not connect to compute node"); - (invalidate_cache(node_info), e) - } + Err(e) => e, }; - ctx.latency_timer.cache_miss(); + error!(error = ?err, "could not connect to compute node"); let mut num_retries = 1; - // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node - info!("compute node's state has likely changed; requesting a wake-up"); - let node_info = loop { - let wake_res = match user_info { - auth::BackendType::Console(api, user_info) => api.wake_compute(ctx, user_info).await, - // nothing to do? - auth::BackendType::Link(_) => return Err(err.into()), - // test backend - #[cfg(test)] - auth::BackendType::Test(x) => x.wake_compute(), - }; + match user_info { + auth::BackendType::Console(api, info) => { + // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node + info!("compute node's state has likely changed; requesting a wake-up"); - match handle_try_wake(wake_res, num_retries) { - Err(e) => { - error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); - report_error(&e, false); - return Err(e.into()); - } - // failed to wake up but we can continue to retry - Ok(ControlFlow::Continue(e)) => { - report_error(&e, true); - warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node"); - } - // successfully woke up a compute node and can break the wakeup loop - Ok(ControlFlow::Break(mut node_info)) => { - node_info.config.reuse_password(&config); - mechanism.update_connect_config(&mut node_info.config); - break node_info; - } + ctx.latency_timer.cache_miss(); + let config = invalidate_cache(node_info); + node_info = wake_compute(&mut num_retries, ctx, api, info).await?; + + node_info.config.reuse_password(&config); + mechanism.update_connect_config(&mut node_info.config); } - - let wait_duration = retry_after(num_retries); - num_retries += 1; - - time::sleep(wait_duration).await; + // nothing to do? + auth::BackendType::Link(_) => {} }; // now that we have a new node, try connect to it repeatedly. @@ -221,23 +165,3 @@ where time::sleep(wait_duration).await; } } - -/// Attempts to wake up the compute node. -/// * Returns Ok(Continue(e)) if there was an error waking but retries are acceptable -/// * Returns Ok(Break(node)) if the wakeup succeeded -/// * Returns Err(e) if there was an error -pub fn handle_try_wake( - result: Result, - num_retries: u32, -) -> Result, WakeComputeError> { - match result { - Err(err) => match &err { - WakeComputeError::ApiError(api) if api.should_retry(num_retries) => { - Ok(ControlFlow::Continue(err)) - } - _ => Err(err), - }, - // Ready to try again. - Ok(new) => Ok(ControlFlow::Break(new)), - } -} diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 1f57d343c4..2000774224 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -5,9 +5,9 @@ mod mitm; use super::connect_compute::ConnectMechanism; use super::retry::ShouldRetry; use super::*; -use crate::auth::backend::{ComputeUserInfo, TestBackend}; +use crate::auth::backend::{ComputeUserInfo, MaybeOwned, TestBackend}; use crate::config::CertResolver; -use crate::console::provider::{CachedAllowedIps, CachedRoleSecret}; +use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; use crate::console::{self, CachedNodeInfo, NodeInfo}; use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT}; use crate::{auth, http, sasl, scram}; @@ -371,6 +371,7 @@ enum ConnectAction { Fail, } +#[derive(Clone)] struct TestConnectMechanism { counter: Arc>, sequence: Vec, @@ -490,9 +491,16 @@ fn helper_create_cached_node_info() -> CachedNodeInfo { fn helper_create_connect_info( mechanism: &TestConnectMechanism, -) -> (CachedNodeInfo, auth::BackendType<'_, ComputeUserInfo>) { +) -> (CachedNodeInfo, auth::BackendType<'static, ComputeUserInfo>) { let cache = helper_create_cached_node_info(); - let user_info = auth::BackendType::Test(mechanism); + let user_info = auth::BackendType::Console( + MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))), + ComputeUserInfo { + endpoint: "endpoint".into(), + user: "user".into(), + options: NeonOptions::parse_options_raw(""), + }, + ); (cache, user_info) } diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs new file mode 100644 index 0000000000..925727bdab --- /dev/null +++ b/proxy/src/proxy/wake_compute.rs @@ -0,0 +1,95 @@ +use crate::auth::backend::ComputeUserInfo; +use crate::console::{ + errors::WakeComputeError, + provider::{CachedNodeInfo, ConsoleBackend}, + Api, +}; +use crate::context::RequestMonitoring; +use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES}; +use crate::proxy::retry::retry_after; +use hyper::StatusCode; +use std::ops::ControlFlow; +use tracing::{error, warn}; + +use super::retry::ShouldRetry; + +/// wake a compute (or retrieve an existing compute session from cache) +pub async fn wake_compute( + num_retries: &mut u32, + ctx: &mut RequestMonitoring, + api: &ConsoleBackend, + info: &ComputeUserInfo, +) -> Result { + loop { + let wake_res = api.wake_compute(ctx, info).await; + match handle_try_wake(wake_res, *num_retries) { + Err(e) => { + error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); + report_error(&e, false); + return Err(e); + } + Ok(ControlFlow::Continue(e)) => { + warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node"); + report_error(&e, true); + } + Ok(ControlFlow::Break(n)) => return Ok(n), + } + + let wait_duration = retry_after(*num_retries); + *num_retries += 1; + tokio::time::sleep(wait_duration).await; + } +} + +/// Attempts to wake up the compute node. +/// * Returns Ok(Continue(e)) if there was an error waking but retries are acceptable +/// * Returns Ok(Break(node)) if the wakeup succeeded +/// * Returns Err(e) if there was an error +pub fn handle_try_wake( + result: Result, + num_retries: u32, +) -> Result, WakeComputeError> { + match result { + Err(err) => match &err { + WakeComputeError::ApiError(api) if api.should_retry(num_retries) => { + Ok(ControlFlow::Continue(err)) + } + _ => Err(err), + }, + // Ready to try again. + Ok(new) => Ok(ControlFlow::Break(new)), + } +} + +fn report_error(e: &WakeComputeError, retry: bool) { + use crate::console::errors::ApiError; + let retry = bool_to_str(retry); + let kind = match e { + WakeComputeError::BadComputeAddress(_) => "bad_compute_address", + WakeComputeError::ApiError(ApiError::Transport(_)) => "api_transport_error", + WakeComputeError::ApiError(ApiError::Console { + status: StatusCode::LOCKED, + ref text, + }) if text.contains("written data quota exceeded") + || text.contains("the limit for current plan reached") => + { + "quota_exceeded" + } + WakeComputeError::ApiError(ApiError::Console { + status: StatusCode::LOCKED, + .. + }) => "api_console_locked", + WakeComputeError::ApiError(ApiError::Console { + status: StatusCode::BAD_REQUEST, + .. + }) => "api_console_bad_request", + WakeComputeError::ApiError(ApiError::Console { status, .. }) + if status.is_server_error() => + { + "api_console_other_server_error" + } + WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error", + WakeComputeError::TimeoutError => "timeout_error", + }; + NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc(); +} diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index a2eb7e62cc..7ff93b23b8 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -3,6 +3,7 @@ //! Handles both SQL over HTTP and SQL over Websockets. mod conn_pool; +mod json; mod sql_over_http; mod websocket; diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs new file mode 100644 index 0000000000..05835b23ce --- /dev/null +++ b/proxy/src/serverless/json.rs @@ -0,0 +1,448 @@ +use serde_json::Map; +use serde_json::Value; +use tokio_postgres::types::Kind; +use tokio_postgres::types::Type; +use tokio_postgres::Row; + +// +// Convert json non-string types to strings, so that they can be passed to Postgres +// as parameters. +// +pub fn json_to_pg_text(json: Vec) -> Vec> { + json.iter() + .map(|value| { + match value { + // special care for nulls + Value::Null => None, + + // convert to text with escaping + v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()), + + // avoid escaping here, as we pass this as a parameter + Value::String(s) => Some(s.to_string()), + + // special care for arrays + Value::Array(_) => json_array_to_pg_array(value), + } + }) + .collect() +} + +// +// Serialize a JSON array to a Postgres array. Contrary to the strings in the params +// in the array we need to escape the strings. Postgres is okay with arrays of form +// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving +// it for Postgres to check. +// +// Example of the same escaping in node-postgres: packages/pg/lib/utils.js +// +fn json_array_to_pg_array(value: &Value) -> Option { + match value { + // special care for nulls + Value::Null => None, + + // convert to text with escaping + // here string needs to be escaped, as it is part of the array + v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()), + v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())), + + // recurse into array + Value::Array(arr) => { + let vals = arr + .iter() + .map(json_array_to_pg_array) + .map(|v| v.unwrap_or_else(|| "NULL".to_string())) + .collect::>() + .join(","); + + Some(format!("{{{}}}", vals)) + } + } +} + +// +// Convert postgres row with text-encoded values to JSON object +// +pub fn pg_text_row_to_json( + row: &Row, + columns: &[Type], + raw_output: bool, + array_mode: bool, +) -> Result { + let iter = row + .columns() + .iter() + .zip(columns) + .enumerate() + .map(|(i, (column, typ))| { + let name = column.name(); + let pg_value = row.as_text(i)?; + let json_value = if raw_output { + match pg_value { + Some(v) => Value::String(v.to_string()), + None => Value::Null, + } + } else { + pg_text_to_json(pg_value, typ)? + }; + Ok((name.to_string(), json_value)) + }); + + if array_mode { + // drop keys and aggregate into array + let arr = iter + .map(|r| r.map(|(_key, val)| val)) + .collect::, anyhow::Error>>()?; + Ok(Value::Array(arr)) + } else { + let obj = iter.collect::, anyhow::Error>>()?; + Ok(Value::Object(obj)) + } +} + +// +// Convert postgres text-encoded value to JSON value +// +fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result { + if let Some(val) = pg_value { + if let Kind::Array(elem_type) = pg_type.kind() { + return pg_array_parse(val, elem_type); + } + + match *pg_type { + Type::BOOL => Ok(Value::Bool(val == "t")), + Type::INT2 | Type::INT4 => { + let val = val.parse::()?; + Ok(Value::Number(serde_json::Number::from(val))) + } + Type::FLOAT4 | Type::FLOAT8 => { + let fval = val.parse::()?; + let num = serde_json::Number::from_f64(fval); + if let Some(num) = num { + Ok(Value::Number(num)) + } else { + // Pass Nan, Inf, -Inf as strings + // JS JSON.stringify() does converts them to null, but we + // want to preserve them, so we pass them as strings + Ok(Value::String(val.to_string())) + } + } + Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?), + _ => Ok(Value::String(val.to_string())), + } + } else { + Ok(Value::Null) + } +} + +// +// Parse postgres array into JSON array. +// +// This is a bit involved because we need to handle nested arrays and quoted +// values. Unlike postgres we don't check that all nested arrays have the same +// dimensions, we just return them as is. +// +fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result { + _pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v) +} + +fn _pg_array_parse( + pg_array: &str, + elem_type: &Type, + nested: bool, +) -> Result<(Value, usize), anyhow::Error> { + let mut pg_array_chr = pg_array.char_indices(); + let mut level = 0; + let mut quote = false; + let mut entries: Vec = Vec::new(); + let mut entry = String::new(); + + // skip bounds decoration + if let Some('[') = pg_array.chars().next() { + for (_, c) in pg_array_chr.by_ref() { + if c == '=' { + break; + } + } + } + + fn push_checked( + entry: &mut String, + entries: &mut Vec, + elem_type: &Type, + ) -> Result<(), anyhow::Error> { + if !entry.is_empty() { + // While in usual postgres response we get nulls as None and everything else + // as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while + // string with value 'NULL' will be represented by '"NULL"'). So catch NULLs + // here while we have quotation info and convert them to None. + if entry == "NULL" { + entries.push(pg_text_to_json(None, elem_type)?); + } else { + entries.push(pg_text_to_json(Some(entry), elem_type)?); + } + entry.clear(); + } + + Ok(()) + } + + while let Some((mut i, mut c)) = pg_array_chr.next() { + let mut escaped = false; + + if c == '\\' { + escaped = true; + (i, c) = pg_array_chr.next().unwrap(); + } + + match c { + '{' if !quote => { + level += 1; + if level > 1 { + let (res, off) = _pg_array_parse(&pg_array[i..], elem_type, true)?; + entries.push(res); + for _ in 0..off - 1 { + pg_array_chr.next(); + } + } + } + '}' if !quote => { + level -= 1; + if level == 0 { + push_checked(&mut entry, &mut entries, elem_type)?; + if nested { + return Ok((Value::Array(entries), i)); + } + } + } + '"' if !escaped => { + if quote { + // end of quoted string, so push it manually without any checks + // for emptiness or nulls + entries.push(pg_text_to_json(Some(&entry), elem_type)?); + entry.clear(); + } + quote = !quote; + } + ',' if !quote => { + push_checked(&mut entry, &mut entries, elem_type)?; + } + _ => { + entry.push(c); + } + } + } + + if level != 0 { + return Err(anyhow::anyhow!("unbalanced array")); + } + + Ok((Value::Array(entries), 0)) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_atomic_types_to_pg_params() { + let json = vec![Value::Bool(true), Value::Bool(false)]; + let pg_params = json_to_pg_text(json); + assert_eq!( + pg_params, + vec![Some("true".to_owned()), Some("false".to_owned())] + ); + + let json = vec![Value::Number(serde_json::Number::from(42))]; + let pg_params = json_to_pg_text(json); + assert_eq!(pg_params, vec![Some("42".to_owned())]); + + let json = vec![Value::String("foo\"".to_string())]; + let pg_params = json_to_pg_text(json); + assert_eq!(pg_params, vec![Some("foo\"".to_owned())]); + + let json = vec![Value::Null]; + let pg_params = json_to_pg_text(json); + assert_eq!(pg_params, vec![None]); + } + + #[test] + fn test_json_array_to_pg_array() { + // atoms and escaping + let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]"; + let json: Value = serde_json::from_str(json).unwrap(); + let pg_params = json_to_pg_text(vec![json]); + assert_eq!( + pg_params, + vec![Some( + "{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned() + )] + ); + + // nested arrays + let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]"; + let json: Value = serde_json::from_str(json).unwrap(); + let pg_params = json_to_pg_text(vec![json]); + assert_eq!( + pg_params, + vec![Some( + "{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned() + )] + ); + // array of objects + let json = r#"[{"foo": 1},{"bar": 2}]"#; + let json: Value = serde_json::from_str(json).unwrap(); + let pg_params = json_to_pg_text(vec![json]); + assert_eq!( + pg_params, + vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())] + ); + } + + #[test] + fn test_atomic_types_parse() { + assert_eq!( + pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(), + json!("foo") + ); + assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null)); + assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42)); + assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42)); + assert_eq!( + pg_text_to_json(Some("42"), &Type::INT8).unwrap(), + json!("42") + ); + assert_eq!( + pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(), + json!(42.42) + ); + assert_eq!( + pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(), + json!(42.42) + ); + assert_eq!( + pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(), + json!("NaN") + ); + assert_eq!( + pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(), + json!("Infinity") + ); + assert_eq!( + pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(), + json!("-Infinity") + ); + + let json: Value = + serde_json::from_str("{\"s\":\"str\",\"n\":42,\"f\":4.2,\"a\":[null,3,\"a\"]}") + .unwrap(); + assert_eq!( + pg_text_to_json( + Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#), + &Type::JSONB + ) + .unwrap(), + json + ); + } + + #[test] + fn test_pg_array_parse_text() { + fn pt(pg_arr: &str) -> Value { + pg_array_parse(pg_arr, &Type::TEXT).unwrap() + } + assert_eq!( + pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#), + json!(["aa\"\\,a", "cha", "bbbb"]) + ); + assert_eq!( + pt(r#"{{"foo","bar"},{"bee","bop"}}"#), + json!([["foo", "bar"], ["bee", "bop"]]) + ); + assert_eq!( + pt(r#"{{{{"foo",NULL,"bop",bup}}}}"#), + json!([[[["foo", null, "bop", "bup"]]]]) + ); + assert_eq!( + pt(r#"{{"1",2,3},{4,NULL,6},{NULL,NULL,NULL}}"#), + json!([["1", "2", "3"], ["4", null, "6"], [null, null, null]]) + ); + } + + #[test] + fn test_pg_array_parse_bool() { + fn pb(pg_arr: &str) -> Value { + pg_array_parse(pg_arr, &Type::BOOL).unwrap() + } + assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true])); + assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]])); + assert_eq!( + pb(r#"{{t,f},{f,t}}"#), + json!([[true, false], [false, true]]) + ); + assert_eq!( + pb(r#"{{t,NULL},{NULL,f}}"#), + json!([[true, null], [null, false]]) + ); + } + + #[test] + fn test_pg_array_parse_numbers() { + fn pn(pg_arr: &str, ty: &Type) -> Value { + pg_array_parse(pg_arr, ty).unwrap() + } + assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3])); + assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3])); + assert_eq!(pn(r#"{1,2,3}"#, &Type::INT8), json!(["1", "2", "3"])); + assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT4), json!([1.0, 2.0, 3.0])); + assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT8), json!([1.0, 2.0, 3.0])); + assert_eq!( + pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT4), + json!([1.1, 2.2, 3.3]) + ); + assert_eq!( + pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT8), + json!([1.1, 2.2, 3.3]) + ); + assert_eq!( + pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT4), + json!(["NaN", "Infinity", "-Infinity"]) + ); + assert_eq!( + pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT8), + json!(["NaN", "Infinity", "-Infinity"]) + ); + } + + #[test] + fn test_pg_array_with_decoration() { + fn p(pg_arr: &str) -> Value { + pg_array_parse(pg_arr, &Type::INT2).unwrap() + } + assert_eq!( + p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#), + json!([[[1, 2, 3], [4, 5, 6]]]) + ); + } + + #[test] + fn test_pg_array_parse_json() { + fn pt(pg_arr: &str) -> Value { + pg_array_parse(pg_arr, &Type::JSONB).unwrap() + } + assert_eq!(pt(r#"{"{}"}"#), json!([{}])); + assert_eq!( + pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#), + json!([{"foo": 1, "bar": 2}]) + ); + assert_eq!( + pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#), + json!([{"foo": 1}, {"bar": 2}]) + ); + assert_eq!( + pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#), + json!([[{"foo": 1}, {"bar": 2}]]) + ); + } +} diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 27c2134221..96bf39c915 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -12,16 +12,12 @@ use hyper::Response; use hyper::StatusCode; use hyper::{Body, HeaderMap, Request}; use serde_json::json; -use serde_json::Map; use serde_json::Value; use tokio_postgres::error::DbError; use tokio_postgres::error::ErrorPosition; -use tokio_postgres::types::Kind; -use tokio_postgres::types::Type; use tokio_postgres::GenericClient; use tokio_postgres::IsolationLevel; use tokio_postgres::ReadyForQueryStatus; -use tokio_postgres::Row; use tokio_postgres::Transaction; use tracing::error; use tracing::instrument; @@ -40,6 +36,7 @@ use crate::RoleName; use super::conn_pool::ConnInfo; use super::conn_pool::GlobalConnPool; +use super::json::{json_to_pg_text, pg_text_row_to_json}; use super::SERVERLESS_DRIVER_SNI; #[derive(serde::Deserialize)] @@ -72,62 +69,6 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); -// -// Convert json non-string types to strings, so that they can be passed to Postgres -// as parameters. -// -fn json_to_pg_text(json: Vec) -> Vec> { - json.iter() - .map(|value| { - match value { - // special care for nulls - Value::Null => None, - - // convert to text with escaping - v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()), - - // avoid escaping here, as we pass this as a parameter - Value::String(s) => Some(s.to_string()), - - // special care for arrays - Value::Array(_) => json_array_to_pg_array(value), - } - }) - .collect() -} - -// -// Serialize a JSON array to a Postgres array. Contrary to the strings in the params -// in the array we need to escape the strings. Postgres is okay with arrays of form -// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving -// it for Postgres to check. -// -// Example of the same escaping in node-postgres: packages/pg/lib/utils.js -// -fn json_array_to_pg_array(value: &Value) -> Option { - match value { - // special care for nulls - Value::Null => None, - - // convert to text with escaping - // here string needs to be escaped, as it is part of the array - v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()), - v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())), - - // recurse into array - Value::Array(arr) => { - let vals = arr - .iter() - .map(json_array_to_pg_array) - .map(|v| v.unwrap_or_else(|| "NULL".to_string())) - .collect::>() - .join(","); - - Some(format!("{{{}}}", vals)) - } - } -} - fn get_conn_info( ctx: &mut RequestMonitoring, headers: &HeaderMap, @@ -611,389 +552,3 @@ async fn query_to_json( }), )) } - -// -// Convert postgres row with text-encoded values to JSON object -// -pub fn pg_text_row_to_json( - row: &Row, - columns: &[Type], - raw_output: bool, - array_mode: bool, -) -> Result { - let iter = row - .columns() - .iter() - .zip(columns) - .enumerate() - .map(|(i, (column, typ))| { - let name = column.name(); - let pg_value = row.as_text(i)?; - let json_value = if raw_output { - match pg_value { - Some(v) => Value::String(v.to_string()), - None => Value::Null, - } - } else { - pg_text_to_json(pg_value, typ)? - }; - Ok((name.to_string(), json_value)) - }); - - if array_mode { - // drop keys and aggregate into array - let arr = iter - .map(|r| r.map(|(_key, val)| val)) - .collect::, anyhow::Error>>()?; - Ok(Value::Array(arr)) - } else { - let obj = iter.collect::, anyhow::Error>>()?; - Ok(Value::Object(obj)) - } -} - -// -// Convert postgres text-encoded value to JSON value -// -pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result { - if let Some(val) = pg_value { - if let Kind::Array(elem_type) = pg_type.kind() { - return pg_array_parse(val, elem_type); - } - - match *pg_type { - Type::BOOL => Ok(Value::Bool(val == "t")), - Type::INT2 | Type::INT4 => { - let val = val.parse::()?; - Ok(Value::Number(serde_json::Number::from(val))) - } - Type::FLOAT4 | Type::FLOAT8 => { - let fval = val.parse::()?; - let num = serde_json::Number::from_f64(fval); - if let Some(num) = num { - Ok(Value::Number(num)) - } else { - // Pass Nan, Inf, -Inf as strings - // JS JSON.stringify() does converts them to null, but we - // want to preserve them, so we pass them as strings - Ok(Value::String(val.to_string())) - } - } - Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?), - _ => Ok(Value::String(val.to_string())), - } - } else { - Ok(Value::Null) - } -} - -// -// Parse postgres array into JSON array. -// -// This is a bit involved because we need to handle nested arrays and quoted -// values. Unlike postgres we don't check that all nested arrays have the same -// dimensions, we just return them as is. -// -fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result { - _pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v) -} - -fn _pg_array_parse( - pg_array: &str, - elem_type: &Type, - nested: bool, -) -> Result<(Value, usize), anyhow::Error> { - let mut pg_array_chr = pg_array.char_indices(); - let mut level = 0; - let mut quote = false; - let mut entries: Vec = Vec::new(); - let mut entry = String::new(); - - // skip bounds decoration - if let Some('[') = pg_array.chars().next() { - for (_, c) in pg_array_chr.by_ref() { - if c == '=' { - break; - } - } - } - - fn push_checked( - entry: &mut String, - entries: &mut Vec, - elem_type: &Type, - ) -> Result<(), anyhow::Error> { - if !entry.is_empty() { - // While in usual postgres response we get nulls as None and everything else - // as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while - // string with value 'NULL' will be represented by '"NULL"'). So catch NULLs - // here while we have quotation info and convert them to None. - if entry == "NULL" { - entries.push(pg_text_to_json(None, elem_type)?); - } else { - entries.push(pg_text_to_json(Some(entry), elem_type)?); - } - entry.clear(); - } - - Ok(()) - } - - while let Some((mut i, mut c)) = pg_array_chr.next() { - let mut escaped = false; - - if c == '\\' { - escaped = true; - (i, c) = pg_array_chr.next().unwrap(); - } - - match c { - '{' if !quote => { - level += 1; - if level > 1 { - let (res, off) = _pg_array_parse(&pg_array[i..], elem_type, true)?; - entries.push(res); - for _ in 0..off - 1 { - pg_array_chr.next(); - } - } - } - '}' if !quote => { - level -= 1; - if level == 0 { - push_checked(&mut entry, &mut entries, elem_type)?; - if nested { - return Ok((Value::Array(entries), i)); - } - } - } - '"' if !escaped => { - if quote { - // end of quoted string, so push it manually without any checks - // for emptiness or nulls - entries.push(pg_text_to_json(Some(&entry), elem_type)?); - entry.clear(); - } - quote = !quote; - } - ',' if !quote => { - push_checked(&mut entry, &mut entries, elem_type)?; - } - _ => { - entry.push(c); - } - } - } - - if level != 0 { - return Err(anyhow::anyhow!("unbalanced array")); - } - - Ok((Value::Array(entries), 0)) -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_atomic_types_to_pg_params() { - let json = vec![Value::Bool(true), Value::Bool(false)]; - let pg_params = json_to_pg_text(json); - assert_eq!( - pg_params, - vec![Some("true".to_owned()), Some("false".to_owned())] - ); - - let json = vec![Value::Number(serde_json::Number::from(42))]; - let pg_params = json_to_pg_text(json); - assert_eq!(pg_params, vec![Some("42".to_owned())]); - - let json = vec![Value::String("foo\"".to_string())]; - let pg_params = json_to_pg_text(json); - assert_eq!(pg_params, vec![Some("foo\"".to_owned())]); - - let json = vec![Value::Null]; - let pg_params = json_to_pg_text(json); - assert_eq!(pg_params, vec![None]); - } - - #[test] - fn test_json_array_to_pg_array() { - // atoms and escaping - let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]"; - let json: Value = serde_json::from_str(json).unwrap(); - let pg_params = json_to_pg_text(vec![json]); - assert_eq!( - pg_params, - vec![Some( - "{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned() - )] - ); - - // nested arrays - let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]"; - let json: Value = serde_json::from_str(json).unwrap(); - let pg_params = json_to_pg_text(vec![json]); - assert_eq!( - pg_params, - vec![Some( - "{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned() - )] - ); - // array of objects - let json = r#"[{"foo": 1},{"bar": 2}]"#; - let json: Value = serde_json::from_str(json).unwrap(); - let pg_params = json_to_pg_text(vec![json]); - assert_eq!( - pg_params, - vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())] - ); - } - - #[test] - fn test_atomic_types_parse() { - assert_eq!( - pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(), - json!("foo") - ); - assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null)); - assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42)); - assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42)); - assert_eq!( - pg_text_to_json(Some("42"), &Type::INT8).unwrap(), - json!("42") - ); - assert_eq!( - pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(), - json!(42.42) - ); - assert_eq!( - pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(), - json!(42.42) - ); - assert_eq!( - pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(), - json!("NaN") - ); - assert_eq!( - pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(), - json!("Infinity") - ); - assert_eq!( - pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(), - json!("-Infinity") - ); - - let json: Value = - serde_json::from_str("{\"s\":\"str\",\"n\":42,\"f\":4.2,\"a\":[null,3,\"a\"]}") - .unwrap(); - assert_eq!( - pg_text_to_json( - Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#), - &Type::JSONB - ) - .unwrap(), - json - ); - } - - #[test] - fn test_pg_array_parse_text() { - fn pt(pg_arr: &str) -> Value { - pg_array_parse(pg_arr, &Type::TEXT).unwrap() - } - assert_eq!( - pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#), - json!(["aa\"\\,a", "cha", "bbbb"]) - ); - assert_eq!( - pt(r#"{{"foo","bar"},{"bee","bop"}}"#), - json!([["foo", "bar"], ["bee", "bop"]]) - ); - assert_eq!( - pt(r#"{{{{"foo",NULL,"bop",bup}}}}"#), - json!([[[["foo", null, "bop", "bup"]]]]) - ); - assert_eq!( - pt(r#"{{"1",2,3},{4,NULL,6},{NULL,NULL,NULL}}"#), - json!([["1", "2", "3"], ["4", null, "6"], [null, null, null]]) - ); - } - - #[test] - fn test_pg_array_parse_bool() { - fn pb(pg_arr: &str) -> Value { - pg_array_parse(pg_arr, &Type::BOOL).unwrap() - } - assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true])); - assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]])); - assert_eq!( - pb(r#"{{t,f},{f,t}}"#), - json!([[true, false], [false, true]]) - ); - assert_eq!( - pb(r#"{{t,NULL},{NULL,f}}"#), - json!([[true, null], [null, false]]) - ); - } - - #[test] - fn test_pg_array_parse_numbers() { - fn pn(pg_arr: &str, ty: &Type) -> Value { - pg_array_parse(pg_arr, ty).unwrap() - } - assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3])); - assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3])); - assert_eq!(pn(r#"{1,2,3}"#, &Type::INT8), json!(["1", "2", "3"])); - assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT4), json!([1.0, 2.0, 3.0])); - assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT8), json!([1.0, 2.0, 3.0])); - assert_eq!( - pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT4), - json!([1.1, 2.2, 3.3]) - ); - assert_eq!( - pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT8), - json!([1.1, 2.2, 3.3]) - ); - assert_eq!( - pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT4), - json!(["NaN", "Infinity", "-Infinity"]) - ); - assert_eq!( - pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT8), - json!(["NaN", "Infinity", "-Infinity"]) - ); - } - - #[test] - fn test_pg_array_with_decoration() { - fn p(pg_arr: &str) -> Value { - pg_array_parse(pg_arr, &Type::INT2).unwrap() - } - assert_eq!( - p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#), - json!([[[1, 2, 3], [4, 5, 6]]]) - ); - } - #[test] - fn test_pg_array_parse_json() { - fn pt(pg_arr: &str) -> Value { - pg_array_parse(pg_arr, &Type::JSONB).unwrap() - } - assert_eq!(pt(r#"{"{}"}"#), json!([{}])); - assert_eq!( - pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#), - json!([{"foo": 1, "bar": 2}]) - ); - assert_eq!( - pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#), - json!([{"foo": 1}, {"bar": 2}]) - ); - assert_eq!( - pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#), - json!([[{"foo": 1}, {"bar": 2}]]) - ); - } -}