From fcc6d5222181b81d4f8775de233136a2908d82fa Mon Sep 17 00:00:00 2001 From: Arseny Sher Date: Thu, 29 Sep 2022 18:12:11 +0300 Subject: [PATCH] wss --- Cargo.lock | 68 ++++++++++++++++++++++++- proxy/Cargo.toml | 3 ++ proxy/src/http/server.rs | 104 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 170 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ddb10352b8..622fbc5a52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1448,6 +1448,19 @@ dependencies = [ "tokio-native-tls", ] +[[package]] +name = "hyper-tungstenite" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36692e7f740cd10fbe3f84f7cb7bfec2a71f929e72f97c19824d3f7f45aeec9b" +dependencies = [ + "hyper", + "pin-project", + "tokio", + "tokio-tungstenite", + "tungstenite", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -2381,6 +2394,7 @@ dependencies = [ "hex", "hmac 0.12.1", "hyper", + "hyper-tungstenite", "itertools", "md5", "metrics", @@ -2389,6 +2403,7 @@ dependencies = [ "pin-project-lite", "rand", "rcgen", + "regex", "reqwest", "routerify", "rstest", @@ -2407,6 +2422,7 @@ dependencies = [ "url", "utils", "uuid", + "webpki-roots", "workspace_hack", "x509-parser", ] @@ -3060,6 +3076,17 @@ dependencies = [ "syn", ] +[[package]] +name = "sha-1" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.10.3", +] + [[package]] name = "sha2" version = "0.9.9" @@ -3534,6 +3561,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.3" @@ -3745,6 +3784,25 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" +[[package]] +name = "tungstenite" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0" +dependencies = [ + "base64", + "byteorder", + "bytes", + "http", + "httparse", + "log", + "rand", + "sha-1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typenum" version = "1.15.0" @@ -3808,6 +3866,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utils" version = "0.1.0" @@ -4032,9 +4096,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.22.4" +version = "0.22.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1c760f0d366a6c24a02ed7816e23e691f5d92291f94d15e836006fd11b04daf" +checksum = "368bfe657969fb01238bb756d351dcade285e0f6fcbd36dcb23359a5169975be" dependencies = [ "webpki", ] diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 7d0449cd1a..e9a853191b 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -16,12 +16,14 @@ hashbrown = "0.12" hex = "0.4.3" hmac = "0.12.1" hyper = "0.14" +hyper-tungstenite = "0.8.1" itertools = "0.10.3" md5 = "0.7.0" once_cell = "1.13.0" parking_lot = "0.12" pin-project-lite = "0.2.7" rand = "0.8.3" +regex = "1.4.5" reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] } routerify = "3" rustls = "0.20.0" @@ -37,6 +39,7 @@ tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", re tokio-rustls = "0.23.0" url = "2.2.2" uuid = { version = "0.8.2", features = ["v4", "serde"]} +webpki-roots = "0.22.5" x509-parser = "0.13.2" utils = { path = "../libs/utils" } diff --git a/proxy/src/http/server.rs b/proxy/src/http/server.rs index 5a75718742..06b7c73789 100644 --- a/proxy/src/http/server.rs +++ b/proxy/src/http/server.rs @@ -1,15 +1,113 @@ +use regex::Regex; +use rustls::{OwnedTrustAnchor, RootCertStore}; +use std::{net::TcpListener, sync::Arc}; + use anyhow::anyhow; +use futures::{sink::SinkExt, stream::StreamExt}; use hyper::{Body, Request, Response, StatusCode}; -use std::net::TcpListener; -use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService}; +use hyper_tungstenite::tungstenite::Message; +use hyper_tungstenite::HyperWebsocket; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpStream; +use utils::http::{ + endpoint, error::ApiError, json::json_response, request::parse_request_param, RouterBuilder, + RouterService, +}; async fn status_handler(_: Request) -> Result, ApiError> { json_response(StatusCode::OK, "") } +async fn ws_handler(mut request: Request) -> Result, ApiError> { + // Check if the request is a websocket upgrade request. + if hyper_tungstenite::is_upgrade_request(&request) { + let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None) + .map_err(|e| ApiError::BadRequest(e.into()))?; + + let connstr: String = parse_request_param(&request, "connstr")?; + let r = Regex::new(r"@(.*:(\d+)").unwrap(); + let caps = r.captures(&connstr).unwrap(); + let endpoint: String = caps.get(1).unwrap().as_str().to_owned(); + + // connect to itself, but through remote endpoint, lol + let neon_sock = TcpStream::connect(endpoint) + .await + .map_err(|e| ApiError::InternalServerError(e.into()))?; + + // Spawn a task to handle the websocket connection. + tokio::spawn(async move { + if let Err(e) = serve_websocket(websocket, neon_sock).await { + eprintln!("Error in websocket connection: {}", e); + } + }); + + // Return the response so the spawned future can continue. + Ok(response) + } else { + json_response(StatusCode::OK, "hi") + } +} + +async fn serve_websocket( + websocket: HyperWebsocket, + mut neon_sock: TcpStream, +) -> anyhow::Result<()> { + let mut websocket = websocket.await?; + let mut buf = [0u8; 8192]; + + tokio::select! { + Some(message) = websocket.next() => { + match message? { + Message::Text(msg) => { + println!("Received text message: {}", msg); + } + Message::Binary(msg) => { + println!("Received binary message: {:02X?}", msg); + neon_sock.write_all(&msg).await?; + websocket + .send(Message::binary(b"Thank you, come again.".to_vec())) + .await?; + } + Message::Ping(msg) => { + // No need to send a reply: tungstenite takes care of this for you. + println!("Received ping message: {:02X?}", msg); + } + Message::Pong(msg) => { + println!("Received pong message: {:02X?}", msg); + } + Message::Close(msg) => { + // No need to send a reply: tungstenite takes care of this for you. + if let Some(msg) = &msg { + println!( + "Received close message with code {} and message: {}", + msg.code, msg.reason + ); + } else { + println!("Received close message"); + } + } + Message::Frame(_msg) => { + unreachable!(); + } + } + }, + res = neon_sock.read(&mut buf) => { + let res = res?; + websocket + .send(Message::binary(&buf[0..res])) + .await?; + } + } + + Ok(()) +} + fn make_router() -> RouterBuilder { let router = endpoint::make_router(); - router.get("/v1/status", status_handler) + router + .get("/v1/status", status_handler) + .get("/v1/ws", ws_handler) } pub async fn thread_main(http_listener: TcpListener) -> anyhow::Result<()> {