mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-27 18:10:37 +00:00
Implement wss support in proxy (#3247)
This is a hacky implementation of WebSocket server, embedded into our postgres proxy. The server is used to allow https://github.com/neondatabase/serverless to connect to our postgres from browser and serverless javascript functions. How it will work (general schema): - browser opens a websocket connection to `wss://ep-abc-xyz-123.xx-central-1.aws.neon.tech/` - proxy accepts this connection and terminates TLS (https) - inside encrypted tunnel (HTTPS), browser initiates plain (non-encrypted) postgres connection - proxy performs auth as in usual plain pg connection and forwards connection to the compute Related issue: #3225
This commit is contained in:
committed by
GitHub
parent
df42213dbb
commit
debd134b15
79
Cargo.lock
generated
79
Cargo.lock
generated
@@ -1700,6 +1700,19 @@ dependencies = [
|
||||
"tokio-io-timeout",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tungstenite"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d62004bcd4f6f85d9e2aa4206f1466ee67031f5ededcb6c6e62d48f9306ad879"
|
||||
dependencies = [
|
||||
"hyper",
|
||||
"pin-project",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.53"
|
||||
@@ -2658,6 +2671,7 @@ dependencies = [
|
||||
"hex",
|
||||
"hmac",
|
||||
"hyper",
|
||||
"hyper-tungstenite",
|
||||
"itertools",
|
||||
"md5",
|
||||
"metrics",
|
||||
@@ -2667,6 +2681,7 @@ dependencies = [
|
||||
"pq_proto",
|
||||
"rand",
|
||||
"rcgen",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"routerify",
|
||||
"rstest",
|
||||
@@ -2678,6 +2693,7 @@ dependencies = [
|
||||
"sha2",
|
||||
"socket2",
|
||||
"thiserror",
|
||||
"tls-listener",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
@@ -2687,6 +2703,7 @@ dependencies = [
|
||||
"url",
|
||||
"utils",
|
||||
"uuid",
|
||||
"webpki-roots",
|
||||
"workspace_hack",
|
||||
"x509-parser",
|
||||
]
|
||||
@@ -3324,6 +3341,17 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha-1"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.5"
|
||||
@@ -3687,6 +3715,20 @@ version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
|
||||
|
||||
[[package]]
|
||||
name = "tls-listener"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c9d4ff21187d434ac7709bfc7441ca88f63681247e5ad99f0f08c8c91ddc103d"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"hyper",
|
||||
"pin-project-lite",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.21.1"
|
||||
@@ -3801,6 +3843,18 @@ dependencies = [
|
||||
"xattr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.17.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.4"
|
||||
@@ -4027,6 +4081,25 @@ version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.17.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0"
|
||||
dependencies = [
|
||||
"base64 0.13.1",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand",
|
||||
"sha-1",
|
||||
"thiserror",
|
||||
"url",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.16.0"
|
||||
@@ -4115,6 +4188,12 @@ version = "2.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9"
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "utils"
|
||||
version = "0.1.0"
|
||||
|
||||
@@ -17,12 +17,14 @@ hashbrown = "0.12"
|
||||
hex = "0.4.3"
|
||||
hmac = "0.12.1"
|
||||
hyper = "0.14"
|
||||
hyper-tungstenite = "0.8.1"
|
||||
itertools = "0.10.3"
|
||||
md5 = "0.7.0"
|
||||
once_cell = "1.13.0"
|
||||
parking_lot = "0.12"
|
||||
pin-project-lite = "0.2.7"
|
||||
rand = "0.8.3"
|
||||
regex = "1.4.5"
|
||||
reqwest = { version = "0.11", default-features = false, features = [ "json", "rustls-tls" ] }
|
||||
routerify = "3"
|
||||
rustls = "0.20.0"
|
||||
@@ -36,10 +38,12 @@ thiserror = "1.0.30"
|
||||
tokio = { version = "1.17", features = ["macros"] }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
||||
tokio-rustls = "0.23.0"
|
||||
tls-listener = { version = "0.5.1", features = ["rustls", "hyper-h1"] }
|
||||
tracing = "0.1.36"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
url = "2.2.2"
|
||||
uuid = { version = "1.2", features = ["v4", "serde"] }
|
||||
webpki-roots = "0.22.5"
|
||||
x509-parser = "0.14"
|
||||
|
||||
metrics = { path = "../libs/metrics" }
|
||||
|
||||
@@ -149,7 +149,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the project name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
let fetch_magic_payload = async {
|
||||
let fetch_magic_payload = |client| async {
|
||||
warn!("project name not specified, resorting to the password hack auth flow");
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
@@ -161,10 +161,26 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
auth::Result::Ok(payload)
|
||||
};
|
||||
|
||||
// If we want to use cleartext password flow, we can read the password
|
||||
// from the client and pretend that it's a magic payload (PasswordHack hack).
|
||||
let fetch_plaintext_password = |client| async {
|
||||
info!("using cleartext password flow");
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::CleartextPassword)
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
auth::Result::Ok(auth::password_hack::PasswordHackPayload {
|
||||
project: String::new(),
|
||||
password: payload,
|
||||
})
|
||||
};
|
||||
|
||||
// TODO: find a proper way to merge those very similar blocks.
|
||||
let (mut node, payload) = match self {
|
||||
Console(endpoint, creds) if creds.project.is_none() => {
|
||||
let payload = fetch_magic_payload.await?;
|
||||
let payload = fetch_magic_payload(client).await?;
|
||||
|
||||
let mut creds = creds.as_ref();
|
||||
creds.project = Some(payload.project.as_str().into());
|
||||
@@ -174,8 +190,18 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
|
||||
(node, payload)
|
||||
}
|
||||
Console(endpoint, creds) if creds.use_cleartext_password_flow => {
|
||||
// This is a hack to allow cleartext password in secure connections (wss).
|
||||
let payload = fetch_plaintext_password(client).await?;
|
||||
let creds = creds.as_ref();
|
||||
let node = console::Api::new(endpoint, extra, &creds)
|
||||
.wake_compute()
|
||||
.await?;
|
||||
|
||||
(node, payload)
|
||||
}
|
||||
Postgres(endpoint, creds) if creds.project.is_none() => {
|
||||
let payload = fetch_magic_payload.await?;
|
||||
let payload = fetch_magic_payload(client).await?;
|
||||
|
||||
let mut creds = creds.as_ref();
|
||||
creds.project = Some(payload.project.as_str().into());
|
||||
|
||||
@@ -34,6 +34,9 @@ pub struct ClientCredentials<'a> {
|
||||
pub user: &'a str,
|
||||
pub dbname: &'a str,
|
||||
pub project: Option<Cow<'a, str>>,
|
||||
/// If `True`, we'll use the old cleartext password flow. This is used for
|
||||
/// websocket connections, which want to minimize the number of round trips.
|
||||
pub use_cleartext_password_flow: bool,
|
||||
}
|
||||
|
||||
impl ClientCredentials<'_> {
|
||||
@@ -50,6 +53,7 @@ impl<'a> ClientCredentials<'a> {
|
||||
user: self.user,
|
||||
dbname: self.dbname,
|
||||
project: self.project().map(Cow::Borrowed),
|
||||
use_cleartext_password_flow: self.use_cleartext_password_flow,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -59,6 +63,7 @@ impl<'a> ClientCredentials<'a> {
|
||||
params: &'a StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
common_name: Option<&str>,
|
||||
use_cleartext_password_flow: bool,
|
||||
) -> Result<Self, ClientCredsParseError> {
|
||||
use ClientCredsParseError::*;
|
||||
|
||||
@@ -108,6 +113,7 @@ impl<'a> ClientCredentials<'a> {
|
||||
user = user,
|
||||
dbname = dbname,
|
||||
project = project.as_deref(),
|
||||
use_cleartext_password_flow = use_cleartext_password_flow,
|
||||
"credentials"
|
||||
);
|
||||
|
||||
@@ -115,6 +121,7 @@ impl<'a> ClientCredentials<'a> {
|
||||
user,
|
||||
dbname,
|
||||
project,
|
||||
use_cleartext_password_flow,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -141,7 +148,7 @@ mod tests {
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
|
||||
// TODO: check that `creds.dbname` is None.
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None, false)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
|
||||
Ok(())
|
||||
@@ -151,7 +158,7 @@ mod tests {
|
||||
fn parse_missing_project() -> anyhow::Result<()> {
|
||||
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
|
||||
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None, false)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project, None);
|
||||
@@ -166,7 +173,7 @@ mod tests {
|
||||
let sni = Some("foo.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("foo"));
|
||||
@@ -182,7 +189,7 @@ mod tests {
|
||||
("options", "-ckey=1 project=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None, false)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
@@ -201,7 +208,7 @@ mod tests {
|
||||
let sni = Some("baz.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("baz"));
|
||||
@@ -220,7 +227,8 @@ mod tests {
|
||||
let sni = Some("second.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
|
||||
let err =
|
||||
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
|
||||
match err {
|
||||
InconsistentProjectNames { domain, option } => {
|
||||
assert_eq!(option, "first");
|
||||
@@ -237,7 +245,8 @@ mod tests {
|
||||
let sni = Some("project.localhost");
|
||||
let common_name = Some("example.com");
|
||||
|
||||
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
|
||||
let err =
|
||||
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
|
||||
match err {
|
||||
InconsistentSni { sni, cn } => {
|
||||
assert_eq!(sni, "project.localhost");
|
||||
|
||||
@@ -37,6 +37,17 @@ impl AuthMethod for PasswordHack {
|
||||
}
|
||||
}
|
||||
|
||||
/// Use clear-text password auth called `password` in docs
|
||||
/// <https://www.postgresql.org/docs/current/auth-password.html>
|
||||
pub struct CleartextPassword;
|
||||
|
||||
impl AuthMethod for CleartextPassword {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub struct AuthFlow<'a, Stream, State> {
|
||||
@@ -86,6 +97,18 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<Vec<u8>> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
|
||||
|
||||
Ok(password.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod server;
|
||||
pub mod websocket;
|
||||
|
||||
use crate::url::ApiUrl;
|
||||
|
||||
|
||||
263
proxy/src/http/websocket.rs
Normal file
263
proxy/src/http/websocket.rs
Normal file
@@ -0,0 +1,263 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{Sink, Stream, StreamExt};
|
||||
use hyper::server::accept::{self};
|
||||
use hyper::server::conn::AddrIncoming;
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper::{Body, Request, Response, StatusCode};
|
||||
use hyper_tungstenite::{tungstenite, WebSocketStream};
|
||||
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket};
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::future::ready;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tls_listener::TlsListener;
|
||||
|
||||
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
use utils::http::{error::ApiError, json::json_response};
|
||||
|
||||
use crate::cancellation::CancelMap;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::proxy::handle_ws_client;
|
||||
|
||||
pin_project! {
|
||||
/// This is a wrapper around a WebSocketStream that implements AsyncRead and AsyncWrite.
|
||||
pub struct WebSocketRW {
|
||||
#[pin]
|
||||
stream: WebSocketStream<Upgraded>,
|
||||
chunk: Option<bytes::Bytes>,
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: explain why this is safe or try to remove `unsafe impl`.
|
||||
unsafe impl Sync for WebSocketRW {}
|
||||
|
||||
impl WebSocketRW {
|
||||
pub fn new(stream: WebSocketStream<Upgraded>) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
chunk: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn has_chunk(&self) -> bool {
|
||||
if let Some(ref chunk) = self.chunk {
|
||||
chunk.remaining() > 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ws_err_into(e: tungstenite::Error) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, e.to_string())
|
||||
}
|
||||
|
||||
impl AsyncWrite for WebSocketRW {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
let mut this = self.project();
|
||||
match this.stream.as_mut().poll_ready(cx) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
if let Err(e) = this
|
||||
.stream
|
||||
.as_mut()
|
||||
.start_send(Message::Binary(buf.to_vec()))
|
||||
{
|
||||
Poll::Ready(Err(ws_err_into(e)))
|
||||
} else {
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(ws_err_into(e))),
|
||||
Poll::Pending => {
|
||||
cx.waker().wake_by_ref();
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().stream.poll_flush(cx).map_err(ws_err_into)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().stream.poll_close(cx).map_err(ws_err_into)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for WebSocketRW {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if buf.remaining() == 0 {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
let inner_buf = match self.as_mut().poll_fill_buf(cx) {
|
||||
Poll::Ready(Ok(buf)) => buf,
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
};
|
||||
let len = std::cmp::min(inner_buf.len(), buf.remaining());
|
||||
buf.put_slice(&inner_buf[..len]);
|
||||
|
||||
self.consume(len);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncBufRead for WebSocketRW {
|
||||
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
|
||||
loop {
|
||||
if self.as_mut().has_chunk() {
|
||||
let buf = self.project().chunk.as_ref().unwrap().chunk();
|
||||
return Poll::Ready(Ok(buf));
|
||||
} else {
|
||||
match self.as_mut().project().stream.poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(message))) => match message {
|
||||
Message::Text(_) => {}
|
||||
Message::Binary(chunk) => {
|
||||
*self.as_mut().project().chunk = Some(Bytes::from(chunk));
|
||||
}
|
||||
Message::Ping(_) => {
|
||||
// No need to send a reply: tungstenite takes care of this for you.
|
||||
}
|
||||
Message::Pong(_) => {}
|
||||
Message::Close(_) => {
|
||||
// No need to send a reply: tungstenite takes care of this for you.
|
||||
return Poll::Ready(Ok(&[]));
|
||||
}
|
||||
Message::Frame(_) => {
|
||||
unreachable!();
|
||||
}
|
||||
},
|
||||
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(ws_err_into(err))),
|
||||
Poll::Ready(None) => return Poll::Ready(Ok(&[])),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(self: Pin<&mut Self>, amt: usize) {
|
||||
if amt > 0 {
|
||||
self.project()
|
||||
.chunk
|
||||
.as_mut()
|
||||
.expect("No chunk present")
|
||||
.advance(amt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn serve_websocket(
|
||||
websocket: HyperWebsocket,
|
||||
config: &ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
hostname: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
handle_ws_client(
|
||||
config,
|
||||
cancel_map,
|
||||
session_id,
|
||||
WebSocketRW::new(websocket),
|
||||
hostname,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ws_handler(
|
||||
mut request: Request<Body>,
|
||||
config: &'static ProxyConfig,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
session_id: uuid::Uuid,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.split(':').next())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// Check if the request is a websocket upgrade request.
|
||||
if hyper_tungstenite::is_upgrade_request(&request) {
|
||||
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host).await
|
||||
{
|
||||
error!("error in websocket connection: {:?}", e);
|
||||
}
|
||||
});
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else {
|
||||
json_response(StatusCode::OK, "Connect with a websocket client")
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn task_main(
|
||||
ws_listener: TcpListener,
|
||||
config: &'static ProxyConfig,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("websocket server has shut down");
|
||||
}
|
||||
|
||||
let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
|
||||
Some(config) => config.into(),
|
||||
None => {
|
||||
warn!("TLS config is missing, WebSocket Secure server will not be started");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let addr_incoming = AddrIncoming::from_listener(ws_listener)?;
|
||||
|
||||
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
|
||||
if let Err(err) = conn {
|
||||
error!("failed to accept TLS connection for websockets: {:?}", err);
|
||||
ready(false)
|
||||
} else {
|
||||
ready(true)
|
||||
}
|
||||
});
|
||||
|
||||
let make_svc = hyper::service::make_service_fn(|_stream| async move {
|
||||
Ok::<_, Infallible>(hyper::service::service_fn(
|
||||
move |req: Request<Body>| async move {
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
ws_handler(req, config, cancel_map, session_id)
|
||||
.instrument(info_span!(
|
||||
"ws-client",
|
||||
session = format_args!("{session_id}")
|
||||
))
|
||||
.await
|
||||
},
|
||||
))
|
||||
});
|
||||
|
||||
hyper::Server::builder(accept::from_stream(tls_listener))
|
||||
.serve(make_svc)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -110,12 +110,23 @@ async fn main() -> anyhow::Result<()> {
|
||||
info!("Starting proxy on {proxy_address}");
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
|
||||
let tasks = [
|
||||
let mut tasks = vec![
|
||||
tokio::spawn(http::server::task_main(http_listener)),
|
||||
tokio::spawn(proxy::task_main(config, proxy_listener)),
|
||||
tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener)),
|
||||
]
|
||||
.map(flatten_err);
|
||||
];
|
||||
|
||||
if let Some(wss_address) = arg_matches.get_one::<String>("wss") {
|
||||
let wss_address: SocketAddr = wss_address.parse()?;
|
||||
info!("Starting wss on {}", wss_address);
|
||||
let wss_listener = TcpListener::bind(wss_address).await?;
|
||||
tasks.push(tokio::spawn(http::websocket::task_main(
|
||||
wss_listener,
|
||||
config,
|
||||
)));
|
||||
}
|
||||
|
||||
let tasks = tasks.into_iter().map(flatten_err);
|
||||
|
||||
set_build_info_metric(GIT_VERSION);
|
||||
// This will block until all tasks have completed.
|
||||
@@ -155,6 +166,11 @@ fn cli() -> clap::Command {
|
||||
.help("listen for incoming http connections (metrics, etc) on ip:port")
|
||||
.default_value("127.0.0.1:7001"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("wss")
|
||||
.long("wss")
|
||||
.help("listen for incoming wss connections on ip:port"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("uri")
|
||||
.short('u')
|
||||
|
||||
@@ -82,6 +82,47 @@ pub async fn task_main(
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_ws_client(
|
||||
config: &ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||
hostname: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
// The `closed` counter will increase when this future is destroyed.
|
||||
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
|
||||
scopeguard::defer! {
|
||||
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
|
||||
}
|
||||
|
||||
let tls = config.tls_config.as_ref();
|
||||
let hostname = hostname.as_deref();
|
||||
|
||||
// TLS is None here, because the connection is already encrypted.
|
||||
let do_handshake = handshake(stream, None, cancel_map).instrument(info_span!("handshake"));
|
||||
let (mut stream, params) = match do_handshake.await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
};
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let creds = {
|
||||
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_name, true))
|
||||
.transpose();
|
||||
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
|
||||
let client = Client::new(stream, creds, ¶ms, session_id);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(session))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn handle_client(
|
||||
config: &ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
@@ -108,7 +149,7 @@ async fn handle_client(
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name))
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name, false))
|
||||
.transpose();
|
||||
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
|
||||
Reference in New Issue
Block a user