diff --git a/Cargo.lock b/Cargo.lock index 55418473d5..4d63ebd99d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2820,7 +2820,7 @@ dependencies = [ [[package]] name = "postgres" version = "0.19.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?rev=0bc41d8503c092b040142214aac3cf7d11d0c19f#0bc41d8503c092b040142214aac3cf7d11d0c19f" +source = "git+https://github.com/neondatabase/rust-postgres.git?rev=2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9#2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9" dependencies = [ "bytes", "fallible-iterator", @@ -2833,7 +2833,7 @@ dependencies = [ [[package]] name = "postgres-native-tls" version = "0.5.0" -source = "git+https://github.com/neondatabase/rust-postgres.git?rev=0bc41d8503c092b040142214aac3cf7d11d0c19f#0bc41d8503c092b040142214aac3cf7d11d0c19f" +source = "git+https://github.com/neondatabase/rust-postgres.git?rev=2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9#2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9" dependencies = [ "native-tls", "tokio", @@ -2844,7 +2844,7 @@ dependencies = [ [[package]] name = "postgres-protocol" version = "0.6.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?rev=0bc41d8503c092b040142214aac3cf7d11d0c19f#0bc41d8503c092b040142214aac3cf7d11d0c19f" +source = "git+https://github.com/neondatabase/rust-postgres.git?rev=2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9#2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9" dependencies = [ "base64 0.20.0", "byteorder", @@ -2862,7 +2862,7 @@ dependencies = [ [[package]] name = "postgres-types" version = "0.2.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?rev=0bc41d8503c092b040142214aac3cf7d11d0c19f#0bc41d8503c092b040142214aac3cf7d11d0c19f" +source = "git+https://github.com/neondatabase/rust-postgres.git?rev=2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9#2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9" dependencies = [ "bytes", "fallible-iterator", @@ -4321,7 +4321,7 @@ dependencies = [ [[package]] name = "tokio-postgres" version = "0.7.7" -source = "git+https://github.com/neondatabase/rust-postgres.git?rev=0bc41d8503c092b040142214aac3cf7d11d0c19f#0bc41d8503c092b040142214aac3cf7d11d0c19f" +source = "git+https://github.com/neondatabase/rust-postgres.git?rev=2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9#2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9" dependencies = [ "async-trait", "byteorder", diff --git a/Cargo.toml b/Cargo.toml index c901532f86..7895459841 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -126,11 +126,11 @@ env_logger = "0.10" log = "0.4" ## Libraries from neondatabase/ git forks, ideally with changes to be upstreamed -postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="0bc41d8503c092b040142214aac3cf7d11d0c19f" } -postgres-native-tls = { git = "https://github.com/neondatabase/rust-postgres.git", rev="0bc41d8503c092b040142214aac3cf7d11d0c19f" } -postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="0bc41d8503c092b040142214aac3cf7d11d0c19f" } -postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="0bc41d8503c092b040142214aac3cf7d11d0c19f" } -tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="0bc41d8503c092b040142214aac3cf7d11d0c19f" } +postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9" } +postgres-native-tls = { git = "https://github.com/neondatabase/rust-postgres.git", rev="2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9" } +postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9" } +postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9" } +tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9" } tokio-tar = { git = "https://github.com/neondatabase/tokio-tar.git", rev="404df61437de0feef49ba2ccdbdd94eb8ad6e142" } ## Other git libraries @@ -166,7 +166,7 @@ tonic-build = "0.9" # This is only needed for proxy's tests. # TODO: we should probably fork `tokio-postgres-rustls` instead. -tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="0bc41d8503c092b040142214aac3cf7d11d0c19f" } +tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9" } # Changes the MAX_THREADS limit from 4096 to 32768. # This is a temporary workaround for using tracing from many threads in safekeepers code, diff --git a/proxy/README.md b/proxy/README.md index 4ead098b73..cd76a2443f 100644 --- a/proxy/README.md +++ b/proxy/README.md @@ -1,6 +1,6 @@ # Proxy -Proxy binary accepts `--auth-backend` CLI option, which determines auth scheme and cluster routing method. Following backends are currently implemented: +Proxy binary accepts `--auth-backend` CLI option, which determines auth scheme and cluster routing method. Following routing backends are currently implemented: * console new SCRAM-based console API; uses SNI info to select the destination project (endpoint soon) @@ -9,6 +9,90 @@ Proxy binary accepts `--auth-backend` CLI option, which determines auth scheme a * link sends login link for all usernames +Also proxy can expose following services to the external world: + +* postgres protocol over TCP -- usual postgres endpoint compatible with usual + postgres drivers +* postgres protocol over WebSockets -- same protocol tunneled over websockets + for environments where TCP connection is not available. We have our own + implementation of a client that uses node-postgres and tunnels traffic through + websockets: https://github.com/neondatabase/serverless +* SQL over HTTP -- service that accepts POST requests with SQL text over HTTP + and responds with JSON-serialised results. + + +## SQL over HTTP + +Contrary to the usual postgres proto over TCP and WebSockets using plain +one-shot HTTP request achieves smaller amortized latencies in edge setups due to +fewer round trips and an enhanced open connection reuse by the v8 engine. Also +such endpoint could be used directly without any driver. + +To play with it locally one may start proxy over a local postgres installation +(see end of this page on how to generate certs with openssl): + +``` +./target/debug/proxy -c server.crt -k server.key --auth-backend=postgres --auth-endpoint=postgres://stas@127.0.0.1:5432/stas --wss 0.0.0.0:4444 +``` + +If both postgres and proxy are running you may send a SQL query: +```json +curl -k -X POST 'https://proxy.localtest.me:4444/sql' \ + -H 'Neon-Connection-String: postgres://stas:pass@proxy.localtest.me:4444/postgres' \ + -H 'Content-Type: application/json' \ + --data '{ + "query":"SELECT $1::int[] as arr, $2::jsonb as obj, 42 as num", + "params":[ "{{1,2},{\"3\",4}}", {"key":"val", "ikey":4242}] + }' | jq + +{ + "command": "SELECT", + "fields": [ + { "dataTypeID": 1007, "name": "arr" }, + { "dataTypeID": 3802, "name": "obj" }, + { "dataTypeID": 23, "name": "num" } + ], + "rowCount": 1, + "rows": [ + { + "arr": [[1,2],[3,4]], + "num": 42, + "obj": { + "ikey": 4242, + "key": "val" + } + } + ] +} +``` + + +With the current approach we made the following design decisions: + +1. SQL injection protection: We employed the extended query protocol, modifying + the rust-postgres driver to send queries in one roundtrip using a text + protocol rather than binary, bypassing potential issues like those identified + in sfackler/rust-postgres#1030. + +2. Postgres type compatibility: As not all postgres types have binary + representations (e.g., acl's in pg_class), we adjusted rust-postgres to + respond with text protocol, simplifying serialization and fixing queries with + text-only types in response. + +3. Data type conversion: Considering JSON supports fewer data types than + Postgres, we perform conversions where possible, passing all other types as + strings. Key conversions include: + - postgres int2, int4, float4, float8 -> json number (NaN and Inf remain + text) + - postgres bool, null, text -> json bool, null, string + - postgres array -> json array + - postgres json and jsonb -> json object + +4. Alignment with node-postgres: To facilitate integration with js libraries, + we've matched the response structure of node-postgres, returning command tags + and column oids. Command tag capturing was added to the rust-postgres + functionality as part of this change. + ## Using SNI-based routing on localhost Now proxy determines project name from the subdomain, request to the `round-rice-566201.somedomain.tld` will be routed to the project named `round-rice-566201`. Unfortunately, `/etc/hosts` does not support domain wildcards, so I usually use `*.localtest.me` which resolves to `127.0.0.1`. Now we can create self-signed certificate and play with proxy: diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 530229b3fd..6a26cea78e 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -100,9 +100,10 @@ impl CertResolver { is_default: bool, ) -> anyhow::Result<()> { let priv_key = { - let key_bytes = std::fs::read(key_path).context("TLS key file")?; - let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]) + let key_bytes = std::fs::read(key_path) .context(format!("Failed to read TLS keys at '{key_path}'"))?; + let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]) + .context(format!("Failed to parse TLS keys at '{key_path}'"))?; ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len()); keys.pop().map(rustls::PrivateKey).unwrap() diff --git a/proxy/src/http.rs b/proxy/src/http.rs index a544157800..5cf49b669c 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -3,6 +3,7 @@ //! directly relying on deps like `reqwest` (think loose coupling). pub mod server; +pub mod sql_over_http; pub mod websocket; pub use reqwest::{Request, Response, StatusCode}; diff --git a/proxy/src/http/sql_over_http.rs b/proxy/src/http/sql_over_http.rs new file mode 100644 index 0000000000..0438a82c12 --- /dev/null +++ b/proxy/src/http/sql_over_http.rs @@ -0,0 +1,603 @@ +use futures::pin_mut; +use futures::StreamExt; +use hyper::body::HttpBody; +use hyper::{Body, HeaderMap, Request}; +use pq_proto::StartupMessageParams; +use serde_json::json; +use serde_json::Map; +use serde_json::Value; +use tokio_postgres::types::Kind; +use tokio_postgres::types::Type; +use tokio_postgres::Row; +use url::Url; + +use crate::{auth, config::ProxyConfig, console}; + +#[derive(serde::Deserialize)] +struct QueryData { + query: String, + params: Vec, +} + +const APP_NAME: &str = "sql_over_http"; +const MAX_RESPONSE_SIZE: usize = 1024 * 1024; // 1 MB +const MAX_REQUEST_SIZE: u64 = 1024 * 1024; // 1 MB + +// +// Convert json non-string types to strings, so that they can be passed to Postgres +// as parameters. +// +fn json_to_pg_text(json: Vec) -> Result, serde_json::Error> { + json.iter() + .map(|value| { + match value { + Value::Null => serde_json::to_string(value), + Value::Bool(_) => serde_json::to_string(value), + Value::Number(_) => serde_json::to_string(value), + Value::Object(_) => serde_json::to_string(value), + + // no need to escape + Value::String(s) => Ok(s.to_string()), + + // special care for arrays + Value::Array(_) => json_array_to_pg_array(value), + } + }) + .collect() +} + +// +// Serialize a JSON array to a Postgres array. Contrary to the strings in the params +// in the array we need to escape the strings. Postgres is okay with arrays of form +// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving +// it for Postgres to check. +// +// Example of the same escaping in node-postgres: packages/pg/lib/utils.js +// +fn json_array_to_pg_array(value: &Value) -> Result { + match value { + // same + Value::Null => serde_json::to_string(value), + Value::Bool(_) => serde_json::to_string(value), + Value::Number(_) => serde_json::to_string(value), + Value::Object(_) => serde_json::to_string(value), + + // now needs to be escaped, as it is part of the array + Value::String(_) => serde_json::to_string(value), + + // recurse into array + Value::Array(arr) => { + let vals = arr + .iter() + .map(json_array_to_pg_array) + .collect::, _>>()? + .join(","); + Ok(format!("{{{}}}", vals)) + } + } +} + +fn get_conn_info( + headers: &HeaderMap, + sni_hostname: Option, +) -> Result<(String, String, String, String), anyhow::Error> { + let connection_string = headers + .get("Neon-Connection-String") + .ok_or(anyhow::anyhow!("missing connection string"))? + .to_str()?; + + let connection_url = Url::parse(connection_string)?; + + let protocol = connection_url.scheme(); + if protocol != "postgres" && protocol != "postgresql" { + return Err(anyhow::anyhow!( + "connection string must start with postgres: or postgresql:" + )); + } + + let mut url_path = connection_url + .path_segments() + .ok_or(anyhow::anyhow!("missing database name"))?; + + let dbname = url_path + .next() + .ok_or(anyhow::anyhow!("invalid database name"))?; + + let username = connection_url.username(); + if username.is_empty() { + return Err(anyhow::anyhow!("missing username")); + } + + let password = connection_url + .password() + .ok_or(anyhow::anyhow!("no password"))?; + + // TLS certificate selector now based on SNI hostname, so if we are running here + // we are sure that SNI hostname is set to one of the configured domain names. + let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?; + + let hostname = connection_url + .host_str() + .ok_or(anyhow::anyhow!("no host"))?; + + let host_header = headers + .get("host") + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.split(':').next()); + + if hostname != sni_hostname { + return Err(anyhow::anyhow!("mismatched SNI hostname and hostname")); + } else if let Some(h) = host_header { + if h != hostname { + return Err(anyhow::anyhow!("mismatched host header and hostname")); + } + } + + Ok(( + username.to_owned(), + dbname.to_owned(), + hostname.to_owned(), + password.to_owned(), + )) +} + +// TODO: return different http error codes +pub async fn handle( + config: &'static ProxyConfig, + request: Request, + sni_hostname: Option, +) -> anyhow::Result { + // + // Determine the destination and connection params + // + let headers = request.headers(); + let (username, dbname, hostname, password) = get_conn_info(headers, sni_hostname)?; + let credential_params = StartupMessageParams::new([ + ("user", &username), + ("database", &dbname), + ("application_name", APP_NAME), + ]); + + // + // Wake up the destination if needed. Code here is a bit involved because + // we reuse the code from the usual proxy and we need to prepare few structures + // that this code expects. + // + let tls = config.tls_config.as_ref(); + let common_names = tls.and_then(|tls| tls.common_names.clone()); + let creds = config + .auth_backend + .as_ref() + .map(|_| auth::ClientCredentials::parse(&credential_params, Some(&hostname), common_names)) + .transpose()?; + let extra = console::ConsoleReqExtra { + session_id: uuid::Uuid::new_v4(), + application_name: Some(APP_NAME), + }; + let node = creds.wake_compute(&extra).await?.expect("msg"); + let conf = node.value.config; + let port = *conf.get_ports().first().expect("no port"); + let host = match conf.get_hosts().first().expect("no host") { + tokio_postgres::config::Host::Tcp(host) => host, + tokio_postgres::config::Host::Unix(_) => { + return Err(anyhow::anyhow!("unix socket is not supported")); + } + }; + + let request_content_length = match request.body().size_hint().upper() { + Some(v) => v, + None => MAX_REQUEST_SIZE + 1, + }; + + if request_content_length > MAX_REQUEST_SIZE { + return Err(anyhow::anyhow!( + "request is too large (max {MAX_REQUEST_SIZE} bytes)" + )); + } + + // + // Read the query and query params from the request body + // + let body = hyper::body::to_bytes(request.into_body()).await?; + let QueryData { query, params } = serde_json::from_slice(&body)?; + let query_params = json_to_pg_text(params)?; + + // + // Connenct to the destination + // + let (client, connection) = tokio_postgres::Config::new() + .host(host) + .port(port) + .user(&username) + .password(&password) + .dbname(&dbname) + .max_backend_message_size(MAX_RESPONSE_SIZE) + .connect(tokio_postgres::NoTls) + .await?; + + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + // + // Now execute the query and return the result + // + let row_stream = client.query_raw_txt(query, query_params).await?; + + // Manually drain the stream into a vector to leave row_stream hanging + // around to get a command tag. Also check that the response is not too + // big. + pin_mut!(row_stream); + let mut rows: Vec = Vec::new(); + let mut curret_size = 0; + while let Some(row) = row_stream.next().await { + let row = row?; + curret_size += row.body_len(); + rows.push(row); + if curret_size > MAX_RESPONSE_SIZE { + return Err(anyhow::anyhow!("response too large")); + } + } + + // grab the command tag and number of rows affected + let command_tag = row_stream.command_tag().unwrap_or_default(); + let mut command_tag_split = command_tag.split(' '); + let command_tag_name = command_tag_split.next().unwrap_or_default(); + let command_tag_count = if command_tag_name == "INSERT" { + // INSERT returns OID first and then number of rows + command_tag_split.nth(1) + } else { + // other commands return number of rows (if any) + command_tag_split.next() + } + .and_then(|s| s.parse::().ok()); + + let fields = if !rows.is_empty() { + rows[0] + .columns() + .iter() + .map(|c| { + json!({ + "name": Value::String(c.name().to_owned()), + "dataTypeID": Value::Number(c.type_().oid().into()), + }) + }) + .collect::>() + } else { + Vec::new() + }; + + // convert rows to JSON + let rows = rows + .iter() + .map(pg_text_row_to_json) + .collect::, _>>()?; + + // resulting JSON format is based on the format of node-postgres result + Ok(json!({ + "command": command_tag_name, + "rowCount": command_tag_count, + "rows": rows, + "fields": fields, + })) +} + +// +// Convert postgres row with text-encoded values to JSON object +// +pub fn pg_text_row_to_json(row: &Row) -> Result { + let res = row + .columns() + .iter() + .enumerate() + .map(|(i, column)| { + let name = column.name(); + let pg_value = row.as_text(i)?; + let json_value = pg_text_to_json(pg_value, column.type_())?; + Ok((name.to_string(), json_value)) + }) + .collect::, anyhow::Error>>()?; + + Ok(Value::Object(res)) +} + +// +// Convert postgres text-encoded value to JSON value +// +pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result { + if let Some(val) = pg_value { + if val == "NULL" { + return Ok(Value::Null); + } + + if let Kind::Array(elem_type) = pg_type.kind() { + return pg_array_parse(val, elem_type); + } + + match *pg_type { + Type::BOOL => Ok(Value::Bool(val == "t")), + Type::INT2 | Type::INT4 => { + let val = val.parse::()?; + Ok(Value::Number(serde_json::Number::from(val))) + } + Type::FLOAT4 | Type::FLOAT8 => { + let fval = val.parse::()?; + let num = serde_json::Number::from_f64(fval); + if let Some(num) = num { + Ok(Value::Number(num)) + } else { + // Pass Nan, Inf, -Inf as strings + // JS JSON.stringify() does converts them to null, but we + // want to preserve them, so we pass them as strings + Ok(Value::String(val.to_string())) + } + } + Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?), + _ => Ok(Value::String(val.to_string())), + } + } else { + Ok(Value::Null) + } +} + +// +// Parse postgres array into JSON array. +// +// This is a bit involved because we need to handle nested arrays and quoted +// values. Unlike postgres we don't check that all nested arrays have the same +// dimensions, we just return them as is. +// +fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result { + _pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v) +} + +fn _pg_array_parse( + pg_array: &str, + elem_type: &Type, + nested: bool, +) -> Result<(Value, usize), anyhow::Error> { + let mut pg_array_chr = pg_array.char_indices(); + let mut level = 0; + let mut quote = false; + let mut entries: Vec = Vec::new(); + let mut entry = String::new(); + + // skip bounds decoration + if let Some('[') = pg_array.chars().next() { + for (_, c) in pg_array_chr.by_ref() { + if c == '=' { + break; + } + } + } + + while let Some((mut i, mut c)) = pg_array_chr.next() { + let mut escaped = false; + + if c == '\\' { + escaped = true; + (i, c) = pg_array_chr.next().unwrap(); + } + + match c { + '{' if !quote => { + level += 1; + if level > 1 { + let (res, off) = _pg_array_parse(&pg_array[i..], elem_type, true)?; + entries.push(res); + for _ in 0..off - 1 { + pg_array_chr.next(); + } + } + } + '}' => { + level -= 1; + if level == 0 { + if !entry.is_empty() { + entries.push(pg_text_to_json(Some(&entry), elem_type)?); + } + if nested { + return Ok((Value::Array(entries), i)); + } + } + } + '"' if !escaped => { + if quote { + // push even if empty + entries.push(pg_text_to_json(Some(&entry), elem_type)?); + entry = String::new(); + } + quote = !quote; + } + ',' if !quote => { + if !entry.is_empty() { + entries.push(pg_text_to_json(Some(&entry), elem_type)?); + entry = String::new(); + } + } + _ => { + entry.push(c); + } + } + } + + if level != 0 { + return Err(anyhow::anyhow!("unbalanced array")); + } + + Ok((Value::Array(entries), 0)) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_atomic_types_to_pg_params() { + let json = vec![Value::Bool(true), Value::Bool(false)]; + let pg_params = json_to_pg_text(json).unwrap(); + assert_eq!(pg_params, vec!["true", "false"]); + + let json = vec![Value::Number(serde_json::Number::from(42))]; + let pg_params = json_to_pg_text(json).unwrap(); + assert_eq!(pg_params, vec!["42"]); + + let json = vec![Value::String("foo\"".to_string())]; + let pg_params = json_to_pg_text(json).unwrap(); + assert_eq!(pg_params, vec!["foo\""]); + + let json = vec![Value::Null]; + let pg_params = json_to_pg_text(json).unwrap(); + assert_eq!(pg_params, vec!["null"]); + } + + #[test] + fn test_json_array_to_pg_array() { + // atoms and escaping + let json = "[true, false, null, 42, \"foo\", \"bar\\\"-\\\\\"]"; + let json: Value = serde_json::from_str(json).unwrap(); + let pg_params = json_to_pg_text(vec![json]).unwrap(); + assert_eq!( + pg_params, + vec!["{true,false,null,42,\"foo\",\"bar\\\"-\\\\\"}"] + ); + + // nested arrays + let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]"; + let json: Value = serde_json::from_str(json).unwrap(); + let pg_params = json_to_pg_text(vec![json]).unwrap(); + assert_eq!( + pg_params, + vec!["{{true,false},{null,42},{\"foo\",\"bar\\\"-\\\\\"}}"] + ); + } + + #[test] + fn test_atomic_types_parse() { + assert_eq!( + pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(), + json!("foo") + ); + assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null)); + assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42)); + assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42)); + assert_eq!( + pg_text_to_json(Some("42"), &Type::INT8).unwrap(), + json!("42") + ); + assert_eq!( + pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(), + json!(42.42) + ); + assert_eq!( + pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(), + json!(42.42) + ); + assert_eq!( + pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(), + json!("NaN") + ); + assert_eq!( + pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(), + json!("Infinity") + ); + assert_eq!( + pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(), + json!("-Infinity") + ); + + let json: Value = + serde_json::from_str("{\"s\":\"str\",\"n\":42,\"f\":4.2,\"a\":[null,3,\"a\"]}") + .unwrap(); + assert_eq!( + pg_text_to_json( + Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#), + &Type::JSONB + ) + .unwrap(), + json + ); + } + + #[test] + fn test_pg_array_parse_text() { + fn pt(pg_arr: &str) -> Value { + pg_array_parse(pg_arr, &Type::TEXT).unwrap() + } + assert_eq!( + pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#), + json!(["aa\"\\,a", "cha", "bbbb"]) + ); + assert_eq!( + pt(r#"{{"foo","bar"},{"bee","bop"}}"#), + json!([["foo", "bar"], ["bee", "bop"]]) + ); + assert_eq!( + pt(r#"{{{{"foo",NULL,"bop",bup}}}}"#), + json!([[[["foo", null, "bop", "bup"]]]]) + ); + assert_eq!( + pt(r#"{{"1",2,3},{4,NULL,6},{NULL,NULL,NULL}}"#), + json!([["1", "2", "3"], ["4", null, "6"], [null, null, null]]) + ); + } + + #[test] + fn test_pg_array_parse_bool() { + fn pb(pg_arr: &str) -> Value { + pg_array_parse(pg_arr, &Type::BOOL).unwrap() + } + assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true])); + assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]])); + assert_eq!( + pb(r#"{{t,f},{f,t}}"#), + json!([[true, false], [false, true]]) + ); + assert_eq!( + pb(r#"{{t,NULL},{NULL,f}}"#), + json!([[true, null], [null, false]]) + ); + } + + #[test] + fn test_pg_array_parse_numbers() { + fn pn(pg_arr: &str, ty: &Type) -> Value { + pg_array_parse(pg_arr, ty).unwrap() + } + assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3])); + assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3])); + assert_eq!(pn(r#"{1,2,3}"#, &Type::INT8), json!(["1", "2", "3"])); + assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT4), json!([1.0, 2.0, 3.0])); + assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT8), json!([1.0, 2.0, 3.0])); + assert_eq!( + pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT4), + json!([1.1, 2.2, 3.3]) + ); + assert_eq!( + pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT8), + json!([1.1, 2.2, 3.3]) + ); + assert_eq!( + pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT4), + json!(["NaN", "Infinity", "-Infinity"]) + ); + assert_eq!( + pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT8), + json!(["NaN", "Infinity", "-Infinity"]) + ); + } + + #[test] + fn test_pg_array_with_decoration() { + fn p(pg_arr: &str) -> Value { + pg_array_parse(pg_arr, &Type::INT2).unwrap() + } + assert_eq!( + p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#), + json!([[[1, 2, 3], [4, 5, 6]]]) + ); + } +} diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index c7676e8e14..fbb602e3d2 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -4,12 +4,17 @@ use crate::{ use bytes::{Buf, Bytes}; use futures::{Sink, Stream, StreamExt}; use hyper::{ - server::{accept, conn::AddrIncoming}, + server::{ + accept, + conn::{AddrIncoming, AddrStream}, + }, upgrade::Upgraded, - Body, Request, Response, StatusCode, + Body, Method, Request, Response, StatusCode, }; use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream}; use pin_project_lite::pin_project; +use serde_json::{json, Value}; + use std::{ convert::Infallible, future::ready, @@ -21,6 +26,7 @@ use tls_listener::TlsListener; use tokio::{ io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}, net::TcpListener, + select, }; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, warn, Instrument}; @@ -30,6 +36,8 @@ use utils::http::{error::ApiError, json::json_response}; // Tracking issue: https://github.com/rust-lang/rust/issues/98407. use sync_wrapper::SyncWrapper; +use super::sql_over_http; + pin_project! { /// This is a wrapper around a [`WebSocketStream`] that /// implements [`AsyncRead`] and [`AsyncWrite`]. @@ -159,6 +167,7 @@ async fn ws_handler( config: &'static ProxyConfig, cancel_map: Arc, session_id: uuid::Uuid, + sni_hostname: Option, ) -> Result, ApiError> { let host = request .headers() @@ -181,8 +190,44 @@ async fn ws_handler( // Return the response so the spawned future can continue. Ok(response) + // TODO: that deserves a refactor as now this function also handles http json client besides websockets. + // Right now I don't want to blow up sql-over-http patch with file renames and do that as a follow up instead. + } else if request.uri().path() == "/sql" && request.method() == Method::POST { + let result = select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => { + Err(anyhow::anyhow!("Query timed out")) + } + response = sql_over_http::handle(config, request, sni_hostname) => { + response + } + }; + let status_code = match result { + Ok(_) => StatusCode::OK, + Err(_) => StatusCode::BAD_REQUEST, + }; + let json = match result { + Ok(r) => r, + Err(e) => { + let message = format!("{:?}", e); + let code = match e.downcast_ref::() { + Some(e) => match e.code() { + Some(e) => serde_json::to_value(e.code()).unwrap(), + None => Value::Null, + }, + None => Value::Null, + }; + json!({ "message": message, "code": code }) + } + }; + json_response(status_code, json).map(|mut r| { + r.headers_mut().insert( + "Access-Control-Allow-Origin", + hyper::http::HeaderValue::from_static("*"), + ); + r + }) } else { - json_response(StatusCode::OK, "Connect with a websocket client") + json_response(StatusCode::BAD_REQUEST, "query is not supported") } } @@ -216,20 +261,27 @@ pub async fn task_main( } }); - let make_svc = hyper::service::make_service_fn(|_stream| async move { - Ok::<_, Infallible>(hyper::service::service_fn( - move |req: Request| async move { - let cancel_map = Arc::new(CancelMap::default()); - let session_id = uuid::Uuid::new_v4(); - ws_handler(req, config, cancel_map, session_id) - .instrument(info_span!( - "ws-client", - session = format_args!("{session_id}") - )) - .await - }, - )) - }); + let make_svc = + hyper::service::make_service_fn(|stream: &tokio_rustls::server::TlsStream| { + let sni_name = stream.get_ref().1.sni_hostname().map(|s| s.to_string()); + + async move { + Ok::<_, Infallible>(hyper::service::service_fn(move |req: Request| { + let sni_name = sni_name.clone(); + async move { + let cancel_map = Arc::new(CancelMap::default()); + let session_id = uuid::Uuid::new_v4(); + + ws_handler(req, config, cancel_map, session_id, sni_name) + .instrument(info_span!( + "ws-client", + session = format_args!("{session_id}") + )) + .await + } + })) + } + }); hyper::Server::builder(accept::from_stream(tls_listener)) .serve(make_svc) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 8ec17834ac..bde91e6783 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2042,15 +2042,19 @@ class NeonProxy(PgProtocol): proxy_port: int, http_port: int, mgmt_port: int, + external_http_port: int, auth_backend: NeonProxy.AuthBackend, metric_collection_endpoint: Optional[str] = None, metric_collection_interval: Optional[str] = None, ): host = "127.0.0.1" - super().__init__(dsn=auth_backend.default_conn_url, host=host, port=proxy_port) + domain = "proxy.localtest.me" # resolves to 127.0.0.1 + super().__init__(dsn=auth_backend.default_conn_url, host=domain, port=proxy_port) + self.domain = domain self.host = host self.http_port = http_port + self.external_http_port = external_http_port self.neon_binpath = neon_binpath self.test_output_dir = test_output_dir self.proxy_port = proxy_port @@ -2062,11 +2066,42 @@ class NeonProxy(PgProtocol): def start(self) -> NeonProxy: assert self._popen is None + + # generate key of it doesn't exist + crt_path = self.test_output_dir / "proxy.crt" + key_path = self.test_output_dir / "proxy.key" + + if not key_path.exists(): + r = subprocess.run( + [ + "openssl", + "req", + "-new", + "-x509", + "-days", + "365", + "-nodes", + "-text", + "-out", + str(crt_path), + "-keyout", + str(key_path), + "-subj", + "/CN=*.localtest.me", + "-addext", + "subjectAltName = DNS:*.localtest.me", + ] + ) + assert r.returncode == 0 + args = [ str(self.neon_binpath / "proxy"), *["--http", f"{self.host}:{self.http_port}"], *["--proxy", f"{self.host}:{self.proxy_port}"], *["--mgmt", f"{self.host}:{self.mgmt_port}"], + *["--wss", f"{self.host}:{self.external_http_port}"], + *["-c", str(crt_path)], + *["-k", str(key_path)], *self.auth_backend.extra_args(), ] @@ -2190,6 +2225,7 @@ def link_proxy( http_port = port_distributor.get_port() proxy_port = port_distributor.get_port() mgmt_port = port_distributor.get_port() + external_http_port = port_distributor.get_port() with NeonProxy( neon_binpath=neon_binpath, @@ -2197,6 +2233,7 @@ def link_proxy( proxy_port=proxy_port, http_port=http_port, mgmt_port=mgmt_port, + external_http_port=external_http_port, auth_backend=NeonProxy.Link(), ) as proxy: proxy.start() @@ -2224,6 +2261,7 @@ def static_proxy( proxy_port = port_distributor.get_port() mgmt_port = port_distributor.get_port() http_port = port_distributor.get_port() + external_http_port = port_distributor.get_port() with NeonProxy( neon_binpath=neon_binpath, @@ -2231,6 +2269,7 @@ def static_proxy( proxy_port=proxy_port, http_port=http_port, mgmt_port=mgmt_port, + external_http_port=external_http_port, auth_backend=NeonProxy.Postgres(auth_endpoint), ) as proxy: proxy.start() diff --git a/test_runner/regress/test_metric_collection.py b/test_runner/regress/test_metric_collection.py index 1231188896..00ea77f2e7 100644 --- a/test_runner/regress/test_metric_collection.py +++ b/test_runner/regress/test_metric_collection.py @@ -204,6 +204,7 @@ def proxy_with_metric_collector( http_port = port_distributor.get_port() proxy_port = port_distributor.get_port() mgmt_port = port_distributor.get_port() + external_http_port = port_distributor.get_port() (host, port) = httpserver_listen_address metric_collection_endpoint = f"http://{host}:{port}/billing/api/v1/usage_events" @@ -215,6 +216,7 @@ def proxy_with_metric_collector( proxy_port=proxy_port, http_port=http_port, mgmt_port=mgmt_port, + external_http_port=external_http_port, metric_collection_endpoint=metric_collection_endpoint, metric_collection_interval=metric_collection_interval, auth_backend=NeonProxy.Link(), diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index ae914e384e..6be3995714 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -1,22 +1,32 @@ +import json import subprocess +from typing import Any, List import psycopg2 import pytest +import requests from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres -@pytest.mark.parametrize("option_name", ["project", "endpoint"]) -def test_proxy_select_1(static_proxy: NeonProxy, option_name: str): +def test_proxy_select_1(static_proxy: NeonProxy): """ A simplest smoke test: check proxy against a local postgres instance. """ - out = static_proxy.safe_psql("select 1", options=f"{option_name}=generic-project-name") + # no SNI, deprecated `options=project` syntax (before we had several endpoint in project) + out = static_proxy.safe_psql("select 1", sslsni=0, options="project=generic-project-name") assert out[0][0] == 1 + # no SNI, new `options=endpoint` syntax + out = static_proxy.safe_psql("select 1", sslsni=0, options="endpoint=generic-project-name") + assert out[0][0] == 1 -@pytest.mark.parametrize("option_name", ["project", "endpoint"]) -def test_password_hack(static_proxy: NeonProxy, option_name: str): + # with SNI + out = static_proxy.safe_psql("select 42", host="generic-project-name.localtest.me") + assert out[0][0] == 42 + + +def test_password_hack(static_proxy: NeonProxy): """ Check the PasswordHack auth flow: an alternative to SCRAM auth for clients which can't provide the project/endpoint name via SNI or `options`. @@ -24,14 +34,16 @@ def test_password_hack(static_proxy: NeonProxy, option_name: str): user = "borat" password = "password" - static_proxy.safe_psql( - f"create role {user} with login password '{password}'", - options=f"{option_name}=irrelevant", - ) + static_proxy.safe_psql(f"create role {user} with login password '{password}'") # Note the format of `magic`! - magic = f"{option_name}=irrelevant;{password}" - static_proxy.safe_psql("select 1", sslsni=0, user=user, password=magic) + magic = f"project=irrelevant;{password}" + out = static_proxy.safe_psql("select 1", sslsni=0, user=user, password=magic) + assert out[0][0] == 1 + + magic = f"endpoint=irrelevant;{password}" + out = static_proxy.safe_psql("select 1", sslsni=0, user=user, password=magic) + assert out[0][0] == 1 # Must also check that invalid magic won't be accepted. with pytest.raises(psycopg2.OperationalError): @@ -69,52 +81,55 @@ def test_proxy_options(static_proxy: NeonProxy, option_name: str): """ options = f"{option_name}=irrelevant -cproxytest.option=value" - out = static_proxy.safe_psql("show proxytest.option", options=options) + out = static_proxy.safe_psql("show proxytest.option", options=options, sslsni=0) assert out[0][0] == "value" options = f"-c proxytest.foo=\\ str {option_name}=irrelevant" + out = static_proxy.safe_psql("show proxytest.foo", options=options, sslsni=0) + assert out[0][0] == " str" + + options = "-cproxytest.option=value" + out = static_proxy.safe_psql("show proxytest.option", options=options) + assert out[0][0] == "value" + + options = "-c proxytest.foo=\\ str" out = static_proxy.safe_psql("show proxytest.foo", options=options) assert out[0][0] == " str" -@pytest.mark.parametrize("option_name", ["project", "endpoint"]) -def test_auth_errors(static_proxy: NeonProxy, option_name: str): +def test_auth_errors(static_proxy: NeonProxy): """ Check that we throw very specific errors in some unsuccessful auth scenarios. """ # User does not exist with pytest.raises(psycopg2.Error) as exprinfo: - static_proxy.connect(user="pinocchio", options=f"{option_name}=irrelevant") + static_proxy.connect(user="pinocchio") text = str(exprinfo.value).strip() - assert text.endswith("password authentication failed for user 'pinocchio'") + assert text.find("password authentication failed for user 'pinocchio'") != -1 static_proxy.safe_psql( "create role pinocchio with login password 'magic'", - options=f"{option_name}=irrelevant", ) # User exists, but password is missing with pytest.raises(psycopg2.Error) as exprinfo: - static_proxy.connect(user="pinocchio", password=None, options=f"{option_name}=irrelevant") + static_proxy.connect(user="pinocchio", password=None) text = str(exprinfo.value).strip() - assert text.endswith("password authentication failed for user 'pinocchio'") + assert text.find("password authentication failed for user 'pinocchio'") != -1 # User exists, but password is wrong with pytest.raises(psycopg2.Error) as exprinfo: - static_proxy.connect(user="pinocchio", password="bad", options=f"{option_name}=irrelevant") + static_proxy.connect(user="pinocchio", password="bad") text = str(exprinfo.value).strip() - assert text.endswith("password authentication failed for user 'pinocchio'") + assert text.find("password authentication failed for user 'pinocchio'") != -1 # Finally, check that the user can connect - with static_proxy.connect( - user="pinocchio", password="magic", options=f"{option_name}=irrelevant" - ): + with static_proxy.connect(user="pinocchio", password="magic"): pass -@pytest.mark.parametrize("option_name", ["project", "endpoint"]) -def test_forward_params_to_client(static_proxy: NeonProxy, option_name: str): +def test_forward_params_to_client(static_proxy: NeonProxy): """ Check that we forward all necessary PostgreSQL server params to client. """ @@ -140,7 +155,7 @@ def test_forward_params_to_client(static_proxy: NeonProxy, option_name: str): where name = any(%s) """ - with static_proxy.connect(options=f"{option_name}=irrelevant") as conn: + with static_proxy.connect() as conn: with conn.cursor() as cur: cur.execute(query, (reported_params_subset,)) for name, value in cur.fetchall(): @@ -148,18 +163,65 @@ def test_forward_params_to_client(static_proxy: NeonProxy, option_name: str): assert conn.get_parameter_status(name) == value -@pytest.mark.parametrize("option_name", ["project", "endpoint"]) @pytest.mark.timeout(5) -def test_close_on_connections_exit(static_proxy: NeonProxy, option_name: str): +def test_close_on_connections_exit(static_proxy: NeonProxy): # Open two connections, send SIGTERM, then ensure that proxy doesn't exit # until after connections close. - with static_proxy.connect(options=f"{option_name}=irrelevant"), static_proxy.connect( - options=f"{option_name}=irrelevant" - ): + with static_proxy.connect(), static_proxy.connect(): static_proxy.terminate() with pytest.raises(subprocess.TimeoutExpired): static_proxy.wait_for_exit(timeout=2) # Ensure we don't accept any more connections with pytest.raises(psycopg2.OperationalError): - static_proxy.connect(options=f"{option_name}=irrelevant") + static_proxy.connect() static_proxy.wait_for_exit() + + +def test_sql_over_http(static_proxy: NeonProxy): + static_proxy.safe_psql("create role http with login password 'http' superuser") + + def q(sql: str, params: List[Any] = []) -> Any: + connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres" + response = requests.post( + f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql", + data=json.dumps({"query": sql, "params": params}), + headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr}, + verify=str(static_proxy.test_output_dir / "proxy.crt"), + ) + assert response.status_code == 200 + return response.json() + + rows = q("select 42 as answer")["rows"] + assert rows == [{"answer": 42}] + + rows = q("select $1 as answer", [42])["rows"] + assert rows == [{"answer": "42"}] + + rows = q("select $1 * 1 as answer", [42])["rows"] + assert rows == [{"answer": 42}] + + rows = q("select $1::int[] as answer", [[1, 2, 3]])["rows"] + assert rows == [{"answer": [1, 2, 3]}] + + rows = q("select $1::json->'a' as answer", [{"a": {"b": 42}}])["rows"] + assert rows == [{"answer": {"b": 42}}] + + rows = q("select * from pg_class limit 1")["rows"] + assert len(rows) == 1 + + res = q("create table t(id serial primary key, val int)") + assert res["command"] == "CREATE" + assert res["rowCount"] is None + + res = q("insert into t(val) values (10), (20), (30) returning id") + assert res["command"] == "INSERT" + assert res["rowCount"] == 3 + assert res["rows"] == [{"id": 1}, {"id": 2}, {"id": 3}] + + res = q("select * from t") + assert res["command"] == "SELECT" + assert res["rowCount"] == 3 + + res = q("drop table t") + assert res["command"] == "DROP" + assert res["rowCount"] is None