mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 14:02:55 +00:00
proxy: cancel http queries on timeout (#7031)
## Problem On HTTP query timeout, we should try and cancel the current in-flight SQL query. ## Summary of changes Trigger a cancellation command in postgres once the timeout is reach
This commit is contained in:
@@ -612,13 +612,6 @@ impl<C: ClientInnerExt> Client<C> {
|
||||
let inner = inner.as_mut().expect("client inner should not be removed");
|
||||
(&mut inner.inner, Discard { pool, conn_info })
|
||||
}
|
||||
|
||||
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
self.inner().1.check_idle(status)
|
||||
}
|
||||
pub fn discard(&mut self) {
|
||||
self.inner().1.discard()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
@@ -739,7 +732,7 @@ mod tests {
|
||||
{
|
||||
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
|
||||
assert_eq!(0, pool.get_global_connections_count());
|
||||
client.discard();
|
||||
client.inner().1.discard();
|
||||
// Discard should not add the connection from the pool.
|
||||
assert_eq!(0, pool.get_global_connections_count());
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use futures::future::select;
|
||||
use futures::future::try_join;
|
||||
use futures::future::Either;
|
||||
use futures::StreamExt;
|
||||
use hyper::body::HttpBody;
|
||||
use hyper::header;
|
||||
@@ -11,13 +15,16 @@ use hyper::StatusCode;
|
||||
use hyper::{Body, HeaderMap, Request};
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
use tokio::try_join;
|
||||
use tokio::time;
|
||||
use tokio_postgres::error::DbError;
|
||||
use tokio_postgres::error::ErrorPosition;
|
||||
use tokio_postgres::error::SqlState;
|
||||
use tokio_postgres::GenericClient;
|
||||
use tokio_postgres::IsolationLevel;
|
||||
use tokio_postgres::NoTls;
|
||||
use tokio_postgres::ReadyForQueryStatus;
|
||||
use tokio_postgres::Transaction;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use url::Url;
|
||||
@@ -194,108 +201,111 @@ pub async fn handle(
|
||||
request: Request<Body>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
config.http_config.request_timeout,
|
||||
handle_inner(config, &mut ctx, request, backend),
|
||||
)
|
||||
.await;
|
||||
let cancel = CancellationToken::new();
|
||||
let cancel2 = cancel.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
time::sleep(config.http_config.request_timeout).await;
|
||||
cancel2.cancel();
|
||||
});
|
||||
|
||||
let result = handle_inner(cancel, config, &mut ctx, request, backend).await;
|
||||
handle.abort();
|
||||
|
||||
let mut response = match result {
|
||||
Ok(r) => match r {
|
||||
Ok(r) => {
|
||||
ctx.set_success();
|
||||
r
|
||||
Ok(Ok(r)) => {
|
||||
ctx.set_success();
|
||||
r
|
||||
}
|
||||
Err(e) => {
|
||||
// TODO: ctx.set_error_kind(e.get_error_type());
|
||||
|
||||
let mut message = format!("{:?}", e);
|
||||
let db_error = e
|
||||
.downcast_ref::<tokio_postgres::Error>()
|
||||
.and_then(|e| e.as_db_error());
|
||||
fn get<'a, T: serde::Serialize>(
|
||||
db: Option<&'a DbError>,
|
||||
x: impl FnOnce(&'a DbError) -> T,
|
||||
) -> Value {
|
||||
db.map(x)
|
||||
.and_then(|t| serde_json::to_value(t).ok())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
Err(e) => {
|
||||
// TODO: ctx.set_error_kind(e.get_error_type());
|
||||
|
||||
let mut message = format!("{:?}", e);
|
||||
let db_error = e
|
||||
.downcast_ref::<tokio_postgres::Error>()
|
||||
.and_then(|e| e.as_db_error());
|
||||
fn get<'a, T: serde::Serialize>(
|
||||
db: Option<&'a DbError>,
|
||||
x: impl FnOnce(&'a DbError) -> T,
|
||||
) -> Value {
|
||||
db.map(x)
|
||||
.and_then(|t| serde_json::to_value(t).ok())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
if let Some(db_error) = db_error {
|
||||
db_error.message().clone_into(&mut message);
|
||||
}
|
||||
|
||||
let position = db_error.and_then(|db| db.position());
|
||||
let (position, internal_position, internal_query) = match position {
|
||||
Some(ErrorPosition::Original(position)) => (
|
||||
Value::String(position.to_string()),
|
||||
Value::Null,
|
||||
Value::Null,
|
||||
),
|
||||
Some(ErrorPosition::Internal { position, query }) => (
|
||||
Value::Null,
|
||||
Value::String(position.to_string()),
|
||||
Value::String(query.clone()),
|
||||
),
|
||||
None => (Value::Null, Value::Null, Value::Null),
|
||||
};
|
||||
|
||||
let code = get(db_error, |db| db.code().code());
|
||||
let severity = get(db_error, |db| db.severity());
|
||||
let detail = get(db_error, |db| db.detail());
|
||||
let hint = get(db_error, |db| db.hint());
|
||||
let where_ = get(db_error, |db| db.where_());
|
||||
let table = get(db_error, |db| db.table());
|
||||
let column = get(db_error, |db| db.column());
|
||||
let schema = get(db_error, |db| db.schema());
|
||||
let datatype = get(db_error, |db| db.datatype());
|
||||
let constraint = get(db_error, |db| db.constraint());
|
||||
let file = get(db_error, |db| db.file());
|
||||
let line = get(db_error, |db| db.line().map(|l| l.to_string()));
|
||||
let routine = get(db_error, |db| db.routine());
|
||||
|
||||
error!(
|
||||
?code,
|
||||
"sql-over-http per-client task finished with an error: {e:#}"
|
||||
);
|
||||
// TODO: this shouldn't always be bad request.
|
||||
json_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
json!({
|
||||
"message": message,
|
||||
"code": code,
|
||||
"detail": detail,
|
||||
"hint": hint,
|
||||
"position": position,
|
||||
"internalPosition": internal_position,
|
||||
"internalQuery": internal_query,
|
||||
"severity": severity,
|
||||
"where": where_,
|
||||
"table": table,
|
||||
"column": column,
|
||||
"schema": schema,
|
||||
"dataType": datatype,
|
||||
"constraint": constraint,
|
||||
"file": file,
|
||||
"line": line,
|
||||
"routine": routine,
|
||||
}),
|
||||
)?
|
||||
if let Some(db_error) = db_error {
|
||||
db_error.message().clone_into(&mut message);
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
|
||||
let position = db_error.and_then(|db| db.position());
|
||||
let (position, internal_position, internal_query) = match position {
|
||||
Some(ErrorPosition::Original(position)) => (
|
||||
Value::String(position.to_string()),
|
||||
Value::Null,
|
||||
Value::Null,
|
||||
),
|
||||
Some(ErrorPosition::Internal { position, query }) => (
|
||||
Value::Null,
|
||||
Value::String(position.to_string()),
|
||||
Value::String(query.clone()),
|
||||
),
|
||||
None => (Value::Null, Value::Null, Value::Null),
|
||||
};
|
||||
|
||||
let code = get(db_error, |db| db.code().code());
|
||||
let severity = get(db_error, |db| db.severity());
|
||||
let detail = get(db_error, |db| db.detail());
|
||||
let hint = get(db_error, |db| db.hint());
|
||||
let where_ = get(db_error, |db| db.where_());
|
||||
let table = get(db_error, |db| db.table());
|
||||
let column = get(db_error, |db| db.column());
|
||||
let schema = get(db_error, |db| db.schema());
|
||||
let datatype = get(db_error, |db| db.datatype());
|
||||
let constraint = get(db_error, |db| db.constraint());
|
||||
let file = get(db_error, |db| db.file());
|
||||
let line = get(db_error, |db| db.line().map(|l| l.to_string()));
|
||||
let routine = get(db_error, |db| db.routine());
|
||||
|
||||
error!(
|
||||
?code,
|
||||
"sql-over-http per-client task finished with an error: {e:#}"
|
||||
);
|
||||
// TODO: this shouldn't always be bad request.
|
||||
json_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
json!({
|
||||
"message": message,
|
||||
"code": code,
|
||||
"detail": detail,
|
||||
"hint": hint,
|
||||
"position": position,
|
||||
"internalPosition": internal_position,
|
||||
"internalQuery": internal_query,
|
||||
"severity": severity,
|
||||
"where": where_,
|
||||
"table": table,
|
||||
"column": column,
|
||||
"schema": schema,
|
||||
"dataType": datatype,
|
||||
"constraint": constraint,
|
||||
"file": file,
|
||||
"line": line,
|
||||
"routine": routine,
|
||||
}),
|
||||
)?
|
||||
}
|
||||
Ok(Err(Cancelled())) => {
|
||||
// TODO: when http error classification is done, distinguish between
|
||||
// timeout on sql vs timeout in proxy/cplane
|
||||
// ctx.set_error_kind(crate::error::ErrorKind::RateLimit);
|
||||
|
||||
let message = format!(
|
||||
"HTTP-Connection timed out, execution time exceeded {} seconds",
|
||||
config.http_config.request_timeout.as_secs()
|
||||
"Query cancelled, runtime exceeded. SQL queries over HTTP must not exceed {} seconds of runtime. Please consider using our websocket based connections",
|
||||
config.http_config.request_timeout.as_secs_f64()
|
||||
);
|
||||
error!(message);
|
||||
json_response(
|
||||
StatusCode::GATEWAY_TIMEOUT,
|
||||
json!({ "message": message, "code": StatusCode::GATEWAY_TIMEOUT.as_u16() }),
|
||||
StatusCode::BAD_REQUEST,
|
||||
json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }),
|
||||
)?
|
||||
}
|
||||
};
|
||||
@@ -307,12 +317,15 @@ pub async fn handle(
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
struct Cancelled();
|
||||
|
||||
async fn handle_inner(
|
||||
cancel: CancellationToken,
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> anyhow::Result<Response<Body>> {
|
||||
) -> Result<Result<Response<Body>, Cancelled>, anyhow::Error> {
|
||||
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
|
||||
.with_label_values(&[ctx.protocol])
|
||||
.guard();
|
||||
@@ -389,7 +402,18 @@ async fn handle_inner(
|
||||
};
|
||||
|
||||
// Run both operations in parallel
|
||||
let (payload, mut client) = try_join!(fetch_and_process_request, authenticate_and_connect)?;
|
||||
let (payload, mut client) = match select(
|
||||
try_join(
|
||||
pin!(fetch_and_process_request),
|
||||
pin!(authenticate_and_connect),
|
||||
),
|
||||
pin!(cancel.cancelled()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Either::Left((result, _cancelled)) => result?,
|
||||
Either::Right((_cancelled, _)) => return Ok(Err(Cancelled())),
|
||||
};
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
@@ -401,19 +425,60 @@ async fn handle_inner(
|
||||
let mut size = 0;
|
||||
let result = match payload {
|
||||
Payload::Single(stmt) => {
|
||||
let (status, results) =
|
||||
query_to_json(&*client, stmt, &mut 0, raw_output, default_array_mode)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
client.discard();
|
||||
e
|
||||
})?;
|
||||
client.check_idle(status);
|
||||
results
|
||||
let mut size = 0;
|
||||
let (inner, mut discard) = client.inner();
|
||||
let cancel_token = inner.cancel_token();
|
||||
let query = pin!(query_to_json(
|
||||
&*inner,
|
||||
stmt,
|
||||
&mut size,
|
||||
raw_output,
|
||||
default_array_mode
|
||||
));
|
||||
let cancelled = pin!(cancel.cancelled());
|
||||
let res = select(query, cancelled).await;
|
||||
match res {
|
||||
Either::Left((Ok((status, results)), _cancelled)) => {
|
||||
discard.check_idle(status);
|
||||
results
|
||||
}
|
||||
Either::Left((Err(e), _cancelled)) => {
|
||||
discard.discard();
|
||||
return Err(e);
|
||||
}
|
||||
Either::Right((_cancelled, query)) => {
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::error!(?err, "could not cancel query");
|
||||
}
|
||||
match time::timeout(time::Duration::from_millis(100), query).await {
|
||||
Ok(Ok((status, results))) => {
|
||||
discard.check_idle(status);
|
||||
results
|
||||
}
|
||||
Ok(Err(error)) => {
|
||||
let db_error = error
|
||||
.downcast_ref::<tokio_postgres::Error>()
|
||||
.and_then(|e| e.as_db_error());
|
||||
|
||||
// if errored for some other reason, it might not be safe to return
|
||||
if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
|
||||
discard.discard();
|
||||
}
|
||||
|
||||
return Ok(Err(Cancelled()));
|
||||
}
|
||||
Err(_timeout) => {
|
||||
discard.discard();
|
||||
return Ok(Err(Cancelled()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Payload::Batch(statements) => {
|
||||
info!("starting transaction");
|
||||
let (inner, mut discard) = client.inner();
|
||||
let cancel_token = inner.cancel_token();
|
||||
let mut builder = inner.build_transaction();
|
||||
if let Some(isolation_level) = txn_isolation_level {
|
||||
builder = builder.isolation_level(isolation_level);
|
||||
@@ -433,6 +498,7 @@ async fn handle_inner(
|
||||
})?;
|
||||
|
||||
let results = match query_batch(
|
||||
cancel.child_token(),
|
||||
&transaction,
|
||||
statements,
|
||||
&mut size,
|
||||
@@ -441,7 +507,7 @@ async fn handle_inner(
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(results) => {
|
||||
Ok(Ok(results)) => {
|
||||
info!("commit");
|
||||
let status = transaction.commit().await.map_err(|e| {
|
||||
// if we cannot commit - for now don't return connection to pool
|
||||
@@ -452,6 +518,15 @@ async fn handle_inner(
|
||||
discard.check_idle(status);
|
||||
results
|
||||
}
|
||||
Ok(Err(Cancelled())) => {
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::error!(?err, "could not cancel query");
|
||||
}
|
||||
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
|
||||
discard.discard();
|
||||
|
||||
return Ok(Err(Cancelled()));
|
||||
}
|
||||
Err(err) => {
|
||||
info!("rollback");
|
||||
let status = transaction.rollback().await.map_err(|e| {
|
||||
@@ -499,26 +574,44 @@ async fn handle_inner(
|
||||
// moving this later in the stack is going to be a lot of effort and ehhhh
|
||||
metrics.record_egress(len as u64);
|
||||
|
||||
Ok(response)
|
||||
Ok(Ok(response))
|
||||
}
|
||||
|
||||
async fn query_batch(
|
||||
cancel: CancellationToken,
|
||||
transaction: &Transaction<'_>,
|
||||
queries: BatchQueryData,
|
||||
total_size: &mut usize,
|
||||
raw_output: bool,
|
||||
array_mode: bool,
|
||||
) -> anyhow::Result<Vec<Value>> {
|
||||
) -> anyhow::Result<Result<Vec<Value>, Cancelled>> {
|
||||
let mut results = Vec::with_capacity(queries.queries.len());
|
||||
let mut current_size = 0;
|
||||
for stmt in queries.queries {
|
||||
// TODO: maybe we should check that the transaction bit is set here
|
||||
let (_, values) =
|
||||
query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?;
|
||||
results.push(values);
|
||||
let query = pin!(query_to_json(
|
||||
transaction,
|
||||
stmt,
|
||||
&mut current_size,
|
||||
raw_output,
|
||||
array_mode
|
||||
));
|
||||
let cancelled = pin!(cancel.cancelled());
|
||||
let res = select(query, cancelled).await;
|
||||
match res {
|
||||
// TODO: maybe we should check that the transaction bit is set here
|
||||
Either::Left((Ok((_, values)), _cancelled)) => {
|
||||
results.push(values);
|
||||
}
|
||||
Either::Left((Err(e), _cancelled)) => {
|
||||
return Err(e);
|
||||
}
|
||||
Either::Right((_cancelled, _)) => {
|
||||
return Ok(Err(Cancelled()));
|
||||
}
|
||||
}
|
||||
}
|
||||
*total_size += current_size;
|
||||
Ok(results)
|
||||
Ok(Ok(results))
|
||||
}
|
||||
|
||||
async fn query_to_json<T: GenericClient>(
|
||||
|
||||
@@ -2859,6 +2859,7 @@ class NeonProxy(PgProtocol):
|
||||
self.auth_backend = auth_backend
|
||||
self.metric_collection_endpoint = metric_collection_endpoint
|
||||
self.metric_collection_interval = metric_collection_interval
|
||||
self.http_timeout_seconds = 15
|
||||
self._popen: Optional[subprocess.Popen[bytes]] = None
|
||||
|
||||
def start(self) -> NeonProxy:
|
||||
@@ -2897,6 +2898,7 @@ class NeonProxy(PgProtocol):
|
||||
*["--proxy", f"{self.host}:{self.proxy_port}"],
|
||||
*["--mgmt", f"{self.host}:{self.mgmt_port}"],
|
||||
*["--wss", f"{self.host}:{self.external_http_port}"],
|
||||
*["--sql-over-http-timeout", f"{self.http_timeout_seconds}s"],
|
||||
*["-c", str(crt_path)],
|
||||
*["-k", str(key_path)],
|
||||
*self.auth_backend.extra_args(),
|
||||
@@ -2937,6 +2939,8 @@ class NeonProxy(PgProtocol):
|
||||
password = quote(kwargs["password"])
|
||||
expected_code = kwargs.get("expected_code")
|
||||
|
||||
log.info(f"Executing http query: {query}")
|
||||
|
||||
connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres"
|
||||
response = requests.post(
|
||||
f"https://{self.domain}:{self.external_http_port}/sql",
|
||||
@@ -2959,6 +2963,8 @@ class NeonProxy(PgProtocol):
|
||||
password = kwargs["password"]
|
||||
expected_code = kwargs.get("expected_code")
|
||||
|
||||
log.info(f"Executing http2 query: {query}")
|
||||
|
||||
connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres"
|
||||
async with httpx.AsyncClient(
|
||||
http2=True, verify=str(self.test_output_dir / "proxy.crt")
|
||||
|
||||
@@ -564,3 +564,35 @@ async def test_sql_over_http2(static_proxy: NeonProxy):
|
||||
"select 42 as answer", [], user="http", password="http", expected_code=200
|
||||
)
|
||||
assert resp["rows"] == [{"answer": 42}]
|
||||
|
||||
|
||||
def test_sql_over_http_timeout_cancel(static_proxy: NeonProxy):
|
||||
static_proxy.safe_psql("create role http with login password 'http' superuser")
|
||||
|
||||
static_proxy.safe_psql("create table test_table ( id int primary key )")
|
||||
|
||||
# insert into a table, with a unique constraint, after sleeping for n seconds
|
||||
query = "WITH temp AS ( \
|
||||
SELECT pg_sleep($1) as sleep, $2::int as id \
|
||||
) INSERT INTO test_table (id) SELECT id FROM temp"
|
||||
|
||||
# expect to fail with timeout
|
||||
res = static_proxy.http_query(
|
||||
query,
|
||||
[static_proxy.http_timeout_seconds + 1, 1],
|
||||
user="http",
|
||||
password="http",
|
||||
expected_code=400,
|
||||
)
|
||||
assert "Query cancelled, runtime exceeded" in res["message"], "HTTP query should time out"
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
res = static_proxy.http_query(query, [1, 1], user="http", password="http", expected_code=200)
|
||||
assert res["command"] == "INSERT", "HTTP query should insert"
|
||||
assert res["rowCount"] == 1, "HTTP query should insert"
|
||||
|
||||
res = static_proxy.http_query(query, [0, 1], user="http", password="http", expected_code=400)
|
||||
assert (
|
||||
"duplicate key value violates unique constraint" in res["message"]
|
||||
), "HTTP query should conflict"
|
||||
|
||||
Reference in New Issue
Block a user