From 09699d4bd883d9e1753dcff22406d8455b5f133e Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 12 Mar 2024 11:52:00 +0000 Subject: [PATCH] proxy: cancel http queries on timeout (#7031) ## Problem On HTTP query timeout, we should try and cancel the current in-flight SQL query. ## Summary of changes Trigger a cancellation command in postgres once the timeout is reach --- proxy/src/serverless/conn_pool.rs | 9 +- proxy/src/serverless/sql_over_http.rs | 313 +++++++++++++++++--------- test_runner/fixtures/neon_fixtures.py | 6 + test_runner/regress/test_proxy.py | 32 +++ 4 files changed, 242 insertions(+), 118 deletions(-) diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 73f213d074..901e30224b 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -612,13 +612,6 @@ impl Client { let inner = inner.as_mut().expect("client inner should not be removed"); (&mut inner.inner, Discard { pool, conn_info }) } - - pub fn check_idle(&mut self, status: ReadyForQueryStatus) { - self.inner().1.check_idle(status) - } - pub fn discard(&mut self) { - self.inner().1.discard() - } } impl Discard<'_, C> { @@ -739,7 +732,7 @@ mod tests { { let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); assert_eq!(0, pool.get_global_connections_count()); - client.discard(); + client.inner().1.discard(); // Discard should not add the connection from the pool. assert_eq!(0, pool.get_global_connections_count()); } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 74af985211..20d9795b47 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,6 +1,10 @@ +use std::pin::pin; use std::sync::Arc; use anyhow::bail; +use futures::future::select; +use futures::future::try_join; +use futures::future::Either; use futures::StreamExt; use hyper::body::HttpBody; use hyper::header; @@ -11,13 +15,16 @@ use hyper::StatusCode; use hyper::{Body, HeaderMap, Request}; use serde_json::json; use serde_json::Value; -use tokio::try_join; +use tokio::time; use tokio_postgres::error::DbError; use tokio_postgres::error::ErrorPosition; +use tokio_postgres::error::SqlState; use tokio_postgres::GenericClient; use tokio_postgres::IsolationLevel; +use tokio_postgres::NoTls; use tokio_postgres::ReadyForQueryStatus; use tokio_postgres::Transaction; +use tokio_util::sync::CancellationToken; use tracing::error; use tracing::info; use url::Url; @@ -194,108 +201,111 @@ pub async fn handle( request: Request, backend: Arc, ) -> Result, ApiError> { - let result = tokio::time::timeout( - config.http_config.request_timeout, - handle_inner(config, &mut ctx, request, backend), - ) - .await; + let cancel = CancellationToken::new(); + let cancel2 = cancel.clone(); + let handle = tokio::spawn(async move { + time::sleep(config.http_config.request_timeout).await; + cancel2.cancel(); + }); + + let result = handle_inner(cancel, config, &mut ctx, request, backend).await; + handle.abort(); + let mut response = match result { - Ok(r) => match r { - Ok(r) => { - ctx.set_success(); - r + Ok(Ok(r)) => { + ctx.set_success(); + r + } + Err(e) => { + // TODO: ctx.set_error_kind(e.get_error_type()); + + let mut message = format!("{:?}", e); + let db_error = e + .downcast_ref::() + .and_then(|e| e.as_db_error()); + fn get<'a, T: serde::Serialize>( + db: Option<&'a DbError>, + x: impl FnOnce(&'a DbError) -> T, + ) -> Value { + db.map(x) + .and_then(|t| serde_json::to_value(t).ok()) + .unwrap_or_default() } - Err(e) => { - // TODO: ctx.set_error_kind(e.get_error_type()); - let mut message = format!("{:?}", e); - let db_error = e - .downcast_ref::() - .and_then(|e| e.as_db_error()); - fn get<'a, T: serde::Serialize>( - db: Option<&'a DbError>, - x: impl FnOnce(&'a DbError) -> T, - ) -> Value { - db.map(x) - .and_then(|t| serde_json::to_value(t).ok()) - .unwrap_or_default() - } - - if let Some(db_error) = db_error { - db_error.message().clone_into(&mut message); - } - - let position = db_error.and_then(|db| db.position()); - let (position, internal_position, internal_query) = match position { - Some(ErrorPosition::Original(position)) => ( - Value::String(position.to_string()), - Value::Null, - Value::Null, - ), - Some(ErrorPosition::Internal { position, query }) => ( - Value::Null, - Value::String(position.to_string()), - Value::String(query.clone()), - ), - None => (Value::Null, Value::Null, Value::Null), - }; - - let code = get(db_error, |db| db.code().code()); - let severity = get(db_error, |db| db.severity()); - let detail = get(db_error, |db| db.detail()); - let hint = get(db_error, |db| db.hint()); - let where_ = get(db_error, |db| db.where_()); - let table = get(db_error, |db| db.table()); - let column = get(db_error, |db| db.column()); - let schema = get(db_error, |db| db.schema()); - let datatype = get(db_error, |db| db.datatype()); - let constraint = get(db_error, |db| db.constraint()); - let file = get(db_error, |db| db.file()); - let line = get(db_error, |db| db.line().map(|l| l.to_string())); - let routine = get(db_error, |db| db.routine()); - - error!( - ?code, - "sql-over-http per-client task finished with an error: {e:#}" - ); - // TODO: this shouldn't always be bad request. - json_response( - StatusCode::BAD_REQUEST, - json!({ - "message": message, - "code": code, - "detail": detail, - "hint": hint, - "position": position, - "internalPosition": internal_position, - "internalQuery": internal_query, - "severity": severity, - "where": where_, - "table": table, - "column": column, - "schema": schema, - "dataType": datatype, - "constraint": constraint, - "file": file, - "line": line, - "routine": routine, - }), - )? + if let Some(db_error) = db_error { + db_error.message().clone_into(&mut message); } - }, - Err(_) => { + + let position = db_error.and_then(|db| db.position()); + let (position, internal_position, internal_query) = match position { + Some(ErrorPosition::Original(position)) => ( + Value::String(position.to_string()), + Value::Null, + Value::Null, + ), + Some(ErrorPosition::Internal { position, query }) => ( + Value::Null, + Value::String(position.to_string()), + Value::String(query.clone()), + ), + None => (Value::Null, Value::Null, Value::Null), + }; + + let code = get(db_error, |db| db.code().code()); + let severity = get(db_error, |db| db.severity()); + let detail = get(db_error, |db| db.detail()); + let hint = get(db_error, |db| db.hint()); + let where_ = get(db_error, |db| db.where_()); + let table = get(db_error, |db| db.table()); + let column = get(db_error, |db| db.column()); + let schema = get(db_error, |db| db.schema()); + let datatype = get(db_error, |db| db.datatype()); + let constraint = get(db_error, |db| db.constraint()); + let file = get(db_error, |db| db.file()); + let line = get(db_error, |db| db.line().map(|l| l.to_string())); + let routine = get(db_error, |db| db.routine()); + + error!( + ?code, + "sql-over-http per-client task finished with an error: {e:#}" + ); + // TODO: this shouldn't always be bad request. + json_response( + StatusCode::BAD_REQUEST, + json!({ + "message": message, + "code": code, + "detail": detail, + "hint": hint, + "position": position, + "internalPosition": internal_position, + "internalQuery": internal_query, + "severity": severity, + "where": where_, + "table": table, + "column": column, + "schema": schema, + "dataType": datatype, + "constraint": constraint, + "file": file, + "line": line, + "routine": routine, + }), + )? + } + Ok(Err(Cancelled())) => { // TODO: when http error classification is done, distinguish between // timeout on sql vs timeout in proxy/cplane // ctx.set_error_kind(crate::error::ErrorKind::RateLimit); let message = format!( - "HTTP-Connection timed out, execution time exceeded {} seconds", - config.http_config.request_timeout.as_secs() + "Query cancelled, runtime exceeded. SQL queries over HTTP must not exceed {} seconds of runtime. Please consider using our websocket based connections", + config.http_config.request_timeout.as_secs_f64() ); error!(message); json_response( - StatusCode::GATEWAY_TIMEOUT, - json!({ "message": message, "code": StatusCode::GATEWAY_TIMEOUT.as_u16() }), + StatusCode::BAD_REQUEST, + json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }), )? } }; @@ -307,12 +317,15 @@ pub async fn handle( Ok(response) } +struct Cancelled(); + async fn handle_inner( + cancel: CancellationToken, config: &'static ProxyConfig, ctx: &mut RequestMonitoring, request: Request, backend: Arc, -) -> anyhow::Result> { +) -> Result, Cancelled>, anyhow::Error> { let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE .with_label_values(&[ctx.protocol]) .guard(); @@ -389,7 +402,18 @@ async fn handle_inner( }; // Run both operations in parallel - let (payload, mut client) = try_join!(fetch_and_process_request, authenticate_and_connect)?; + let (payload, mut client) = match select( + try_join( + pin!(fetch_and_process_request), + pin!(authenticate_and_connect), + ), + pin!(cancel.cancelled()), + ) + .await + { + Either::Left((result, _cancelled)) => result?, + Either::Right((_cancelled, _)) => return Ok(Err(Cancelled())), + }; let mut response = Response::builder() .status(StatusCode::OK) @@ -401,19 +425,60 @@ async fn handle_inner( let mut size = 0; let result = match payload { Payload::Single(stmt) => { - let (status, results) = - query_to_json(&*client, stmt, &mut 0, raw_output, default_array_mode) - .await - .map_err(|e| { - client.discard(); - e - })?; - client.check_idle(status); - results + let mut size = 0; + let (inner, mut discard) = client.inner(); + let cancel_token = inner.cancel_token(); + let query = pin!(query_to_json( + &*inner, + stmt, + &mut size, + raw_output, + default_array_mode + )); + let cancelled = pin!(cancel.cancelled()); + let res = select(query, cancelled).await; + match res { + Either::Left((Ok((status, results)), _cancelled)) => { + discard.check_idle(status); + results + } + Either::Left((Err(e), _cancelled)) => { + discard.discard(); + return Err(e); + } + Either::Right((_cancelled, query)) => { + if let Err(err) = cancel_token.cancel_query(NoTls).await { + tracing::error!(?err, "could not cancel query"); + } + match time::timeout(time::Duration::from_millis(100), query).await { + Ok(Ok((status, results))) => { + discard.check_idle(status); + results + } + Ok(Err(error)) => { + let db_error = error + .downcast_ref::() + .and_then(|e| e.as_db_error()); + + // if errored for some other reason, it might not be safe to return + if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) { + discard.discard(); + } + + return Ok(Err(Cancelled())); + } + Err(_timeout) => { + discard.discard(); + return Ok(Err(Cancelled())); + } + } + } + } } Payload::Batch(statements) => { info!("starting transaction"); let (inner, mut discard) = client.inner(); + let cancel_token = inner.cancel_token(); let mut builder = inner.build_transaction(); if let Some(isolation_level) = txn_isolation_level { builder = builder.isolation_level(isolation_level); @@ -433,6 +498,7 @@ async fn handle_inner( })?; let results = match query_batch( + cancel.child_token(), &transaction, statements, &mut size, @@ -441,7 +507,7 @@ async fn handle_inner( ) .await { - Ok(results) => { + Ok(Ok(results)) => { info!("commit"); let status = transaction.commit().await.map_err(|e| { // if we cannot commit - for now don't return connection to pool @@ -452,6 +518,15 @@ async fn handle_inner( discard.check_idle(status); results } + Ok(Err(Cancelled())) => { + if let Err(err) = cancel_token.cancel_query(NoTls).await { + tracing::error!(?err, "could not cancel query"); + } + // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe. + discard.discard(); + + return Ok(Err(Cancelled())); + } Err(err) => { info!("rollback"); let status = transaction.rollback().await.map_err(|e| { @@ -499,26 +574,44 @@ async fn handle_inner( // moving this later in the stack is going to be a lot of effort and ehhhh metrics.record_egress(len as u64); - Ok(response) + Ok(Ok(response)) } async fn query_batch( + cancel: CancellationToken, transaction: &Transaction<'_>, queries: BatchQueryData, total_size: &mut usize, raw_output: bool, array_mode: bool, -) -> anyhow::Result> { +) -> anyhow::Result, Cancelled>> { let mut results = Vec::with_capacity(queries.queries.len()); let mut current_size = 0; for stmt in queries.queries { - // TODO: maybe we should check that the transaction bit is set here - let (_, values) = - query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?; - results.push(values); + let query = pin!(query_to_json( + transaction, + stmt, + &mut current_size, + raw_output, + array_mode + )); + let cancelled = pin!(cancel.cancelled()); + let res = select(query, cancelled).await; + match res { + // TODO: maybe we should check that the transaction bit is set here + Either::Left((Ok((_, values)), _cancelled)) => { + results.push(values); + } + Either::Left((Err(e), _cancelled)) => { + return Err(e); + } + Either::Right((_cancelled, _)) => { + return Ok(Err(Cancelled())); + } + } } *total_size += current_size; - Ok(results) + Ok(Ok(results)) } async fn query_to_json( diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 234bfa8bf9..b7196a2556 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2859,6 +2859,7 @@ class NeonProxy(PgProtocol): self.auth_backend = auth_backend self.metric_collection_endpoint = metric_collection_endpoint self.metric_collection_interval = metric_collection_interval + self.http_timeout_seconds = 15 self._popen: Optional[subprocess.Popen[bytes]] = None def start(self) -> NeonProxy: @@ -2897,6 +2898,7 @@ class NeonProxy(PgProtocol): *["--proxy", f"{self.host}:{self.proxy_port}"], *["--mgmt", f"{self.host}:{self.mgmt_port}"], *["--wss", f"{self.host}:{self.external_http_port}"], + *["--sql-over-http-timeout", f"{self.http_timeout_seconds}s"], *["-c", str(crt_path)], *["-k", str(key_path)], *self.auth_backend.extra_args(), @@ -2937,6 +2939,8 @@ class NeonProxy(PgProtocol): password = quote(kwargs["password"]) expected_code = kwargs.get("expected_code") + log.info(f"Executing http query: {query}") + connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres" response = requests.post( f"https://{self.domain}:{self.external_http_port}/sql", @@ -2959,6 +2963,8 @@ class NeonProxy(PgProtocol): password = kwargs["password"] expected_code = kwargs.get("expected_code") + log.info(f"Executing http2 query: {query}") + connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres" async with httpx.AsyncClient( http2=True, verify=str(self.test_output_dir / "proxy.crt") diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 9905f120e1..078589d8eb 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -564,3 +564,35 @@ async def test_sql_over_http2(static_proxy: NeonProxy): "select 42 as answer", [], user="http", password="http", expected_code=200 ) assert resp["rows"] == [{"answer": 42}] + + +def test_sql_over_http_timeout_cancel(static_proxy: NeonProxy): + static_proxy.safe_psql("create role http with login password 'http' superuser") + + static_proxy.safe_psql("create table test_table ( id int primary key )") + + # insert into a table, with a unique constraint, after sleeping for n seconds + query = "WITH temp AS ( \ + SELECT pg_sleep($1) as sleep, $2::int as id \ + ) INSERT INTO test_table (id) SELECT id FROM temp" + + # expect to fail with timeout + res = static_proxy.http_query( + query, + [static_proxy.http_timeout_seconds + 1, 1], + user="http", + password="http", + expected_code=400, + ) + assert "Query cancelled, runtime exceeded" in res["message"], "HTTP query should time out" + + time.sleep(2) + + res = static_proxy.http_query(query, [1, 1], user="http", password="http", expected_code=200) + assert res["command"] == "INSERT", "HTTP query should insert" + assert res["rowCount"] == 1, "HTTP query should insert" + + res = static_proxy.http_query(query, [0, 1], user="http", password="http", expected_code=400) + assert ( + "duplicate key value violates unique constraint" in res["message"] + ), "HTTP query should conflict"