use std::pin::pin; use std::sync::Arc; use bytes::Bytes; use futures::future::{select, try_join, Either}; use futures::{StreamExt, TryFutureExt}; use http::header::AUTHORIZATION; use http::Method; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full}; use hyper::body::{Body, Incoming}; use hyper::http::{HeaderName, HeaderValue}; use hyper::{header, HeaderMap, Request, Response, StatusCode}; use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; use serde_json::Value; use tokio::time; use tokio_postgres::error::{DbError, ErrorPosition, SqlState}; use tokio_postgres::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction}; use tokio_util::sync::CancellationToken; use tracing::{error, info}; use typed_json::json; use url::Url; use urlencoding; use utils::http::error::ApiError; use super::backend::{LocalProxyConnError, PoolingBackend}; use super::conn_pool::{AuthData, ConnInfo, ConnInfoWithAuth}; use super::http_util::json_response; use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError}; use super::{conn_pool, local_conn_pool}; use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; use crate::auth::{endpoint_sni, ComputeUserInfoParseError}; use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig}; use crate::context::RequestMonitoring; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::{HttpDirection, Metrics}; use crate::proxy::{run_until_cancelled, NeonOptions}; use crate::serverless::backend::HttpConnError; use crate::usage_metrics::{MetricCounter, MetricCounterRecorder}; use crate::{DbName, RoleName}; #[derive(serde::Deserialize)] #[serde(rename_all = "camelCase")] struct QueryData { query: String, #[serde(deserialize_with = "bytes_to_pg_text")] params: Vec>, #[serde(default)] array_mode: Option, } #[derive(serde::Deserialize)] struct BatchQueryData { queries: Vec, } #[derive(serde::Deserialize)] #[serde(untagged)] enum Payload { Single(QueryData), Batch(BatchQueryData), } static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string"); static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode"); static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in"); static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level"); static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only"); static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable"); static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result>, D::Error> where D: serde::de::Deserializer<'de>, { // TODO: consider avoiding the allocation here. let json: Vec = serde::de::Deserialize::deserialize(deserializer)?; Ok(json_to_pg_text(json)) } #[derive(Debug, thiserror::Error)] pub(crate) enum ConnInfoError { #[error("invalid header: {0}")] InvalidHeader(&'static HeaderName), #[error("invalid connection string: {0}")] UrlParseError(#[from] url::ParseError), #[error("incorrect scheme")] IncorrectScheme, #[error("missing database name")] MissingDbName, #[error("invalid database name")] InvalidDbName, #[error("missing username")] MissingUsername, #[error("invalid username: {0}")] InvalidUsername(#[from] std::string::FromUtf8Error), #[error("missing authentication credentials: {0}")] MissingCredentials(Credentials), #[error("missing hostname")] MissingHostname, #[error("invalid hostname: {0}")] InvalidEndpoint(#[from] ComputeUserInfoParseError), #[error("malformed endpoint")] MalformedEndpoint, } #[derive(Debug, thiserror::Error)] pub(crate) enum Credentials { #[error("required password")] Password, #[error("required authorization bearer token in JWT format")] BearerJwt, } impl ReportableError for ConnInfoError { fn get_error_kind(&self) -> ErrorKind { ErrorKind::User } } impl UserFacingError for ConnInfoError { fn to_string_client(&self) -> String { self.to_string() } } fn get_conn_info( config: &'static AuthenticationConfig, ctx: &RequestMonitoring, headers: &HeaderMap, tls: Option<&TlsConfig>, ) -> Result { // HTTP only uses cleartext (for now and likely always) ctx.set_auth_method(crate::context::AuthMethod::Cleartext); let connection_string = headers .get(&CONN_STRING) .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))? .to_str() .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?; let connection_url = Url::parse(connection_string)?; let protocol = connection_url.scheme(); if protocol != "postgres" && protocol != "postgresql" { return Err(ConnInfoError::IncorrectScheme); } let mut url_path = connection_url .path_segments() .ok_or(ConnInfoError::MissingDbName)?; let dbname: DbName = urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into(); ctx.set_dbname(dbname.clone()); let username = RoleName::from(urlencoding::decode(connection_url.username())?); if username.is_empty() { return Err(ConnInfoError::MissingUsername); } ctx.set_user(username.clone()); let auth = if let Some(auth) = headers.get(&AUTHORIZATION) { if !config.accept_jwts { return Err(ConnInfoError::MissingCredentials(Credentials::Password)); } let auth = auth .to_str() .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?; AuthData::Jwt( auth.strip_prefix("Bearer ") .ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))? .into(), ) } else if let Some(pass) = connection_url.password() { // wrong credentials provided if config.accept_jwts { return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt)); } AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) { std::borrow::Cow::Borrowed(b) => b.into(), std::borrow::Cow::Owned(b) => b.into(), }) } else if config.accept_jwts { return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt)); } else { return Err(ConnInfoError::MissingCredentials(Credentials::Password)); }; let endpoint = match connection_url.host() { Some(url::Host::Domain(hostname)) => { if let Some(tls) = tls { endpoint_sni(hostname, &tls.common_names)? .ok_or(ConnInfoError::MalformedEndpoint)? } else { hostname .split_once('.') .map_or(hostname, |(prefix, _)| prefix) .into() } } Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => { return Err(ConnInfoError::MissingHostname) } }; ctx.set_endpoint_id(endpoint.clone()); let pairs = connection_url.query_pairs(); let mut options = Option::None; let mut params = StartupMessageParamsBuilder::default(); params.insert("user", &username); params.insert("database", &dbname); for (key, value) in pairs { params.insert(&key, &value); if key == "options" { options = Some(NeonOptions::parse_options_raw(&value)); } } let user_info = ComputeUserInfo { endpoint, user: username, options: options.unwrap_or_default(), }; let conn_info = ConnInfo { user_info, dbname }; Ok(ConnInfoWithAuth { conn_info, auth }) } // TODO: return different http error codes pub(crate) async fn handle( config: &'static ProxyConfig, ctx: RequestMonitoring, request: Request, backend: Arc, cancel: CancellationToken, ) -> Result>, ApiError> { let result = handle_inner(cancel, config, &ctx, request, backend).await; let mut response = match result { Ok(r) => { ctx.set_success(); r } Err(e @ SqlOverHttpError::Cancelled(_)) => { let error_kind = e.get_error_kind(); ctx.set_error_kind(error_kind); let message = "Query cancelled, connection was terminated"; tracing::info!( kind=error_kind.to_metric_label(), error=%e, msg=message, "forwarding error to user" ); json_response( StatusCode::BAD_REQUEST, json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }), )? } Err(e) => { let error_kind = e.get_error_kind(); ctx.set_error_kind(error_kind); let mut message = e.to_string_client(); let db_error = match &e { SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e)) | SqlOverHttpError::Postgres(e) => e.as_db_error(), _ => None, }; fn get<'a, T: Default>(db: Option<&'a DbError>, x: impl FnOnce(&'a DbError) -> T) -> T { db.map(x).unwrap_or_default() } if let Some(db_error) = db_error { db_error.message().clone_into(&mut message); } let position = db_error.and_then(|db| db.position()); let (position, internal_position, internal_query) = match position { Some(ErrorPosition::Original(position)) => (Some(position.to_string()), None, None), Some(ErrorPosition::Internal { position, query }) => { (None, Some(position.to_string()), Some(query.clone())) } None => (None, None, None), }; let code = get(db_error, |db| db.code().code()); let severity = get(db_error, |db| db.severity()); let detail = get(db_error, |db| db.detail()); let hint = get(db_error, |db| db.hint()); let where_ = get(db_error, |db| db.where_()); let table = get(db_error, |db| db.table()); let column = get(db_error, |db| db.column()); let schema = get(db_error, |db| db.schema()); let datatype = get(db_error, |db| db.datatype()); let constraint = get(db_error, |db| db.constraint()); let file = get(db_error, |db| db.file()); let line = get(db_error, |db| db.line().map(|l| l.to_string())); let routine = get(db_error, |db| db.routine()); tracing::info!( kind=error_kind.to_metric_label(), error=%e, msg=message, "forwarding error to user" ); // TODO: this shouldn't always be bad request. json_response( StatusCode::BAD_REQUEST, json!({ "message": message, "code": code, "detail": detail, "hint": hint, "position": position, "internalPosition": internal_position, "internalQuery": internal_query, "severity": severity, "where": where_, "table": table, "column": column, "schema": schema, "dataType": datatype, "constraint": constraint, "file": file, "line": line, "routine": routine, }), )? } }; response .headers_mut() .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*")); Ok(response) } #[derive(Debug, thiserror::Error)] pub(crate) enum SqlOverHttpError { #[error("{0}")] ReadPayload(#[from] ReadPayloadError), #[error("{0}")] ConnectCompute(#[from] HttpConnError), #[error("{0}")] ConnInfo(#[from] ConnInfoError), #[error("request is too large (max is {0} bytes)")] RequestTooLarge(u64), #[error("response is too large (max is {0} bytes)")] ResponseTooLarge(usize), #[error("invalid isolation level")] InvalidIsolationLevel, #[error("{0}")] Postgres(#[from] tokio_postgres::Error), #[error("{0}")] JsonConversion(#[from] JsonConversionError), #[error("{0}")] Cancelled(SqlOverHttpCancel), } impl ReportableError for SqlOverHttpError { fn get_error_kind(&self) -> ErrorKind { match self { SqlOverHttpError::ReadPayload(e) => e.get_error_kind(), SqlOverHttpError::ConnectCompute(e) => e.get_error_kind(), SqlOverHttpError::ConnInfo(e) => e.get_error_kind(), SqlOverHttpError::RequestTooLarge(_) => ErrorKind::User, SqlOverHttpError::ResponseTooLarge(_) => ErrorKind::User, SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User, SqlOverHttpError::Postgres(p) => p.get_error_kind(), SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres, SqlOverHttpError::Cancelled(c) => c.get_error_kind(), } } } impl UserFacingError for SqlOverHttpError { fn to_string_client(&self) -> String { match self { SqlOverHttpError::ReadPayload(p) => p.to_string(), SqlOverHttpError::ConnectCompute(c) => c.to_string_client(), SqlOverHttpError::ConnInfo(c) => c.to_string_client(), SqlOverHttpError::RequestTooLarge(_) => self.to_string(), SqlOverHttpError::ResponseTooLarge(_) => self.to_string(), SqlOverHttpError::InvalidIsolationLevel => self.to_string(), SqlOverHttpError::Postgres(p) => p.to_string(), SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(), SqlOverHttpError::Cancelled(_) => self.to_string(), } } } #[derive(Debug, thiserror::Error)] pub(crate) enum ReadPayloadError { #[error("could not read the HTTP request body: {0}")] Read(#[from] hyper::Error), #[error("could not parse the HTTP request body: {0}")] Parse(#[from] serde_json::Error), } impl ReportableError for ReadPayloadError { fn get_error_kind(&self) -> ErrorKind { match self { ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect, ReadPayloadError::Parse(_) => ErrorKind::User, } } } #[derive(Debug, thiserror::Error)] pub(crate) enum SqlOverHttpCancel { #[error("query was cancelled")] Postgres, #[error("query was cancelled while stuck trying to connect to the database")] Connect, } impl ReportableError for SqlOverHttpCancel { fn get_error_kind(&self) -> ErrorKind { match self { SqlOverHttpCancel::Postgres => ErrorKind::ClientDisconnect, SqlOverHttpCancel::Connect => ErrorKind::ClientDisconnect, } } } #[derive(Clone, Copy, Debug)] struct HttpHeaders { raw_output: bool, default_array_mode: bool, txn_isolation_level: Option, txn_read_only: bool, txn_deferrable: bool, } impl HttpHeaders { fn try_parse(headers: &hyper::http::HeaderMap) -> Result { // Determine the output options. Default behaviour is 'false'. Anything that is not // strictly 'true' assumed to be false. let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE); let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE); // isolation level, read only and deferrable let txn_isolation_level = match headers.get(&TXN_ISOLATION_LEVEL) { Some(x) => Some( map_header_to_isolation_level(x).ok_or(SqlOverHttpError::InvalidIsolationLevel)?, ), None => None, }; let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE); let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE); Ok(Self { raw_output, default_array_mode, txn_isolation_level, txn_read_only, txn_deferrable, }) } } fn map_header_to_isolation_level(level: &HeaderValue) -> Option { match level.as_bytes() { b"Serializable" => Some(IsolationLevel::Serializable), b"ReadUncommitted" => Some(IsolationLevel::ReadUncommitted), b"ReadCommitted" => Some(IsolationLevel::ReadCommitted), b"RepeatableRead" => Some(IsolationLevel::RepeatableRead), _ => None, } } fn map_isolation_level_to_headers(level: IsolationLevel) -> Option { match level { IsolationLevel::ReadUncommitted => Some(HeaderValue::from_static("ReadUncommitted")), IsolationLevel::ReadCommitted => Some(HeaderValue::from_static("ReadCommitted")), IsolationLevel::RepeatableRead => Some(HeaderValue::from_static("RepeatableRead")), IsolationLevel::Serializable => Some(HeaderValue::from_static("Serializable")), _ => None, } } async fn handle_inner( cancel: CancellationToken, config: &'static ProxyConfig, ctx: &RequestMonitoring, request: Request, backend: Arc, ) -> Result>, SqlOverHttpError> { let _requeset_gauge = Metrics::get() .proxy .connection_requests .guard(ctx.protocol()); info!( protocol = %ctx.protocol(), "handling interactive connection from client" ); let conn_info = get_conn_info( &config.authentication_config, ctx, request.headers(), config.tls_config.as_ref(), )?; info!( user = conn_info.conn_info.user_info.user.as_str(), "credentials" ); match conn_info.auth { AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => { handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await } auth => { handle_db_inner( cancel, config, ctx, request, conn_info.conn_info, auth, backend, ) .await } } } async fn handle_db_inner( cancel: CancellationToken, config: &'static ProxyConfig, ctx: &RequestMonitoring, request: Request, conn_info: ConnInfo, auth: AuthData, backend: Arc, ) -> Result>, SqlOverHttpError> { // // Determine the destination and connection params // let headers = request.headers(); // Allow connection pooling only if explicitly requested // or if we have decided that http pool is no longer opt-in let allow_pool = !config.http_config.pool_options.opt_in || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE); let parsed_headers = HttpHeaders::try_parse(headers)?; let request_content_length = match request.body().size_hint().upper() { Some(v) => v, None => config.http_config.max_request_size_bytes + 1, }; info!(request_content_length, "request size in bytes"); Metrics::get() .proxy .http_conn_content_length_bytes .observe(HttpDirection::Request, request_content_length as f64); // we don't have a streaming request support yet so this is to prevent OOM // from a malicious user sending an extremely large request body if request_content_length > config.http_config.max_request_size_bytes { return Err(SqlOverHttpError::RequestTooLarge( config.http_config.max_request_size_bytes, )); } let fetch_and_process_request = Box::pin( async { let body = request.into_body().collect().await?.to_bytes(); info!(length = body.len(), "request payload read"); let payload: Payload = serde_json::from_slice(&body)?; Ok::(payload) // Adjust error type accordingly } .map_err(SqlOverHttpError::from), ); let authenticate_and_connect = Box::pin( async { let is_local_proxy = matches!(backend.auth_backend, crate::auth::Backend::Local(_)); let keys = match auth { AuthData::Password(pw) => { backend .authenticate_with_password(ctx, &conn_info.user_info, &pw) .await? } AuthData::Jwt(jwt) => { backend .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) .await? } }; let client = match keys.keys { ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => { let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?; client.set_jwt_session(&payload).await?; Client::Local(client) } _ => { let client = backend .connect_to_compute(ctx, conn_info, keys, !allow_pool) .await?; Client::Remote(client) } }; // not strictly necessary to mark success here, // but it's just insurance for if we forget it somewhere else ctx.success(); Ok::<_, HttpConnError>(client) } .map_err(SqlOverHttpError::from), ); let (payload, mut client) = match run_until_cancelled( // Run both operations in parallel try_join( pin!(fetch_and_process_request), pin!(authenticate_and_connect), ), &cancel, ) .await { Some(result) => result?, None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)), }; let mut response = Response::builder() .status(StatusCode::OK) .header(header::CONTENT_TYPE, "application/json"); // Now execute the query and return the result. let json_output = match payload { Payload::Single(stmt) => { stmt.process(&config.http_config, cancel, &mut client, parsed_headers) .await? } Payload::Batch(statements) => { if parsed_headers.txn_read_only { response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE); } if parsed_headers.txn_deferrable { response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE); } if let Some(txn_isolation_level) = parsed_headers .txn_isolation_level .and_then(map_isolation_level_to_headers) { response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level); } statements .process(&config.http_config, cancel, &mut client, parsed_headers) .await? } }; let metrics = client.metrics(); let len = json_output.len(); let response = response .body( Full::new(Bytes::from(json_output)) .map_err(|x| match x {}) .boxed(), ) // only fails if invalid status code or invalid header/values are given. // these are not user configurable so it cannot fail dynamically .expect("building response payload should not fail"); // count the egress bytes - we miss the TLS and header overhead but oh well... // moving this later in the stack is going to be a lot of effort and ehhhh metrics.record_egress(len as u64); Metrics::get() .proxy .http_conn_content_length_bytes .observe(HttpDirection::Response, len as f64); Ok(response) } static HEADERS_TO_FORWARD: &[&HeaderName] = &[ &AUTHORIZATION, &CONN_STRING, &RAW_TEXT_OUTPUT, &ARRAY_MODE, &TXN_ISOLATION_LEVEL, &TXN_READ_ONLY, &TXN_DEFERRABLE, ]; async fn handle_auth_broker_inner( ctx: &RequestMonitoring, request: Request, conn_info: ConnInfo, jwt: String, backend: Arc, ) -> Result>, SqlOverHttpError> { backend .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) .await .map_err(HttpConnError::from)?; let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?; let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql"); let (mut parts, body) = request.into_parts(); let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri); // todo(conradludgate): maybe auth-broker should parse these and re-serialize // these instead just to ensure they remain normalised. for &h in HEADERS_TO_FORWARD { if let Some(hv) = parts.headers.remove(h) { req = req.header(h, hv); } } let req = req .body(body) .expect("all headers and params received via hyper should be valid for request"); // todo: map body to count egress let _metrics = client.metrics(); Ok(client .inner .send_request(req) .await .map_err(LocalProxyConnError::from) .map_err(HttpConnError::from)? .map(|b| b.boxed())) } impl QueryData { async fn process( self, config: &'static HttpConfig, cancel: CancellationToken, client: &mut Client, parsed_headers: HttpHeaders, ) -> Result { let (inner, mut discard) = client.inner(); let cancel_token = inner.cancel_token(); let res = match select( pin!(query_to_json(config, &*inner, self, &mut 0, parsed_headers)), pin!(cancel.cancelled()), ) .await { // The query successfully completed. Either::Left((Ok((status, results)), __not_yet_cancelled)) => { discard.check_idle(status); let json_output = serde_json::to_string(&results).expect("json serialization should not fail"); Ok(json_output) } // The query failed with an error Either::Left((Err(e), __not_yet_cancelled)) => { discard.discard(); return Err(e); } // The query was cancelled. Either::Right((_cancelled, query)) => { tracing::info!("cancelling query"); if let Err(err) = cancel_token.cancel_query(NoTls).await { tracing::warn!(?err, "could not cancel query"); } // wait for the query cancellation match time::timeout(time::Duration::from_millis(100), query).await { // query successed before it was cancelled. Ok(Ok((status, results))) => { discard.check_idle(status); let json_output = serde_json::to_string(&results) .expect("json serialization should not fail"); Ok(json_output) } // query failed or was cancelled. Ok(Err(error)) => { let db_error = match &error { SqlOverHttpError::ConnectCompute( HttpConnError::PostgresConnectionError(e), ) | SqlOverHttpError::Postgres(e) => e.as_db_error(), _ => None, }; // if errored for some other reason, it might not be safe to return if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) { discard.discard(); } Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)) } Err(_timeout) => { discard.discard(); Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)) } } } }; res } } impl BatchQueryData { async fn process( self, config: &'static HttpConfig, cancel: CancellationToken, client: &mut Client, parsed_headers: HttpHeaders, ) -> Result { info!("starting transaction"); let (inner, mut discard) = client.inner(); let cancel_token = inner.cancel_token(); let mut builder = inner.build_transaction(); if let Some(isolation_level) = parsed_headers.txn_isolation_level { builder = builder.isolation_level(isolation_level); } if parsed_headers.txn_read_only { builder = builder.read_only(true); } if parsed_headers.txn_deferrable { builder = builder.deferrable(true); } let transaction = builder.start().await.inspect_err(|_| { // if we cannot start a transaction, we should return immediately // and not return to the pool. connection is clearly broken discard.discard(); })?; let json_output = match query_batch( config, cancel.child_token(), &transaction, self, parsed_headers, ) .await { Ok(json_output) => { info!("commit"); let status = transaction.commit().await.inspect_err(|_| { // if we cannot commit - for now don't return connection to pool // TODO: get a query status from the error discard.discard(); })?; discard.check_idle(status); json_output } Err(SqlOverHttpError::Cancelled(_)) => { if let Err(err) = cancel_token.cancel_query(NoTls).await { tracing::warn!(?err, "could not cancel query"); } // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe. discard.discard(); return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); } Err(err) => { info!("rollback"); let status = transaction.rollback().await.inspect_err(|_| { // if we cannot rollback - for now don't return connection to pool // TODO: get a query status from the error discard.discard(); })?; discard.check_idle(status); return Err(err); } }; Ok(json_output) } } async fn query_batch( config: &'static HttpConfig, cancel: CancellationToken, transaction: &Transaction<'_>, queries: BatchQueryData, parsed_headers: HttpHeaders, ) -> Result { let mut results = Vec::with_capacity(queries.queries.len()); let mut current_size = 0; for stmt in queries.queries { let query = pin!(query_to_json( config, transaction, stmt, &mut current_size, parsed_headers, )); let cancelled = pin!(cancel.cancelled()); let res = select(query, cancelled).await; match res { // TODO: maybe we should check that the transaction bit is set here Either::Left((Ok((_, values)), _cancelled)) => { results.push(values); } Either::Left((Err(e), _cancelled)) => { return Err(e); } Either::Right((_cancelled, _)) => { return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); } } } let results = json!({ "results": results }); let json_output = serde_json::to_string(&results).expect("json serialization should not fail"); Ok(json_output) } async fn query_to_json( config: &'static HttpConfig, client: &T, data: QueryData, current_size: &mut usize, parsed_headers: HttpHeaders, ) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> { info!("executing query"); let query_params = data.params; let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?); info!("finished executing query"); // Manually drain the stream into a vector to leave row_stream hanging // around to get a command tag. Also check that the response is not too // big. let mut rows: Vec = Vec::new(); while let Some(row) = row_stream.next().await { let row = row?; *current_size += row.body_len(); rows.push(row); // we don't have a streaming response support yet so this is to prevent OOM // from a malicious query (eg a cross join) if *current_size > config.max_response_size_bytes { return Err(SqlOverHttpError::ResponseTooLarge( config.max_response_size_bytes, )); } } let ready = row_stream.ready_status(); // grab the command tag and number of rows affected let command_tag = row_stream.command_tag().unwrap_or_default(); let mut command_tag_split = command_tag.split(' '); let command_tag_name = command_tag_split.next().unwrap_or_default(); let command_tag_count = if command_tag_name == "INSERT" { // INSERT returns OID first and then number of rows command_tag_split.nth(1) } else { // other commands return number of rows (if any) command_tag_split.next() } .and_then(|s| s.parse::().ok()); info!( rows = rows.len(), ?ready, command_tag, "finished reading rows" ); let columns_len = row_stream.columns().len(); let mut fields = Vec::with_capacity(columns_len); let mut columns = Vec::with_capacity(columns_len); for c in row_stream.columns() { fields.push(json!({ "name": c.name().to_owned(), "dataTypeID": c.type_().oid(), "tableID": c.table_oid(), "columnID": c.column_id(), "dataTypeSize": c.type_size(), "dataTypeModifier": c.type_modifier(), "format": "text", })); columns.push(client.get_type(c.type_oid()).await?); } let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode); // convert rows to JSON let rows = rows .iter() .map(|row| pg_text_row_to_json(row, &columns, parsed_headers.raw_output, array_mode)) .collect::, _>>()?; // Resulting JSON format is based on the format of node-postgres result. let results = json!({ "command": command_tag_name.to_string(), "rowCount": command_tag_count, "rows": rows, "fields": fields, "rowAsArray": array_mode, }); Ok((ready, results)) } enum Client { Remote(conn_pool::Client), Local(local_conn_pool::LocalClient), } enum Discard<'a> { Remote(conn_pool::Discard<'a, tokio_postgres::Client>), Local(local_conn_pool::Discard<'a, tokio_postgres::Client>), } impl Client { fn metrics(&self) -> Arc { match self { Client::Remote(client) => client.metrics(), Client::Local(local_client) => local_client.metrics(), } } fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) { match self { Client::Remote(client) => { let (c, d) = client.inner(); (c, Discard::Remote(d)) } Client::Local(local_client) => { let (c, d) = local_client.inner(); (c, Discard::Local(d)) } } } } impl Discard<'_> { fn check_idle(&mut self, status: ReadyForQueryStatus) { match self { Discard::Remote(discard) => discard.check_idle(status), Discard::Local(discard) => discard.check_idle(status), } } fn discard(&mut self) { match self { Discard::Remote(discard) => discard.discard(), Discard::Local(discard) => discard.discard(), } } }