Implement wss support in proxy (#3247)

This is a hacky implementation of WebSocket server, embedded into our
postgres proxy. The server is used to allow https://github.com/neondatabase/serverless 
to connect to our postgres from browser and serverless javascript functions.

How it will work (general schema):
- browser opens a websocket connection to
`wss://ep-abc-xyz-123.xx-central-1.aws.neon.tech/`
- proxy accepts this connection and terminates TLS (https)
- inside encrypted tunnel (HTTPS), browser initiates plain
(non-encrypted) postgres connection
- proxy performs auth as in usual plain pg connection and forwards
connection to the compute

Related issue: #3225
This commit is contained in:
Arthur Petukhovsky
2023-01-06 19:34:18 +04:00
committed by GitHub
parent df42213dbb
commit debd134b15
9 changed files with 476 additions and 14 deletions

79
Cargo.lock generated
View File

@@ -1700,6 +1700,19 @@ dependencies = [
"tokio-io-timeout",
]
[[package]]
name = "hyper-tungstenite"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d62004bcd4f6f85d9e2aa4206f1466ee67031f5ededcb6c6e62d48f9306ad879"
dependencies = [
"hyper",
"pin-project",
"tokio",
"tokio-tungstenite",
"tungstenite",
]
[[package]]
name = "iana-time-zone"
version = "0.1.53"
@@ -2658,6 +2671,7 @@ dependencies = [
"hex",
"hmac",
"hyper",
"hyper-tungstenite",
"itertools",
"md5",
"metrics",
@@ -2667,6 +2681,7 @@ dependencies = [
"pq_proto",
"rand",
"rcgen",
"regex",
"reqwest",
"routerify",
"rstest",
@@ -2678,6 +2693,7 @@ dependencies = [
"sha2",
"socket2",
"thiserror",
"tls-listener",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
@@ -2687,6 +2703,7 @@ dependencies = [
"url",
"utils",
"uuid",
"webpki-roots",
"workspace_hack",
"x509-parser",
]
@@ -3324,6 +3341,17 @@ dependencies = [
"syn",
]
[[package]]
name = "sha-1"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "sha1"
version = "0.10.5"
@@ -3687,6 +3715,20 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
[[package]]
name = "tls-listener"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9d4ff21187d434ac7709bfc7441ca88f63681247e5ad99f0f08c8c91ddc103d"
dependencies = [
"futures-util",
"hyper",
"pin-project-lite",
"thiserror",
"tokio",
"tokio-rustls",
]
[[package]]
name = "tokio"
version = "1.21.1"
@@ -3801,6 +3843,18 @@ dependencies = [
"xattr",
]
[[package]]
name = "tokio-tungstenite"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]]
name = "tokio-util"
version = "0.7.4"
@@ -4027,6 +4081,25 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
[[package]]
name = "tungstenite"
version = "0.17.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0"
dependencies = [
"base64 0.13.1",
"byteorder",
"bytes",
"http",
"httparse",
"log",
"rand",
"sha-1",
"thiserror",
"url",
"utf-8",
]
[[package]]
name = "typenum"
version = "1.16.0"
@@ -4115,6 +4188,12 @@ version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9"
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "utils"
version = "0.1.0"

View File

@@ -17,12 +17,14 @@ hashbrown = "0.12"
hex = "0.4.3"
hmac = "0.12.1"
hyper = "0.14"
hyper-tungstenite = "0.8.1"
itertools = "0.10.3"
md5 = "0.7.0"
once_cell = "1.13.0"
parking_lot = "0.12"
pin-project-lite = "0.2.7"
rand = "0.8.3"
regex = "1.4.5"
reqwest = { version = "0.11", default-features = false, features = [ "json", "rustls-tls" ] }
routerify = "3"
rustls = "0.20.0"
@@ -36,10 +38,12 @@ thiserror = "1.0.30"
tokio = { version = "1.17", features = ["macros"] }
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
tokio-rustls = "0.23.0"
tls-listener = { version = "0.5.1", features = ["rustls", "hyper-h1"] }
tracing = "0.1.36"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
url = "2.2.2"
uuid = { version = "1.2", features = ["v4", "serde"] }
webpki-roots = "0.22.5"
x509-parser = "0.14"
metrics = { path = "../libs/metrics" }

View File

@@ -149,7 +149,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the project name.
// We now expect to see a very specific payload in the place of password.
let fetch_magic_payload = async {
let fetch_magic_payload = |client| async {
warn!("project name not specified, resorting to the password hack auth flow");
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)
@@ -161,10 +161,26 @@ impl BackendType<'_, ClientCredentials<'_>> {
auth::Result::Ok(payload)
};
// If we want to use cleartext password flow, we can read the password
// from the client and pretend that it's a magic payload (PasswordHack hack).
let fetch_plaintext_password = |client| async {
info!("using cleartext password flow");
let payload = AuthFlow::new(client)
.begin(auth::CleartextPassword)
.await?
.authenticate()
.await?;
auth::Result::Ok(auth::password_hack::PasswordHackPayload {
project: String::new(),
password: payload,
})
};
// TODO: find a proper way to merge those very similar blocks.
let (mut node, payload) = match self {
Console(endpoint, creds) if creds.project.is_none() => {
let payload = fetch_magic_payload.await?;
let payload = fetch_magic_payload(client).await?;
let mut creds = creds.as_ref();
creds.project = Some(payload.project.as_str().into());
@@ -174,8 +190,18 @@ impl BackendType<'_, ClientCredentials<'_>> {
(node, payload)
}
Console(endpoint, creds) if creds.use_cleartext_password_flow => {
// This is a hack to allow cleartext password in secure connections (wss).
let payload = fetch_plaintext_password(client).await?;
let creds = creds.as_ref();
let node = console::Api::new(endpoint, extra, &creds)
.wake_compute()
.await?;
(node, payload)
}
Postgres(endpoint, creds) if creds.project.is_none() => {
let payload = fetch_magic_payload.await?;
let payload = fetch_magic_payload(client).await?;
let mut creds = creds.as_ref();
creds.project = Some(payload.project.as_str().into());

View File

@@ -34,6 +34,9 @@ pub struct ClientCredentials<'a> {
pub user: &'a str,
pub dbname: &'a str,
pub project: Option<Cow<'a, str>>,
/// If `True`, we'll use the old cleartext password flow. This is used for
/// websocket connections, which want to minimize the number of round trips.
pub use_cleartext_password_flow: bool,
}
impl ClientCredentials<'_> {
@@ -50,6 +53,7 @@ impl<'a> ClientCredentials<'a> {
user: self.user,
dbname: self.dbname,
project: self.project().map(Cow::Borrowed),
use_cleartext_password_flow: self.use_cleartext_password_flow,
}
}
}
@@ -59,6 +63,7 @@ impl<'a> ClientCredentials<'a> {
params: &'a StartupMessageParams,
sni: Option<&str>,
common_name: Option<&str>,
use_cleartext_password_flow: bool,
) -> Result<Self, ClientCredsParseError> {
use ClientCredsParseError::*;
@@ -108,6 +113,7 @@ impl<'a> ClientCredentials<'a> {
user = user,
dbname = dbname,
project = project.as_deref(),
use_cleartext_password_flow = use_cleartext_password_flow,
"credentials"
);
@@ -115,6 +121,7 @@ impl<'a> ClientCredentials<'a> {
user,
dbname,
project,
use_cleartext_password_flow,
})
}
}
@@ -141,7 +148,7 @@ mod tests {
let options = StartupMessageParams::new([("user", "john_doe")]);
// TODO: check that `creds.dbname` is None.
let creds = ClientCredentials::parse(&options, None, None)?;
let creds = ClientCredentials::parse(&options, None, None, false)?;
assert_eq!(creds.user, "john_doe");
Ok(())
@@ -151,7 +158,7 @@ mod tests {
fn parse_missing_project() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
let creds = ClientCredentials::parse(&options, None, None)?;
let creds = ClientCredentials::parse(&options, None, None, false)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project, None);
@@ -166,7 +173,7 @@ mod tests {
let sni = Some("foo.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(&options, sni, common_name)?;
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("foo"));
@@ -182,7 +189,7 @@ mod tests {
("options", "-ckey=1 project=bar -c geqo=off"),
]);
let creds = ClientCredentials::parse(&options, None, None)?;
let creds = ClientCredentials::parse(&options, None, None, false)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("bar"));
@@ -201,7 +208,7 @@ mod tests {
let sni = Some("baz.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(&options, sni, common_name)?;
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("baz"));
@@ -220,7 +227,8 @@ mod tests {
let sni = Some("second.localhost");
let common_name = Some("localhost");
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
let err =
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
assert_eq!(option, "first");
@@ -237,7 +245,8 @@ mod tests {
let sni = Some("project.localhost");
let common_name = Some("example.com");
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
let err =
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
match err {
InconsistentSni { sni, cn } => {
assert_eq!(sni, "project.localhost");

View File

@@ -37,6 +37,17 @@ impl AuthMethod for PasswordHack {
}
}
/// Use clear-text password auth called `password` in docs
/// <https://www.postgresql.org/docs/current/auth-password.html>
pub struct CleartextPassword;
impl AuthMethod for CleartextPassword {
#[inline(always)]
fn first_message(&self) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub struct AuthFlow<'a, Stream, State> {
@@ -86,6 +97,18 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<Vec<u8>> {
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
Ok(password.to_vec())
}
}
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.

View File

@@ -1,4 +1,5 @@
pub mod server;
pub mod websocket;
use crate::url::ApiUrl;

263
proxy/src/http/websocket.rs Normal file
View File

@@ -0,0 +1,263 @@
use bytes::{Buf, Bytes};
use futures::{Sink, Stream, StreamExt};
use hyper::server::accept::{self};
use hyper::server::conn::AddrIncoming;
use hyper::upgrade::Upgraded;
use hyper::{Body, Request, Response, StatusCode};
use hyper_tungstenite::{tungstenite, WebSocketStream};
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket};
use pin_project_lite::pin_project;
use tokio::net::TcpListener;
use std::convert::Infallible;
use std::future::ready;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tls_listener::TlsListener;
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tracing::{error, info, info_span, warn, Instrument};
use utils::http::{error::ApiError, json::json_response};
use crate::cancellation::CancelMap;
use crate::config::ProxyConfig;
use crate::proxy::handle_ws_client;
pin_project! {
/// This is a wrapper around a WebSocketStream that implements AsyncRead and AsyncWrite.
pub struct WebSocketRW {
#[pin]
stream: WebSocketStream<Upgraded>,
chunk: Option<bytes::Bytes>,
}
}
// FIXME: explain why this is safe or try to remove `unsafe impl`.
unsafe impl Sync for WebSocketRW {}
impl WebSocketRW {
pub fn new(stream: WebSocketStream<Upgraded>) -> Self {
Self {
stream,
chunk: None,
}
}
fn has_chunk(&self) -> bool {
if let Some(ref chunk) = self.chunk {
chunk.remaining() > 0
} else {
false
}
}
}
fn ws_err_into(e: tungstenite::Error) -> io::Error {
io::Error::new(io::ErrorKind::Other, e.to_string())
}
impl AsyncWrite for WebSocketRW {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let mut this = self.project();
match this.stream.as_mut().poll_ready(cx) {
Poll::Ready(Ok(())) => {
if let Err(e) = this
.stream
.as_mut()
.start_send(Message::Binary(buf.to_vec()))
{
Poll::Ready(Err(ws_err_into(e)))
} else {
Poll::Ready(Ok(buf.len()))
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(ws_err_into(e))),
Poll::Pending => {
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().stream.poll_flush(cx).map_err(ws_err_into)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().stream.poll_close(cx).map_err(ws_err_into)
}
}
impl AsyncRead for WebSocketRW {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}
let inner_buf = match self.as_mut().poll_fill_buf(cx) {
Poll::Ready(Ok(buf)) => buf,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
};
let len = std::cmp::min(inner_buf.len(), buf.remaining());
buf.put_slice(&inner_buf[..len]);
self.consume(len);
Poll::Ready(Ok(()))
}
}
impl AsyncBufRead for WebSocketRW {
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
loop {
if self.as_mut().has_chunk() {
let buf = self.project().chunk.as_ref().unwrap().chunk();
return Poll::Ready(Ok(buf));
} else {
match self.as_mut().project().stream.poll_next(cx) {
Poll::Ready(Some(Ok(message))) => match message {
Message::Text(_) => {}
Message::Binary(chunk) => {
*self.as_mut().project().chunk = Some(Bytes::from(chunk));
}
Message::Ping(_) => {
// No need to send a reply: tungstenite takes care of this for you.
}
Message::Pong(_) => {}
Message::Close(_) => {
// No need to send a reply: tungstenite takes care of this for you.
return Poll::Ready(Ok(&[]));
}
Message::Frame(_) => {
unreachable!();
}
},
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(ws_err_into(err))),
Poll::Ready(None) => return Poll::Ready(Ok(&[])),
Poll::Pending => return Poll::Pending,
}
}
}
}
fn consume(self: Pin<&mut Self>, amt: usize) {
if amt > 0 {
self.project()
.chunk
.as_mut()
.expect("No chunk present")
.advance(amt);
}
}
}
async fn serve_websocket(
websocket: HyperWebsocket,
config: &ProxyConfig,
cancel_map: &CancelMap,
session_id: uuid::Uuid,
hostname: Option<String>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
handle_ws_client(
config,
cancel_map,
session_id,
WebSocketRW::new(websocket),
hostname,
)
.await?;
Ok(())
}
async fn ws_handler(
mut request: Request<Body>,
config: &'static ProxyConfig,
cancel_map: Arc<CancelMap>,
session_id: uuid::Uuid,
) -> Result<Response<Body>, ApiError> {
let host = request
.headers()
.get("host")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(':').next())
.map(|s| s.to_string());
// Check if the request is a websocket upgrade request.
if hyper_tungstenite::is_upgrade_request(&request) {
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
.map_err(|e| ApiError::BadRequest(e.into()))?;
tokio::spawn(async move {
if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host).await
{
error!("error in websocket connection: {:?}", e);
}
});
// Return the response so the spawned future can continue.
Ok(response)
} else {
json_response(StatusCode::OK, "Connect with a websocket client")
}
}
pub async fn task_main(
ws_listener: TcpListener,
config: &'static ProxyConfig,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("websocket server has shut down");
}
let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
Some(config) => config.into(),
None => {
warn!("TLS config is missing, WebSocket Secure server will not be started");
return Ok(());
}
};
let addr_incoming = AddrIncoming::from_listener(ws_listener)?;
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
if let Err(err) = conn {
error!("failed to accept TLS connection for websockets: {:?}", err);
ready(false)
} else {
ready(true)
}
});
let make_svc = hyper::service::make_service_fn(|_stream| async move {
Ok::<_, Infallible>(hyper::service::service_fn(
move |req: Request<Body>| async move {
let cancel_map = Arc::new(CancelMap::default());
let session_id = uuid::Uuid::new_v4();
ws_handler(req, config, cancel_map, session_id)
.instrument(info_span!(
"ws-client",
session = format_args!("{session_id}")
))
.await
},
))
});
hyper::Server::builder(accept::from_stream(tls_listener))
.serve(make_svc)
.await?;
Ok(())
}

