mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
proxy refactor serverless (#4685)
## Problem
Our serverless backend was a bit jumbled. As a comment indicated, we
were handling SQL-over-HTTP in our `websocket.rs` file.
I've extracted out the `sql_over_http` and `websocket` files from the
`http` module and put them into a new module called `serverless`.
## Summary of changes
```sh
mkdir proxy/src/serverless
mv proxy/src/http/{conn_pool,sql_over_http,websocket}.rs proxy/src/serverless/
mv proxy/src/http/server.rs proxy/src/http/health_server.rs
mv proxy/src/metrics proxy/src/usage_metrics.rs
```
I have also extracted the hyper server and handler from websocket.rs
into `serverless.rs`
This commit is contained in:
616
proxy/src/serverless/conn_pool.rs
Normal file
616
proxy/src/serverless/conn_pool.rs
Normal file
@@ -0,0 +1,616 @@
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use futures::future::poll_fn;
|
||||
use parking_lot::RwLock;
|
||||
use pbkdf2::{
|
||||
password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString},
|
||||
Params, Pbkdf2,
|
||||
};
|
||||
use pq_proto::StartupMessageParams;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use std::{
|
||||
fmt,
|
||||
task::{ready, Poll},
|
||||
};
|
||||
use std::{
|
||||
ops::Deref,
|
||||
sync::atomic::{self, AtomicUsize},
|
||||
};
|
||||
use tokio::time;
|
||||
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
|
||||
|
||||
use crate::{
|
||||
auth, console,
|
||||
proxy::{LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER, NUM_DB_CONNECTIONS_OPENED_COUNTER},
|
||||
usage_metrics::{Ids, MetricCounter, USAGE_METRICS},
|
||||
};
|
||||
use crate::{compute, config};
|
||||
|
||||
use crate::proxy::ConnectMechanism;
|
||||
|
||||
use tracing::{error, warn, Span};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
pub const APP_NAME: &str = "sql_over_http";
|
||||
const MAX_CONNS_PER_ENDPOINT: usize = 20;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnInfo {
|
||||
pub username: String,
|
||||
pub dbname: String,
|
||||
pub hostname: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
impl ConnInfo {
|
||||
// hm, change to hasher to avoid cloning?
|
||||
pub fn db_and_user(&self) -> (String, String) {
|
||||
(self.dbname.clone(), self.username.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ConnInfo {
|
||||
// use custom display to avoid logging password
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}@{}/{}", self.username, self.hostname, self.dbname)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConnPoolEntry {
|
||||
conn: ClientInner,
|
||||
_last_access: std::time::Instant,
|
||||
}
|
||||
|
||||
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub struct EndpointConnPool {
|
||||
pools: HashMap<(String, String), DbUserConnPool>,
|
||||
total_conns: usize,
|
||||
}
|
||||
|
||||
/// 4096 is the number of rounds that SCRAM-SHA-256 recommends.
|
||||
/// It's not the 600,000 that OWASP recommends... but our passwords are high entropy anyway.
|
||||
///
|
||||
/// Still takes 1.4ms to hash on my hardware.
|
||||
/// We don't want to ruin the latency improvements of using the pool by making password verification take too long
|
||||
const PARAMS: Params = Params {
|
||||
rounds: 4096,
|
||||
output_length: 32,
|
||||
};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct DbUserConnPool {
|
||||
conns: Vec<ConnPoolEntry>,
|
||||
password_hash: Option<PasswordHashString>,
|
||||
}
|
||||
|
||||
pub struct GlobalConnPool {
|
||||
// endpoint -> per-endpoint connection pool
|
||||
//
|
||||
// That should be a fairly conteded map, so return reference to the per-endpoint
|
||||
// pool as early as possible and release the lock.
|
||||
global_pool: DashMap<String, Arc<RwLock<EndpointConnPool>>>,
|
||||
|
||||
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
|
||||
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
|
||||
/// It's only used for diagnostics.
|
||||
global_pool_size: AtomicUsize,
|
||||
|
||||
// Maximum number of connections per one endpoint.
|
||||
// Can mix different (dbname, username) connections.
|
||||
// When running out of free slots for a particular endpoint,
|
||||
// falls back to opening a new connection for each request.
|
||||
max_conns_per_endpoint: usize,
|
||||
|
||||
proxy_config: &'static crate::config::ProxyConfig,
|
||||
|
||||
// Using a lock to remove any race conditions.
|
||||
// Eg cleaning up connections while a new connection is returned
|
||||
closed: RwLock<bool>,
|
||||
}
|
||||
|
||||
impl GlobalConnPool {
|
||||
pub fn new(config: &'static crate::config::ProxyConfig) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
global_pool: DashMap::new(),
|
||||
global_pool_size: AtomicUsize::new(0),
|
||||
max_conns_per_endpoint: MAX_CONNS_PER_ENDPOINT,
|
||||
proxy_config: config,
|
||||
closed: RwLock::new(false),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn shutdown(&self) {
|
||||
*self.closed.write() = true;
|
||||
|
||||
self.global_pool.retain(|_, endpoint_pool| {
|
||||
let mut pool = endpoint_pool.write();
|
||||
// by clearing this hashmap, we remove the slots that a connection can be returned to.
|
||||
// when returning, it drops the connection if the slot doesn't exist
|
||||
pool.pools.clear();
|
||||
pool.total_conns = 0;
|
||||
|
||||
false
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
self: &Arc<Self>,
|
||||
conn_info: &ConnInfo,
|
||||
force_new: bool,
|
||||
session_id: uuid::Uuid,
|
||||
) -> anyhow::Result<Client> {
|
||||
let mut client: Option<ClientInner> = None;
|
||||
let mut latency_timer = LatencyTimer::new("http");
|
||||
|
||||
let pool = if force_new {
|
||||
None
|
||||
} else {
|
||||
Some((conn_info.clone(), self.clone()))
|
||||
};
|
||||
|
||||
let mut hash_valid = false;
|
||||
if !force_new {
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
let mut hash = None;
|
||||
|
||||
// find a pool entry by (dbname, username) if exists
|
||||
{
|
||||
let pool = pool.read();
|
||||
if let Some(pool_entries) = pool.pools.get(&conn_info.db_and_user()) {
|
||||
if !pool_entries.conns.is_empty() {
|
||||
hash = pool_entries.password_hash.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// a connection exists in the pool, verify the password hash
|
||||
if let Some(hash) = hash {
|
||||
let pw = conn_info.password.clone();
|
||||
let validate = tokio::task::spawn_blocking(move || {
|
||||
Pbkdf2.verify_password(pw.as_bytes(), &hash.password_hash())
|
||||
})
|
||||
.await?;
|
||||
|
||||
// if the hash is invalid, don't error
|
||||
// we will continue with the regular connection flow
|
||||
if validate.is_ok() {
|
||||
hash_valid = true;
|
||||
let mut pool = pool.write();
|
||||
if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
|
||||
if let Some(entry) = pool_entries.conns.pop() {
|
||||
client = Some(entry.conn);
|
||||
pool.total_conns -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ok return cached connection if found and establish a new one otherwise
|
||||
let new_client = if let Some(client) = client {
|
||||
if client.inner.is_closed() {
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
info!(%conn_id, "pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
connect_to_compute(
|
||||
self.proxy_config,
|
||||
conn_info,
|
||||
conn_id,
|
||||
session_id,
|
||||
latency_timer,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
info!("pool: reusing connection '{conn_info}'");
|
||||
client.session.send(session_id)?;
|
||||
latency_timer.pool_hit();
|
||||
latency_timer.success();
|
||||
return Ok(Client {
|
||||
conn_id: client.conn_id,
|
||||
inner: Some(client),
|
||||
span: Span::current(),
|
||||
pool,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
connect_to_compute(
|
||||
self.proxy_config,
|
||||
conn_info,
|
||||
conn_id,
|
||||
session_id,
|
||||
latency_timer,
|
||||
)
|
||||
.await
|
||||
};
|
||||
|
||||
match &new_client {
|
||||
// clear the hash. it's no longer valid
|
||||
// TODO: update tokio-postgres fork to allow access to this error kind directly
|
||||
Err(err)
|
||||
if hash_valid && err.to_string().contains("password authentication failed") =>
|
||||
{
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
let mut pool = pool.write();
|
||||
if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) {
|
||||
entry.password_hash = None;
|
||||
}
|
||||
}
|
||||
// new password is valid and we should insert/update it
|
||||
Ok(_) if !force_new && !hash_valid => {
|
||||
let pw = conn_info.password.clone();
|
||||
let new_hash = tokio::task::spawn_blocking(move || {
|
||||
let salt = SaltString::generate(rand::rngs::OsRng);
|
||||
Pbkdf2
|
||||
.hash_password_customized(pw.as_bytes(), None, None, PARAMS, &salt)
|
||||
.map(|s| s.serialize())
|
||||
})
|
||||
.await??;
|
||||
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
let mut pool = pool.write();
|
||||
pool.pools
|
||||
.entry(conn_info.db_and_user())
|
||||
.or_default()
|
||||
.password_hash = Some(new_hash);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
new_client.map(|inner| Client {
|
||||
conn_id: inner.conn_id,
|
||||
inner: Some(inner),
|
||||
span: Span::current(),
|
||||
pool,
|
||||
})
|
||||
}
|
||||
|
||||
fn put(&self, conn_info: &ConnInfo, client: ClientInner) -> anyhow::Result<()> {
|
||||
let conn_id = client.conn_id;
|
||||
|
||||
// We want to hold this open while we return. This ensures that the pool can't close
|
||||
// while we are in the middle of returning the connection.
|
||||
let closed = self.closed.read();
|
||||
if *closed {
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is closed");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if client.inner.is_closed() {
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
|
||||
// return connection to the pool
|
||||
let mut returned = false;
|
||||
let mut per_db_size = 0;
|
||||
let total_conns = {
|
||||
let mut pool = pool.write();
|
||||
|
||||
if pool.total_conns < self.max_conns_per_endpoint {
|
||||
// we create this db-user entry in get, so it should not be None
|
||||
if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
|
||||
pool_entries.conns.push(ConnPoolEntry {
|
||||
conn: client,
|
||||
_last_access: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
returned = true;
|
||||
per_db_size = pool_entries.conns.len();
|
||||
|
||||
pool.total_conns += 1;
|
||||
}
|
||||
}
|
||||
|
||||
pool.total_conns
|
||||
};
|
||||
|
||||
// do logging outside of the mutex
|
||||
if returned {
|
||||
info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
|
||||
} else {
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_or_create_endpoint_pool(&self, endpoint: &String) -> Arc<RwLock<EndpointConnPool>> {
|
||||
// fast path
|
||||
if let Some(pool) = self.global_pool.get(endpoint) {
|
||||
return pool.clone();
|
||||
}
|
||||
|
||||
// slow path
|
||||
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
|
||||
pools: HashMap::new(),
|
||||
total_conns: 0,
|
||||
}));
|
||||
|
||||
// find or create a pool for this endpoint
|
||||
let mut created = false;
|
||||
let pool = self
|
||||
.global_pool
|
||||
.entry(endpoint.clone())
|
||||
.or_insert_with(|| {
|
||||
created = true;
|
||||
new_pool
|
||||
})
|
||||
.clone();
|
||||
|
||||
// log new global pool size
|
||||
if created {
|
||||
let global_pool_size = self
|
||||
.global_pool_size
|
||||
.fetch_add(1, atomic::Ordering::Relaxed)
|
||||
+ 1;
|
||||
info!(
|
||||
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
|
||||
);
|
||||
}
|
||||
|
||||
pool
|
||||
}
|
||||
}
|
||||
|
||||
struct TokioMechanism<'a> {
|
||||
conn_info: &'a ConnInfo,
|
||||
session_id: uuid::Uuid,
|
||||
conn_id: uuid::Uuid,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for TokioMechanism<'_> {
|
||||
type Connection = ClientInner;
|
||||
type ConnectError = tokio_postgres::Error;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
connect_to_compute_once(
|
||||
node_info,
|
||||
self.conn_info,
|
||||
timeout,
|
||||
self.conn_id,
|
||||
self.session_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
// Wake up the destination if needed. Code here is a bit involved because
|
||||
// we reuse the code from the usual proxy and we need to prepare few structures
|
||||
// that this code expects.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn connect_to_compute(
|
||||
config: &config::ProxyConfig,
|
||||
conn_info: &ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
session_id: uuid::Uuid,
|
||||
latency_timer: LatencyTimer,
|
||||
) -> anyhow::Result<ClientInner> {
|
||||
let tls = config.tls_config.as_ref();
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
|
||||
let credential_params = StartupMessageParams::new([
|
||||
("user", &conn_info.username),
|
||||
("database", &conn_info.dbname),
|
||||
("application_name", APP_NAME),
|
||||
]);
|
||||
|
||||
let creds = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| {
|
||||
auth::ClientCredentials::parse(
|
||||
&credential_params,
|
||||
Some(&conn_info.hostname),
|
||||
common_names,
|
||||
)
|
||||
})
|
||||
.transpose()?;
|
||||
let extra = console::ConsoleReqExtra {
|
||||
session_id: uuid::Uuid::new_v4(),
|
||||
application_name: Some(APP_NAME),
|
||||
};
|
||||
|
||||
let node_info = creds
|
||||
.wake_compute(&extra)
|
||||
.await?
|
||||
.context("missing cache entry from wake_compute")?;
|
||||
|
||||
crate::proxy::connect_to_compute(
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
session_id,
|
||||
},
|
||||
node_info,
|
||||
&extra,
|
||||
&creds,
|
||||
latency_timer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn connect_to_compute_once(
|
||||
node_info: &console::CachedNodeInfo,
|
||||
conn_info: &ConnInfo,
|
||||
timeout: time::Duration,
|
||||
conn_id: uuid::Uuid,
|
||||
mut session: uuid::Uuid,
|
||||
) -> Result<ClientInner, tokio_postgres::Error> {
|
||||
let mut config = (*node_info.config).clone();
|
||||
|
||||
let (client, mut connection) = config
|
||||
.user(&conn_info.username)
|
||||
.password(&conn_info.password)
|
||||
.dbname(&conn_info.dbname)
|
||||
.connect_timeout(timeout)
|
||||
.connect(tokio_postgres::NoTls)
|
||||
.await?;
|
||||
|
||||
let (tx, mut rx) = tokio::sync::watch::channel(session);
|
||||
|
||||
let span = info_span!(parent: None, "connection", %conn_id);
|
||||
span.in_scope(|| {
|
||||
info!(%conn_info, %session, "new connection");
|
||||
});
|
||||
let ids = Ids {
|
||||
endpoint_id: node_info.aux.endpoint_id.to_string(),
|
||||
branch_id: node_info.aux.branch_id.to_string(),
|
||||
};
|
||||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
NUM_DB_CONNECTIONS_OPENED_COUNTER.with_label_values(&["http"]).inc();
|
||||
scopeguard::defer! {
|
||||
NUM_DB_CONNECTIONS_CLOSED_COUNTER.with_label_values(&["http"]).inc();
|
||||
}
|
||||
poll_fn(move |cx| {
|
||||
if matches!(rx.has_changed(), Ok(true)) {
|
||||
session = *rx.borrow_and_update();
|
||||
info!(%session, "changed session");
|
||||
}
|
||||
|
||||
loop {
|
||||
let message = ready!(connection.poll_message(cx));
|
||||
|
||||
match message {
|
||||
Some(Ok(AsyncMessage::Notice(notice))) => {
|
||||
info!(%session, "notice: {}", notice);
|
||||
}
|
||||
Some(Ok(AsyncMessage::Notification(notif))) => {
|
||||
warn!(%session, pid = notif.process_id(), channel = notif.channel(), "notification received");
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
warn!(%session, "unknown message");
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
error!(%session, "connection error: {}", e);
|
||||
return Poll::Ready(())
|
||||
}
|
||||
None => {
|
||||
info!("connection closed");
|
||||
return Poll::Ready(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}).await
|
||||
}
|
||||
.instrument(span)
|
||||
);
|
||||
|
||||
Ok(ClientInner {
|
||||
inner: client,
|
||||
session: tx,
|
||||
ids,
|
||||
conn_id,
|
||||
})
|
||||
}
|
||||
|
||||
struct ClientInner {
|
||||
inner: tokio_postgres::Client,
|
||||
session: tokio::sync::watch::Sender<uuid::Uuid>,
|
||||
ids: Ids,
|
||||
conn_id: uuid::Uuid,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn metrics(&self) -> Arc<MetricCounter> {
|
||||
USAGE_METRICS.register(self.inner.as_ref().unwrap().ids.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Client {
|
||||
conn_id: uuid::Uuid,
|
||||
span: Span,
|
||||
inner: Option<ClientInner>,
|
||||
pool: Option<(ConnInfo, Arc<GlobalConnPool>)>,
|
||||
}
|
||||
|
||||
pub struct Discard<'a> {
|
||||
conn_id: uuid::Uuid,
|
||||
pool: &'a mut Option<(ConnInfo, Arc<GlobalConnPool>)>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
|
||||
let Self {
|
||||
inner,
|
||||
pool,
|
||||
conn_id,
|
||||
span: _,
|
||||
} = self;
|
||||
(
|
||||
&mut inner
|
||||
.as_mut()
|
||||
.expect("client inner should not be removed")
|
||||
.inner,
|
||||
Discard {
|
||||
pool,
|
||||
conn_id: *conn_id,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
self.inner().1.check_idle(status)
|
||||
}
|
||||
pub fn discard(&mut self) {
|
||||
self.inner().1.discard()
|
||||
}
|
||||
}
|
||||
|
||||
impl Discard<'_> {
|
||||
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
if status != ReadyForQueryStatus::Idle {
|
||||
if let Some((conn_info, _)) = self.pool.take() {
|
||||
info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is not idle")
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn discard(&mut self) {
|
||||
if let Some((conn_info, _)) = self.pool.take() {
|
||||
info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Client {
|
||||
type Target = tokio_postgres::Client;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self
|
||||
.inner
|
||||
.as_ref()
|
||||
.expect("client inner should not be removed")
|
||||
.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Client {
|
||||
fn drop(&mut self) {
|
||||
let client = self
|
||||
.inner
|
||||
.take()
|
||||
.expect("client inner should not be removed");
|
||||
if let Some((conn_info, conn_pool)) = self.pool.take() {
|
||||
let current_span = self.span.clone();
|
||||
// return connection to the pool
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let _span = current_span.enter();
|
||||
let _ = conn_pool.put(&conn_info, client);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
890
proxy/src/serverless/sql_over_http.rs
Normal file
890
proxy/src/serverless/sql_over_http.rs
Normal file
@@ -0,0 +1,890 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
use hyper::body::HttpBody;
|
||||
use hyper::header;
|
||||
use hyper::http::HeaderName;
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::Response;
|
||||
use hyper::StatusCode;
|
||||
use hyper::{Body, HeaderMap, Request};
|
||||
use serde_json::json;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
use tokio_postgres::types::Kind;
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::GenericClient;
|
||||
use tokio_postgres::IsolationLevel;
|
||||
use tokio_postgres::ReadyForQueryStatus;
|
||||
use tokio_postgres::Row;
|
||||
use tokio_postgres::Transaction;
|
||||
use tracing::error;
|
||||
use tracing::instrument;
|
||||
use url::Url;
|
||||
use utils::http::error::ApiError;
|
||||
use utils::http::json::json_response;
|
||||
|
||||
use crate::config::HttpConfig;
|
||||
use crate::proxy::{NUM_CONNECTIONS_ACCEPTED_COUNTER, NUM_CONNECTIONS_CLOSED_COUNTER};
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::conn_pool::GlobalConnPool;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct QueryData {
|
||||
query: String,
|
||||
params: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct BatchQueryData {
|
||||
queries: Vec<QueryData>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum Payload {
|
||||
Single(QueryData),
|
||||
Batch(BatchQueryData),
|
||||
}
|
||||
|
||||
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
|
||||
const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
|
||||
|
||||
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
|
||||
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
|
||||
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
|
||||
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
|
||||
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
|
||||
static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
|
||||
|
||||
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
||||
|
||||
//
|
||||
// Convert json non-string types to strings, so that they can be passed to Postgres
|
||||
// as parameters.
|
||||
//
|
||||
fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
|
||||
json.iter()
|
||||
.map(|value| {
|
||||
match value {
|
||||
// special care for nulls
|
||||
Value::Null => None,
|
||||
|
||||
// convert to text with escaping
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
|
||||
|
||||
// avoid escaping here, as we pass this as a parameter
|
||||
Value::String(s) => Some(s.to_string()),
|
||||
|
||||
// special care for arrays
|
||||
Value::Array(_) => json_array_to_pg_array(value),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
//
|
||||
// Serialize a JSON array to a Postgres array. Contrary to the strings in the params
|
||||
// in the array we need to escape the strings. Postgres is okay with arrays of form
|
||||
// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving
|
||||
// it for Postgres to check.
|
||||
//
|
||||
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
|
||||
//
|
||||
fn json_array_to_pg_array(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
// special care for nulls
|
||||
Value::Null => None,
|
||||
|
||||
// convert to text with escaping
|
||||
// here string needs to be escaped, as it is part of the array
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
|
||||
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
|
||||
|
||||
// recurse into array
|
||||
Value::Array(arr) => {
|
||||
let vals = arr
|
||||
.iter()
|
||||
.map(json_array_to_pg_array)
|
||||
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
|
||||
Some(format!("{{{}}}", vals))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
headers: &HeaderMap,
|
||||
sni_hostname: Option<String>,
|
||||
) -> Result<ConnInfo, anyhow::Error> {
|
||||
let connection_string = headers
|
||||
.get("Neon-Connection-String")
|
||||
.ok_or(anyhow::anyhow!("missing connection string"))?
|
||||
.to_str()?;
|
||||
|
||||
let connection_url = Url::parse(connection_string)?;
|
||||
|
||||
let protocol = connection_url.scheme();
|
||||
if protocol != "postgres" && protocol != "postgresql" {
|
||||
return Err(anyhow::anyhow!(
|
||||
"connection string must start with postgres: or postgresql:"
|
||||
));
|
||||
}
|
||||
|
||||
let mut url_path = connection_url
|
||||
.path_segments()
|
||||
.ok_or(anyhow::anyhow!("missing database name"))?;
|
||||
|
||||
let dbname = url_path
|
||||
.next()
|
||||
.ok_or(anyhow::anyhow!("invalid database name"))?;
|
||||
|
||||
let username = connection_url.username();
|
||||
if username.is_empty() {
|
||||
return Err(anyhow::anyhow!("missing username"));
|
||||
}
|
||||
|
||||
let password = connection_url
|
||||
.password()
|
||||
.ok_or(anyhow::anyhow!("no password"))?;
|
||||
|
||||
// TLS certificate selector now based on SNI hostname, so if we are running here
|
||||
// we are sure that SNI hostname is set to one of the configured domain names.
|
||||
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
|
||||
|
||||
let hostname = connection_url
|
||||
.host_str()
|
||||
.ok_or(anyhow::anyhow!("no host"))?;
|
||||
|
||||
let host_header = headers
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.split(':').next());
|
||||
|
||||
if hostname != sni_hostname {
|
||||
return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
|
||||
} else if let Some(h) = host_header {
|
||||
if h != hostname {
|
||||
return Err(anyhow::anyhow!("mismatched host header and hostname"));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ConnInfo {
|
||||
username: username.to_owned(),
|
||||
dbname: dbname.to_owned(),
|
||||
hostname: hostname.to_owned(),
|
||||
password: password.to_owned(),
|
||||
})
|
||||
}
|
||||
|
||||
// TODO: return different http error codes
|
||||
pub async fn handle(
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
session_id: uuid::Uuid,
|
||||
config: &'static HttpConfig,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
config.sql_over_http_timeout,
|
||||
handle_inner(request, sni_hostname, conn_pool, session_id),
|
||||
)
|
||||
.await;
|
||||
let mut response = match result {
|
||||
Ok(r) => match r {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let message = format!("{:?}", e);
|
||||
let code = e.downcast_ref::<tokio_postgres::Error>().and_then(|e| {
|
||||
e.code()
|
||||
.map(|s| serde_json::to_value(s.code()).unwrap_or_default())
|
||||
});
|
||||
let code = match code {
|
||||
Some(c) => c,
|
||||
None => Value::Null,
|
||||
};
|
||||
error!(
|
||||
?code,
|
||||
"sql-over-http per-client task finished with an error: {e:#}"
|
||||
);
|
||||
// TODO: this shouldn't always be bad request.
|
||||
json_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
json!({ "message": message, "code": code }),
|
||||
)?
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
let message = format!(
|
||||
"HTTP-Connection timed out, execution time exeeded {} seconds",
|
||||
config.sql_over_http_timeout.as_secs()
|
||||
);
|
||||
error!(message);
|
||||
json_response(
|
||||
StatusCode::GATEWAY_TIMEOUT,
|
||||
json!({ "message": message, "code": StatusCode::GATEWAY_TIMEOUT.as_u16() }),
|
||||
)?
|
||||
}
|
||||
};
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Origin",
|
||||
hyper::http::HeaderValue::from_static("*"),
|
||||
);
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[instrument(name = "sql-over-http", skip_all)]
|
||||
async fn handle_inner(
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
session_id: uuid::Uuid,
|
||||
) -> anyhow::Result<Response<Body>> {
|
||||
NUM_CONNECTIONS_ACCEPTED_COUNTER
|
||||
.with_label_values(&["http"])
|
||||
.inc();
|
||||
scopeguard::defer! {
|
||||
NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&["http"]).inc();
|
||||
}
|
||||
|
||||
//
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
let headers = request.headers();
|
||||
let conn_info = get_conn_info(headers, sni_hostname)?;
|
||||
|
||||
// Determine the output options. Default behaviour is 'false'. Anything that is not
|
||||
// strictly 'true' assumed to be false.
|
||||
let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
|
||||
let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
// Allow connection pooling only if explicitly requested
|
||||
let allow_pool = headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
// isolation level, read only and deferrable
|
||||
|
||||
let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned();
|
||||
let txn_isolation_level = match txn_isolation_level_raw {
|
||||
Some(ref x) => Some(match x.as_bytes() {
|
||||
b"Serializable" => IsolationLevel::Serializable,
|
||||
b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
|
||||
b"ReadCommitted" => IsolationLevel::ReadCommitted,
|
||||
b"RepeatableRead" => IsolationLevel::RepeatableRead,
|
||||
_ => bail!("invalid isolation level"),
|
||||
}),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
|
||||
let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
let request_content_length = match request.body().size_hint().upper() {
|
||||
Some(v) => v,
|
||||
None => MAX_REQUEST_SIZE + 1,
|
||||
};
|
||||
|
||||
// we don't have a streaming request support yet so this is to prevent OOM
|
||||
// from a malicious user sending an extremely large request body
|
||||
if request_content_length > MAX_REQUEST_SIZE {
|
||||
return Err(anyhow::anyhow!(
|
||||
"request is too large (max is {MAX_REQUEST_SIZE} bytes)"
|
||||
));
|
||||
}
|
||||
|
||||
//
|
||||
// Read the query and query params from the request body
|
||||
//
|
||||
let body = hyper::body::to_bytes(request.into_body()).await?;
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
|
||||
let mut client = conn_pool.get(&conn_info, !allow_pool, session_id).await?;
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/json");
|
||||
|
||||
//
|
||||
// Now execute the query and return the result
|
||||
//
|
||||
let mut size = 0;
|
||||
let result =
|
||||
match payload {
|
||||
Payload::Single(stmt) => {
|
||||
let (status, results) =
|
||||
query_to_json(&*client, stmt, &mut 0, raw_output, array_mode)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
client.discard();
|
||||
e
|
||||
})?;
|
||||
client.check_idle(status);
|
||||
results
|
||||
}
|
||||
Payload::Batch(statements) => {
|
||||
let (inner, mut discard) = client.inner();
|
||||
let mut builder = inner.build_transaction();
|
||||
if let Some(isolation_level) = txn_isolation_level {
|
||||
builder = builder.isolation_level(isolation_level);
|
||||
}
|
||||
if txn_read_only {
|
||||
builder = builder.read_only(true);
|
||||
}
|
||||
if txn_deferrable {
|
||||
builder = builder.deferrable(true);
|
||||
}
|
||||
|
||||
let transaction = builder.start().await.map_err(|e| {
|
||||
// if we cannot start a transaction, we should return immediately
|
||||
// and not return to the pool. connection is clearly broken
|
||||
discard.discard();
|
||||
e
|
||||
})?;
|
||||
|
||||
let results =
|
||||
match query_batch(&transaction, statements, &mut size, raw_output, array_mode)
|
||||
.await
|
||||
{
|
||||
Ok(results) => {
|
||||
let status = transaction.commit().await.map_err(|e| {
|
||||
// if we cannot commit - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
e
|
||||
})?;
|
||||
discard.check_idle(status);
|
||||
results
|
||||
}
|
||||
Err(err) => {
|
||||
let status = transaction.rollback().await.map_err(|e| {
|
||||
// if we cannot rollback - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
e
|
||||
})?;
|
||||
discard.check_idle(status);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
if txn_read_only {
|
||||
response = response.header(
|
||||
TXN_READ_ONLY.clone(),
|
||||
HeaderValue::try_from(txn_read_only.to_string())?,
|
||||
);
|
||||
}
|
||||
if txn_deferrable {
|
||||
response = response.header(
|
||||
TXN_DEFERRABLE.clone(),
|
||||
HeaderValue::try_from(txn_deferrable.to_string())?,
|
||||
);
|
||||
}
|
||||
if let Some(txn_isolation_level) = txn_isolation_level_raw {
|
||||
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
|
||||
}
|
||||
json!({ "results": results })
|
||||
}
|
||||
};
|
||||
|
||||
let metrics = client.metrics();
|
||||
|
||||
// how could this possibly fail
|
||||
let body = serde_json::to_string(&result).expect("json serialization should not fail");
|
||||
let len = body.len();
|
||||
let response = response
|
||||
.body(Body::from(body))
|
||||
// only fails if invalid status code or invalid header/values are given.
|
||||
// these are not user configurable so it cannot fail dynamically
|
||||
.expect("building response payload should not fail");
|
||||
|
||||
// count the egress bytes - we miss the TLS and header overhead but oh well...
|
||||
// moving this later in the stack is going to be a lot of effort and ehhhh
|
||||
metrics.record_egress(len as u64);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn query_batch(
|
||||
transaction: &Transaction<'_>,
|
||||
queries: BatchQueryData,
|
||||
total_size: &mut usize,
|
||||
raw_output: bool,
|
||||
array_mode: bool,
|
||||
) -> anyhow::Result<Vec<Value>> {
|
||||
let mut results = Vec::with_capacity(queries.queries.len());
|
||||
let mut current_size = 0;
|
||||
for stmt in queries.queries {
|
||||
// TODO: maybe we should check that the transaction bit is set here
|
||||
let (_, values) =
|
||||
query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?;
|
||||
results.push(values);
|
||||
}
|
||||
*total_size += current_size;
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn query_to_json<T: GenericClient>(
|
||||
client: &T,
|
||||
data: QueryData,
|
||||
current_size: &mut usize,
|
||||
raw_output: bool,
|
||||
array_mode: bool,
|
||||
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
|
||||
let query_params = json_to_pg_text(data.params);
|
||||
let row_stream = client.query_raw_txt(&data.query, query_params).await?;
|
||||
|
||||
// Manually drain the stream into a vector to leave row_stream hanging
|
||||
// around to get a command tag. Also check that the response is not too
|
||||
// big.
|
||||
pin_mut!(row_stream);
|
||||
let mut rows: Vec<tokio_postgres::Row> = Vec::new();
|
||||
while let Some(row) = row_stream.next().await {
|
||||
let row = row?;
|
||||
*current_size += row.body_len();
|
||||
rows.push(row);
|
||||
// we don't have a streaming response support yet so this is to prevent OOM
|
||||
// from a malicious query (eg a cross join)
|
||||
if *current_size > MAX_RESPONSE_SIZE {
|
||||
return Err(anyhow::anyhow!(
|
||||
"response is too large (max is {MAX_RESPONSE_SIZE} bytes)"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let ready = row_stream.ready_status();
|
||||
|
||||
// grab the command tag and number of rows affected
|
||||
let command_tag = row_stream.command_tag().unwrap_or_default();
|
||||
let mut command_tag_split = command_tag.split(' ');
|
||||
let command_tag_name = command_tag_split.next().unwrap_or_default();
|
||||
let command_tag_count = if command_tag_name == "INSERT" {
|
||||
// INSERT returns OID first and then number of rows
|
||||
command_tag_split.nth(1)
|
||||
} else {
|
||||
// other commands return number of rows (if any)
|
||||
command_tag_split.next()
|
||||
}
|
||||
.and_then(|s| s.parse::<i64>().ok());
|
||||
|
||||
let fields = if !rows.is_empty() {
|
||||
rows[0]
|
||||
.columns()
|
||||
.iter()
|
||||
.map(|c| {
|
||||
json!({
|
||||
"name": Value::String(c.name().to_owned()),
|
||||
"dataTypeID": Value::Number(c.type_().oid().into()),
|
||||
"tableID": c.table_oid(),
|
||||
"columnID": c.column_id(),
|
||||
"dataTypeSize": c.type_size(),
|
||||
"dataTypeModifier": c.type_modifier(),
|
||||
"format": "text",
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
// convert rows to JSON
|
||||
let rows = rows
|
||||
.iter()
|
||||
.map(|row| pg_text_row_to_json(row, raw_output, array_mode))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// resulting JSON format is based on the format of node-postgres result
|
||||
Ok((
|
||||
ready,
|
||||
json!({
|
||||
"command": command_tag_name,
|
||||
"rowCount": command_tag_count,
|
||||
"rows": rows,
|
||||
"fields": fields,
|
||||
"rowAsArray": array_mode,
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
//
|
||||
// Convert postgres row with text-encoded values to JSON object
|
||||
//
|
||||
pub fn pg_text_row_to_json(
|
||||
row: &Row,
|
||||
raw_output: bool,
|
||||
array_mode: bool,
|
||||
) -> Result<Value, anyhow::Error> {
|
||||
let iter = row.columns().iter().enumerate().map(|(i, column)| {
|
||||
let name = column.name();
|
||||
let pg_value = row.as_text(i)?;
|
||||
let json_value = if raw_output {
|
||||
match pg_value {
|
||||
Some(v) => Value::String(v.to_string()),
|
||||
None => Value::Null,
|
||||
}
|
||||
} else {
|
||||
pg_text_to_json(pg_value, column.type_())?
|
||||
};
|
||||
Ok((name.to_string(), json_value))
|
||||
});
|
||||
|
||||
if array_mode {
|
||||
// drop keys and aggregate into array
|
||||
let arr = iter
|
||||
.map(|r| r.map(|(_key, val)| val))
|
||||
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Array(arr))
|
||||
} else {
|
||||
let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Object(obj))
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Convert postgres text-encoded value to JSON value
|
||||
//
|
||||
pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
|
||||
if let Some(val) = pg_value {
|
||||
if let Kind::Array(elem_type) = pg_type.kind() {
|
||||
return pg_array_parse(val, elem_type);
|
||||
}
|
||||
|
||||
match *pg_type {
|
||||
Type::BOOL => Ok(Value::Bool(val == "t")),
|
||||
Type::INT2 | Type::INT4 => {
|
||||
let val = val.parse::<i32>()?;
|
||||
Ok(Value::Number(serde_json::Number::from(val)))
|
||||
}
|
||||
Type::FLOAT4 | Type::FLOAT8 => {
|
||||
let fval = val.parse::<f64>()?;
|
||||
let num = serde_json::Number::from_f64(fval);
|
||||
if let Some(num) = num {
|
||||
Ok(Value::Number(num))
|
||||
} else {
|
||||
// Pass Nan, Inf, -Inf as strings
|
||||
// JS JSON.stringify() does converts them to null, but we
|
||||
// want to preserve them, so we pass them as strings
|
||||
Ok(Value::String(val.to_string()))
|
||||
}
|
||||
}
|
||||
Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
|
||||
_ => Ok(Value::String(val.to_string())),
|
||||
}
|
||||
} else {
|
||||
Ok(Value::Null)
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse postgres array into JSON array.
|
||||
//
|
||||
// This is a bit involved because we need to handle nested arrays and quoted
|
||||
// values. Unlike postgres we don't check that all nested arrays have the same
|
||||
// dimensions, we just return them as is.
|
||||
//
|
||||
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, anyhow::Error> {
|
||||
_pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v)
|
||||
}
|
||||
|
||||
fn _pg_array_parse(
|
||||
pg_array: &str,
|
||||
elem_type: &Type,
|
||||
nested: bool,
|
||||
) -> Result<(Value, usize), anyhow::Error> {
|
||||
let mut pg_array_chr = pg_array.char_indices();
|
||||
let mut level = 0;
|
||||
let mut quote = false;
|
||||
let mut entries: Vec<Value> = Vec::new();
|
||||
let mut entry = String::new();
|
||||
|
||||
// skip bounds decoration
|
||||
if let Some('[') = pg_array.chars().next() {
|
||||
for (_, c) in pg_array_chr.by_ref() {
|
||||
if c == '=' {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn push_checked(
|
||||
entry: &mut String,
|
||||
entries: &mut Vec<Value>,
|
||||
elem_type: &Type,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
if !entry.is_empty() {
|
||||
// While in usual postgres response we get nulls as None and everything else
|
||||
// as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
|
||||
// string with value 'NULL' will be represented by '"NULL"'). So catch NULLs
|
||||
// here while we have quotation info and convert them to None.
|
||||
if entry == "NULL" {
|
||||
entries.push(pg_text_to_json(None, elem_type)?);
|
||||
} else {
|
||||
entries.push(pg_text_to_json(Some(entry), elem_type)?);
|
||||
}
|
||||
entry.clear();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
while let Some((mut i, mut c)) = pg_array_chr.next() {
|
||||
let mut escaped = false;
|
||||
|
||||
if c == '\\' {
|
||||
escaped = true;
|
||||
(i, c) = pg_array_chr.next().unwrap();
|
||||
}
|
||||
|
||||
match c {
|
||||
'{' if !quote => {
|
||||
level += 1;
|
||||
if level > 1 {
|
||||
let (res, off) = _pg_array_parse(&pg_array[i..], elem_type, true)?;
|
||||
entries.push(res);
|
||||
for _ in 0..off - 1 {
|
||||
pg_array_chr.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
'}' if !quote => {
|
||||
level -= 1;
|
||||
if level == 0 {
|
||||
push_checked(&mut entry, &mut entries, elem_type)?;
|
||||
if nested {
|
||||
return Ok((Value::Array(entries), i));
|
||||
}
|
||||
}
|
||||
}
|
||||
'"' if !escaped => {
|
||||
if quote {
|
||||
// end of quoted string, so push it manually without any checks
|
||||
// for emptiness or nulls
|
||||
entries.push(pg_text_to_json(Some(&entry), elem_type)?);
|
||||
entry.clear();
|
||||
}
|
||||
quote = !quote;
|
||||
}
|
||||
',' if !quote => {
|
||||
push_checked(&mut entry, &mut entries, elem_type)?;
|
||||
}
|
||||
_ => {
|
||||
entry.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if level != 0 {
|
||||
return Err(anyhow::anyhow!("unbalanced array"));
|
||||
}
|
||||
|
||||
Ok((Value::Array(entries), 0))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_atomic_types_to_pg_params() {
|
||||
let json = vec![Value::Bool(true), Value::Bool(false)];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some("true".to_owned()), Some("false".to_owned())]
|
||||
);
|
||||
|
||||
let json = vec![Value::Number(serde_json::Number::from(42))];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
assert_eq!(pg_params, vec![Some("42".to_owned())]);
|
||||
|
||||
let json = vec![Value::String("foo\"".to_string())];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
|
||||
|
||||
let json = vec![Value::Null];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
assert_eq!(pg_params, vec![None]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_array_to_pg_array() {
|
||||
// atoms and escaping
|
||||
let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
|
||||
let json: Value = serde_json::from_str(json).unwrap();
|
||||
let pg_params = json_to_pg_text(vec![json]);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some(
|
||||
"{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned()
|
||||
)]
|
||||
);
|
||||
|
||||
// nested arrays
|
||||
let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
|
||||
let json: Value = serde_json::from_str(json).unwrap();
|
||||
let pg_params = json_to_pg_text(vec![json]);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some(
|
||||
"{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned()
|
||||
)]
|
||||
);
|
||||
// array of objects
|
||||
let json = r#"[{"foo": 1},{"bar": 2}]"#;
|
||||
let json: Value = serde_json::from_str(json).unwrap();
|
||||
let pg_params = json_to_pg_text(vec![json]);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_types_parse() {
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(),
|
||||
json!("foo")
|
||||
);
|
||||
assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null));
|
||||
assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42));
|
||||
assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42));
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("42"), &Type::INT8).unwrap(),
|
||||
json!("42")
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(),
|
||||
json!(42.42)
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(),
|
||||
json!(42.42)
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(),
|
||||
json!("NaN")
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(),
|
||||
json!("Infinity")
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(),
|
||||
json!("-Infinity")
|
||||
);
|
||||
|
||||
let json: Value =
|
||||
serde_json::from_str("{\"s\":\"str\",\"n\":42,\"f\":4.2,\"a\":[null,3,\"a\"]}")
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
pg_text_to_json(
|
||||
Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#),
|
||||
&Type::JSONB
|
||||
)
|
||||
.unwrap(),
|
||||
json
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_array_parse_text() {
|
||||
fn pt(pg_arr: &str) -> Value {
|
||||
pg_array_parse(pg_arr, &Type::TEXT).unwrap()
|
||||
}
|
||||
assert_eq!(
|
||||
pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#),
|
||||
json!(["aa\"\\,a", "cha", "bbbb"])
|
||||
);
|
||||
assert_eq!(
|
||||
pt(r#"{{"foo","bar"},{"bee","bop"}}"#),
|
||||
json!([["foo", "bar"], ["bee", "bop"]])
|
||||
);
|
||||
assert_eq!(
|
||||
pt(r#"{{{{"foo",NULL,"bop",bup}}}}"#),
|
||||
json!([[[["foo", null, "bop", "bup"]]]])
|
||||
);
|
||||
assert_eq!(
|
||||
pt(r#"{{"1",2,3},{4,NULL,6},{NULL,NULL,NULL}}"#),
|
||||
json!([["1", "2", "3"], ["4", null, "6"], [null, null, null]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_array_parse_bool() {
|
||||
fn pb(pg_arr: &str) -> Value {
|
||||
pg_array_parse(pg_arr, &Type::BOOL).unwrap()
|
||||
}
|
||||
assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true]));
|
||||
assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]]));
|
||||
assert_eq!(
|
||||
pb(r#"{{t,f},{f,t}}"#),
|
||||
json!([[true, false], [false, true]])
|
||||
);
|
||||
assert_eq!(
|
||||
pb(r#"{{t,NULL},{NULL,f}}"#),
|
||||
json!([[true, null], [null, false]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_array_parse_numbers() {
|
||||
fn pn(pg_arr: &str, ty: &Type) -> Value {
|
||||
pg_array_parse(pg_arr, ty).unwrap()
|
||||
}
|
||||
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3]));
|
||||
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3]));
|
||||
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT8), json!(["1", "2", "3"]));
|
||||
assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT4), json!([1.0, 2.0, 3.0]));
|
||||
assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT8), json!([1.0, 2.0, 3.0]));
|
||||
assert_eq!(
|
||||
pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT4),
|
||||
json!([1.1, 2.2, 3.3])
|
||||
);
|
||||
assert_eq!(
|
||||
pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT8),
|
||||
json!([1.1, 2.2, 3.3])
|
||||
);
|
||||
assert_eq!(
|
||||
pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT4),
|
||||
json!(["NaN", "Infinity", "-Infinity"])
|
||||
);
|
||||
assert_eq!(
|
||||
pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT8),
|
||||
json!(["NaN", "Infinity", "-Infinity"])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_array_with_decoration() {
|
||||
fn p(pg_arr: &str) -> Value {
|
||||
pg_array_parse(pg_arr, &Type::INT2).unwrap()
|
||||
}
|
||||
assert_eq!(
|
||||
p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#),
|
||||
json!([[[1, 2, 3], [4, 5, 6]]])
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn test_pg_array_parse_json() {
|
||||
fn pt(pg_arr: &str) -> Value {
|
||||
pg_array_parse(pg_arr, &Type::JSONB).unwrap()
|
||||
}
|
||||
assert_eq!(pt(r#"{"{}"}"#), json!([{}]));
|
||||
assert_eq!(
|
||||
pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#),
|
||||
json!([{"foo": 1, "bar": 2}])
|
||||
);
|
||||
assert_eq!(
|
||||
pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#),
|
||||
json!([{"foo": 1}, {"bar": 2}])
|
||||
);
|
||||
assert_eq!(
|
||||
pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#),
|
||||
json!([[{"foo": 1}, {"bar": 2}]])
|
||||
);
|
||||
}
|
||||
}
|
||||
146
proxy/src/serverless/websocket.rs
Normal file
146
proxy/src/serverless/websocket.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
use crate::{
|
||||
cancellation::CancelMap,
|
||||
config::ProxyConfig,
|
||||
error::io_error,
|
||||
proxy::{handle_client, ClientMode},
|
||||
};
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{Sink, Stream};
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tracing::warn;
|
||||
|
||||
// TODO: use `std::sync::Exclusive` once it's stabilized.
|
||||
// Tracking issue: https://github.com/rust-lang/rust/issues/98407.
|
||||
use sync_wrapper::SyncWrapper;
|
||||
|
||||
pin_project! {
|
||||
/// This is a wrapper around a [`WebSocketStream`] that
|
||||
/// implements [`AsyncRead`] and [`AsyncWrite`].
|
||||
pub struct WebSocketRw {
|
||||
#[pin]
|
||||
stream: SyncWrapper<WebSocketStream<Upgraded>>,
|
||||
bytes: Bytes,
|
||||
}
|
||||
}
|
||||
|
||||
impl WebSocketRw {
|
||||
pub fn new(stream: WebSocketStream<Upgraded>) -> Self {
|
||||
Self {
|
||||
stream: stream.into(),
|
||||
bytes: Bytes::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for WebSocketRw {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let mut stream = self.project().stream.get_pin_mut();
|
||||
|
||||
ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
|
||||
match stream.as_mut().start_send(Message::Binary(buf.into())) {
|
||||
Ok(()) => Poll::Ready(Ok(buf.len())),
|
||||
Err(e) => Poll::Ready(Err(io_error(e))),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let stream = self.project().stream.get_pin_mut();
|
||||
stream.poll_flush(cx).map_err(io_error)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let stream = self.project().stream.get_pin_mut();
|
||||
stream.poll_close(cx).map_err(io_error)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for WebSocketRw {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if buf.remaining() > 0 {
|
||||
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
|
||||
let len = std::cmp::min(bytes.len(), buf.remaining());
|
||||
buf.put_slice(&bytes[..len]);
|
||||
self.consume(len);
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncBufRead for WebSocketRw {
|
||||
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
|
||||
// Please refer to poll_fill_buf's documentation.
|
||||
const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
|
||||
|
||||
let mut this = self.project();
|
||||
loop {
|
||||
if !this.bytes.chunk().is_empty() {
|
||||
let chunk = (*this.bytes).chunk();
|
||||
return Poll::Ready(Ok(chunk));
|
||||
}
|
||||
|
||||
let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
|
||||
match res.transpose().map_err(io_error)? {
|
||||
Some(message) => match message {
|
||||
Message::Ping(_) => {}
|
||||
Message::Pong(_) => {}
|
||||
Message::Text(text) => {
|
||||
// We expect to see only binary messages.
|
||||
let error = "unexpected text message in the websocket";
|
||||
warn!(length = text.len(), error);
|
||||
return Poll::Ready(Err(io_error(error)));
|
||||
}
|
||||
Message::Frame(_) => {
|
||||
// This case is impossible according to Frame's doc.
|
||||
panic!("unexpected raw frame in the websocket");
|
||||
}
|
||||
Message::Binary(chunk) => {
|
||||
assert!(this.bytes.is_empty());
|
||||
*this.bytes = Bytes::from(chunk);
|
||||
}
|
||||
Message::Close(_) => return EOF,
|
||||
},
|
||||
None => return EOF,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(self: Pin<&mut Self>, amount: usize) {
|
||||
self.project().bytes.advance(amount);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve_websocket(
|
||||
websocket: HyperWebsocket,
|
||||
config: &'static ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
hostname: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
handle_client(
|
||||
config,
|
||||
cancel_map,
|
||||
session_id,
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user