mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 21:42:56 +00:00
[proxy] Implement per-tenant traffic metrics
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
//! Client authentication mechanisms.
|
||||
|
||||
pub mod backend;
|
||||
pub use backend::{BackendType, ConsoleReqExtra, DatabaseInfo};
|
||||
pub use backend::{BackendType, ConsoleReqExtra};
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::ClientCredentials;
|
||||
|
||||
@@ -12,7 +12,6 @@ use crate::{
|
||||
waiters::{self, Waiter, Waiters},
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Cow;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
@@ -36,45 +35,6 @@ pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), wait
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
/// Compute node connection params provided by the cloud.
|
||||
/// Note how it implements serde traits, since we receive it over the wire.
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub dbname: String,
|
||||
pub user: String,
|
||||
pub password: Option<String>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit personal and sensitive info.
|
||||
impl std::fmt::Debug for DatabaseInfo {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fmt.debug_struct("DatabaseInfo")
|
||||
.field("host", &self.host)
|
||||
.field("port", &self.port)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
fn from(db_info: DatabaseInfo) -> Self {
|
||||
let mut config = tokio_postgres::Config::new();
|
||||
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password);
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
/// Extra query params we'd like to pass to the console.
|
||||
pub struct ConsoleReqExtra<'a> {
|
||||
/// A unique identifier for a connection.
|
||||
@@ -158,54 +118,107 @@ impl<'a, T, E> BackendType<'a, Result<T, E>> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A product of successful authentication.
|
||||
pub struct AuthSuccess<T> {
|
||||
/// Did we send [`pq_proto::BeMessage::AuthenticationOk`] to client?
|
||||
pub reported_auth_ok: bool,
|
||||
/// Something to be considered a positive result.
|
||||
pub value: T,
|
||||
}
|
||||
|
||||
impl<T> AuthSuccess<T> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`AuthSuccess<T>`] to [`AuthSuccess<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> AuthSuccess<R> {
|
||||
AuthSuccess {
|
||||
reported_auth_ok: self.reported_auth_ok,
|
||||
value: f(self.value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
/// This is what we get after auth succeeded, but not before!
|
||||
pub struct NodeInfo {
|
||||
/// Project from [`auth::ClientCredentials`].
|
||||
pub project: String,
|
||||
/// Compute node connection params.
|
||||
pub config: compute::ConnCfg,
|
||||
}
|
||||
|
||||
impl BackendType<'_, ClientCredentials<'_>> {
|
||||
/// Do something special if user didn't provide the `project` parameter.
|
||||
async fn try_password_hack(
|
||||
&mut self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> auth::Result<Option<AuthSuccess<NodeInfo>>> {
|
||||
use BackendType::*;
|
||||
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the project name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
let fetch_magic_payload = async {
|
||||
warn!("project name not specified, resorting to the password hack auth flow");
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
info!(project = &payload.project, "received missing parameter");
|
||||
auth::Result::Ok(payload)
|
||||
};
|
||||
|
||||
// TODO: find a proper way to merge those very similar blocks.
|
||||
let (mut config, payload) = match self {
|
||||
Console(endpoint, creds) if creds.project.is_none() => {
|
||||
let payload = fetch_magic_payload.await?;
|
||||
|
||||
let mut creds = creds.as_ref();
|
||||
creds.project = Some(payload.project.as_str().into());
|
||||
let config = console::Api::new(endpoint, extra, &creds)
|
||||
.wake_compute()
|
||||
.await?;
|
||||
|
||||
(config, payload)
|
||||
}
|
||||
Postgres(endpoint, creds) if creds.project.is_none() => {
|
||||
let payload = fetch_magic_payload.await?;
|
||||
|
||||
let mut creds = creds.as_ref();
|
||||
creds.project = Some(payload.project.as_str().into());
|
||||
let config = postgres::Api::new(endpoint, &creds).wake_compute().await?;
|
||||
|
||||
(config, payload)
|
||||
}
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
config.password(payload.password);
|
||||
Ok(Some(AuthSuccess {
|
||||
reported_auth_ok: false,
|
||||
value: NodeInfo {
|
||||
project: payload.project,
|
||||
config,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
pub async fn authenticate(
|
||||
mut self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> super::Result<compute::NodeInfo> {
|
||||
) -> auth::Result<AuthSuccess<NodeInfo>> {
|
||||
use BackendType::*;
|
||||
|
||||
if let Console(_, creds) | Postgres(_, creds) = &mut self {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the project name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
if creds.project().is_none() {
|
||||
warn!("project name not specified, resorting to the password hack auth flow");
|
||||
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
// Finally we may finish the initialization of `creds`.
|
||||
// TODO: add missing type safety to ClientCredentials.
|
||||
info!(project = &payload.project, "received missing parameter");
|
||||
creds.project = Some(payload.project.into());
|
||||
|
||||
let mut config = match &self {
|
||||
Console(endpoint, creds) => {
|
||||
console::Api::new(endpoint, extra, creds)
|
||||
.wake_compute()
|
||||
.await?
|
||||
}
|
||||
Postgres(endpoint, creds) => {
|
||||
postgres::Api::new(endpoint, creds).wake_compute().await?
|
||||
}
|
||||
_ => unreachable!("see the patterns above"),
|
||||
};
|
||||
|
||||
// We should use a password from payload as well.
|
||||
config.password(payload.password);
|
||||
|
||||
info!("user successfully authenticated (using the password hack)");
|
||||
return Ok(compute::NodeInfo {
|
||||
reported_auth_ok: false,
|
||||
config,
|
||||
});
|
||||
}
|
||||
// Handle cases when `project` is missing in `creds`.
|
||||
// TODO: type safety: return `creds` with irrefutable `project`.
|
||||
if let Some(res) = self.try_password_hack(extra, client).await? {
|
||||
info!("user successfully authenticated (using the password hack)");
|
||||
return Ok(res);
|
||||
}
|
||||
|
||||
let res = match self {
|
||||
@@ -215,22 +228,34 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
project = creds.project(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
assert!(creds.project.is_some());
|
||||
console::Api::new(&endpoint, extra, &creds)
|
||||
.handle_user(client)
|
||||
.await
|
||||
.await?
|
||||
.map(|config| NodeInfo {
|
||||
project: creds.project.unwrap().into_owned(),
|
||||
config,
|
||||
})
|
||||
}
|
||||
Postgres(endpoint, creds) => {
|
||||
info!("performing mock authentication using a local postgres instance");
|
||||
|
||||
assert!(creds.project.is_some());
|
||||
postgres::Api::new(&endpoint, &creds)
|
||||
.handle_user(client)
|
||||
.await
|
||||
.await?
|
||||
.map(|config| NodeInfo {
|
||||
project: creds.project.unwrap().into_owned(),
|
||||
config,
|
||||
})
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Link(url) => {
|
||||
info!("performing link authentication");
|
||||
link::handle_user(&url, client).await
|
||||
link::handle_user(&url, client).await?
|
||||
}
|
||||
}?;
|
||||
};
|
||||
|
||||
info!("user successfully authenticated");
|
||||
Ok(res)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
//! Cloud API V2.
|
||||
|
||||
use super::ConsoleReqExtra;
|
||||
use super::{AuthSuccess, ConsoleReqExtra};
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
compute::{self, ComputeConnCfg},
|
||||
compute,
|
||||
error::{io_error, UserFacingError},
|
||||
http, scram,
|
||||
stream::PqStream,
|
||||
@@ -128,7 +128,7 @@ impl<'a> Api<'a> {
|
||||
pub(super) async fn handle_user(
|
||||
self,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
) -> auth::Result<AuthSuccess<compute::ConnCfg>> {
|
||||
handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await
|
||||
}
|
||||
|
||||
@@ -164,7 +164,7 @@ impl<'a> Api<'a> {
|
||||
}
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg, WakeComputeError> {
|
||||
pub(super) async fn wake_compute(&self) -> Result<compute::ConnCfg, WakeComputeError> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
let req = self
|
||||
.endpoint
|
||||
@@ -195,7 +195,7 @@ impl<'a> Api<'a> {
|
||||
Some(x) => x,
|
||||
};
|
||||
|
||||
let mut config = ComputeConnCfg::new();
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(host)
|
||||
.port(port)
|
||||
@@ -213,10 +213,10 @@ pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>(
|
||||
endpoint: &'a Endpoint,
|
||||
get_auth_info: impl FnOnce(&'a Endpoint) -> GetAuthInfo,
|
||||
wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute,
|
||||
) -> auth::Result<compute::NodeInfo>
|
||||
) -> auth::Result<AuthSuccess<compute::ConnCfg>>
|
||||
where
|
||||
GetAuthInfo: Future<Output = Result<AuthInfo, GetAuthInfoError>>,
|
||||
WakeCompute: Future<Output = Result<ComputeConnCfg, WakeComputeError>>,
|
||||
WakeCompute: Future<Output = Result<compute::ConnCfg, WakeComputeError>>,
|
||||
{
|
||||
info!("fetching user's authentication info");
|
||||
let auth_info = get_auth_info(endpoint).await?;
|
||||
@@ -243,9 +243,9 @@ where
|
||||
config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(keys));
|
||||
}
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: false,
|
||||
config,
|
||||
value: config,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use super::{AuthSuccess, NodeInfo};
|
||||
use crate::{auth, compute, error::UserFacingError, stream::PqStream, waiters};
|
||||
use pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
use thiserror::Error;
|
||||
@@ -49,7 +50,7 @@ pub fn new_psql_session_id() -> String {
|
||||
pub async fn handle_user(
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
) -> auth::Result<AuthSuccess<NodeInfo>> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
let span = info_span!("link", psql_session_id = &psql_session_id);
|
||||
let greeting = hello_message(link_uri, &psql_session_id);
|
||||
@@ -71,8 +72,22 @@ pub async fn handle_user(
|
||||
|
||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password);
|
||||
}
|
||||
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: true,
|
||||
config: db_info.into(),
|
||||
value: NodeInfo {
|
||||
project: db_info.project,
|
||||
config,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
//! Local mock of Cloud API V2.
|
||||
|
||||
use super::{
|
||||
console::{self, AuthInfo, GetAuthInfoError, TransportError, WakeComputeError},
|
||||
AuthSuccess,
|
||||
};
|
||||
use crate::{
|
||||
auth::{
|
||||
self,
|
||||
backend::console::{self, AuthInfo, GetAuthInfoError, TransportError, WakeComputeError},
|
||||
ClientCredentials,
|
||||
},
|
||||
compute::{self, ComputeConnCfg},
|
||||
auth::{self, ClientCredentials},
|
||||
compute,
|
||||
error::io_error,
|
||||
scram,
|
||||
stream::PqStream,
|
||||
@@ -37,7 +37,7 @@ impl<'a> Api<'a> {
|
||||
pub(super) async fn handle_user(
|
||||
self,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
) -> auth::Result<AuthSuccess<compute::ConnCfg>> {
|
||||
// We reuse user handling logic from a production module.
|
||||
console::handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await
|
||||
}
|
||||
@@ -82,8 +82,8 @@ impl<'a> Api<'a> {
|
||||
}
|
||||
|
||||
/// We don't need to wake anything locally, so we just return the connection info.
|
||||
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg, WakeComputeError> {
|
||||
let mut config = ComputeConnCfg::new();
|
||||
pub(super) async fn wake_compute(&self) -> Result<compute::ConnCfg, WakeComputeError> {
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(self.endpoint.host_str().unwrap_or("localhost"))
|
||||
.port(self.endpoint.port().unwrap_or(5432))
|
||||
|
||||
@@ -36,11 +36,23 @@ pub struct ClientCredentials<'a> {
|
||||
}
|
||||
|
||||
impl ClientCredentials<'_> {
|
||||
#[inline]
|
||||
pub fn project(&self) -> Option<&str> {
|
||||
self.project.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ClientCredentials<'a> {
|
||||
#[inline]
|
||||
pub fn as_ref(&'a self) -> ClientCredentials<'a> {
|
||||
Self {
|
||||
user: self.user,
|
||||
dbname: self.dbname,
|
||||
project: self.project().map(Cow::Borrowed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ClientCredentials<'a> {
|
||||
pub fn parse(
|
||||
params: &'a StartupMessageParams,
|
||||
|
||||
@@ -40,17 +40,36 @@ impl UserFacingError for ConnectionError {
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
|
||||
|
||||
pub type ComputeConnCfg = tokio_postgres::Config;
|
||||
/// A config for establishing a connection to compute node.
|
||||
/// Eventually, `tokio_postgres` will be replaced with something better.
|
||||
/// Newtype allows us to implement methods on top of it.
|
||||
#[repr(transparent)]
|
||||
pub struct ConnCfg(pub tokio_postgres::Config);
|
||||
|
||||
/// Various compute node info for establishing connection etc.
|
||||
pub struct NodeInfo {
|
||||
/// Did we send [`pq_proto::BeMessage::AuthenticationOk`]?
|
||||
pub reported_auth_ok: bool,
|
||||
/// Compute node connection params.
|
||||
pub config: tokio_postgres::Config,
|
||||
impl ConnCfg {
|
||||
/// Construct a new connection config.
|
||||
pub fn new() -> Self {
|
||||
Self(tokio_postgres::Config::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
impl std::ops::Deref for ConnCfg {
|
||||
type Target = tokio_postgres::Config;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// For now, let's make it easier to setup the config.
|
||||
impl std::ops::DerefMut for ConnCfg {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
/// Establish a raw TCP connection to the compute node.
|
||||
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> {
|
||||
use tokio_postgres::config::Host;
|
||||
|
||||
@@ -68,8 +87,8 @@ impl NodeInfo {
|
||||
// because it has no means for extracting the underlying socket which we
|
||||
// require for our business.
|
||||
let mut connection_error = None;
|
||||
let ports = self.config.get_ports();
|
||||
let hosts = self.config.get_hosts();
|
||||
let ports = self.0.get_ports();
|
||||
let hosts = self.0.get_hosts();
|
||||
// the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
|
||||
if ports.len() > 1 && ports.len() != hosts.len() {
|
||||
return Err(io::Error::new(
|
||||
@@ -77,7 +96,7 @@ impl NodeInfo {
|
||||
format!(
|
||||
"couldn't connect: bad compute config, \
|
||||
ports and hosts entries' count does not match: {:?}",
|
||||
self.config
|
||||
self.0
|
||||
),
|
||||
));
|
||||
}
|
||||
@@ -103,7 +122,7 @@ impl NodeInfo {
|
||||
Err(connection_error.unwrap_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("couldn't connect: bad compute config: {:?}", self.config),
|
||||
format!("couldn't connect: bad compute config: {:?}", self.0),
|
||||
)
|
||||
}))
|
||||
}
|
||||
@@ -116,7 +135,7 @@ pub struct PostgresConnection {
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
impl ConnCfg {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub async fn connect(
|
||||
mut self,
|
||||
@@ -130,21 +149,21 @@ impl NodeInfo {
|
||||
.intersperse(" ") // TODO: use impl from std once it's stabilized
|
||||
.collect();
|
||||
|
||||
self.config.options(&options);
|
||||
self.0.options(&options);
|
||||
}
|
||||
|
||||
if let Some(app_name) = params.get("application_name") {
|
||||
self.config.application_name(app_name);
|
||||
self.0.application_name(app_name);
|
||||
}
|
||||
|
||||
if let Some(replication) = params.get("replication") {
|
||||
use tokio_postgres::config::ReplicationMode;
|
||||
match replication {
|
||||
"true" | "on" | "yes" | "1" => {
|
||||
self.config.replication_mode(ReplicationMode::Physical);
|
||||
self.0.replication_mode(ReplicationMode::Physical);
|
||||
}
|
||||
"database" => {
|
||||
self.config.replication_mode(ReplicationMode::Logical);
|
||||
self.0.replication_mode(ReplicationMode::Logical);
|
||||
}
|
||||
_other => {}
|
||||
}
|
||||
@@ -160,7 +179,7 @@ impl NodeInfo {
|
||||
.map_err(|_| ConnectionError::FailedToConnectToCompute)?;
|
||||
|
||||
// TODO: establish a secure connection to the DB
|
||||
let (client, conn) = self.config.connect_raw(&mut stream, NoTls).await?;
|
||||
let (client, conn) = self.0.connect_raw(&mut stream, NoTls).await?;
|
||||
let version = conn
|
||||
.parameter("server_version")
|
||||
.ok_or(ConnectionError::FailedToFetchPgVersion)?
|
||||
|
||||
@@ -6,16 +6,11 @@ use std::{
|
||||
net::{TcpListener, TcpStream},
|
||||
thread,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
use tracing::{error, info, info_span};
|
||||
use utils::postgres_backend::{self, AuthType, PostgresBackend};
|
||||
|
||||
/// TODO: move all of that to auth-backend/link.rs when we ditch legacy-console backend
|
||||
|
||||
///
|
||||
/// Main proxy listener loop.
|
||||
///
|
||||
/// Listens for connections, and launches a new handler thread for each.
|
||||
///
|
||||
/// Console management API listener thread.
|
||||
/// It spawns console response handlers needed for the link auth.
|
||||
pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("mgmt has shut down");
|
||||
@@ -24,6 +19,7 @@ pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||
listener
|
||||
.set_nonblocking(false)
|
||||
.context("failed to set listener to blocking")?;
|
||||
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept().context("failed to accept a new client")?;
|
||||
info!("accepted connection from {peer_addr}");
|
||||
@@ -31,9 +27,19 @@ pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||
.set_nodelay(true)
|
||||
.context("failed to set client socket option")?;
|
||||
|
||||
// TODO: replace with async tasks.
|
||||
thread::spawn(move || {
|
||||
if let Err(err) = handle_connection(socket) {
|
||||
error!("{err}");
|
||||
let tid = std::thread::current().id();
|
||||
let span = info_span!("mgmt", thread = format_args!("{tid:?}"));
|
||||
let _enter = span.enter();
|
||||
|
||||
info!("started a new console management API thread");
|
||||
scopeguard::defer! {
|
||||
info!("console management API thread is about to finish");
|
||||
}
|
||||
|
||||
if let Err(e) = handle_connection(socket) {
|
||||
error!("thread failed with an error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -44,44 +50,21 @@ fn handle_connection(socket: TcpStream) -> anyhow::Result<()> {
|
||||
pgbackend.run(&mut MgmtHandler)
|
||||
}
|
||||
|
||||
struct MgmtHandler;
|
||||
|
||||
/// Serialized examples:
|
||||
// {
|
||||
// "session_id": "71d6d03e6d93d99a",
|
||||
// "result": {
|
||||
// "Success": {
|
||||
// "host": "127.0.0.1",
|
||||
// "port": 5432,
|
||||
// "dbname": "stas",
|
||||
// "user": "stas",
|
||||
// "password": "mypass"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// {
|
||||
// "session_id": "71d6d03e6d93d99a",
|
||||
// "result": {
|
||||
// "Failure": "oops"
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // to test manually by sending a query to mgmt interface:
|
||||
// psql -h 127.0.0.1 -p 9999 -c '{"session_id":"4f10dde522e14739","result":{"Success":{"host":"127.0.0.1","port":5432,"dbname":"stas","user":"stas","password":"stas"}}}'
|
||||
#[derive(Deserialize)]
|
||||
/// Known as `kickResponse` in the console.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PsqlSessionResponse {
|
||||
session_id: String,
|
||||
result: PsqlSessionResult,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
enum PsqlSessionResult {
|
||||
Success(auth::DatabaseInfo),
|
||||
Success(DatabaseInfo),
|
||||
Failure(String),
|
||||
}
|
||||
|
||||
/// A message received by `mgmt` when a compute node is ready.
|
||||
pub type ComputeReady = Result<auth::DatabaseInfo, String>;
|
||||
pub type ComputeReady = Result<DatabaseInfo, String>;
|
||||
|
||||
impl PsqlSessionResult {
|
||||
fn into_compute_ready(self) -> ComputeReady {
|
||||
@@ -92,25 +75,51 @@ impl PsqlSessionResult {
|
||||
}
|
||||
}
|
||||
|
||||
impl postgres_backend::Handler for MgmtHandler {
|
||||
fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let res = try_process_query(pgb, query_string);
|
||||
// intercept and log error message
|
||||
if res.is_err() {
|
||||
error!("mgmt query failed: {res:?}");
|
||||
}
|
||||
res
|
||||
/// Compute node connection params provided by the console.
|
||||
/// This struct and its parents are mgmt API implementation
|
||||
/// detail and thus should remain in this module.
|
||||
// TODO: restore deserialization tests from git history.
|
||||
#[derive(Deserialize)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub dbname: String,
|
||||
pub user: String,
|
||||
/// Console always provides a password, but it might
|
||||
/// be inconvenient for debug with local PG instance.
|
||||
pub password: Option<String>,
|
||||
pub project: String,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit sensitive info.
|
||||
impl std::fmt::Debug for DatabaseInfo {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fmt.debug_struct("DatabaseInfo")
|
||||
.field("host", &self.host)
|
||||
.field("port", &self.port)
|
||||
.field("dbname", &self.dbname)
|
||||
.field("user", &self.user)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::Result<()> {
|
||||
info!("got mgmt query [redacted]"); // Content contains password, don't print it
|
||||
// TODO: replace with an http-based protocol.
|
||||
struct MgmtHandler;
|
||||
impl postgres_backend::Handler for MgmtHandler {
|
||||
fn process_query(&mut self, pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<()> {
|
||||
try_process_query(pgb, query).map_err(|e| {
|
||||
error!("failed to process response: {e:?}");
|
||||
e
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let resp: PsqlSessionResponse = serde_json::from_str(query_string)?;
|
||||
fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<()> {
|
||||
let resp: PsqlSessionResponse = serde_json::from_str(query)?;
|
||||
|
||||
let span = info_span!("event", session_id = resp.session_id);
|
||||
let _enter = span.enter();
|
||||
info!("got response: {:?}", resp.result);
|
||||
|
||||
match auth::backend::notify(&resp.session_id, resp.result.into_compute_ready()) {
|
||||
Ok(()) => {
|
||||
@@ -119,9 +128,50 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R
|
||||
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("failed to deliver response to per-client task");
|
||||
pgb.write_message(&BeMessage::ErrorResponse(&e.to_string()))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_db_info() -> anyhow::Result<()> {
|
||||
// with password
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"password": "password",
|
||||
"project": "hello_world",
|
||||
}))?;
|
||||
|
||||
// without password
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"project": "hello_world",
|
||||
}))?;
|
||||
|
||||
// new field (forward compatibility)
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"project": "hello_world",
|
||||
"N.E.W": "forward compatibility check",
|
||||
}))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::config::{ProxyConfig, TlsConfig};
|
||||
use crate::stream::{MeasuredStream, PqStream, Stream};
|
||||
use anyhow::{bail, Context};
|
||||
use futures::TryFutureExt;
|
||||
use metrics::{register_int_counter, IntCounter};
|
||||
use metrics::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec};
|
||||
use once_cell::sync::Lazy;
|
||||
use pq_proto::{BeMessage as Be, *};
|
||||
use std::sync::Arc;
|
||||
@@ -30,10 +30,16 @@ static NUM_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounter> = Lazy::new(|| {
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounter> = Lazy::new(|| {
|
||||
register_int_counter!(
|
||||
"proxy_io_bytes_total",
|
||||
"Number of bytes sent/received between any client and backend."
|
||||
static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
|
||||
register_int_counter_vec!(
|
||||
"proxy_io_bytes_per_client",
|
||||
"Number of bytes sent/received between client and backend.",
|
||||
&[
|
||||
// Received (rx) / sent (tx).
|
||||
"direction",
|
||||
// Proxy can keep calling it `project` internally.
|
||||
"endpoint_id"
|
||||
]
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
@@ -230,16 +236,17 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
application_name: params.get("application_name"),
|
||||
};
|
||||
|
||||
// Authenticate and connect to a compute node.
|
||||
let auth = creds
|
||||
.authenticate(&extra, &mut stream)
|
||||
.instrument(info_span!("auth"))
|
||||
.await;
|
||||
|
||||
let node = async { auth }.or_else(|e| stream.throw_error(e)).await?;
|
||||
let reported_auth_ok = node.reported_auth_ok;
|
||||
let auth_result = async {
|
||||
// `&mut stream` doesn't let us merge those 2 lines.
|
||||
let res = creds.authenticate(&extra, &mut stream).await;
|
||||
async { res }.or_else(|e| stream.throw_error(e)).await
|
||||
}
|
||||
.instrument(info_span!("auth"))
|
||||
.await?;
|
||||
|
||||
let node = auth_result.value;
|
||||
let (db, cancel_closure) = node
|
||||
.config
|
||||
.connect(params)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
@@ -247,7 +254,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
let cancel_key_data = session.enable_query_cancellation(cancel_closure);
|
||||
|
||||
// Report authentication success if we haven't done this already.
|
||||
if !reported_auth_ok {
|
||||
// Note that we do this only (for the most part) after we've connected
|
||||
// to a compute (see above) which performs its own authentication.
|
||||
if !auth_result.reported_auth_ok {
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
@@ -261,17 +270,23 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
/// This function will be called for writes to either direction.
|
||||
fn inc_proxied(cnt: usize) {
|
||||
// Consider inventing something more sophisticated
|
||||
// if this ever becomes a bottleneck (cacheline bouncing).
|
||||
NUM_BYTES_PROXIED_COUNTER.inc_by(cnt as u64);
|
||||
}
|
||||
// TODO: add more identifiers.
|
||||
let metric_id = node.project;
|
||||
|
||||
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx", &metric_id]);
|
||||
let mut client = MeasuredStream::new(stream.into_inner(), |cnt| {
|
||||
// Number of bytes we sent to the client (outbound).
|
||||
m_sent.inc_by(cnt as u64);
|
||||
});
|
||||
|
||||
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx", &metric_id]);
|
||||
let mut db = MeasuredStream::new(db.stream, |cnt| {
|
||||
// Number of bytes the client sent to the compute node (inbound).
|
||||
m_recv.inc_by(cnt as u64);
|
||||
});
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
let mut db = MeasuredStream::new(db.stream, inc_proxied);
|
||||
let mut client = MeasuredStream::new(stream.into_inner(), inc_proxied);
|
||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
|
||||
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user