diff --git a/proxy/src/http/sql_over_http.rs b/proxy/src/http/sql_over_http.rs
index 1f83fcfc68..a4eb13a10b 100644
--- a/proxy/src/http/sql_over_http.rs
+++ b/proxy/src/http/sql_over_http.rs
@@ -1,7 +1,9 @@
use std::sync::Arc;
+use anyhow::bail;
use futures::pin_mut;
use futures::StreamExt;
+use hashbrown::HashMap;
use hyper::body::HttpBody;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
@@ -12,6 +14,7 @@ use serde_json::Value;
use tokio_postgres::types::Kind;
use tokio_postgres::types::Type;
use tokio_postgres::GenericClient;
+use tokio_postgres::IsolationLevel;
use tokio_postgres::Row;
use url::Url;
@@ -37,6 +40,8 @@ const MAX_REQUEST_SIZE: u64 = 1024 * 1024; // 1 MB
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
+static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
+static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
@@ -170,7 +175,7 @@ pub async fn handle(
request: Request
,
sni_hostname: Option,
conn_pool: Arc,
-) -> anyhow::Result {
+) -> anyhow::Result<(Value, HashMap)> {
//
// Determine the destination and connection params
//
@@ -185,6 +190,23 @@ pub async fn handle(
// Allow connection pooling only if explicitly requested
let allow_pool = headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
+ // isolation level and read only
+
+ let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned();
+ let txn_isolation_level = match txn_isolation_level_raw {
+ Some(ref x) => Some(match x.as_bytes() {
+ b"Serializable" => IsolationLevel::Serializable,
+ b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
+ b"ReadCommitted" => IsolationLevel::ReadCommitted,
+ b"RepeatableRead" => IsolationLevel::RepeatableRead,
+ _ => bail!("invalid isolation level"),
+ }),
+ None => None,
+ };
+
+ let txn_read_only_raw = headers.get(&TXN_READ_ONLY).cloned();
+ let txn_read_only = txn_read_only_raw.as_ref() == Some(&HEADER_VALUE_TRUE);
+
let request_content_length = match request.body().size_hint().upper() {
Some(v) => v,
None => MAX_REQUEST_SIZE + 1,
@@ -208,10 +230,19 @@ pub async fn handle(
// Now execute the query and return the result
//
let result = match payload {
- Payload::Single(query) => query_to_json(&client, query, raw_output, array_mode).await,
+ Payload::Single(query) => query_to_json(&client, query, raw_output, array_mode)
+ .await
+ .map(|x| (x, HashMap::default())),
Payload::Batch(queries) => {
let mut results = Vec::new();
- let transaction = client.transaction().await?;
+ let mut builder = client.build_transaction();
+ if let Some(isolation_level) = txn_isolation_level {
+ builder = builder.isolation_level(isolation_level);
+ }
+ if txn_read_only {
+ builder = builder.read_only(true);
+ }
+ let transaction = builder.start().await?;
for query in queries {
let result = query_to_json(&transaction, query, raw_output, array_mode).await;
match result {
@@ -223,7 +254,15 @@ pub async fn handle(
}
}
transaction.commit().await?;
- Ok(json!({ "results": results }))
+ let mut headers = HashMap::default();
+ headers.insert(
+ TXN_READ_ONLY.clone(),
+ HeaderValue::try_from(txn_read_only.to_string())?,
+ );
+ if let Some(txn_isolation_level_raw) = txn_isolation_level_raw {
+ headers.insert(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level_raw);
+ }
+ Ok((json!({ "results": results }), headers))
}
};
diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs
index 4b6e15dc3a..fec76c74f4 100644
--- a/proxy/src/http/websocket.rs
+++ b/proxy/src/http/websocket.rs
@@ -6,6 +6,7 @@ use crate::{
};
use bytes::{Buf, Bytes};
use futures::{Sink, Stream, StreamExt};
+use hashbrown::HashMap;
use hyper::{
server::{
accept,
@@ -205,7 +206,7 @@ async fn ws_handler(
Ok(_) => StatusCode::OK,
Err(_) => StatusCode::BAD_REQUEST,
};
- let json = match result {
+ let (json, headers) = match result {
Ok(r) => r,
Err(e) => {
let message = format!("{:?}", e);
@@ -216,7 +217,10 @@ async fn ws_handler(
},
None => Value::Null,
};
- json!({ "message": message, "code": code })
+ (
+ json!({ "message": message, "code": code }),
+ HashMap::default(),
+ )
}
};
json_response(status_code, json).map(|mut r| {
@@ -224,6 +228,9 @@ async fn ws_handler(
"Access-Control-Allow-Origin",
hyper::http::HeaderValue::from_static("*"),
);
+ for (k, v) in headers {
+ r.headers_mut().insert(k, v);
+ }
r
})
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py
index d5bf98109c..35334ec7b2 100644
--- a/test_runner/regress/test_proxy.py
+++ b/test_runner/regress/test_proxy.py
@@ -265,18 +265,23 @@ def test_sql_over_http_output_options(static_proxy: NeonProxy):
def test_sql_over_http_batch(static_proxy: NeonProxy):
static_proxy.safe_psql("create role http with login password 'http' superuser")
- def qq(queries: List[Tuple[str, Optional[List[Any]]]]) -> Any:
+ def qq(queries: List[Tuple[str, Optional[List[Any]]]], read_only: bool = False) -> Any:
connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres"
response = requests.post(
f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
data=json.dumps(list(map(lambda x: {"query": x[0], "params": x[1] or []}, queries))),
- headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr},
+ headers={
+ "Content-Type": "application/sql",
+ "Neon-Connection-String": connstr,
+ "Neon-Batch-Isolation-Level": "Serializable",
+ "Neon-Batch-Read-Only": "true" if read_only else "false",
+ },
verify=str(static_proxy.test_output_dir / "proxy.crt"),
)
assert response.status_code == 200
- return response.json()["results"]
+ return response.json()["results"], response.headers
- result = qq(
+ result, headers = qq(
[
("select 42 as answer", None),
("select $1 as answer", [42]),
@@ -291,6 +296,9 @@ def test_sql_over_http_batch(static_proxy: NeonProxy):
]
)
+ assert headers["Neon-Batch-Isolation-Level"] == "Serializable"
+ assert headers["Neon-Batch-Read-Only"] == "false"
+
assert result[0]["rows"] == [{"answer": 42}]
assert result[1]["rows"] == [{"answer": "42"}]
assert result[2]["rows"] == [{"answer": 42}]
@@ -311,3 +319,14 @@ def test_sql_over_http_batch(static_proxy: NeonProxy):
assert res["command"] == "DROP"
assert res["rowCount"] is None
assert len(result) == 10
+
+ result, headers = qq(
+ [
+ ("select 42 as answer", None),
+ ],
+ True,
+ )
+ assert headers["Neon-Batch-Isolation-Level"] == "Serializable"
+ assert headers["Neon-Batch-Read-Only"] == "true"
+
+ assert result[0]["rows"] == [{"answer": 42}]