Compare commits

...

9 Commits

Author SHA1 Message Date
Conrad Ludgate
bea8269a17 zerocopy parsing 2024-11-04 13:07:27 +00:00
Conrad Ludgate
3500a758af parse ip addr for consistency 2024-11-04 11:19:36 +00:00
Conrad Ludgate
3b3c2da57f [proxy]: parse proxy protocol TLVs with aws/azure support 2024-11-01 18:11:08 +00:00
John Spray
3c16bd6e0b storcon: skip non-active projects in chaos injection (#9606)
## Problem

We may sometimes use scheduling modes like `Pause` to pin a tenant in
its current location for operational reasons. It is undesirable for the
chaos task to make any changes to such projects.

## Summary of changes

- Add a check for scheduling mode
- Add a log line when we do choose to do a chaos action for a tenant:
this will help us understand which operations originate from the chaos
task.
2024-11-01 16:47:20 +00:00
Erik Grinaker
123816e99a safekeeper: log slow WalAcceptor sends (#9564)
## Problem

We don't have any observability into full WalAcceptor queues per
timeline.

## Summary of changes

Logs a message when a WalAcceptor send has blocked for 5 seconds, and
another message when the send completes. This implies that the log
frequency is at most once every 5 seconds per timeline, so we don't need
further throttling.
2024-11-01 13:47:03 +01:00
Peter Bendel
8b3bcf71ee revert higher token expiration (#9605)
## Problem

The IAM role associated with our github action runner supports a max
token expiration which is lower than the value we tried.

## Summary of changes

Since we believe to have understood the performance regression we (by
ensuring availability zone affinity of compute and pageserver) the job
should again run in lower than 5 hours and we revert this change instead
of increasing the max session token expiration in the IAM role which
would reduce our security.
2024-11-01 12:46:02 +01:00
Erik Grinaker
4c2c8d6708 test_runner: fix tenant_get_shards with one pageserver (#9603)
## Problem

`tenant_get_shards()` does not work with a sharded tenant on 1
pageserver, as it assumes an unsharded tenant in this case. This special
case appears to have been added to handle e.g. `test_emergency_mode`,
where the storage controller is stopped. This breaks e.g. the sharded
ingest benchmark in #9591 when run with a single shard.

## Summary of changes

Correctly look up shards even with a single pageserver, but add a
special case that assumes an unsharded tenant if the storage controller
is stopped and the caller provides an explicit pageserver, in order to
accomodate `test_emergency_mode`.
2024-11-01 11:25:04 +00:00
Conrad Ludgate
2d1366c8ee fix pre-commit hook with python stubs (#9602)
fix #9601
2024-11-01 11:22:38 +00:00
Vlad Lazar
e589c2e5ec storage_controller: allow deployment infra to use infra token (#9596)
## Problem

We wish for the deployment orchestrator to use infra scoped tokens,
but storcon endpoints it's using require admin scoped tokens.

## Summary of Changes

Switch over all endpoints that are used by the deployment orchestrator
to use an infra scoped token. This causes no breakage during mixed
version scenarios because admin scoped tokens allow access to all
endpoints. The deployment orchestrator can cut over to the infra token
after this commit touches down in prod.

Once this commit is released we should also update the tests code to use
infra scoped tokens where appropriate. Currently it would fail on the
[compat tests](9761b6a64e/test_runner/regress/test_storage_controller.py (L69-L71)).
2024-10-31 18:29:16 +00:00
15 changed files with 421 additions and 109 deletions

View File

@@ -683,7 +683,7 @@ jobs:
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 43200 # 12 hours
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download

35
Cargo.lock generated
View File

@@ -34,7 +34,7 @@ dependencies = [
"getrandom 0.2.11",
"once_cell",
"version_check",
"zerocopy",
"zerocopy 0.7.31",
]
[[package]]
@@ -951,7 +951,7 @@ dependencies = [
"bitflags 2.4.1",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
@@ -974,7 +974,7 @@ dependencies = [
"bitflags 2.4.1",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"itertools 0.12.1",
"log",
"prettyplease",
"proc-macro2",
@@ -4314,7 +4314,7 @@ checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15"
dependencies = [
"bytes",
"heck 0.5.0",
"itertools 0.10.5",
"itertools 0.12.1",
"log",
"multimap",
"once_cell",
@@ -4334,7 +4334,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5"
dependencies = [
"anyhow",
"itertools 0.10.5",
"itertools 0.12.1",
"proc-macro2",
"quote",
"syn 2.0.52",
@@ -4433,6 +4433,8 @@ dependencies = [
"smallvec",
"smol_str",
"socket2",
"strum",
"strum_macros",
"subtle",
"thiserror",
"tikv-jemalloc-ctl",
@@ -4455,6 +4457,7 @@ dependencies = [
"walkdir",
"workspace_hack",
"x509-parser",
"zerocopy 0.8.8",
]
[[package]]
@@ -7550,7 +7553,16 @@ version = "0.7.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d"
dependencies = [
"zerocopy-derive",
"zerocopy-derive 0.7.31",
]
[[package]]
name = "zerocopy"
version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a4e33e6dce36f2adba29746927f8e848ba70989fdb61c772773bbdda8b5d6a7"
dependencies = [
"zerocopy-derive 0.8.8",
]
[[package]]
@@ -7564,6 +7576,17 @@ dependencies = [
"syn 2.0.52",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cd137b4cc21bde6ecce3bbbb3350130872cda0be2c6888874279ea76e17d4c1"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.52",
]
[[package]]
name = "zeroize"
version = "1.7.0"

View File

@@ -74,6 +74,8 @@ sha2 = { workspace = true, features = ["asm", "oid"] }
smol_str.workspace = true
smallvec.workspace = true
socket2.workspace = true
strum.workspace = true
strum_macros.workspace = true
subtle.workspace = true
thiserror.workspace = true
tikv-jemallocator.workspace = true
@@ -96,6 +98,7 @@ rustls-native-certs.workspace = true
x509-parser.workspace = true
postgres-protocol.workspace = true
redis.workspace = true
zerocopy = { version = "0.8", features = ["derive"] }
# jwt stuff
jose-jwa = "0.1.2"

View File

@@ -13,6 +13,7 @@ use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestMonitoring;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::protocol2::ConnectionInfo;
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use proxy::stream::{PqStream, Stream};
use rustls::crypto::aws_lc_rs;
@@ -179,7 +180,10 @@ async fn task_main(
info!(%peer_addr, "serving");
let ctx = RequestMonitoring::new(
session_id,
peer_addr.ip(),
ConnectionInfo {
addr: peer_addr,
extra: None,
},
proxy::metrics::Protocol::SniRouter,
"sni",
);

View File

@@ -11,7 +11,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::read_proxy_protocol;
use crate::protocol2::{read_proxy_protocol, ConnectionInfo};
use crate::proxy::connect_compute::{connect_to_compute, TcpMechanism};
use crate::proxy::handshake::{handshake, HandshakeData};
use crate::proxy::passthrough::ProxyPassthrough;
@@ -65,8 +65,8 @@ pub async fn task_main(
error!("proxy protocol header not supported");
return;
}
Ok((socket, Some(addr))) => (socket, addr.ip()),
Ok((socket, None)) => (socket, peer_addr.ip()),
Ok((socket, Some(info))) => (socket, info),
Ok((socket, None)) => (socket, ConnectionInfo{ addr: peer_addr, extra: None }),
};
match socket.inner.set_nodelay(true) {

View File

@@ -19,6 +19,7 @@ use crate::intern::{BranchIdInt, ProjectIdInt};
use crate::metrics::{
ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting,
};
use crate::protocol2::ConnectionInfo;
use crate::types::{DbName, EndpointId, RoleName};
pub mod parquet;
@@ -40,7 +41,7 @@ pub struct RequestMonitoring(
);
struct RequestMonitoringInner {
pub(crate) peer_addr: IpAddr,
pub(crate) conn_info: ConnectionInfo,
pub(crate) session_id: Uuid,
pub(crate) protocol: Protocol,
first_packet: chrono::DateTime<Utc>,
@@ -84,7 +85,7 @@ impl Clone for RequestMonitoring {
fn clone(&self) -> Self {
let inner = self.0.try_lock().expect("should not deadlock");
let new = RequestMonitoringInner {
peer_addr: inner.peer_addr,
conn_info: inner.conn_info.clone(),
session_id: inner.session_id,
protocol: inner.protocol,
first_packet: inner.first_packet,
@@ -117,7 +118,7 @@ impl Clone for RequestMonitoring {
impl RequestMonitoring {
pub fn new(
session_id: Uuid,
peer_addr: IpAddr,
conn_info: ConnectionInfo,
protocol: Protocol,
region: &'static str,
) -> Self {
@@ -125,13 +126,13 @@ impl RequestMonitoring {
"connect_request",
%protocol,
?session_id,
%peer_addr,
%conn_info,
ep = tracing::field::Empty,
role = tracing::field::Empty,
);
let inner = RequestMonitoringInner {
peer_addr,
conn_info,
session_id,
protocol,
first_packet: Utc::now(),
@@ -162,7 +163,11 @@ impl RequestMonitoring {
#[cfg(test)]
pub(crate) fn test() -> Self {
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test")
use std::net::SocketAddr;
let ip = IpAddr::from([127, 0, 0, 1]);
let addr = SocketAddr::new(ip, 5432);
let conn_info = ConnectionInfo { addr, extra: None };
RequestMonitoring::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
}
pub(crate) fn console_application_name(&self) -> String {
@@ -286,7 +291,12 @@ impl RequestMonitoring {
}
pub(crate) fn peer_addr(&self) -> IpAddr {
self.0.try_lock().expect("should not deadlock").peer_addr
self.0
.try_lock()
.expect("should not deadlock")
.conn_info
.addr
.ip()
}
pub(crate) fn cold_start_info(&self) -> ColdStartInfo {
@@ -362,7 +372,7 @@ impl RequestMonitoringInner {
}
fn has_private_peer_addr(&self) -> bool {
match self.peer_addr {
match self.conn_info.addr.ip() {
IpAddr::V4(ip) => ip.is_private(),
IpAddr::V6(_) => false,
}

View File

@@ -121,7 +121,7 @@ impl From<&RequestMonitoringInner> for RequestData {
fn from(value: &RequestMonitoringInner) -> Self {
Self {
session_id: value.session_id,
peer_addr: value.peer_addr.to_string(),
peer_addr: value.conn_info.addr.ip().to_string(),
timestamp: value.first_packet.naive_utc(),
username: value.user.as_deref().map(String::from),
application_name: value.application.as_deref().map(String::from),

View File

@@ -1,13 +1,17 @@
//! Proxy Protocol V2 implementation
//! Compatible with <https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt>
use core::fmt;
use std::io;
use std::net::SocketAddr;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use bytes::{Buf, Bytes, BytesMut};
use pin_project_lite::pin_project;
use strum_macros::FromRepr;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use zerocopy::{FromBytes, Immutable, KnownLayout, Unaligned};
pin_project! {
/// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
@@ -58,9 +62,84 @@ const HEADER: [u8; 12] = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
];
#[derive(PartialEq, Eq, Clone, Debug)]
pub struct ConnectionInfo {
pub addr: SocketAddr,
pub extra: Option<ConnectionInfoExtra>,
}
impl fmt::Display for ConnectionInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.extra {
None => self.addr.ip().fmt(f),
Some(ConnectionInfoExtra::Aws { vpce_id }) => {
write!(f, "vpce_id[{vpce_id:?}]:addr[{}]", self.addr.ip())
}
Some(ConnectionInfoExtra::Azure { link_id }) => {
write!(f, "link_id[{link_id}]:addr[{}]", self.addr.ip())
}
}
}
}
#[derive(PartialEq, Eq, Clone, Debug)]
pub enum ConnectionInfoExtra {
Aws { vpce_id: Bytes },
Azure { link_id: u32 },
}
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(C)]
struct ProxyProtocolV2Header {
identifier: [u8; 12],
version_and_command: u8,
protocol_and_family: u8,
len: zerocopy::byteorder::network_endian::U16,
}
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(C)]
struct ProxyProtocolV2HeaderV4 {
src_addr: NetworkEndianIpv4,
dst_addr: NetworkEndianIpv4,
src_port: zerocopy::byteorder::network_endian::U16,
dst_port: zerocopy::byteorder::network_endian::U16,
}
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(C)]
struct ProxyProtocolV2HeaderV6 {
src_addr: NetworkEndianIpv6,
dst_addr: NetworkEndianIpv6,
src_port: zerocopy::byteorder::network_endian::U16,
dst_port: zerocopy::byteorder::network_endian::U16,
}
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(transparent)]
struct NetworkEndianIpv4(zerocopy::byteorder::network_endian::U32);
impl NetworkEndianIpv4 {
#[inline]
fn get(self) -> Ipv4Addr {
Ipv4Addr::from_bits(self.0.get())
}
}
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(transparent)]
struct NetworkEndianIpv6(zerocopy::byteorder::network_endian::U128);
impl NetworkEndianIpv6 {
#[inline]
fn get(self) -> Ipv6Addr {
Ipv6Addr::from_bits(self.0.get())
}
}
pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
mut read: T,
) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
) -> std::io::Result<(ChainRW<T>, Option<ConnectionInfo>)> {
let mut buf = BytesMut::with_capacity(128);
while buf.len() < 16 {
let bytes_read = read.read_buf(&mut buf).await?;
@@ -77,14 +156,15 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
};
}
let header = buf.split_to(16);
let header = buf
.try_get::<ProxyProtocolV2Header>()
.expect("we have checked the length already, so this should not panic");
// The next byte (the 13th one) is the protocol version and command.
// The highest four bits contains the version. As of this specification, it must
// always be sent as \x2 and the receiver must only accept this value.
let vc = header[12];
let version = vc >> 4;
let command = vc & 0b1111;
let version = header.version_and_command >> 4;
let command = header.version_and_command & 0b1111;
if version != 2 {
return Err(io::Error::new(
io::ErrorKind::Other,
@@ -115,18 +195,13 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
// The 14th byte contains the transport protocol and address family. The highest 4
// bits contain the address family, the lowest 4 bits contain the protocol.
let ft = header[13];
let address_length = match ft {
// - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
// protocol family. Address length is 2*4 + 2*2 = 12 bytes.
// - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
// protocol family. Address length is 2*4 + 2*2 = 12 bytes.
0x11 | 0x12 => 12,
// - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
// protocol family. Address length is 2*16 + 2*2 = 36 bytes.
// - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
// protocol family. Address length is 2*16 + 2*2 = 36 bytes.
0x21 | 0x22 => 36,
let address_length = match header.protocol_and_family {
// - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET protocol family.
// - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET protocol family.
0x11 | 0x12 => size_of::<ProxyProtocolV2HeaderV4>(),
// - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6 protocol family.
// - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6 protocol family.
0x21 | 0x22 => size_of::<ProxyProtocolV2HeaderV6>(),
// unspecified or unix stream. ignore the addresses
_ => 0,
};
@@ -140,16 +215,15 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
// of bytes and must not assume zero is presented for LOCAL connections. When a
// receiver accepts an incoming connection showing an UNSPEC address family or
// protocol, it may or may not decide to log the address information if present.
let remaining_length = u16::from_be_bytes(header[14..16].try_into().unwrap());
let remaining_length = usize::from(header.len.get());
if remaining_length < address_length {
return Err(io::Error::new(
io::ErrorKind::Other,
"invalid proxy protocol length. not enough to fit requested IP addresses",
));
}
drop(header);
while buf.len() < remaining_length as usize {
while buf.len() < remaining_length {
if read.read_buf(&mut buf).await? == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
@@ -164,22 +238,106 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
// - destination layer 3 address in network byte order
// - source layer 4 address if any, in network byte order (port)
// - destination layer 4 address if any, in network byte order (port)
let addresses = buf.split_to(remaining_length as usize);
let mut header = buf.split_to(remaining_length);
let socket = match address_length {
12 => {
let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
Some(SocketAddr::from((src_addr, src_port)))
let addr = header
.try_get::<ProxyProtocolV2HeaderV4>()
.expect("we have verified that 12 bytes are in the buf");
Some(SocketAddr::from((addr.src_addr.get(), addr.src_port.get())))
}
36 => {
let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
Some(SocketAddr::from((src_addr, src_port)))
let addr = header
.try_get::<ProxyProtocolV2HeaderV6>()
.expect("we have verified that 36 bytes are in the buf");
Some(SocketAddr::from((addr.src_addr.get(), addr.src_port.get())))
}
_ => None,
};
Ok((ChainRW { inner: read, buf }, socket))
let mut extra = None;
while let Some(mut tlv) = read_tlv(&mut header) {
match Pp2Kind::from_repr(tlv.kind) {
Some(Pp2Kind::Aws) => {
if tlv.value.is_empty() {
tracing::warn!("invalid aws tlv: no subtype");
}
let subtype = tlv.value.get_u8();
match Pp2AwsType::from_repr(subtype) {
Some(Pp2AwsType::VpceId) => {
extra = Some(ConnectionInfoExtra::Aws { vpce_id: tlv.value });
}
None => {
tracing::warn!("unknown aws tlv: subtype={subtype}");
}
}
}
Some(Pp2Kind::Azure) => {
if tlv.value.is_empty() {
tracing::warn!("invalid azure tlv: no subtype");
}
let subtype = tlv.value.get_u8();
match Pp2AzureType::from_repr(subtype) {
Some(Pp2AzureType::PrivateEndpointLinkId) => {
if tlv.value.len() != 4 {
tracing::warn!("invalid azure link_id: {:?}", tlv.value);
}
extra = Some(ConnectionInfoExtra::Azure {
link_id: tlv.value.get_u32_le(),
});
}
None => {
tracing::warn!("unknown azure tlv: subtype={subtype}");
}
}
}
Some(kind) => {
tracing::debug!("unused tlv[{kind:?}]: {:?}", tlv.value);
}
None => {
tracing::debug!("unknown tlv: {tlv:?}");
}
}
}
let conn_info = socket.map(|addr| ConnectionInfo { addr, extra });
Ok((ChainRW { inner: read, buf }, conn_info))
}
#[derive(FromRepr, Debug, Copy, Clone)]
#[repr(u8)]
enum Pp2Kind {
// The following are defined by https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt
// we don't use these but it would be interesting to know what's available
Alpn = 0x01,
Authority = 0x02,
Crc32C = 0x03,
Noop = 0x04,
UniqueId = 0x05,
Ssl = 0x20,
NetNs = 0x30,
/// <https://docs.aws.amazon.com/elasticloadbalancing/latest/network/edit-target-group-attributes.html#proxy-protocol>
Aws = 0xEA,
/// <https://learn.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2>
Azure = 0xEE,
}
#[derive(FromRepr, Debug, Copy, Clone)]
#[repr(u8)]
enum Pp2AwsType {
VpceId = 0x01,
}
#[derive(FromRepr, Debug, Copy, Clone)]
#[repr(u8)]
enum Pp2AzureType {
PrivateEndpointLinkId = 0x01,
}
impl<T: AsyncRead> AsyncRead for ChainRW<T> {
@@ -216,6 +374,60 @@ impl<T: AsyncRead> ChainRW<T> {
}
}
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(C)]
struct TlvHeader {
kind: u8,
len: zerocopy::byteorder::network_endian::U16,
}
#[derive(Debug)]
struct Tlv {
kind: u8,
value: Bytes,
}
fn read_tlv(b: &mut BytesMut) -> Option<Tlv> {
let tlv_header = b.try_get::<TlvHeader>().ok()?;
let len = usize::from(tlv_header.len.get());
if b.len() < len {
return None;
}
let value = b.split_to(len).freeze();
Some(Tlv {
kind: tlv_header.kind,
value,
})
}
trait BufExt: Sized {
fn try_get<T: zerocopy::FromBytes>(&mut self)
-> Result<T, zerocopy::error::SizeError<Self, T>>;
// fn peek<T: zerocopy::FromBytes + zerocopy::KnownLayout + zerocopy::Immutable>(
// &self,
// ) -> Option<&T>;
}
impl BufExt for BytesMut {
fn try_get<T: zerocopy::FromBytes>(
&mut self,
) -> Result<T, zerocopy::error::SizeError<Self, T>> {
let len = size_of::<T>();
// this will error in the read_from_bytes if the buf is too small
let len = usize::min(len, self.len());
let buf = self.split_to(len);
T::read_from_bytes(&buf).map_err(|e| e.map_src(|_| buf.clone()))
}
// fn peek<T: zerocopy::FromBytes + zerocopy::KnownLayout + zerocopy::Immutable>(
// &self,
// ) -> Option<&T> {
// T::ref_from_prefix(self).ok().map(|(t, _)| t)
// }
}
#[cfg(test)]
mod tests {
use tokio::io::AsyncReadExt;
@@ -242,7 +454,7 @@ mod tests {
let extra_data = [0x55; 256];
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
.await
.unwrap();
@@ -250,7 +462,9 @@ mod tests {
read.read_to_end(&mut bytes).await.unwrap();
assert_eq!(bytes, extra_data);
assert_eq!(addr, Some(([127, 0, 0, 1], 65535).into()));
let info = info.unwrap();
assert_eq!(info.addr, ([127, 0, 0, 1], 65535).into());
}
#[tokio::test]
@@ -273,7 +487,7 @@ mod tests {
let extra_data = [0x55; 256];
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
.await
.unwrap();
@@ -281,9 +495,11 @@ mod tests {
read.read_to_end(&mut bytes).await.unwrap();
assert_eq!(bytes, extra_data);
let info = info.unwrap();
assert_eq!(
addr,
Some(([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into())
info.addr,
([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
);
}
@@ -291,30 +507,31 @@ mod tests {
async fn test_invalid() {
let data = [0x55; 256];
let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
let mut bytes = vec![];
read.read_to_end(&mut bytes).await.unwrap();
assert_eq!(bytes, data);
assert_eq!(addr, None);
assert_eq!(info, None);
}
#[tokio::test]
async fn test_short() {
let data = [0x55; 10];
let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
let mut bytes = vec![];
read.read_to_end(&mut bytes).await.unwrap();
assert_eq!(bytes, data);
assert_eq!(addr, None);
assert_eq!(info, None);
}
#[tokio::test]
async fn test_large_tlv() {
let tlv = vec![0x55; 32768];
let len = (12 + tlv.len() as u16).to_be_bytes();
let tlv_len = (tlv.len() as u16).to_be_bytes();
let len = (12 + 3 + tlv.len() as u16).to_be_bytes();
let header = super::HEADER
// Proxy command, Inet << 4 | Stream
@@ -330,11 +547,13 @@ mod tests {
// dst port
.chain([1, 1].as_slice())
// TLV
.chain([255].as_slice())
.chain(tlv_len.as_slice())
.chain(tlv.as_slice());
let extra_data = [0xaa; 256];
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
.await
.unwrap();
@@ -342,6 +561,8 @@ mod tests {
read.read_to_end(&mut bytes).await.unwrap();
assert_eq!(bytes, extra_data);
assert_eq!(addr, Some(([55, 56, 57, 58], 65535).into()));
let info = info.unwrap();
assert_eq!(info.addr, ([55, 56, 57, 58], 65535).into());
}
}

View File

@@ -28,7 +28,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::read_proxy_protocol;
use crate::protocol2::{read_proxy_protocol, ConnectionInfo};
use crate::proxy::handshake::{handshake, HandshakeData};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
@@ -87,7 +87,7 @@ pub async fn task_main(
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
let (socket, conn_info) = match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
return;
@@ -100,8 +100,8 @@ pub async fn task_main(
warn!("proxy protocol header not supported");
return;
}
Ok((socket, Some(addr))) => (socket, addr.ip()),
Ok((socket, None)) => (socket, peer_addr.ip()),
Ok((socket, Some(info))) => (socket, info),
Ok((socket, None)) => (socket, ConnectionInfo { addr: peer_addr, extra: None }),
};
match socket.inner.set_nodelay(true) {
@@ -114,7 +114,7 @@ pub async fn task_main(
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);

View File

@@ -44,10 +44,10 @@ use tracing::{info, warn, Instrument};
use utils::http::error::ApiError;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestMonitoring;
use crate::metrics::Metrics;
use crate::protocol2::{read_proxy_protocol, ChainRW};
use crate::protocol2::{read_proxy_protocol, ChainRW, ConnectionInfo};
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
@@ -180,7 +180,7 @@ pub async fn task_main(
peer_addr,
))
.await;
let Some((conn, peer_addr)) = startup_result else {
let Some((conn, conn_info)) = startup_result else {
return;
};
@@ -192,7 +192,7 @@ pub async fn task_main(
endpoint_rate_limiter,
conn_token,
conn,
peer_addr,
conn_info,
session_id,
))
.await;
@@ -240,7 +240,7 @@ async fn connection_startup(
session_id: uuid::Uuid,
conn: TcpStream,
peer_addr: SocketAddr,
) -> Option<(AsyncRW, IpAddr)> {
) -> Option<(AsyncRW, ConnectionInfo)> {
// handle PROXY protocol
let (conn, peer) = match read_proxy_protocol(conn).await {
Ok(c) => c,
@@ -250,17 +250,32 @@ async fn connection_startup(
}
};
let peer_addr = peer.unwrap_or(peer_addr).ip();
let has_private_peer_addr = match peer_addr {
let conn_info = match peer {
None if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
tracing::warn!("missing required proxy protocol header");
return None;
}
Some(_) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
tracing::warn!("proxy protocol header not supported");
return None;
}
Some(info) => info,
None => ConnectionInfo {
addr: peer_addr,
extra: None,
},
};
let has_private_peer_addr = match conn_info.addr.ip() {
IpAddr::V4(ip) => ip.is_private(),
IpAddr::V6(_) => false,
};
info!(?session_id, %peer_addr, "accepted new TCP connection");
info!(?session_id, %conn_info, "accepted new TCP connection");
// try upgrade to TLS, but with a timeout.
let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
Ok(Ok(conn)) => {
info!(?session_id, %peer_addr, "accepted new TLS connection");
info!(?session_id, %conn_info, "accepted new TLS connection");
conn
}
// The handshake failed
@@ -268,7 +283,7 @@ async fn connection_startup(
if !has_private_peer_addr {
Metrics::get().proxy.tls_handshake_failures.inc();
}
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
return None;
}
// The handshake timed out
@@ -276,12 +291,12 @@ async fn connection_startup(
if !has_private_peer_addr {
Metrics::get().proxy.tls_handshake_failures.inc();
}
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
return None;
}
};
Some((conn, peer_addr))
Some((conn, conn_info))
}
/// Handles HTTP connection
@@ -297,7 +312,7 @@ async fn connection_handler(
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
conn: AsyncRW,
peer_addr: IpAddr,
conn_info: ConnectionInfo,
session_id: uuid::Uuid,
) {
let session_id = AtomicTake::new(session_id);
@@ -306,6 +321,7 @@ async fn connection_handler(
let http_cancellation_token = CancellationToken::new();
let _cancel_connection = http_cancellation_token.clone().drop_guard();
let conn_info2 = conn_info.clone();
let server = Builder::new(TokioExecutor::new());
let conn = server.serve_connection_with_upgrades(
hyper_util::rt::TokioIo::new(conn),
@@ -340,7 +356,7 @@ async fn connection_handler(
connections.clone(),
cancellation_handler.clone(),
session_id,
peer_addr,
conn_info2.clone(),
http_request_token,
endpoint_rate_limiter.clone(),
)
@@ -365,7 +381,7 @@ async fn connection_handler(
// On cancellation, trigger the HTTP connection handler to shut down.
let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
Either::Left((_cancelled, mut conn)) => {
tracing::debug!(%peer_addr, "cancelling connection");
tracing::debug!(%conn_info, "cancelling connection");
conn.as_mut().graceful_shutdown();
conn.await
}
@@ -373,8 +389,8 @@ async fn connection_handler(
};
match res {
Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
Ok(()) => tracing::info!(%conn_info, "HTTP connection closed"),
Err(e) => tracing::warn!(%conn_info, "HTTP connection error {e}"),
}
}
@@ -386,7 +402,7 @@ async fn request_handler(
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
session_id: uuid::Uuid,
peer_addr: IpAddr,
conn_info: ConnectionInfo,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -404,7 +420,7 @@ async fn request_handler(
{
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
conn_info,
crate::metrics::Protocol::Ws,
&config.region,
);
@@ -439,7 +455,7 @@ async fn request_handler(
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
conn_info,
crate::metrics::Protocol::Http,
&config.region,
);

View File

@@ -67,7 +67,7 @@ exclude = [
check_untyped_defs = true
# Help mypy find imports when running against list of individual files.
# Without this line it would behave differently when executed on the entire project.
mypy_path = "$MYPY_CONFIG_FILE_DIR:$MYPY_CONFIG_FILE_DIR/test_runner"
mypy_path = "$MYPY_CONFIG_FILE_DIR:$MYPY_CONFIG_FILE_DIR/test_runner:$MYPY_CONFIG_FILE_DIR/test_runner/stubs"
disallow_incomplete_defs = false
disallow_untyped_calls = false

View File

@@ -26,10 +26,11 @@ use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::sync::mpsc::error::SendTimeoutError;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::task;
use tokio::task::JoinHandle;
use tokio::time::{Duration, MissedTickBehavior};
use tokio::time::{Duration, Instant, MissedTickBehavior};
use tracing::*;
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
@@ -384,9 +385,29 @@ async fn read_network_loop<IO: AsyncRead + AsyncWrite + Unpin>(
msg_tx: Sender<ProposerAcceptorMessage>,
mut next_msg: ProposerAcceptorMessage,
) -> Result<(), CopyStreamHandlerEnd> {
/// Threshold for logging slow WalAcceptor sends.
const SLOW_THRESHOLD: Duration = Duration::from_secs(5);
loop {
if msg_tx.send(next_msg).await.is_err() {
return Ok(()); // chan closed, WalAcceptor terminated
let started = Instant::now();
match msg_tx.send_timeout(next_msg, SLOW_THRESHOLD).await {
Ok(()) => {}
// Slow send, log a message and keep trying. Log context has timeline ID.
Err(SendTimeoutError::Timeout(next_msg)) => {
warn!(
"slow WalAcceptor send blocked for {:.3}s",
Instant::now().duration_since(started).as_secs_f64()
);
if msg_tx.send(next_msg).await.is_err() {
return Ok(()); // WalAcceptor terminated
}
warn!(
"slow WalAcceptor send completed after {:.3}s",
Instant::now().duration_since(started).as_secs_f64()
)
}
// WalAcceptor terminated.
Err(SendTimeoutError::Closed(_)) => return Ok(()),
}
next_msg = read_message(pgb_reader).await?;
}

View File

@@ -658,7 +658,7 @@ async fn handle_node_register(req: Request<Body>) -> Result<Response<Body>, ApiE
}
async fn handle_node_list(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
check_permissions(&req, Scope::Infra)?;
let req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
@@ -737,7 +737,7 @@ async fn handle_node_configure(req: Request<Body>) -> Result<Response<Body>, Api
}
async fn handle_node_status(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
check_permissions(&req, Scope::Infra)?;
let req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
@@ -786,7 +786,7 @@ async fn handle_get_leader(req: Request<Body>) -> Result<Response<Body>, ApiErro
}
async fn handle_node_drain(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
check_permissions(&req, Scope::Infra)?;
let req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
@@ -804,7 +804,7 @@ async fn handle_node_drain(req: Request<Body>) -> Result<Response<Body>, ApiErro
}
async fn handle_cancel_node_drain(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
check_permissions(&req, Scope::Infra)?;
let req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
@@ -822,7 +822,7 @@ async fn handle_cancel_node_drain(req: Request<Body>) -> Result<Response<Body>,
}
async fn handle_node_fill(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
check_permissions(&req, Scope::Infra)?;
let req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
@@ -840,7 +840,7 @@ async fn handle_node_fill(req: Request<Body>) -> Result<Response<Body>, ApiError
}
async fn handle_cancel_node_fill(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
check_permissions(&req, Scope::Infra)?;
let req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {

View File

@@ -1,5 +1,6 @@
use std::{sync::Arc, time::Duration};
use pageserver_api::controller_api::ShardSchedulingPolicy;
use rand::seq::SliceRandom;
use rand::thread_rng;
use tokio_util::sync::CancellationToken;
@@ -47,6 +48,16 @@ impl ChaosInjector {
.get_mut(victim)
.expect("Held lock between choosing ID and this get");
if !matches!(shard.get_scheduling_policy(), ShardSchedulingPolicy::Active) {
// Skip non-active scheduling policies, so that a shard with a policy like Pause can
// be pinned without being disrupted by us.
tracing::info!(
"Skipping shard {victim}: scheduling policy is {:?}",
shard.get_scheduling_policy()
);
continue;
}
// Pick a secondary to promote
let Some(new_location) = shard
.intent
@@ -63,6 +74,8 @@ impl ChaosInjector {
continue;
};
tracing::info!("Injecting chaos: migrate {victim} {old_location}->{new_location}");
shard.intent.demote_attached(scheduler, old_location);
shard.intent.promote_attached(scheduler, new_location);
self.service.maybe_reconcile_shard(shard, nodes);

View File

@@ -1397,7 +1397,7 @@ def neon_simple_env(
pageserver_virtual_file_io_mode: Optional[str],
) -> Iterator[NeonEnv]:
"""
Simple Neon environment, with no authentication and no safekeepers.
Simple Neon environment, with 1 safekeeper and 1 pageserver. No authentication, no fsync.
This fixture will use RemoteStorageKind.LOCAL_FS with pageserver.
"""
@@ -4701,6 +4701,7 @@ def tenant_get_shards(
If the caller provides `pageserver_id`, it will be used for all shards, even
if the shard is indicated by storage controller to be on some other pageserver.
If the storage controller is not running, assume an unsharded tenant.
Caller should over the response to apply their per-pageserver action to
each shard
@@ -4710,17 +4711,17 @@ def tenant_get_shards(
else:
override_pageserver = None
if len(env.pageservers) > 1:
return [
(
TenantShardId.parse(s["shard_id"]),
override_pageserver or env.get_pageserver(s["node_id"]),
)
for s in env.storage_controller.locate(tenant_id)
]
else:
# Assume an unsharded tenant
return [(TenantShardId(tenant_id, 0, 0), override_pageserver or env.pageserver)]
if not env.storage_controller.running and override_pageserver is not None:
log.warning(f"storage controller not running, assuming unsharded tenant {tenant_id}")
return [(TenantShardId(tenant_id, 0, 0), override_pageserver)]
return [
(
TenantShardId.parse(s["shard_id"]),
override_pageserver or env.get_pageserver(s["node_id"]),
)
for s in env.storage_controller.locate(tenant_id)
]
def wait_replica_caughtup(primary: Endpoint, secondary: Endpoint):