Files
neon/proxy/src/serverless/sql_over_http.rs
Conrad Ludgate ea089dc977 proxy: add per query array mode flag (#6678)
## Problem

Drizzle needs to be able to configure the array_mode flag per query.

## Summary of changes

Adds an array_mode flag to the query data json that will otherwise
default to the header flag.
2024-02-09 10:29:20 +00:00

608 lines
21 KiB
Rust

use std::sync::Arc;
use anyhow::bail;
use anyhow::Context;
use futures::pin_mut;
use futures::StreamExt;
use hyper::body::HttpBody;
use hyper::header;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::Response;
use hyper::StatusCode;
use hyper::{Body, HeaderMap, Request};
use serde_json::json;
use serde_json::Value;
use tokio::join;
use tokio_postgres::error::DbError;
use tokio_postgres::error::ErrorPosition;
use tokio_postgres::GenericClient;
use tokio_postgres::IsolationLevel;
use tokio_postgres::ReadyForQueryStatus;
use tokio_postgres::Transaction;
use tracing::error;
use tracing::info;
use tracing::instrument;
use url::Url;
use utils::http::error::ApiError;
use utils::http::json::json_response;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::metrics::HTTP_CONTENT_LENGTH;
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
use crate::proxy::NeonOptions;
use crate::RoleName;
use super::backend::PoolingBackend;
use super::conn_pool::ConnInfo;
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::SERVERLESS_DRIVER_SNI;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct QueryData {
query: String,
#[serde(deserialize_with = "bytes_to_pg_text")]
params: Vec<Option<String>>,
#[serde(default)]
array_mode: Option<bool>,
}
#[derive(serde::Deserialize)]
struct BatchQueryData {
queries: Vec<QueryData>,
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum Payload {
Single(QueryData),
Batch(BatchQueryData),
}
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
where
D: serde::de::Deserializer<'de>,
{
// TODO: consider avoiding the allocation here.
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
Ok(json_to_pg_text(json))
}
fn get_conn_info(
ctx: &mut RequestMonitoring,
headers: &HeaderMap,
sni_hostname: Option<String>,
tls: &TlsConfig,
) -> Result<ConnInfo, anyhow::Error> {
let connection_string = headers
.get("Neon-Connection-String")
.ok_or(anyhow::anyhow!("missing connection string"))?
.to_str()?;
let connection_url = Url::parse(connection_string)?;
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(anyhow::anyhow!(
"connection string must start with postgres: or postgresql:"
));
}
let mut url_path = connection_url
.path_segments()
.ok_or(anyhow::anyhow!("missing database name"))?;
let dbname = url_path
.next()
.ok_or(anyhow::anyhow!("invalid database name"))?;
let username = RoleName::from(connection_url.username());
if username.is_empty() {
return Err(anyhow::anyhow!("missing username"));
}
ctx.set_user(username.clone());
let password = connection_url
.password()
.ok_or(anyhow::anyhow!("no password"))?;
// TLS certificate selector now based on SNI hostname, so if we are running here
// we are sure that SNI hostname is set to one of the configured domain names.
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
let hostname = connection_url
.host_str()
.ok_or(anyhow::anyhow!("no host"))?;
let host_header = headers
.get("host")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(':').next());
// sni_hostname has to be either the same as hostname or the one used in serverless driver.
if !check_matches(&sni_hostname, hostname)? {
return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
} else if let Some(h) = host_header {
if h != sni_hostname {
return Err(anyhow::anyhow!("mismatched host header and hostname"));
}
}
let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?;
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
let mut options = Option::None;
for (key, value) in pairs {
if key == "options" {
options = Some(NeonOptions::parse_options_raw(&value));
break;
}
}
let user_info = ComputeUserInfo {
endpoint,
user: username,
options: options.unwrap_or_default(),
};
Ok(ConnInfo {
user_info,
dbname: dbname.into(),
password: password.into(),
})
}
fn check_matches(sni_hostname: &str, hostname: &str) -> Result<bool, anyhow::Error> {
if sni_hostname == hostname {
return Ok(true);
}
let (sni_hostname_first, sni_hostname_rest) = sni_hostname
.split_once('.')
.ok_or_else(|| anyhow::anyhow!("Unexpected sni format."))?;
let (_, hostname_rest) = hostname
.split_once('.')
.ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?;
Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI)
}
// TODO: return different http error codes
pub async fn handle(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
request: Request<Body>,
sni_hostname: Option<String>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Body>, ApiError> {
let result = tokio::time::timeout(
config.http_config.request_timeout,
handle_inner(config, ctx, request, sni_hostname, backend),
)
.await;
let mut response = match result {
Ok(r) => match r {
Ok(r) => r,
Err(e) => {
let mut message = format!("{:?}", e);
let db_error = e
.downcast_ref::<tokio_postgres::Error>()
.and_then(|e| e.as_db_error());
fn get<'a, T: serde::Serialize>(
db: Option<&'a DbError>,
x: impl FnOnce(&'a DbError) -> T,
) -> Value {
db.map(x)
.and_then(|t| serde_json::to_value(t).ok())
.unwrap_or_default()
}
if let Some(db_error) = db_error {
db_error.message().clone_into(&mut message);
}
let position = db_error.and_then(|db| db.position());
let (position, internal_position, internal_query) = match position {
Some(ErrorPosition::Original(position)) => (
Value::String(position.to_string()),
Value::Null,
Value::Null,
),
Some(ErrorPosition::Internal { position, query }) => (
Value::Null,
Value::String(position.to_string()),
Value::String(query.clone()),
),
None => (Value::Null, Value::Null, Value::Null),
};
let code = get(db_error, |db| db.code().code());
let severity = get(db_error, |db| db.severity());
let detail = get(db_error, |db| db.detail());
let hint = get(db_error, |db| db.hint());
let where_ = get(db_error, |db| db.where_());
let table = get(db_error, |db| db.table());
let column = get(db_error, |db| db.column());
let schema = get(db_error, |db| db.schema());
let datatype = get(db_error, |db| db.datatype());
let constraint = get(db_error, |db| db.constraint());
let file = get(db_error, |db| db.file());
let line = get(db_error, |db| db.line().map(|l| l.to_string()));
let routine = get(db_error, |db| db.routine());
error!(
?code,
"sql-over-http per-client task finished with an error: {e:#}"
);
// TODO: this shouldn't always be bad request.
json_response(
StatusCode::BAD_REQUEST,
json!({
"message": message,
"code": code,
"detail": detail,
"hint": hint,
"position": position,
"internalPosition": internal_position,
"internalQuery": internal_query,
"severity": severity,
"where": where_,
"table": table,
"column": column,
"schema": schema,
"dataType": datatype,
"constraint": constraint,
"file": file,
"line": line,
"routine": routine,
}),
)?
}
},
Err(_) => {
let message = format!(
"HTTP-Connection timed out, execution time exeeded {} seconds",
config.http_config.request_timeout.as_secs()
);
error!(message);
json_response(
StatusCode::GATEWAY_TIMEOUT,
json!({ "message": message, "code": StatusCode::GATEWAY_TIMEOUT.as_u16() }),
)?
}
};
response.headers_mut().insert(
"Access-Control-Allow-Origin",
hyper::http::HeaderValue::from_static("*"),
);
Ok(response)
}
#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
async fn handle_inner(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
request: Request<Body>,
sni_hostname: Option<String>,
backend: Arc<PoolingBackend>,
) -> anyhow::Result<Response<Body>> {
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
.with_label_values(&[ctx.protocol])
.guard();
info!(
protocol = ctx.protocol,
"handling interactive connection from client"
);
//
// Determine the destination and connection params
//
let headers = request.headers();
// TLS config should be there.
let conn_info = get_conn_info(
ctx,
headers,
sni_hostname,
config.tls_config.as_ref().unwrap(),
)?;
info!(
user = conn_info.user_info.user.as_str(),
project = conn_info.user_info.endpoint.as_str(),
"credentials"
);
// Determine the output options. Default behaviour is 'false'. Anything that is not
// strictly 'true' assumed to be false.
let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
// Allow connection pooling only if explicitly requested
// or if we have decided that http pool is no longer opt-in
let allow_pool = !config.http_config.pool_options.opt_in
|| headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
// isolation level, read only and deferrable
let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned();
let txn_isolation_level = match txn_isolation_level_raw {
Some(ref x) => Some(match x.as_bytes() {
b"Serializable" => IsolationLevel::Serializable,
b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
b"ReadCommitted" => IsolationLevel::ReadCommitted,
b"RepeatableRead" => IsolationLevel::RepeatableRead,
_ => bail!("invalid isolation level"),
}),
None => None,
};
let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
let paused = ctx.latency_timer.pause();
let request_content_length = match request.body().size_hint().upper() {
Some(v) => v,
None => MAX_REQUEST_SIZE + 1,
};
drop(paused);
info!(request_content_length, "request size in bytes");
HTTP_CONTENT_LENGTH.observe(request_content_length as f64);
// we don't have a streaming request support yet so this is to prevent OOM
// from a malicious user sending an extremely large request body
if request_content_length > MAX_REQUEST_SIZE {
return Err(anyhow::anyhow!(
"request is too large (max is {MAX_REQUEST_SIZE} bytes)"
));
}
let fetch_and_process_request = async {
let body = hyper::body::to_bytes(request.into_body())
.await
.map_err(anyhow::Error::from)?;
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
};
let authenticate_and_connect = async {
let keys = backend.authenticate(ctx, &conn_info).await?;
backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await
};
// Run both operations in parallel
let (payload_result, auth_and_connect_result) =
join!(fetch_and_process_request, authenticate_and_connect,);
// Handle the results
let payload = payload_result?; // Handle errors appropriately
let mut client = auth_and_connect_result?; // Handle errors appropriately
let mut response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json");
//
// Now execute the query and return the result
//
let mut size = 0;
let result = match payload {
Payload::Single(stmt) => {
let (status, results) =
query_to_json(&*client, stmt, &mut 0, raw_output, default_array_mode)
.await
.map_err(|e| {
client.discard();
e
})?;
client.check_idle(status);
results
}
Payload::Batch(statements) => {
let (inner, mut discard) = client.inner();
let mut builder = inner.build_transaction();
if let Some(isolation_level) = txn_isolation_level {
builder = builder.isolation_level(isolation_level);
}
if txn_read_only {
builder = builder.read_only(true);
}
if txn_deferrable {
builder = builder.deferrable(true);
}
let transaction = builder.start().await.map_err(|e| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
e
})?;
let results = match query_batch(
&transaction,
statements,
&mut size,
raw_output,
default_array_mode,
)
.await
{
Ok(results) => {
let status = transaction.commit().await.map_err(|e| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
e
})?;
discard.check_idle(status);
results
}
Err(err) => {
let status = transaction.rollback().await.map_err(|e| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
e
})?;
discard.check_idle(status);
return Err(err);
}
};
if txn_read_only {
response = response.header(
TXN_READ_ONLY.clone(),
HeaderValue::try_from(txn_read_only.to_string())?,
);
}
if txn_deferrable {
response = response.header(
TXN_DEFERRABLE.clone(),
HeaderValue::try_from(txn_deferrable.to_string())?,
);
}
if let Some(txn_isolation_level) = txn_isolation_level_raw {
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
}
json!({ "results": results })
}
};
ctx.set_success();
ctx.log();
let metrics = client.metrics();
// how could this possibly fail
let body = serde_json::to_string(&result).expect("json serialization should not fail");
let len = body.len();
let response = response
.body(Body::from(body))
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");
// count the egress bytes - we miss the TLS and header overhead but oh well...
// moving this later in the stack is going to be a lot of effort and ehhhh
metrics.record_egress(len as u64);
Ok(response)
}
async fn query_batch(
transaction: &Transaction<'_>,
queries: BatchQueryData,
total_size: &mut usize,
raw_output: bool,
array_mode: bool,
) -> anyhow::Result<Vec<Value>> {
let mut results = Vec::with_capacity(queries.queries.len());
let mut current_size = 0;
for stmt in queries.queries {
// TODO: maybe we should check that the transaction bit is set here
let (_, values) =
query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?;
results.push(values);
}
*total_size += current_size;
Ok(results)
}
async fn query_to_json<T: GenericClient>(
client: &T,
data: QueryData,
current_size: &mut usize,
raw_output: bool,
default_array_mode: bool,
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
let query_params = data.params;
let row_stream = client.query_raw_txt(&data.query, query_params).await?;
// Manually drain the stream into a vector to leave row_stream hanging
// around to get a command tag. Also check that the response is not too
// big.
pin_mut!(row_stream);
let mut rows: Vec<tokio_postgres::Row> = Vec::new();
while let Some(row) = row_stream.next().await {
let row = row?;
*current_size += row.body_len();
rows.push(row);
// we don't have a streaming response support yet so this is to prevent OOM
// from a malicious query (eg a cross join)
if *current_size > MAX_RESPONSE_SIZE {
return Err(anyhow::anyhow!(
"response is too large (max is {MAX_RESPONSE_SIZE} bytes)"
));
}
}
let ready = row_stream.ready_status();
// grab the command tag and number of rows affected
let command_tag = row_stream.command_tag().unwrap_or_default();
let mut command_tag_split = command_tag.split(' ');
let command_tag_name = command_tag_split.next().unwrap_or_default();
let command_tag_count = if command_tag_name == "INSERT" {
// INSERT returns OID first and then number of rows
command_tag_split.nth(1)
} else {
// other commands return number of rows (if any)
command_tag_split.next()
}
.and_then(|s| s.parse::<i64>().ok());
let mut fields = vec![];
let mut columns = vec![];
for c in row_stream.columns() {
fields.push(json!({
"name": Value::String(c.name().to_owned()),
"dataTypeID": Value::Number(c.type_().oid().into()),
"tableID": c.table_oid(),
"columnID": c.column_id(),
"dataTypeSize": c.type_size(),
"dataTypeModifier": c.type_modifier(),
"format": "text",
}));
columns.push(client.get_type(c.type_oid()).await?);
}
let array_mode = data.array_mode.unwrap_or(default_array_mode);
// convert rows to JSON
let rows = rows
.iter()
.map(|row| pg_text_row_to_json(row, &columns, raw_output, array_mode))
.collect::<Result<Vec<_>, _>>()?;
// resulting JSON format is based on the format of node-postgres result
Ok((
ready,
json!({
"command": command_tag_name,
"rowCount": command_tag_count,
"rows": rows,
"fields": fields,
"rowAsArray": array_mode,
}),
))
}