Proxy: improve http-pool (#6577)

## Problem

The password check logic for the sql-over-http is a bit non-intuitive. 

## Summary of changes

1. Perform scram auth using the same logic as for websocket cleartext
password.
2. Split establish connection logic and connection pool.
3. Parallelize param parsing logic with authentication + wake compute.
4. Limit the total number of clients
This commit is contained in:
Anna Khanova
2024-02-08 12:57:05 +01:00
committed by GitHub
parent c52495774d
commit c63e3e7e84
16 changed files with 753 additions and 478 deletions

1
Cargo.lock generated
View File

@@ -4079,6 +4079,7 @@ dependencies = [
"clap",
"consumption_metrics",
"dashmap",
"env_logger",
"futures",
"git-version",
"hashbrown 0.13.2",

View File

@@ -19,6 +19,7 @@ chrono.workspace = true
clap.workspace = true
consumption_metrics.workspace = true
dashmap.workspace = true
env_logger.workspace = true
futures.workspace = true
git-version.workspace = true
hashbrown.workspace = true

View File

@@ -68,6 +68,7 @@ pub trait TestBackend: Send + Sync + 'static {
fn get_allowed_ips_and_secret(
&self,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
}
impl std::fmt::Display for BackendType<'_, ()> {
@@ -358,6 +359,17 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
}
impl BackendType<'_, ComputeUserInfo> {
pub async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
use BackendType::*;
match self {
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
Link(_) => Ok(Cached::new_uncached(None)),
}
}
pub async fn get_allowed_ips_and_secret(
&self,
ctx: &mut RequestMonitoring,

View File

@@ -167,7 +167,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
}
pub(super) fn validate_password_and_exchange(
pub(crate) fn validate_password_and_exchange(
password: &[u8],
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {

View File

@@ -165,6 +165,10 @@ struct SqlOverHttpArgs {
#[clap(long, default_value_t = 20)]
sql_over_http_pool_max_conns_per_endpoint: usize,
/// How many connections to pool for each endpoint. Excess connections are discarded
#[clap(long, default_value_t = 20000)]
sql_over_http_pool_max_total_conns: usize,
/// How long pooled connections should remain idle for before closing
#[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
sql_over_http_idle_timeout: tokio::time::Duration,
@@ -387,6 +391,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
pool_shards: args.sql_over_http.sql_over_http_pool_shards,
idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
opt_in: args.sql_over_http.sql_over_http_pool_opt_in,
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
},
};
let authentication_config = AuthenticationConfig {

View File

@@ -188,6 +188,7 @@ impl super::Api for Api {
ep,
Arc::new(auth_info.allowed_ips),
);
ctx.set_project_id(project_id);
}
// When we just got a secret, we don't need to invalidate it.
Ok(Cached::new_uncached(auth_info.secret))
@@ -221,6 +222,7 @@ impl super::Api for Api {
self.caches
.project_info
.insert_allowed_ips(&project_id, ep, allowed_ips.clone());
ctx.set_project_id(project_id);
}
Ok((
Cached::new_uncached(allowed_ips),

View File

@@ -89,6 +89,10 @@ impl RequestMonitoring {
self.project = Some(x.project_id);
}
pub fn set_project_id(&mut self, project_id: ProjectId) {
self.project = Some(project_id);
}
pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
crate::metrics::CONNECTING_ENDPOINTS
.with_label_values(&[self.protocol])

View File

@@ -1,8 +1,10 @@
use ::metrics::{
exponential_buckets, register_histogram, register_histogram_vec, register_hll_vec,
register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge_vec, Histogram,
HistogramVec, HyperLogLogVec, IntCounterPairVec, IntCounterVec, IntGaugeVec,
register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge,
register_int_gauge_vec, Histogram, HistogramVec, HyperLogLogVec, IntCounterPairVec,
IntCounterVec, IntGauge, IntGaugeVec,
};
use metrics::{register_int_counter_pair, IntCounterPair};
use once_cell::sync::Lazy;
use tokio::time;
@@ -112,6 +114,44 @@ pub static ALLOWED_IPS_NUMBER: Lazy<Histogram> = Lazy::new(|| {
.unwrap()
});
pub static HTTP_CONTENT_LENGTH: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"proxy_http_conn_content_length_bytes",
"Time it took for proxy to establish a connection to the compute endpoint",
// largest bucket = 3^16 * 0.05ms = 2.15s
exponential_buckets(8.0, 2.0, 20).unwrap()
)
.unwrap()
});
pub static GC_LATENCY: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"proxy_http_pool_reclaimation_lag_seconds",
"Time it takes to reclaim unused connection pools",
// 1us -> 65ms
exponential_buckets(1e-6, 2.0, 16).unwrap(),
)
.unwrap()
});
pub static ENDPOINT_POOLS: Lazy<IntCounterPair> = Lazy::new(|| {
register_int_counter_pair!(
"proxy_http_pool_endpoints_registered_total",
"Number of endpoints we have registered pools for",
"proxy_http_pool_endpoints_unregistered_total",
"Number of endpoints we have unregistered pools for",
)
.unwrap()
});
pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy<IntGauge> = Lazy::new(|| {
register_int_gauge!(
"proxy_http_pool_opened_connections",
"Number of opened connections to a database.",
)
.unwrap()
});
#[derive(Clone)]
pub struct LatencyTimer {
// time since the stopwatch was started

View File

@@ -34,21 +34,6 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg
node_info.invalidate().config
}
/// Try to connect to the compute node once.
#[tracing::instrument(name = "connect_once", fields(pid = tracing::field::Empty), skip_all)]
async fn connect_to_compute_once(
ctx: &mut RequestMonitoring,
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, compute::ConnectionError> {
let allow_self_signed_compute = node_info.allow_self_signed_compute;
node_info
.config
.connect(ctx, allow_self_signed_compute, timeout)
.await
}
#[async_trait]
pub trait ConnectMechanism {
type Connection;
@@ -75,13 +60,18 @@ impl ConnectMechanism for TcpMechanism<'_> {
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
async fn connect_once(
&self,
ctx: &mut RequestMonitoring,
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> {
connect_to_compute_once(ctx, node_info, timeout).await
let allow_self_signed_compute = node_info.allow_self_signed_compute;
node_info
.config
.connect(ctx, allow_self_signed_compute, timeout)
.await
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {

View File

@@ -478,6 +478,9 @@ impl TestBackend for TestConnectMechanism {
{
unimplemented!("not used in tests")
}
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
unimplemented!("not used in tests")
}
}
fn helper_create_cached_node_info() -> CachedNodeInfo {

View File

@@ -2,6 +2,7 @@
//!
//! Handles both SQL over HTTP and SQL over Websockets.
mod backend;
mod conn_pool;
mod json;
mod sql_over_http;
@@ -18,11 +19,11 @@ pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio_util::task::TaskTracker;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::{cancellation::CancelMap, config::ProxyConfig};
use futures::StreamExt;
use hyper::{
@@ -54,12 +55,13 @@ pub async fn task_main(
info!("websocket server has shut down");
}
let conn_pool = conn_pool::GlobalConnPool::new(config);
let conn_pool2 = Arc::clone(&conn_pool);
tokio::spawn(async move {
conn_pool2.gc_worker(StdRng::from_entropy()).await;
});
let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
{
let conn_pool = Arc::clone(&conn_pool);
tokio::spawn(async move {
conn_pool.gc_worker(StdRng::from_entropy()).await;
});
}
// shutdown the connection pool
tokio::spawn({
@@ -73,6 +75,11 @@ pub async fn task_main(
}
});
let backend = Arc::new(PoolingBackend {
pool: Arc::clone(&conn_pool),
config,
});
let tls_config = match config.tls_config.as_ref() {
Some(config) => config,
None => {
@@ -106,7 +113,7 @@ pub async fn task_main(
let client_addr = io.client_addr();
let remote_addr = io.inner.remote_addr();
let sni_name = tls.server_name().map(|s| s.to_string());
let conn_pool = conn_pool.clone();
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
@@ -119,7 +126,7 @@ pub async fn task_main(
Ok(MetricService::new(hyper::service::service_fn(
move |req: Request<Body>| {
let sni_name = sni_name.clone();
let conn_pool = conn_pool.clone();
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
@@ -130,8 +137,7 @@ pub async fn task_main(
request_handler(
req,
config,
tls_config,
conn_pool,
backend,
ws_connections,
cancel_map,
session_id,
@@ -200,8 +206,7 @@ where
async fn request_handler(
mut request: Request<Body>,
config: &'static ProxyConfig,
tls: &'static TlsConfig,
conn_pool: Arc<conn_pool::GlobalConnPool>,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancel_map: Arc<CancelMap>,
session_id: uuid::Uuid,
@@ -248,15 +253,7 @@ async fn request_handler(
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
sql_over_http::handle(
tls,
&config.http_config,
&mut ctx,
request,
sni_hostname,
conn_pool,
)
.await
sql_over_http::handle(config, &mut ctx, request, sni_hostname, backend).await
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
Response::builder()
.header("Allow", "OPTIONS, POST")

View File

@@ -0,0 +1,157 @@
use std::{sync::Arc, time::Duration};
use anyhow::Context;
use async_trait::async_trait;
use tracing::info;
use crate::{
auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError},
compute,
config::ProxyConfig,
console::CachedNodeInfo,
context::RequestMonitoring,
proxy::connect_compute::ConnectMechanism,
};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool, APP_NAME};
pub struct PoolingBackend {
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub config: &'static ProxyConfig,
}
impl PoolingBackend {
pub async fn authenticate(
&self,
ctx: &mut RequestMonitoring,
conn_info: &ConnInfo,
) -> Result<ComputeCredentialKeys, AuthError> {
let user_info = conn_info.user_info.clone();
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(AuthError::ip_address_not_allowed());
}
let cached_secret = match maybe_secret {
Some(secret) => secret,
None => backend.get_role_secret(ctx).await?,
};
let secret = match cached_secret.value.clone() {
Some(secret) => secret,
None => {
// If we don't have an authentication secret, for the http flow we can just return an error.
info!("authentication info not found");
return Err(AuthError::auth_failed(&*user_info.user));
}
};
let auth_outcome =
crate::auth::validate_password_and_exchange(conn_info.password.as_bytes(), secret)?;
match auth_outcome {
crate::sasl::Outcome::Success(key) => Ok(key),
crate::sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
Err(AuthError::auth_failed(&*conn_info.user_info.user))
}
}
}
// Wake up the destination if needed. Code here is a bit involved because
// we reuse the code from the usual proxy and we need to prepare few structures
// that this code expects.
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
pub async fn connect_to_compute(
&self,
ctx: &mut RequestMonitoring,
conn_info: ConnInfo,
keys: ComputeCredentialKeys,
force_new: bool,
) -> anyhow::Result<Client<tokio_postgres::Client>> {
let maybe_client = if !force_new {
info!("pool: looking for an existing connection");
self.pool.get(ctx, &conn_info).await?
} else {
info!("pool: pool is disabled");
None
};
if let Some(client) = maybe_client {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
ctx.set_application(Some(APP_NAME));
let backend = self
.config
.auth_backend
.as_ref()
.map(|_| conn_info.user_info.clone());
let mut node_info = backend
.wake_compute(ctx)
.await?
.context("missing cache entry from wake_compute")?;
match keys {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => node_info.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys),
};
ctx.set_project(node_info.aux.clone());
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
},
node_info,
&backend,
)
.await
}
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
}
#[async_trait]
impl ConnectMechanism for TokioMechanism {
type Connection = Client<tokio_postgres::Client>;
type ConnectError = tokio_postgres::Error;
type Error = anyhow::Error;
async fn connect_once(
&self,
ctx: &mut RequestMonitoring,
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let mut config = (*node_info.config).clone();
let config = config
.user(&self.conn_info.user_info.user)
.password(&*self.conn_info.password)
.dbname(&self.conn_info.dbname)
.connect_timeout(timeout);
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
Ok(poll_client(
self.pool.clone(),
ctx,
self.conn_info.clone(),
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}

File diff suppressed because it is too large Load Diff

View File

@@ -9,23 +9,23 @@ use tokio_postgres::Row;
// as parameters.
//
pub fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
json.iter()
.map(|value| {
match value {
// special care for nulls
Value::Null => None,
json.iter().map(json_value_to_pg_text).collect()
}
// convert to text with escaping
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
fn json_value_to_pg_text(value: &Value) -> Option<String> {
match value {
// special care for nulls
Value::Null => None,
// avoid escaping here, as we pass this as a parameter
Value::String(s) => Some(s.to_string()),
// convert to text with escaping
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
// special care for arrays
Value::Array(_) => json_array_to_pg_array(value),
}
})
.collect()
// avoid escaping here, as we pass this as a parameter
Value::String(s) => Some(s.to_string()),
// special care for arrays
Value::Array(_) => json_array_to_pg_array(value),
}
}
//

View File

@@ -13,6 +13,7 @@ use hyper::StatusCode;
use hyper::{Body, HeaderMap, Request};
use serde_json::json;
use serde_json::Value;
use tokio::join;
use tokio_postgres::error::DbError;
use tokio_postgres::error::ErrorPosition;
use tokio_postgres::GenericClient;
@@ -20,6 +21,7 @@ use tokio_postgres::IsolationLevel;
use tokio_postgres::ReadyForQueryStatus;
use tokio_postgres::Transaction;
use tracing::error;
use tracing::info;
use tracing::instrument;
use url::Url;
use utils::http::error::ApiError;
@@ -27,22 +29,25 @@ use utils::http::json::json_response;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::config::HttpConfig;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::metrics::HTTP_CONTENT_LENGTH;
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
use crate::proxy::NeonOptions;
use crate::RoleName;
use super::backend::PoolingBackend;
use super::conn_pool::ConnInfo;
use super::conn_pool::GlobalConnPool;
use super::json::{json_to_pg_text, pg_text_row_to_json};
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::SERVERLESS_DRIVER_SNI;
#[derive(serde::Deserialize)]
struct QueryData {
query: String,
params: Vec<serde_json::Value>,
#[serde(deserialize_with = "bytes_to_pg_text")]
params: Vec<Option<String>>,
}
#[derive(serde::Deserialize)]
@@ -69,6 +74,15 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
where
D: serde::de::Deserializer<'de>,
{
// TODO: consider avoiding the allocation here.
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
Ok(json_to_pg_text(json))
}
fn get_conn_info(
ctx: &mut RequestMonitoring,
headers: &HeaderMap,
@@ -171,16 +185,15 @@ fn check_matches(sni_hostname: &str, hostname: &str) -> Result<bool, anyhow::Err
// TODO: return different http error codes
pub async fn handle(
tls: &'static TlsConfig,
config: &'static HttpConfig,
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
request: Request<Body>,
sni_hostname: Option<String>,
conn_pool: Arc<GlobalConnPool>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Body>, ApiError> {
let result = tokio::time::timeout(
config.request_timeout,
handle_inner(tls, config, ctx, request, sni_hostname, conn_pool),
config.http_config.request_timeout,
handle_inner(config, ctx, request, sni_hostname, backend),
)
.await;
let mut response = match result {
@@ -265,7 +278,7 @@ pub async fn handle(
Err(_) => {
let message = format!(
"HTTP-Connection timed out, execution time exeeded {} seconds",
config.request_timeout.as_secs()
config.http_config.request_timeout.as_secs()
);
error!(message);
json_response(
@@ -283,22 +296,36 @@ pub async fn handle(
#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
async fn handle_inner(
tls: &'static TlsConfig,
config: &'static HttpConfig,
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
request: Request<Body>,
sni_hostname: Option<String>,
conn_pool: Arc<GlobalConnPool>,
backend: Arc<PoolingBackend>,
) -> anyhow::Result<Response<Body>> {
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
.with_label_values(&["http"])
.with_label_values(&[ctx.protocol])
.guard();
info!(
protocol = ctx.protocol,
"handling interactive connection from client"
);
//
// Determine the destination and connection params
//
let headers = request.headers();
let conn_info = get_conn_info(ctx, headers, sni_hostname, tls)?;
// TLS config should be there.
let conn_info = get_conn_info(
ctx,
headers,
sni_hostname,
config.tls_config.as_ref().unwrap(),
)?;
info!(
user = conn_info.user_info.user.as_str(),
project = conn_info.user_info.endpoint.as_str(),
"credentials"
);
// Determine the output options. Default behaviour is 'false'. Anything that is not
// strictly 'true' assumed to be false.
@@ -307,8 +334,8 @@ async fn handle_inner(
// Allow connection pooling only if explicitly requested
// or if we have decided that http pool is no longer opt-in
let allow_pool =
!config.pool_options.opt_in || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
let allow_pool = !config.http_config.pool_options.opt_in
|| headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
// isolation level, read only and deferrable
@@ -333,6 +360,8 @@ async fn handle_inner(
None => MAX_REQUEST_SIZE + 1,
};
drop(paused);
info!(request_content_length, "request size in bytes");
HTTP_CONTENT_LENGTH.observe(request_content_length as f64);
// we don't have a streaming request support yet so this is to prevent OOM
// from a malicious user sending an extremely large request body
@@ -342,13 +371,28 @@ async fn handle_inner(
));
}
//
// Read the query and query params from the request body
//
let body = hyper::body::to_bytes(request.into_body()).await?;
let payload: Payload = serde_json::from_slice(&body)?;
let fetch_and_process_request = async {
let body = hyper::body::to_bytes(request.into_body())
.await
.map_err(anyhow::Error::from)?;
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
};
let mut client = conn_pool.get(ctx, conn_info, !allow_pool).await?;
let authenticate_and_connect = async {
let keys = backend.authenticate(ctx, &conn_info).await?;
backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await
};
// Run both operations in parallel
let (payload_result, auth_and_connect_result) =
join!(fetch_and_process_request, authenticate_and_connect,);
// Handle the results
let payload = payload_result?; // Handle errors appropriately
let mut client = auth_and_connect_result?; // Handle errors appropriately
let mut response = Response::builder()
.status(StatusCode::OK)
@@ -482,7 +526,7 @@ async fn query_to_json<T: GenericClient>(
raw_output: bool,
array_mode: bool,
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
let query_params = json_to_pg_text(data.params);
let query_params = data.params;
let row_stream = client.query_raw_txt(&data.query, query_params).await?;
// Manually drain the stream into a vector to leave row_stream hanging

View File

@@ -393,11 +393,11 @@ def test_sql_over_http_batch(static_proxy: NeonProxy):
def test_sql_over_http_pool(static_proxy: NeonProxy):
static_proxy.safe_psql("create user http_auth with password 'http' superuser")
def get_pid(status: int, pw: str) -> Any:
def get_pid(status: int, pw: str, user="http_auth") -> Any:
return static_proxy.http_query(
GET_CONNECTION_PID_QUERY,
[],
user="http_auth",
user=user,
password=pw,
expected_code=status,
)
@@ -418,20 +418,14 @@ def test_sql_over_http_pool(static_proxy: NeonProxy):
static_proxy.safe_psql("alter user http_auth with password 'http2'")
# after password change, should open a new connection to verify it
pid2 = get_pid(200, "http2")["rows"][0]["pid"]
assert pid1 != pid2
# after password change, shouldn't open a new connection because it checks password in proxy.
rows = get_pid(200, "http2")["rows"]
assert rows == [{"pid": pid1}]
time.sleep(0.02)
# query should be on an existing connection
pid = get_pid(200, "http2")["rows"][0]["pid"]
assert pid in [pid1, pid2]
time.sleep(0.02)
# old password should not work
res = get_pid(400, "http")
# incorrect user shouldn't reveal that the user doesn't exists
res = get_pid(400, "http", user="http_auth2")
assert "password authentication failed for user" in res["message"]