View File

@@ -110,12 +110,23 @@ async fn main() -> anyhow::Result<()> {
info!("Starting proxy on {proxy_address}");
let proxy_listener = TcpListener::bind(proxy_address).await?;
let tasks = [
let mut tasks = vec![
tokio::spawn(http::server::task_main(http_listener)),
tokio::spawn(proxy::task_main(config, proxy_listener)),
tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener)),
]
.map(flatten_err);
];
if let Some(wss_address) = arg_matches.get_one::<String>("wss") {
let wss_address: SocketAddr = wss_address.parse()?;
info!("Starting wss on {}", wss_address);
let wss_listener = TcpListener::bind(wss_address).await?;
tasks.push(tokio::spawn(http::websocket::task_main(
wss_listener,
config,
)));
}
let tasks = tasks.into_iter().map(flatten_err);
set_build_info_metric(GIT_VERSION);
// This will block until all tasks have completed.
@@ -155,6 +166,11 @@ fn cli() -> clap::Command {
.help("listen for incoming http connections (metrics, etc) on ip:port")
.default_value("127.0.0.1:7001"),
)
.arg(
Arg::new("wss")
.long("wss")
.help("listen for incoming wss connections on ip:port"),
)
.arg(
Arg::new("uri")
.short('u')

View File

@@ -82,6 +82,47 @@ pub async fn task_main(
}
}
pub async fn handle_ws_client(
config: &ProxyConfig,
cancel_map: &CancelMap,
session_id: uuid::Uuid,
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
hostname: Option<String>,
) -> anyhow::Result<()> {
// The `closed` counter will increase when this future is destroyed.
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
scopeguard::defer! {
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
}
let tls = config.tls_config.as_ref();
let hostname = hostname.as_deref();
// TLS is None here, because the connection is already encrypted.
let do_handshake = handshake(stream, None, cancel_map).instrument(info_span!("handshake"));
let (mut stream, params) = match do_handshake.await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
};
// Extract credentials which we're going to use for auth.
let creds = {
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, hostname, common_name, true))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session))
.await
}
async fn handle_client(
config: &ProxyConfig,
cancel_map: &CancelMap,
@@ -108,7 +149,7 @@ async fn handle_client(
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, sni, common_name))
.map(|_| auth::ClientCredentials::parse(&params, sni, common_name, false))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?