mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 01:12:56 +00:00
WebSocket URL query data working; pg-protocol POST requests partly working
This commit is contained in:
@@ -6,7 +6,6 @@ use crate::{
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{Sink, Stream, StreamExt, TryStreamExt};
|
||||
use tokio_postgres::Row;
|
||||
use tokio_postgres::error::DbError;
|
||||
use std::collections::HashMap;
|
||||
use hyper::{
|
||||
server::{accept, conn::AddrIncoming},
|
||||
@@ -53,10 +52,10 @@ pin_project! {
|
||||
}
|
||||
|
||||
impl WebSocketRw {
|
||||
pub fn new(stream: WebSocketStream<Upgraded>) -> Self {
|
||||
pub fn new(stream: WebSocketStream<Upgraded>, startup_data: Bytes) -> Self {
|
||||
Self {
|
||||
stream: stream.into(),
|
||||
bytes: Bytes::new(),
|
||||
bytes: startup_data,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -93,6 +92,7 @@ impl AsyncRead for WebSocketRw {
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
|
||||
if buf.remaining() > 0 {
|
||||
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
|
||||
let len = std::cmp::min(bytes.len(), buf.remaining());
|
||||
@@ -153,19 +153,73 @@ async fn serve_websocket(
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
hostname: Option<String>,
|
||||
startup_data: Vec<u8>,
|
||||
) -> anyhow::Result<()> {
|
||||
|
||||
let websocket = websocket.await?;
|
||||
|
||||
handle_ws_client(
|
||||
config,
|
||||
cancel_map,
|
||||
session_id,
|
||||
WebSocketRw::new(websocket),
|
||||
WebSocketRw::new(websocket, startup_data.into()),
|
||||
hostname,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct MyObject {
|
||||
data: Vec<u8>,
|
||||
recv_data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl MyObject {
|
||||
fn new(data: Vec<u8>) -> Self {
|
||||
MyObject {
|
||||
data,
|
||||
recv_data: Vec::with_capacity(512),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl AsyncRead for MyObject {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<(), io::Error>> {
|
||||
let data = &self.get_mut().data;
|
||||
let mut reader = &data[..];
|
||||
Pin::new(&mut reader).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
impl AsyncWrite for MyObject {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
|
||||
eprintln!("{:?}", buf);
|
||||
let recv_data = &mut self.get_mut().recv_data;
|
||||
recv_data.extend(buf);
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async fn ws_handler(
|
||||
mut request: Request<Body>,
|
||||
config: &'static ProxyConfig,
|
||||
@@ -183,12 +237,24 @@ async fn ws_handler(
|
||||
|
||||
// Check if the request is a websocket upgrade request.
|
||||
if hyper_tungstenite::is_upgrade_request(&request) {
|
||||
let startup_data = match request.uri().query() {
|
||||
Some(b64_str) => match base64::decode_config(b64_str, base64::URL_SAFE) {
|
||||
Ok(x) => x,
|
||||
Err(_) => {
|
||||
eprintln!("invalid WebSocket base64 startup data");
|
||||
vec![]
|
||||
}
|
||||
},
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
info!("{} bytes of startup data received via WebSocket URL query", startup_data.len());
|
||||
|
||||
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host).await
|
||||
{
|
||||
if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host, startup_data).await {
|
||||
error!("error in websocket connection: {e:?}");
|
||||
}
|
||||
});
|
||||
@@ -196,6 +262,35 @@ async fn ws_handler(
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
|
||||
} else if request.uri().path() == "/pg-protocol" && request.method() == Method::POST {
|
||||
let mut body = request.into_body();
|
||||
let mut data = Vec::with_capacity(512);
|
||||
while let Some(chunk) = body.next().await {
|
||||
data.extend(&chunk.map_err(|e| ApiError::InternalServerError(e.into()))?);
|
||||
}
|
||||
|
||||
let mut my_object = MyObject::new(data);
|
||||
let handle = tokio::spawn(async move {
|
||||
let result = handle_ws_client(
|
||||
config,
|
||||
&cancel_map,
|
||||
session_id,
|
||||
&mut my_object,
|
||||
host,
|
||||
).await;
|
||||
my_object
|
||||
});
|
||||
let my_object = handle.await.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
|
||||
let response = Response::builder()
|
||||
.header("Content-Type", "application/octet-stream")
|
||||
.header("Access-Control-Allow-Origin", "*")
|
||||
.status(StatusCode::OK)
|
||||
.body(Body::from(my_object.recv_data))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
|
||||
Ok(response)
|
||||
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||
let result = handle_sql(config, request).await;
|
||||
let status_code = match result {
|
||||
|
||||
Reference in New Issue
Block a user