diff --git a/proxy/src/http/sql_over_http.rs b/proxy/src/http/sql_over_http.rs index 1f83fcfc68..a4eb13a10b 100644 --- a/proxy/src/http/sql_over_http.rs +++ b/proxy/src/http/sql_over_http.rs @@ -1,7 +1,9 @@ use std::sync::Arc; +use anyhow::bail; use futures::pin_mut; use futures::StreamExt; +use hashbrown::HashMap; use hyper::body::HttpBody; use hyper::http::HeaderName; use hyper::http::HeaderValue; @@ -12,6 +14,7 @@ use serde_json::Value; use tokio_postgres::types::Kind; use tokio_postgres::types::Type; use tokio_postgres::GenericClient; +use tokio_postgres::IsolationLevel; use tokio_postgres::Row; use url::Url; @@ -37,6 +40,8 @@ const MAX_REQUEST_SIZE: u64 = 1024 * 1024; // 1 MB static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode"); static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in"); +static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level"); +static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only"); static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); @@ -170,7 +175,7 @@ pub async fn handle( request: Request, sni_hostname: Option, conn_pool: Arc, -) -> anyhow::Result { +) -> anyhow::Result<(Value, HashMap)> { // // Determine the destination and connection params // @@ -185,6 +190,23 @@ pub async fn handle( // Allow connection pooling only if explicitly requested let allow_pool = headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE); + // isolation level and read only + + let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned(); + let txn_isolation_level = match txn_isolation_level_raw { + Some(ref x) => Some(match x.as_bytes() { + b"Serializable" => IsolationLevel::Serializable, + b"ReadUncommitted" => IsolationLevel::ReadUncommitted, + b"ReadCommitted" => IsolationLevel::ReadCommitted, + b"RepeatableRead" => IsolationLevel::RepeatableRead, + _ => bail!("invalid isolation level"), + }), + None => None, + }; + + let txn_read_only_raw = headers.get(&TXN_READ_ONLY).cloned(); + let txn_read_only = txn_read_only_raw.as_ref() == Some(&HEADER_VALUE_TRUE); + let request_content_length = match request.body().size_hint().upper() { Some(v) => v, None => MAX_REQUEST_SIZE + 1, @@ -208,10 +230,19 @@ pub async fn handle( // Now execute the query and return the result // let result = match payload { - Payload::Single(query) => query_to_json(&client, query, raw_output, array_mode).await, + Payload::Single(query) => query_to_json(&client, query, raw_output, array_mode) + .await + .map(|x| (x, HashMap::default())), Payload::Batch(queries) => { let mut results = Vec::new(); - let transaction = client.transaction().await?; + let mut builder = client.build_transaction(); + if let Some(isolation_level) = txn_isolation_level { + builder = builder.isolation_level(isolation_level); + } + if txn_read_only { + builder = builder.read_only(true); + } + let transaction = builder.start().await?; for query in queries { let result = query_to_json(&transaction, query, raw_output, array_mode).await; match result { @@ -223,7 +254,15 @@ pub async fn handle( } } transaction.commit().await?; - Ok(json!({ "results": results })) + let mut headers = HashMap::default(); + headers.insert( + TXN_READ_ONLY.clone(), + HeaderValue::try_from(txn_read_only.to_string())?, + ); + if let Some(txn_isolation_level_raw) = txn_isolation_level_raw { + headers.insert(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level_raw); + } + Ok((json!({ "results": results }), headers)) } }; diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index 4b6e15dc3a..fec76c74f4 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -6,6 +6,7 @@ use crate::{ }; use bytes::{Buf, Bytes}; use futures::{Sink, Stream, StreamExt}; +use hashbrown::HashMap; use hyper::{ server::{ accept, @@ -205,7 +206,7 @@ async fn ws_handler( Ok(_) => StatusCode::OK, Err(_) => StatusCode::BAD_REQUEST, }; - let json = match result { + let (json, headers) = match result { Ok(r) => r, Err(e) => { let message = format!("{:?}", e); @@ -216,7 +217,10 @@ async fn ws_handler( }, None => Value::Null, }; - json!({ "message": message, "code": code }) + ( + json!({ "message": message, "code": code }), + HashMap::default(), + ) } }; json_response(status_code, json).map(|mut r| { @@ -224,6 +228,9 @@ async fn ws_handler( "Access-Control-Allow-Origin", hyper::http::HeaderValue::from_static("*"), ); + for (k, v) in headers { + r.headers_mut().insert(k, v); + } r }) } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS { diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index d5bf98109c..35334ec7b2 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -265,18 +265,23 @@ def test_sql_over_http_output_options(static_proxy: NeonProxy): def test_sql_over_http_batch(static_proxy: NeonProxy): static_proxy.safe_psql("create role http with login password 'http' superuser") - def qq(queries: List[Tuple[str, Optional[List[Any]]]]) -> Any: + def qq(queries: List[Tuple[str, Optional[List[Any]]]], read_only: bool = False) -> Any: connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres" response = requests.post( f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql", data=json.dumps(list(map(lambda x: {"query": x[0], "params": x[1] or []}, queries))), - headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr}, + headers={ + "Content-Type": "application/sql", + "Neon-Connection-String": connstr, + "Neon-Batch-Isolation-Level": "Serializable", + "Neon-Batch-Read-Only": "true" if read_only else "false", + }, verify=str(static_proxy.test_output_dir / "proxy.crt"), ) assert response.status_code == 200 - return response.json()["results"] + return response.json()["results"], response.headers - result = qq( + result, headers = qq( [ ("select 42 as answer", None), ("select $1 as answer", [42]), @@ -291,6 +296,9 @@ def test_sql_over_http_batch(static_proxy: NeonProxy): ] ) + assert headers["Neon-Batch-Isolation-Level"] == "Serializable" + assert headers["Neon-Batch-Read-Only"] == "false" + assert result[0]["rows"] == [{"answer": 42}] assert result[1]["rows"] == [{"answer": "42"}] assert result[2]["rows"] == [{"answer": 42}] @@ -311,3 +319,14 @@ def test_sql_over_http_batch(static_proxy: NeonProxy): assert res["command"] == "DROP" assert res["rowCount"] is None assert len(result) == 10 + + result, headers = qq( + [ + ("select 42 as answer", None), + ], + True, + ) + assert headers["Neon-Batch-Isolation-Level"] == "Serializable" + assert headers["Neon-Batch-Read-Only"] == "true" + + assert result[0]["rows"] == [{"answer": 42}]