WebSocket URL query data working; pg-protocol POST requests partly working

This commit is contained in:
George MacKerron
2023-05-15 17:47:33 +01:00
parent b808160aff
commit 2e05a9b652

View File

@@ -6,7 +6,6 @@ use crate::{
use bytes::{Buf, Bytes};
use futures::{Sink, Stream, StreamExt, TryStreamExt};
use tokio_postgres::Row;
use tokio_postgres::error::DbError;
use std::collections::HashMap;
use hyper::{
server::{accept, conn::AddrIncoming},
@@ -53,10 +52,10 @@ pin_project! {
}
impl WebSocketRw {
pub fn new(stream: WebSocketStream<Upgraded>) -> Self {
pub fn new(stream: WebSocketStream<Upgraded>, startup_data: Bytes) -> Self {
Self {
stream: stream.into(),
bytes: Bytes::new(),
bytes: startup_data,
}
}
}
@@ -93,6 +92,7 @@ impl AsyncRead for WebSocketRw {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if buf.remaining() > 0 {
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
let len = std::cmp::min(bytes.len(), buf.remaining());
@@ -153,19 +153,73 @@ async fn serve_websocket(
cancel_map: &CancelMap,
session_id: uuid::Uuid,
hostname: Option<String>,
startup_data: Vec<u8>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
handle_ws_client(
config,
cancel_map,
session_id,
WebSocketRw::new(websocket),
WebSocketRw::new(websocket, startup_data.into()),
hostname,
)
.await?;
Ok(())
}
struct MyObject {
data: Vec<u8>,
recv_data: Vec<u8>,
}
impl MyObject {
fn new(data: Vec<u8>) -> Self {
MyObject {
data,
recv_data: Vec::with_capacity(512),
}
}
}
impl AsyncRead for MyObject {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), io::Error>> {
let data = &self.get_mut().data;
let mut reader = &data[..];
Pin::new(&mut reader).poll_read(cx, buf)
}
}
impl AsyncWrite for MyObject {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
eprintln!("{:?}", buf);
let recv_data = &mut self.get_mut().recv_data;
recv_data.extend(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
async fn ws_handler(
mut request: Request<Body>,
config: &'static ProxyConfig,
@@ -183,12 +237,24 @@ async fn ws_handler(
// Check if the request is a websocket upgrade request.
if hyper_tungstenite::is_upgrade_request(&request) {
let startup_data = match request.uri().query() {
Some(b64_str) => match base64::decode_config(b64_str, base64::URL_SAFE) {
Ok(x) => x,
Err(_) => {
eprintln!("invalid WebSocket base64 startup data");
vec![]
}
},
None => vec![],
};
info!("{} bytes of startup data received via WebSocket URL query", startup_data.len());
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
.map_err(|e| ApiError::BadRequest(e.into()))?;
tokio::spawn(async move {
if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host).await
{
if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host, startup_data).await {
error!("error in websocket connection: {e:?}");
}
});
@@ -196,6 +262,35 @@ async fn ws_handler(
// Return the response so the spawned future can continue.
Ok(response)
} else if request.uri().path() == "/pg-protocol" && request.method() == Method::POST {
let mut body = request.into_body();
let mut data = Vec::with_capacity(512);
while let Some(chunk) = body.next().await {
data.extend(&chunk.map_err(|e| ApiError::InternalServerError(e.into()))?);
}
let mut my_object = MyObject::new(data);
let handle = tokio::spawn(async move {
let result = handle_ws_client(
config,
&cancel_map,
session_id,
&mut my_object,
host,
).await;
my_object
});
let my_object = handle.await.map_err(|e| ApiError::InternalServerError(e.into()))?;
let response = Response::builder()
.header("Content-Type", "application/octet-stream")
.header("Access-Control-Allow-Origin", "*")
.status(StatusCode::OK)
.body(Body::from(my_object.recv_data))
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
let result = handle_sql(config, request).await;
let status_code = match result {