mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-28 00:23:00 +00:00
## Problem Drizzle needs to be able to configure the array_mode flag per query. ## Summary of changes Adds an array_mode flag to the query data json that will otherwise default to the header flag.
608 lines
21 KiB
Rust
608 lines
21 KiB
Rust
use std::sync::Arc;
|
|
|
|
use anyhow::bail;
|
|
use anyhow::Context;
|
|
use futures::pin_mut;
|
|
use futures::StreamExt;
|
|
use hyper::body::HttpBody;
|
|
use hyper::header;
|
|
use hyper::http::HeaderName;
|
|
use hyper::http::HeaderValue;
|
|
use hyper::Response;
|
|
use hyper::StatusCode;
|
|
use hyper::{Body, HeaderMap, Request};
|
|
use serde_json::json;
|
|
use serde_json::Value;
|
|
use tokio::join;
|
|
use tokio_postgres::error::DbError;
|
|
use tokio_postgres::error::ErrorPosition;
|
|
use tokio_postgres::GenericClient;
|
|
use tokio_postgres::IsolationLevel;
|
|
use tokio_postgres::ReadyForQueryStatus;
|
|
use tokio_postgres::Transaction;
|
|
use tracing::error;
|
|
use tracing::info;
|
|
use tracing::instrument;
|
|
use url::Url;
|
|
use utils::http::error::ApiError;
|
|
use utils::http::json::json_response;
|
|
|
|
use crate::auth::backend::ComputeUserInfo;
|
|
use crate::auth::endpoint_sni;
|
|
use crate::config::ProxyConfig;
|
|
use crate::config::TlsConfig;
|
|
use crate::context::RequestMonitoring;
|
|
use crate::metrics::HTTP_CONTENT_LENGTH;
|
|
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
|
|
use crate::proxy::NeonOptions;
|
|
use crate::RoleName;
|
|
|
|
use super::backend::PoolingBackend;
|
|
use super::conn_pool::ConnInfo;
|
|
use super::json::json_to_pg_text;
|
|
use super::json::pg_text_row_to_json;
|
|
use super::SERVERLESS_DRIVER_SNI;
|
|
|
|
#[derive(serde::Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct QueryData {
|
|
query: String,
|
|
#[serde(deserialize_with = "bytes_to_pg_text")]
|
|
params: Vec<Option<String>>,
|
|
#[serde(default)]
|
|
array_mode: Option<bool>,
|
|
}
|
|
|
|
#[derive(serde::Deserialize)]
|
|
struct BatchQueryData {
|
|
queries: Vec<QueryData>,
|
|
}
|
|
|
|
#[derive(serde::Deserialize)]
|
|
#[serde(untagged)]
|
|
enum Payload {
|
|
Single(QueryData),
|
|
Batch(BatchQueryData),
|
|
}
|
|
|
|
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
|
|
const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
|
|
|
|
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
|
|
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
|
|
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
|
|
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
|
|
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
|
|
static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
|
|
|
|
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
|
|
|
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
|
|
where
|
|
D: serde::de::Deserializer<'de>,
|
|
{
|
|
// TODO: consider avoiding the allocation here.
|
|
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
|
|
Ok(json_to_pg_text(json))
|
|
}
|
|
|
|
fn get_conn_info(
|
|
ctx: &mut RequestMonitoring,
|
|
headers: &HeaderMap,
|
|
sni_hostname: Option<String>,
|
|
tls: &TlsConfig,
|
|
) -> Result<ConnInfo, anyhow::Error> {
|
|
let connection_string = headers
|
|
.get("Neon-Connection-String")
|
|
.ok_or(anyhow::anyhow!("missing connection string"))?
|
|
.to_str()?;
|
|
|
|
let connection_url = Url::parse(connection_string)?;
|
|
|
|
let protocol = connection_url.scheme();
|
|
if protocol != "postgres" && protocol != "postgresql" {
|
|
return Err(anyhow::anyhow!(
|
|
"connection string must start with postgres: or postgresql:"
|
|
));
|
|
}
|
|
|
|
let mut url_path = connection_url
|
|
.path_segments()
|
|
.ok_or(anyhow::anyhow!("missing database name"))?;
|
|
|
|
let dbname = url_path
|
|
.next()
|
|
.ok_or(anyhow::anyhow!("invalid database name"))?;
|
|
|
|
let username = RoleName::from(connection_url.username());
|
|
if username.is_empty() {
|
|
return Err(anyhow::anyhow!("missing username"));
|
|
}
|
|
ctx.set_user(username.clone());
|
|
|
|
let password = connection_url
|
|
.password()
|
|
.ok_or(anyhow::anyhow!("no password"))?;
|
|
|
|
// TLS certificate selector now based on SNI hostname, so if we are running here
|
|
// we are sure that SNI hostname is set to one of the configured domain names.
|
|
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
|
|
|
|
let hostname = connection_url
|
|
.host_str()
|
|
.ok_or(anyhow::anyhow!("no host"))?;
|
|
|
|
let host_header = headers
|
|
.get("host")
|
|
.and_then(|h| h.to_str().ok())
|
|
.and_then(|h| h.split(':').next());
|
|
|
|
// sni_hostname has to be either the same as hostname or the one used in serverless driver.
|
|
if !check_matches(&sni_hostname, hostname)? {
|
|
return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
|
|
} else if let Some(h) = host_header {
|
|
if h != sni_hostname {
|
|
return Err(anyhow::anyhow!("mismatched host header and hostname"));
|
|
}
|
|
}
|
|
|
|
let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?;
|
|
ctx.set_endpoint_id(endpoint.clone());
|
|
|
|
let pairs = connection_url.query_pairs();
|
|
|
|
let mut options = Option::None;
|
|
|
|
for (key, value) in pairs {
|
|
if key == "options" {
|
|
options = Some(NeonOptions::parse_options_raw(&value));
|
|
break;
|
|
}
|
|
}
|
|
|
|
let user_info = ComputeUserInfo {
|
|
endpoint,
|
|
user: username,
|
|
options: options.unwrap_or_default(),
|
|
};
|
|
|
|
Ok(ConnInfo {
|
|
user_info,
|
|
dbname: dbname.into(),
|
|
password: password.into(),
|
|
})
|
|
}
|
|
|
|
fn check_matches(sni_hostname: &str, hostname: &str) -> Result<bool, anyhow::Error> {
|
|
if sni_hostname == hostname {
|
|
return Ok(true);
|
|
}
|
|
let (sni_hostname_first, sni_hostname_rest) = sni_hostname
|
|
.split_once('.')
|
|
.ok_or_else(|| anyhow::anyhow!("Unexpected sni format."))?;
|
|
let (_, hostname_rest) = hostname
|
|
.split_once('.')
|
|
.ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?;
|
|
Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI)
|
|
}
|
|
|
|
// TODO: return different http error codes
|
|
pub async fn handle(
|
|
config: &'static ProxyConfig,
|
|
ctx: &mut RequestMonitoring,
|
|
request: Request<Body>,
|
|
sni_hostname: Option<String>,
|
|
backend: Arc<PoolingBackend>,
|
|
) -> Result<Response<Body>, ApiError> {
|
|
let result = tokio::time::timeout(
|
|
config.http_config.request_timeout,
|
|
handle_inner(config, ctx, request, sni_hostname, backend),
|
|
)
|
|
.await;
|
|
let mut response = match result {
|
|
Ok(r) => match r {
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
let mut message = format!("{:?}", e);
|
|
let db_error = e
|
|
.downcast_ref::<tokio_postgres::Error>()
|
|
.and_then(|e| e.as_db_error());
|
|
fn get<'a, T: serde::Serialize>(
|
|
db: Option<&'a DbError>,
|
|
x: impl FnOnce(&'a DbError) -> T,
|
|
) -> Value {
|
|
db.map(x)
|
|
.and_then(|t| serde_json::to_value(t).ok())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
if let Some(db_error) = db_error {
|
|
db_error.message().clone_into(&mut message);
|
|
}
|
|
|
|
let position = db_error.and_then(|db| db.position());
|
|
let (position, internal_position, internal_query) = match position {
|
|
Some(ErrorPosition::Original(position)) => (
|
|
Value::String(position.to_string()),
|
|
Value::Null,
|
|
Value::Null,
|
|
),
|
|
Some(ErrorPosition::Internal { position, query }) => (
|
|
Value::Null,
|
|
Value::String(position.to_string()),
|
|
Value::String(query.clone()),
|
|
),
|
|
None => (Value::Null, Value::Null, Value::Null),
|
|
};
|
|
|
|
let code = get(db_error, |db| db.code().code());
|
|
let severity = get(db_error, |db| db.severity());
|
|
let detail = get(db_error, |db| db.detail());
|
|
let hint = get(db_error, |db| db.hint());
|
|
let where_ = get(db_error, |db| db.where_());
|
|
let table = get(db_error, |db| db.table());
|
|
let column = get(db_error, |db| db.column());
|
|
let schema = get(db_error, |db| db.schema());
|
|
let datatype = get(db_error, |db| db.datatype());
|
|
let constraint = get(db_error, |db| db.constraint());
|
|
let file = get(db_error, |db| db.file());
|
|
let line = get(db_error, |db| db.line().map(|l| l.to_string()));
|
|
let routine = get(db_error, |db| db.routine());
|
|
|
|
error!(
|
|
?code,
|
|
"sql-over-http per-client task finished with an error: {e:#}"
|
|
);
|
|
// TODO: this shouldn't always be bad request.
|
|
json_response(
|
|
StatusCode::BAD_REQUEST,
|
|
json!({
|
|
"message": message,
|
|
"code": code,
|
|
"detail": detail,
|
|
"hint": hint,
|
|
"position": position,
|
|
"internalPosition": internal_position,
|
|
"internalQuery": internal_query,
|
|
"severity": severity,
|
|
"where": where_,
|
|
"table": table,
|
|
"column": column,
|
|
"schema": schema,
|
|
"dataType": datatype,
|
|
"constraint": constraint,
|
|
"file": file,
|
|
"line": line,
|
|
"routine": routine,
|
|
}),
|
|
)?
|
|
}
|
|
},
|
|
Err(_) => {
|
|
let message = format!(
|
|
"HTTP-Connection timed out, execution time exeeded {} seconds",
|
|
config.http_config.request_timeout.as_secs()
|
|
);
|
|
error!(message);
|
|
json_response(
|
|
StatusCode::GATEWAY_TIMEOUT,
|
|
json!({ "message": message, "code": StatusCode::GATEWAY_TIMEOUT.as_u16() }),
|
|
)?
|
|
}
|
|
};
|
|
response.headers_mut().insert(
|
|
"Access-Control-Allow-Origin",
|
|
hyper::http::HeaderValue::from_static("*"),
|
|
);
|
|
Ok(response)
|
|
}
|
|
|
|
#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
|
|
async fn handle_inner(
|
|
config: &'static ProxyConfig,
|
|
ctx: &mut RequestMonitoring,
|
|
request: Request<Body>,
|
|
sni_hostname: Option<String>,
|
|
backend: Arc<PoolingBackend>,
|
|
) -> anyhow::Result<Response<Body>> {
|
|
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
|
|
.with_label_values(&[ctx.protocol])
|
|
.guard();
|
|
info!(
|
|
protocol = ctx.protocol,
|
|
"handling interactive connection from client"
|
|
);
|
|
|
|
//
|
|
// Determine the destination and connection params
|
|
//
|
|
let headers = request.headers();
|
|
// TLS config should be there.
|
|
let conn_info = get_conn_info(
|
|
ctx,
|
|
headers,
|
|
sni_hostname,
|
|
config.tls_config.as_ref().unwrap(),
|
|
)?;
|
|
info!(
|
|
user = conn_info.user_info.user.as_str(),
|
|
project = conn_info.user_info.endpoint.as_str(),
|
|
"credentials"
|
|
);
|
|
|
|
// Determine the output options. Default behaviour is 'false'. Anything that is not
|
|
// strictly 'true' assumed to be false.
|
|
let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
|
|
let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
|
|
|
|
// Allow connection pooling only if explicitly requested
|
|
// or if we have decided that http pool is no longer opt-in
|
|
let allow_pool = !config.http_config.pool_options.opt_in
|
|
|| headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
|
|
|
|
// isolation level, read only and deferrable
|
|
|
|
let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned();
|
|
let txn_isolation_level = match txn_isolation_level_raw {
|
|
Some(ref x) => Some(match x.as_bytes() {
|
|
b"Serializable" => IsolationLevel::Serializable,
|
|
b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
|
|
b"ReadCommitted" => IsolationLevel::ReadCommitted,
|
|
b"RepeatableRead" => IsolationLevel::RepeatableRead,
|
|
_ => bail!("invalid isolation level"),
|
|
}),
|
|
None => None,
|
|
};
|
|
|
|
let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
|
|
let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
|
|
|
|
let paused = ctx.latency_timer.pause();
|
|
let request_content_length = match request.body().size_hint().upper() {
|
|
Some(v) => v,
|
|
None => MAX_REQUEST_SIZE + 1,
|
|
};
|
|
drop(paused);
|
|
info!(request_content_length, "request size in bytes");
|
|
HTTP_CONTENT_LENGTH.observe(request_content_length as f64);
|
|
|
|
// we don't have a streaming request support yet so this is to prevent OOM
|
|
// from a malicious user sending an extremely large request body
|
|
if request_content_length > MAX_REQUEST_SIZE {
|
|
return Err(anyhow::anyhow!(
|
|
"request is too large (max is {MAX_REQUEST_SIZE} bytes)"
|
|
));
|
|
}
|
|
|
|
let fetch_and_process_request = async {
|
|
let body = hyper::body::to_bytes(request.into_body())
|
|
.await
|
|
.map_err(anyhow::Error::from)?;
|
|
let payload: Payload = serde_json::from_slice(&body)?;
|
|
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
|
|
};
|
|
|
|
let authenticate_and_connect = async {
|
|
let keys = backend.authenticate(ctx, &conn_info).await?;
|
|
backend
|
|
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
|
.await
|
|
};
|
|
|
|
// Run both operations in parallel
|
|
let (payload_result, auth_and_connect_result) =
|
|
join!(fetch_and_process_request, authenticate_and_connect,);
|
|
|
|
// Handle the results
|
|
let payload = payload_result?; // Handle errors appropriately
|
|
let mut client = auth_and_connect_result?; // Handle errors appropriately
|
|
|
|
let mut response = Response::builder()
|
|
.status(StatusCode::OK)
|
|
.header(header::CONTENT_TYPE, "application/json");
|
|
|
|
//
|
|
// Now execute the query and return the result
|
|
//
|
|
let mut size = 0;
|
|
let result = match payload {
|
|
Payload::Single(stmt) => {
|
|
let (status, results) =
|
|
query_to_json(&*client, stmt, &mut 0, raw_output, default_array_mode)
|
|
.await
|
|
.map_err(|e| {
|
|
client.discard();
|
|
e
|
|
})?;
|
|
client.check_idle(status);
|
|
results
|
|
}
|
|
Payload::Batch(statements) => {
|
|
let (inner, mut discard) = client.inner();
|
|
let mut builder = inner.build_transaction();
|
|
if let Some(isolation_level) = txn_isolation_level {
|
|
builder = builder.isolation_level(isolation_level);
|
|
}
|
|
if txn_read_only {
|
|
builder = builder.read_only(true);
|
|
}
|
|
if txn_deferrable {
|
|
builder = builder.deferrable(true);
|
|
}
|
|
|
|
let transaction = builder.start().await.map_err(|e| {
|
|
// if we cannot start a transaction, we should return immediately
|
|
// and not return to the pool. connection is clearly broken
|
|
discard.discard();
|
|
e
|
|
})?;
|
|
|
|
let results = match query_batch(
|
|
&transaction,
|
|
statements,
|
|
&mut size,
|
|
raw_output,
|
|
default_array_mode,
|
|
)
|
|
.await
|
|
{
|
|
Ok(results) => {
|
|
let status = transaction.commit().await.map_err(|e| {
|
|
// if we cannot commit - for now don't return connection to pool
|
|
// TODO: get a query status from the error
|
|
discard.discard();
|
|
e
|
|
})?;
|
|
discard.check_idle(status);
|
|
results
|
|
}
|
|
Err(err) => {
|
|
let status = transaction.rollback().await.map_err(|e| {
|
|
// if we cannot rollback - for now don't return connection to pool
|
|
// TODO: get a query status from the error
|
|
discard.discard();
|
|
e
|
|
})?;
|
|
discard.check_idle(status);
|
|
return Err(err);
|
|
}
|
|
};
|
|
|
|
if txn_read_only {
|
|
response = response.header(
|
|
TXN_READ_ONLY.clone(),
|
|
HeaderValue::try_from(txn_read_only.to_string())?,
|
|
);
|
|
}
|
|
if txn_deferrable {
|
|
response = response.header(
|
|
TXN_DEFERRABLE.clone(),
|
|
HeaderValue::try_from(txn_deferrable.to_string())?,
|
|
);
|
|
}
|
|
if let Some(txn_isolation_level) = txn_isolation_level_raw {
|
|
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
|
|
}
|
|
json!({ "results": results })
|
|
}
|
|
};
|
|
|
|
ctx.set_success();
|
|
ctx.log();
|
|
let metrics = client.metrics();
|
|
|
|
// how could this possibly fail
|
|
let body = serde_json::to_string(&result).expect("json serialization should not fail");
|
|
let len = body.len();
|
|
let response = response
|
|
.body(Body::from(body))
|
|
// only fails if invalid status code or invalid header/values are given.
|
|
// these are not user configurable so it cannot fail dynamically
|
|
.expect("building response payload should not fail");
|
|
|
|
// count the egress bytes - we miss the TLS and header overhead but oh well...
|
|
// moving this later in the stack is going to be a lot of effort and ehhhh
|
|
metrics.record_egress(len as u64);
|
|
|
|
Ok(response)
|
|
}
|
|
|
|
async fn query_batch(
|
|
transaction: &Transaction<'_>,
|
|
queries: BatchQueryData,
|
|
total_size: &mut usize,
|
|
raw_output: bool,
|
|
array_mode: bool,
|
|
) -> anyhow::Result<Vec<Value>> {
|
|
let mut results = Vec::with_capacity(queries.queries.len());
|
|
let mut current_size = 0;
|
|
for stmt in queries.queries {
|
|
// TODO: maybe we should check that the transaction bit is set here
|
|
let (_, values) =
|
|
query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?;
|
|
results.push(values);
|
|
}
|
|
*total_size += current_size;
|
|
Ok(results)
|
|
}
|
|
|
|
async fn query_to_json<T: GenericClient>(
|
|
client: &T,
|
|
data: QueryData,
|
|
current_size: &mut usize,
|
|
raw_output: bool,
|
|
default_array_mode: bool,
|
|
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
|
|
let query_params = data.params;
|
|
let row_stream = client.query_raw_txt(&data.query, query_params).await?;
|
|
|
|
// Manually drain the stream into a vector to leave row_stream hanging
|
|
// around to get a command tag. Also check that the response is not too
|
|
// big.
|
|
pin_mut!(row_stream);
|
|
let mut rows: Vec<tokio_postgres::Row> = Vec::new();
|
|
while let Some(row) = row_stream.next().await {
|
|
let row = row?;
|
|
*current_size += row.body_len();
|
|
rows.push(row);
|
|
// we don't have a streaming response support yet so this is to prevent OOM
|
|
// from a malicious query (eg a cross join)
|
|
if *current_size > MAX_RESPONSE_SIZE {
|
|
return Err(anyhow::anyhow!(
|
|
"response is too large (max is {MAX_RESPONSE_SIZE} bytes)"
|
|
));
|
|
}
|
|
}
|
|
|
|
let ready = row_stream.ready_status();
|
|
|
|
// grab the command tag and number of rows affected
|
|
let command_tag = row_stream.command_tag().unwrap_or_default();
|
|
let mut command_tag_split = command_tag.split(' ');
|
|
let command_tag_name = command_tag_split.next().unwrap_or_default();
|
|
let command_tag_count = if command_tag_name == "INSERT" {
|
|
// INSERT returns OID first and then number of rows
|
|
command_tag_split.nth(1)
|
|
} else {
|
|
// other commands return number of rows (if any)
|
|
command_tag_split.next()
|
|
}
|
|
.and_then(|s| s.parse::<i64>().ok());
|
|
|
|
let mut fields = vec![];
|
|
let mut columns = vec![];
|
|
|
|
for c in row_stream.columns() {
|
|
fields.push(json!({
|
|
"name": Value::String(c.name().to_owned()),
|
|
"dataTypeID": Value::Number(c.type_().oid().into()),
|
|
"tableID": c.table_oid(),
|
|
"columnID": c.column_id(),
|
|
"dataTypeSize": c.type_size(),
|
|
"dataTypeModifier": c.type_modifier(),
|
|
"format": "text",
|
|
}));
|
|
columns.push(client.get_type(c.type_oid()).await?);
|
|
}
|
|
|
|
let array_mode = data.array_mode.unwrap_or(default_array_mode);
|
|
|
|
// convert rows to JSON
|
|
let rows = rows
|
|
.iter()
|
|
.map(|row| pg_text_row_to_json(row, &columns, raw_output, array_mode))
|
|
.collect::<Result<Vec<_>, _>>()?;
|
|
|
|
// resulting JSON format is based on the format of node-postgres result
|
|
Ok((
|
|
ready,
|
|
json!({
|
|
"command": command_tag_name,
|
|
"rowCount": command_tag_count,
|
|
"rows": rows,
|
|
"fields": fields,
|
|
"rowAsArray": array_mode,
|
|
}),
|
|
))
|
|
}
|