proxy: add per query array mode flag (#6678)

## Problem

Drizzle needs to be able to configure the array_mode flag per query.

## Summary of changes

Adds an array_mode flag to the query data json that will otherwise
default to the header flag.
This commit is contained in:
Conrad Ludgate
2024-02-09 10:29:20 +00:00
committed by GitHub
parent 951c9bf4ca
commit ea089dc977
2 changed files with 119 additions and 77 deletions

View File

@@ -44,10 +44,13 @@ use super::json::pg_text_row_to_json;
use super::SERVERLESS_DRIVER_SNI;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct QueryData {
query: String,
#[serde(deserialize_with = "bytes_to_pg_text")]
params: Vec<Option<String>>,
#[serde(default)]
array_mode: Option<bool>,
}
#[derive(serde::Deserialize)]
@@ -330,7 +333,7 @@ async fn handle_inner(
// Determine the output options. Default behaviour is 'false'. Anything that is not
// strictly 'true' assumed to be false.
let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
// Allow connection pooling only if explicitly requested
// or if we have decided that http pool is no longer opt-in
@@ -402,83 +405,87 @@ async fn handle_inner(
// Now execute the query and return the result
//
let mut size = 0;
let result =
match payload {
Payload::Single(stmt) => {
let (status, results) =
query_to_json(&*client, stmt, &mut 0, raw_output, array_mode)
.await
.map_err(|e| {
client.discard();
e
})?;
client.check_idle(status);
results
let result = match payload {
Payload::Single(stmt) => {
let (status, results) =
query_to_json(&*client, stmt, &mut 0, raw_output, default_array_mode)
.await
.map_err(|e| {
client.discard();
e
})?;
client.check_idle(status);
results
}
Payload::Batch(statements) => {
let (inner, mut discard) = client.inner();
let mut builder = inner.build_transaction();
if let Some(isolation_level) = txn_isolation_level {
builder = builder.isolation_level(isolation_level);
}
Payload::Batch(statements) => {
let (inner, mut discard) = client.inner();
let mut builder = inner.build_transaction();
if let Some(isolation_level) = txn_isolation_level {
builder = builder.isolation_level(isolation_level);
}
if txn_read_only {
builder = builder.read_only(true);
}
if txn_deferrable {
builder = builder.deferrable(true);
}
let transaction = builder.start().await.map_err(|e| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
e
})?;
let results =
match query_batch(&transaction, statements, &mut size, raw_output, array_mode)
.await
{
Ok(results) => {
let status = transaction.commit().await.map_err(|e| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
e
})?;
discard.check_idle(status);
results
}
Err(err) => {
let status = transaction.rollback().await.map_err(|e| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
e
})?;
discard.check_idle(status);
return Err(err);
}
};
if txn_read_only {
response = response.header(
TXN_READ_ONLY.clone(),
HeaderValue::try_from(txn_read_only.to_string())?,
);
}
if txn_deferrable {
response = response.header(
TXN_DEFERRABLE.clone(),
HeaderValue::try_from(txn_deferrable.to_string())?,
);
}
if let Some(txn_isolation_level) = txn_isolation_level_raw {
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
}
json!({ "results": results })
if txn_read_only {
builder = builder.read_only(true);
}
};
if txn_deferrable {
builder = builder.deferrable(true);
}
let transaction = builder.start().await.map_err(|e| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
e
})?;
let results = match query_batch(
&transaction,
statements,
&mut size,
raw_output,
default_array_mode,
)
.await
{
Ok(results) => {
let status = transaction.commit().await.map_err(|e| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
e
})?;
discard.check_idle(status);
results
}
Err(err) => {
let status = transaction.rollback().await.map_err(|e| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
e
})?;
discard.check_idle(status);
return Err(err);
}
};
if txn_read_only {
response = response.header(
TXN_READ_ONLY.clone(),
HeaderValue::try_from(txn_read_only.to_string())?,
);
}
if txn_deferrable {
response = response.header(
TXN_DEFERRABLE.clone(),
HeaderValue::try_from(txn_deferrable.to_string())?,
);
}
if let Some(txn_isolation_level) = txn_isolation_level_raw {
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
}
json!({ "results": results })
}
};
ctx.set_success();
ctx.log();
@@ -524,7 +531,7 @@ async fn query_to_json<T: GenericClient>(
data: QueryData,
current_size: &mut usize,
raw_output: bool,
array_mode: bool,
default_array_mode: bool,
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
let query_params = data.params;
let row_stream = client.query_raw_txt(&data.query, query_params).await?;
@@ -578,6 +585,8 @@ async fn query_to_json<T: GenericClient>(
columns.push(client.get_type(c.type_oid()).await?);
}
let array_mode = data.array_mode.unwrap_or(default_array_mode);
// convert rows to JSON
let rows = rows
.iter()

View File

@@ -390,6 +390,39 @@ def test_sql_over_http_batch(static_proxy: NeonProxy):
assert result[0]["rows"] == [{"answer": 42}]
def test_sql_over_http_batch_output_options(static_proxy: NeonProxy):
static_proxy.safe_psql("create role http with login password 'http' superuser")
connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres"
response = requests.post(
f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
data=json.dumps(
{
"queries": [
{"query": "select $1 as answer", "params": [42], "arrayMode": True},
{"query": "select $1 as answer", "params": [42], "arrayMode": False},
]
}
),
headers={
"Content-Type": "application/sql",
"Neon-Connection-String": connstr,
"Neon-Batch-Isolation-Level": "Serializable",
"Neon-Batch-Read-Only": "false",
"Neon-Batch-Deferrable": "false",
},
verify=str(static_proxy.test_output_dir / "proxy.crt"),
)
assert response.status_code == 200
results = response.json()["results"]
assert results[0]["rowAsArray"]
assert results[0]["rows"] == [["42"]]
assert not results[1]["rowAsArray"]
assert results[1]["rows"] == [{"answer": "42"}]
def test_sql_over_http_pool(static_proxy: NeonProxy):
static_proxy.safe_psql("create user http_auth with password 'http' superuser")