mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-01 17:50:38 +00:00
Compare commits
7 Commits
refactor-c
...
proxy-http
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
081e794878 | ||
|
|
07ccaa7575 | ||
|
|
a318213e72 | ||
|
|
520171f17a | ||
|
|
85e17bc550 | ||
|
|
76fe42aae0 | ||
|
|
4d37f89189 |
16
Cargo.lock
generated
16
Cargo.lock
generated
@@ -2389,19 +2389,6 @@ dependencies = [
|
||||
"tokio-native-tls",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tungstenite"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9"
|
||||
dependencies = [
|
||||
"hyper",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.56"
|
||||
@@ -3893,7 +3880,6 @@ dependencies = [
|
||||
"hostname",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper-tungstenite",
|
||||
"ipnet",
|
||||
"itertools",
|
||||
"md5",
|
||||
@@ -3939,11 +3925,13 @@ dependencies = [
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
"tracing-utils",
|
||||
"tungstenite",
|
||||
"url",
|
||||
"utils",
|
||||
"uuid",
|
||||
|
||||
@@ -89,7 +89,6 @@ http-types = { version = "2", default-features = false }
|
||||
humantime = "2.1"
|
||||
humantime-serde = "1.1.1"
|
||||
hyper = "0.14"
|
||||
hyper-tungstenite = "0.11"
|
||||
inotify = "0.10.2"
|
||||
ipnet = "2.9.0"
|
||||
itertools = "0.10"
|
||||
@@ -156,6 +155,7 @@ tokio-rustls = "0.24"
|
||||
tokio-stream = "0.1"
|
||||
tokio-tar = "0.3"
|
||||
tokio-util = { version = "0.7.10", features = ["io", "rt"] }
|
||||
tokio-tungstenite = "0.20"
|
||||
toml = "0.7"
|
||||
toml_edit = "0.19"
|
||||
tonic = {version = "0.9", features = ["tls", "tls-roots"]}
|
||||
@@ -163,6 +163,7 @@ tracing = "0.1"
|
||||
tracing-error = "0.2.0"
|
||||
tracing-opentelemetry = "0.19.0"
|
||||
tracing-subscriber = { version = "0.3", default_features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
|
||||
tungstenite = "0.20"
|
||||
url = "2.2"
|
||||
uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
|
||||
walkdir = "2.3.2"
|
||||
|
||||
@@ -27,7 +27,6 @@ hex.workspace = true
|
||||
hmac.workspace = true
|
||||
hostname.workspace = true
|
||||
humantime.workspace = true
|
||||
hyper-tungstenite.workspace = true
|
||||
hyper.workspace = true
|
||||
ipnet.workspace = true
|
||||
itertools.workspace = true
|
||||
@@ -66,11 +65,13 @@ tls-listener.workspace = true
|
||||
tokio-postgres.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tokio-tungstenite.workspace = true
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
tracing-opentelemetry.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing-utils.workspace = true
|
||||
tracing.workspace = true
|
||||
tungstenite.workspace = true
|
||||
url.workspace = true
|
||||
utils.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
@@ -56,16 +56,16 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
pub fn endpoint_sni<'a>(
|
||||
sni: &'a str,
|
||||
common_names: &HashSet<String>,
|
||||
) -> Result<&'a str, ComputeUserInfoParseError> {
|
||||
) -> Result<(&'a str, &'a str), ComputeUserInfoParseError> {
|
||||
let Some((subdomain, common_name)) = sni.split_once('.') else {
|
||||
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
|
||||
return Ok((sni, ""));
|
||||
};
|
||||
if !common_names.contains(common_name) {
|
||||
return Err(ComputeUserInfoParseError::UnknownCommonName {
|
||||
cn: common_name.into(),
|
||||
});
|
||||
}
|
||||
Ok(subdomain)
|
||||
Ok((subdomain, common_name))
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
@@ -102,7 +102,7 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
|
||||
let project_from_domain = if let Some(sni_str) = sni {
|
||||
if let Some(cn) = common_names {
|
||||
Some(SmolStr::from(endpoint_sni(sni_str, cn)?))
|
||||
Some(SmolStr::from(endpoint_sni(sni_str, cn)?.0))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
||||
@@ -77,7 +77,11 @@ pub async fn task_main(
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = tls_config.to_server_config().into();
|
||||
|
||||
let mut tls_server_config = rustls::ServerConfig::clone(&tls_config.to_server_config());
|
||||
// prefer http2, but support http/1.1
|
||||
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
|
||||
|
||||
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
|
||||
let _ = addr_incoming.set_nodelay(true);
|
||||
@@ -150,6 +154,7 @@ pub async fn task_main(
|
||||
);
|
||||
|
||||
hyper::Server::builder(accept::from_stream(tls_listener))
|
||||
.http2_enable_connect_protocol()
|
||||
.serve(make_svc)
|
||||
.with_graceful_shutdown(cancellation_token.cancelled())
|
||||
.await?;
|
||||
@@ -213,11 +218,13 @@ async fn request_handler(
|
||||
.and_then(|h| h.split(':').next())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let ws_config = None;
|
||||
|
||||
// Check if the request is a websocket upgrade request.
|
||||
if hyper_tungstenite::is_upgrade_request(&request) {
|
||||
if websocket::is_upgrade_request(&request) {
|
||||
info!(session_id = ?session_id, "performing websocket upgrade");
|
||||
|
||||
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
|
||||
let (response, websocket) = websocket::upgrade(&mut request, ws_config)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
ws_connections.spawn(
|
||||
@@ -240,6 +247,34 @@ async fn request_handler(
|
||||
.in_current_span(),
|
||||
);
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else if websocket::is_connect_request(&request) {
|
||||
info!(session_id = ?session_id, "performing http2 websocket upgrade");
|
||||
|
||||
let (response, websocket) = websocket::connect(&mut request, ws_config)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws2", &config.region);
|
||||
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
&mut ctx,
|
||||
websocket,
|
||||
&cancel_map,
|
||||
host,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!(session_id = ?session_id, "error in http2 websocket connection: {e:#}");
|
||||
}
|
||||
}
|
||||
.in_current_span(),
|
||||
);
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||
@@ -256,7 +291,7 @@ async fn request_handler(
|
||||
.await
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
|
||||
Response::builder()
|
||||
.header("Allow", "OPTIONS, POST")
|
||||
.header("Allow", "OPTIONS, POST, CONNECT")
|
||||
.header("Access-Control-Allow-Origin", "*")
|
||||
.header(
|
||||
"Access-Control-Allow-Headers",
|
||||
|
||||
@@ -170,22 +170,21 @@ fn get_conn_info(
|
||||
let hostname = connection_url
|
||||
.host_str()
|
||||
.ok_or(anyhow::anyhow!("no host"))?;
|
||||
let (endpoint, common_name) = endpoint_sni(hostname, &tls.common_names)?;
|
||||
|
||||
let host_header = headers
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.split(':').next());
|
||||
|
||||
if hostname != sni_hostname {
|
||||
if !sni_hostname.ends_with(common_name) {
|
||||
return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
|
||||
} else if let Some(h) = host_header {
|
||||
if h != hostname {
|
||||
if !h.ends_with(common_name) {
|
||||
return Err(anyhow::anyhow!("mismatched host header and hostname"));
|
||||
}
|
||||
}
|
||||
|
||||
let endpoint = endpoint_sni(hostname, &tls.common_names)?;
|
||||
|
||||
let endpoint: SmolStr = endpoint.into();
|
||||
ctx.set_endpoint_id(Some(endpoint.clone()));
|
||||
|
||||
|
||||
@@ -8,9 +8,15 @@ use crate::{
|
||||
};
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{Sink, Stream};
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
||||
use hyper::{ext::Protocol, upgrade::Upgraded, Body, Method, Request, Response};
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tungstenite::{
|
||||
error::{Error as WSError, ProtocolError},
|
||||
handshake::derive_accept_key,
|
||||
protocol::{Role, WebSocketConfig},
|
||||
Message,
|
||||
};
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
@@ -150,19 +156,202 @@ pub async fn serve_websocket(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Try to upgrade a received `hyper::Request` to a websocket connection.
|
||||
///
|
||||
/// The function returns a HTTP response and a future that resolves to the websocket stream.
|
||||
/// The response body *MUST* be sent to the client before the future can be resolved.
|
||||
///
|
||||
/// This functions checks `Sec-WebSocket-Key` and `Sec-WebSocket-Version` headers.
|
||||
/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers.
|
||||
/// You can inspect the headers manually before calling this function,
|
||||
/// and modify the response headers appropriately.
|
||||
///
|
||||
/// This function also does not look at the `Connection` or `Upgrade` headers.
|
||||
/// To check if a request is a websocket upgrade request, you can use [`is_upgrade_request`].
|
||||
/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually.
|
||||
///
|
||||
pub fn upgrade<B>(
|
||||
mut request: impl std::borrow::BorrowMut<Request<B>>,
|
||||
config: Option<WebSocketConfig>,
|
||||
) -> Result<(Response<Body>, HyperWebsocket), ProtocolError> {
|
||||
let request = request.borrow_mut();
|
||||
|
||||
let key = request
|
||||
.headers()
|
||||
.get("Sec-WebSocket-Key")
|
||||
.ok_or(ProtocolError::MissingSecWebSocketKey)?;
|
||||
if request
|
||||
.headers()
|
||||
.get("Sec-WebSocket-Version")
|
||||
.map(|v| v.as_bytes())
|
||||
!= Some(b"13")
|
||||
{
|
||||
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
|
||||
}
|
||||
|
||||
let response = Response::builder()
|
||||
.status(hyper::StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(hyper::header::CONNECTION, "upgrade")
|
||||
.header(hyper::header::UPGRADE, "websocket")
|
||||
.header("Sec-WebSocket-Accept", &derive_accept_key(key.as_bytes()))
|
||||
.body(Body::from("switching to websocket protocol"))
|
||||
.expect("bug: failed to build response");
|
||||
|
||||
let stream = HyperWebsocket {
|
||||
inner: hyper::upgrade::on(request),
|
||||
config,
|
||||
};
|
||||
|
||||
Ok((response, stream))
|
||||
}
|
||||
|
||||
/// Check if a request is a websocket upgrade request.
|
||||
///
|
||||
/// If the `Upgrade` header lists multiple protocols,
|
||||
/// this function returns true if of them are `"websocket"`,
|
||||
/// If the server supports multiple upgrade protocols,
|
||||
/// it would be more appropriate to try each listed protocol in order.
|
||||
pub fn is_upgrade_request<B>(request: &hyper::Request<B>) -> bool {
|
||||
header_contains_value(request.headers(), hyper::header::CONNECTION, "Upgrade")
|
||||
&& header_contains_value(request.headers(), hyper::header::UPGRADE, "websocket")
|
||||
}
|
||||
|
||||
/// Check if there is a header of the given name containing the wanted value.
|
||||
fn header_contains_value(
|
||||
headers: &hyper::HeaderMap,
|
||||
header: impl hyper::header::AsHeaderName,
|
||||
value: impl AsRef<[u8]>,
|
||||
) -> bool {
|
||||
let value = value.as_ref();
|
||||
for header in headers.get_all(header) {
|
||||
if header
|
||||
.as_bytes()
|
||||
.split(|&c| c == b',')
|
||||
.any(|x| trim(x).eq_ignore_ascii_case(value))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn trim(data: &[u8]) -> &[u8] {
|
||||
trim_end(trim_start(data))
|
||||
}
|
||||
|
||||
fn trim_start(data: &[u8]) -> &[u8] {
|
||||
if let Some(start) = data.iter().position(|x| !x.is_ascii_whitespace()) {
|
||||
&data[start..]
|
||||
} else {
|
||||
b""
|
||||
}
|
||||
}
|
||||
|
||||
fn trim_end(data: &[u8]) -> &[u8] {
|
||||
if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) {
|
||||
&data[..last + 1]
|
||||
} else {
|
||||
b""
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to upgrade a received `hyper::Request` to a websocket connection.
|
||||
///
|
||||
/// The function returns a HTTP response and a future that resolves to the websocket stream.
|
||||
/// The response body *MUST* be sent to the client before the future can be resolved.
|
||||
///
|
||||
/// This functions checks `Sec-WebSocket-Version` header.
|
||||
/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers.
|
||||
/// You can inspect the headers manually before calling this function,
|
||||
/// and modify the response headers appropriately.
|
||||
///
|
||||
/// This function also does not look at the `Connection` or `Upgrade` headers.
|
||||
/// To check if a request is a websocket upgrade request, you can use [`is_upgrade2_request`].
|
||||
/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually.
|
||||
///
|
||||
pub fn connect<B>(
|
||||
mut request: impl std::borrow::BorrowMut<Request<B>>,
|
||||
config: Option<WebSocketConfig>,
|
||||
) -> Result<(Response<Body>, HyperWebsocket), ProtocolError> {
|
||||
let request = request.borrow_mut();
|
||||
|
||||
if request
|
||||
.headers()
|
||||
.get("Sec-WebSocket-Version")
|
||||
.map(|v| v.as_bytes())
|
||||
!= Some(b"13")
|
||||
{
|
||||
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
|
||||
}
|
||||
|
||||
let response = Response::builder()
|
||||
.status(hyper::StatusCode::OK)
|
||||
.body(Body::from("switching to websocket protocol"))
|
||||
.expect("bug: failed to build response");
|
||||
|
||||
let stream = HyperWebsocket {
|
||||
inner: hyper::upgrade::on(request),
|
||||
config,
|
||||
};
|
||||
|
||||
Ok((response, stream))
|
||||
}
|
||||
|
||||
/// Check if a request is a websocket connect request.
|
||||
pub fn is_connect_request<B>(request: &hyper::Request<B>) -> bool {
|
||||
request.method() == Method::CONNECT
|
||||
&& request
|
||||
.extensions()
|
||||
.get::<Protocol>()
|
||||
.is_some_and(|protocol| protocol.as_str() == "websocket")
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
/// A future that resolves to a websocket stream when the associated connection completes.
|
||||
#[derive(Debug)]
|
||||
pub struct HyperWebsocket {
|
||||
#[pin]
|
||||
inner: hyper::upgrade::OnUpgrade,
|
||||
config: Option<WebSocketConfig>
|
||||
}
|
||||
}
|
||||
|
||||
impl std::future::Future for HyperWebsocket {
|
||||
type Output = Result<WebSocketStream<hyper::upgrade::Upgraded>, WSError>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
|
||||
let this = self.project();
|
||||
let upgraded = match this.inner.poll(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(x) => x,
|
||||
};
|
||||
|
||||
let upgraded =
|
||||
upgraded.map_err(|_| WSError::Protocol(ProtocolError::HandshakeIncomplete))?;
|
||||
|
||||
let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, None);
|
||||
tokio::pin!(stream);
|
||||
|
||||
// The future returned by `from_raw_socket` is always ready.
|
||||
// Not sure why it is a future in the first place.
|
||||
match stream.as_mut().poll(cx) {
|
||||
Poll::Pending => unreachable!("from_raw_socket should always be created ready"),
|
||||
Poll::Ready(x) => Poll::Ready(Ok(x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::pin;
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use hyper_tungstenite::{
|
||||
tungstenite::{protocol::Role, Message},
|
||||
WebSocketStream,
|
||||
};
|
||||
use tokio::{
|
||||
io::{duplex, AsyncReadExt, AsyncWriteExt},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tungstenite::{protocol::Role, Message};
|
||||
|
||||
use super::WebSocketRw;
|
||||
|
||||
|
||||
@@ -500,3 +500,22 @@ def test_sql_over_http_pool_custom_types(static_proxy: NeonProxy):
|
||||
"select array['foo'::foo, 'bar'::foo, 'baz'::foo] as data",
|
||||
)
|
||||
assert response["rows"][0]["data"] == ["foo", "bar", "baz"]
|
||||
|
||||
|
||||
def test_sql_over_http_different_endpoint(static_proxy: NeonProxy):
|
||||
static_proxy.safe_psql("create role http with login password 'http' superuser")
|
||||
|
||||
def q(sql: str, params: Optional[List[Any]] = None) -> Any:
|
||||
params = params or []
|
||||
connstr = f"postgresql://http:http@my-endpoint.{static_proxy.domain}:{static_proxy.proxy_port}/postgres"
|
||||
response = requests.post(
|
||||
f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
|
||||
data=json.dumps({"query": sql, "params": params}),
|
||||
headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr},
|
||||
verify=str(static_proxy.test_output_dir / "proxy.crt"),
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
return response.json()
|
||||
|
||||
rows = q("select 42 as answer")["rows"]
|
||||
assert rows == [{"answer": 42}]
|
||||
|
||||
Reference in New Issue
Block a user