mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-15 04:00:38 +00:00
Merge pull request #5060 from neondatabase/releases/2023-08-22
Release 2023-08-22
This commit is contained in:
@@ -13,6 +13,7 @@ bytes = { workspace = true, features = ["serde"] }
|
||||
chrono.workspace = true
|
||||
clap.workspace = true
|
||||
consumption_metrics.workspace = true
|
||||
dashmap.workspace = true
|
||||
futures.workspace = true
|
||||
git-version.workspace = true
|
||||
hashbrown.workspace = true
|
||||
@@ -29,7 +30,7 @@ metrics.workspace = true
|
||||
once_cell.workspace = true
|
||||
opentelemetry.workspace = true
|
||||
parking_lot.workspace = true
|
||||
pbkdf2.workspace = true
|
||||
pbkdf2 = { workspace = true, features = ["simple", "std"] }
|
||||
pin-project-lite.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
pq_proto.workspace = true
|
||||
|
||||
@@ -36,7 +36,18 @@ pub(super) async fn authenticate(
|
||||
AuthInfo::Scram(secret) => {
|
||||
info!("auth endpoint chooses SCRAM");
|
||||
let scram = auth::Scram(&secret);
|
||||
let client_key = match flow.begin(scram).await?.authenticate().await? {
|
||||
|
||||
let auth_flow = flow.begin(scram).await.map_err(|error| {
|
||||
warn!(?error, "error sending scram acknowledgement");
|
||||
error
|
||||
})?;
|
||||
|
||||
let auth_outcome = auth_flow.authenticate().await.map_err(|error| {
|
||||
warn!(?error, "error processing scram messages");
|
||||
error
|
||||
})?;
|
||||
|
||||
let client_key = match auth_outcome {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
@@ -51,7 +62,6 @@ pub(super) async fn authenticate(
|
||||
}
|
||||
};
|
||||
|
||||
info!("compute node's state has likely changed; requesting a wake-up");
|
||||
let mut num_retries = 0;
|
||||
let mut node = loop {
|
||||
let wake_res = api.wake_compute(extra, creds).await;
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::Mutex;
|
||||
use dashmap::DashMap;
|
||||
use futures::future::poll_fn;
|
||||
use parking_lot::RwLock;
|
||||
use pbkdf2::{
|
||||
password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString},
|
||||
Params, Pbkdf2,
|
||||
};
|
||||
use pq_proto::StartupMessageParams;
|
||||
use std::fmt;
|
||||
use std::sync::atomic::{self, AtomicUsize};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use std::{
|
||||
fmt,
|
||||
task::{ready, Poll},
|
||||
};
|
||||
use tokio::time;
|
||||
use tokio_postgres::AsyncMessage;
|
||||
|
||||
use crate::{auth, console};
|
||||
use crate::{compute, config};
|
||||
@@ -13,8 +24,8 @@ use super::sql_over_http::MAX_RESPONSE_SIZE;
|
||||
|
||||
use crate::proxy::ConnectMechanism;
|
||||
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::{error, warn};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
pub const APP_NAME: &str = "sql_over_http";
|
||||
const MAX_CONNS_PER_ENDPOINT: usize = 20;
|
||||
@@ -42,23 +53,44 @@ impl fmt::Display for ConnInfo {
|
||||
}
|
||||
|
||||
struct ConnPoolEntry {
|
||||
conn: tokio_postgres::Client,
|
||||
conn: Client,
|
||||
_last_access: std::time::Instant,
|
||||
}
|
||||
|
||||
// Per-endpoint connection pool, (dbname, username) -> Vec<ConnPoolEntry>
|
||||
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub struct EndpointConnPool {
|
||||
pools: HashMap<(String, String), Vec<ConnPoolEntry>>,
|
||||
pools: HashMap<(String, String), DbUserConnPool>,
|
||||
total_conns: usize,
|
||||
}
|
||||
|
||||
/// 4096 is the number of rounds that SCRAM-SHA-256 recommends.
|
||||
/// It's not the 600,000 that OWASP recommends... but our passwords are high entropy anyway.
|
||||
///
|
||||
/// Still takes 1.4ms to hash on my hardware.
|
||||
/// We don't want to ruin the latency improvements of using the pool by making password verification take too long
|
||||
const PARAMS: Params = Params {
|
||||
rounds: 4096,
|
||||
output_length: 32,
|
||||
};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct DbUserConnPool {
|
||||
conns: Vec<ConnPoolEntry>,
|
||||
password_hash: Option<PasswordHashString>,
|
||||
}
|
||||
|
||||
pub struct GlobalConnPool {
|
||||
// endpoint -> per-endpoint connection pool
|
||||
//
|
||||
// That should be a fairly conteded map, so return reference to the per-endpoint
|
||||
// pool as early as possible and release the lock.
|
||||
global_pool: Mutex<HashMap<String, Arc<Mutex<EndpointConnPool>>>>,
|
||||
global_pool: DashMap<String, Arc<RwLock<EndpointConnPool>>>,
|
||||
|
||||
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
|
||||
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
|
||||
/// It's only used for diagnostics.
|
||||
global_pool_size: AtomicUsize,
|
||||
|
||||
// Maximum number of connections per one endpoint.
|
||||
// Can mix different (dbname, username) connections.
|
||||
@@ -67,85 +99,173 @@ pub struct GlobalConnPool {
|
||||
max_conns_per_endpoint: usize,
|
||||
|
||||
proxy_config: &'static crate::config::ProxyConfig,
|
||||
|
||||
// Using a lock to remove any race conditions.
|
||||
// Eg cleaning up connections while a new connection is returned
|
||||
closed: RwLock<bool>,
|
||||
}
|
||||
|
||||
impl GlobalConnPool {
|
||||
pub fn new(config: &'static crate::config::ProxyConfig) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
global_pool: Mutex::new(HashMap::new()),
|
||||
global_pool: DashMap::new(),
|
||||
global_pool_size: AtomicUsize::new(0),
|
||||
max_conns_per_endpoint: MAX_CONNS_PER_ENDPOINT,
|
||||
proxy_config: config,
|
||||
closed: RwLock::new(false),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn shutdown(&self) {
|
||||
*self.closed.write() = true;
|
||||
|
||||
self.global_pool.retain(|_, endpoint_pool| {
|
||||
let mut pool = endpoint_pool.write();
|
||||
// by clearing this hashmap, we remove the slots that a connection can be returned to.
|
||||
// when returning, it drops the connection if the slot doesn't exist
|
||||
pool.pools.clear();
|
||||
pool.total_conns = 0;
|
||||
|
||||
false
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
&self,
|
||||
conn_info: &ConnInfo,
|
||||
force_new: bool,
|
||||
) -> anyhow::Result<tokio_postgres::Client> {
|
||||
let mut client: Option<tokio_postgres::Client> = None;
|
||||
session_id: uuid::Uuid,
|
||||
) -> anyhow::Result<Client> {
|
||||
let mut client: Option<Client> = None;
|
||||
|
||||
let mut hash_valid = false;
|
||||
if !force_new {
|
||||
let pool = self.get_endpoint_pool(&conn_info.hostname).await;
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
let mut hash = None;
|
||||
|
||||
// find a pool entry by (dbname, username) if exists
|
||||
let mut pool = pool.lock();
|
||||
let pool_entries = pool.pools.get_mut(&conn_info.db_and_user());
|
||||
if let Some(pool_entries) = pool_entries {
|
||||
if let Some(entry) = pool_entries.pop() {
|
||||
client = Some(entry.conn);
|
||||
pool.total_conns -= 1;
|
||||
{
|
||||
let pool = pool.read();
|
||||
if let Some(pool_entries) = pool.pools.get(&conn_info.db_and_user()) {
|
||||
if !pool_entries.conns.is_empty() {
|
||||
hash = pool_entries.password_hash.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// a connection exists in the pool, verify the password hash
|
||||
if let Some(hash) = hash {
|
||||
let pw = conn_info.password.clone();
|
||||
let validate = tokio::task::spawn_blocking(move || {
|
||||
Pbkdf2.verify_password(pw.as_bytes(), &hash.password_hash())
|
||||
})
|
||||
.await?;
|
||||
|
||||
// if the hash is invalid, don't error
|
||||
// we will continue with the regular connection flow
|
||||
if validate.is_ok() {
|
||||
hash_valid = true;
|
||||
let mut pool = pool.write();
|
||||
if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
|
||||
if let Some(entry) = pool_entries.conns.pop() {
|
||||
client = Some(entry.conn);
|
||||
pool.total_conns -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ok return cached connection if found and establish a new one otherwise
|
||||
if let Some(client) = client {
|
||||
if client.is_closed() {
|
||||
let new_client = if let Some(client) = client {
|
||||
if client.inner.is_closed() {
|
||||
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
connect_to_compute(self.proxy_config, conn_info).await
|
||||
connect_to_compute(self.proxy_config, conn_info, session_id).await
|
||||
} else {
|
||||
info!("pool: reusing connection '{conn_info}'");
|
||||
Ok(client)
|
||||
client.session.send(session_id)?;
|
||||
return Ok(client);
|
||||
}
|
||||
} else {
|
||||
info!("pool: opening a new connection '{conn_info}'");
|
||||
connect_to_compute(self.proxy_config, conn_info).await
|
||||
connect_to_compute(self.proxy_config, conn_info, session_id).await
|
||||
};
|
||||
|
||||
match &new_client {
|
||||
// clear the hash. it's no longer valid
|
||||
// TODO: update tokio-postgres fork to allow access to this error kind directly
|
||||
Err(err)
|
||||
if hash_valid && err.to_string().contains("password authentication failed") =>
|
||||
{
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
let mut pool = pool.write();
|
||||
if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) {
|
||||
entry.password_hash = None;
|
||||
}
|
||||
}
|
||||
// new password is valid and we should insert/update it
|
||||
Ok(_) if !force_new && !hash_valid => {
|
||||
let pw = conn_info.password.clone();
|
||||
let new_hash = tokio::task::spawn_blocking(move || {
|
||||
let salt = SaltString::generate(rand::rngs::OsRng);
|
||||
Pbkdf2
|
||||
.hash_password_customized(pw.as_bytes(), None, None, PARAMS, &salt)
|
||||
.map(|s| s.serialize())
|
||||
})
|
||||
.await??;
|
||||
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
let mut pool = pool.write();
|
||||
pool.pools
|
||||
.entry(conn_info.db_and_user())
|
||||
.or_default()
|
||||
.password_hash = Some(new_hash);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
new_client
|
||||
}
|
||||
|
||||
pub async fn put(
|
||||
&self,
|
||||
conn_info: &ConnInfo,
|
||||
client: tokio_postgres::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
let pool = self.get_endpoint_pool(&conn_info.hostname).await;
|
||||
pub fn put(&self, conn_info: &ConnInfo, client: Client) -> anyhow::Result<()> {
|
||||
// We want to hold this open while we return. This ensures that the pool can't close
|
||||
// while we are in the middle of returning the connection.
|
||||
let closed = self.closed.read();
|
||||
if *closed {
|
||||
info!("pool: throwing away connection '{conn_info}' because pool is closed");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if client.inner.is_closed() {
|
||||
info!("pool: throwing away connection '{conn_info}' because connection is closed");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
|
||||
// return connection to the pool
|
||||
let mut total_conns;
|
||||
let mut returned = false;
|
||||
let mut per_db_size = 0;
|
||||
{
|
||||
let mut pool = pool.lock();
|
||||
total_conns = pool.total_conns;
|
||||
let total_conns = {
|
||||
let mut pool = pool.write();
|
||||
|
||||
let pool_entries: &mut Vec<ConnPoolEntry> = pool
|
||||
.pools
|
||||
.entry(conn_info.db_and_user())
|
||||
.or_insert_with(|| Vec::with_capacity(1));
|
||||
if total_conns < self.max_conns_per_endpoint {
|
||||
pool_entries.push(ConnPoolEntry {
|
||||
conn: client,
|
||||
_last_access: std::time::Instant::now(),
|
||||
});
|
||||
if pool.total_conns < self.max_conns_per_endpoint {
|
||||
// we create this db-user entry in get, so it should not be None
|
||||
if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) {
|
||||
pool_entries.conns.push(ConnPoolEntry {
|
||||
conn: client,
|
||||
_last_access: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
total_conns += 1;
|
||||
returned = true;
|
||||
per_db_size = pool_entries.len();
|
||||
returned = true;
|
||||
per_db_size = pool_entries.conns.len();
|
||||
|
||||
pool.total_conns += 1;
|
||||
pool.total_conns += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pool.total_conns
|
||||
};
|
||||
|
||||
// do logging outside of the mutex
|
||||
if returned {
|
||||
@@ -157,25 +277,35 @@ impl GlobalConnPool {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_endpoint_pool(&self, endpoint: &String) -> Arc<Mutex<EndpointConnPool>> {
|
||||
fn get_or_create_endpoint_pool(&self, endpoint: &String) -> Arc<RwLock<EndpointConnPool>> {
|
||||
// fast path
|
||||
if let Some(pool) = self.global_pool.get(endpoint) {
|
||||
return pool.clone();
|
||||
}
|
||||
|
||||
// slow path
|
||||
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
|
||||
pools: HashMap::new(),
|
||||
total_conns: 0,
|
||||
}));
|
||||
|
||||
// find or create a pool for this endpoint
|
||||
let mut created = false;
|
||||
let mut global_pool = self.global_pool.lock();
|
||||
let pool = global_pool
|
||||
let pool = self
|
||||
.global_pool
|
||||
.entry(endpoint.clone())
|
||||
.or_insert_with(|| {
|
||||
created = true;
|
||||
Arc::new(Mutex::new(EndpointConnPool {
|
||||
pools: HashMap::new(),
|
||||
total_conns: 0,
|
||||
}))
|
||||
new_pool
|
||||
})
|
||||
.clone();
|
||||
let global_pool_size = global_pool.len();
|
||||
drop(global_pool);
|
||||
|
||||
// log new global pool size
|
||||
if created {
|
||||
let global_pool_size = self
|
||||
.global_pool_size
|
||||
.fetch_add(1, atomic::Ordering::Relaxed)
|
||||
+ 1;
|
||||
info!(
|
||||
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
|
||||
);
|
||||
@@ -187,11 +317,12 @@ impl GlobalConnPool {
|
||||
|
||||
struct TokioMechanism<'a> {
|
||||
conn_info: &'a ConnInfo,
|
||||
session_id: uuid::Uuid,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for TokioMechanism<'_> {
|
||||
type Connection = tokio_postgres::Client;
|
||||
type Connection = Client;
|
||||
type ConnectError = tokio_postgres::Error;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
@@ -200,7 +331,7 @@ impl ConnectMechanism for TokioMechanism<'_> {
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
connect_to_compute_once(node_info, self.conn_info, timeout).await
|
||||
connect_to_compute_once(node_info, self.conn_info, timeout, self.session_id).await
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
@@ -213,7 +344,8 @@ impl ConnectMechanism for TokioMechanism<'_> {
|
||||
async fn connect_to_compute(
|
||||
config: &config::ProxyConfig,
|
||||
conn_info: &ConnInfo,
|
||||
) -> anyhow::Result<tokio_postgres::Client> {
|
||||
session_id: uuid::Uuid,
|
||||
) -> anyhow::Result<Client> {
|
||||
let tls = config.tls_config.as_ref();
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
|
||||
@@ -244,17 +376,27 @@ async fn connect_to_compute(
|
||||
.await?
|
||||
.context("missing cache entry from wake_compute")?;
|
||||
|
||||
crate::proxy::connect_to_compute(&TokioMechanism { conn_info }, node_info, &extra, &creds).await
|
||||
crate::proxy::connect_to_compute(
|
||||
&TokioMechanism {
|
||||
conn_info,
|
||||
session_id,
|
||||
},
|
||||
node_info,
|
||||
&extra,
|
||||
&creds,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn connect_to_compute_once(
|
||||
node_info: &console::CachedNodeInfo,
|
||||
conn_info: &ConnInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<tokio_postgres::Client, tokio_postgres::Error> {
|
||||
mut session: uuid::Uuid,
|
||||
) -> Result<Client, tokio_postgres::Error> {
|
||||
let mut config = (*node_info.config).clone();
|
||||
|
||||
let (client, connection) = config
|
||||
let (client, mut connection) = config
|
||||
.user(&conn_info.username)
|
||||
.password(&conn_info.password)
|
||||
.dbname(&conn_info.dbname)
|
||||
@@ -263,11 +405,53 @@ async fn connect_to_compute_once(
|
||||
.connect(tokio_postgres::NoTls)
|
||||
.await?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
error!("connection error: {}", e);
|
||||
}
|
||||
let (tx, mut rx) = tokio::sync::watch::channel(session);
|
||||
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
let span = info_span!(parent: None, "connection", %conn_info, %conn_id);
|
||||
span.in_scope(|| {
|
||||
info!(%session, "new connection");
|
||||
});
|
||||
|
||||
Ok(client)
|
||||
tokio::spawn(
|
||||
poll_fn(move |cx| {
|
||||
if matches!(rx.has_changed(), Ok(true)) {
|
||||
session = *rx.borrow_and_update();
|
||||
info!(%session, "changed session");
|
||||
}
|
||||
|
||||
let message = ready!(connection.poll_message(cx));
|
||||
|
||||
match message {
|
||||
Some(Ok(AsyncMessage::Notice(notice))) => {
|
||||
info!(%session, "notice: {}", notice);
|
||||
Poll::Pending
|
||||
}
|
||||
Some(Ok(AsyncMessage::Notification(notif))) => {
|
||||
warn!(%session, pid = notif.process_id(), channel = notif.channel(), "notification received");
|
||||
Poll::Pending
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
warn!(%session, "unknown message");
|
||||
Poll::Pending
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
error!(%session, "connection error: {}", e);
|
||||
Poll::Ready(())
|
||||
}
|
||||
None => Poll::Ready(()),
|
||||
}
|
||||
})
|
||||
.instrument(span)
|
||||
);
|
||||
|
||||
Ok(Client {
|
||||
inner: client,
|
||||
session: tx,
|
||||
})
|
||||
}
|
||||
|
||||
pub struct Client {
|
||||
pub inner: tokio_postgres::Client,
|
||||
session: tokio::sync::watch::Sender<uuid::Uuid>,
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-outp
|
||||
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
|
||||
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
|
||||
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
|
||||
static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
|
||||
|
||||
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
||||
|
||||
@@ -179,6 +180,7 @@ pub async fn handle(
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
session_id: uuid::Uuid,
|
||||
) -> anyhow::Result<(Value, HashMap<HeaderName, HeaderValue>)> {
|
||||
//
|
||||
// Determine the destination and connection params
|
||||
@@ -194,7 +196,7 @@ pub async fn handle(
|
||||
// Allow connection pooling only if explicitly requested
|
||||
let allow_pool = false;
|
||||
|
||||
// isolation level and read only
|
||||
// isolation level, read only and deferrable
|
||||
|
||||
let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned();
|
||||
let txn_isolation_level = match txn_isolation_level_raw {
|
||||
@@ -208,8 +210,8 @@ pub async fn handle(
|
||||
None => None,
|
||||
};
|
||||
|
||||
let txn_read_only_raw = headers.get(&TXN_READ_ONLY).cloned();
|
||||
let txn_read_only = txn_read_only_raw.as_ref() == Some(&HEADER_VALUE_TRUE);
|
||||
let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
|
||||
let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
let request_content_length = match request.body().size_hint().upper() {
|
||||
Some(v) => v,
|
||||
@@ -228,24 +230,27 @@ pub async fn handle(
|
||||
let body = hyper::body::to_bytes(request.into_body()).await?;
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
|
||||
let mut client = conn_pool.get(&conn_info, !allow_pool).await?;
|
||||
let mut client = conn_pool.get(&conn_info, !allow_pool, session_id).await?;
|
||||
|
||||
//
|
||||
// Now execute the query and return the result
|
||||
//
|
||||
let result = match payload {
|
||||
Payload::Single(query) => query_to_json(&client, query, raw_output, array_mode)
|
||||
Payload::Single(query) => query_to_json(&client.inner, query, raw_output, array_mode)
|
||||
.await
|
||||
.map(|x| (x, HashMap::default())),
|
||||
Payload::Batch(batch_query) => {
|
||||
let mut results = Vec::new();
|
||||
let mut builder = client.build_transaction();
|
||||
let mut builder = client.inner.build_transaction();
|
||||
if let Some(isolation_level) = txn_isolation_level {
|
||||
builder = builder.isolation_level(isolation_level);
|
||||
}
|
||||
if txn_read_only {
|
||||
builder = builder.read_only(true);
|
||||
}
|
||||
if txn_deferrable {
|
||||
builder = builder.deferrable(true);
|
||||
}
|
||||
let transaction = builder.start().await?;
|
||||
for query in batch_query.queries {
|
||||
let result = query_to_json(&transaction, query, raw_output, array_mode).await;
|
||||
@@ -259,21 +264,31 @@ pub async fn handle(
|
||||
}
|
||||
transaction.commit().await?;
|
||||
let mut headers = HashMap::default();
|
||||
headers.insert(
|
||||
TXN_READ_ONLY.clone(),
|
||||
HeaderValue::try_from(txn_read_only.to_string())?,
|
||||
);
|
||||
if let Some(txn_isolation_level_raw) = txn_isolation_level_raw {
|
||||
headers.insert(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level_raw);
|
||||
if txn_read_only {
|
||||
headers.insert(
|
||||
TXN_READ_ONLY.clone(),
|
||||
HeaderValue::try_from(txn_read_only.to_string())?,
|
||||
);
|
||||
}
|
||||
if txn_deferrable {
|
||||
headers.insert(
|
||||
TXN_DEFERRABLE.clone(),
|
||||
HeaderValue::try_from(txn_deferrable.to_string())?,
|
||||
);
|
||||
}
|
||||
if let Some(txn_isolation_level) = txn_isolation_level_raw {
|
||||
headers.insert(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
|
||||
}
|
||||
Ok((json!({ "results": results }), headers))
|
||||
}
|
||||
};
|
||||
|
||||
if allow_pool {
|
||||
let current_span = tracing::Span::current();
|
||||
// return connection to the pool
|
||||
tokio::task::spawn(async move {
|
||||
let _ = conn_pool.put(&conn_info, client).await;
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let _span = current_span.enter();
|
||||
let _ = conn_pool.put(&conn_info, client);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -203,7 +203,7 @@ async fn ws_handler(
|
||||
// TODO: that deserves a refactor as now this function also handles http json client besides websockets.
|
||||
// Right now I don't want to blow up sql-over-http patch with file renames and do that as a follow up instead.
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||
let result = sql_over_http::handle(request, sni_hostname, conn_pool)
|
||||
let result = sql_over_http::handle(request, sni_hostname, conn_pool, session_id)
|
||||
.instrument(info_span!("sql-over-http"))
|
||||
.await;
|
||||
let status_code = match result {
|
||||
@@ -269,6 +269,18 @@ pub async fn task_main(
|
||||
|
||||
let conn_pool: Arc<GlobalConnPool> = GlobalConnPool::new(config);
|
||||
|
||||
// shutdown the connection pool
|
||||
tokio::spawn({
|
||||
let cancellation_token = cancellation_token.clone();
|
||||
let conn_pool = conn_pool.clone();
|
||||
async move {
|
||||
cancellation_token.cancelled().await;
|
||||
tokio::task::spawn_blocking(move || conn_pool.shutdown())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
|
||||
Some(config) => config.into(),
|
||||
@@ -307,7 +319,7 @@ pub async fn task_main(
|
||||
ws_handler(req, config, conn_pool, cancel_map, session_id, sni_name)
|
||||
.instrument(info_span!(
|
||||
"ws-client",
|
||||
session = format_args!("{session_id}")
|
||||
session = %session_id
|
||||
))
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use super::{messages::ServerMessage, Mechanism};
|
||||
use crate::stream::PqStream;
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
/// Abstracts away all peculiarities of the libpq's protocol.
|
||||
pub struct SaslStream<'a, S> {
|
||||
@@ -68,7 +69,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
) -> super::Result<Outcome<M::Output>> {
|
||||
loop {
|
||||
let input = self.recv().await?;
|
||||
let step = mechanism.exchange(input)?;
|
||||
let step = mechanism.exchange(input).map_err(|error| {
|
||||
info!(?error, "error during SASL exchange");
|
||||
error
|
||||
})?;
|
||||
|
||||
use super::Step;
|
||||
return Ok(match step {
|
||||
|
||||
Reference in New Issue
Block a user