diff --git a/Cargo.lock b/Cargo.lock index 5665a9ef88..8c641bd36c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3561,7 +3561,7 @@ dependencies = [ [[package]] name = "postgres" version = "0.19.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?rev=9011f7110db12b5e15afaf98f8ac834501d50ddc#9011f7110db12b5e15afaf98f8ac834501d50ddc" +source = "git+https://github.com/neondatabase/rust-postgres.git?rev=a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d#a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d" dependencies = [ "bytes", "fallible-iterator", @@ -3574,7 +3574,7 @@ dependencies = [ [[package]] name = "postgres-native-tls" version = "0.5.0" -source = "git+https://github.com/neondatabase/rust-postgres.git?rev=9011f7110db12b5e15afaf98f8ac834501d50ddc#9011f7110db12b5e15afaf98f8ac834501d50ddc" +source = "git+https://github.com/neondatabase/rust-postgres.git?rev=a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d#a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d" dependencies = [ "native-tls", "tokio", @@ -3585,7 +3585,7 @@ dependencies = [ [[package]] name = "postgres-protocol" version = "0.6.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?rev=9011f7110db12b5e15afaf98f8ac834501d50ddc#9011f7110db12b5e15afaf98f8ac834501d50ddc" +source = "git+https://github.com/neondatabase/rust-postgres.git?rev=a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d#a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d" dependencies = [ "base64 0.20.0", "byteorder", @@ -3603,7 +3603,7 @@ dependencies = [ [[package]] name = "postgres-types" version = "0.2.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?rev=9011f7110db12b5e15afaf98f8ac834501d50ddc#9011f7110db12b5e15afaf98f8ac834501d50ddc" +source = "git+https://github.com/neondatabase/rust-postgres.git?rev=a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d#a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d" dependencies = [ "bytes", "fallible-iterator", @@ -5407,7 +5407,7 @@ dependencies = [ [[package]] name = "tokio-postgres" version = "0.7.7" -source = "git+https://github.com/neondatabase/rust-postgres.git?rev=9011f7110db12b5e15afaf98f8ac834501d50ddc#9011f7110db12b5e15afaf98f8ac834501d50ddc" +source = "git+https://github.com/neondatabase/rust-postgres.git?rev=a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d#a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d" dependencies = [ "async-trait", "byteorder", @@ -5422,7 +5422,7 @@ dependencies = [ "pin-project-lite", "postgres-protocol", "postgres-types", - "socket2 0.4.9", + "socket2 0.5.3", "tokio", "tokio-util", ] @@ -6497,7 +6497,6 @@ dependencies = [ "serde", "serde_json", "smallvec", - "socket2 0.4.9", "standback", "syn 1.0.109", "syn 2.0.28", diff --git a/Cargo.toml b/Cargo.toml index 621b7af564..727fcbd5a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -160,11 +160,11 @@ env_logger = "0.10" log = "0.4" ## Libraries from neondatabase/ git forks, ideally with changes to be upstreamed -postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" } -postgres-native-tls = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" } -postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" } -postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" } -tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" } +postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d" } +postgres-native-tls = { git = "https://github.com/neondatabase/rust-postgres.git", rev="a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d" } +postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d" } +postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d" } +tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d" } ## Other git libraries heapless = { default-features=false, features=[], git = "https://github.com/japaric/heapless.git", rev = "644653bf3b831c6bb4963be2de24804acf5e5001" } # upstream release pending @@ -200,7 +200,7 @@ tonic-build = "0.9" # This is only needed for proxy's tests. # TODO: we should probably fork `tokio-postgres-rustls` instead. -tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" } +tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="a2d0652ec3f8f710ff8cfc2e7c68f096fb852d9d" } ################# Binary contents sections diff --git a/proxy/src/http/conn_pool.rs b/proxy/src/http/conn_pool.rs index b268c4073e..ed84b0f2bf 100644 --- a/proxy/src/http/conn_pool.rs +++ b/proxy/src/http/conn_pool.rs @@ -8,14 +8,17 @@ use pbkdf2::{ Params, Pbkdf2, }; use pq_proto::StartupMessageParams; -use std::sync::atomic::{self, AtomicUsize}; use std::{collections::HashMap, sync::Arc}; use std::{ fmt, task::{ready, Poll}, }; +use std::{ + ops::Deref, + sync::atomic::{self, AtomicUsize}, +}; use tokio::time; -use tokio_postgres::AsyncMessage; +use tokio_postgres::{AsyncMessage, ReadyForQueryStatus}; use crate::{ auth, console, @@ -26,13 +29,13 @@ use crate::{compute, config}; use crate::proxy::ConnectMechanism; -use tracing::{error, warn}; +use tracing::{error, warn, Span}; use tracing::{info, info_span, Instrument}; pub const APP_NAME: &str = "sql_over_http"; const MAX_CONNS_PER_ENDPOINT: usize = 20; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ConnInfo { pub username: String, pub dbname: String, @@ -55,7 +58,7 @@ impl fmt::Display for ConnInfo { } struct ConnPoolEntry { - conn: Client, + conn: ClientInner, _last_access: std::time::Instant, } @@ -133,14 +136,20 @@ impl GlobalConnPool { } pub async fn get( - &self, + self: &Arc, conn_info: &ConnInfo, force_new: bool, session_id: uuid::Uuid, ) -> anyhow::Result { - let mut client: Option = None; + let mut client: Option = None; let mut latency_timer = LatencyTimer::new("http"); + let pool = if force_new { + None + } else { + Some((conn_info.clone(), self.clone())) + }; + let mut hash_valid = false; if !force_new { let pool = self.get_or_create_endpoint_pool(&conn_info.hostname); @@ -188,7 +197,11 @@ impl GlobalConnPool { latency_timer.pool_hit(); info!("pool: reusing connection '{conn_info}'"); client.session.send(session_id)?; - return Ok(client); + return Ok(Client { + inner: Some(client), + span: Span::current(), + pool, + }); } } else { info!("pool: opening a new connection '{conn_info}'"); @@ -228,10 +241,14 @@ impl GlobalConnPool { _ => {} } - new_client + new_client.map(|inner| Client { + inner: Some(inner), + span: Span::current(), + pool, + }) } - pub fn put(&self, conn_info: &ConnInfo, client: Client) -> anyhow::Result<()> { + fn put(&self, conn_info: &ConnInfo, client: ClientInner) -> anyhow::Result<()> { // We want to hold this open while we return. This ensures that the pool can't close // while we are in the middle of returning the connection. let closed = self.closed.read(); @@ -326,7 +343,7 @@ struct TokioMechanism<'a> { #[async_trait] impl ConnectMechanism for TokioMechanism<'_> { - type Connection = Client; + type Connection = ClientInner; type ConnectError = tokio_postgres::Error; type Error = anyhow::Error; @@ -350,7 +367,7 @@ async fn connect_to_compute( conn_info: &ConnInfo, session_id: uuid::Uuid, latency_timer: LatencyTimer, -) -> anyhow::Result { +) -> anyhow::Result { let tls = config.tls_config.as_ref(); let common_names = tls.and_then(|tls| tls.common_names.clone()); @@ -399,7 +416,7 @@ async fn connect_to_compute_once( conn_info: &ConnInfo, timeout: time::Duration, mut session: uuid::Uuid, -) -> Result { +) -> Result { let mut config = (*node_info.config).clone(); let (client, mut connection) = config @@ -462,21 +479,99 @@ async fn connect_to_compute_once( .instrument(span) ); - Ok(Client { + Ok(ClientInner { inner: client, session: tx, ids, }) } -pub struct Client { - pub inner: tokio_postgres::Client, +struct ClientInner { + inner: tokio_postgres::Client, session: tokio::sync::watch::Sender, ids: Ids, } impl Client { pub fn metrics(&self) -> Arc { - USAGE_METRICS.register(self.ids.clone()) + USAGE_METRICS.register(self.inner.as_ref().unwrap().ids.clone()) + } +} + +pub struct Client { + span: Span, + inner: Option, + pool: Option<(ConnInfo, Arc)>, +} + +pub struct Discard<'a> { + pool: &'a mut Option<(ConnInfo, Arc)>, +} + +impl Client { + pub fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) { + let Self { + inner, + pool, + span: _, + } = self; + ( + &mut inner + .as_mut() + .expect("client inner should not be removed") + .inner, + Discard { pool }, + ) + } + + pub fn check_idle(&mut self, status: ReadyForQueryStatus) { + self.inner().1.check_idle(status) + } + pub fn discard(&mut self) { + self.inner().1.discard() + } +} + +impl Discard<'_> { + pub fn check_idle(&mut self, status: ReadyForQueryStatus) { + if status != ReadyForQueryStatus::Idle { + if let Some((conn_info, _)) = self.pool.take() { + info!("pool: throwing away connection '{conn_info}' because connection is not idle") + } + } + } + pub fn discard(&mut self) { + if let Some((conn_info, _)) = self.pool.take() { + info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state") + } + } +} + +impl Deref for Client { + type Target = tokio_postgres::Client; + + fn deref(&self) -> &Self::Target { + &self + .inner + .as_ref() + .expect("client inner should not be removed") + .inner + } +} + +impl Drop for Client { + fn drop(&mut self) { + let client = self + .inner + .take() + .expect("client inner should not be removed"); + if let Some((conn_info, conn_pool)) = self.pool.take() { + let current_span = self.span.clone(); + // return connection to the pool + tokio::task::spawn_blocking(move || { + let _span = current_span.enter(); + let _ = conn_pool.put(&conn_info, client); + }); + } } } diff --git a/proxy/src/http/sql_over_http.rs b/proxy/src/http/sql_over_http.rs index 02ce3aa9dc..93df86cfc4 100644 --- a/proxy/src/http/sql_over_http.rs +++ b/proxy/src/http/sql_over_http.rs @@ -17,7 +17,9 @@ use tokio_postgres::types::Kind; use tokio_postgres::types::Type; use tokio_postgres::GenericClient; use tokio_postgres::IsolationLevel; +use tokio_postgres::ReadyForQueryStatus; use tokio_postgres::Row; +use tokio_postgres::Transaction; use tracing::error; use tracing::instrument; use url::Url; @@ -64,20 +66,18 @@ static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); // Convert json non-string types to strings, so that they can be passed to Postgres // as parameters. // -fn json_to_pg_text(json: Vec) -> Result>, serde_json::Error> { +fn json_to_pg_text(json: Vec) -> Vec> { json.iter() .map(|value| { match value { // special care for nulls - Value::Null => Ok(None), + Value::Null => None, // convert to text with escaping - Value::Bool(_) => serde_json::to_string(value).map(Some), - Value::Number(_) => serde_json::to_string(value).map(Some), - Value::Object(_) => serde_json::to_string(value).map(Some), + v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()), // avoid escaping here, as we pass this as a parameter - Value::String(s) => Ok(Some(s.to_string())), + Value::String(s) => Some(s.to_string()), // special care for arrays Value::Array(_) => json_array_to_pg_array(value), @@ -94,29 +94,26 @@ fn json_to_pg_text(json: Vec) -> Result>, serde_json:: // // Example of the same escaping in node-postgres: packages/pg/lib/utils.js // -fn json_array_to_pg_array(value: &Value) -> Result, serde_json::Error> { +fn json_array_to_pg_array(value: &Value) -> Option { match value { // special care for nulls - Value::Null => Ok(None), + Value::Null => None, // convert to text with escaping - Value::Bool(_) => serde_json::to_string(value).map(Some), - Value::Number(_) => serde_json::to_string(value).map(Some), - // here string needs to be escaped, as it is part of the array - Value::Object(_) => json_array_to_pg_array(&Value::String(serde_json::to_string(value)?)), - Value::String(_) => serde_json::to_string(value).map(Some), + v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()), + v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())), // recurse into array Value::Array(arr) => { let vals = arr .iter() .map(json_array_to_pg_array) - .map(|r| r.map(|v| v.unwrap_or_else(|| "NULL".to_string()))) - .collect::, _>>()? + .map(|v| v.unwrap_or_else(|| "NULL".to_string())) + .collect::>() .join(","); - Ok(Some(format!("{{{}}}", vals))) + Some(format!("{{{}}}", vals)) } } } @@ -315,83 +312,119 @@ async fn handle_inner( // Now execute the query and return the result // let mut size = 0; - let result = match payload { - Payload::Single(query) => { - query_to_json(&client.inner, query, &mut size, raw_output, array_mode).await - } - Payload::Batch(batch_query) => { - let mut results = Vec::new(); - let mut builder = client.inner.build_transaction(); - if let Some(isolation_level) = txn_isolation_level { - builder = builder.isolation_level(isolation_level); + let result = + match payload { + Payload::Single(stmt) => { + let (status, results) = + query_to_json(&*client, stmt, &mut 0, raw_output, array_mode) + .await + .map_err(|e| { + client.discard(); + e + })?; + client.check_idle(status); + results } - if txn_read_only { - builder = builder.read_only(true); - } - if txn_deferrable { - builder = builder.deferrable(true); - } - let transaction = builder.start().await?; - for query in batch_query.queries { - let result = - query_to_json(&transaction, query, &mut size, raw_output, array_mode).await; - match result { - Ok(r) => results.push(r), - Err(e) => { - transaction.rollback().await?; - return Err(e); - } + Payload::Batch(statements) => { + let (inner, mut discard) = client.inner(); + let mut builder = inner.build_transaction(); + if let Some(isolation_level) = txn_isolation_level { + builder = builder.isolation_level(isolation_level); } + if txn_read_only { + builder = builder.read_only(true); + } + if txn_deferrable { + builder = builder.deferrable(true); + } + + let transaction = builder.start().await.map_err(|e| { + // if we cannot start a transaction, we should return immediately + // and not return to the pool. connection is clearly broken + discard.discard(); + e + })?; + + let results = + match query_batch(&transaction, statements, &mut size, raw_output, array_mode) + .await + { + Ok(results) => { + let status = transaction.commit().await.map_err(|e| { + // if we cannot commit - for now don't return connection to pool + // TODO: get a query status from the error + discard.discard(); + e + })?; + discard.check_idle(status); + results + } + Err(err) => { + let status = transaction.rollback().await.map_err(|e| { + // if we cannot rollback - for now don't return connection to pool + // TODO: get a query status from the error + discard.discard(); + e + })?; + discard.check_idle(status); + return Err(err); + } + }; + + if txn_read_only { + response = response.header( + TXN_READ_ONLY.clone(), + HeaderValue::try_from(txn_read_only.to_string())?, + ); + } + if txn_deferrable { + response = response.header( + TXN_DEFERRABLE.clone(), + HeaderValue::try_from(txn_deferrable.to_string())?, + ); + } + if let Some(txn_isolation_level) = txn_isolation_level_raw { + response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level); + } + json!({ "results": results }) } - transaction.commit().await?; - if txn_read_only { - response = response.header( - TXN_READ_ONLY.clone(), - HeaderValue::try_from(txn_read_only.to_string())?, - ); - } - if txn_deferrable { - response = response.header( - TXN_DEFERRABLE.clone(), - HeaderValue::try_from(txn_deferrable.to_string())?, - ); - } - if let Some(txn_isolation_level) = txn_isolation_level_raw { - response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level); - } - Ok(json!({ "results": results })) - } - }; + }; let metrics = client.metrics(); - if allow_pool { - let current_span = tracing::Span::current(); - // return connection to the pool - tokio::task::spawn_blocking(move || { - let _span = current_span.enter(); - let _ = conn_pool.put(&conn_info, client); - }); - } + // how could this possibly fail + let body = serde_json::to_string(&result).expect("json serialization should not fail"); + let len = body.len(); + let response = response + .body(Body::from(body)) + // only fails if invalid status code or invalid header/values are given. + // these are not user configurable so it cannot fail dynamically + .expect("building response payload should not fail"); - match result { - Ok(value) => { - // how could this possibly fail - let body = serde_json::to_string(&value).expect("json serialization should not fail"); - let len = body.len(); - let response = response - .body(Body::from(body)) - // only fails if invalid status code or invalid header/values are given. - // these are not user configurable so it cannot fail dynamically - .expect("building response payload should not fail"); + // count the egress bytes - we miss the TLS and header overhead but oh well... + // moving this later in the stack is going to be a lot of effort and ehhhh + metrics.record_egress(len as u64); - // count the egress bytes - we miss the TLS and header overhead but oh well... - // moving this later in the stack is going to be a lot of effort and ehhhh - metrics.record_egress(len as u64); - Ok(response) - } - Err(e) => Err(e), + Ok(response) +} + +async fn query_batch( + transaction: &Transaction<'_>, + queries: BatchQueryData, + total_size: &mut usize, + raw_output: bool, + array_mode: bool, +) -> anyhow::Result> { + let mut results = Vec::with_capacity(queries.queries.len()); + let mut current_size = 0; + for stmt in queries.queries { + // TODO: maybe we should check that the transaction bit is set here + let (_, values) = + query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?; + results.push(values); } + *total_size += current_size; + Ok(results) } async fn query_to_json( @@ -400,11 +433,9 @@ async fn query_to_json( current_size: &mut usize, raw_output: bool, array_mode: bool, -) -> anyhow::Result { - let query_params = json_to_pg_text(data.params)?; - let row_stream = client - .query_raw_txt::(data.query, query_params) - .await?; +) -> anyhow::Result<(ReadyForQueryStatus, Value)> { + let query_params = json_to_pg_text(data.params); + let row_stream = client.query_raw_txt(&data.query, query_params).await?; // Manually drain the stream into a vector to leave row_stream hanging // around to get a command tag. Also check that the response is not too @@ -424,6 +455,8 @@ async fn query_to_json( } } + let ready = row_stream.ready_status(); + // grab the command tag and number of rows affected let command_tag = row_stream.command_tag().unwrap_or_default(); let mut command_tag_split = command_tag.split(' '); @@ -464,13 +497,16 @@ async fn query_to_json( .collect::, _>>()?; // resulting JSON format is based on the format of node-postgres result - Ok(json!({ - "command": command_tag_name, - "rowCount": command_tag_count, - "rows": rows, - "fields": fields, - "rowAsArray": array_mode, - })) + Ok(( + ready, + json!({ + "command": command_tag_name, + "rowCount": command_tag_count, + "rows": rows, + "fields": fields, + "rowAsArray": array_mode, + }), + )) } // @@ -655,22 +691,22 @@ mod tests { #[test] fn test_atomic_types_to_pg_params() { let json = vec![Value::Bool(true), Value::Bool(false)]; - let pg_params = json_to_pg_text(json).unwrap(); + let pg_params = json_to_pg_text(json); assert_eq!( pg_params, vec![Some("true".to_owned()), Some("false".to_owned())] ); let json = vec![Value::Number(serde_json::Number::from(42))]; - let pg_params = json_to_pg_text(json).unwrap(); + let pg_params = json_to_pg_text(json); assert_eq!(pg_params, vec![Some("42".to_owned())]); let json = vec![Value::String("foo\"".to_string())]; - let pg_params = json_to_pg_text(json).unwrap(); + let pg_params = json_to_pg_text(json); assert_eq!(pg_params, vec![Some("foo\"".to_owned())]); let json = vec![Value::Null]; - let pg_params = json_to_pg_text(json).unwrap(); + let pg_params = json_to_pg_text(json); assert_eq!(pg_params, vec![None]); } @@ -679,7 +715,7 @@ mod tests { // atoms and escaping let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]"; let json: Value = serde_json::from_str(json).unwrap(); - let pg_params = json_to_pg_text(vec![json]).unwrap(); + let pg_params = json_to_pg_text(vec![json]); assert_eq!( pg_params, vec![Some( @@ -690,7 +726,7 @@ mod tests { // nested arrays let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]"; let json: Value = serde_json::from_str(json).unwrap(); - let pg_params = json_to_pg_text(vec![json]).unwrap(); + let pg_params = json_to_pg_text(vec![json]); assert_eq!( pg_params, vec![Some( @@ -700,7 +736,7 @@ mod tests { // array of objects let json = r#"[{"foo": 1},{"bar": 2}]"#; let json: Value = serde_json::from_str(json).unwrap(); - let pg_params = json_to_pg_text(vec![json]).unwrap(); + let pg_params = json_to_pg_text(vec![json]); assert_eq!( pg_params, vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())] diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index f57b15c9b9..a33b29549c 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -7,6 +7,8 @@ import pytest import requests from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres +GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'" + def test_proxy_select_1(static_proxy: NeonProxy): """ @@ -353,7 +355,7 @@ def test_sql_over_http_pool(static_proxy: NeonProxy): def get_pid(status: int, pw: str) -> Any: return static_proxy.http_query( - "SELECT pid FROM pg_stat_activity WHERE state = 'active'", + GET_CONNECTION_PID_QUERY, [], user="http_auth", password=pw, @@ -387,7 +389,6 @@ def test_sql_over_http_pool(static_proxy: NeonProxy): # Beginning a transaction should not impact the next query, # which might come from a completely different client. -@pytest.mark.xfail(reason="not implemented") def test_http_pool_begin(static_proxy: NeonProxy): static_proxy.safe_psql("create user http_auth with password 'http' superuser") @@ -403,3 +404,21 @@ def test_http_pool_begin(static_proxy: NeonProxy): query(200, "BEGIN;") query(400, "garbage-lol(&(&(&(&") # Intentional error to break the transaction query(200, "SELECT 1;") # Query that should succeed regardless of the transaction + + +def test_sql_over_http_pool_idle(static_proxy: NeonProxy): + static_proxy.safe_psql("create user http_auth2 with password 'http' superuser") + + def query(status: int, query: str) -> Any: + return static_proxy.http_query( + query, + [], + user="http_auth2", + password="http", + expected_code=status, + ) + + pid1 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"] + query(200, "BEGIN") + pid2 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"] + assert pid1 != pid2 diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 11e583084c..e2a65ad150 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -55,7 +55,6 @@ scopeguard = { version = "1" } serde = { version = "1", features = ["alloc", "derive"] } serde_json = { version = "1", features = ["raw_value"] } smallvec = { version = "1", default-features = false, features = ["write"] } -socket2 = { version = "0.4", default-features = false, features = ["all"] } standback = { version = "0.2", default-features = false, features = ["std"] } time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] } tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }