mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-26 01:20:38 +00:00
[proxy] Implement compute node info cache (#3331)
This patch adds a timed LRU cache implementation and a compute node info cache on top of that. Cache entries might expire on their own (default ttl=5mins) or become invalid due to real-world events, e.g. compute node scale-to-zero event, so we add a connection retry loop with a wake-up call. Solved problems: - [x] Find a decent LRU implementation. - [x] Implement timed LRU on top of that. - [x] Cache results of `proxy_wake_compute` API call. - [x] Don't invalidate newer cache entries for the same key. - [x] Add cmdline configuration knobs (requires some refactoring). - [x] Add failed connection estab metric. - [x] Refactor auth backends to make things simpler (retries, cache placement, etc). - [x] Address review comments (add code comments + cleanup). - [x] Retry `/proxy_wake_compute` if we couldn't connect to a compute (e.g. stalled cache entry). - [x] Add high-level description for `TimedLru`. TODOs (will be addressed later): - [ ] Add cache metrics (hit, spurious hit, miss). - [ ] Synchronize http requests across concurrent per-client tasks (https://github.com/neondatabase/neon/pull/3331#issuecomment-1399216069). - [ ] Cache results of `proxy_get_role_secret` API call.
This commit is contained in:
@@ -63,13 +63,13 @@ impl KickSession<'_> {
|
||||
/// Compute node connection params.
|
||||
#[derive(Deserialize)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
pub host: Box<str>,
|
||||
pub port: u16,
|
||||
pub dbname: String,
|
||||
pub user: String,
|
||||
pub dbname: Box<str>,
|
||||
pub user: Box<str>,
|
||||
/// Console always provides a password, but it might
|
||||
/// be inconvenient for debug with local PG instance.
|
||||
pub password: Option<String>,
|
||||
pub password: Option<Box<str>>,
|
||||
pub aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
|
||||
112
proxy/src/console/mgmt.rs
Normal file
112
proxy/src/console/mgmt.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use crate::{
|
||||
console::messages::{DatabaseInfo, KickSession},
|
||||
waiters::{self, Waiter, Waiters},
|
||||
};
|
||||
use anyhow::Context;
|
||||
use once_cell::sync::Lazy;
|
||||
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
|
||||
use std::{
|
||||
net::{TcpListener, TcpStream},
|
||||
thread,
|
||||
};
|
||||
use tracing::{error, info, info_span};
|
||||
use utils::{
|
||||
postgres_backend::{self, AuthType, PostgresBackend},
|
||||
postgres_backend_async::QueryError,
|
||||
};
|
||||
|
||||
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
|
||||
|
||||
/// Give caller an opportunity to wait for the cloud's reply.
|
||||
pub async fn with_waiter<R, T, E>(
|
||||
psql_session_id: impl Into<String>,
|
||||
action: impl FnOnce(Waiter<'static, ComputeReady>) -> R,
|
||||
) -> Result<T, E>
|
||||
where
|
||||
R: std::future::Future<Output = Result<T, E>>,
|
||||
E: From<waiters::RegisterError>,
|
||||
{
|
||||
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
|
||||
action(waiter).await
|
||||
}
|
||||
|
||||
pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
/// Console management API listener thread.
|
||||
/// It spawns console response handlers needed for the link auth.
|
||||
pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("mgmt has shut down");
|
||||
}
|
||||
|
||||
listener
|
||||
.set_nonblocking(false)
|
||||
.context("failed to set listener to blocking")?;
|
||||
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept().context("failed to accept a new client")?;
|
||||
info!("accepted connection from {peer_addr}");
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.context("failed to set client socket option")?;
|
||||
|
||||
// TODO: replace with async tasks.
|
||||
thread::spawn(move || {
|
||||
let tid = std::thread::current().id();
|
||||
let span = info_span!("mgmt", thread = format_args!("{tid:?}"));
|
||||
let _enter = span.enter();
|
||||
|
||||
info!("started a new console management API thread");
|
||||
scopeguard::defer! {
|
||||
info!("console management API thread is about to finish");
|
||||
}
|
||||
|
||||
if let Err(e) = handle_connection(socket) {
|
||||
error!("thread failed with an error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
|
||||
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?;
|
||||
pgbackend.run(&mut MgmtHandler)
|
||||
}
|
||||
|
||||
/// A message received by `mgmt` when a compute node is ready.
|
||||
pub type ComputeReady = Result<DatabaseInfo, String>;
|
||||
|
||||
// TODO: replace with an http-based protocol.
|
||||
struct MgmtHandler;
|
||||
impl postgres_backend::Handler for MgmtHandler {
|
||||
fn process_query(&mut self, pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
|
||||
try_process_query(pgb, query).map_err(|e| {
|
||||
error!("failed to process response: {e:?}");
|
||||
e
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
|
||||
let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
|
||||
|
||||
let span = info_span!("event", session_id = resp.session_id);
|
||||
let _enter = span.enter();
|
||||
info!("got response: {:?}", resp.result);
|
||||
|
||||
match notify(resp.session_id, Ok(resp.result)) {
|
||||
Ok(()) => {
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("failed to deliver response to per-client task");
|
||||
pgb.write_message(&BeMessage::ErrorResponse(&e.to_string(), None))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
194
proxy/src/console/provider.rs
Normal file
194
proxy/src/console/provider.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
pub mod mock;
|
||||
pub mod neon;
|
||||
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use crate::{
|
||||
auth::ClientCredentials,
|
||||
cache::{timed_lru, TimedLru},
|
||||
compute, scram,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub mod errors {
|
||||
use crate::error::{io_error, UserFacingError};
|
||||
use reqwest::StatusCode as HttpStatusCode;
|
||||
use thiserror::Error;
|
||||
|
||||
/// A go-to error message which doesn't leak any detail.
|
||||
const REQUEST_FAILED: &str = "Console request failed";
|
||||
|
||||
/// Common console API error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
/// Error returned by the console itself.
|
||||
#[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
|
||||
Console {
|
||||
status: HttpStatusCode,
|
||||
text: Box<str>,
|
||||
},
|
||||
|
||||
/// Various IO errors like broken pipe or malformed payload.
|
||||
#[error("{REQUEST_FAILED}: {0}")]
|
||||
Transport(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
/// Returns HTTP status code if it's the reason for failure.
|
||||
pub fn http_status_code(&self) -> Option<HttpStatusCode> {
|
||||
use ApiError::*;
|
||||
match self {
|
||||
Console { status, .. } => Some(*status),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ApiError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use ApiError::*;
|
||||
match self {
|
||||
// To minimize risks, only select errors are forwarded to users.
|
||||
// Ask @neondatabase/control-plane for review before adding more.
|
||||
Console { status, .. } => match *status {
|
||||
HttpStatusCode::NOT_FOUND => {
|
||||
// Status 404: failed to get a project-related resource.
|
||||
format!("{REQUEST_FAILED}: endpoint cannot be found")
|
||||
}
|
||||
HttpStatusCode::NOT_ACCEPTABLE => {
|
||||
// Status 406: endpoint is disabled (we don't allow connections).
|
||||
format!("{REQUEST_FAILED}: endpoint is disabled")
|
||||
}
|
||||
HttpStatusCode::LOCKED => {
|
||||
// Status 423: project might be in maintenance mode (or bad state).
|
||||
format!("{REQUEST_FAILED}: endpoint is temporary unavailable")
|
||||
}
|
||||
_ => REQUEST_FAILED.to_owned(),
|
||||
},
|
||||
_ => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helps eliminate graceless `.map_err` calls without introducing another ctor.
|
||||
impl From<reqwest::Error> for ApiError {
|
||||
fn from(e: reqwest::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GetAuthInfoError {
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Console responded with a malformed auth secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ApiError(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for GetAuthInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use GetAuthInfoError::*;
|
||||
match self {
|
||||
// We absolutely should not leak any secrets!
|
||||
BadSecret => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
ApiError(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WakeComputeError {
|
||||
#[error("Console responded with a malformed compute address: {0}")]
|
||||
BadComputeAddress(Box<str>),
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ApiError>> From<E> for WakeComputeError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ApiError(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for WakeComputeError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use WakeComputeError::*;
|
||||
match self {
|
||||
// We shouldn't show user the address even if it's broken.
|
||||
// Besides, user is unlikely to care about this detail.
|
||||
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
ApiError(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extra query params we'd like to pass to the console.
|
||||
pub struct ConsoleReqExtra<'a> {
|
||||
/// A unique identifier for a connection.
|
||||
pub session_id: uuid::Uuid,
|
||||
/// Name of client application, if set.
|
||||
pub application_name: Option<&'a str>,
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
pub enum AuthInfo {
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
/// This is what we get after auth succeeded, but not before!
|
||||
#[derive(Clone)]
|
||||
pub struct NodeInfo {
|
||||
/// Compute node connection params.
|
||||
/// It's sad that we have to clone this, but this will improve
|
||||
/// once we migrate to a bespoke connection logic.
|
||||
pub config: compute::ConnCfg,
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: Arc<MetricsAuxInfo>,
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
|
||||
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
#[async_trait]
|
||||
pub trait Api {
|
||||
/// Get the client's auth secret for authentication.
|
||||
async fn get_auth_info(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<Option<AuthInfo>, errors::GetAuthInfoError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
}
|
||||
|
||||
/// Various caches for [`console`].
|
||||
pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
pub node_info: NodeInfoCache,
|
||||
}
|
||||
135
proxy/src/console/provider/mock.rs
Normal file
135
proxy/src/console/provider/mock.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
//! Mock console backend which relies on a user-provided postgres instance.
|
||||
|
||||
use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use thiserror::Error;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum MockApiError {
|
||||
#[error("Failed to read password: {0}")]
|
||||
PasswordNotSet(tokio_postgres::Error),
|
||||
}
|
||||
|
||||
impl From<MockApiError> for ApiError {
|
||||
fn from(e: MockApiError) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio_postgres::Error> for ApiError {
|
||||
fn from(e: tokio_postgres::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
endpoint: ApiUrl,
|
||||
}
|
||||
|
||||
impl Api {
|
||||
pub fn new(endpoint: ApiUrl) -> Self {
|
||||
Self { endpoint }
|
||||
}
|
||||
|
||||
pub fn url(&self) -> &str {
|
||||
self.endpoint.as_str()
|
||||
}
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
|
||||
async {
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
// write more code for reopening it if it got closed, which doesn't
|
||||
// seem worth it.
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
|
||||
|
||||
tokio::spawn(connection);
|
||||
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
|
||||
let rows = client.query(query, &[&creds.user]).await?;
|
||||
|
||||
// We can get at most one row, because `rolname` is unique.
|
||||
let row = match rows.get(0) {
|
||||
Some(row) => row,
|
||||
// This means that the user doesn't exist, so there can be no secret.
|
||||
// However, this is still a *valid* outcome which is very similar
|
||||
// to getting `404 Not found` from the Neon console.
|
||||
None => {
|
||||
warn!("user '{}' does not exist", creds.user);
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let entry = row
|
||||
.try_get("rolpassword")
|
||||
.map_err(MockApiError::PasswordNotSet)?;
|
||||
|
||||
info!("got a secret: {entry}"); // safe since it's not a prod scenario
|
||||
let secret = scram::ServerSecret::parse(entry).map(AuthInfo::Scram);
|
||||
Ok(secret.or_else(|| parse_md5(entry).map(AuthInfo::Md5)))
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("postgres", url = self.endpoint.as_str()))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn do_wake_compute(
|
||||
&self,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<NodeInfo, WakeComputeError> {
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(self.endpoint.host_str().unwrap_or("localhost"))
|
||||
.port(self.endpoint.port().unwrap_or(5432))
|
||||
.dbname(creds.dbname)
|
||||
.user(creds.user);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
aux: Default::default(),
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_auth_info(
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
|
||||
self.do_get_auth_info(creds).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
self.do_wake_compute(creds)
|
||||
.map_ok(CachedNodeInfo::new_uncached)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_md5(input: &str) -> Option<[u8; 16]> {
|
||||
let text = input.strip_prefix("md5")?;
|
||||
|
||||
let mut bytes = [0u8; 16];
|
||||
hex::decode_to_slice(text, &mut bytes).ok()?;
|
||||
|
||||
Some(bytes)
|
||||
}
|
||||
196
proxy/src/console/provider/neon.rs
Normal file
196
proxy/src/console/provider/neon.rs
Normal file
@@ -0,0 +1,196 @@
|
||||
//! Production console backend.
|
||||
|
||||
use super::{
|
||||
super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::{auth::ClientCredentials, compute, http, scram};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use reqwest::StatusCode as HttpStatusCode;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
caches: &'static ApiCaches,
|
||||
}
|
||||
|
||||
impl Api {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub fn new(endpoint: http::Endpoint, caches: &'static ApiCaches) -> Self {
|
||||
Self { endpoint, caches }
|
||||
}
|
||||
|
||||
pub fn url(&self) -> &str {
|
||||
self.endpoint.url().as_str()
|
||||
}
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get("proxy_get_role_secret")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.query(&[("session_id", extra.session_id)])
|
||||
.query(&[
|
||||
("application_name", extra.application_name),
|
||||
("project", Some(creds.project().expect("impossible"))),
|
||||
("role", Some(creds.user)),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
let body = match parse_body::<GetRoleSecret>(response).await {
|
||||
Ok(body) => body,
|
||||
// Error 404 is special: it's ok not to have a secret.
|
||||
Err(e) => match e.http_status_code() {
|
||||
Some(HttpStatusCode::NOT_FOUND) => return Ok(None),
|
||||
_otherwise => return Err(e.into()),
|
||||
},
|
||||
};
|
||||
|
||||
let secret = scram::ServerSecret::parse(&body.role_secret)
|
||||
.map(AuthInfo::Scram)
|
||||
.ok_or(GetAuthInfoError::BadSecret)?;
|
||||
|
||||
Ok(Some(secret))
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("http", id = request_id))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn do_wake_compute(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<NodeInfo, WakeComputeError> {
|
||||
let project = creds.project().expect("impossible");
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get("proxy_wake_compute")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.query(&[("session_id", extra.session_id)])
|
||||
.query(&[
|
||||
("application_name", extra.application_name),
|
||||
("project", Some(project)),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
let body = parse_body::<WakeCompute>(response).await?;
|
||||
|
||||
// Unfortunately, ownership won't let us use `Option::ok_or` here.
|
||||
let (host, port) = match parse_host_port(&body.address) {
|
||||
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
|
||||
Some(x) => x,
|
||||
};
|
||||
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(host)
|
||||
.port(port)
|
||||
.dbname(creds.dbname)
|
||||
.user(creds.user);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
aux: body.aux.into(),
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("http", id = request_id))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_auth_info(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
|
||||
self.do_get_auth_info(extra, creds).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let key = creds.project().expect("impossible");
|
||||
|
||||
// Every time we do a wakeup http request, the compute node will stay up
|
||||
// for some time (highly depends on the console's scale-to-zero policy);
|
||||
// The connection info remains the same during that period of time,
|
||||
// which means that we might cache it to reduce the load and latency.
|
||||
if let Some(cached) = self.caches.node_info.get(key) {
|
||||
info!(key = key, "found cached compute node info");
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
let node = self.do_wake_compute(extra, creds).await?;
|
||||
let (_, cached) = self.caches.node_info.insert(key.into(), node);
|
||||
info!(key = key, "created a cache entry for compute node info");
|
||||
|
||||
Ok(cached)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse http response body, taking status code into account.
|
||||
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
response: reqwest::Response,
|
||||
) -> Result<T, ApiError> {
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
// We shouldn't log raw body because it may contain secrets.
|
||||
info!("request succeeded, processing the body");
|
||||
return Ok(response.json().await?);
|
||||
}
|
||||
|
||||
// Don't throw an error here because it's not as important
|
||||
// as the fact that the request itself has failed.
|
||||
let body = response.json().await.unwrap_or_else(|e| {
|
||||
warn!("failed to parse error body: {e}");
|
||||
ConsoleError {
|
||||
error: "reason unclear (malformed error message)".into(),
|
||||
}
|
||||
});
|
||||
|
||||
let text = body.error;
|
||||
error!("console responded with an error ({status}): {text}");
|
||||
Err(ApiError::Console { status, text })
|
||||
}
|
||||
|
||||
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
|
||||
let (host, port) = input.split_once(':')?;
|
||||
Some((host, port.parse().ok()?))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_host_port() {
|
||||
let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
|
||||
assert_eq!(host, "127.0.0.1");
|
||||
assert_eq!(port, 5432);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user