diff --git a/libs/proxy/postgres-protocol2/src/message/backend.rs b/libs/proxy/postgres-protocol2/src/message/backend.rs index b1728ef37d..5574599381 100644 --- a/libs/proxy/postgres-protocol2/src/message/backend.rs +++ b/libs/proxy/postgres-protocol2/src/message/backend.rs @@ -600,6 +600,7 @@ impl ParameterStatusBody { } } +#[derive(Clone, Copy)] pub struct ReadyForQueryBody { status: u8, } diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 6f50e1a610..67a68c7dea 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -29,6 +29,9 @@ pub struct Responses { waiting: usize, /// number of ReadyForQuery messages received. received: usize, + + /// The last query status we received. + last_status: ReadyForQueryStatus, } impl Responses { @@ -39,7 +42,8 @@ impl Responses { let received = self.received; // increase the query head if this is the last message. - if let Message::ReadyForQuery(_) = message { + if let Message::ReadyForQuery(ref status) = message { + self.last_status = (*status).into(); self.received += 1; } @@ -68,6 +72,15 @@ impl Responses { pub async fn next(&mut self) -> Result { future::poll_fn(|cx| self.poll_next(cx)).await } + + pub async fn wait_until_ready(&mut self) -> Result { + while self.received < self.waiting { + if let Message::ReadyForQuery(status) = self.next().await? { + return Ok(status.into()); + } + } + Ok(self.last_status) + } } /// A cache of type info and prepared statements for fetching type info @@ -92,13 +105,6 @@ impl InnerClient { Ok(PartialQuery(Some(self))) } - // pub fn send_with_sync(&mut self, f: F) -> Result<&mut Responses, Error> - // where - // F: FnOnce(&mut BytesMut) -> Result<(), Error>, - // { - // self.start()?.send_with_sync(f) - // } - pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> { self.responses.waiting += 1; @@ -197,6 +203,8 @@ impl Client { cur: BackendMessages::empty(), waiting: 0, received: 0, + // new connections are always idle. + last_status: ReadyForQueryStatus::Idle, }, buffer: Default::default(), }, @@ -230,6 +238,10 @@ impl Client { rx } + pub async fn wait_until_ready(&mut self) -> Result { + self.inner_mut().responses.wait_until_ready().await + } + /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip pub async fn query_raw_txt( diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index bc68b5e677..89f17ebdd4 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -2,8 +2,8 @@ use std::pin::pin; use std::sync::Arc; use bytes::Bytes; -use futures::future::{Either, select, try_join}; -use futures::{StreamExt, TryFutureExt}; +use futures::future::try_join; +use futures::{FutureExt, TryFutureExt, TryStreamExt}; use http::Method; use http::header::AUTHORIZATION; use http_body_util::combinators::BoxBody; @@ -495,7 +495,7 @@ async fn handle_db_inner( .http_conn_content_length_bytes .observe(HttpDirection::Request, body.len() as f64); - debug!(length = body.len(), "request payload read"); + debug!(length = body.len(), "request payload read "); let payload: Payload = serde_json::from_slice(&body)?; Ok::(payload) // Adjust error type accordingly } @@ -566,29 +566,32 @@ async fn handle_db_inner( .status(StatusCode::OK) .header(header::CONTENT_TYPE, "application/json"); - // Now execute the query and return the result. - let json_output = match payload { - Payload::Single(stmt) => { - stmt.process(&config.http_config, cancel, &mut client, parsed_headers) - .await? + if let Payload::Batch(_) = payload { + if parsed_headers.txn_read_only { + response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE); } - Payload::Batch(statements) => { - if parsed_headers.txn_read_only { - response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE); - } - if parsed_headers.txn_deferrable { - response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE); - } - if let Some(txn_isolation_level) = parsed_headers - .txn_isolation_level - .and_then(map_isolation_level_to_headers) - { - response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level); - } + if parsed_headers.txn_deferrable { + response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE); + } + if let Some(txn_isolation_level) = parsed_headers + .txn_isolation_level + .and_then(map_isolation_level_to_headers) + { + response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level); + } + } - statements - .process(&config.http_config, cancel, &mut client, parsed_headers) - .await? + // Now execute the query and return the result. + let json_output = match payload + .process(&config.http_config, cancel, &mut client, parsed_headers) + .await + { + Ok(json_output) => json_output, + Err(error) => { + if let SqlOverHttpError::Cancelled(_) = error { + cancel_query(&mut client).await; + } + return Err(error); } }; @@ -673,7 +676,7 @@ async fn handle_auth_broker_inner( .map(|b| b.boxed())) } -impl QueryData { +impl Payload { async fn process( self, config: &'static HttpConfig, @@ -682,85 +685,11 @@ impl QueryData { parsed_headers: HttpHeaders, ) -> Result { let (inner, mut discard) = client.inner(); - let cancel_token = inner.cancel_token(); - let mut json_buf = vec![]; + let needs_tx = matches!(self, Payload::Batch(_)); - let batch_result = match select( - pin!(query_to_json( - config, - &mut *inner, - self, - json::ValueSer::new(&mut json_buf), - parsed_headers - )), - pin!(cancel.cancelled()), - ) - .await - { - Either::Left((res, __not_yet_cancelled)) => res, - Either::Right((_cancelled, query)) => { - tracing::info!("cancelling query"); - if let Err(err) = cancel_token.cancel_query(NoTls).await { - tracing::warn!(?err, "could not cancel query"); - } - // wait for the query cancellation - match time::timeout(time::Duration::from_millis(100), query).await { - // query successed before it was cancelled. - Ok(Ok(status)) => Ok(status), - // query failed or was cancelled. - Ok(Err(error)) => { - let db_error = match &error { - SqlOverHttpError::ConnectCompute( - HttpConnError::PostgresConnectionError(e), - ) - | SqlOverHttpError::Postgres(e) => e.as_db_error(), - _ => None, - }; - - // if errored for some other reason, it might not be safe to return - if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) { - discard.discard(); - } - - return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); - } - Err(_timeout) => { - discard.discard(); - return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); - } - } - } - }; - - match batch_result { - // The query successfully completed. - Ok(_) => { - let json_output = String::from_utf8(json_buf).expect("json should be valid utf8"); - Ok(json_output) - } - // The query failed with an error - Err(e) => { - discard.discard(); - Err(e) - } - } - } -} - -impl BatchQueryData { - async fn process( - self, - config: &'static HttpConfig, - cancel: CancellationToken, - client: &mut Client, - parsed_headers: HttpHeaders, - ) -> Result { - info!("starting transaction"); - let (inner, mut discard) = client.inner(); - let cancel_token = inner.cancel_token(); - - { + if needs_tx { + info!("starting transaction"); let query = TransactionBuilder { isolation_level: parsed_headers.txn_isolation_level, read_only: parsed_headers.txn_read_only.then_some(true), @@ -779,93 +708,74 @@ impl BatchQueryData { .map_err(SqlOverHttpError::Postgres)?; } - let res = - query_batch_to_json(config, cancel.child_token(), inner, self, parsed_headers).await; - - let json_output = match res { - Ok(json_output) => { - info!("commit"); - inner - .commit() - .await - .inspect_err(|_| { - // if we cannot commit - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - }) - .map_err(SqlOverHttpError::Postgres)?; - json_output + let json_output = json::value_to_string!(|value| match self { + Payload::Single(query) => { + query_to_json(config, &cancel, inner, query, value, parsed_headers).await?; } - Err(SqlOverHttpError::Cancelled(_)) => { - if let Err(err) = cancel_token.cancel_query(NoTls).await { - tracing::warn!(?err, "could not cancel query"); + Payload::Batch(batch) => { + let mut obj = value.object(); + let mut results = obj.key("results").list(); + + for query in batch.queries { + let value = results.entry(); + query_to_json(config, &cancel, inner, query, value, parsed_headers).await?; } - // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe. - discard.discard(); - return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); + results.finish(); + obj.finish(); } - Err(err) => { - return Err(err); - } - }; + }); + + if needs_tx { + inner + .commit() + .await + .inspect_err(|_| { + // if we cannot commit - for now don't return connection to pool + // TODO: get a query status from the error + discard.discard(); + }) + .map_err(SqlOverHttpError::Postgres)?; + } Ok(json_output) } } -async fn query_batch( - config: &'static HttpConfig, - cancel: CancellationToken, - client: &mut postgres_client::Client, - queries: BatchQueryData, - parsed_headers: HttpHeaders, - results: &mut json::ListSer<'_>, -) -> Result<(), SqlOverHttpError> { - for stmt in queries.queries { - let query = pin!(query_to_json( - config, - client, - stmt, - results.entry(), - parsed_headers, - )); - let cancelled = pin!(cancel.cancelled()); - let res = select(query, cancelled).await; - match res { - // TODO: maybe we should check that the transaction bit is set here - Either::Left((Ok(_), _cancelled)) => {} - Either::Left((Err(e), _cancelled)) => { - return Err(e); - } - Either::Right((_cancelled, _)) => { - return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); - } - } +async fn cancel_query(client: &mut Client) { + let (inner, mut discard) = client.inner(); + let cancel_token = inner.cancel_token(); + + if let Err(err) = cancel_token.cancel_query(NoTls).await { + tracing::warn!(?err, "could not cancel query"); + + // couldn't reach the server. let's just throw away this conn + discard.discard(); + return; } - Ok(()) -} + // wait for the query cancellation + match time::timeout(time::Duration::from_millis(100), inner.wait_until_ready()).await { + // we managed to cancel the query. + Ok(Ok(_)) => {} + // query failed or was cancelled. + Ok(Err(error)) => { + let db_error = error.as_db_error(); -async fn query_batch_to_json( - config: &'static HttpConfig, - cancel: CancellationToken, - client: &mut postgres_client::Client, - queries: BatchQueryData, - headers: HttpHeaders, -) -> Result { - let json_output = json::value_to_string!(|obj| json::value_as_object!(|obj| { - let results = obj.key("results"); - json::value_as_list!(|results| { - query_batch(config, cancel, client, queries, headers, results).await?; - }); - })); - - Ok(json_output) + // if errored for some other reason, it might not be safe to reuse the connection. + if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) { + discard.discard(); + } + } + Err(_timeout) => { + discard.discard(); + } + } } async fn query_to_json( config: &'static HttpConfig, + cancel: &CancellationToken, client: &mut postgres_client::Client, data: QueryData, output: json::ValueSer<'_>, @@ -874,10 +784,13 @@ async fn query_to_json( let query_start = Instant::now(); let mut output = json::ObjectSer::new(output); - let mut row_stream = client - .query_raw_txt(&data.query, data.params) - .await - .map_err(SqlOverHttpError::Postgres)?; + + let mut row_stream = + run_until_cancelled(client.query_raw_txt(&data.query, data.params), cancel) + .await + .ok_or(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))? + .map_err(SqlOverHttpError::Postgres)?; + let query_acknowledged = Instant::now(); let mut json_fields = output.key("fields").list(); @@ -903,8 +816,13 @@ async fn query_to_json( // big. let mut rows = 0; let mut json_rows = output.key("rows").list(); - while let Some(row) = row_stream.next().await { - let row = row.map_err(SqlOverHttpError::Postgres)?; + loop { + let row = run_until_cancelled(row_stream.try_next(), cancel) + .await + .ok_or(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))? + .map_err(SqlOverHttpError::Postgres)?; + + let Some(row) = row else { break }; // we don't have a streaming response support yet so this is to prevent OOM // from a malicious query (eg a cross join) diff --git a/proxy/src/util.rs b/proxy/src/util.rs index c89ebab008..aaf9ece8fa 100644 --- a/proxy/src/util.rs +++ b/proxy/src/util.rs @@ -1,23 +1,50 @@ -use std::pin::pin; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; -use futures::future::{Either, select}; +use futures::FutureExt; use tokio_util::sync::CancellationToken; -pub async fn run_until_cancelled( +pub fn run_until_cancelled( f: F, cancellation_token: &CancellationToken, -) -> Option { - run_until(f, cancellation_token.cancelled()).await.ok() +) -> impl Future> { + run_until(f, cancellation_token.cancelled()).map(|r| r.ok()) } /// Runs the future `f` unless interrupted by future `condition`. -pub async fn run_until( +pub fn run_until( f: F1, condition: F2, -) -> Result { - match select(pin!(f), pin!(condition)).await { - Either::Left((f1, _)) => Ok(f1), - Either::Right((f2, _)) => Err(f2), +) -> impl Future> { + RunUntil { a: f, b: condition } +} + +pin_project_lite::pin_project! { + struct RunUntil { + #[pin] a: A, + #[pin] b: B, + } +} + +impl Future for RunUntil +where + A: Future, + B: Future, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if let Poll::Ready(a) = this.a.poll(cx) { + return Poll::Ready(Ok(a)); + } + if let Poll::Ready(b) = this.b.poll(cx) { + return Poll::Ready(Err(b)); + } + Poll::Pending } }