diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index 072a88dc46..e956657c2c 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -6,7 +6,6 @@ use crate::{ use bytes::{Buf, Bytes}; use futures::{Sink, Stream, StreamExt, TryStreamExt}; use tokio_postgres::Row; -use tokio_postgres::error::DbError; use std::collections::HashMap; use hyper::{ server::{accept, conn::AddrIncoming}, @@ -53,10 +52,10 @@ pin_project! { } impl WebSocketRw { - pub fn new(stream: WebSocketStream) -> Self { + pub fn new(stream: WebSocketStream, startup_data: Bytes) -> Self { Self { stream: stream.into(), - bytes: Bytes::new(), + bytes: startup_data, } } } @@ -93,6 +92,7 @@ impl AsyncRead for WebSocketRw { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { + if buf.remaining() > 0 { let bytes = ready!(self.as_mut().poll_fill_buf(cx))?; let len = std::cmp::min(bytes.len(), buf.remaining()); @@ -153,19 +153,73 @@ async fn serve_websocket( cancel_map: &CancelMap, session_id: uuid::Uuid, hostname: Option, + startup_data: Vec, ) -> anyhow::Result<()> { + let websocket = websocket.await?; + handle_ws_client( config, cancel_map, session_id, - WebSocketRw::new(websocket), + WebSocketRw::new(websocket, startup_data.into()), hostname, ) .await?; Ok(()) } +struct MyObject { + data: Vec, + recv_data: Vec, +} + +impl MyObject { + fn new(data: Vec) -> Self { + MyObject { + data, + recv_data: Vec::with_capacity(512), + } + } +} +impl AsyncRead for MyObject { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let data = &self.get_mut().data; + let mut reader = &data[..]; + Pin::new(&mut reader).poll_read(cx, buf) + } +} +impl AsyncWrite for MyObject { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + + eprintln!("{:?}", buf); + let recv_data = &mut self.get_mut().recv_data; + recv_data.extend(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + + + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + + + Poll::Ready(Ok(())) + } +} + + async fn ws_handler( mut request: Request, config: &'static ProxyConfig, @@ -183,12 +237,24 @@ async fn ws_handler( // Check if the request is a websocket upgrade request. if hyper_tungstenite::is_upgrade_request(&request) { + let startup_data = match request.uri().query() { + Some(b64_str) => match base64::decode_config(b64_str, base64::URL_SAFE) { + Ok(x) => x, + Err(_) => { + eprintln!("invalid WebSocket base64 startup data"); + vec![] + } + }, + None => vec![], + }; + + info!("{} bytes of startup data received via WebSocket URL query", startup_data.len()); + let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None) .map_err(|e| ApiError::BadRequest(e.into()))?; tokio::spawn(async move { - if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host).await - { + if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host, startup_data).await { error!("error in websocket connection: {e:?}"); } }); @@ -196,6 +262,35 @@ async fn ws_handler( // Return the response so the spawned future can continue. Ok(response) + } else if request.uri().path() == "/pg-protocol" && request.method() == Method::POST { + let mut body = request.into_body(); + let mut data = Vec::with_capacity(512); + while let Some(chunk) = body.next().await { + data.extend(&chunk.map_err(|e| ApiError::InternalServerError(e.into()))?); + } + + let mut my_object = MyObject::new(data); + let handle = tokio::spawn(async move { + let result = handle_ws_client( + config, + &cancel_map, + session_id, + &mut my_object, + host, + ).await; + my_object + }); + let my_object = handle.await.map_err(|e| ApiError::InternalServerError(e.into()))?; + + let response = Response::builder() + .header("Content-Type", "application/octet-stream") + .header("Access-Control-Allow-Origin", "*") + .status(StatusCode::OK) + .body(Body::from(my_object.recv_data)) + .map_err(|e| ApiError::InternalServerError(e.into()))?; + + Ok(response) + } else if request.uri().path() == "/sql" && request.method() == Method::POST { let result = handle_sql(config, request).await; let status_code = match result {