mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-20 22:50:38 +00:00
Compare commits
3 Commits
erik/safek
...
zerocopy-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bea8269a17 | ||
|
|
3500a758af | ||
|
|
3b3c2da57f |
35
Cargo.lock
generated
35
Cargo.lock
generated
@@ -34,7 +34,7 @@ dependencies = [
|
||||
"getrandom 0.2.11",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
"zerocopy",
|
||||
"zerocopy 0.7.31",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -951,7 +951,7 @@ dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"log",
|
||||
@@ -974,7 +974,7 @@ dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.12.1",
|
||||
"log",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
@@ -4314,7 +4314,7 @@ checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"heck 0.5.0",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.12.1",
|
||||
"log",
|
||||
"multimap",
|
||||
"once_cell",
|
||||
@@ -4334,7 +4334,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.12.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.52",
|
||||
@@ -4433,6 +4433,8 @@ dependencies = [
|
||||
"smallvec",
|
||||
"smol_str",
|
||||
"socket2",
|
||||
"strum",
|
||||
"strum_macros",
|
||||
"subtle",
|
||||
"thiserror",
|
||||
"tikv-jemalloc-ctl",
|
||||
@@ -4455,6 +4457,7 @@ dependencies = [
|
||||
"walkdir",
|
||||
"workspace_hack",
|
||||
"x509-parser",
|
||||
"zerocopy 0.8.8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7550,7 +7553,16 @@ version = "0.7.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d"
|
||||
dependencies = [
|
||||
"zerocopy-derive",
|
||||
"zerocopy-derive 0.7.31",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.8.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a4e33e6dce36f2adba29746927f8e848ba70989fdb61c772773bbdda8b5d6a7"
|
||||
dependencies = [
|
||||
"zerocopy-derive 0.8.8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7564,6 +7576,17 @@ dependencies = [
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy-derive"
|
||||
version = "0.8.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3cd137b4cc21bde6ecce3bbbb3350130872cda0be2c6888874279ea76e17d4c1"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeroize"
|
||||
version = "1.7.0"
|
||||
|
||||
@@ -562,7 +562,6 @@ pub enum BeMessage<'a> {
|
||||
options: &'a [&'a str],
|
||||
},
|
||||
KeepAlive(WalSndKeepAlive),
|
||||
NeonInterpretedWalRecord(&'a [u8]), // TODO: use appropriate fields
|
||||
}
|
||||
|
||||
/// Common shorthands.
|
||||
@@ -997,17 +996,6 @@ impl BeMessage<'_> {
|
||||
Ok(())
|
||||
})?
|
||||
}
|
||||
|
||||
// Neon extension: send interpreted WAL records to relevant pageservers. This is
|
||||
// temporary until we move to a different protocol for Safekeeper->Pageserver WAL
|
||||
// (possibly gRPC).
|
||||
BeMessage::NeonInterpretedWalRecord(data) => {
|
||||
buf.put_u8(b'z'); // arbitrary unused value
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u64(data.len() as u64);
|
||||
buf.put_slice(data);
|
||||
})
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ use crate::{
|
||||
CreateTimelineCause, DeleteTimelineError, MaybeDeletedIndexPart, Tenant,
|
||||
TimelineOrOffloaded,
|
||||
},
|
||||
virtual_file::MaybeFatalIo,
|
||||
};
|
||||
|
||||
use super::{Timeline, TimelineResources};
|
||||
@@ -63,10 +62,10 @@ pub(super) async fn delete_local_timeline_directory(
|
||||
conf: &PageServerConf,
|
||||
tenant_shard_id: TenantShardId,
|
||||
timeline: &Timeline,
|
||||
) {
|
||||
) -> anyhow::Result<()> {
|
||||
// Always ensure the lock order is compaction -> gc.
|
||||
let compaction_lock = timeline.compaction_lock.lock();
|
||||
let _compaction_lock = crate::timed(
|
||||
let compaction_lock = crate::timed(
|
||||
compaction_lock,
|
||||
"acquires compaction lock",
|
||||
std::time::Duration::from_secs(5),
|
||||
@@ -74,7 +73,7 @@ pub(super) async fn delete_local_timeline_directory(
|
||||
.await;
|
||||
|
||||
let gc_lock = timeline.gc_lock.lock();
|
||||
let _gc_lock = crate::timed(
|
||||
let gc_lock = crate::timed(
|
||||
gc_lock,
|
||||
"acquires gc lock",
|
||||
std::time::Duration::from_secs(5),
|
||||
@@ -86,15 +85,24 @@ pub(super) async fn delete_local_timeline_directory(
|
||||
|
||||
let local_timeline_directory = conf.timeline_path(&tenant_shard_id, &timeline.timeline_id);
|
||||
|
||||
fail::fail_point!("timeline-delete-before-rm", |_| {
|
||||
Err(anyhow::anyhow!("failpoint: timeline-delete-before-rm"))?
|
||||
});
|
||||
|
||||
// NB: This need not be atomic because the deleted flag in the IndexPart
|
||||
// will be observed during tenant/timeline load. The deletion will be resumed there.
|
||||
//
|
||||
// ErrorKind::NotFound can happen e.g. if we race with tenant detach, because,
|
||||
// Note that here we do not bail out on std::io::ErrorKind::NotFound.
|
||||
// This can happen if we're called a second time, e.g.,
|
||||
// because of a previous failure/cancellation at/after
|
||||
// failpoint timeline-delete-after-rm.
|
||||
//
|
||||
// ErrorKind::NotFound can also happen if we race with tenant detach, because,
|
||||
// no locks are shared.
|
||||
tokio::fs::remove_dir_all(local_timeline_directory)
|
||||
.await
|
||||
.or_else(fs_ext::ignore_not_found)
|
||||
.fatal_err("removing timeline directory");
|
||||
.context("remove local timeline directory")?;
|
||||
|
||||
// Make sure previous deletions are ordered before mark removal.
|
||||
// Otherwise there is no guarantee that they reach the disk before mark deletion.
|
||||
@@ -105,9 +113,17 @@ pub(super) async fn delete_local_timeline_directory(
|
||||
let timeline_path = conf.timelines_path(&tenant_shard_id);
|
||||
crashsafe::fsync_async(timeline_path)
|
||||
.await
|
||||
.fatal_err("fsync after removing timeline directory");
|
||||
.context("fsync_pre_mark_remove")?;
|
||||
|
||||
info!("finished deleting layer files, releasing locks");
|
||||
drop(gc_lock);
|
||||
drop(compaction_lock);
|
||||
|
||||
fail::fail_point!("timeline-delete-after-rm", |_| {
|
||||
Err(anyhow::anyhow!("failpoint: timeline-delete-after-rm"))?
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes remote layers and an index file after them.
|
||||
@@ -424,20 +440,12 @@ impl DeleteTimelineFlow {
|
||||
timeline: &TimelineOrOffloaded,
|
||||
remote_client: Arc<RemoteTimelineClient>,
|
||||
) -> Result<(), DeleteTimelineError> {
|
||||
fail::fail_point!("timeline-delete-before-rm", |_| {
|
||||
Err(anyhow::anyhow!("failpoint: timeline-delete-before-rm"))?
|
||||
});
|
||||
|
||||
// Offloaded timelines have no local state
|
||||
// TODO: once we persist offloaded information, delete the timeline from there, too
|
||||
if let TimelineOrOffloaded::Timeline(timeline) = timeline {
|
||||
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await;
|
||||
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await?;
|
||||
}
|
||||
|
||||
fail::fail_point!("timeline-delete-after-rm", |_| {
|
||||
Err(anyhow::anyhow!("failpoint: timeline-delete-after-rm"))?
|
||||
});
|
||||
|
||||
delete_remote_layers_and_index(&remote_client).await?;
|
||||
|
||||
pausable_failpoint!("in_progress_delete");
|
||||
|
||||
@@ -67,7 +67,9 @@ pub(crate) async fn offload_timeline(
|
||||
// to make deletions possible while offloading is in progress
|
||||
|
||||
let conf = &tenant.conf;
|
||||
delete_local_timeline_directory(conf, tenant.tenant_shard_id, &timeline).await;
|
||||
delete_local_timeline_directory(conf, tenant.tenant_shard_id, &timeline)
|
||||
.await
|
||||
.map_err(OffloadError::Other)?;
|
||||
|
||||
remove_timeline_from_tenant(tenant, &timeline, &guard);
|
||||
|
||||
|
||||
@@ -67,10 +67,7 @@ pub(crate) fn apply_in_neon(
|
||||
let map = &mut page[pg_constants::MAXALIGN_SIZE_OF_PAGE_HEADER_DATA..];
|
||||
|
||||
map[map_byte as usize] &= !(flags << map_offset);
|
||||
// The page should never be empty, but we're checking it anyway as a precaution, so that if it is empty for some reason anyway, we don't make matters worse by setting the LSN on it.
|
||||
if !postgres_ffi::page_is_new(page) {
|
||||
postgres_ffi::page_set_lsn(page, lsn);
|
||||
}
|
||||
postgres_ffi::page_set_lsn(page, lsn);
|
||||
}
|
||||
|
||||
// Repeat for 'old_heap_blkno', if any
|
||||
@@ -84,10 +81,7 @@ pub(crate) fn apply_in_neon(
|
||||
let map = &mut page[pg_constants::MAXALIGN_SIZE_OF_PAGE_HEADER_DATA..];
|
||||
|
||||
map[map_byte as usize] &= !(flags << map_offset);
|
||||
// The page should never be empty, but we're checking it anyway as a precaution, so that if it is empty for some reason anyway, we don't make matters worse by setting the LSN on it.
|
||||
if !postgres_ffi::page_is_new(page) {
|
||||
postgres_ffi::page_set_lsn(page, lsn);
|
||||
}
|
||||
postgres_ffi::page_set_lsn(page, lsn);
|
||||
}
|
||||
}
|
||||
// Non-relational WAL records are handled here, with custom code that has the
|
||||
|
||||
@@ -74,6 +74,8 @@ sha2 = { workspace = true, features = ["asm", "oid"] }
|
||||
smol_str.workspace = true
|
||||
smallvec.workspace = true
|
||||
socket2.workspace = true
|
||||
strum.workspace = true
|
||||
strum_macros.workspace = true
|
||||
subtle.workspace = true
|
||||
thiserror.workspace = true
|
||||
tikv-jemallocator.workspace = true
|
||||
@@ -96,6 +98,7 @@ rustls-native-certs.workspace = true
|
||||
x509-parser.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
redis.workspace = true
|
||||
zerocopy = { version = "0.8", features = ["derive"] }
|
||||
|
||||
# jwt stuff
|
||||
jose-jwa = "0.1.2"
|
||||
|
||||
@@ -13,6 +13,7 @@ use itertools::Itertools;
|
||||
use proxy::config::TlsServerEndPoint;
|
||||
use proxy::context::RequestMonitoring;
|
||||
use proxy::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use proxy::protocol2::ConnectionInfo;
|
||||
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
|
||||
use proxy::stream::{PqStream, Stream};
|
||||
use rustls::crypto::aws_lc_rs;
|
||||
@@ -179,7 +180,10 @@ async fn task_main(
|
||||
info!(%peer_addr, "serving");
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr.ip(),
|
||||
ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
extra: None,
|
||||
},
|
||||
proxy::metrics::Protocol::SniRouter,
|
||||
"sni",
|
||||
);
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::read_proxy_protocol;
|
||||
use crate::protocol2::{read_proxy_protocol, ConnectionInfo};
|
||||
use crate::proxy::connect_compute::{connect_to_compute, TcpMechanism};
|
||||
use crate::proxy::handshake::{handshake, HandshakeData};
|
||||
use crate::proxy::passthrough::ProxyPassthrough;
|
||||
@@ -65,8 +65,8 @@ pub async fn task_main(
|
||||
error!("proxy protocol header not supported");
|
||||
return;
|
||||
}
|
||||
Ok((socket, Some(addr))) => (socket, addr.ip()),
|
||||
Ok((socket, None)) => (socket, peer_addr.ip()),
|
||||
Ok((socket, Some(info))) => (socket, info),
|
||||
Ok((socket, None)) => (socket, ConnectionInfo{ addr: peer_addr, extra: None }),
|
||||
};
|
||||
|
||||
match socket.inner.set_nodelay(true) {
|
||||
|
||||
@@ -19,6 +19,7 @@ use crate::intern::{BranchIdInt, ProjectIdInt};
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting,
|
||||
};
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
use crate::types::{DbName, EndpointId, RoleName};
|
||||
|
||||
pub mod parquet;
|
||||
@@ -40,7 +41,7 @@ pub struct RequestMonitoring(
|
||||
);
|
||||
|
||||
struct RequestMonitoringInner {
|
||||
pub(crate) peer_addr: IpAddr,
|
||||
pub(crate) conn_info: ConnectionInfo,
|
||||
pub(crate) session_id: Uuid,
|
||||
pub(crate) protocol: Protocol,
|
||||
first_packet: chrono::DateTime<Utc>,
|
||||
@@ -84,7 +85,7 @@ impl Clone for RequestMonitoring {
|
||||
fn clone(&self) -> Self {
|
||||
let inner = self.0.try_lock().expect("should not deadlock");
|
||||
let new = RequestMonitoringInner {
|
||||
peer_addr: inner.peer_addr,
|
||||
conn_info: inner.conn_info.clone(),
|
||||
session_id: inner.session_id,
|
||||
protocol: inner.protocol,
|
||||
first_packet: inner.first_packet,
|
||||
@@ -117,7 +118,7 @@ impl Clone for RequestMonitoring {
|
||||
impl RequestMonitoring {
|
||||
pub fn new(
|
||||
session_id: Uuid,
|
||||
peer_addr: IpAddr,
|
||||
conn_info: ConnectionInfo,
|
||||
protocol: Protocol,
|
||||
region: &'static str,
|
||||
) -> Self {
|
||||
@@ -125,13 +126,13 @@ impl RequestMonitoring {
|
||||
"connect_request",
|
||||
%protocol,
|
||||
?session_id,
|
||||
%peer_addr,
|
||||
%conn_info,
|
||||
ep = tracing::field::Empty,
|
||||
role = tracing::field::Empty,
|
||||
);
|
||||
|
||||
let inner = RequestMonitoringInner {
|
||||
peer_addr,
|
||||
conn_info,
|
||||
session_id,
|
||||
protocol,
|
||||
first_packet: Utc::now(),
|
||||
@@ -162,7 +163,11 @@ impl RequestMonitoring {
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn test() -> Self {
|
||||
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test")
|
||||
use std::net::SocketAddr;
|
||||
let ip = IpAddr::from([127, 0, 0, 1]);
|
||||
let addr = SocketAddr::new(ip, 5432);
|
||||
let conn_info = ConnectionInfo { addr, extra: None };
|
||||
RequestMonitoring::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
|
||||
}
|
||||
|
||||
pub(crate) fn console_application_name(&self) -> String {
|
||||
@@ -286,7 +291,12 @@ impl RequestMonitoring {
|
||||
}
|
||||
|
||||
pub(crate) fn peer_addr(&self) -> IpAddr {
|
||||
self.0.try_lock().expect("should not deadlock").peer_addr
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.conn_info
|
||||
.addr
|
||||
.ip()
|
||||
}
|
||||
|
||||
pub(crate) fn cold_start_info(&self) -> ColdStartInfo {
|
||||
@@ -362,7 +372,7 @@ impl RequestMonitoringInner {
|
||||
}
|
||||
|
||||
fn has_private_peer_addr(&self) -> bool {
|
||||
match self.peer_addr {
|
||||
match self.conn_info.addr.ip() {
|
||||
IpAddr::V4(ip) => ip.is_private(),
|
||||
IpAddr::V6(_) => false,
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ impl From<&RequestMonitoringInner> for RequestData {
|
||||
fn from(value: &RequestMonitoringInner) -> Self {
|
||||
Self {
|
||||
session_id: value.session_id,
|
||||
peer_addr: value.peer_addr.to_string(),
|
||||
peer_addr: value.conn_info.addr.ip().to_string(),
|
||||
timestamp: value.first_packet.naive_utc(),
|
||||
username: value.user.as_deref().map(String::from),
|
||||
application_name: value.application.as_deref().map(String::from),
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
//! Proxy Protocol V2 implementation
|
||||
//! Compatible with <https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt>
|
||||
|
||||
use core::fmt;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use pin_project_lite::pin_project;
|
||||
use strum_macros::FromRepr;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
use zerocopy::{FromBytes, Immutable, KnownLayout, Unaligned};
|
||||
|
||||
pin_project! {
|
||||
/// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
|
||||
@@ -58,9 +62,84 @@ const HEADER: [u8; 12] = [
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
];
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Debug)]
|
||||
pub struct ConnectionInfo {
|
||||
pub addr: SocketAddr,
|
||||
pub extra: Option<ConnectionInfoExtra>,
|
||||
}
|
||||
|
||||
impl fmt::Display for ConnectionInfo {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match &self.extra {
|
||||
None => self.addr.ip().fmt(f),
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => {
|
||||
write!(f, "vpce_id[{vpce_id:?}]:addr[{}]", self.addr.ip())
|
||||
}
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => {
|
||||
write!(f, "link_id[{link_id}]:addr[{}]", self.addr.ip())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Debug)]
|
||||
pub enum ConnectionInfoExtra {
|
||||
Aws { vpce_id: Bytes },
|
||||
Azure { link_id: u32 },
|
||||
}
|
||||
|
||||
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
struct ProxyProtocolV2Header {
|
||||
identifier: [u8; 12],
|
||||
version_and_command: u8,
|
||||
protocol_and_family: u8,
|
||||
len: zerocopy::byteorder::network_endian::U16,
|
||||
}
|
||||
|
||||
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
struct ProxyProtocolV2HeaderV4 {
|
||||
src_addr: NetworkEndianIpv4,
|
||||
dst_addr: NetworkEndianIpv4,
|
||||
src_port: zerocopy::byteorder::network_endian::U16,
|
||||
dst_port: zerocopy::byteorder::network_endian::U16,
|
||||
}
|
||||
|
||||
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
struct ProxyProtocolV2HeaderV6 {
|
||||
src_addr: NetworkEndianIpv6,
|
||||
dst_addr: NetworkEndianIpv6,
|
||||
src_port: zerocopy::byteorder::network_endian::U16,
|
||||
dst_port: zerocopy::byteorder::network_endian::U16,
|
||||
}
|
||||
|
||||
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
|
||||
#[repr(transparent)]
|
||||
struct NetworkEndianIpv4(zerocopy::byteorder::network_endian::U32);
|
||||
|
||||
impl NetworkEndianIpv4 {
|
||||
#[inline]
|
||||
fn get(self) -> Ipv4Addr {
|
||||
Ipv4Addr::from_bits(self.0.get())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
|
||||
#[repr(transparent)]
|
||||
struct NetworkEndianIpv6(zerocopy::byteorder::network_endian::U128);
|
||||
|
||||
impl NetworkEndianIpv6 {
|
||||
#[inline]
|
||||
fn get(self) -> Ipv6Addr {
|
||||
Ipv6Addr::from_bits(self.0.get())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
mut read: T,
|
||||
) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
|
||||
) -> std::io::Result<(ChainRW<T>, Option<ConnectionInfo>)> {
|
||||
let mut buf = BytesMut::with_capacity(128);
|
||||
while buf.len() < 16 {
|
||||
let bytes_read = read.read_buf(&mut buf).await?;
|
||||
@@ -77,14 +156,15 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
};
|
||||
}
|
||||
|
||||
let header = buf.split_to(16);
|
||||
let header = buf
|
||||
.try_get::<ProxyProtocolV2Header>()
|
||||
.expect("we have checked the length already, so this should not panic");
|
||||
|
||||
// The next byte (the 13th one) is the protocol version and command.
|
||||
// The highest four bits contains the version. As of this specification, it must
|
||||
// always be sent as \x2 and the receiver must only accept this value.
|
||||
let vc = header[12];
|
||||
let version = vc >> 4;
|
||||
let command = vc & 0b1111;
|
||||
let version = header.version_and_command >> 4;
|
||||
let command = header.version_and_command & 0b1111;
|
||||
if version != 2 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
@@ -115,18 +195,13 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
|
||||
// The 14th byte contains the transport protocol and address family. The highest 4
|
||||
// bits contain the address family, the lowest 4 bits contain the protocol.
|
||||
let ft = header[13];
|
||||
let address_length = match ft {
|
||||
// - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
|
||||
// protocol family. Address length is 2*4 + 2*2 = 12 bytes.
|
||||
// - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
|
||||
// protocol family. Address length is 2*4 + 2*2 = 12 bytes.
|
||||
0x11 | 0x12 => 12,
|
||||
// - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
|
||||
// protocol family. Address length is 2*16 + 2*2 = 36 bytes.
|
||||
// - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
|
||||
// protocol family. Address length is 2*16 + 2*2 = 36 bytes.
|
||||
0x21 | 0x22 => 36,
|
||||
let address_length = match header.protocol_and_family {
|
||||
// - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET protocol family.
|
||||
// - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET protocol family.
|
||||
0x11 | 0x12 => size_of::<ProxyProtocolV2HeaderV4>(),
|
||||
// - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6 protocol family.
|
||||
// - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6 protocol family.
|
||||
0x21 | 0x22 => size_of::<ProxyProtocolV2HeaderV6>(),
|
||||
// unspecified or unix stream. ignore the addresses
|
||||
_ => 0,
|
||||
};
|
||||
@@ -140,16 +215,15 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
// of bytes and must not assume zero is presented for LOCAL connections. When a
|
||||
// receiver accepts an incoming connection showing an UNSPEC address family or
|
||||
// protocol, it may or may not decide to log the address information if present.
|
||||
let remaining_length = u16::from_be_bytes(header[14..16].try_into().unwrap());
|
||||
let remaining_length = usize::from(header.len.get());
|
||||
if remaining_length < address_length {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"invalid proxy protocol length. not enough to fit requested IP addresses",
|
||||
));
|
||||
}
|
||||
drop(header);
|
||||
|
||||
while buf.len() < remaining_length as usize {
|
||||
while buf.len() < remaining_length {
|
||||
if read.read_buf(&mut buf).await? == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
@@ -164,22 +238,106 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
// - destination layer 3 address in network byte order
|
||||
// - source layer 4 address if any, in network byte order (port)
|
||||
// - destination layer 4 address if any, in network byte order (port)
|
||||
let addresses = buf.split_to(remaining_length as usize);
|
||||
let mut header = buf.split_to(remaining_length);
|
||||
let socket = match address_length {
|
||||
12 => {
|
||||
let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
|
||||
let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
|
||||
Some(SocketAddr::from((src_addr, src_port)))
|
||||
let addr = header
|
||||
.try_get::<ProxyProtocolV2HeaderV4>()
|
||||
.expect("we have verified that 12 bytes are in the buf");
|
||||
|
||||
Some(SocketAddr::from((addr.src_addr.get(), addr.src_port.get())))
|
||||
}
|
||||
36 => {
|
||||
let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
|
||||
let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
|
||||
Some(SocketAddr::from((src_addr, src_port)))
|
||||
let addr = header
|
||||
.try_get::<ProxyProtocolV2HeaderV6>()
|
||||
.expect("we have verified that 36 bytes are in the buf");
|
||||
|
||||
Some(SocketAddr::from((addr.src_addr.get(), addr.src_port.get())))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Ok((ChainRW { inner: read, buf }, socket))
|
||||
let mut extra = None;
|
||||
|
||||
while let Some(mut tlv) = read_tlv(&mut header) {
|
||||
match Pp2Kind::from_repr(tlv.kind) {
|
||||
Some(Pp2Kind::Aws) => {
|
||||
if tlv.value.is_empty() {
|
||||
tracing::warn!("invalid aws tlv: no subtype");
|
||||
}
|
||||
let subtype = tlv.value.get_u8();
|
||||
match Pp2AwsType::from_repr(subtype) {
|
||||
Some(Pp2AwsType::VpceId) => {
|
||||
extra = Some(ConnectionInfoExtra::Aws { vpce_id: tlv.value });
|
||||
}
|
||||
None => {
|
||||
tracing::warn!("unknown aws tlv: subtype={subtype}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Pp2Kind::Azure) => {
|
||||
if tlv.value.is_empty() {
|
||||
tracing::warn!("invalid azure tlv: no subtype");
|
||||
}
|
||||
let subtype = tlv.value.get_u8();
|
||||
match Pp2AzureType::from_repr(subtype) {
|
||||
Some(Pp2AzureType::PrivateEndpointLinkId) => {
|
||||
if tlv.value.len() != 4 {
|
||||
tracing::warn!("invalid azure link_id: {:?}", tlv.value);
|
||||
}
|
||||
extra = Some(ConnectionInfoExtra::Azure {
|
||||
link_id: tlv.value.get_u32_le(),
|
||||
});
|
||||
}
|
||||
None => {
|
||||
tracing::warn!("unknown azure tlv: subtype={subtype}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(kind) => {
|
||||
tracing::debug!("unused tlv[{kind:?}]: {:?}", tlv.value);
|
||||
}
|
||||
None => {
|
||||
tracing::debug!("unknown tlv: {tlv:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let conn_info = socket.map(|addr| ConnectionInfo { addr, extra });
|
||||
|
||||
Ok((ChainRW { inner: read, buf }, conn_info))
|
||||
}
|
||||
|
||||
#[derive(FromRepr, Debug, Copy, Clone)]
|
||||
#[repr(u8)]
|
||||
enum Pp2Kind {
|
||||
// The following are defined by https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt
|
||||
// we don't use these but it would be interesting to know what's available
|
||||
Alpn = 0x01,
|
||||
Authority = 0x02,
|
||||
Crc32C = 0x03,
|
||||
Noop = 0x04,
|
||||
UniqueId = 0x05,
|
||||
Ssl = 0x20,
|
||||
NetNs = 0x30,
|
||||
|
||||
/// <https://docs.aws.amazon.com/elasticloadbalancing/latest/network/edit-target-group-attributes.html#proxy-protocol>
|
||||
Aws = 0xEA,
|
||||
|
||||
/// <https://learn.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2>
|
||||
Azure = 0xEE,
|
||||
}
|
||||
|
||||
#[derive(FromRepr, Debug, Copy, Clone)]
|
||||
#[repr(u8)]
|
||||
enum Pp2AwsType {
|
||||
VpceId = 0x01,
|
||||
}
|
||||
|
||||
#[derive(FromRepr, Debug, Copy, Clone)]
|
||||
#[repr(u8)]
|
||||
enum Pp2AzureType {
|
||||
PrivateEndpointLinkId = 0x01,
|
||||
}
|
||||
|
||||
impl<T: AsyncRead> AsyncRead for ChainRW<T> {
|
||||
@@ -216,6 +374,60 @@ impl<T: AsyncRead> ChainRW<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
struct TlvHeader {
|
||||
kind: u8,
|
||||
len: zerocopy::byteorder::network_endian::U16,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Tlv {
|
||||
kind: u8,
|
||||
value: Bytes,
|
||||
}
|
||||
|
||||
fn read_tlv(b: &mut BytesMut) -> Option<Tlv> {
|
||||
let tlv_header = b.try_get::<TlvHeader>().ok()?;
|
||||
let len = usize::from(tlv_header.len.get());
|
||||
if b.len() < len {
|
||||
return None;
|
||||
}
|
||||
let value = b.split_to(len).freeze();
|
||||
Some(Tlv {
|
||||
kind: tlv_header.kind,
|
||||
value,
|
||||
})
|
||||
}
|
||||
|
||||
trait BufExt: Sized {
|
||||
fn try_get<T: zerocopy::FromBytes>(&mut self)
|
||||
-> Result<T, zerocopy::error::SizeError<Self, T>>;
|
||||
|
||||
// fn peek<T: zerocopy::FromBytes + zerocopy::KnownLayout + zerocopy::Immutable>(
|
||||
// &self,
|
||||
// ) -> Option<&T>;
|
||||
}
|
||||
|
||||
impl BufExt for BytesMut {
|
||||
fn try_get<T: zerocopy::FromBytes>(
|
||||
&mut self,
|
||||
) -> Result<T, zerocopy::error::SizeError<Self, T>> {
|
||||
let len = size_of::<T>();
|
||||
// this will error in the read_from_bytes if the buf is too small
|
||||
let len = usize::min(len, self.len());
|
||||
let buf = self.split_to(len);
|
||||
|
||||
T::read_from_bytes(&buf).map_err(|e| e.map_src(|_| buf.clone()))
|
||||
}
|
||||
|
||||
// fn peek<T: zerocopy::FromBytes + zerocopy::KnownLayout + zerocopy::Immutable>(
|
||||
// &self,
|
||||
// ) -> Option<&T> {
|
||||
// T::ref_from_prefix(self).ok().map(|(t, _)| t)
|
||||
// }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tokio::io::AsyncReadExt;
|
||||
@@ -242,7 +454,7 @@ mod tests {
|
||||
|
||||
let extra_data = [0x55; 256];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -250,7 +462,9 @@ mod tests {
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
|
||||
assert_eq!(bytes, extra_data);
|
||||
assert_eq!(addr, Some(([127, 0, 0, 1], 65535).into()));
|
||||
|
||||
let info = info.unwrap();
|
||||
assert_eq!(info.addr, ([127, 0, 0, 1], 65535).into());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -273,7 +487,7 @@ mod tests {
|
||||
|
||||
let extra_data = [0x55; 256];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -281,9 +495,11 @@ mod tests {
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
|
||||
assert_eq!(bytes, extra_data);
|
||||
|
||||
let info = info.unwrap();
|
||||
assert_eq!(
|
||||
addr,
|
||||
Some(([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into())
|
||||
info.addr,
|
||||
([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -291,30 +507,31 @@ mod tests {
|
||||
async fn test_invalid() {
|
||||
let data = [0x55; 256];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
assert_eq!(bytes, data);
|
||||
assert_eq!(addr, None);
|
||||
assert_eq!(info, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_short() {
|
||||
let data = [0x55; 10];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
assert_eq!(bytes, data);
|
||||
assert_eq!(addr, None);
|
||||
assert_eq!(info, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_large_tlv() {
|
||||
let tlv = vec![0x55; 32768];
|
||||
let len = (12 + tlv.len() as u16).to_be_bytes();
|
||||
let tlv_len = (tlv.len() as u16).to_be_bytes();
|
||||
let len = (12 + 3 + tlv.len() as u16).to_be_bytes();
|
||||
|
||||
let header = super::HEADER
|
||||
// Proxy command, Inet << 4 | Stream
|
||||
@@ -330,11 +547,13 @@ mod tests {
|
||||
// dst port
|
||||
.chain([1, 1].as_slice())
|
||||
// TLV
|
||||
.chain([255].as_slice())
|
||||
.chain(tlv_len.as_slice())
|
||||
.chain(tlv.as_slice());
|
||||
|
||||
let extra_data = [0xaa; 256];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -342,6 +561,8 @@ mod tests {
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
|
||||
assert_eq!(bytes, extra_data);
|
||||
assert_eq!(addr, Some(([55, 56, 57, 58], 65535).into()));
|
||||
|
||||
let info = info.unwrap();
|
||||
assert_eq!(info.addr, ([55, 56, 57, 58], 65535).into());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::read_proxy_protocol;
|
||||
use crate::protocol2::{read_proxy_protocol, ConnectionInfo};
|
||||
use crate::proxy::handshake::{handshake, HandshakeData};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
@@ -87,7 +87,7 @@ pub async fn task_main(
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
|
||||
let (socket, conn_info) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
@@ -100,8 +100,8 @@ pub async fn task_main(
|
||||
warn!("proxy protocol header not supported");
|
||||
return;
|
||||
}
|
||||
Ok((socket, Some(addr))) => (socket, addr.ip()),
|
||||
Ok((socket, None)) => (socket, peer_addr.ip()),
|
||||
Ok((socket, Some(info))) => (socket, info),
|
||||
Ok((socket, None)) => (socket, ConnectionInfo { addr: peer_addr, extra: None }),
|
||||
};
|
||||
|
||||
match socket.inner.set_nodelay(true) {
|
||||
@@ -114,7 +114,7 @@ pub async fn task_main(
|
||||
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
|
||||
@@ -44,10 +44,10 @@ use tracing::{info, warn, Instrument};
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
use crate::cancellation::CancellationHandlerMain;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::protocol2::{read_proxy_protocol, ChainRW};
|
||||
use crate::protocol2::{read_proxy_protocol, ChainRW, ConnectionInfo};
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
@@ -180,7 +180,7 @@ pub async fn task_main(
|
||||
peer_addr,
|
||||
))
|
||||
.await;
|
||||
let Some((conn, peer_addr)) = startup_result else {
|
||||
let Some((conn, conn_info)) = startup_result else {
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -192,7 +192,7 @@ pub async fn task_main(
|
||||
endpoint_rate_limiter,
|
||||
conn_token,
|
||||
conn,
|
||||
peer_addr,
|
||||
conn_info,
|
||||
session_id,
|
||||
))
|
||||
.await;
|
||||
@@ -240,7 +240,7 @@ async fn connection_startup(
|
||||
session_id: uuid::Uuid,
|
||||
conn: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> Option<(AsyncRW, IpAddr)> {
|
||||
) -> Option<(AsyncRW, ConnectionInfo)> {
|
||||
// handle PROXY protocol
|
||||
let (conn, peer) = match read_proxy_protocol(conn).await {
|
||||
Ok(c) => c,
|
||||
@@ -250,17 +250,32 @@ async fn connection_startup(
|
||||
}
|
||||
};
|
||||
|
||||
let peer_addr = peer.unwrap_or(peer_addr).ip();
|
||||
let has_private_peer_addr = match peer_addr {
|
||||
let conn_info = match peer {
|
||||
None if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
|
||||
tracing::warn!("missing required proxy protocol header");
|
||||
return None;
|
||||
}
|
||||
Some(_) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
|
||||
tracing::warn!("proxy protocol header not supported");
|
||||
return None;
|
||||
}
|
||||
Some(info) => info,
|
||||
None => ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
extra: None,
|
||||
},
|
||||
};
|
||||
|
||||
let has_private_peer_addr = match conn_info.addr.ip() {
|
||||
IpAddr::V4(ip) => ip.is_private(),
|
||||
IpAddr::V6(_) => false,
|
||||
};
|
||||
info!(?session_id, %peer_addr, "accepted new TCP connection");
|
||||
info!(?session_id, %conn_info, "accepted new TCP connection");
|
||||
|
||||
// try upgrade to TLS, but with a timeout.
|
||||
let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
|
||||
Ok(Ok(conn)) => {
|
||||
info!(?session_id, %peer_addr, "accepted new TLS connection");
|
||||
info!(?session_id, %conn_info, "accepted new TLS connection");
|
||||
conn
|
||||
}
|
||||
// The handshake failed
|
||||
@@ -268,7 +283,7 @@ async fn connection_startup(
|
||||
if !has_private_peer_addr {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc();
|
||||
}
|
||||
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
|
||||
warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
|
||||
return None;
|
||||
}
|
||||
// The handshake timed out
|
||||
@@ -276,12 +291,12 @@ async fn connection_startup(
|
||||
if !has_private_peer_addr {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc();
|
||||
}
|
||||
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
|
||||
warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
Some((conn, peer_addr))
|
||||
Some((conn, conn_info))
|
||||
}
|
||||
|
||||
/// Handles HTTP connection
|
||||
@@ -297,7 +312,7 @@ async fn connection_handler(
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_token: CancellationToken,
|
||||
conn: AsyncRW,
|
||||
peer_addr: IpAddr,
|
||||
conn_info: ConnectionInfo,
|
||||
session_id: uuid::Uuid,
|
||||
) {
|
||||
let session_id = AtomicTake::new(session_id);
|
||||
@@ -306,6 +321,7 @@ async fn connection_handler(
|
||||
let http_cancellation_token = CancellationToken::new();
|
||||
let _cancel_connection = http_cancellation_token.clone().drop_guard();
|
||||
|
||||
let conn_info2 = conn_info.clone();
|
||||
let server = Builder::new(TokioExecutor::new());
|
||||
let conn = server.serve_connection_with_upgrades(
|
||||
hyper_util::rt::TokioIo::new(conn),
|
||||
@@ -340,7 +356,7 @@ async fn connection_handler(
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
session_id,
|
||||
peer_addr,
|
||||
conn_info2.clone(),
|
||||
http_request_token,
|
||||
endpoint_rate_limiter.clone(),
|
||||
)
|
||||
@@ -365,7 +381,7 @@ async fn connection_handler(
|
||||
// On cancellation, trigger the HTTP connection handler to shut down.
|
||||
let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
|
||||
Either::Left((_cancelled, mut conn)) => {
|
||||
tracing::debug!(%peer_addr, "cancelling connection");
|
||||
tracing::debug!(%conn_info, "cancelling connection");
|
||||
conn.as_mut().graceful_shutdown();
|
||||
conn.await
|
||||
}
|
||||
@@ -373,8 +389,8 @@ async fn connection_handler(
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
|
||||
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
|
||||
Ok(()) => tracing::info!(%conn_info, "HTTP connection closed"),
|
||||
Err(e) => tracing::warn!(%conn_info, "HTTP connection error {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -386,7 +402,7 @@ async fn request_handler(
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
conn_info: ConnectionInfo,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -404,7 +420,7 @@ async fn request_handler(
|
||||
{
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Ws,
|
||||
&config.region,
|
||||
);
|
||||
@@ -439,7 +455,7 @@ async fn request_handler(
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Http,
|
||||
&config.region,
|
||||
);
|
||||
|
||||
@@ -46,16 +46,10 @@ pub struct SafekeeperPostgresHandler {
|
||||
/// Parsed Postgres command.
|
||||
enum SafekeeperPostgresCommand {
|
||||
StartWalPush,
|
||||
StartReplication {
|
||||
start_lsn: Lsn,
|
||||
term: Option<Term>,
|
||||
interpret_wal: bool,
|
||||
},
|
||||
StartReplication { start_lsn: Lsn, term: Option<Term> },
|
||||
IdentifySystem,
|
||||
TimelineStatus,
|
||||
JSONCtrl {
|
||||
cmd: AppendLogicalMessage,
|
||||
},
|
||||
JSONCtrl { cmd: AppendLogicalMessage },
|
||||
}
|
||||
|
||||
fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
|
||||
@@ -64,7 +58,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
|
||||
} else if cmd.starts_with("START_REPLICATION") {
|
||||
let re = Regex::new(
|
||||
// We follow postgres START_REPLICATION LOGICAL options to pass term.
|
||||
r"START_REPLICATION(?: SLOT [^ ]+)?(?: PHYSICAL)? ([[:xdigit:]]+/[[:xdigit:]]+)(?: \(term='(\d+)'\))?( interpret_wal)",
|
||||
r"START_REPLICATION(?: SLOT [^ ]+)?(?: PHYSICAL)? ([[:xdigit:]]+/[[:xdigit:]]+)(?: \(term='(\d+)'\))?",
|
||||
)
|
||||
.unwrap();
|
||||
let caps = re
|
||||
@@ -77,12 +71,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let interpret_wal = caps.get(3).is_some();
|
||||
Ok(SafekeeperPostgresCommand::StartReplication {
|
||||
start_lsn,
|
||||
term,
|
||||
interpret_wal,
|
||||
})
|
||||
Ok(SafekeeperPostgresCommand::StartReplication { start_lsn, term })
|
||||
} else if cmd.starts_with("IDENTIFY_SYSTEM") {
|
||||
Ok(SafekeeperPostgresCommand::IdentifySystem)
|
||||
} else if cmd.starts_with("TIMELINE_STATUS") {
|
||||
@@ -241,12 +230,8 @@ impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
|
||||
.instrument(info_span!("WAL receiver"))
|
||||
.await
|
||||
}
|
||||
SafekeeperPostgresCommand::StartReplication {
|
||||
start_lsn,
|
||||
term,
|
||||
interpret_wal,
|
||||
} => {
|
||||
self.handle_start_replication(pgb, start_lsn, term, interpret_wal)
|
||||
SafekeeperPostgresCommand::StartReplication { start_lsn, term } => {
|
||||
self.handle_start_replication(pgb, start_lsn, term)
|
||||
.instrument(info_span!("WAL sender"))
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -380,21 +380,17 @@ impl SafekeeperPostgresHandler {
|
||||
/// Wrapper around handle_start_replication_guts handling result. Error is
|
||||
/// handled here while we're still in walsender ttid span; with API
|
||||
/// extension, this can probably be moved into postgres_backend.
|
||||
///
|
||||
/// If interpret_wal is true, change the protocol to send custom Neon InterpretedWalRecord
|
||||
/// instead of XLogData, for ingestion by Pageservers.
|
||||
pub async fn handle_start_replication<IO: AsyncRead + AsyncWrite + Unpin>(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend<IO>,
|
||||
start_pos: Lsn,
|
||||
term: Option<Term>,
|
||||
interpret_wal: bool,
|
||||
) -> Result<(), QueryError> {
|
||||
let tli = GlobalTimelines::get(self.ttid).map_err(|e| QueryError::Other(e.into()))?;
|
||||
let residence_guard = tli.wal_residence_guard().await?;
|
||||
|
||||
if let Err(end) = self
|
||||
.handle_start_replication_guts(pgb, start_pos, term, interpret_wal, residence_guard)
|
||||
.handle_start_replication_guts(pgb, start_pos, term, residence_guard)
|
||||
.await
|
||||
{
|
||||
let info = tli.get_safekeeper_info(&self.conf).await;
|
||||
@@ -411,7 +407,6 @@ impl SafekeeperPostgresHandler {
|
||||
pgb: &mut PostgresBackend<IO>,
|
||||
start_pos: Lsn,
|
||||
term: Option<Term>,
|
||||
interpret_wal: bool,
|
||||
tli: WalResidentTimeline,
|
||||
) -> Result<(), CopyStreamHandlerEnd> {
|
||||
let appname = self.appname.clone();
|
||||
@@ -469,7 +464,6 @@ impl SafekeeperPostgresHandler {
|
||||
start_pos,
|
||||
end_pos,
|
||||
term,
|
||||
interpret_wal,
|
||||
end_watch,
|
||||
ws_guard: ws_guard.clone(),
|
||||
wal_reader,
|
||||
@@ -549,8 +543,6 @@ struct WalSender<'a, IO> {
|
||||
/// in. Streaming is stopped if local term changes to a different (higher)
|
||||
/// value.
|
||||
term: Option<Term>,
|
||||
/// If true, decode and filter WAL records and send InterpretedWalRecord instead of XLogRecord.
|
||||
interpret_wal: bool,
|
||||
/// Watch channel receiver to learn end of available WAL (and wait for its advancement).
|
||||
end_watch: EndWatch,
|
||||
ws_guard: Arc<WalSenderGuard>,
|
||||
@@ -579,49 +571,45 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> WalSender<'_, IO> {
|
||||
"nothing to send after waiting for WAL"
|
||||
);
|
||||
|
||||
let (msg, send_size) = if self.interpret_wal {
|
||||
(BeMessage::NeonInterpretedWalRecord(&[]), 0) // TODO
|
||||
// try to send as much as available, capped by MAX_SEND_SIZE
|
||||
let mut chunk_end_pos = self.start_pos + MAX_SEND_SIZE as u64;
|
||||
// if we went behind available WAL, back off
|
||||
if chunk_end_pos >= self.end_pos {
|
||||
chunk_end_pos = self.end_pos;
|
||||
} else {
|
||||
// try to send as much as available, capped by MAX_SEND_SIZE
|
||||
let mut chunk_end_pos = self.start_pos + MAX_SEND_SIZE as u64;
|
||||
// if we went behind available WAL, back off
|
||||
if chunk_end_pos >= self.end_pos {
|
||||
chunk_end_pos = self.end_pos;
|
||||
// If sending not up to end pos, round down to page boundary to
|
||||
// avoid breaking WAL record not at page boundary, as protocol
|
||||
// demands. See walsender.c (XLogSendPhysical).
|
||||
chunk_end_pos = chunk_end_pos
|
||||
.checked_sub(chunk_end_pos.block_offset())
|
||||
.unwrap();
|
||||
}
|
||||
let send_size = (chunk_end_pos.0 - self.start_pos.0) as usize;
|
||||
let send_buf = &mut self.send_buf[..send_size];
|
||||
let send_size: usize;
|
||||
{
|
||||
// If uncommitted part is being pulled, check that the term is
|
||||
// still the expected one.
|
||||
let _term_guard = if let Some(t) = self.term {
|
||||
Some(self.tli.acquire_term(t).await?)
|
||||
} else {
|
||||
// If sending not up to end pos, round down to page boundary to
|
||||
// avoid breaking WAL record not at page boundary, as protocol
|
||||
// demands. See walsender.c (XLogSendPhysical).
|
||||
chunk_end_pos = chunk_end_pos
|
||||
.checked_sub(chunk_end_pos.block_offset())
|
||||
.unwrap();
|
||||
}
|
||||
let send_size = (chunk_end_pos.0 - self.start_pos.0) as usize;
|
||||
let send_buf = &mut self.send_buf[..send_size];
|
||||
let send_size: usize;
|
||||
{
|
||||
// If uncommitted part is being pulled, check that the term is
|
||||
// still the expected one.
|
||||
let _term_guard = if let Some(t) = self.term {
|
||||
Some(self.tli.acquire_term(t).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
// Read WAL into buffer. send_size can be additionally capped to
|
||||
// segment boundary here.
|
||||
send_size = self.wal_reader.read(send_buf).await?
|
||||
None
|
||||
};
|
||||
let send_buf = &send_buf[..send_size];
|
||||
let msg = BeMessage::XLogData(XLogDataBody {
|
||||
// Read WAL into buffer. send_size can be additionally capped to
|
||||
// segment boundary here.
|
||||
send_size = self.wal_reader.read(send_buf).await?
|
||||
};
|
||||
let send_buf = &send_buf[..send_size];
|
||||
|
||||
// and send it
|
||||
self.pgb
|
||||
.write_message(&BeMessage::XLogData(XLogDataBody {
|
||||
wal_start: self.start_pos.0,
|
||||
wal_end: self.end_pos.0,
|
||||
timestamp: get_current_timestamp(),
|
||||
data: send_buf,
|
||||
});
|
||||
(msg, send_size)
|
||||
};
|
||||
|
||||
// and send it
|
||||
self.pgb.write_message(&msg).await?;
|
||||
}))
|
||||
.await?;
|
||||
|
||||
if let Some(appname) = &self.appname {
|
||||
if appname == "replica" {
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import closing
|
||||
|
||||
import pytest
|
||||
from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker
|
||||
from fixtures.common_types import Lsn, TenantShardId
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_fixtures import (
|
||||
NeonEnvBuilder,
|
||||
tenant_get_shards,
|
||||
wait_for_last_flush_lsn,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.timeout(600)
|
||||
@pytest.mark.parametrize("shard_count", [1, 8, 32])
|
||||
def test_sharded_ingest(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
zenbenchmark: NeonBenchmarker,
|
||||
shard_count: int,
|
||||
):
|
||||
"""
|
||||
Benchmarks sharded ingestion throughput, by ingesting a large amount of WAL into a Safekeeper
|
||||
and fanning out to a large number of shards on dedicated Pageservers. Comparing the base case
|
||||
(shard_count=1) to the sharded case indicates the overhead of sharding.
|
||||
"""
|
||||
|
||||
ROW_COUNT = 100_000_000 # about 7 GB of WAL
|
||||
|
||||
neon_env_builder.num_pageservers = shard_count
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
# Create a sharded tenant and timeline, and migrate it to the respective pageservers. Ensure
|
||||
# the storage controller doesn't mess with shard placements.
|
||||
#
|
||||
# TODO: there should be a way to disable storage controller background reconciliations.
|
||||
# Currently, disabling reconciliation also disables foreground operations.
|
||||
tenant_id, timeline_id = env.create_tenant(shard_count=shard_count)
|
||||
|
||||
for shard_number in range(0, shard_count):
|
||||
tenant_shard_id = TenantShardId(tenant_id, shard_number, shard_count)
|
||||
pageserver_id = shard_number + 1
|
||||
env.storage_controller.tenant_shard_migrate(tenant_shard_id, pageserver_id)
|
||||
|
||||
shards = tenant_get_shards(env, tenant_id)
|
||||
env.storage_controller.reconcile_until_idle()
|
||||
assert tenant_get_shards(env, tenant_id) == shards, "shards moved"
|
||||
|
||||
# Start the endpoint.
|
||||
endpoint = env.endpoints.create_start("main", tenant_id=tenant_id)
|
||||
start_lsn = Lsn(endpoint.safe_psql("select pg_current_wal_lsn()")[0][0])
|
||||
|
||||
# Ingest data and measure WAL volume and duration.
|
||||
with closing(endpoint.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
log.info("Ingesting data")
|
||||
cur.execute("set statement_timeout = 0")
|
||||
cur.execute("create table huge (i int, j int)")
|
||||
|
||||
with zenbenchmark.record_duration("pageserver_ingest"):
|
||||
with zenbenchmark.record_duration("wal_ingest"):
|
||||
cur.execute(f"insert into huge values (generate_series(1, {ROW_COUNT}), 0)")
|
||||
|
||||
wait_for_last_flush_lsn(env, endpoint, tenant_id, timeline_id)
|
||||
|
||||
end_lsn = Lsn(endpoint.safe_psql("select pg_current_wal_lsn()")[0][0])
|
||||
wal_written_mb = round((end_lsn - start_lsn) / (1024 * 1024))
|
||||
zenbenchmark.record("wal_written", wal_written_mb, "MB", MetricReport.TEST_PARAM)
|
||||
|
||||
assert tenant_get_shards(env, tenant_id) == shards, "shards moved"
|
||||
Reference in New Issue
Block a user