proxy: cancel http queries on timeout (#7031)

## Problem

On HTTP query timeout, we should try and cancel the current in-flight
SQL query.

## Summary of changes

Trigger a cancellation command in postgres once the timeout is reach
This commit is contained in:
Conrad Ludgate
2024-03-12 11:52:00 +00:00
committed by GitHub
parent 89cf714890
commit 09699d4bd8
4 changed files with 242 additions and 118 deletions

View File

@@ -612,13 +612,6 @@ impl<C: ClientInnerExt> Client<C> {
let inner = inner.as_mut().expect("client inner should not be removed");
(&mut inner.inner, Discard { pool, conn_info })
}
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
self.inner().1.check_idle(status)
}
pub fn discard(&mut self) {
self.inner().1.discard()
}
}
impl<C: ClientInnerExt> Discard<'_, C> {
@@ -739,7 +732,7 @@ mod tests {
{
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
assert_eq!(0, pool.get_global_connections_count());
client.discard();
client.inner().1.discard();
// Discard should not add the connection from the pool.
assert_eq!(0, pool.get_global_connections_count());
}

View File

@@ -1,6 +1,10 @@
use std::pin::pin;
use std::sync::Arc;
use anyhow::bail;
use futures::future::select;
use futures::future::try_join;
use futures::future::Either;
use futures::StreamExt;
use hyper::body::HttpBody;
use hyper::header;
@@ -11,13 +15,16 @@ use hyper::StatusCode;
use hyper::{Body, HeaderMap, Request};
use serde_json::json;
use serde_json::Value;
use tokio::try_join;
use tokio::time;
use tokio_postgres::error::DbError;
use tokio_postgres::error::ErrorPosition;
use tokio_postgres::error::SqlState;
use tokio_postgres::GenericClient;
use tokio_postgres::IsolationLevel;
use tokio_postgres::NoTls;
use tokio_postgres::ReadyForQueryStatus;
use tokio_postgres::Transaction;
use tokio_util::sync::CancellationToken;
use tracing::error;
use tracing::info;
use url::Url;
@@ -194,108 +201,111 @@ pub async fn handle(
request: Request<Body>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Body>, ApiError> {
let result = tokio::time::timeout(
config.http_config.request_timeout,
handle_inner(config, &mut ctx, request, backend),
)
.await;
let cancel = CancellationToken::new();
let cancel2 = cancel.clone();
let handle = tokio::spawn(async move {
time::sleep(config.http_config.request_timeout).await;
cancel2.cancel();
});
let result = handle_inner(cancel, config, &mut ctx, request, backend).await;
handle.abort();
let mut response = match result {
Ok(r) => match r {
Ok(r) => {
ctx.set_success();
r
Ok(Ok(r)) => {
ctx.set_success();
r
}
Err(e) => {
// TODO: ctx.set_error_kind(e.get_error_type());
let mut message = format!("{:?}", e);
let db_error = e
.downcast_ref::<tokio_postgres::Error>()
.and_then(|e| e.as_db_error());
fn get<'a, T: serde::Serialize>(
db: Option<&'a DbError>,
x: impl FnOnce(&'a DbError) -> T,
) -> Value {
db.map(x)
.and_then(|t| serde_json::to_value(t).ok())
.unwrap_or_default()
}
Err(e) => {
// TODO: ctx.set_error_kind(e.get_error_type());
let mut message = format!("{:?}", e);
let db_error = e
.downcast_ref::<tokio_postgres::Error>()
.and_then(|e| e.as_db_error());
fn get<'a, T: serde::Serialize>(
db: Option<&'a DbError>,
x: impl FnOnce(&'a DbError) -> T,
) -> Value {
db.map(x)
.and_then(|t| serde_json::to_value(t).ok())
.unwrap_or_default()
}
if let Some(db_error) = db_error {
db_error.message().clone_into(&mut message);
}
let position = db_error.and_then(|db| db.position());
let (position, internal_position, internal_query) = match position {
Some(ErrorPosition::Original(position)) => (
Value::String(position.to_string()),
Value::Null,
Value::Null,
),
Some(ErrorPosition::Internal { position, query }) => (
Value::Null,
Value::String(position.to_string()),
Value::String(query.clone()),
),
None => (Value::Null, Value::Null, Value::Null),
};
let code = get(db_error, |db| db.code().code());
let severity = get(db_error, |db| db.severity());
let detail = get(db_error, |db| db.detail());
let hint = get(db_error, |db| db.hint());
let where_ = get(db_error, |db| db.where_());
let table = get(db_error, |db| db.table());
let column = get(db_error, |db| db.column());
let schema = get(db_error, |db| db.schema());
let datatype = get(db_error, |db| db.datatype());
let constraint = get(db_error, |db| db.constraint());
let file = get(db_error, |db| db.file());
let line = get(db_error, |db| db.line().map(|l| l.to_string()));
let routine = get(db_error, |db| db.routine());
error!(
?code,
"sql-over-http per-client task finished with an error: {e:#}"
);
// TODO: this shouldn't always be bad request.
json_response(
StatusCode::BAD_REQUEST,
json!({
"message": message,
"code": code,
"detail": detail,
"hint": hint,
"position": position,
"internalPosition": internal_position,
"internalQuery": internal_query,
"severity": severity,
"where": where_,
"table": table,
"column": column,
"schema": schema,
"dataType": datatype,
"constraint": constraint,
"file": file,
"line": line,
"routine": routine,
}),
)?
if let Some(db_error) = db_error {
db_error.message().clone_into(&mut message);
}
},
Err(_) => {
let position = db_error.and_then(|db| db.position());
let (position, internal_position, internal_query) = match position {
Some(ErrorPosition::Original(position)) => (
Value::String(position.to_string()),
Value::Null,
Value::Null,
),
Some(ErrorPosition::Internal { position, query }) => (
Value::Null,
Value::String(position.to_string()),
Value::String(query.clone()),
),
None => (Value::Null, Value::Null, Value::Null),
};
let code = get(db_error, |db| db.code().code());
let severity = get(db_error, |db| db.severity());
let detail = get(db_error, |db| db.detail());
let hint = get(db_error, |db| db.hint());
let where_ = get(db_error, |db| db.where_());
let table = get(db_error, |db| db.table());
let column = get(db_error, |db| db.column());
let schema = get(db_error, |db| db.schema());
let datatype = get(db_error, |db| db.datatype());
let constraint = get(db_error, |db| db.constraint());
let file = get(db_error, |db| db.file());
let line = get(db_error, |db| db.line().map(|l| l.to_string()));
let routine = get(db_error, |db| db.routine());
error!(
?code,
"sql-over-http per-client task finished with an error: {e:#}"
);
// TODO: this shouldn't always be bad request.
json_response(
StatusCode::BAD_REQUEST,
json!({
"message": message,
"code": code,
"detail": detail,
"hint": hint,
"position": position,
"internalPosition": internal_position,
"internalQuery": internal_query,
"severity": severity,
"where": where_,
"table": table,
"column": column,
"schema": schema,
"dataType": datatype,
"constraint": constraint,
"file": file,
"line": line,
"routine": routine,
}),
)?
}
Ok(Err(Cancelled())) => {
// TODO: when http error classification is done, distinguish between
// timeout on sql vs timeout in proxy/cplane
// ctx.set_error_kind(crate::error::ErrorKind::RateLimit);
let message = format!(
"HTTP-Connection timed out, execution time exceeded {} seconds",
config.http_config.request_timeout.as_secs()
"Query cancelled, runtime exceeded. SQL queries over HTTP must not exceed {} seconds of runtime. Please consider using our websocket based connections",
config.http_config.request_timeout.as_secs_f64()
);
error!(message);
json_response(
StatusCode::GATEWAY_TIMEOUT,
json!({ "message": message, "code": StatusCode::GATEWAY_TIMEOUT.as_u16() }),
StatusCode::BAD_REQUEST,
json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }),
)?
}
};
@@ -307,12 +317,15 @@ pub async fn handle(
Ok(response)
}
struct Cancelled();
async fn handle_inner(
cancel: CancellationToken,
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
request: Request<Body>,
backend: Arc<PoolingBackend>,
) -> anyhow::Result<Response<Body>> {
) -> Result<Result<Response<Body>, Cancelled>, anyhow::Error> {
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
.with_label_values(&[ctx.protocol])
.guard();
@@ -389,7 +402,18 @@ async fn handle_inner(
};
// Run both operations in parallel
let (payload, mut client) = try_join!(fetch_and_process_request, authenticate_and_connect)?;
let (payload, mut client) = match select(
try_join(
pin!(fetch_and_process_request),
pin!(authenticate_and_connect),
),
pin!(cancel.cancelled()),
)
.await
{
Either::Left((result, _cancelled)) => result?,
Either::Right((_cancelled, _)) => return Ok(Err(Cancelled())),
};
let mut response = Response::builder()
.status(StatusCode::OK)
@@ -401,19 +425,60 @@ async fn handle_inner(
let mut size = 0;
let result = match payload {
Payload::Single(stmt) => {
let (status, results) =
query_to_json(&*client, stmt, &mut 0, raw_output, default_array_mode)
.await
.map_err(|e| {
client.discard();
e
})?;
client.check_idle(status);
results
let mut size = 0;
let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token();
let query = pin!(query_to_json(
&*inner,
stmt,
&mut size,
raw_output,
default_array_mode
));
let cancelled = pin!(cancel.cancelled());
let res = select(query, cancelled).await;
match res {
Either::Left((Ok((status, results)), _cancelled)) => {
discard.check_idle(status);
results
}
Either::Left((Err(e), _cancelled)) => {
discard.discard();
return Err(e);
}
Either::Right((_cancelled, query)) => {
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::error!(?err, "could not cancel query");
}
match time::timeout(time::Duration::from_millis(100), query).await {
Ok(Ok((status, results))) => {
discard.check_idle(status);
results
}
Ok(Err(error)) => {
let db_error = error
.downcast_ref::<tokio_postgres::Error>()
.and_then(|e| e.as_db_error());
// if errored for some other reason, it might not be safe to return
if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
discard.discard();
}
return Ok(Err(Cancelled()));
}
Err(_timeout) => {
discard.discard();
return Ok(Err(Cancelled()));
}
}
}
}
}
Payload::Batch(statements) => {
info!("starting transaction");
let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token();
let mut builder = inner.build_transaction();
if let Some(isolation_level) = txn_isolation_level {
builder = builder.isolation_level(isolation_level);
@@ -433,6 +498,7 @@ async fn handle_inner(
})?;
let results = match query_batch(
cancel.child_token(),
&transaction,
statements,
&mut size,
@@ -441,7 +507,7 @@ async fn handle_inner(
)
.await
{
Ok(results) => {
Ok(Ok(results)) => {
info!("commit");
let status = transaction.commit().await.map_err(|e| {
// if we cannot commit - for now don't return connection to pool
@@ -452,6 +518,15 @@ async fn handle_inner(
discard.check_idle(status);
results
}
Ok(Err(Cancelled())) => {
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::error!(?err, "could not cancel query");
}
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
discard.discard();
return Ok(Err(Cancelled()));
}
Err(err) => {
info!("rollback");
let status = transaction.rollback().await.map_err(|e| {
@@ -499,26 +574,44 @@ async fn handle_inner(
// moving this later in the stack is going to be a lot of effort and ehhhh
metrics.record_egress(len as u64);
Ok(response)
Ok(Ok(response))
}
async fn query_batch(
cancel: CancellationToken,
transaction: &Transaction<'_>,
queries: BatchQueryData,
total_size: &mut usize,
raw_output: bool,
array_mode: bool,
) -> anyhow::Result<Vec<Value>> {
) -> anyhow::Result<Result<Vec<Value>, Cancelled>> {
let mut results = Vec::with_capacity(queries.queries.len());
let mut current_size = 0;
for stmt in queries.queries {
// TODO: maybe we should check that the transaction bit is set here
let (_, values) =
query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?;
results.push(values);
let query = pin!(query_to_json(
transaction,
stmt,
&mut current_size,
raw_output,
array_mode
));
let cancelled = pin!(cancel.cancelled());
let res = select(query, cancelled).await;
match res {
// TODO: maybe we should check that the transaction bit is set here
Either::Left((Ok((_, values)), _cancelled)) => {
results.push(values);
}
Either::Left((Err(e), _cancelled)) => {
return Err(e);
}
Either::Right((_cancelled, _)) => {
return Ok(Err(Cancelled()));
}
}
}
*total_size += current_size;
Ok(results)
Ok(Ok(results))
}
async fn query_to_json<T: GenericClient>(

View File

@@ -2859,6 +2859,7 @@ class NeonProxy(PgProtocol):
self.auth_backend = auth_backend
self.metric_collection_endpoint = metric_collection_endpoint
self.metric_collection_interval = metric_collection_interval
self.http_timeout_seconds = 15
self._popen: Optional[subprocess.Popen[bytes]] = None
def start(self) -> NeonProxy:
@@ -2897,6 +2898,7 @@ class NeonProxy(PgProtocol):
*["--proxy", f"{self.host}:{self.proxy_port}"],
*["--mgmt", f"{self.host}:{self.mgmt_port}"],
*["--wss", f"{self.host}:{self.external_http_port}"],
*["--sql-over-http-timeout", f"{self.http_timeout_seconds}s"],
*["-c", str(crt_path)],
*["-k", str(key_path)],
*self.auth_backend.extra_args(),
@@ -2937,6 +2939,8 @@ class NeonProxy(PgProtocol):
password = quote(kwargs["password"])
expected_code = kwargs.get("expected_code")
log.info(f"Executing http query: {query}")
connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres"
response = requests.post(
f"https://{self.domain}:{self.external_http_port}/sql",
@@ -2959,6 +2963,8 @@ class NeonProxy(PgProtocol):
password = kwargs["password"]
expected_code = kwargs.get("expected_code")
log.info(f"Executing http2 query: {query}")
connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres"
async with httpx.AsyncClient(
http2=True, verify=str(self.test_output_dir / "proxy.crt")

View File

@@ -564,3 +564,35 @@ async def test_sql_over_http2(static_proxy: NeonProxy):
"select 42 as answer", [], user="http", password="http", expected_code=200
)
assert resp["rows"] == [{"answer": 42}]
def test_sql_over_http_timeout_cancel(static_proxy: NeonProxy):
static_proxy.safe_psql("create role http with login password 'http' superuser")
static_proxy.safe_psql("create table test_table ( id int primary key )")
# insert into a table, with a unique constraint, after sleeping for n seconds
query = "WITH temp AS ( \
SELECT pg_sleep($1) as sleep, $2::int as id \
) INSERT INTO test_table (id) SELECT id FROM temp"
# expect to fail with timeout
res = static_proxy.http_query(
query,
[static_proxy.http_timeout_seconds + 1, 1],
user="http",
password="http",
expected_code=400,
)
assert "Query cancelled, runtime exceeded" in res["message"], "HTTP query should time out"
time.sleep(2)
res = static_proxy.http_query(query, [1, 1], user="http", password="http", expected_code=200)
assert res["command"] == "INSERT", "HTTP query should insert"
assert res["rowCount"] == 1, "HTTP query should insert"
res = static_proxy.http_query(query, [0, 1], user="http", password="http", expected_code=400)
assert (
"duplicate key value violates unique constraint" in res["message"]
), "HTTP query should conflict"