Compare commits

...

20 Commits

Author SHA1 Message Date
George MacKerron
6abe34bfa1 Dead-end websocket experiments 2023-05-26 11:35:36 +01:00
George MacKerron
b932c14f94 Pg protocol over SQL vaguely-semi-working 2023-05-16 07:42:51 +01:00
George MacKerron
2e05a9b652 WebSocket URL query data working; pg-protocol POST requests partly working 2023-05-15 17:47:33 +01:00
George MacKerron
b808160aff Improved error handling 2023-05-09 17:24:00 +01:00
George MacKerron
2997018005 Added support to return timestamptz, timestamp, date and time for http queries 2023-05-08 11:02:12 +01:00
George MacKerron
8f98cc29fa Stop double-serializing JSON response 2023-05-05 21:30:28 +01:00
Arthur Petukhovsky
b4ee8e3b73 Box query values
SELECT 1.0 is still not working
2023-05-05 17:27:26 +00:00
George MacKerron
68ea916bf1 More tweaks, now compiling 2023-05-05 14:49:15 +01:00
George MacKerron
eea3dd54da Tweaks 2023-05-05 14:32:01 +01:00
George MacKerron
4071b22519 Further work on sql over http 2023-05-05 13:25:50 +01:00
George MacKerron
c58edec63a Small progress on sql-over-http 2023-05-05 11:16:07 +01:00
George MacKerron
fdb9d4373d More work on mapping params 2023-05-05 10:53:32 +01:00
George MacKerron
cb88df7ffa Broken/incomplete attempts to implement new http API for serverless driver 2023-05-04 22:00:47 +01:00
Arthur Petukhovsky
595248532c Add global connection cache for ws 2023-04-18 15:29:00 +00:00
Stas Kelvich
158c051c9a add sleep endpoint to simulate proxy response 2023-04-18 15:39:10 +03:00
Stas Kelvich
143c4954df use simple query 2023-04-18 13:40:49 +03:00
Eduard Dyckman
373ae7672b Allow only post connection to sql url (#4035) 2023-04-15 20:06:57 +09:00
Stas Kelvich
1233923c3a works with browser 2023-04-15 02:18:52 +03:00
Stas Kelvich
1a3a7f14dd works now, few fixes 2023-04-15 01:15:22 +03:00
Stas Kelvich
4e9067a8c2 wip: now compiles 2023-04-14 19:41:02 +03:00
9 changed files with 581 additions and 30 deletions

41
Cargo.lock generated
View File

@@ -130,6 +130,12 @@ dependencies = [
"static_assertions",
]
[[package]]
name = "array-init"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d62b7694a562cdf5a74227903507c56ab2cc8bdd1f781ed5cb4cf9c9f810bfc"
[[package]]
name = "asn1-rs"
version = "0.5.2"
@@ -297,7 +303,7 @@ dependencies = [
"http",
"http-body",
"lazy_static",
"percent-encoding",
"percent-encoding 2.2.0",
"pin-project-lite",
"tracing",
]
@@ -401,7 +407,7 @@ dependencies = [
"hex",
"http",
"once_cell",
"percent-encoding",
"percent-encoding 2.2.0",
"regex",
"ring",
"time",
@@ -522,7 +528,7 @@ dependencies = [
"http-body",
"hyper",
"once_cell",
"percent-encoding",
"percent-encoding 2.2.0",
"pin-project-lite",
"pin-utils",
"tokio",
@@ -544,7 +550,7 @@ dependencies = [
"http-body",
"hyper",
"once_cell",
"percent-encoding",
"percent-encoding 2.2.0",
"pin-project-lite",
"pin-utils",
"tracing",
@@ -684,7 +690,7 @@ dependencies = [
"matchit",
"memchr",
"mime",
"percent-encoding",
"percent-encoding 2.2.0",
"pin-project-lite",
"rustversion",
"serde",
@@ -1580,7 +1586,7 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8"
dependencies = [
"percent-encoding",
"percent-encoding 2.2.0",
]
[[package]]
@@ -2584,7 +2590,7 @@ dependencies = [
"futures-util",
"once_cell",
"opentelemetry_api",
"percent-encoding",
"percent-encoding 2.2.0",
"rand",
"thiserror",
"tokio",
@@ -2749,6 +2755,12 @@ dependencies = [
"base64 0.13.1",
]
[[package]]
name = "percent-encoding"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31010dd2e1ac33d5b46a5b413495239882813e0369f8ed8a5e266f173602f831"
[[package]]
name = "percent-encoding"
version = "2.2.0"
@@ -2879,7 +2891,9 @@ name = "postgres-types"
version = "0.2.4"
source = "git+https://github.com/neondatabase/rust-postgres.git?rev=43e6db254a97fdecbce33d8bc0890accfd74495e#43e6db254a97fdecbce33d8bc0890accfd74495e"
dependencies = [
"array-init",
"bytes",
"chrono",
"fallible-iterator",
"postgres-protocol",
]
@@ -3112,6 +3126,7 @@ dependencies = [
"once_cell",
"opentelemetry",
"parking_lot",
"percent-encoding 1.0.1",
"pin-project-lite",
"postgres_backend",
"pq_proto",
@@ -3316,7 +3331,7 @@ dependencies = [
"mime",
"mime_guess",
"once_cell",
"percent-encoding",
"percent-encoding 2.2.0",
"pin-project-lite",
"rustls 0.20.8",
"rustls-pemfile",
@@ -3389,7 +3404,7 @@ dependencies = [
"http",
"hyper",
"lazy_static",
"percent-encoding",
"percent-encoding 2.2.0",
"regex",
]
@@ -4332,7 +4347,7 @@ dependencies = [
"futures-util",
"log",
"parking_lot",
"percent-encoding",
"percent-encoding 2.2.0",
"phf",
"pin-project-lite",
"postgres-protocol",
@@ -4480,7 +4495,7 @@ dependencies = [
"http-body",
"hyper",
"hyper-timeout",
"percent-encoding",
"percent-encoding 2.2.0",
"pin-project",
"prost",
"prost-derive",
@@ -4512,7 +4527,7 @@ dependencies = [
"http-body",
"hyper",
"hyper-timeout",
"percent-encoding",
"percent-encoding 2.2.0",
"pin-project",
"prost",
"rustls-native-certs",
@@ -4822,7 +4837,7 @@ checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
dependencies = [
"form_urlencoded",
"idna",
"percent-encoding",
"percent-encoding 2.2.0",
"serde",
]

View File

@@ -55,7 +55,7 @@ hmac = "0.12.1"
hostname = "0.3.1"
humantime = "2.1"
humantime-serde = "1.1.1"
hyper = "0.14"
hyper = { version = "0.14", features = ["http2", "tcp", "runtime", "http1"]}
hyper-tungstenite = "0.9"
itertools = "0.10"
jsonwebtoken = "8"
@@ -117,16 +117,17 @@ uuid = { version = "1.2", features = ["v4", "serde"] }
walkdir = "2.3.2"
webpki-roots = "0.23"
x509-parser = "0.15"
percent-encoding = "1.0"
## TODO replace this with tracing
env_logger = "0.10"
log = "0.4"
## Libraries from neondatabase/ git forks, ideally with changes to be upstreamed
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e", features = ["with-chrono-0_4", "array-impls"] }
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e", features = ["array-impls"] }
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e", features = ["array-impls"] }
tokio-tar = { git = "https://github.com/neondatabase/tokio-tar.git", rev="404df61437de0feef49ba2ccdbdd94eb8ad6e142" }
## Other git libraries

View File

@@ -23,7 +23,7 @@ hmac.workspace = true
hostname.workspace = true
humantime.workspace = true
hyper-tungstenite.workspace = true
hyper.workspace = true
hyper = { workspace = true, features = ["http2", "http1", "tcp", "runtime"] }
itertools.workspace = true
md5.workspace = true
metrics.workspace = true
@@ -65,6 +65,7 @@ x509-parser.workspace = true
workspace_hack.workspace = true
tokio-util.workspace = true
percent-encoding.workspace = true
[dev-dependencies]
rcgen.workspace = true

View File

@@ -2,6 +2,7 @@
//! Other modules should use stuff from this module instead of
//! directly relying on deps like `reqwest` (think loose coupling).
pub mod pg_to_json;
pub mod server;
pub mod websocket;

View File

@@ -0,0 +1,154 @@
use anyhow::{anyhow, Context};
use chrono::{Utc, DateTime, NaiveDateTime, NaiveTime, NaiveDate};
use serde_json::Map;
use tokio_postgres::{
types::{FromSql, Type},
Column, Row,
};
// some type-aliases I use in my project
pub type JSONValue = serde_json::Value;
pub type RowData = Map<String, JSONValue>;
pub type Error = anyhow::Error; // from: https://github.com/dtolnay/anyhow
pub fn postgres_row_to_json_value(row: &Row) -> Result<JSONValue, Error> {
let row_data = postgres_row_to_row_data(row)?;
Ok(JSONValue::Object(row_data))
}
pub fn postgres_row_to_row_data(row: &Row) -> Result<RowData, Error> {
let mut result: Map<String, JSONValue> = Map::new();
for (i, column) in row.columns().iter().enumerate() {
let name = column.name();
let json_value = pg_cell_to_json_value(&row, column, i)?;
result.insert(name.to_string(), json_value);
}
Ok(result)
}
pub fn pg_cell_to_json_value(
row: &Row,
column: &Column,
column_i: usize,
) -> Result<JSONValue, Error> {
let f64_to_json_number = |raw_val: f64| -> Result<JSONValue, Error> {
let temp = serde_json::Number::from_f64(raw_val).ok_or(anyhow!("invalid json-float"))?;
Ok(JSONValue::Number(temp))
};
Ok(match *column.type_() {
// for rust-postgres <> postgres type-mappings: https://docs.rs/postgres/latest/postgres/types/trait.FromSql.html#types
// for postgres types: https://www.postgresql.org/docs/7.4/datatype.html#DATATYPE-TABLE
Type::TIMESTAMPTZ => get_basic(row, column, column_i, |a: DateTime<Utc>| {
Ok(JSONValue::String(a.to_rfc3339()))
})?,
Type::TIMESTAMP => get_basic(row, column, column_i, |a: NaiveDateTime| {
Ok(JSONValue::String(a.format("%Y-%m-%dT%H:%M:%S%.6f").to_string()))
})?,
Type::TIME => get_basic(row, column, column_i, |a: NaiveTime| {
Ok(JSONValue::String(a.format("%H:%M:%S%.6f").to_string()))
})?,
Type::DATE => get_basic(row, column, column_i, |a: NaiveDate| {
Ok(JSONValue::String(a.format("%Y-%m-%d").to_string()))
})?,
// no TIMETZ support?
// single types
Type::BOOL => get_basic(row, column, column_i, |a: bool| Ok(JSONValue::Bool(a)))?,
Type::INT2 => get_basic(row, column, column_i, |a: i16| {
Ok(JSONValue::Number(serde_json::Number::from(a)))
})?,
Type::INT4 => get_basic(row, column, column_i, |a: i32| {
Ok(JSONValue::Number(serde_json::Number::from(a)))
})?,
Type::INT8 => get_basic(row, column, column_i, |a: i64| {
Ok(JSONValue::Number(serde_json::Number::from(a)))
})?,
Type::TEXT | Type::VARCHAR => {
get_basic(row, column, column_i, |a: String| Ok(JSONValue::String(a)))?
}
// Type::JSON | Type::JSONB => get_basic(row, column, column_i, |a: JSONValue| Ok(a))?,
Type::FLOAT4 => get_basic(row, column, column_i, |a: f32| f64_to_json_number(a.into()))?,
Type::FLOAT8 => get_basic(row, column, column_i, f64_to_json_number)?,
// these types require a custom StringCollector struct as an intermediary (see struct at bottom)
Type::TS_VECTOR => get_basic(row, column, column_i, |a: StringCollector| {
Ok(JSONValue::String(a.0))
})?,
// array types
Type::BOOL_ARRAY => get_array(row, column, column_i, |a: bool| Ok(JSONValue::Bool(a)))?,
Type::INT2_ARRAY => get_array(row, column, column_i, |a: i16| {
Ok(JSONValue::Number(serde_json::Number::from(a)))
})?,
Type::INT4_ARRAY => get_array(row, column, column_i, |a: i32| {
Ok(JSONValue::Number(serde_json::Number::from(a)))
})?,
Type::INT8_ARRAY => get_array(row, column, column_i, |a: i64| {
Ok(JSONValue::Number(serde_json::Number::from(a)))
})?,
Type::TEXT_ARRAY | Type::VARCHAR_ARRAY => {
get_array(row, column, column_i, |a: String| Ok(JSONValue::String(a)))?
}
// Type::JSON_ARRAY | Type::JSONB_ARRAY => get_array(row, column, column_i, |a: JSONValue| Ok(a))?,
Type::FLOAT4_ARRAY => get_array(row, column, column_i, f64_to_json_number)?,
Type::FLOAT8_ARRAY => get_array(row, column, column_i, f64_to_json_number)?,
// these types require a custom StringCollector struct as an intermediary (see struct at bottom)
Type::TS_VECTOR_ARRAY => get_array(row, column, column_i, |a: StringCollector| {
Ok(JSONValue::String(a.0))
})?,
_ => anyhow::bail!(
"Cannot convert pg-cell \"{}\" of type \"{}\" to a JSONValue.",
column.name(),
column.type_().name()
),
})
}
fn get_basic<'a, T: FromSql<'a>>(
row: &'a Row,
column: &Column,
column_i: usize,
val_to_json_val: impl Fn(T) -> Result<JSONValue, Error>,
) -> Result<JSONValue, Error> {
let raw_val = row
.try_get::<_, Option<T>>(column_i)
.with_context(|| format!("column_name:{}", column.name()))?;
raw_val.map_or(Ok(JSONValue::Null), val_to_json_val)
}
fn get_array<'a, T: FromSql<'a>>(
row: &'a Row,
column: &Column,
column_i: usize,
val_to_json_val: impl Fn(T) -> Result<JSONValue, Error>,
) -> Result<JSONValue, Error> {
let raw_val_array = row
.try_get::<_, Option<Vec<T>>>(column_i)
.with_context(|| format!("column_name:{}", column.name()))?;
Ok(match raw_val_array {
Some(val_array) => {
let mut result = vec![];
for val in val_array {
result.push(val_to_json_val(val)?);
}
JSONValue::Array(result)
}
None => JSONValue::Null,
})
}
struct StringCollector(String);
impl FromSql<'_> for StringCollector {
fn from_sql(
_: &Type,
raw: &[u8],
) -> Result<StringCollector, Box<dyn std::error::Error + Sync + Send>> {
let result = std::str::from_utf8(raw)?;
Ok(StringCollector(result.to_owned()))
}
fn accepts(_ty: &Type) -> bool {
true
}
}

View File

@@ -1,15 +1,25 @@
use crate::http::pg_to_json::postgres_row_to_json_value;
use crate::{
cancellation::CancelMap, config::ProxyConfig, error::io_error, proxy::handle_ws_client,
auth, cancellation::CancelMap, config::ProxyConfig, console, error::io_error,
proxy::handle_ws_client,
};
use bytes::{Buf, Bytes};
use futures::{Sink, Stream, StreamExt};
use futures::{Sink, Stream, StreamExt, TryStreamExt};
use tokio_postgres::Row;
use std::collections::HashMap;
use hyper::{
server::{accept, conn::AddrIncoming},
upgrade::Upgraded,
Body, Request, Response, StatusCode,
Body, Method, Request, Response, StatusCode,
};
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
use percent_encoding::percent_decode;
use pin_project_lite::pin_project;
use pq_proto::StartupMessageParams;
use serde_json::{Value, json};
use tokio::sync::Mutex;
use tokio_postgres::types::{ToSql};
use std::{
convert::Infallible,
future::ready,
@@ -24,6 +34,7 @@ use tokio::{
};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument};
use url::{Url};
use utils::http::{error::ApiError, json::json_response};
// TODO: use `std::sync::Exclusive` once it's stabilized.
@@ -41,10 +52,10 @@ pin_project! {
}
impl WebSocketRw {
pub fn new(stream: WebSocketStream<Upgraded>) -> Self {
pub fn new(stream: WebSocketStream<Upgraded>, startup_data: Bytes) -> Self {
Self {
stream: stream.into(),
bytes: Bytes::new(),
bytes: startup_data,
}
}
}
@@ -81,6 +92,7 @@ impl AsyncRead for WebSocketRw {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if buf.remaining() > 0 {
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
let len = std::cmp::min(bytes.len(), buf.remaining());
@@ -141,25 +153,101 @@ async fn serve_websocket(
cancel_map: &CancelMap,
session_id: uuid::Uuid,
hostname: Option<String>,
startup_data: Vec<u8>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
handle_ws_client(
config,
cancel_map,
session_id,
WebSocketRw::new(websocket),
WebSocketRw::new(websocket, startup_data.into()),
hostname,
)
.await?;
Ok(())
}
struct MyObject {
data: Vec<u8>,
recv_data: Vec<u8>,
}
impl MyObject {
fn new(data: Vec<u8>) -> Self {
MyObject {
data,
recv_data: Vec::with_capacity(512),
}
}
}
impl AsyncRead for MyObject {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), io::Error>> {
let data = &self.get_mut().data;
// let len: usize = recv_data.len();
// info!("len: {}", len);
// let mut i: usize = 0;
// let mut zcount = 0;
// while len >= i + 5 {
// let cmd = recv_data[i];
// if cmd == 0x5a { zcount += 1; }
// let size = u32::from_be_bytes(recv_data[(i + 1)..(i + 5)].try_into().unwrap());
// info!("cmd: {} size: {} buf: {:?}", cmd, size, recv_data);
// i += usize::try_from(size).unwrap() + 1;
// }
// if zcount < 2 {
// Poll::Pending
// } else {
// let mut reader = &recv_data[..];
// Pin::new(&mut reader).poll_read(cx, buf)
// }
let mut reader = &data[..];
Pin::new(&mut reader).poll_read(cx, buf)
}
}
impl AsyncWrite for MyObject {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
// eprintln!("{:?}", buf);
let recv_data = &mut self.get_mut().recv_data;
recv_data.extend(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
// ...
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
// ...
Poll::Ready(Ok(()))
}
}
async fn ws_handler(
mut request: Request<Body>,
config: &'static ProxyConfig,
cancel_map: Arc<CancelMap>,
session_id: uuid::Uuid,
cache: Arc<Mutex<ConnectionCache>>,
) -> Result<Response<Body>, ApiError> {
let host = request
.headers()
.get("host")
@@ -169,25 +257,311 @@ async fn ws_handler(
// Check if the request is a websocket upgrade request.
if hyper_tungstenite::is_upgrade_request(&request) {
let startup_data = match request.uri().query() {
Some(b64_str) => match base64::decode_config(b64_str, base64::URL_SAFE) {
Ok(x) => x,
Err(_) => {
eprintln!("invalid WebSocket base64 startup data");
vec![]
}
},
None => vec![],
};
info!("{} bytes of startup data received via WebSocket URL query", startup_data.len());
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
.map_err(|e| ApiError::BadRequest(e.into()))?;
tokio::spawn(async move {
if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host).await
{
if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host, startup_data).await {
error!("error in websocket connection: {e:?}");
}
});
// Return the response so the spawned future can continue.
Ok(response)
} else if request.uri().path() == "/pg-protocol" && request.method() == Method::POST {
let mut body = request.into_body();
let mut data = Vec::with_capacity(512);
while let Some(chunk) = body.next().await {
data.extend(&chunk.map_err(|e| ApiError::InternalServerError(e.into()))?);
}
let mut my_object = MyObject::new(data);
let my_object = tokio::spawn(async move {
let result = handle_ws_client(
config,
&cancel_map,
session_id,
&mut my_object,
host,
).await;
my_object
})
.await
.map_err(|e| ApiError::InternalServerError(e.into()))?;
let response = Response::builder()
.header("Content-Type", "application/octet-stream")
.header("Access-Control-Allow-Origin", "*")
.status(StatusCode::OK)
.body(Body::from(my_object.recv_data))
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
let result = handle_sql(config, request).await;
let status_code = match result {
Ok(_) => StatusCode::OK,
Err(_) => StatusCode::BAD_REQUEST
};
let json = match result {
Ok(r) => serde_json::to_value(r).unwrap(),
Err(e) => {
let message = format!("{:?}", e);
let code = match e.downcast_ref::<tokio_postgres::Error>() {
Some(e) => match e.code() {
Some(e) => serde_json::to_value(e.code()).unwrap(),
None => Value::Null
},
None => Value::Null
};
json!({ "message": message, "code": code })
}
};
json_response(status_code, json).map(|mut r| {
r.headers_mut().insert(
"Access-Control-Allow-Origin",
hyper::http::HeaderValue::from_static("*"),
);
r
})
} else if request.uri().path() == "/sleep" {
// sleep 15ms
tokio::time::sleep(std::time::Duration::from_millis(15)).await;
json_response(StatusCode::OK, "done")
} else {
json_response(StatusCode::OK, "Connect with a websocket client")
json_response(StatusCode::BAD_REQUEST, "query is not supported")
}
}
fn boxed_from_json(json: Vec<Value>) -> Result<Vec<Box<dyn ToSql + Sync + Send>>, anyhow::Error> {
json.iter().map(|value| {
let boxed: Result<Box<dyn ToSql + Sync + Send>, anyhow::Error> = match value {
Value::Bool(b) => Ok(Box::new(b.clone())),
Value::Number(n) => Ok(Box::new(n.as_f64().unwrap())),
Value::String(s) => Ok(Box::new(s.clone())),
// TODO: support null (not like this: `Value::Null => Ok(Box::new(None::<bool>)),`)
// TODO: support arrays
x => Err(anyhow::anyhow!("unsupported param {:?}", x))
};
boxed
})
.collect::<Result<Vec<_>, anyhow::Error>>()
}
// XXX: return different error codes
async fn handle_sql(
config: &'static ProxyConfig,
request: Request<Body>
) -> anyhow::Result<Vec<Value>> {
let headers = request.headers();
let connection_string = headers
.get("X-Neon-ConnectionString")
.ok_or(anyhow::anyhow!("missing connection string"))?
.to_str()?;
let connection_url = Url::parse(connection_string)?;
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(anyhow::anyhow!("connection string must start with postgres: or postgresql:"))
}
let mut url_path = connection_url
.path_segments()
.ok_or(anyhow::anyhow!("missing database name"))?;
let dbname = url_path
.next()
.ok_or(anyhow::anyhow!("invalid database name"))?;
let username = connection_url.username();
let password = connection_url
.password()
.ok_or(anyhow::anyhow!("no password"))?;
let hostname = connection_url
.host_str()
.ok_or(anyhow::anyhow!("no host"))?;
let host_header = request
.headers()
.get("host")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(':').next());
match host_header {
Some(h) if h == hostname => h,
Some(_) => return Err(anyhow::anyhow!("mismatched host header and hostname")),
None => return Err(anyhow::anyhow!("no host header"))
};
let mut body = request.into_body();
let mut data = Vec::with_capacity(512);
while let Some(chunk) = body.next().await {
data.extend(&chunk?);
}
#[derive(serde::Deserialize)]
struct QueryData {
query: String,
params: Vec<serde_json::Value>
}
let query_data: QueryData = serde_json::from_slice(&data)?;
let credential_params = StartupMessageParams::new([
("user", username),
("database", dbname),
("application_name", "proxy_http_sql"),
]);
let tls = config.tls_config.as_ref();
let common_names = tls.and_then(|tls| tls.common_names.clone());
let creds = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&credential_params, Some(hostname), common_names))
.transpose()?;
let extra = console::ConsoleReqExtra {
session_id: uuid::Uuid::new_v4(),
application_name: Some("proxy_http_sql"),
};
let node = creds.wake_compute(&extra).await?.expect("msg");
let conf = node.value.config;
let host = match conf.get_hosts().first().expect("no host") {
tokio_postgres::config::Host::Tcp(host) => host,
tokio_postgres::config::Host::Unix(_) => {
return Err(anyhow::anyhow!("unix socket is not supported"));
}
};
let conn_string = &format!(
"host={} port={} user={} password={} dbname={}",
host,
conf.get_ports().first().expect("no port"),
username,
password,
dbname
);
let (client, connection) = tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let query = &query_data.query;
let query_params = boxed_from_json(query_data.params)?;
// TODO: find a way to catch the panic and return an error if the number of params is wrong
let pg_rows: Vec<Row> = client
.query_raw(query, query_params)
.await?
.try_collect::<Vec<Row>>()
.await?;
let rows: Result<Vec<serde_json::Value>, anyhow::Error> = pg_rows
.iter()
.map(postgres_row_to_json_value)
.collect();
let rows = rows?;
Ok(rows)
}
pub struct ConnectionCache {
connections: HashMap<String, tokio_postgres::Client>,
}
impl ConnectionCache {
pub fn new() -> Arc<Mutex<Self>> {
Arc::new(Mutex::new(Self {
connections: HashMap::new(),
}))
}
pub async fn execute(
cache: &Arc<Mutex<ConnectionCache>>,
conn_string: &str,
hostname: &str,
sql: &str,
) -> anyhow::Result<String> {
// TODO: let go mutex when establishing connection
let mut cache = cache.lock().await;
let cache_key = format!("connstr={}, hostname={}", conn_string, hostname);
let client = if let Some(client) = cache.connections.get(&cache_key) {
info!("using cached connection {}", conn_string);
client
} else {
info!("!!!! connecting to: {}", conn_string);
let (client, connection) =
tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
tokio::spawn(async move {
// TODO: remove connection from cache
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
cache.connections.insert(cache_key.clone(), client);
cache.connections.get(&cache_key).unwrap()
};
let sql = percent_decode(sql.as_bytes()).decode_utf8()?.to_string();
let rows: Vec<HashMap<_, _>> = client
.simple_query(&sql)
.await?
.into_iter()
.filter_map(|el| {
if let tokio_postgres::SimpleQueryMessage::Row(row) = el {
let mut serialized_row: HashMap<String, String> = HashMap::new();
for i in 0..row.len() {
let col = row.columns().get(i).map_or("?", |c| c.name());
let val = row.get(i).unwrap_or("?");
serialized_row.insert(col.into(), val.into());
}
Some(serialized_row)
} else {
None
}
})
.collect();
Ok(serde_json::to_string(&rows)?)
}
}
pub async fn task_main(
config: &'static ProxyConfig,
cache: &'static Arc<Mutex<ConnectionCache>>,
ws_listener: TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
@@ -221,7 +595,7 @@ pub async fn task_main(
move |req: Request<Body>| async move {
let cancel_map = Arc::new(CancelMap::default());
let session_id = uuid::Uuid::new_v4();
ws_handler(req, config, cancel_map, session_id)
ws_handler(req, config, cancel_map, session_id, cache.clone())
.instrument(info_span!(
"ws-client",
session = format_args!("{session_id}")

View File

@@ -32,6 +32,8 @@ use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use utils::{project_git_version, sentry_init::init_sentry};
use crate::http::websocket::ConnectionCache;
project_git_version!(GIT_VERSION);
/// Flattens `Result<Result<T>>` into `Result<T>`.
@@ -53,6 +55,8 @@ async fn main() -> anyhow::Result<()> {
let args = cli().get_matches();
let config = build_config(&args)?;
let wsconn_cache = Box::leak(Box::new(ConnectionCache::new()));
info!("Authentication backend: {}", config.auth_backend);
// Check that we can bind to address before further initialization
@@ -82,6 +86,7 @@ async fn main() -> anyhow::Result<()> {
client_tasks.push(tokio::spawn(http::websocket::task_main(
config,
wsconn_cache,
wss_listener,
cancellation_token.clone(),
)));