diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 67a68c7dea..10fb853175 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -238,6 +238,7 @@ impl Client { rx } + /// Wait until this connection has no more active queries. pub async fn wait_until_ready(&mut self) -> Result { self.inner_mut().responses.wait_until_ready().await } @@ -329,11 +330,6 @@ impl Client { Ok(()) } - /// Commit the transaction. - pub async fn commit(&mut self) -> Result { - self.batch_execute("COMMIT").await - } - /// Constructs a cancellation token that can later be used to request cancellation of a query running on the /// connection associated with this client. pub fn cancel_token(&self) -> CancelToken { diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 89f17ebdd4..66ee26b781 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,9 +1,8 @@ -use std::pin::pin; use std::sync::Arc; use bytes::Bytes; use futures::future::try_join; -use futures::{FutureExt, TryFutureExt, TryStreamExt}; +use futures::{TryFutureExt, TryStreamExt}; use http::Method; use http::header::AUTHORIZATION; use http_body_util::combinators::BoxBody; @@ -530,13 +529,13 @@ async fn handle_db_inner( let (cli_inner, _dsc) = client.client_inner(); cli_inner.set_jwt_session(&payload).await?; } - Client::Local(client) + Box::new(Client::Local(client)) } _ => { let client = backend .connect_to_compute(ctx, conn_info, keys, !allow_pool) .await?; - Client::Remote(client) + Box::new(Client::Remote(client)) } }; @@ -550,10 +549,7 @@ async fn handle_db_inner( let (payload, mut client) = match run_until_cancelled( // Run both operations in parallel - try_join( - pin!(fetch_and_process_request), - pin!(authenticate_and_connect), - ), + try_join(fetch_and_process_request, authenticate_and_connect), &cancel, ) .await @@ -562,6 +558,20 @@ async fn handle_db_inner( None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)), }; + // Now execute the query and return the result. + let json_output = match payload + .process(&config.http_config, cancel, &mut client, parsed_headers) + .await + { + Ok(json_output) => json_output, + Err(error) => { + if let SqlOverHttpError::Cancelled(_) = error { + cancel_query(&mut client).await; + } + return Err(error); + } + }; + let mut response = Response::builder() .status(StatusCode::OK) .header(header::CONTENT_TYPE, "application/json"); @@ -581,22 +591,6 @@ async fn handle_db_inner( } } - // Now execute the query and return the result. - let json_output = match payload - .process(&config.http_config, cancel, &mut client, parsed_headers) - .await - { - Ok(json_output) => json_output, - Err(error) => { - if let SqlOverHttpError::Cancelled(_) = error { - cancel_query(&mut client).await; - } - return Err(error); - } - }; - - let metrics = client.metrics(ctx); - let len = json_output.len(); let response = response .body( @@ -610,6 +604,7 @@ async fn handle_db_inner( // count the egress bytes - we miss the TLS and header overhead but oh well... // moving this later in the stack is going to be a lot of effort and ehhhh + let metrics = client.metrics(ctx); metrics.record_egress(len as u64); metrics.record_ingress(request_len as u64); @@ -678,7 +673,7 @@ async fn handle_auth_broker_inner( impl Payload { async fn process( - self, + &self, config: &'static HttpConfig, cancel: CancellationToken, client: &mut Client, @@ -708,7 +703,7 @@ impl Payload { .map_err(SqlOverHttpError::Postgres)?; } - let json_output = json::value_to_string!(|value| match self { + let json_output = json::value_to_string!(|value| match &self { Payload::Single(query) => { query_to_json(config, &cancel, inner, query, value, parsed_headers).await?; } @@ -716,7 +711,7 @@ impl Payload { let mut obj = value.object(); let mut results = obj.key("results").list(); - for query in batch.queries { + for query in &batch.queries { let value = results.entry(); query_to_json(config, &cancel, inner, query, value, parsed_headers).await?; } @@ -728,7 +723,7 @@ impl Payload { if needs_tx { inner - .commit() + .batch_execute("COMMIT") .await .inspect_err(|_| { // if we cannot commit - for now don't return connection to pool @@ -777,22 +772,22 @@ async fn query_to_json( config: &'static HttpConfig, cancel: &CancellationToken, client: &mut postgres_client::Client, - data: QueryData, + data: &QueryData, output: json::ValueSer<'_>, parsed_headers: HttpHeaders, ) -> Result { let query_start = Instant::now(); - let mut output = json::ObjectSer::new(output); - - let mut row_stream = - run_until_cancelled(client.query_raw_txt(&data.query, data.params), cancel) - .await - .ok_or(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))? - .map_err(SqlOverHttpError::Postgres)?; + let params = data.params.iter().map(Option::as_deref); + let mut row_stream = run_until_cancelled(client.query_raw_txt(&data.query, params), cancel) + .await + .ok_or(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))? + .map_err(SqlOverHttpError::Postgres)?; let query_acknowledged = Instant::now(); + let mut output = json::ObjectSer::new(output); + let mut json_fields = output.key("fields").list(); for c in row_stream.statement.columns() { let json_field = json_fields.entry();