mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-21 07:00:38 +00:00
Compare commits
13 Commits
tristan957
...
vlad/safek
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b3ef315041 | ||
|
|
f0044b8651 | ||
|
|
b7ff993df6 | ||
|
|
5d096f127e | ||
|
|
70cdd56294 | ||
|
|
4dfa0c221b | ||
|
|
bdd492b1d8 | ||
|
|
5d8284c7fe | ||
|
|
ebc43efebc | ||
|
|
754d2950a3 | ||
|
|
fcde40d600 | ||
|
|
babfeb70ba | ||
|
|
2f1a56c8f9 |
101
Cargo.lock
generated
101
Cargo.lock
generated
@@ -1245,7 +1245,7 @@ dependencies = [
|
||||
"tar",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres 0.7.7",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
@@ -1351,7 +1351,7 @@ dependencies = [
|
||||
"storage_broker",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres 0.7.7",
|
||||
"tokio-util",
|
||||
"toml",
|
||||
"toml_edit",
|
||||
@@ -3620,8 +3620,8 @@ dependencies = [
|
||||
"pageserver_compaction",
|
||||
"pin-project-lite",
|
||||
"postgres",
|
||||
"postgres-protocol",
|
||||
"postgres-types",
|
||||
"postgres-protocol 0.6.4",
|
||||
"postgres-types 0.2.4",
|
||||
"postgres_backend",
|
||||
"postgres_connection",
|
||||
"postgres_ffi",
|
||||
@@ -3649,7 +3649,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-epoll-uring",
|
||||
"tokio-io-timeout",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres 0.7.7",
|
||||
"tokio-stream",
|
||||
"tokio-tar",
|
||||
"tokio-util",
|
||||
@@ -3707,7 +3707,7 @@ dependencies = [
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres 0.7.7",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"utils",
|
||||
@@ -4006,14 +4006,31 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "postgres"
|
||||
version = "0.19.4"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?rev=20031d7a9ee1addeae6e0968e3899ae6bf01cee2#20031d7a9ee1addeae6e0968e3899ae6bf01cee2"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres 0.7.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres-protocol"
|
||||
version = "0.6.4"
|
||||
dependencies = [
|
||||
"base64 0.20.0",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"hmac",
|
||||
"lazy_static",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"rand 0.8.5",
|
||||
"sha2",
|
||||
"stringprep",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4035,6 +4052,17 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres-types"
|
||||
version = "0.2.4"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"postgres-protocol 0.6.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres-types"
|
||||
version = "0.2.4"
|
||||
@@ -4042,7 +4070,7 @@ source = "git+https://github.com/neondatabase/rust-postgres.git?rev=20031d7a9ee1
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"postgres-protocol",
|
||||
"postgres-protocol 0.6.4 (git+https://github.com/neondatabase/rust-postgres.git?rev=20031d7a9ee1addeae6e0968e3899ae6bf01cee2)",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
@@ -4060,7 +4088,7 @@ dependencies = [
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres 0.7.7",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls 0.26.0",
|
||||
"tokio-util",
|
||||
@@ -4075,7 +4103,7 @@ dependencies = [
|
||||
"itertools 0.10.5",
|
||||
"once_cell",
|
||||
"postgres",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres 0.7.7",
|
||||
"url",
|
||||
]
|
||||
|
||||
@@ -4127,7 +4155,7 @@ dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"itertools 0.10.5",
|
||||
"postgres-protocol",
|
||||
"postgres-protocol 0.6.4",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"thiserror",
|
||||
@@ -4313,7 +4341,7 @@ dependencies = [
|
||||
"parquet_derive",
|
||||
"pbkdf2",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
"postgres-protocol 0.6.4",
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
"prometheus",
|
||||
@@ -4348,7 +4376,7 @@ dependencies = [
|
||||
"tikv-jemalloc-ctl",
|
||||
"tikv-jemallocator",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres 0.7.7",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls 0.26.0",
|
||||
"tokio-tungstenite",
|
||||
@@ -4365,6 +4393,7 @@ dependencies = [
|
||||
"walkdir",
|
||||
"workspace_hack",
|
||||
"x509-parser",
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5154,9 +5183,10 @@ dependencies = [
|
||||
"hyper 0.14.30",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
"pageserver_api",
|
||||
"parking_lot 0.12.1",
|
||||
"postgres",
|
||||
"postgres-protocol",
|
||||
"postgres-protocol 0.6.4",
|
||||
"postgres_backend",
|
||||
"postgres_ffi",
|
||||
"pq_proto",
|
||||
@@ -5176,7 +5206,7 @@ dependencies = [
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-io-timeout",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres 0.7.7",
|
||||
"tokio-stream",
|
||||
"tokio-tar",
|
||||
"tokio-util",
|
||||
@@ -5184,6 +5214,7 @@ dependencies = [
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
"utils",
|
||||
"wal_decoder",
|
||||
"walproposer",
|
||||
"workspace_hack",
|
||||
]
|
||||
@@ -5820,7 +5851,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"storage_controller_client",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres 0.7.7",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
@@ -6217,6 +6248,28 @@ dependencies = [
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-postgres"
|
||||
version = "0.7.7"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"log",
|
||||
"parking_lot 0.12.1",
|
||||
"percent-encoding",
|
||||
"phf",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol 0.6.4",
|
||||
"postgres-types 0.2.4",
|
||||
"socket2",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-postgres"
|
||||
version = "0.7.7"
|
||||
@@ -6233,8 +6286,8 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
"phf",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
"postgres-types",
|
||||
"postgres-protocol 0.6.4 (git+https://github.com/neondatabase/rust-postgres.git?rev=20031d7a9ee1addeae6e0968e3899ae6bf01cee2)",
|
||||
"postgres-types 0.2.4 (git+https://github.com/neondatabase/rust-postgres.git?rev=20031d7a9ee1addeae6e0968e3899ae6bf01cee2)",
|
||||
"socket2",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
@@ -6249,7 +6302,7 @@ dependencies = [
|
||||
"ring",
|
||||
"rustls 0.23.16",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres 0.7.7",
|
||||
"tokio-rustls 0.26.0",
|
||||
"x509-certificate",
|
||||
]
|
||||
@@ -6832,7 +6885,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"sysinfo",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres 0.7.7",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
@@ -7339,7 +7392,7 @@ dependencies = [
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"parquet",
|
||||
"postgres-types",
|
||||
"postgres-types 0.2.4 (git+https://github.com/neondatabase/rust-postgres.git?rev=20031d7a9ee1addeae6e0968e3899ae6bf01cee2)",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"prost",
|
||||
@@ -7364,7 +7417,7 @@ dependencies = [
|
||||
"time",
|
||||
"time-macros",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres 0.7.7 (git+https://github.com/neondatabase/rust-postgres.git?rev=20031d7a9ee1addeae6e0968e3899ae6bf01cee2)",
|
||||
"tokio-rustls 0.26.0",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
@@ -7374,6 +7427,7 @@ dependencies = [
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"url",
|
||||
"zerocopy",
|
||||
"zeroize",
|
||||
"zstd",
|
||||
"zstd-safe",
|
||||
@@ -7446,6 +7500,7 @@ version = "0.7.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"zerocopy-derive",
|
||||
]
|
||||
|
||||
|
||||
16
Cargo.toml
16
Cargo.toml
@@ -196,6 +196,7 @@ walkdir = "2.3.2"
|
||||
rustls-native-certs = "0.8"
|
||||
x509-parser = "0.16"
|
||||
whoami = "1.5.1"
|
||||
zerocopy = { version = "0.7", features = ["derive"] }
|
||||
|
||||
## TODO replace this with tracing
|
||||
env_logger = "0.10"
|
||||
@@ -213,10 +214,14 @@ log = "0.4"
|
||||
#
|
||||
# When those proxy changes are re-applied (see PR #8747), we can switch using
|
||||
# the tip of the 'neon' branch again.
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2" }
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2" }
|
||||
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2" }
|
||||
# postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2" }
|
||||
# postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2" }
|
||||
# postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2" }
|
||||
# tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2" }
|
||||
postgres = { path = "../../.cargo/git/checkouts/rust-postgres-e2c00088c8e2b112/20031d7/postgres" }
|
||||
postgres-protocol = { path = "../../.cargo/git/checkouts/rust-postgres-e2c00088c8e2b112/20031d7/postgres-protocol" }
|
||||
postgres-types = { path = "../../.cargo/git/checkouts/rust-postgres-e2c00088c8e2b112/20031d7/postgres-types" }
|
||||
tokio-postgres = { path = "../../.cargo/git/checkouts/rust-postgres-e2c00088c8e2b112/20031d7/tokio-postgres" }
|
||||
|
||||
## Local libraries
|
||||
compute_api = { version = "0.1", path = "./libs/compute_api/" }
|
||||
@@ -254,7 +259,8 @@ tonic-build = "0.12"
|
||||
[patch.crates-io]
|
||||
|
||||
# Needed to get `tokio-postgres-rustls` to depend on our fork.
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2" }
|
||||
tokio-postgres = { path = "../../.cargo/git/checkouts/rust-postgres-e2c00088c8e2b112/20031d7/tokio-postgres" }
|
||||
# tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2" }
|
||||
|
||||
################# Binary contents sections
|
||||
|
||||
|
||||
@@ -84,9 +84,8 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
let status = compute.get_status();
|
||||
if status != ComputeStatus::Running {
|
||||
let msg = format!(
|
||||
"invalid compute status for check_writability request: {}, need {}",
|
||||
status,
|
||||
ComputeStatus::Running
|
||||
"invalid compute status for check_writability request: {:?}",
|
||||
status
|
||||
);
|
||||
error!(msg);
|
||||
return Response::new(Body::from(msg));
|
||||
@@ -107,9 +106,8 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
let status = compute.get_status();
|
||||
if status != ComputeStatus::Running {
|
||||
let msg = format!(
|
||||
"invalid compute status for extensions request: {}, need {}",
|
||||
status,
|
||||
ComputeStatus::Running
|
||||
"invalid compute status for extensions request: {:?}",
|
||||
status
|
||||
);
|
||||
error!(msg);
|
||||
return render_json_error(&msg, StatusCode::PRECONDITION_FAILED);
|
||||
@@ -207,9 +205,8 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
let status = compute.get_status();
|
||||
if status != ComputeStatus::Running {
|
||||
let msg = format!(
|
||||
"invalid compute status for set_role_grants request: {}, need {}",
|
||||
status,
|
||||
ComputeStatus::Running
|
||||
"invalid compute status for set_role_grants request: {:?}",
|
||||
status
|
||||
);
|
||||
error!(msg);
|
||||
return render_json_error(&msg, StatusCode::PRECONDITION_FAILED);
|
||||
@@ -253,9 +250,8 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
let status = compute.get_status();
|
||||
if status != ComputeStatus::Running {
|
||||
let msg = format!(
|
||||
"invalid compute status for extensions request: {}, need {}",
|
||||
status,
|
||||
ComputeStatus::Running
|
||||
"invalid compute status for extensions request: {:?}",
|
||||
status
|
||||
);
|
||||
error!(msg);
|
||||
return Response::new(Body::from(msg));
|
||||
@@ -387,12 +383,10 @@ async fn handle_configure_request(
|
||||
// ```
|
||||
{
|
||||
let mut state = compute.state.lock().unwrap();
|
||||
if !matches!(state.status, ComputeStatus::Empty | ComputeStatus::Running) {
|
||||
if state.status != ComputeStatus::Empty && state.status != ComputeStatus::Running {
|
||||
let msg = format!(
|
||||
"invalid compute status for configuration request: {}, cannot be {} or {}",
|
||||
state.status,
|
||||
ComputeStatus::Empty,
|
||||
ComputeStatus::Running
|
||||
"invalid compute status for configuration request: {:?}",
|
||||
state.status.clone()
|
||||
);
|
||||
return Err((msg, StatusCode::PRECONDITION_FAILED));
|
||||
}
|
||||
@@ -468,12 +462,10 @@ async fn handle_terminate_request(compute: &Arc<ComputeNode>) -> Result<(), (Str
|
||||
if state.status == ComputeStatus::Terminated {
|
||||
return Ok(());
|
||||
}
|
||||
if !matches!(state.status, ComputeStatus::Empty | ComputeStatus::Running) {
|
||||
if state.status != ComputeStatus::Empty && state.status != ComputeStatus::Running {
|
||||
let msg = format!(
|
||||
"invalid compute status for termination request: {}, cannot be {} or {}",
|
||||
state.status,
|
||||
ComputeStatus::Empty,
|
||||
ComputeStatus::Running,
|
||||
"invalid compute status for termination request: {}",
|
||||
state.status
|
||||
);
|
||||
return Err((msg, StatusCode::PRECONDITION_FAILED));
|
||||
}
|
||||
|
||||
@@ -66,15 +66,14 @@ pub enum ComputeStatus {
|
||||
|
||||
impl Display for ComputeStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// snake_case matches the serde implementation
|
||||
match self {
|
||||
ComputeStatus::Empty => f.write_str("empty"),
|
||||
ComputeStatus::ConfigurationPending => f.write_str("configuration_pending"),
|
||||
ComputeStatus::ConfigurationPending => f.write_str("configuration-pending"),
|
||||
ComputeStatus::Init => f.write_str("init"),
|
||||
ComputeStatus::Running => f.write_str("running"),
|
||||
ComputeStatus::Configuration => f.write_str("configuration"),
|
||||
ComputeStatus::Failed => f.write_str("failed"),
|
||||
ComputeStatus::TerminationPending => f.write_str("termination_pending"),
|
||||
ComputeStatus::TerminationPending => f.write_str("termination-pending"),
|
||||
ComputeStatus::Terminated => f.write_str("terminated"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ pub struct Key {
|
||||
|
||||
/// When working with large numbers of Keys in-memory, it is more efficient to handle them as i128 than as
|
||||
/// a struct of fields.
|
||||
#[derive(Clone, Copy, Hash, PartialEq, Eq, Ord, PartialOrd)]
|
||||
#[derive(Clone, Copy, Hash, PartialEq, Eq, Ord, PartialOrd, Serialize, Deserialize, Debug)]
|
||||
pub struct CompactKey(i128);
|
||||
|
||||
/// The storage key size.
|
||||
|
||||
@@ -24,7 +24,7 @@ use postgres_ffi::Oid;
|
||||
// FIXME: should move 'forknum' as last field to keep this consistent with Postgres.
|
||||
// Then we could replace the custom Ord and PartialOrd implementations below with
|
||||
// deriving them. This will require changes in walredoproc.c.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize)]
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct RelTag {
|
||||
pub forknum: u8,
|
||||
pub spcnode: Oid,
|
||||
|
||||
@@ -16,7 +16,7 @@ use utils::bin_ser::DeserializeError;
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct XlMultiXactCreate {
|
||||
pub mid: MultiXactId,
|
||||
/* new MultiXact's ID */
|
||||
@@ -46,7 +46,7 @@ impl XlMultiXactCreate {
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct XlMultiXactTruncate {
|
||||
pub oldest_multi_db: Oid,
|
||||
/* to-be-truncated range of multixact offsets */
|
||||
@@ -72,7 +72,7 @@ impl XlMultiXactTruncate {
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct XlRelmapUpdate {
|
||||
pub dbid: Oid, /* database ID, or 0 for shared map */
|
||||
pub tsid: Oid, /* database's tablespace, or pg_global */
|
||||
@@ -90,7 +90,7 @@ impl XlRelmapUpdate {
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct XlReploriginDrop {
|
||||
pub node_id: RepOriginId,
|
||||
}
|
||||
@@ -104,7 +104,7 @@ impl XlReploriginDrop {
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct XlReploriginSet {
|
||||
pub remote_lsn: Lsn,
|
||||
pub node_id: RepOriginId,
|
||||
@@ -120,7 +120,7 @@ impl XlReploriginSet {
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct RelFileNode {
|
||||
pub spcnode: Oid, /* tablespace */
|
||||
pub dbnode: Oid, /* database */
|
||||
@@ -911,7 +911,7 @@ impl XlSmgrCreate {
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct XlSmgrTruncate {
|
||||
pub blkno: BlockNumber,
|
||||
pub rnode: RelFileNode,
|
||||
@@ -984,7 +984,7 @@ impl XlDropDatabase {
|
||||
/// xl_xact_parsed_abort structs in PostgreSQL, but we use the same
|
||||
/// struct for commits and aborts.
|
||||
///
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct XlXactParsedRecord {
|
||||
pub xid: TransactionId,
|
||||
pub info: u8,
|
||||
|
||||
@@ -13,3 +13,4 @@ rand.workspace = true
|
||||
tokio = { workspace = true, features = ["io-util"] }
|
||||
thiserror.workspace = true
|
||||
serde.workspace = true
|
||||
# wal_decoder.workspace = true
|
||||
|
||||
@@ -562,6 +562,7 @@ pub enum BeMessage<'a> {
|
||||
options: &'a [&'a str],
|
||||
},
|
||||
KeepAlive(WalSndKeepAlive),
|
||||
InterpretedWalRecord(InterpretedWalRecordBody<'a>),
|
||||
}
|
||||
|
||||
/// Common shorthands.
|
||||
@@ -665,6 +666,12 @@ pub struct XLogDataBody<'a> {
|
||||
pub data: &'a [u8],
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct InterpretedWalRecordBody<'a> {
|
||||
pub wal_end: u64,
|
||||
pub data: &'a [u8],
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WalSndKeepAlive {
|
||||
pub wal_end: u64, // current end of WAL on the server
|
||||
@@ -996,6 +1003,15 @@ impl BeMessage<'_> {
|
||||
Ok(())
|
||||
})?
|
||||
}
|
||||
|
||||
BeMessage::InterpretedWalRecord(rec) => {
|
||||
buf.put_u8(b'd'); // arbitrary?
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u8(b'0');
|
||||
buf.put_u64(rec.wal_end);
|
||||
buf.put_slice(rec.data);
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -7,29 +7,65 @@ use postgres_connection::{parse_host_port, PgConnectionConfig};
|
||||
|
||||
use crate::id::TenantTimelineId;
|
||||
|
||||
/// Protocol used for safekeeper recovery. This sends raw Postgres WAL.
|
||||
pub const POSTGRES_PROTO_VERSION: u8 = 0;
|
||||
/// Protocol used for safekeeper to pageserver communication.
|
||||
/// This sends interpreted WAL records for the pageserver to ingest
|
||||
/// and is shard-aware.
|
||||
pub const PAGESERVER_SAFEKEEPER_PROTO_VERSION: u8 = 1;
|
||||
|
||||
pub struct ConnectionConfigArgs<'a> {
|
||||
pub protocol_version: u8,
|
||||
|
||||
pub ttid: TenantTimelineId,
|
||||
pub shard_number: Option<u8>,
|
||||
pub shard_count: Option<u8>,
|
||||
pub shard_stripe_size: Option<u32>,
|
||||
|
||||
pub listen_pg_addr_str: &'a str,
|
||||
|
||||
pub auth_token: Option<&'a str>,
|
||||
pub availability_zone: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a> ConnectionConfigArgs<'a> {
|
||||
fn options(&'a self) -> Vec<String> {
|
||||
let mut options = vec![
|
||||
"-c".to_owned(),
|
||||
format!("timeline_id={}", self.ttid.timeline_id),
|
||||
format!("tenant_id={}", self.ttid.tenant_id),
|
||||
format!("protocol_version={}", self.protocol_version),
|
||||
];
|
||||
|
||||
if self.shard_number.is_some() {
|
||||
assert!(self.shard_count.is_some());
|
||||
assert!(self.shard_stripe_size.is_some());
|
||||
|
||||
options.push(format!("shard_count={}", self.shard_count.unwrap()));
|
||||
options.push(format!("shard_number={}", self.shard_number.unwrap()));
|
||||
options.push(format!(
|
||||
"shard_stripe_size={}",
|
||||
self.shard_stripe_size.unwrap()
|
||||
));
|
||||
}
|
||||
|
||||
options
|
||||
}
|
||||
}
|
||||
|
||||
/// Create client config for fetching WAL from safekeeper on particular timeline.
|
||||
/// listen_pg_addr_str is in form host:\[port\].
|
||||
pub fn wal_stream_connection_config(
|
||||
TenantTimelineId {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
}: TenantTimelineId,
|
||||
listen_pg_addr_str: &str,
|
||||
auth_token: Option<&str>,
|
||||
availability_zone: Option<&str>,
|
||||
args: ConnectionConfigArgs,
|
||||
) -> anyhow::Result<PgConnectionConfig> {
|
||||
let (host, port) =
|
||||
parse_host_port(listen_pg_addr_str).context("Unable to parse listen_pg_addr_str")?;
|
||||
parse_host_port(args.listen_pg_addr_str).context("Unable to parse listen_pg_addr_str")?;
|
||||
let port = port.unwrap_or(5432);
|
||||
let mut connstr = PgConnectionConfig::new_host_port(host, port)
|
||||
.extend_options([
|
||||
"-c".to_owned(),
|
||||
format!("timeline_id={}", timeline_id),
|
||||
format!("tenant_id={}", tenant_id),
|
||||
])
|
||||
.set_password(auth_token.map(|s| s.to_owned()));
|
||||
.extend_options(args.options())
|
||||
.set_password(args.auth_token.map(|s| s.to_owned()));
|
||||
|
||||
if let Some(availability_zone) = availability_zone {
|
||||
if let Some(availability_zone) = args.availability_zone {
|
||||
connstr = connstr.extend_options([format!("availability_zone={}", availability_zone)]);
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[features]
|
||||
testing = []
|
||||
testing = ["pageserver_api/testing"]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
|
||||
@@ -2,15 +2,13 @@
|
||||
//! raw bytes which represent a raw Postgres WAL record.
|
||||
|
||||
use crate::models::*;
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use pageserver_api::key::rel_block_to_key;
|
||||
use pageserver_api::record::NeonWalRecord;
|
||||
use crate::serialized_batch::SerializedValueBatch;
|
||||
use bytes::{Buf, Bytes};
|
||||
use pageserver_api::reltag::{RelTag, SlruKind};
|
||||
use pageserver_api::shard::ShardIdentity;
|
||||
use pageserver_api::value::Value;
|
||||
use postgres_ffi::pg_constants;
|
||||
use postgres_ffi::relfile_utils::VISIBILITYMAP_FORKNUM;
|
||||
use postgres_ffi::walrecord::*;
|
||||
use postgres_ffi::{page_is_new, page_set_lsn, pg_constants, BLCKSZ};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
impl InterpretedWalRecord {
|
||||
@@ -21,11 +19,12 @@ impl InterpretedWalRecord {
|
||||
pub fn from_bytes_filtered(
|
||||
buf: Bytes,
|
||||
shard: &ShardIdentity,
|
||||
lsn: Lsn,
|
||||
record_end_lsn: Lsn,
|
||||
pg_version: u32,
|
||||
) -> anyhow::Result<InterpretedWalRecord> {
|
||||
let mut decoded = DecodedWALRecord::default();
|
||||
decode_wal_record(buf, &mut decoded, pg_version)?;
|
||||
let xid = decoded.xl_xid;
|
||||
|
||||
let flush_uncommitted = if decoded.is_dbase_create_copy(pg_version) {
|
||||
FlushUncommittedRecords::Yes
|
||||
@@ -33,96 +32,20 @@ impl InterpretedWalRecord {
|
||||
FlushUncommittedRecords::No
|
||||
};
|
||||
|
||||
let metadata_record = MetadataRecord::from_decoded(&decoded, lsn, pg_version)?;
|
||||
|
||||
let mut blocks = Vec::default();
|
||||
for blk in decoded.blocks.iter() {
|
||||
let rel = RelTag {
|
||||
spcnode: blk.rnode_spcnode,
|
||||
dbnode: blk.rnode_dbnode,
|
||||
relnode: blk.rnode_relnode,
|
||||
forknum: blk.forknum,
|
||||
};
|
||||
|
||||
let key = rel_block_to_key(rel, blk.blkno);
|
||||
|
||||
if !key.is_valid_key_on_write_path() {
|
||||
anyhow::bail!("Unsupported key decoded at LSN {}: {}", lsn, key);
|
||||
}
|
||||
|
||||
let key_is_local = shard.is_key_local(&key);
|
||||
|
||||
tracing::debug!(
|
||||
lsn=%lsn,
|
||||
key=%key,
|
||||
"ingest: shard decision {}",
|
||||
if !key_is_local { "drop" } else { "keep" },
|
||||
);
|
||||
|
||||
if !key_is_local {
|
||||
if shard.is_shard_zero() {
|
||||
// Shard 0 tracks relation sizes. Although we will not store this block, we will observe
|
||||
// its blkno in case it implicitly extends a relation.
|
||||
blocks.push((key.to_compact(), None));
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// Instead of storing full-page-image WAL record,
|
||||
// it is better to store extracted image: we can skip wal-redo
|
||||
// in this case. Also some FPI records may contain multiple (up to 32) pages,
|
||||
// so them have to be copied multiple times.
|
||||
//
|
||||
let value = if blk.apply_image
|
||||
&& blk.has_image
|
||||
&& decoded.xl_rmid == pg_constants::RM_XLOG_ID
|
||||
&& (decoded.xl_info == pg_constants::XLOG_FPI
|
||||
|| decoded.xl_info == pg_constants::XLOG_FPI_FOR_HINT)
|
||||
// compression of WAL is not yet supported: fall back to storing the original WAL record
|
||||
&& !postgres_ffi::bkpimage_is_compressed(blk.bimg_info, pg_version)
|
||||
// do not materialize null pages because them most likely be soon replaced with real data
|
||||
&& blk.bimg_len != 0
|
||||
{
|
||||
// Extract page image from FPI record
|
||||
let img_len = blk.bimg_len as usize;
|
||||
let img_offs = blk.bimg_offset as usize;
|
||||
let mut image = BytesMut::with_capacity(BLCKSZ as usize);
|
||||
// TODO(vlad): skip the copy
|
||||
image.extend_from_slice(&decoded.record[img_offs..img_offs + img_len]);
|
||||
|
||||
if blk.hole_length != 0 {
|
||||
let tail = image.split_off(blk.hole_offset as usize);
|
||||
image.resize(image.len() + blk.hole_length as usize, 0u8);
|
||||
image.unsplit(tail);
|
||||
}
|
||||
//
|
||||
// Match the logic of XLogReadBufferForRedoExtended:
|
||||
// The page may be uninitialized. If so, we can't set the LSN because
|
||||
// that would corrupt the page.
|
||||
//
|
||||
if !page_is_new(&image) {
|
||||
page_set_lsn(&mut image, lsn)
|
||||
}
|
||||
assert_eq!(image.len(), BLCKSZ as usize);
|
||||
|
||||
Value::Image(image.freeze())
|
||||
} else {
|
||||
Value::WalRecord(NeonWalRecord::Postgres {
|
||||
will_init: blk.will_init || blk.apply_image,
|
||||
rec: decoded.record.clone(),
|
||||
})
|
||||
};
|
||||
|
||||
blocks.push((key.to_compact(), Some(value)));
|
||||
}
|
||||
let metadata_record = MetadataRecord::from_decoded(&decoded, record_end_lsn, pg_version)?;
|
||||
let batch = SerializedValueBatch::from_decoded_filtered(
|
||||
decoded,
|
||||
shard,
|
||||
record_end_lsn,
|
||||
pg_version,
|
||||
)?;
|
||||
|
||||
Ok(InterpretedWalRecord {
|
||||
metadata_record,
|
||||
blocks,
|
||||
lsn,
|
||||
batch,
|
||||
end_lsn: record_end_lsn,
|
||||
flush_uncommitted,
|
||||
xid: decoded.xl_xid,
|
||||
xid,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -130,7 +53,7 @@ impl InterpretedWalRecord {
|
||||
impl MetadataRecord {
|
||||
fn from_decoded(
|
||||
decoded: &DecodedWALRecord,
|
||||
lsn: Lsn,
|
||||
record_end_lsn: Lsn,
|
||||
pg_version: u32,
|
||||
) -> anyhow::Result<Option<MetadataRecord>> {
|
||||
// Note: this doesn't actually copy the bytes since
|
||||
@@ -151,7 +74,7 @@ impl MetadataRecord {
|
||||
Ok(None)
|
||||
}
|
||||
pg_constants::RM_CLOG_ID => Self::decode_clog_record(&mut buf, decoded, pg_version),
|
||||
pg_constants::RM_XACT_ID => Self::decode_xact_record(&mut buf, decoded, lsn),
|
||||
pg_constants::RM_XACT_ID => Self::decode_xact_record(&mut buf, decoded, record_end_lsn),
|
||||
pg_constants::RM_MULTIXACT_ID => {
|
||||
Self::decode_multixact_record(&mut buf, decoded, pg_version)
|
||||
}
|
||||
@@ -163,7 +86,7 @@ impl MetadataRecord {
|
||||
//
|
||||
// Alternatively, one can make the checkpoint part of the subscription protocol
|
||||
// to the pageserver. This should work fine, but can be done at a later point.
|
||||
pg_constants::RM_XLOG_ID => Self::decode_xlog_record(&mut buf, decoded, lsn),
|
||||
pg_constants::RM_XLOG_ID => Self::decode_xlog_record(&mut buf, decoded, record_end_lsn),
|
||||
pg_constants::RM_LOGICALMSG_ID => {
|
||||
Self::decode_logical_message_record(&mut buf, decoded)
|
||||
}
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
pub mod decoder;
|
||||
pub mod models;
|
||||
pub mod serialized_batch;
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
//! ready for the pageserver to interpret. They are derived from the original
|
||||
//! WAL records, so that each struct corresponds closely to one WAL record of
|
||||
//! a specific kind. They contain the same information as the original WAL records,
|
||||
//! just decoded into structs and fields for easier access.
|
||||
//! but the values are already serialized in a [`SerializedValueBatch`], which
|
||||
//! is the format that the pageserver is expecting them in.
|
||||
//!
|
||||
//! The ingestion code uses these structs to help with parsing the WAL records,
|
||||
//! and it splits them into a stream of modifications to the key-value pairs that
|
||||
@@ -25,32 +26,34 @@
|
||||
//! |--> write to KV store within the pageserver
|
||||
|
||||
use bytes::Bytes;
|
||||
use pageserver_api::key::CompactKey;
|
||||
use pageserver_api::reltag::{RelTag, SlruKind};
|
||||
use pageserver_api::value::Value;
|
||||
use postgres_ffi::walrecord::{
|
||||
XlMultiXactCreate, XlMultiXactTruncate, XlRelmapUpdate, XlReploriginDrop, XlReploriginSet,
|
||||
XlSmgrTruncate, XlXactParsedRecord,
|
||||
};
|
||||
use postgres_ffi::{Oid, TransactionId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use crate::serialized_batch::SerializedValueBatch;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum FlushUncommittedRecords {
|
||||
Yes,
|
||||
No,
|
||||
}
|
||||
|
||||
/// An interpreted Postgres WAL record, ready to be handled by the pageserver
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct InterpretedWalRecord {
|
||||
/// Optional metadata record - may cause writes to metadata keys
|
||||
/// in the storage engine
|
||||
pub metadata_record: Option<MetadataRecord>,
|
||||
/// Images or deltas for blocks modified in the original WAL record.
|
||||
/// The [`Value`] is optional to avoid sending superfluous data to
|
||||
/// shard 0 for relation size tracking.
|
||||
pub blocks: Vec<(CompactKey, Option<Value>)>,
|
||||
/// A pre-serialized batch along with the required metadata for ingestion
|
||||
/// by the pageserver
|
||||
pub batch: SerializedValueBatch,
|
||||
/// Byte offset within WAL for the end of the original PG WAL record
|
||||
pub lsn: Lsn,
|
||||
pub end_lsn: Lsn,
|
||||
/// Whether to flush all uncommitted modifications to the storage engine
|
||||
/// before ingesting this record. This is currently only used for legacy PG
|
||||
/// database creations which read pages from a template database. Such WAL
|
||||
@@ -62,6 +65,7 @@ pub struct InterpretedWalRecord {
|
||||
|
||||
/// The interpreted part of the Postgres WAL record which requires metadata
|
||||
/// writes to the underlying storage engine.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum MetadataRecord {
|
||||
Heapam(HeapamRecord),
|
||||
Neonrmgr(NeonrmgrRecord),
|
||||
@@ -77,10 +81,12 @@ pub enum MetadataRecord {
|
||||
Replorigin(ReploriginRecord),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum HeapamRecord {
|
||||
ClearVmBits(ClearVmBits),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ClearVmBits {
|
||||
pub new_heap_blkno: Option<u32>,
|
||||
pub old_heap_blkno: Option<u32>,
|
||||
@@ -88,24 +94,29 @@ pub struct ClearVmBits {
|
||||
pub flags: u8,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum NeonrmgrRecord {
|
||||
ClearVmBits(ClearVmBits),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum SmgrRecord {
|
||||
Create(SmgrCreate),
|
||||
Truncate(XlSmgrTruncate),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SmgrCreate {
|
||||
pub rel: RelTag,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum DbaseRecord {
|
||||
Create(DbaseCreate),
|
||||
Drop(DbaseDrop),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct DbaseCreate {
|
||||
pub db_id: Oid,
|
||||
pub tablespace_id: Oid,
|
||||
@@ -113,27 +124,32 @@ pub struct DbaseCreate {
|
||||
pub src_tablespace_id: Oid,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct DbaseDrop {
|
||||
pub db_id: Oid,
|
||||
pub tablespace_ids: Vec<Oid>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum ClogRecord {
|
||||
ZeroPage(ClogZeroPage),
|
||||
Truncate(ClogTruncate),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ClogZeroPage {
|
||||
pub segno: u32,
|
||||
pub rpageno: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ClogTruncate {
|
||||
pub pageno: u32,
|
||||
pub oldest_xid: TransactionId,
|
||||
pub oldest_xid_db: Oid,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum XactRecord {
|
||||
Commit(XactCommon),
|
||||
Abort(XactCommon),
|
||||
@@ -142,6 +158,7 @@ pub enum XactRecord {
|
||||
Prepare(XactPrepare),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct XactCommon {
|
||||
pub parsed: XlXactParsedRecord,
|
||||
pub origin_id: u16,
|
||||
@@ -150,61 +167,73 @@ pub struct XactCommon {
|
||||
pub lsn: Lsn,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct XactPrepare {
|
||||
pub xl_xid: TransactionId,
|
||||
pub data: Bytes,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum MultiXactRecord {
|
||||
ZeroPage(MultiXactZeroPage),
|
||||
Create(XlMultiXactCreate),
|
||||
Truncate(XlMultiXactTruncate),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct MultiXactZeroPage {
|
||||
pub slru_kind: SlruKind,
|
||||
pub segno: u32,
|
||||
pub rpageno: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum RelmapRecord {
|
||||
Update(RelmapUpdate),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct RelmapUpdate {
|
||||
pub update: XlRelmapUpdate,
|
||||
pub buf: Bytes,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum XlogRecord {
|
||||
Raw(RawXlogRecord),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct RawXlogRecord {
|
||||
pub info: u8,
|
||||
pub lsn: Lsn,
|
||||
pub buf: Bytes,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum LogicalMessageRecord {
|
||||
Put(PutLogicalMessage),
|
||||
#[cfg(feature = "testing")]
|
||||
Failpoint,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct PutLogicalMessage {
|
||||
pub path: String,
|
||||
pub buf: Bytes,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum StandbyRecord {
|
||||
RunningXacts(StandbyRunningXacts),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct StandbyRunningXacts {
|
||||
pub oldest_running_xid: TransactionId,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum ReploriginRecord {
|
||||
Set(XlReploriginSet),
|
||||
Drop(XlReploriginDrop),
|
||||
|
||||
867
libs/wal_decoder/src/serialized_batch.rs
Normal file
867
libs/wal_decoder/src/serialized_batch.rs
Normal file
@@ -0,0 +1,867 @@
|
||||
//! This module implements batch type for serialized [`pageserver_api::value::Value`]
|
||||
//! instances. Each batch contains a raw buffer (serialized values)
|
||||
//! and a list of metadata for each (key, LSN) tuple present in the batch.
|
||||
//!
|
||||
//! Such batches are created from decoded PG wal records and ingested
|
||||
//! by the pageserver by writing directly to the ephemeral file.
|
||||
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use pageserver_api::key::rel_block_to_key;
|
||||
use pageserver_api::keyspace::KeySpace;
|
||||
use pageserver_api::record::NeonWalRecord;
|
||||
use pageserver_api::reltag::RelTag;
|
||||
use pageserver_api::shard::ShardIdentity;
|
||||
use pageserver_api::{key::CompactKey, value::Value};
|
||||
use postgres_ffi::walrecord::{DecodedBkpBlock, DecodedWALRecord};
|
||||
use postgres_ffi::{page_is_new, page_set_lsn, pg_constants, BLCKSZ};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utils::bin_ser::BeSer;
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use pageserver_api::key::Key;
|
||||
|
||||
static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]);
|
||||
|
||||
/// Accompanying metadata for the batch
|
||||
/// A value may be serialized and stored into the batch or just "observed".
|
||||
/// Shard 0 currently "observes" all values in order to accurately track
|
||||
/// relation sizes. In the case of "observed" values, we only need to know
|
||||
/// the key and LSN, so two types of metadata are supported to save on network
|
||||
/// bandwidth.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum ValueMeta {
|
||||
Serialized(SerializedValueMeta),
|
||||
Observed(ObservedValueMeta),
|
||||
}
|
||||
|
||||
impl ValueMeta {
|
||||
pub fn key(&self) -> CompactKey {
|
||||
match self {
|
||||
Self::Serialized(ser) => ser.key,
|
||||
Self::Observed(obs) => obs.key,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lsn(&self) -> Lsn {
|
||||
match self {
|
||||
Self::Serialized(ser) => ser.lsn,
|
||||
Self::Observed(obs) => obs.lsn,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper around [`ValueMeta`] that implements ordering by
|
||||
/// (key, LSN) tuples
|
||||
struct OrderedValueMeta(ValueMeta);
|
||||
|
||||
impl Ord for OrderedValueMeta {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
(self.0.key(), self.0.lsn()).cmp(&(other.0.key(), other.0.lsn()))
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for OrderedValueMeta {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for OrderedValueMeta {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
(self.0.key(), self.0.lsn()) == (other.0.key(), other.0.lsn())
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for OrderedValueMeta {}
|
||||
|
||||
/// Metadata for a [`Value`] serialized into the batch.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct SerializedValueMeta {
|
||||
pub key: CompactKey,
|
||||
pub lsn: Lsn,
|
||||
/// Starting offset of the value for the (key, LSN) tuple
|
||||
/// in [`SerializedValueBatch::raw`]
|
||||
pub batch_offset: u64,
|
||||
pub len: usize,
|
||||
pub will_init: bool,
|
||||
}
|
||||
|
||||
/// Metadata for a [`Value`] observed by the batch
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ObservedValueMeta {
|
||||
pub key: CompactKey,
|
||||
pub lsn: Lsn,
|
||||
}
|
||||
|
||||
/// Batch of serialized [`Value`]s.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SerializedValueBatch {
|
||||
/// [`Value`]s serialized in EphemeralFile's native format,
|
||||
/// ready for disk write by the pageserver
|
||||
pub raw: Vec<u8>,
|
||||
|
||||
/// Metadata to make sense of the bytes in [`Self::raw`]
|
||||
/// and represent "observed" values.
|
||||
///
|
||||
/// Invariant: Metadata entries for any given key are ordered
|
||||
/// by LSN. Note that entries for a key do not have to be contiguous.
|
||||
pub metadata: Vec<ValueMeta>,
|
||||
|
||||
/// The highest LSN of any value in the batch
|
||||
pub max_lsn: Lsn,
|
||||
|
||||
/// Number of values encoded by [`Self::raw`]
|
||||
pub len: usize,
|
||||
}
|
||||
|
||||
impl Default for SerializedValueBatch {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
raw: Default::default(),
|
||||
metadata: Default::default(),
|
||||
max_lsn: Lsn(0),
|
||||
len: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializedValueBatch {
|
||||
/// Build a batch of serialized values from a decoded PG WAL record
|
||||
///
|
||||
/// The batch will only contain values for keys targeting the specifiec
|
||||
/// shard. Shard 0 is a special case, where any keys that don't belong to
|
||||
/// it are "observed" by the batch (i.e. present in [`SerializedValueBatch::metadata`],
|
||||
/// but absent from the raw buffer [`SerializedValueBatch::raw`]).
|
||||
pub(crate) fn from_decoded_filtered(
|
||||
decoded: DecodedWALRecord,
|
||||
shard: &ShardIdentity,
|
||||
record_end_lsn: Lsn,
|
||||
pg_version: u32,
|
||||
) -> anyhow::Result<SerializedValueBatch> {
|
||||
// First determine how big the buffer needs to be and allocate it up-front.
|
||||
// This duplicates some of the work below, but it's empirically much faster.
|
||||
let estimated_buffer_size = Self::estimate_buffer_size(&decoded, shard, pg_version);
|
||||
let mut buf = Vec::<u8>::with_capacity(estimated_buffer_size);
|
||||
|
||||
let mut metadata: Vec<ValueMeta> = Vec::with_capacity(decoded.blocks.len());
|
||||
let mut max_lsn: Lsn = Lsn(0);
|
||||
let mut len: usize = 0;
|
||||
for blk in decoded.blocks.iter() {
|
||||
let relative_off = buf.len() as u64;
|
||||
|
||||
let rel = RelTag {
|
||||
spcnode: blk.rnode_spcnode,
|
||||
dbnode: blk.rnode_dbnode,
|
||||
relnode: blk.rnode_relnode,
|
||||
forknum: blk.forknum,
|
||||
};
|
||||
|
||||
let key = rel_block_to_key(rel, blk.blkno);
|
||||
|
||||
if !key.is_valid_key_on_write_path() {
|
||||
anyhow::bail!("Unsupported key decoded at LSN {}: {}", record_end_lsn, key);
|
||||
}
|
||||
|
||||
let key_is_local = shard.is_key_local(&key);
|
||||
|
||||
tracing::debug!(
|
||||
lsn=%record_end_lsn,
|
||||
key=%key,
|
||||
"ingest: shard decision {}",
|
||||
if !key_is_local { "drop" } else { "keep" },
|
||||
);
|
||||
|
||||
if !key_is_local {
|
||||
if shard.is_shard_zero() {
|
||||
// Shard 0 tracks relation sizes. Although we will not store this block, we will observe
|
||||
// its blkno in case it implicitly extends a relation.
|
||||
metadata.push(ValueMeta::Observed(ObservedValueMeta {
|
||||
key: key.to_compact(),
|
||||
lsn: record_end_lsn,
|
||||
}))
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// Instead of storing full-page-image WAL record,
|
||||
// it is better to store extracted image: we can skip wal-redo
|
||||
// in this case. Also some FPI records may contain multiple (up to 32) pages,
|
||||
// so them have to be copied multiple times.
|
||||
//
|
||||
let val = if Self::block_is_image(&decoded, blk, pg_version) {
|
||||
// Extract page image from FPI record
|
||||
let img_len = blk.bimg_len as usize;
|
||||
let img_offs = blk.bimg_offset as usize;
|
||||
let mut image = BytesMut::with_capacity(BLCKSZ as usize);
|
||||
// TODO(vlad): skip the copy
|
||||
image.extend_from_slice(&decoded.record[img_offs..img_offs + img_len]);
|
||||
|
||||
if blk.hole_length != 0 {
|
||||
let tail = image.split_off(blk.hole_offset as usize);
|
||||
image.resize(image.len() + blk.hole_length as usize, 0u8);
|
||||
image.unsplit(tail);
|
||||
}
|
||||
//
|
||||
// Match the logic of XLogReadBufferForRedoExtended:
|
||||
// The page may be uninitialized. If so, we can't set the LSN because
|
||||
// that would corrupt the page.
|
||||
//
|
||||
if !page_is_new(&image) {
|
||||
page_set_lsn(&mut image, record_end_lsn)
|
||||
}
|
||||
assert_eq!(image.len(), BLCKSZ as usize);
|
||||
|
||||
Value::Image(image.freeze())
|
||||
} else {
|
||||
Value::WalRecord(NeonWalRecord::Postgres {
|
||||
will_init: blk.will_init || blk.apply_image,
|
||||
rec: decoded.record.clone(),
|
||||
})
|
||||
};
|
||||
|
||||
val.ser_into(&mut buf)
|
||||
.expect("Writing into in-memory buffer is infallible");
|
||||
|
||||
let val_ser_size = buf.len() - relative_off as usize;
|
||||
|
||||
metadata.push(ValueMeta::Serialized(SerializedValueMeta {
|
||||
key: key.to_compact(),
|
||||
lsn: record_end_lsn,
|
||||
batch_offset: relative_off,
|
||||
len: val_ser_size,
|
||||
will_init: val.will_init(),
|
||||
}));
|
||||
max_lsn = std::cmp::max(max_lsn, record_end_lsn);
|
||||
len += 1;
|
||||
}
|
||||
|
||||
if cfg!(any(debug_assertions, test)) {
|
||||
let batch = Self {
|
||||
raw: buf,
|
||||
metadata,
|
||||
max_lsn,
|
||||
len,
|
||||
};
|
||||
|
||||
batch.validate_lsn_order();
|
||||
|
||||
return Ok(batch);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
raw: buf,
|
||||
metadata,
|
||||
max_lsn,
|
||||
len,
|
||||
})
|
||||
}
|
||||
|
||||
/// Look into the decoded PG WAL record and determine
|
||||
/// roughly how large the buffer for serialized values needs to be.
|
||||
fn estimate_buffer_size(
|
||||
decoded: &DecodedWALRecord,
|
||||
shard: &ShardIdentity,
|
||||
pg_version: u32,
|
||||
) -> usize {
|
||||
let mut estimate: usize = 0;
|
||||
|
||||
for blk in decoded.blocks.iter() {
|
||||
let rel = RelTag {
|
||||
spcnode: blk.rnode_spcnode,
|
||||
dbnode: blk.rnode_dbnode,
|
||||
relnode: blk.rnode_relnode,
|
||||
forknum: blk.forknum,
|
||||
};
|
||||
|
||||
let key = rel_block_to_key(rel, blk.blkno);
|
||||
|
||||
if !shard.is_key_local(&key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if Self::block_is_image(decoded, blk, pg_version) {
|
||||
// 4 bytes for the Value::Image discriminator
|
||||
// 8 bytes for encoding the size of the buffer
|
||||
// BLCKSZ for the raw image
|
||||
estimate += (4 + 8 + BLCKSZ) as usize;
|
||||
} else {
|
||||
// 4 bytes for the Value::WalRecord discriminator
|
||||
// 4 bytes for the NeonWalRecord::Postgres discriminator
|
||||
// 1 bytes for NeonWalRecord::Postgres::will_init
|
||||
// 8 bytes for encoding the size of the buffer
|
||||
// length of the raw record
|
||||
estimate += 8 + 1 + 8 + decoded.record.len();
|
||||
}
|
||||
}
|
||||
|
||||
estimate
|
||||
}
|
||||
|
||||
fn block_is_image(decoded: &DecodedWALRecord, blk: &DecodedBkpBlock, pg_version: u32) -> bool {
|
||||
blk.apply_image
|
||||
&& blk.has_image
|
||||
&& decoded.xl_rmid == pg_constants::RM_XLOG_ID
|
||||
&& (decoded.xl_info == pg_constants::XLOG_FPI
|
||||
|| decoded.xl_info == pg_constants::XLOG_FPI_FOR_HINT)
|
||||
// compression of WAL is not yet supported: fall back to storing the original WAL record
|
||||
&& !postgres_ffi::bkpimage_is_compressed(blk.bimg_info, pg_version)
|
||||
// do not materialize null pages because them most likely be soon replaced with real data
|
||||
&& blk.bimg_len != 0
|
||||
}
|
||||
|
||||
/// Encode a list of values and metadata into a serialized batch
|
||||
///
|
||||
/// This is used by the pageserver ingest code to conveniently generate
|
||||
/// batches for metadata writes.
|
||||
pub fn from_values(batch: Vec<(CompactKey, Lsn, usize, Value)>) -> Self {
|
||||
// Pre-allocate a big flat buffer to write into. This should be large but not huge: it is soft-limited in practice by
|
||||
// [`crate::pgdatadir_mapping::DatadirModification::MAX_PENDING_BYTES`]
|
||||
let buffer_size = batch.iter().map(|i| i.2).sum::<usize>();
|
||||
let mut buf = Vec::<u8>::with_capacity(buffer_size);
|
||||
|
||||
let mut metadata: Vec<ValueMeta> = Vec::with_capacity(batch.len());
|
||||
let mut max_lsn: Lsn = Lsn(0);
|
||||
let len = batch.len();
|
||||
for (key, lsn, val_ser_size, val) in batch {
|
||||
let relative_off = buf.len() as u64;
|
||||
|
||||
val.ser_into(&mut buf)
|
||||
.expect("Writing into in-memory buffer is infallible");
|
||||
|
||||
metadata.push(ValueMeta::Serialized(SerializedValueMeta {
|
||||
key,
|
||||
lsn,
|
||||
batch_offset: relative_off,
|
||||
len: val_ser_size,
|
||||
will_init: val.will_init(),
|
||||
}));
|
||||
max_lsn = std::cmp::max(max_lsn, lsn);
|
||||
}
|
||||
|
||||
// Assert that we didn't do any extra allocations while building buffer.
|
||||
debug_assert!(buf.len() <= buffer_size);
|
||||
|
||||
if cfg!(any(debug_assertions, test)) {
|
||||
let batch = Self {
|
||||
raw: buf,
|
||||
metadata,
|
||||
max_lsn,
|
||||
len,
|
||||
};
|
||||
|
||||
batch.validate_lsn_order();
|
||||
|
||||
return batch;
|
||||
}
|
||||
|
||||
Self {
|
||||
raw: buf,
|
||||
metadata,
|
||||
max_lsn,
|
||||
len,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add one value to the batch
|
||||
///
|
||||
/// This is used by the pageserver ingest code to include metadata block
|
||||
/// updates for a single key.
|
||||
pub fn put(&mut self, key: CompactKey, value: Value, lsn: Lsn) {
|
||||
let relative_off = self.raw.len() as u64;
|
||||
value.ser_into(&mut self.raw).unwrap();
|
||||
|
||||
let val_ser_size = self.raw.len() - relative_off as usize;
|
||||
self.metadata
|
||||
.push(ValueMeta::Serialized(SerializedValueMeta {
|
||||
key,
|
||||
lsn,
|
||||
batch_offset: relative_off,
|
||||
len: val_ser_size,
|
||||
will_init: value.will_init(),
|
||||
}));
|
||||
|
||||
self.max_lsn = std::cmp::max(self.max_lsn, lsn);
|
||||
self.len += 1;
|
||||
|
||||
if cfg!(any(debug_assertions, test)) {
|
||||
self.validate_lsn_order();
|
||||
}
|
||||
}
|
||||
|
||||
/// Extend with the contents of another batch
|
||||
///
|
||||
/// One batch is generated for each decoded PG WAL record.
|
||||
/// They are then merged to accumulate reasonably sized writes.
|
||||
pub fn extend(&mut self, mut other: SerializedValueBatch) {
|
||||
let extend_batch_start_offset = self.raw.len() as u64;
|
||||
|
||||
self.raw.extend(other.raw);
|
||||
|
||||
// Shift the offsets in the batch we are extending with
|
||||
other.metadata.iter_mut().for_each(|meta| match meta {
|
||||
ValueMeta::Serialized(ser) => {
|
||||
ser.batch_offset += extend_batch_start_offset;
|
||||
if cfg!(debug_assertions) {
|
||||
let value_end = ser.batch_offset + ser.len as u64;
|
||||
assert!((value_end as usize) <= self.raw.len());
|
||||
}
|
||||
}
|
||||
ValueMeta::Observed(_) => {}
|
||||
});
|
||||
self.metadata.extend(other.metadata);
|
||||
|
||||
self.max_lsn = std::cmp::max(self.max_lsn, other.max_lsn);
|
||||
|
||||
self.len += other.len;
|
||||
|
||||
if cfg!(any(debug_assertions, test)) {
|
||||
self.validate_lsn_order();
|
||||
}
|
||||
}
|
||||
|
||||
/// Add zero images for the (key, LSN) tuples specified
|
||||
///
|
||||
/// PG versions below 16 do not zero out pages before extending
|
||||
/// a relation and may leave gaps. Such gaps need to be identified
|
||||
/// by the pageserver ingest logic and get patched up here.
|
||||
///
|
||||
/// Note that this function does not validate that the gaps have been
|
||||
/// identified correctly (it does not know relation sizes), so it's up
|
||||
/// to the call-site to do it properly.
|
||||
pub fn zero_gaps(&mut self, gaps: Vec<(KeySpace, Lsn)>) {
|
||||
// Implementation note:
|
||||
//
|
||||
// Values within [`SerializedValueBatch::raw`] do not have any ordering requirements,
|
||||
// but the metadata entries should be ordered properly (see
|
||||
// [`SerializedValueBatch::metadata`]).
|
||||
//
|
||||
// Exploiting this observation we do:
|
||||
// 1. Drain all the metadata entries into an ordered set.
|
||||
// The use of a BTreeSet keyed by (Key, Lsn) relies on the observation that Postgres never
|
||||
// includes more than one update to the same block in the same WAL record.
|
||||
// 2. For each (key, LSN) gap tuple, append a zero image to the raw buffer
|
||||
// and add an index entry to the ordered metadata set.
|
||||
// 3. Drain the ordered set back into a metadata vector
|
||||
|
||||
let mut ordered_metas = self
|
||||
.metadata
|
||||
.drain(..)
|
||||
.map(OrderedValueMeta)
|
||||
.collect::<BTreeSet<_>>();
|
||||
for (keyspace, lsn) in gaps {
|
||||
self.max_lsn = std::cmp::max(self.max_lsn, lsn);
|
||||
|
||||
for gap_range in keyspace.ranges {
|
||||
let mut key = gap_range.start;
|
||||
while key != gap_range.end {
|
||||
let relative_off = self.raw.len() as u64;
|
||||
|
||||
// TODO(vlad): Can we be cheeky and write only one zero image, and
|
||||
// make all index entries requiring a zero page point to it?
|
||||
// Alternatively, we can change the index entry format to represent zero pages
|
||||
// without writing them at all.
|
||||
Value::Image(ZERO_PAGE.clone())
|
||||
.ser_into(&mut self.raw)
|
||||
.unwrap();
|
||||
let val_ser_size = self.raw.len() - relative_off as usize;
|
||||
|
||||
ordered_metas.insert(OrderedValueMeta(ValueMeta::Serialized(
|
||||
SerializedValueMeta {
|
||||
key: key.to_compact(),
|
||||
lsn,
|
||||
batch_offset: relative_off,
|
||||
len: val_ser_size,
|
||||
will_init: true,
|
||||
},
|
||||
)));
|
||||
|
||||
self.len += 1;
|
||||
|
||||
key = key.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.metadata = ordered_metas.into_iter().map(|ord| ord.0).collect();
|
||||
|
||||
if cfg!(any(debug_assertions, test)) {
|
||||
self.validate_lsn_order();
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if the batch is empty
|
||||
///
|
||||
/// A batch is empty when it contains no serialized values.
|
||||
/// Note that it may still contain observed values.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
let empty = self.raw.is_empty();
|
||||
|
||||
if cfg!(debug_assertions) && empty {
|
||||
assert!(self
|
||||
.metadata
|
||||
.iter()
|
||||
.all(|meta| matches!(meta, ValueMeta::Observed(_))));
|
||||
}
|
||||
|
||||
empty
|
||||
}
|
||||
|
||||
/// Returns the number of values serialized in the batch
|
||||
pub fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Returns the size of the buffer wrapped by the batch
|
||||
pub fn buffer_size(&self) -> usize {
|
||||
self.raw.len()
|
||||
}
|
||||
|
||||
pub fn updates_key(&self, key: &Key) -> bool {
|
||||
self.metadata.iter().any(|meta| match meta {
|
||||
ValueMeta::Serialized(ser) => key.to_compact() == ser.key,
|
||||
ValueMeta::Observed(_) => false,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn validate_lsn_order(&self) {
|
||||
use std::collections::HashMap;
|
||||
|
||||
let mut last_seen_lsn_per_key: HashMap<CompactKey, Lsn> = HashMap::default();
|
||||
|
||||
for meta in self.metadata.iter() {
|
||||
let lsn = meta.lsn();
|
||||
let key = meta.key();
|
||||
|
||||
if let Some(prev_lsn) = last_seen_lsn_per_key.insert(key, lsn) {
|
||||
assert!(
|
||||
lsn >= prev_lsn,
|
||||
"Ordering violated by {}: {} < {}",
|
||||
Key::from_compact(key),
|
||||
lsn,
|
||||
prev_lsn
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "testing"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn validate_batch(
|
||||
batch: &SerializedValueBatch,
|
||||
values: &[(CompactKey, Lsn, usize, Value)],
|
||||
gaps: Option<&Vec<(KeySpace, Lsn)>>,
|
||||
) {
|
||||
// Invariant 1: The metadata for a given entry in the batch
|
||||
// is correct and can be used to deserialize back to the original value.
|
||||
for (key, lsn, size, value) in values.iter() {
|
||||
let meta = batch
|
||||
.metadata
|
||||
.iter()
|
||||
.find(|meta| (meta.key(), meta.lsn()) == (*key, *lsn))
|
||||
.unwrap();
|
||||
let meta = match meta {
|
||||
ValueMeta::Serialized(ser) => ser,
|
||||
ValueMeta::Observed(_) => unreachable!(),
|
||||
};
|
||||
|
||||
assert_eq!(meta.len, *size);
|
||||
assert_eq!(meta.will_init, value.will_init());
|
||||
|
||||
let start = meta.batch_offset as usize;
|
||||
let end = meta.batch_offset as usize + meta.len;
|
||||
let value_from_batch = Value::des(&batch.raw[start..end]).unwrap();
|
||||
assert_eq!(&value_from_batch, value);
|
||||
}
|
||||
|
||||
let mut expected_buffer_size: usize = values.iter().map(|(_, _, size, _)| size).sum();
|
||||
let mut gap_pages_count: usize = 0;
|
||||
|
||||
// Invariant 2: Zero pages were added for identified gaps and their metadata
|
||||
// is correct.
|
||||
if let Some(gaps) = gaps {
|
||||
for (gap_keyspace, lsn) in gaps {
|
||||
for gap_range in &gap_keyspace.ranges {
|
||||
let mut gap_key = gap_range.start;
|
||||
while gap_key != gap_range.end {
|
||||
let meta = batch
|
||||
.metadata
|
||||
.iter()
|
||||
.find(|meta| (meta.key(), meta.lsn()) == (gap_key.to_compact(), *lsn))
|
||||
.unwrap();
|
||||
let meta = match meta {
|
||||
ValueMeta::Serialized(ser) => ser,
|
||||
ValueMeta::Observed(_) => unreachable!(),
|
||||
};
|
||||
|
||||
let zero_value = Value::Image(ZERO_PAGE.clone());
|
||||
let zero_value_size = zero_value.serialized_size().unwrap() as usize;
|
||||
|
||||
assert_eq!(meta.len, zero_value_size);
|
||||
assert_eq!(meta.will_init, zero_value.will_init());
|
||||
|
||||
let start = meta.batch_offset as usize;
|
||||
let end = meta.batch_offset as usize + meta.len;
|
||||
let value_from_batch = Value::des(&batch.raw[start..end]).unwrap();
|
||||
assert_eq!(value_from_batch, zero_value);
|
||||
|
||||
gap_pages_count += 1;
|
||||
expected_buffer_size += zero_value_size;
|
||||
gap_key = gap_key.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Invariant 3: The length of the batch is equal to the number
|
||||
// of values inserted, plus the number of gap pages. This extends
|
||||
// to the raw buffer size.
|
||||
assert_eq!(batch.len(), values.len() + gap_pages_count);
|
||||
assert_eq!(expected_buffer_size, batch.buffer_size());
|
||||
|
||||
// Invariant 4: Metadata entries for any given key are sorted in LSN order.
|
||||
batch.validate_lsn_order();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_creation_from_values() {
|
||||
const LSN: Lsn = Lsn(0x10);
|
||||
let key = Key::from_hex("110000000033333333444444445500000001").unwrap();
|
||||
|
||||
let values = vec![
|
||||
(
|
||||
key.to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("foo")),
|
||||
),
|
||||
(
|
||||
key.next().to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("bar")),
|
||||
),
|
||||
(
|
||||
key.to_compact(),
|
||||
Lsn(LSN.0 + 0x10),
|
||||
Value::WalRecord(NeonWalRecord::wal_append("baz")),
|
||||
),
|
||||
(
|
||||
key.next().next().to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("taz")),
|
||||
),
|
||||
];
|
||||
|
||||
let values = values
|
||||
.into_iter()
|
||||
.map(|(key, lsn, value)| (key, lsn, value.serialized_size().unwrap() as usize, value))
|
||||
.collect::<Vec<_>>();
|
||||
let batch = SerializedValueBatch::from_values(values.clone());
|
||||
|
||||
validate_batch(&batch, &values, None);
|
||||
|
||||
assert!(!batch.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_put() {
|
||||
const LSN: Lsn = Lsn(0x10);
|
||||
let key = Key::from_hex("110000000033333333444444445500000001").unwrap();
|
||||
|
||||
let values = vec![
|
||||
(
|
||||
key.to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("foo")),
|
||||
),
|
||||
(
|
||||
key.next().to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("bar")),
|
||||
),
|
||||
];
|
||||
|
||||
let mut values = values
|
||||
.into_iter()
|
||||
.map(|(key, lsn, value)| (key, lsn, value.serialized_size().unwrap() as usize, value))
|
||||
.collect::<Vec<_>>();
|
||||
let mut batch = SerializedValueBatch::from_values(values.clone());
|
||||
|
||||
validate_batch(&batch, &values, None);
|
||||
|
||||
let value = (
|
||||
key.to_compact(),
|
||||
Lsn(LSN.0 + 0x10),
|
||||
Value::WalRecord(NeonWalRecord::wal_append("baz")),
|
||||
);
|
||||
let serialized_size = value.2.serialized_size().unwrap() as usize;
|
||||
let value = (value.0, value.1, serialized_size, value.2);
|
||||
values.push(value.clone());
|
||||
batch.put(value.0, value.3, value.1);
|
||||
|
||||
validate_batch(&batch, &values, None);
|
||||
|
||||
let value = (
|
||||
key.next().next().to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("taz")),
|
||||
);
|
||||
let serialized_size = value.2.serialized_size().unwrap() as usize;
|
||||
let value = (value.0, value.1, serialized_size, value.2);
|
||||
values.push(value.clone());
|
||||
batch.put(value.0, value.3, value.1);
|
||||
|
||||
validate_batch(&batch, &values, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extension() {
|
||||
const LSN: Lsn = Lsn(0x10);
|
||||
let key = Key::from_hex("110000000033333333444444445500000001").unwrap();
|
||||
|
||||
let values = vec![
|
||||
(
|
||||
key.to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("foo")),
|
||||
),
|
||||
(
|
||||
key.next().to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("bar")),
|
||||
),
|
||||
(
|
||||
key.next().next().to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("taz")),
|
||||
),
|
||||
];
|
||||
|
||||
let mut values = values
|
||||
.into_iter()
|
||||
.map(|(key, lsn, value)| (key, lsn, value.serialized_size().unwrap() as usize, value))
|
||||
.collect::<Vec<_>>();
|
||||
let mut batch = SerializedValueBatch::from_values(values.clone());
|
||||
|
||||
let other_values = vec![
|
||||
(
|
||||
key.to_compact(),
|
||||
Lsn(LSN.0 + 0x10),
|
||||
Value::WalRecord(NeonWalRecord::wal_append("foo")),
|
||||
),
|
||||
(
|
||||
key.next().to_compact(),
|
||||
Lsn(LSN.0 + 0x10),
|
||||
Value::WalRecord(NeonWalRecord::wal_append("bar")),
|
||||
),
|
||||
(
|
||||
key.next().next().to_compact(),
|
||||
Lsn(LSN.0 + 0x10),
|
||||
Value::WalRecord(NeonWalRecord::wal_append("taz")),
|
||||
),
|
||||
];
|
||||
|
||||
let other_values = other_values
|
||||
.into_iter()
|
||||
.map(|(key, lsn, value)| (key, lsn, value.serialized_size().unwrap() as usize, value))
|
||||
.collect::<Vec<_>>();
|
||||
let other_batch = SerializedValueBatch::from_values(other_values.clone());
|
||||
|
||||
values.extend(other_values);
|
||||
batch.extend(other_batch);
|
||||
|
||||
validate_batch(&batch, &values, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gap_zeroing() {
|
||||
const LSN: Lsn = Lsn(0x10);
|
||||
let rel_foo_base_key = Key::from_hex("110000000033333333444444445500000001").unwrap();
|
||||
|
||||
let rel_bar_base_key = {
|
||||
let mut key = rel_foo_base_key;
|
||||
key.field4 += 1;
|
||||
key
|
||||
};
|
||||
|
||||
let values = vec![
|
||||
(
|
||||
rel_foo_base_key.to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("foo1")),
|
||||
),
|
||||
(
|
||||
rel_foo_base_key.add(1).to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("foo2")),
|
||||
),
|
||||
(
|
||||
rel_foo_base_key.add(5).to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("foo3")),
|
||||
),
|
||||
(
|
||||
rel_foo_base_key.add(1).to_compact(),
|
||||
Lsn(LSN.0 + 0x10),
|
||||
Value::WalRecord(NeonWalRecord::wal_append("foo4")),
|
||||
),
|
||||
(
|
||||
rel_foo_base_key.add(10).to_compact(),
|
||||
Lsn(LSN.0 + 0x10),
|
||||
Value::WalRecord(NeonWalRecord::wal_append("foo5")),
|
||||
),
|
||||
(
|
||||
rel_foo_base_key.add(11).to_compact(),
|
||||
Lsn(LSN.0 + 0x10),
|
||||
Value::WalRecord(NeonWalRecord::wal_append("foo6")),
|
||||
),
|
||||
(
|
||||
rel_foo_base_key.add(12).to_compact(),
|
||||
Lsn(LSN.0 + 0x10),
|
||||
Value::WalRecord(NeonWalRecord::wal_append("foo7")),
|
||||
),
|
||||
(
|
||||
rel_bar_base_key.to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("bar1")),
|
||||
),
|
||||
(
|
||||
rel_bar_base_key.add(4).to_compact(),
|
||||
LSN,
|
||||
Value::WalRecord(NeonWalRecord::wal_append("bar2")),
|
||||
),
|
||||
];
|
||||
|
||||
let values = values
|
||||
.into_iter()
|
||||
.map(|(key, lsn, value)| (key, lsn, value.serialized_size().unwrap() as usize, value))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut batch = SerializedValueBatch::from_values(values.clone());
|
||||
|
||||
let gaps = vec![
|
||||
(
|
||||
KeySpace {
|
||||
ranges: vec![
|
||||
rel_foo_base_key.add(2)..rel_foo_base_key.add(5),
|
||||
rel_bar_base_key.add(1)..rel_bar_base_key.add(4),
|
||||
],
|
||||
},
|
||||
LSN,
|
||||
),
|
||||
(
|
||||
KeySpace {
|
||||
ranges: vec![rel_foo_base_key.add(6)..rel_foo_base_key.add(10)],
|
||||
},
|
||||
Lsn(LSN.0 + 0x10),
|
||||
),
|
||||
];
|
||||
|
||||
batch.zero_gaps(gaps.clone());
|
||||
validate_batch(&batch, &values, Some(&gaps));
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,6 @@ use pageserver::{
|
||||
l0_flush::{L0FlushConfig, L0FlushGlobalState},
|
||||
page_cache,
|
||||
task_mgr::TaskKind,
|
||||
tenant::storage_layer::inmemory_layer::SerializedBatch,
|
||||
tenant::storage_layer::InMemoryLayer,
|
||||
virtual_file,
|
||||
};
|
||||
@@ -18,6 +17,7 @@ use utils::{
|
||||
bin_ser::BeSer,
|
||||
id::{TenantId, TimelineId},
|
||||
};
|
||||
use wal_decoder::serialized_batch::SerializedValueBatch;
|
||||
|
||||
// A very cheap hash for generating non-sequential keys.
|
||||
fn murmurhash32(mut h: u32) -> u32 {
|
||||
@@ -102,13 +102,13 @@ async fn ingest(
|
||||
batch.push((key.to_compact(), lsn, data_ser_size, data.clone()));
|
||||
if batch.len() >= BATCH_SIZE {
|
||||
let this_batch = std::mem::take(&mut batch);
|
||||
let serialized = SerializedBatch::from_values(this_batch).unwrap();
|
||||
let serialized = SerializedValueBatch::from_values(this_batch);
|
||||
layer.put_batch(serialized, &ctx).await?;
|
||||
}
|
||||
}
|
||||
if !batch.is_empty() {
|
||||
let this_batch = std::mem::take(&mut batch);
|
||||
let serialized = SerializedBatch::from_values(this_batch).unwrap();
|
||||
let serialized = SerializedValueBatch::from_values(this_batch);
|
||||
layer.put_batch(serialized, &ctx).await?;
|
||||
}
|
||||
layer.freeze(lsn + 1).await;
|
||||
|
||||
@@ -24,6 +24,7 @@ use pageserver_api::key::{
|
||||
use pageserver_api::keyspace::SparseKeySpace;
|
||||
use pageserver_api::record::NeonWalRecord;
|
||||
use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind};
|
||||
use pageserver_api::shard::ShardIdentity;
|
||||
use pageserver_api::value::Value;
|
||||
use postgres_ffi::relfile_utils::{FSM_FORKNUM, VISIBILITYMAP_FORKNUM};
|
||||
use postgres_ffi::BLCKSZ;
|
||||
@@ -38,6 +39,7 @@ use tracing::{debug, trace, warn};
|
||||
use utils::bin_ser::DeserializeError;
|
||||
use utils::pausable_failpoint;
|
||||
use utils::{bin_ser::BeSer, lsn::Lsn};
|
||||
use wal_decoder::serialized_batch::SerializedValueBatch;
|
||||
|
||||
/// Max delta records appended to the AUX_FILES_KEY (for aux v1). The write path will write a full image once this threshold is reached.
|
||||
pub const MAX_AUX_FILE_DELTAS: usize = 1024;
|
||||
@@ -170,12 +172,11 @@ impl Timeline {
|
||||
tline: self,
|
||||
pending_lsns: Vec::new(),
|
||||
pending_metadata_pages: HashMap::new(),
|
||||
pending_data_pages: Vec::new(),
|
||||
pending_zero_data_pages: Default::default(),
|
||||
pending_data_batch: None,
|
||||
pending_deletions: Vec::new(),
|
||||
pending_nblocks: 0,
|
||||
pending_directory_entries: Vec::new(),
|
||||
pending_bytes: 0,
|
||||
pending_metadata_bytes: 0,
|
||||
lsn,
|
||||
}
|
||||
}
|
||||
@@ -1025,21 +1026,14 @@ pub struct DatadirModification<'a> {
|
||||
|
||||
/// Data writes, ready to be flushed into an ephemeral layer. See [`Self::is_data_key`] for
|
||||
/// which keys are stored here.
|
||||
pending_data_pages: Vec<(CompactKey, Lsn, usize, Value)>,
|
||||
|
||||
// Sometimes during ingest, for example when extending a relation, we would like to write a zero page. However,
|
||||
// if we encounter a write from postgres in the same wal record, we will drop this entry.
|
||||
//
|
||||
// Unlike other 'pending' fields, this does not last until the next call to commit(): it is flushed
|
||||
// at the end of each wal record, and all these writes implicitly are at lsn Self::lsn
|
||||
pending_zero_data_pages: HashSet<CompactKey>,
|
||||
pending_data_batch: Option<SerializedValueBatch>,
|
||||
|
||||
/// For special "directory" keys that store key-value maps, track the size of the map
|
||||
/// if it was updated in this modification.
|
||||
pending_directory_entries: Vec<(DirectoryKind, usize)>,
|
||||
|
||||
/// An **approximation** of how large our EphemeralFile write will be when committed.
|
||||
pending_bytes: usize,
|
||||
/// An **approximation** of how many metadata bytes will be written to the EphemeralFile.
|
||||
pending_metadata_bytes: usize,
|
||||
}
|
||||
|
||||
impl<'a> DatadirModification<'a> {
|
||||
@@ -1054,11 +1048,17 @@ impl<'a> DatadirModification<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn approx_pending_bytes(&self) -> usize {
|
||||
self.pending_bytes
|
||||
self.pending_data_batch
|
||||
.as_ref()
|
||||
.map_or(0, |b| b.buffer_size())
|
||||
+ self.pending_metadata_bytes
|
||||
}
|
||||
|
||||
pub(crate) fn has_dirty_data_pages(&self) -> bool {
|
||||
(!self.pending_data_pages.is_empty()) || (!self.pending_zero_data_pages.is_empty())
|
||||
pub(crate) fn has_dirty_data(&self) -> bool {
|
||||
!self
|
||||
.pending_data_batch
|
||||
.as_ref()
|
||||
.map_or(true, |b| b.is_empty())
|
||||
}
|
||||
|
||||
/// Set the current lsn
|
||||
@@ -1070,9 +1070,6 @@ impl<'a> DatadirModification<'a> {
|
||||
self.lsn
|
||||
);
|
||||
|
||||
// If we are advancing LSN, then state from previous wal record should have been flushed.
|
||||
assert!(self.pending_zero_data_pages.is_empty());
|
||||
|
||||
if lsn > self.lsn {
|
||||
self.pending_lsns.push(self.lsn);
|
||||
self.lsn = lsn;
|
||||
@@ -1147,6 +1144,116 @@ impl<'a> DatadirModification<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Creates a relation if it is not already present.
|
||||
/// Returns the current size of the relation
|
||||
pub(crate) async fn create_relation_if_required(
|
||||
&mut self,
|
||||
rel: RelTag,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<u32, PageReconstructError> {
|
||||
// Get current size and put rel creation if rel doesn't exist
|
||||
//
|
||||
// NOTE: we check the cache first even though get_rel_exists and get_rel_size would
|
||||
// check the cache too. This is because eagerly checking the cache results in
|
||||
// less work overall and 10% better performance. It's more work on cache miss
|
||||
// but cache miss is rare.
|
||||
if let Some(nblocks) = self.tline.get_cached_rel_size(&rel, self.get_lsn()) {
|
||||
Ok(nblocks)
|
||||
} else if !self
|
||||
.tline
|
||||
.get_rel_exists(rel, Version::Modified(self), ctx)
|
||||
.await?
|
||||
{
|
||||
tracing::debug!("Creating relation {rel:?} at lsn {}", self.get_lsn());
|
||||
|
||||
// create it with 0 size initially, the logic below will extend it
|
||||
self.put_rel_creation(rel, 0, ctx)
|
||||
.await
|
||||
.context("Relation Error")?;
|
||||
Ok(0)
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"Skipping relation {rel:?} creation at lsn {}",
|
||||
self.get_lsn()
|
||||
);
|
||||
|
||||
self.tline
|
||||
.get_rel_size(rel, Version::Modified(self), ctx)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Given a block number for a relation (which represents a newly written block),
|
||||
/// the previous block count of the relation, and the shard info, find the gaps
|
||||
/// that were created by the newly written block if any.
|
||||
fn find_gaps(
|
||||
rel: RelTag,
|
||||
blkno: u32,
|
||||
previous_nblocks: u32,
|
||||
shard: &ShardIdentity,
|
||||
) -> Option<KeySpace> {
|
||||
let mut key = rel_block_to_key(rel, blkno);
|
||||
let mut gap_accum = None;
|
||||
|
||||
for gap_blkno in previous_nblocks..blkno {
|
||||
key.field6 = gap_blkno;
|
||||
|
||||
if shard.get_shard_number(&key) != shard.number {
|
||||
continue;
|
||||
}
|
||||
|
||||
gap_accum
|
||||
.get_or_insert_with(KeySpaceAccum::new)
|
||||
.add_key(key);
|
||||
}
|
||||
|
||||
gap_accum.map(|accum| accum.to_keyspace())
|
||||
}
|
||||
|
||||
pub async fn ingest_batch(
|
||||
&mut self,
|
||||
mut batch: SerializedValueBatch,
|
||||
// TODO(vlad): remove this argument and replace the shard check with is_key_local
|
||||
shard: &ShardIdentity,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<()> {
|
||||
tracing::debug!("Ingesting batch with metadata: {:?}", batch.metadata);
|
||||
|
||||
let mut gaps_at_lsns = Vec::default();
|
||||
|
||||
for meta in batch.metadata.iter() {
|
||||
let (rel, blkno) = Key::from_compact(meta.key()).to_rel_block()?;
|
||||
let new_nblocks = blkno + 1;
|
||||
|
||||
let old_nblocks = self.create_relation_if_required(rel, ctx).await?;
|
||||
if new_nblocks > old_nblocks {
|
||||
self.put_rel_extend(rel, new_nblocks, ctx).await?;
|
||||
}
|
||||
|
||||
if let Some(gaps) = Self::find_gaps(rel, blkno, old_nblocks, shard) {
|
||||
gaps_at_lsns.push((gaps, meta.lsn()));
|
||||
}
|
||||
}
|
||||
|
||||
if !gaps_at_lsns.is_empty() {
|
||||
batch.zero_gaps(gaps_at_lsns);
|
||||
}
|
||||
|
||||
match self.pending_data_batch.as_mut() {
|
||||
Some(pending_batch) => {
|
||||
pending_batch.extend(batch);
|
||||
}
|
||||
None if !batch.is_empty() => {
|
||||
self.pending_data_batch = Some(batch);
|
||||
}
|
||||
None => {
|
||||
// Nothing to initialize the batch with
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Put a new page version that can be constructed from a WAL record
|
||||
///
|
||||
/// NOTE: this will *not* implicitly extend the relation, if the page is beyond the
|
||||
@@ -1229,8 +1336,13 @@ impl<'a> DatadirModification<'a> {
|
||||
self.lsn
|
||||
);
|
||||
}
|
||||
self.pending_zero_data_pages.insert(key.to_compact());
|
||||
self.pending_bytes += ZERO_PAGE.len();
|
||||
|
||||
let batch = self
|
||||
.pending_data_batch
|
||||
.get_or_insert_with(SerializedValueBatch::default);
|
||||
|
||||
batch.put(key.to_compact(), Value::Image(ZERO_PAGE.clone()), self.lsn);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1248,17 +1360,14 @@ impl<'a> DatadirModification<'a> {
|
||||
self.lsn
|
||||
);
|
||||
}
|
||||
self.pending_zero_data_pages.insert(key.to_compact());
|
||||
self.pending_bytes += ZERO_PAGE.len();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Call this at the end of each WAL record.
|
||||
pub(crate) fn on_record_end(&mut self) {
|
||||
let pending_zero_data_pages = std::mem::take(&mut self.pending_zero_data_pages);
|
||||
for key in pending_zero_data_pages {
|
||||
self.put_data(key, Value::Image(ZERO_PAGE.clone()));
|
||||
}
|
||||
let batch = self
|
||||
.pending_data_batch
|
||||
.get_or_insert_with(SerializedValueBatch::default);
|
||||
|
||||
batch.put(key.to_compact(), Value::Image(ZERO_PAGE.clone()), self.lsn);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Store a relmapper file (pg_filenode.map) in the repository
|
||||
@@ -1750,12 +1859,17 @@ impl<'a> DatadirModification<'a> {
|
||||
let mut writer = self.tline.writer().await;
|
||||
|
||||
// Flush relation and SLRU data blocks, keep metadata.
|
||||
let pending_data_pages = std::mem::take(&mut self.pending_data_pages);
|
||||
if let Some(batch) = self.pending_data_batch.take() {
|
||||
tracing::debug!(
|
||||
"Flushing batch with max_lsn={}. Last record LSN is {}",
|
||||
batch.max_lsn,
|
||||
self.tline.get_last_record_lsn()
|
||||
);
|
||||
|
||||
// This bails out on first error without modifying pending_updates.
|
||||
// That's Ok, cf this function's doc comment.
|
||||
writer.put_batch(pending_data_pages, ctx).await?;
|
||||
self.pending_bytes = 0;
|
||||
// This bails out on first error without modifying pending_updates.
|
||||
// That's Ok, cf this function's doc comment.
|
||||
writer.put_batch(batch, ctx).await?;
|
||||
}
|
||||
|
||||
if pending_nblocks != 0 {
|
||||
writer.update_current_logical_size(pending_nblocks * i64::from(BLCKSZ));
|
||||
@@ -1775,9 +1889,6 @@ impl<'a> DatadirModification<'a> {
|
||||
/// All the modifications in this atomic update are stamped by the specified LSN.
|
||||
///
|
||||
pub async fn commit(&mut self, ctx: &RequestContext) -> anyhow::Result<()> {
|
||||
// Commit should never be called mid-wal-record
|
||||
assert!(self.pending_zero_data_pages.is_empty());
|
||||
|
||||
let mut writer = self.tline.writer().await;
|
||||
|
||||
let pending_nblocks = self.pending_nblocks;
|
||||
@@ -1785,21 +1896,49 @@ impl<'a> DatadirModification<'a> {
|
||||
|
||||
// Ordering: the items in this batch do not need to be in any global order, but values for
|
||||
// a particular Key must be in Lsn order relative to one another. InMemoryLayer relies on
|
||||
// this to do efficient updates to its index.
|
||||
let mut write_batch = std::mem::take(&mut self.pending_data_pages);
|
||||
// this to do efficient updates to its index. See [`wal_decoder::serialized_batch`] for
|
||||
// more details.
|
||||
|
||||
write_batch.extend(
|
||||
self.pending_metadata_pages
|
||||
let metadata_batch = {
|
||||
let pending_meta = self
|
||||
.pending_metadata_pages
|
||||
.drain()
|
||||
.flat_map(|(key, values)| {
|
||||
values
|
||||
.into_iter()
|
||||
.map(move |(lsn, value_size, value)| (key, lsn, value_size, value))
|
||||
}),
|
||||
);
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if !write_batch.is_empty() {
|
||||
writer.put_batch(write_batch, ctx).await?;
|
||||
if pending_meta.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(SerializedValueBatch::from_values(pending_meta))
|
||||
}
|
||||
};
|
||||
|
||||
let data_batch = self.pending_data_batch.take();
|
||||
|
||||
let maybe_batch = match (data_batch, metadata_batch) {
|
||||
(Some(mut data), Some(metadata)) => {
|
||||
data.extend(metadata);
|
||||
Some(data)
|
||||
}
|
||||
(Some(data), None) => Some(data),
|
||||
(None, Some(metadata)) => Some(metadata),
|
||||
(None, None) => None,
|
||||
};
|
||||
|
||||
if let Some(batch) = maybe_batch {
|
||||
tracing::debug!(
|
||||
"Flushing batch with max_lsn={}. Last record LSN is {}",
|
||||
batch.max_lsn,
|
||||
self.tline.get_last_record_lsn()
|
||||
);
|
||||
|
||||
// This bails out on first error without modifying pending_updates.
|
||||
// That's Ok, cf this function's doc comment.
|
||||
writer.put_batch(batch, ctx).await?;
|
||||
}
|
||||
|
||||
if !self.pending_deletions.is_empty() {
|
||||
@@ -1809,6 +1948,9 @@ impl<'a> DatadirModification<'a> {
|
||||
|
||||
self.pending_lsns.push(self.lsn);
|
||||
for pending_lsn in self.pending_lsns.drain(..) {
|
||||
// TODO(vlad): pretty sure the comment below is not valid anymore
|
||||
// and we can call finish write with the latest LSN
|
||||
//
|
||||
// Ideally, we should be able to call writer.finish_write() only once
|
||||
// with the highest LSN. However, the last_record_lsn variable in the
|
||||
// timeline keeps track of the latest LSN and the immediate previous LSN
|
||||
@@ -1824,14 +1966,14 @@ impl<'a> DatadirModification<'a> {
|
||||
writer.update_directory_entries_count(kind, count as u64);
|
||||
}
|
||||
|
||||
self.pending_bytes = 0;
|
||||
self.pending_metadata_bytes = 0;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.pending_metadata_pages.len()
|
||||
+ self.pending_data_pages.len()
|
||||
+ self.pending_data_batch.as_ref().map_or(0, |b| b.len())
|
||||
+ self.pending_deletions.len()
|
||||
}
|
||||
|
||||
@@ -1873,11 +2015,10 @@ impl<'a> DatadirModification<'a> {
|
||||
// modifications before ingesting DB create operations, which are the only kind that reads
|
||||
// data pages during ingest.
|
||||
if cfg!(debug_assertions) {
|
||||
for (dirty_key, _, _, _) in &self.pending_data_pages {
|
||||
debug_assert!(&key.to_compact() != dirty_key);
|
||||
}
|
||||
|
||||
debug_assert!(!self.pending_zero_data_pages.contains(&key.to_compact()))
|
||||
assert!(!self
|
||||
.pending_data_batch
|
||||
.as_ref()
|
||||
.map_or(false, |b| b.updates_key(&key)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1895,18 +2036,10 @@ impl<'a> DatadirModification<'a> {
|
||||
}
|
||||
|
||||
fn put_data(&mut self, key: CompactKey, val: Value) {
|
||||
let val_serialized_size = val.serialized_size().unwrap() as usize;
|
||||
|
||||
// If this page was previously zero'd in the same WalRecord, then drop the previous zero page write. This
|
||||
// is an optimization that avoids persisting both the zero page generated by us (e.g. during a relation extend),
|
||||
// and the subsequent postgres-originating write
|
||||
if self.pending_zero_data_pages.remove(&key) {
|
||||
self.pending_bytes -= ZERO_PAGE.len();
|
||||
}
|
||||
|
||||
self.pending_bytes += val_serialized_size;
|
||||
self.pending_data_pages
|
||||
.push((key, self.lsn, val_serialized_size, val))
|
||||
let batch = self
|
||||
.pending_data_batch
|
||||
.get_or_insert_with(SerializedValueBatch::default);
|
||||
batch.put(key, val, self.lsn);
|
||||
}
|
||||
|
||||
fn put_metadata(&mut self, key: CompactKey, val: Value) {
|
||||
@@ -1914,10 +2047,10 @@ impl<'a> DatadirModification<'a> {
|
||||
// Replace the previous value if it exists at the same lsn
|
||||
if let Some((last_lsn, last_value_ser_size, last_value)) = values.last_mut() {
|
||||
if *last_lsn == self.lsn {
|
||||
// Update the pending_bytes contribution from this entry, and update the serialized size in place
|
||||
self.pending_bytes -= *last_value_ser_size;
|
||||
// Update the pending_metadata_bytes contribution from this entry, and update the serialized size in place
|
||||
self.pending_metadata_bytes -= *last_value_ser_size;
|
||||
*last_value_ser_size = val.serialized_size().unwrap() as usize;
|
||||
self.pending_bytes += *last_value_ser_size;
|
||||
self.pending_metadata_bytes += *last_value_ser_size;
|
||||
|
||||
// Use the latest value, this replaces any earlier write to the same (key,lsn), such as much
|
||||
// have been generated by synthesized zero page writes prior to the first real write to a page.
|
||||
@@ -1927,8 +2060,12 @@ impl<'a> DatadirModification<'a> {
|
||||
}
|
||||
|
||||
let val_serialized_size = val.serialized_size().unwrap() as usize;
|
||||
self.pending_bytes += val_serialized_size;
|
||||
self.pending_metadata_bytes += val_serialized_size;
|
||||
values.push((self.lsn, val_serialized_size, val));
|
||||
|
||||
if key == CHECKPOINT_KEY.to_compact() {
|
||||
tracing::debug!("Checkpoint key added to pending with size {val_serialized_size}");
|
||||
}
|
||||
}
|
||||
|
||||
fn delete(&mut self, key_range: Range<Key>) {
|
||||
@@ -2037,7 +2174,11 @@ static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]);
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use hex_literal::hex;
|
||||
use utils::id::TimelineId;
|
||||
use pageserver_api::{models::ShardParameters, shard::ShardStripeSize};
|
||||
use utils::{
|
||||
id::TimelineId,
|
||||
shard::{ShardCount, ShardNumber},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -2091,6 +2232,93 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gap_finding() {
|
||||
let rel = RelTag {
|
||||
spcnode: 1663,
|
||||
dbnode: 208101,
|
||||
relnode: 2620,
|
||||
forknum: 0,
|
||||
};
|
||||
let base_blkno = 1;
|
||||
|
||||
let base_key = rel_block_to_key(rel, base_blkno);
|
||||
let before_base_key = rel_block_to_key(rel, base_blkno - 1);
|
||||
|
||||
let shard = ShardIdentity::unsharded();
|
||||
|
||||
let mut previous_nblocks = 0;
|
||||
for i in 0..10 {
|
||||
let crnt_blkno = base_blkno + i;
|
||||
let gaps = DatadirModification::find_gaps(rel, crnt_blkno, previous_nblocks, &shard);
|
||||
|
||||
previous_nblocks = crnt_blkno + 1;
|
||||
|
||||
if i == 0 {
|
||||
// The first block we write is 1, so we should find the gap.
|
||||
assert_eq!(gaps.unwrap(), KeySpace::single(before_base_key..base_key));
|
||||
} else {
|
||||
assert!(gaps.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
// This is an update to an already existing block. No gaps here.
|
||||
let update_blkno = 5;
|
||||
let gaps = DatadirModification::find_gaps(rel, update_blkno, previous_nblocks, &shard);
|
||||
assert!(gaps.is_none());
|
||||
|
||||
// This is an update past the current end block.
|
||||
let after_gap_blkno = 20;
|
||||
let gaps = DatadirModification::find_gaps(rel, after_gap_blkno, previous_nblocks, &shard);
|
||||
|
||||
let gap_start_key = rel_block_to_key(rel, previous_nblocks);
|
||||
let after_gap_key = rel_block_to_key(rel, after_gap_blkno);
|
||||
assert_eq!(
|
||||
gaps.unwrap(),
|
||||
KeySpace::single(gap_start_key..after_gap_key)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sharded_gap_finding() {
|
||||
let rel = RelTag {
|
||||
spcnode: 1663,
|
||||
dbnode: 208101,
|
||||
relnode: 2620,
|
||||
forknum: 0,
|
||||
};
|
||||
|
||||
let first_blkno = 6;
|
||||
|
||||
// This shard will get the even blocks
|
||||
let shard = ShardIdentity::from_params(
|
||||
ShardNumber(0),
|
||||
&ShardParameters {
|
||||
count: ShardCount(2),
|
||||
stripe_size: ShardStripeSize(1),
|
||||
},
|
||||
);
|
||||
|
||||
// Only keys belonging to this shard are considered as gaps.
|
||||
let mut previous_nblocks = 0;
|
||||
let gaps =
|
||||
DatadirModification::find_gaps(rel, first_blkno, previous_nblocks, &shard).unwrap();
|
||||
assert!(!gaps.ranges.is_empty());
|
||||
for gap_range in gaps.ranges {
|
||||
let mut k = gap_range.start;
|
||||
while k != gap_range.end {
|
||||
assert_eq!(shard.get_shard_number(&k), shard.number);
|
||||
k = k.next();
|
||||
}
|
||||
}
|
||||
|
||||
previous_nblocks = first_blkno;
|
||||
|
||||
let update_blkno = 2;
|
||||
let gaps = DatadirModification::find_gaps(rel, update_blkno, previous_nblocks, &shard);
|
||||
assert!(gaps.is_none());
|
||||
}
|
||||
|
||||
/*
|
||||
fn assert_current_logical_size<R: Repository>(timeline: &DatadirTimeline<R>, lsn: Lsn) {
|
||||
let incremental = timeline.get_current_logical_size();
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::tenant::timeline::GetVectoredError;
|
||||
use crate::tenant::PageReconstructError;
|
||||
use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt;
|
||||
use crate::{l0_flush, page_cache};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use anyhow::{anyhow, Result};
|
||||
use camino::Utf8PathBuf;
|
||||
use pageserver_api::key::CompactKey;
|
||||
use pageserver_api::key::Key;
|
||||
@@ -25,6 +25,7 @@ use std::sync::{Arc, OnceLock};
|
||||
use std::time::Instant;
|
||||
use tracing::*;
|
||||
use utils::{bin_ser::BeSer, id::TimelineId, lsn::Lsn, vec_map::VecMap};
|
||||
use wal_decoder::serialized_batch::{SerializedValueBatch, SerializedValueMeta, ValueMeta};
|
||||
// avoid binding to Write (conflicts with std::io::Write)
|
||||
// while being able to use std::fmt::Write's methods
|
||||
use crate::metrics::TIMELINE_EPHEMERAL_BYTES;
|
||||
@@ -452,6 +453,7 @@ impl InMemoryLayer {
|
||||
len,
|
||||
will_init,
|
||||
} = index_entry.unpack();
|
||||
|
||||
reads.entry(key).or_default().push(ValueRead {
|
||||
entry_lsn: *entry_lsn,
|
||||
read: vectored_dio_read::LogicalRead::new(
|
||||
@@ -513,68 +515,6 @@ impl InMemoryLayer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Offset of a particular Value within a serialized batch.
|
||||
struct SerializedBatchOffset {
|
||||
key: CompactKey,
|
||||
lsn: Lsn,
|
||||
// TODO: separate type when we start serde-serializing this value, to avoid coupling
|
||||
// in-memory representation to serialization format.
|
||||
index_entry: IndexEntry,
|
||||
}
|
||||
|
||||
pub struct SerializedBatch {
|
||||
/// Blobs serialized in EphemeralFile's native format, ready for passing to [`EphemeralFile::write_raw`].
|
||||
pub(crate) raw: Vec<u8>,
|
||||
|
||||
/// Index of values in [`Self::raw`], using offsets relative to the start of the buffer.
|
||||
offsets: Vec<SerializedBatchOffset>,
|
||||
|
||||
/// The highest LSN of any value in the batch
|
||||
pub(crate) max_lsn: Lsn,
|
||||
}
|
||||
|
||||
impl SerializedBatch {
|
||||
pub fn from_values(batch: Vec<(CompactKey, Lsn, usize, Value)>) -> anyhow::Result<Self> {
|
||||
// Pre-allocate a big flat buffer to write into. This should be large but not huge: it is soft-limited in practice by
|
||||
// [`crate::pgdatadir_mapping::DatadirModification::MAX_PENDING_BYTES`]
|
||||
let buffer_size = batch.iter().map(|i| i.2).sum::<usize>();
|
||||
let mut cursor = std::io::Cursor::new(Vec::<u8>::with_capacity(buffer_size));
|
||||
|
||||
let mut offsets: Vec<SerializedBatchOffset> = Vec::with_capacity(batch.len());
|
||||
let mut max_lsn: Lsn = Lsn(0);
|
||||
for (key, lsn, val_ser_size, val) in batch {
|
||||
let relative_off = cursor.position();
|
||||
|
||||
val.ser_into(&mut cursor)
|
||||
.expect("Writing into in-memory buffer is infallible");
|
||||
|
||||
offsets.push(SerializedBatchOffset {
|
||||
key,
|
||||
lsn,
|
||||
index_entry: IndexEntry::new(IndexEntryNewArgs {
|
||||
base_offset: 0,
|
||||
batch_offset: relative_off,
|
||||
len: val_ser_size,
|
||||
will_init: val.will_init(),
|
||||
})
|
||||
.context("higher-level code ensures that values are within supported ranges")?,
|
||||
});
|
||||
max_lsn = std::cmp::max(max_lsn, lsn);
|
||||
}
|
||||
|
||||
let buffer = cursor.into_inner();
|
||||
|
||||
// Assert that we didn't do any extra allocations while building buffer.
|
||||
debug_assert!(buffer.len() <= buffer_size);
|
||||
|
||||
Ok(Self {
|
||||
raw: buffer,
|
||||
offsets,
|
||||
max_lsn,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn inmem_layer_display(mut f: impl Write, start_lsn: Lsn, end_lsn: Lsn) -> std::fmt::Result {
|
||||
write!(f, "inmem-{:016X}-{:016X}", start_lsn.0, end_lsn.0)
|
||||
}
|
||||
@@ -642,7 +582,7 @@ impl InMemoryLayer {
|
||||
/// TODO: it can be made retryable if we aborted the process on EphemeralFile write errors.
|
||||
pub async fn put_batch(
|
||||
&self,
|
||||
serialized_batch: SerializedBatch,
|
||||
serialized_batch: SerializedValueBatch,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut inner = self.inner.write().await;
|
||||
@@ -650,27 +590,13 @@ impl InMemoryLayer {
|
||||
|
||||
let base_offset = inner.file.len();
|
||||
|
||||
let SerializedBatch {
|
||||
let SerializedValueBatch {
|
||||
raw,
|
||||
mut offsets,
|
||||
metadata,
|
||||
max_lsn: _,
|
||||
len: _,
|
||||
} = serialized_batch;
|
||||
|
||||
// Add the base_offset to the batch's index entries which are relative to the batch start.
|
||||
for offset in &mut offsets {
|
||||
let IndexEntryUnpacked {
|
||||
will_init,
|
||||
len,
|
||||
pos,
|
||||
} = offset.index_entry.unpack();
|
||||
offset.index_entry = IndexEntry::new(IndexEntryNewArgs {
|
||||
base_offset,
|
||||
batch_offset: pos,
|
||||
len: len.into_usize(),
|
||||
will_init,
|
||||
})?;
|
||||
}
|
||||
|
||||
// Write the batch to the file
|
||||
inner.file.write_raw(&raw, ctx).await?;
|
||||
let new_size = inner.file.len();
|
||||
@@ -683,12 +609,28 @@ impl InMemoryLayer {
|
||||
assert_eq!(new_size, expected_new_len);
|
||||
|
||||
// Update the index with the new entries
|
||||
for SerializedBatchOffset {
|
||||
key,
|
||||
lsn,
|
||||
index_entry,
|
||||
} in offsets
|
||||
{
|
||||
for meta in metadata {
|
||||
let SerializedValueMeta {
|
||||
key,
|
||||
lsn,
|
||||
batch_offset,
|
||||
len,
|
||||
will_init,
|
||||
} = match meta {
|
||||
ValueMeta::Serialized(ser) => ser,
|
||||
ValueMeta::Observed(_) => {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Add the base_offset to the batch's index entries which are relative to the batch start.
|
||||
let index_entry = IndexEntry::new(IndexEntryNewArgs {
|
||||
base_offset,
|
||||
batch_offset,
|
||||
len,
|
||||
will_init,
|
||||
})?;
|
||||
|
||||
let vec_map = inner.index.entry(key).or_default();
|
||||
let old = vec_map.append_or_update_last(lsn, index_entry).unwrap().0;
|
||||
if old.is_some() {
|
||||
|
||||
@@ -24,8 +24,8 @@ use offload::OffloadError;
|
||||
use once_cell::sync::Lazy;
|
||||
use pageserver_api::{
|
||||
key::{
|
||||
CompactKey, KEY_SIZE, METADATA_KEY_BEGIN_PREFIX, METADATA_KEY_END_PREFIX,
|
||||
NON_INHERITED_RANGE, NON_INHERITED_SPARSE_RANGE,
|
||||
KEY_SIZE, METADATA_KEY_BEGIN_PREFIX, METADATA_KEY_END_PREFIX, NON_INHERITED_RANGE,
|
||||
NON_INHERITED_SPARSE_RANGE,
|
||||
},
|
||||
keyspace::{KeySpaceAccum, KeySpaceRandomAccum, SparseKeyPartitioning},
|
||||
models::{
|
||||
@@ -49,6 +49,7 @@ use utils::{
|
||||
fs_ext, pausable_failpoint,
|
||||
sync::gate::{Gate, GateGuard},
|
||||
};
|
||||
use wal_decoder::serialized_batch::SerializedValueBatch;
|
||||
|
||||
use std::sync::atomic::Ordering as AtomicOrdering;
|
||||
use std::sync::{Arc, Mutex, RwLock, Weak};
|
||||
@@ -131,7 +132,6 @@ use crate::task_mgr::TaskKind;
|
||||
use crate::tenant::gc_result::GcResult;
|
||||
use crate::ZERO_PAGE;
|
||||
use pageserver_api::key::Key;
|
||||
use pageserver_api::value::Value;
|
||||
|
||||
use self::delete::DeleteTimelineFlow;
|
||||
pub(super) use self::eviction_task::EvictionTaskTenantState;
|
||||
@@ -141,9 +141,7 @@ use self::logical_size::LogicalSize;
|
||||
use self::walreceiver::{WalReceiver, WalReceiverConf};
|
||||
|
||||
use super::{
|
||||
config::TenantConf,
|
||||
storage_layer::{inmemory_layer, LayerVisibilityHint},
|
||||
upload_queue::NotInitialized,
|
||||
config::TenantConf, storage_layer::LayerVisibilityHint, upload_queue::NotInitialized,
|
||||
MaybeOffloaded,
|
||||
};
|
||||
use super::{debug_assert_current_span_has_tenant_and_timeline_id, AttachedTenantConf};
|
||||
@@ -157,6 +155,9 @@ use super::{
|
||||
GcError,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
use pageserver_api::value::Value;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub(crate) enum FlushLoopState {
|
||||
NotStarted,
|
||||
@@ -5736,23 +5737,22 @@ impl<'a> TimelineWriter<'a> {
|
||||
/// Put a batch of keys at the specified Lsns.
|
||||
pub(crate) async fn put_batch(
|
||||
&mut self,
|
||||
batch: Vec<(CompactKey, Lsn, usize, Value)>,
|
||||
batch: SerializedValueBatch,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<()> {
|
||||
if batch.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let serialized_batch = inmemory_layer::SerializedBatch::from_values(batch)?;
|
||||
let batch_max_lsn = serialized_batch.max_lsn;
|
||||
let buf_size: u64 = serialized_batch.raw.len() as u64;
|
||||
let batch_max_lsn = batch.max_lsn;
|
||||
let buf_size: u64 = batch.buffer_size() as u64;
|
||||
|
||||
let action = self.get_open_layer_action(batch_max_lsn, buf_size);
|
||||
let layer = self
|
||||
.handle_open_layer_action(batch_max_lsn, action, ctx)
|
||||
.await?;
|
||||
|
||||
let res = layer.put_batch(serialized_batch, ctx).await;
|
||||
let res = layer.put_batch(batch, ctx).await;
|
||||
|
||||
if res.is_ok() {
|
||||
// Update the current size only when the entire write was ok.
|
||||
@@ -5787,11 +5787,14 @@ impl<'a> TimelineWriter<'a> {
|
||||
);
|
||||
}
|
||||
let val_ser_size = value.serialized_size().unwrap() as usize;
|
||||
self.put_batch(
|
||||
vec![(key.to_compact(), lsn, val_ser_size, value.clone())],
|
||||
ctx,
|
||||
)
|
||||
.await
|
||||
let batch = SerializedValueBatch::from_values(vec![(
|
||||
key.to_compact(),
|
||||
lsn,
|
||||
val_ser_size,
|
||||
value.clone(),
|
||||
)]);
|
||||
|
||||
self.put_batch(batch, ctx).await
|
||||
}
|
||||
|
||||
pub(crate) async fn delete_batch(
|
||||
|
||||
@@ -36,7 +36,9 @@ use postgres_connection::PgConnectionConfig;
|
||||
use utils::backoff::{
|
||||
exponential_backoff, DEFAULT_BASE_BACKOFF_SECONDS, DEFAULT_MAX_BACKOFF_SECONDS,
|
||||
};
|
||||
use utils::postgres_client::wal_stream_connection_config;
|
||||
use utils::postgres_client::{
|
||||
wal_stream_connection_config, ConnectionConfigArgs, PAGESERVER_SAFEKEEPER_PROTO_VERSION, POSTGRES_PROTO_VERSION,
|
||||
};
|
||||
use utils::{
|
||||
id::{NodeId, TenantTimelineId},
|
||||
lsn::Lsn,
|
||||
@@ -984,15 +986,29 @@ impl ConnectionManagerState {
|
||||
if info.safekeeper_connstr.is_empty() {
|
||||
return None; // no connection string, ignore sk
|
||||
}
|
||||
match wal_stream_connection_config(
|
||||
self.id,
|
||||
info.safekeeper_connstr.as_ref(),
|
||||
match &self.conf.auth_token {
|
||||
None => None,
|
||||
Some(x) => Some(x),
|
||||
},
|
||||
self.conf.availability_zone.as_deref(),
|
||||
) {
|
||||
|
||||
let shard_identity = self.timeline.get_shard_identity();
|
||||
let connection_conf_args = ConnectionConfigArgs {
|
||||
protocol_version: PAGESERVER_SAFEKEEPER_PROTO_VERSION,
|
||||
ttid: self.id,
|
||||
shard_number: Some(shard_identity.number.0),
|
||||
shard_count: Some(shard_identity.count.0),
|
||||
shard_stripe_size: Some(shard_identity.stripe_size.0),
|
||||
listen_pg_addr_str: info.safekeeper_connstr.as_ref(),
|
||||
auth_token: self.conf.auth_token.as_ref().map(|t| t.as_str()),
|
||||
availability_zone: self.conf.availability_zone.as_deref()
|
||||
};
|
||||
// let connection_conf_args = ConnectionConfigArgs {
|
||||
// protocol_version: POSTGRES_PROTO_VERSION,
|
||||
// ttid: self.id,
|
||||
// shard_number: None,
|
||||
// shard_count: None,
|
||||
// shard_stripe_size: None,
|
||||
// listen_pg_addr_str: info.safekeeper_connstr.as_ref(),
|
||||
// auth_token: self.conf.auth_token.as_ref().map(|t| t.as_str()),
|
||||
// availability_zone: self.conf.availability_zone.as_deref()
|
||||
// };
|
||||
match wal_stream_connection_config(connection_conf_args) {
|
||||
Ok(connstr) => Some((*sk_id, info, connstr)),
|
||||
Err(e) => {
|
||||
error!("Failed to create wal receiver connection string from broker data of safekeeper node {}: {e:#}", sk_id);
|
||||
|
||||
@@ -36,7 +36,7 @@ use crate::{
|
||||
use postgres_backend::is_expected_io_error;
|
||||
use postgres_connection::PgConnectionConfig;
|
||||
use postgres_ffi::waldecoder::WalStreamDecoder;
|
||||
use utils::{id::NodeId, lsn::Lsn};
|
||||
use utils::{bin_ser::BeSer, id::NodeId, lsn::Lsn};
|
||||
use utils::{pageserver_feedback::PageserverFeedback, sync::gate::GateError};
|
||||
|
||||
/// Status of the connection.
|
||||
@@ -278,6 +278,7 @@ pub(super) async fn handle_walreceiver_connection(
|
||||
// fails (e.g. in walingest), we still want to know latests LSNs from the safekeeper.
|
||||
match &replication_message {
|
||||
ReplicationMessage::XLogData(xlog_data) => {
|
||||
// TODO(vlad) Is this crap needed?
|
||||
connection_status.latest_connection_update = now;
|
||||
connection_status.commit_lsn = Some(Lsn::from(xlog_data.wal_end()));
|
||||
connection_status.streaming_lsn = Some(Lsn::from(
|
||||
@@ -299,6 +300,24 @@ pub(super) async fn handle_walreceiver_connection(
|
||||
}
|
||||
|
||||
let status_update = match replication_message {
|
||||
ReplicationMessage::RawInterpretedWalRecord(raw) => {
|
||||
connection_status.latest_connection_update = now;
|
||||
connection_status.latest_wal_update = now;
|
||||
connection_status.commit_lsn = Some(Lsn::from(raw.wal_end()));
|
||||
|
||||
let interpreted = InterpretedWalRecord::des(raw.data()).unwrap();
|
||||
let end_lsn = interpreted.end_lsn;
|
||||
|
||||
let mut modification = timeline.begin_modification(end_lsn);
|
||||
walingest
|
||||
.ingest_record(interpreted, &mut modification, &ctx)
|
||||
.await
|
||||
.with_context(|| format!("could not ingest record at {}", end_lsn))?;
|
||||
modification.commit(&ctx).await?;
|
||||
|
||||
Some(end_lsn)
|
||||
}
|
||||
|
||||
ReplicationMessage::XLogData(xlog_data) => {
|
||||
// Pass the WAL data to the decoder, and see if we can decode
|
||||
// more records as a result.
|
||||
@@ -331,11 +350,11 @@ pub(super) async fn handle_walreceiver_connection(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
while let Some((lsn, recdata)) = waldecoder.poll_decode()? {
|
||||
while let Some((record_end_lsn, recdata)) = waldecoder.poll_decode()? {
|
||||
// It is important to deal with the aligned records as lsn in getPage@LSN is
|
||||
// aligned and can be several bytes bigger. Without this alignment we are
|
||||
// at risk of hitting a deadlock.
|
||||
if !lsn.is_aligned() {
|
||||
if !record_end_lsn.is_aligned() {
|
||||
return Err(WalReceiverError::Other(anyhow!("LSN not aligned")));
|
||||
}
|
||||
|
||||
@@ -343,7 +362,7 @@ pub(super) async fn handle_walreceiver_connection(
|
||||
let interpreted = InterpretedWalRecord::from_bytes_filtered(
|
||||
recdata,
|
||||
modification.tline.get_shard_identity(),
|
||||
lsn,
|
||||
record_end_lsn,
|
||||
modification.tline.pg_version,
|
||||
)?;
|
||||
|
||||
@@ -366,9 +385,11 @@ pub(super) async fn handle_walreceiver_connection(
|
||||
let ingested = walingest
|
||||
.ingest_record(interpreted, &mut modification, &ctx)
|
||||
.await
|
||||
.with_context(|| format!("could not ingest record at {lsn}"))?;
|
||||
.with_context(|| {
|
||||
format!("could not ingest record at {record_end_lsn}")
|
||||
})?;
|
||||
if !ingested {
|
||||
tracing::debug!("ingest: filtered out record @ LSN {lsn}");
|
||||
tracing::debug!("ingest: filtered out record @ LSN {record_end_lsn}");
|
||||
WAL_INGEST.records_filtered.inc();
|
||||
filtered_records += 1;
|
||||
}
|
||||
@@ -378,7 +399,7 @@ pub(super) async fn handle_walreceiver_connection(
|
||||
// to timeout the tests.
|
||||
fail_point!("walreceiver-after-ingest");
|
||||
|
||||
last_rec_lsn = lsn;
|
||||
last_rec_lsn = record_end_lsn;
|
||||
|
||||
// Commit every ingest_batch_size records. Even if we filtered out
|
||||
// all records, we still need to call commit to advance the LSN.
|
||||
|
||||
@@ -28,14 +28,13 @@ use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use pageserver_api::key::Key;
|
||||
use pageserver_api::shard::ShardIdentity;
|
||||
use postgres_ffi::fsm_logical_to_physical;
|
||||
use postgres_ffi::walrecord::*;
|
||||
use postgres_ffi::{dispatch_pgversion, enum_pgversion, enum_pgversion_dispatch, TimestampTz};
|
||||
use wal_decoder::models::*;
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use anyhow::{bail, Result};
|
||||
use bytes::{Buf, Bytes};
|
||||
use tracing::*;
|
||||
use utils::failpoint_support;
|
||||
@@ -51,7 +50,6 @@ use crate::ZERO_PAGE;
|
||||
use pageserver_api::key::rel_block_to_key;
|
||||
use pageserver_api::record::NeonWalRecord;
|
||||
use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind};
|
||||
use pageserver_api::value::Value;
|
||||
use postgres_ffi::pg_constants;
|
||||
use postgres_ffi::relfile_utils::{FSM_FORKNUM, INIT_FORKNUM, MAIN_FORKNUM, VISIBILITYMAP_FORKNUM};
|
||||
use postgres_ffi::TransactionId;
|
||||
@@ -156,12 +154,12 @@ impl WalIngest {
|
||||
WAL_INGEST.records_received.inc();
|
||||
let prev_len = modification.len();
|
||||
|
||||
modification.set_lsn(interpreted.lsn)?;
|
||||
modification.set_lsn(interpreted.end_lsn)?;
|
||||
|
||||
if matches!(interpreted.flush_uncommitted, FlushUncommittedRecords::Yes) {
|
||||
// Records of this type should always be preceded by a commit(), as they
|
||||
// rely on reading data pages back from the Timeline.
|
||||
assert!(!modification.has_dirty_data_pages());
|
||||
assert!(!modification.has_dirty_data());
|
||||
}
|
||||
|
||||
assert!(!self.checkpoint_modified);
|
||||
@@ -275,28 +273,9 @@ impl WalIngest {
|
||||
}
|
||||
}
|
||||
|
||||
// Iterate through all the key value pairs provided in the interpreted block
|
||||
// and update the modification currently in-flight to include them.
|
||||
for (compact_key, maybe_value) in interpreted.blocks.into_iter() {
|
||||
let (rel, blk) = Key::from_compact(compact_key).to_rel_block()?;
|
||||
match maybe_value {
|
||||
Some(Value::Image(img)) => {
|
||||
self.put_rel_page_image(modification, rel, blk, img, ctx)
|
||||
.await?;
|
||||
}
|
||||
Some(Value::WalRecord(rec)) => {
|
||||
self.put_rel_wal_record(modification, rel, blk, rec, ctx)
|
||||
.await?;
|
||||
}
|
||||
None => {
|
||||
// Shard 0 tracks relation sizes. We will observe
|
||||
// its blkno in case it implicitly extends a relation.
|
||||
assert!(self.shard.is_shard_zero());
|
||||
self.observe_decoded_block(modification, rel, blk, ctx)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
modification
|
||||
.ingest_batch(interpreted.batch, &self.shard, ctx)
|
||||
.await?;
|
||||
|
||||
// If checkpoint data was updated, store the new version in the repository
|
||||
if self.checkpoint_modified {
|
||||
@@ -310,8 +289,6 @@ impl WalIngest {
|
||||
// until commit() is called to flush the data into the repository and update
|
||||
// the latest LSN.
|
||||
|
||||
modification.on_record_end();
|
||||
|
||||
Ok(modification.len() > prev_len)
|
||||
}
|
||||
|
||||
@@ -334,17 +311,6 @@ impl WalIngest {
|
||||
Ok((epoch as u64) << 32 | xid as u64)
|
||||
}
|
||||
|
||||
/// Do not store this block, but observe it for the purposes of updating our relation size state.
|
||||
async fn observe_decoded_block(
|
||||
&mut self,
|
||||
modification: &mut DatadirModification<'_>,
|
||||
rel: RelTag,
|
||||
blkno: BlockNumber,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<(), PageReconstructError> {
|
||||
self.handle_rel_extend(modification, rel, blkno, ctx).await
|
||||
}
|
||||
|
||||
async fn ingest_clear_vm_bits(
|
||||
&mut self,
|
||||
clear_vm_bits: ClearVmBits,
|
||||
@@ -1248,6 +1214,7 @@ impl WalIngest {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn put_rel_page_image(
|
||||
&mut self,
|
||||
modification: &mut DatadirModification<'_>,
|
||||
@@ -1297,36 +1264,7 @@ impl WalIngest {
|
||||
let new_nblocks = blknum + 1;
|
||||
// Check if the relation exists. We implicitly create relations on first
|
||||
// record.
|
||||
// TODO: would be nice if to be more explicit about it
|
||||
|
||||
// Get current size and put rel creation if rel doesn't exist
|
||||
//
|
||||
// NOTE: we check the cache first even though get_rel_exists and get_rel_size would
|
||||
// check the cache too. This is because eagerly checking the cache results in
|
||||
// less work overall and 10% better performance. It's more work on cache miss
|
||||
// but cache miss is rare.
|
||||
let old_nblocks = if let Some(nblocks) = modification
|
||||
.tline
|
||||
.get_cached_rel_size(&rel, modification.get_lsn())
|
||||
{
|
||||
nblocks
|
||||
} else if !modification
|
||||
.tline
|
||||
.get_rel_exists(rel, Version::Modified(modification), ctx)
|
||||
.await?
|
||||
{
|
||||
// create it with 0 size initially, the logic below will extend it
|
||||
modification
|
||||
.put_rel_creation(rel, 0, ctx)
|
||||
.await
|
||||
.context("Relation Error")?;
|
||||
0
|
||||
} else {
|
||||
modification
|
||||
.tline
|
||||
.get_rel_size(rel, Version::Modified(modification), ctx)
|
||||
.await?
|
||||
};
|
||||
let old_nblocks = modification.create_relation_if_required(rel, ctx).await?;
|
||||
|
||||
if new_nblocks > old_nblocks {
|
||||
//info!("extending {} {} to {}", rel, old_nblocks, new_nblocks);
|
||||
@@ -1553,25 +1491,21 @@ mod tests {
|
||||
walingest
|
||||
.put_rel_page_image(&mut m, TESTREL_A, 0, test_img("foo blk 0 at 2"), &ctx)
|
||||
.await?;
|
||||
m.on_record_end();
|
||||
m.commit(&ctx).await?;
|
||||
let mut m = tline.begin_modification(Lsn(0x30));
|
||||
walingest
|
||||
.put_rel_page_image(&mut m, TESTREL_A, 0, test_img("foo blk 0 at 3"), &ctx)
|
||||
.await?;
|
||||
m.on_record_end();
|
||||
m.commit(&ctx).await?;
|
||||
let mut m = tline.begin_modification(Lsn(0x40));
|
||||
walingest
|
||||
.put_rel_page_image(&mut m, TESTREL_A, 1, test_img("foo blk 1 at 4"), &ctx)
|
||||
.await?;
|
||||
m.on_record_end();
|
||||
m.commit(&ctx).await?;
|
||||
let mut m = tline.begin_modification(Lsn(0x50));
|
||||
walingest
|
||||
.put_rel_page_image(&mut m, TESTREL_A, 2, test_img("foo blk 2 at 5"), &ctx)
|
||||
.await?;
|
||||
m.on_record_end();
|
||||
m.commit(&ctx).await?;
|
||||
|
||||
assert_current_logical_size(&tline, Lsn(0x50));
|
||||
@@ -1713,7 +1647,6 @@ mod tests {
|
||||
walingest
|
||||
.put_rel_page_image(&mut m, TESTREL_A, 1, test_img("foo blk 1"), &ctx)
|
||||
.await?;
|
||||
m.on_record_end();
|
||||
m.commit(&ctx).await?;
|
||||
assert_eq!(
|
||||
tline
|
||||
@@ -1739,7 +1672,6 @@ mod tests {
|
||||
walingest
|
||||
.put_rel_page_image(&mut m, TESTREL_A, 1500, test_img("foo blk 1500"), &ctx)
|
||||
.await?;
|
||||
m.on_record_end();
|
||||
m.commit(&ctx).await?;
|
||||
assert_eq!(
|
||||
tline
|
||||
|
||||
@@ -23,7 +23,7 @@ bstr.workspace = true
|
||||
bytes = { workspace = true, features = ["serde"] }
|
||||
camino.workspace = true
|
||||
chrono.workspace = true
|
||||
clap.workspace = true
|
||||
clap = { workspace = true, features = ["derive", "env"] }
|
||||
compute_api.workspace = true
|
||||
consumption_metrics.workspace = true
|
||||
dashmap.workspace = true
|
||||
@@ -98,6 +98,7 @@ rustls-native-certs.workspace = true
|
||||
x509-parser.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
redis.workspace = true
|
||||
zerocopy.workspace = true
|
||||
|
||||
# jwt stuff
|
||||
jose-jwa = "0.1.2"
|
||||
|
||||
@@ -9,15 +9,14 @@ use super::ComputeCredentialKeys;
|
||||
use crate::cache::Cached;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::provider::NodeInfo;
|
||||
use crate::control_plane::{self, CachedNodeInfo};
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
use crate::stream::PqStream;
|
||||
use crate::{auth, compute, waiters};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum WebAuthError {
|
||||
pub(crate) enum ConsoleRedirectError {
|
||||
#[error(transparent)]
|
||||
WaiterRegister(#[from] waiters::RegisterError),
|
||||
|
||||
@@ -33,13 +32,13 @@ pub struct ConsoleRedirectBackend {
|
||||
console_uri: reqwest::Url,
|
||||
}
|
||||
|
||||
impl UserFacingError for WebAuthError {
|
||||
impl UserFacingError for ConsoleRedirectError {
|
||||
fn to_string_client(&self) -> String {
|
||||
"Internal error".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for WebAuthError {
|
||||
impl ReportableError for ConsoleRedirectError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Self::WaiterRegister(_) => crate::error::ErrorKind::Service,
|
||||
@@ -104,7 +103,7 @@ async fn authenticate(
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<NodeInfo> {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Web);
|
||||
ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);
|
||||
|
||||
// registering waiter can fail if we get unlucky with rng.
|
||||
// just try again.
|
||||
@@ -117,7 +116,7 @@ async fn authenticate(
|
||||
}
|
||||
};
|
||||
|
||||
let span = info_span!("web", psql_session_id = &psql_session_id);
|
||||
let span = info_span!("console_redirect", psql_session_id = &psql_session_id);
|
||||
let greeting = hello_message(link_uri, &psql_session_id);
|
||||
|
||||
// Give user a URL to spawn a new database.
|
||||
@@ -128,14 +127,16 @@ async fn authenticate(
|
||||
.write_message(&Be::NoticeResponse(&greeting))
|
||||
.await?;
|
||||
|
||||
// Wait for web console response (see `mgmt`).
|
||||
// Wait for console response via control plane (see `mgmt`).
|
||||
info!(parent: &span, "waiting for console's reply...");
|
||||
let db_info = tokio::time::timeout(auth_config.webauth_confirmation_timeout, waiter)
|
||||
let db_info = tokio::time::timeout(auth_config.console_redirect_confirmation_timeout, waiter)
|
||||
.await
|
||||
.map_err(|_elapsed| {
|
||||
auth::AuthError::confirmation_timeout(auth_config.webauth_confirmation_timeout.into())
|
||||
auth::AuthError::confirmation_timeout(
|
||||
auth_config.console_redirect_confirmation_timeout.into(),
|
||||
)
|
||||
})?
|
||||
.map_err(WebAuthError::from)?;
|
||||
.map_err(ConsoleRedirectError::from)?;
|
||||
|
||||
if auth_config.ip_allowlist_check_enabled {
|
||||
if let Some(allowed_ips) = &db_info.allowed_ips {
|
||||
|
||||
@@ -9,7 +9,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
pub use console_redirect::ConsoleRedirectBackend;
|
||||
pub(crate) use console_redirect::WebAuthError;
|
||||
pub(crate) use console_redirect::ConsoleRedirectError;
|
||||
use ipnet::{Ipv4Net, Ipv6Net};
|
||||
use local::LocalBackend;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
@@ -21,11 +21,11 @@ use crate::auth::{self, validate_password_and_exchange, AuthError, ComputeUserIn
|
||||
use crate::cache::Cached;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::client::ControlPlaneClient;
|
||||
use crate::control_plane::errors::GetAuthInfoError;
|
||||
use crate::control_plane::provider::{
|
||||
CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneBackend,
|
||||
use crate::control_plane::{
|
||||
self, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
|
||||
};
|
||||
use crate::control_plane::{self, Api, AuthSecret};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
@@ -62,42 +62,26 @@ impl<T> std::ops::Deref for MaybeOwned<'_, T> {
|
||||
/// backends which require them for the authentication process.
|
||||
pub enum Backend<'a, T> {
|
||||
/// Cloud API (V2).
|
||||
ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T),
|
||||
ControlPlane(MaybeOwned<'a, ControlPlaneClient>, T),
|
||||
/// Local proxy uses configured auth credentials and does not wake compute
|
||||
Local(MaybeOwned<'a, LocalBackend>),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) trait TestBackend: Send + Sync + 'static {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
|
||||
fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), control_plane::errors::GetAuthInfoError>;
|
||||
fn dyn_clone(&self) -> Box<dyn TestBackend>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Clone for Box<dyn TestBackend> {
|
||||
fn clone(&self) -> Self {
|
||||
TestBackend::dyn_clone(&**self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Backend<'_, ()> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ControlPlane(api, ()) => match &**api {
|
||||
ControlPlaneBackend::Management(endpoint) => fmt
|
||||
.debug_tuple("ControlPlane::Management")
|
||||
ControlPlaneClient::Neon(endpoint) => fmt
|
||||
.debug_tuple("ControlPlane::Neon")
|
||||
.field(&endpoint.url())
|
||||
.finish(),
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ControlPlaneBackend::PostgresMock(endpoint) => fmt
|
||||
ControlPlaneClient::PostgresMock(endpoint) => fmt
|
||||
.debug_tuple("ControlPlane::PostgresMock")
|
||||
.field(&endpoint.url())
|
||||
.finish(),
|
||||
#[cfg(test)]
|
||||
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
|
||||
ControlPlaneClient::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
|
||||
},
|
||||
Self::Local(_) => fmt.debug_tuple("Local").finish(),
|
||||
}
|
||||
@@ -282,7 +266,7 @@ impl AuthenticationConfig {
|
||||
/// All authentication flows will emit an AuthenticationOk message if successful.
|
||||
async fn auth_quirks(
|
||||
ctx: &RequestMonitoring,
|
||||
api: &impl control_plane::Api,
|
||||
api: &impl control_plane::ControlPlaneApi,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
@@ -499,12 +483,12 @@ mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use bytes::BytesMut;
|
||||
use control_plane::AuthSecret;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
|
||||
use postgres_protocol::message::backend::Message as PgMessage;
|
||||
use postgres_protocol::message::frontend;
|
||||
use provider::AuthSecret;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
use super::jwt::JwkCache;
|
||||
@@ -513,8 +497,7 @@ mod tests {
|
||||
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::provider::{self, CachedAllowedIps, CachedRoleSecret};
|
||||
use crate::control_plane::{self, CachedNodeInfo};
|
||||
use crate::control_plane::{self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret};
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
@@ -526,7 +509,7 @@ mod tests {
|
||||
secret: AuthSecret,
|
||||
}
|
||||
|
||||
impl control_plane::Api for Auth {
|
||||
impl control_plane::ControlPlaneApi for Auth {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
@@ -577,7 +560,7 @@ mod tests {
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_auth_broker: false,
|
||||
accept_jwts: false,
|
||||
webauth_confirmation_timeout: std::time::Duration::from_secs(5),
|
||||
console_redirect_confirmation_timeout: std::time::Duration::from_secs(5),
|
||||
});
|
||||
|
||||
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
|
||||
|
||||
@@ -32,7 +32,7 @@ pub(crate) type Result<T> = std::result::Result<T, AuthError>;
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum AuthError {
|
||||
#[error(transparent)]
|
||||
Web(#[from] backend::WebAuthError),
|
||||
ConsoleRedirect(#[from] backend::ConsoleRedirectError),
|
||||
|
||||
#[error(transparent)]
|
||||
GetAuthInfo(#[from] control_plane::errors::GetAuthInfoError),
|
||||
@@ -115,7 +115,7 @@ impl AuthError {
|
||||
impl UserFacingError for AuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
Self::Web(e) => e.to_string_client(),
|
||||
Self::ConsoleRedirect(e) => e.to_string_client(),
|
||||
Self::GetAuthInfo(e) => e.to_string_client(),
|
||||
Self::Sasl(e) => e.to_string_client(),
|
||||
Self::PasswordFailed(_) => self.to_string(),
|
||||
@@ -135,7 +135,7 @@ impl UserFacingError for AuthError {
|
||||
impl ReportableError for AuthError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Self::Web(e) => e.get_error_kind(),
|
||||
Self::ConsoleRedirect(e) => e.get_error_kind(),
|
||||
Self::GetAuthInfo(e) => e.get_error_kind(),
|
||||
Self::Sasl(e) => e.get_error_kind(),
|
||||
Self::PasswordFailed(_) => crate::error::ErrorKind::User,
|
||||
|
||||
@@ -281,7 +281,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_auth_broker: false,
|
||||
accept_jwts: true,
|
||||
webauth_confirmation_timeout: Duration::ZERO,
|
||||
console_redirect_confirmation_timeout: Duration::ZERO,
|
||||
},
|
||||
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
|
||||
handshake_timeout: Duration::from_secs(10),
|
||||
|
||||
@@ -51,11 +51,11 @@ static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
||||
|
||||
#[derive(Clone, Debug, ValueEnum)]
|
||||
enum AuthBackendType {
|
||||
Console,
|
||||
// clap only shows the name, not the alias, in usage text.
|
||||
// TODO: swap name/alias and deprecate "link"
|
||||
#[value(name("link"), alias("web"))]
|
||||
Web,
|
||||
#[value(name("console"), alias("cplane"))]
|
||||
ControlPlane,
|
||||
|
||||
#[value(name("link"), alias("control-redirect"))]
|
||||
ConsoleRedirect,
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres,
|
||||
@@ -71,7 +71,7 @@ struct ProxyCliArgs {
|
||||
/// listen for incoming client connections on ip:port
|
||||
#[clap(short, long, default_value = "127.0.0.1:4432")]
|
||||
proxy: String,
|
||||
#[clap(value_enum, long, default_value_t = AuthBackendType::Web)]
|
||||
#[clap(value_enum, long, default_value_t = AuthBackendType::ConsoleRedirect)]
|
||||
auth_backend: AuthBackendType,
|
||||
/// listen for management callback connection on ip:port
|
||||
#[clap(short, long, default_value = "127.0.0.1:7000")]
|
||||
@@ -82,7 +82,7 @@ struct ProxyCliArgs {
|
||||
/// listen for incoming wss connections on ip:port
|
||||
#[clap(long)]
|
||||
wss: Option<String>,
|
||||
/// redirect unauthenticated users to the given uri in case of web auth
|
||||
/// redirect unauthenticated users to the given uri in case of console redirect auth
|
||||
#[clap(short, long, default_value = "http://localhost:3000/psql_session/")]
|
||||
uri: String,
|
||||
/// cloud API endpoint for authenticating users
|
||||
@@ -92,6 +92,14 @@ struct ProxyCliArgs {
|
||||
default_value = "http://localhost:3000/authenticate_proxy_request/"
|
||||
)]
|
||||
auth_endpoint: String,
|
||||
/// JWT used to connect to control plane.
|
||||
#[clap(
|
||||
long,
|
||||
value_name = "JWT",
|
||||
default_value = "",
|
||||
env = "NEON_PROXY_TO_CONTROLPLANE_TOKEN"
|
||||
)]
|
||||
control_plane_token: Arc<str>,
|
||||
/// if this is not local proxy, this toggles whether we accept jwt or passwords for http
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
is_auth_broker: bool,
|
||||
@@ -223,6 +231,7 @@ struct ProxyCliArgs {
|
||||
proxy_protocol_v2: ProxyProtocolV2,
|
||||
|
||||
/// Time the proxy waits for the webauth session to be confirmed by the control plane.
|
||||
// TODO: rename to `console_redirect_confirmation_timeout`.
|
||||
#[clap(long, default_value = "2m", value_parser = humantime::parse_duration)]
|
||||
webauth_confirmation_timeout: std::time::Duration,
|
||||
}
|
||||
@@ -513,7 +522,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
if let Either::Left(auth::Backend::ControlPlane(api, _)) = &auth_backend {
|
||||
if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api {
|
||||
if let proxy::control_plane::client::ControlPlaneClient::Neon(api) = &**api {
|
||||
match (redis_notifications_client, regional_redis_client.clone()) {
|
||||
(None, None) => {}
|
||||
(client1, client2) => {
|
||||
@@ -659,7 +668,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
||||
is_auth_broker: args.is_auth_broker,
|
||||
accept_jwts: args.is_auth_broker,
|
||||
webauth_confirmation_timeout: args.webauth_confirmation_timeout,
|
||||
console_redirect_confirmation_timeout: args.webauth_confirmation_timeout,
|
||||
};
|
||||
|
||||
let config = ProxyConfig {
|
||||
@@ -690,7 +699,7 @@ fn build_auth_backend(
|
||||
args: &ProxyCliArgs,
|
||||
) -> anyhow::Result<Either<&'static auth::Backend<'static, ()>, &'static ConsoleRedirectBackend>> {
|
||||
match &args.auth_backend {
|
||||
AuthBackendType::Console => {
|
||||
AuthBackendType::ControlPlane => {
|
||||
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
|
||||
let project_info_cache_config: ProjectInfoCacheOptions =
|
||||
args.project_info_cache.parse()?;
|
||||
@@ -732,13 +741,14 @@ fn build_auth_backend(
|
||||
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
|
||||
let wake_compute_endpoint_rate_limiter =
|
||||
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
|
||||
let api = control_plane::provider::neon::Api::new(
|
||||
let api = control_plane::client::neon::NeonControlPlaneClient::new(
|
||||
endpoint,
|
||||
args.control_plane_token.clone(),
|
||||
caches,
|
||||
locks,
|
||||
wake_compute_endpoint_rate_limiter,
|
||||
);
|
||||
let api = control_plane::provider::ControlPlaneBackend::Management(api);
|
||||
let api = control_plane::client::ControlPlaneClient::Neon(api);
|
||||
let auth_backend = auth::Backend::ControlPlane(MaybeOwned::Owned(api), ());
|
||||
|
||||
let config = Box::leak(Box::new(auth_backend));
|
||||
@@ -749,8 +759,11 @@ fn build_auth_backend(
|
||||
#[cfg(feature = "testing")]
|
||||
AuthBackendType::Postgres => {
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let api = control_plane::provider::mock::Api::new(url, !args.is_private_access_proxy);
|
||||
let api = control_plane::provider::ControlPlaneBackend::PostgresMock(api);
|
||||
let api = control_plane::client::mock::MockControlPlane::new(
|
||||
url,
|
||||
!args.is_private_access_proxy,
|
||||
);
|
||||
let api = control_plane::client::ControlPlaneClient::PostgresMock(api);
|
||||
|
||||
let auth_backend = auth::Backend::ControlPlane(MaybeOwned::Owned(api), ());
|
||||
|
||||
@@ -759,7 +772,7 @@ fn build_auth_backend(
|
||||
Ok(Either::Left(config))
|
||||
}
|
||||
|
||||
AuthBackendType::Web => {
|
||||
AuthBackendType::ConsoleRedirect => {
|
||||
let url = args.uri.parse()?;
|
||||
let backend = ConsoleRedirectBackend::new(url);
|
||||
|
||||
|
||||
77
proxy/src/cache/endpoints.rs
vendored
77
proxy/src/cache/endpoints.rs
vendored
@@ -19,26 +19,28 @@ use crate::rate_limiter::GlobalRateLimiter;
|
||||
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::types::EndpointId;
|
||||
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
#[serde(tag = "type", rename_all(deserialize = "snake_case"))]
|
||||
enum ControlPlaneEvent {
|
||||
EndpointCreated { endpoint_created: EndpointCreated },
|
||||
BranchCreated { branch_created: BranchCreated },
|
||||
ProjectCreated { project_created: ProjectCreated },
|
||||
// TODO: this could be an enum, but events in Redis need to be fixed first.
|
||||
// ProjectCreated was sent with type:branch_created. So we ignore type.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||||
struct ControlPlaneEvent {
|
||||
endpoint_created: Option<EndpointCreated>,
|
||||
branch_created: Option<BranchCreated>,
|
||||
project_created: Option<ProjectCreated>,
|
||||
#[serde(rename = "type")]
|
||||
_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||||
struct EndpointCreated {
|
||||
endpoint_id: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||||
struct BranchCreated {
|
||||
branch_id: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||||
struct ProjectCreated {
|
||||
project_id: String,
|
||||
}
|
||||
@@ -104,24 +106,28 @@ impl EndpointsCache {
|
||||
}
|
||||
|
||||
fn insert_event(&self, event: ControlPlaneEvent) {
|
||||
let counter = match event {
|
||||
ControlPlaneEvent::EndpointCreated { endpoint_created } => {
|
||||
self.endpoints
|
||||
.insert(EndpointIdInt::from(&endpoint_created.endpoint_id.into()));
|
||||
RedisEventsCount::EndpointCreated
|
||||
}
|
||||
ControlPlaneEvent::BranchCreated { branch_created } => {
|
||||
self.branches
|
||||
.insert(BranchIdInt::from(&branch_created.branch_id.into()));
|
||||
RedisEventsCount::BranchCreated
|
||||
}
|
||||
ControlPlaneEvent::ProjectCreated { project_created } => {
|
||||
self.projects
|
||||
.insert(ProjectIdInt::from(&project_created.project_id.into()));
|
||||
RedisEventsCount::ProjectCreated
|
||||
}
|
||||
};
|
||||
Metrics::get().proxy.redis_events_count.inc(counter);
|
||||
if let Some(endpoint_created) = event.endpoint_created {
|
||||
self.endpoints
|
||||
.insert(EndpointIdInt::from(&endpoint_created.endpoint_id.into()));
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::EndpointCreated);
|
||||
} else if let Some(branch_created) = event.branch_created {
|
||||
self.branches
|
||||
.insert(BranchIdInt::from(&branch_created.branch_id.into()));
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::BranchCreated);
|
||||
} else if let Some(project_created) = event.project_created {
|
||||
self.projects
|
||||
.insert(ProjectIdInt::from(&project_created.project_id.into()));
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::ProjectCreated);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn do_read(
|
||||
@@ -235,11 +241,22 @@ impl EndpointsCache {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ControlPlaneEvent;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_control_plane_event() {
|
||||
let s = r#"{"branch_created":null,"endpoint_created":{"endpoint_id":"ep-rapid-thunder-w0qqw2q9"},"project_created":null,"type":"endpoint_created"}"#;
|
||||
serde_json::from_str::<ControlPlaneEvent>(s).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
serde_json::from_str::<ControlPlaneEvent>(s).unwrap(),
|
||||
ControlPlaneEvent {
|
||||
endpoint_created: Some(EndpointCreated {
|
||||
endpoint_id: "ep-rapid-thunder-w0qqw2q9".into()
|
||||
}),
|
||||
branch_created: None,
|
||||
project_created: None,
|
||||
_type: Some("endpoint_created".into()),
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,9 +19,9 @@ use tracing::{error, info, warn};
|
||||
use crate::auth::parse_endpoint_param;
|
||||
use crate::cancellation::CancelClosure;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::client::ApiLockError;
|
||||
use crate::control_plane::errors::WakeComputeError;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::control_plane::provider::ApiLockError;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumDbConnectionsGuard};
|
||||
use crate::proxy::neon_option;
|
||||
@@ -135,13 +135,13 @@ impl ConnCfg {
|
||||
/// Apply startup message params to the connection config.
|
||||
pub(crate) fn set_startup_params(&mut self, params: &StartupMessageParams) {
|
||||
// Only set `user` if it's not present in the config.
|
||||
// Web auth flow takes username from the console's response.
|
||||
// Console redirect auth flow takes username from the console's response.
|
||||
if let (None, Some(user)) = (self.get_user(), params.get("user")) {
|
||||
self.user(user);
|
||||
}
|
||||
|
||||
// Only set `dbname` if it's not present in the config.
|
||||
// Web auth flow takes dbname from the console's response.
|
||||
// Console redirect auth flow takes dbname from the console's response.
|
||||
if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
|
||||
self.dbname(dbname);
|
||||
}
|
||||
@@ -316,6 +316,7 @@ impl ConnCfg {
|
||||
let client_config = client_config.with_no_client_auth();
|
||||
|
||||
let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
|
||||
// TODO(vlad): que?
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
host,
|
||||
|
||||
@@ -78,7 +78,7 @@ pub struct AuthenticationConfig {
|
||||
pub jwks_cache: JwkCache,
|
||||
pub is_auth_broker: bool,
|
||||
pub accept_jwts: bool,
|
||||
pub webauth_confirmation_timeout: tokio::time::Duration,
|
||||
pub console_redirect_confirmation_timeout: tokio::time::Duration,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
@@ -271,7 +271,7 @@ impl CertResolver {
|
||||
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
|
||||
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
|
||||
//
|
||||
// Console Web proxy does not use any wildcard domains and does not need any certificate selection or conn string
|
||||
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
|
||||
// validation, so let's we can continue with any common-name
|
||||
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
|
||||
s.to_string()
|
||||
@@ -366,7 +366,7 @@ pub struct EndpointCacheConfig {
|
||||
}
|
||||
|
||||
impl EndpointCacheConfig {
|
||||
/// Default options for [`crate::control_plane::provider::NodeInfoCache`].
|
||||
/// Default options for [`crate::control_plane::NodeInfoCache`].
|
||||
/// Notice that by default the limiter is empty, which means that cache is disabled.
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str =
|
||||
"initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
|
||||
@@ -441,7 +441,7 @@ pub struct CacheOptions {
|
||||
}
|
||||
|
||||
impl CacheOptions {
|
||||
/// Default options for [`crate::control_plane::provider::NodeInfoCache`].
|
||||
/// Default options for [`crate::control_plane::NodeInfoCache`].
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
|
||||
|
||||
/// Parse cache options passed via cmdline.
|
||||
@@ -497,7 +497,7 @@ pub struct ProjectInfoCacheOptions {
|
||||
}
|
||||
|
||||
impl ProjectInfoCacheOptions {
|
||||
/// Default options for [`crate::control_plane::provider::NodeInfoCache`].
|
||||
/// Default options for [`crate::control_plane::NodeInfoCache`].
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str =
|
||||
"size=10000,ttl=4m,max_roles=10,gc_interval=60m";
|
||||
|
||||
@@ -616,9 +616,9 @@ pub struct ConcurrencyLockOptions {
|
||||
}
|
||||
|
||||
impl ConcurrencyLockOptions {
|
||||
/// Default options for [`crate::control_plane::provider::ApiLocks`].
|
||||
/// Default options for [`crate::control_plane::client::ApiLocks`].
|
||||
pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
|
||||
/// Default options for [`crate::control_plane::provider::ApiLocks`].
|
||||
/// Default options for [`crate::control_plane::client::ApiLocks`].
|
||||
pub const DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK: &'static str =
|
||||
"shards=64,permits=100,epoch=10m,timeout=10ms";
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use futures::TryFutureExt;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, Instrument};
|
||||
use tracing::{debug, error, info, Instrument};
|
||||
|
||||
use crate::auth::backend::ConsoleRedirectBackend;
|
||||
use crate::cancellation::{CancellationHandlerMain, CancellationHandlerMainInternal};
|
||||
@@ -11,7 +11,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::{read_proxy_protocol, ConnectionInfo};
|
||||
use crate::protocol2::{read_proxy_protocol, ConnectHeader, ConnectionInfo};
|
||||
use crate::proxy::connect_compute::{connect_to_compute, TcpMechanism};
|
||||
use crate::proxy::handshake::{handshake, HandshakeData};
|
||||
use crate::proxy::passthrough::ProxyPassthrough;
|
||||
@@ -49,7 +49,7 @@ pub async fn task_main(
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
|
||||
tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
|
||||
@@ -57,16 +57,21 @@ pub async fn task_main(
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
Ok((_socket, ConnectHeader::Local)) => {
|
||||
debug!("healthcheck received");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, ConnectHeader::Missing)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
|
||||
error!("missing required proxy protocol header");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
|
||||
Ok((_socket, ConnectHeader::Proxy(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
|
||||
error!("proxy protocol header not supported");
|
||||
return;
|
||||
}
|
||||
Ok((socket, Some(info))) => (socket, info),
|
||||
Ok((socket, None)) => (socket, ConnectionInfo{ addr: peer_addr, extra: None }),
|
||||
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
||||
Ok((socket, ConnectHeader::Missing)) => (socket, ConnectionInfo{ addr: peer_addr, extra: None }),
|
||||
};
|
||||
|
||||
match socket.inner.set_nodelay(true) {
|
||||
|
||||
@@ -75,7 +75,7 @@ struct RequestMonitoringInner {
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) enum AuthMethod {
|
||||
// aka passwordless, fka link
|
||||
Web,
|
||||
ConsoleRedirect,
|
||||
ScramSha256,
|
||||
ScramSha256Plus,
|
||||
Cleartext,
|
||||
|
||||
@@ -134,7 +134,7 @@ impl From<&RequestMonitoringInner> for RequestData {
|
||||
.as_ref()
|
||||
.and_then(|options| serde_json::to_string(&Options { options }).ok()),
|
||||
auth_method: value.auth_method.as_ref().map(|x| match x {
|
||||
super::AuthMethod::Web => "web",
|
||||
super::AuthMethod::ConsoleRedirect => "console_redirect",
|
||||
super::AuthMethod::ScramSha256 => "scram_sha_256",
|
||||
super::AuthMethod::ScramSha256Plus => "scram_sha_256_plus",
|
||||
super::AuthMethod::Cleartext => "cleartext",
|
||||
|
||||
@@ -9,16 +9,17 @@ use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::Client;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
use super::errors::{ApiError, GetAuthInfoError, WakeComputeError};
|
||||
use super::{AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::IpPattern;
|
||||
use crate::cache::Cached;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::errors::GetEndpointJwksError;
|
||||
use crate::control_plane::client::{CachedAllowedIps, CachedRoleSecret};
|
||||
use crate::control_plane::errors::{
|
||||
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
|
||||
};
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::control_plane::provider::{CachedAllowedIps, CachedRoleSecret};
|
||||
use crate::control_plane::{AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::io_error;
|
||||
use crate::intern::RoleNameInt;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
|
||||
@@ -31,25 +32,25 @@ enum MockApiError {
|
||||
PasswordNotSet(tokio_postgres::Error),
|
||||
}
|
||||
|
||||
impl From<MockApiError> for ApiError {
|
||||
impl From<MockApiError> for ControlPlaneError {
|
||||
fn from(e: MockApiError) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio_postgres::Error> for ApiError {
|
||||
impl From<tokio_postgres::Error> for ControlPlaneError {
|
||||
fn from(e: tokio_postgres::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
pub struct MockControlPlane {
|
||||
endpoint: ApiUrl,
|
||||
ip_allowlist_check_enabled: bool,
|
||||
}
|
||||
|
||||
impl Api {
|
||||
impl MockControlPlane {
|
||||
pub fn new(endpoint: ApiUrl, ip_allowlist_check_enabled: bool) -> Self {
|
||||
Self {
|
||||
endpoint,
|
||||
@@ -201,7 +202,7 @@ async fn get_execute_postgres_query(
|
||||
Ok(Some(entry))
|
||||
}
|
||||
|
||||
impl super::Api for Api {
|
||||
impl super::ControlPlaneApi for MockControlPlane {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
281
proxy/src/control_plane/client/mod.rs
Normal file
281
proxy/src/control_plane/client/mod.rs
Normal file
@@ -0,0 +1,281 @@
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
pub mod mock;
|
||||
pub mod neon;
|
||||
|
||||
use std::hash::Hash;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::time::Instant;
|
||||
use tracing::info;
|
||||
|
||||
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::cache::endpoints::EndpointsCache;
|
||||
use crate::cache::project_info::ProjectInfoCacheImpl;
|
||||
use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::{
|
||||
errors, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, NodeInfoCache,
|
||||
};
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::ApiLockMetrics;
|
||||
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
|
||||
use crate::types::EndpointId;
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone)]
|
||||
pub enum ControlPlaneClient {
|
||||
/// Current Management API (V2).
|
||||
Neon(neon::NeonControlPlaneClient),
|
||||
/// Local mock control plane.
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
PostgresMock(mock::MockControlPlane),
|
||||
/// Internal testing
|
||||
#[cfg(test)]
|
||||
#[allow(private_interfaces)]
|
||||
Test(Box<dyn TestControlPlaneClient>),
|
||||
}
|
||||
|
||||
impl ControlPlaneApi for ControlPlaneClient {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::Neon(api) => api.get_role_secret(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(_) => {
|
||||
unreachable!("this function should never be called in the test backend")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::Neon(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.get_allowed_ips_and_secret(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
|
||||
match self {
|
||||
Self::Neon(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(_api) => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
|
||||
match self {
|
||||
Self::Neon(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.wake_compute(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) trait TestControlPlaneClient: Send + Sync + 'static {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
|
||||
fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
|
||||
|
||||
fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Clone for Box<dyn TestControlPlaneClient> {
|
||||
fn clone(&self) -> Self {
|
||||
TestControlPlaneClient::dyn_clone(&**self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Various caches for [`control_plane`](super).
|
||||
pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
pub(crate) node_info: NodeInfoCache,
|
||||
/// Cache which stores project_id -> endpoint_ids mapping.
|
||||
pub project_info: Arc<ProjectInfoCacheImpl>,
|
||||
/// List of all valid endpoints.
|
||||
pub endpoints_cache: Arc<EndpointsCache>,
|
||||
}
|
||||
|
||||
impl ApiCaches {
|
||||
pub fn new(
|
||||
wake_compute_cache_config: CacheOptions,
|
||||
project_info_cache_config: ProjectInfoCacheOptions,
|
||||
endpoint_cache_config: EndpointCacheConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
node_info: NodeInfoCache::new(
|
||||
"node_info_cache",
|
||||
wake_compute_cache_config.size,
|
||||
wake_compute_cache_config.ttl,
|
||||
true,
|
||||
),
|
||||
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
|
||||
endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Various caches for [`control_plane`](super).
|
||||
pub struct ApiLocks<K> {
|
||||
name: &'static str,
|
||||
node_locks: DashMap<K, Arc<DynamicLimiter>>,
|
||||
config: RateLimiterConfig,
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
metrics: &'static ApiLockMetrics,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum ApiLockError {
|
||||
#[error("timeout acquiring resource permit")]
|
||||
TimeoutError(#[from] tokio::time::error::Elapsed),
|
||||
}
|
||||
|
||||
impl ReportableError for ApiLockError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
config: RateLimiterConfig,
|
||||
shards: usize,
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
metrics: &'static ApiLockMetrics,
|
||||
) -> prometheus::Result<Self> {
|
||||
Ok(Self {
|
||||
name,
|
||||
node_locks: DashMap::with_shard_amount(shards),
|
||||
config,
|
||||
timeout,
|
||||
epoch,
|
||||
metrics,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
|
||||
if self.config.initial_limit == 0 {
|
||||
return Ok(WakeComputePermit {
|
||||
permit: Token::disabled(),
|
||||
});
|
||||
}
|
||||
let now = Instant::now();
|
||||
let semaphore = {
|
||||
// get fast path
|
||||
if let Some(semaphore) = self.node_locks.get(key) {
|
||||
semaphore.clone()
|
||||
} else {
|
||||
self.node_locks
|
||||
.entry(key.clone())
|
||||
.or_insert_with(|| {
|
||||
self.metrics.semaphores_registered.inc();
|
||||
DynamicLimiter::new(self.config)
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
};
|
||||
let permit = semaphore.acquire_timeout(self.timeout).await;
|
||||
|
||||
self.metrics
|
||||
.semaphore_acquire_seconds
|
||||
.observe(now.elapsed().as_secs_f64());
|
||||
info!("acquired permit {:?}", now.elapsed().as_secs_f64());
|
||||
Ok(WakeComputePermit { permit: permit? })
|
||||
}
|
||||
|
||||
pub async fn garbage_collect_worker(&self) {
|
||||
if self.config.initial_limit == 0 {
|
||||
return;
|
||||
}
|
||||
let mut interval =
|
||||
tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
|
||||
loop {
|
||||
for (i, shard) in self.node_locks.shards().iter().enumerate() {
|
||||
interval.tick().await;
|
||||
// temporary lock a single shard and then clear any semaphores that aren't currently checked out
|
||||
// race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
|
||||
// therefore releasing it is safe from race conditions
|
||||
info!(
|
||||
name = self.name,
|
||||
shard = i,
|
||||
"performing epoch reclamation on api lock"
|
||||
);
|
||||
let mut lock = shard.write();
|
||||
let timer = self.metrics.reclamation_lag_seconds.start_timer();
|
||||
let count = lock
|
||||
.extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
|
||||
.count();
|
||||
drop(lock);
|
||||
self.metrics.semaphores_unregistered.inc_by(count as u64);
|
||||
timer.observe();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct WakeComputePermit {
|
||||
permit: Token,
|
||||
}
|
||||
|
||||
impl WakeComputePermit {
|
||||
pub(crate) fn should_check_cache(&self) -> bool {
|
||||
!self.permit.is_disabled()
|
||||
}
|
||||
pub(crate) fn release(self, outcome: Outcome) {
|
||||
self.permit.release(outcome);
|
||||
}
|
||||
pub(crate) fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
|
||||
match res {
|
||||
Ok(_) => self.release(Outcome::Success),
|
||||
Err(_) => self.release(Outcome::Overload),
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl FetchAuthRules for ControlPlaneClient {
|
||||
async fn fetch_auth_rules(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
|
||||
self.get_endpoint_jwks(ctx, endpoint)
|
||||
.await
|
||||
.map_err(FetchAuthRulesError::GetEndpointJwks)
|
||||
}
|
||||
}
|
||||
@@ -10,18 +10,20 @@ use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use super::super::messages::{ControlPlaneError, GetRoleSecret, WakeCompute};
|
||||
use super::errors::{ApiError, GetAuthInfoError, WakeComputeError};
|
||||
use super::{
|
||||
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
|
||||
NodeInfo,
|
||||
};
|
||||
use super::super::messages::{ControlPlaneErrorMessage, GetRoleSecret, WakeCompute};
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::cache::Cached;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::errors::GetEndpointJwksError;
|
||||
use crate::control_plane::caches::ApiCaches;
|
||||
use crate::control_plane::errors::{
|
||||
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
|
||||
};
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
|
||||
use crate::control_plane::{
|
||||
AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo,
|
||||
};
|
||||
use crate::metrics::{CacheOutcome, Metrics};
|
||||
use crate::rate_limiter::WakeComputeRateLimiter;
|
||||
use crate::types::{EndpointCacheKey, EndpointId};
|
||||
@@ -30,7 +32,7 @@ use crate::{compute, http, scram};
|
||||
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
pub struct NeonControlPlaneClient {
|
||||
endpoint: http::Endpoint,
|
||||
pub caches: &'static ApiCaches,
|
||||
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
@@ -39,17 +41,15 @@ pub struct Api {
|
||||
jwt: Arc<str>,
|
||||
}
|
||||
|
||||
impl Api {
|
||||
impl NeonControlPlaneClient {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub fn new(
|
||||
endpoint: http::Endpoint,
|
||||
jwt: Arc<str>,
|
||||
caches: &'static ApiCaches,
|
||||
locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||
) -> Self {
|
||||
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN")
|
||||
.unwrap_or_default()
|
||||
.into();
|
||||
Self {
|
||||
endpoint,
|
||||
caches,
|
||||
@@ -256,7 +256,7 @@ impl Api {
|
||||
}
|
||||
}
|
||||
|
||||
impl super::Api for Api {
|
||||
impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
@@ -356,7 +356,7 @@ impl super::Api for Api {
|
||||
let (cached, info) = cached.take_value();
|
||||
let info = info.map_err(|c| {
|
||||
info!(key = &*key, "found cached wake_compute error");
|
||||
WakeComputeError::ApiError(ApiError::ControlPlane(Box::new(*c)))
|
||||
WakeComputeError::ControlPlane(ControlPlaneError::Message(Box::new(*c)))
|
||||
})?;
|
||||
|
||||
debug!(key = &*key, "found cached compute node info");
|
||||
@@ -403,9 +403,11 @@ impl super::Api for Api {
|
||||
Ok(cached.map(|()| node))
|
||||
}
|
||||
Err(err) => match err {
|
||||
WakeComputeError::ApiError(ApiError::ControlPlane(err)) => {
|
||||
WakeComputeError::ControlPlane(ControlPlaneError::Message(err)) => {
|
||||
let Some(status) = &err.status else {
|
||||
return Err(WakeComputeError::ApiError(ApiError::ControlPlane(err)));
|
||||
return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
|
||||
err,
|
||||
)));
|
||||
};
|
||||
|
||||
let reason = status
|
||||
@@ -415,7 +417,9 @@ impl super::Api for Api {
|
||||
|
||||
// if we can retry this error, do not cache it.
|
||||
if reason.can_retry() {
|
||||
return Err(WakeComputeError::ApiError(ApiError::ControlPlane(err)));
|
||||
return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
|
||||
err,
|
||||
)));
|
||||
}
|
||||
|
||||
// at this point, we should only have quota errors.
|
||||
@@ -430,7 +434,9 @@ impl super::Api for Api {
|
||||
Duration::from_secs(30),
|
||||
);
|
||||
|
||||
Err(WakeComputeError::ApiError(ApiError::ControlPlane(err)))
|
||||
Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
|
||||
err,
|
||||
)))
|
||||
}
|
||||
err => return Err(err),
|
||||
},
|
||||
@@ -441,7 +447,7 @@ impl super::Api for Api {
|
||||
/// Parse http response body, taking status code into account.
|
||||
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
response: http::Response,
|
||||
) -> Result<T, ApiError> {
|
||||
) -> Result<T, ControlPlaneError> {
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
// We shouldn't log raw body because it may contain secrets.
|
||||
@@ -456,7 +462,7 @@ async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
// as the fact that the request itself has failed.
|
||||
let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
|
||||
warn!("failed to parse error body: {e}");
|
||||
ControlPlaneError {
|
||||
ControlPlaneErrorMessage {
|
||||
error: "reason unclear (malformed error message)".into(),
|
||||
http_status_code: status,
|
||||
status: None,
|
||||
@@ -465,7 +471,7 @@ async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
body.http_status_code = status;
|
||||
|
||||
warn!("console responded with an error ({status}): {body:?}");
|
||||
Err(ApiError::ControlPlane(Box::new(body)))
|
||||
Err(ControlPlaneError::Message(Box::new(body)))
|
||||
}
|
||||
|
||||
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
|
||||
216
proxy/src/control_plane/errors.rs
Normal file
216
proxy/src/control_plane/errors.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::control_plane::client::ApiLockError;
|
||||
use crate::control_plane::messages::{self, ControlPlaneErrorMessage, Reason};
|
||||
use crate::error::{io_error, ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::proxy::retry::CouldRetry;
|
||||
|
||||
/// A go-to error message which doesn't leak any detail.
|
||||
pub(crate) const REQUEST_FAILED: &str = "Console request failed";
|
||||
|
||||
/// Common console API error.
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum ControlPlaneError {
|
||||
/// Error returned by the console itself.
|
||||
#[error("{REQUEST_FAILED} with {0}")]
|
||||
Message(Box<ControlPlaneErrorMessage>),
|
||||
|
||||
/// Various IO errors like broken pipe or malformed payload.
|
||||
#[error("{REQUEST_FAILED}: {0}")]
|
||||
Transport(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl ControlPlaneError {
|
||||
/// Returns HTTP status code if it's the reason for failure.
|
||||
pub(crate) fn get_reason(&self) -> messages::Reason {
|
||||
match self {
|
||||
ControlPlaneError::Message(e) => e.get_reason(),
|
||||
ControlPlaneError::Transport(_) => messages::Reason::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ControlPlaneError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
// To minimize risks, only select errors are forwarded to users.
|
||||
ControlPlaneError::Message(c) => c.get_user_facing_message(),
|
||||
ControlPlaneError::Transport(_) => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ControlPlaneError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ControlPlaneError::Message(e) => match e.get_reason() {
|
||||
Reason::RoleProtected => ErrorKind::User,
|
||||
Reason::ResourceNotFound => ErrorKind::User,
|
||||
Reason::ProjectNotFound => ErrorKind::User,
|
||||
Reason::EndpointNotFound => ErrorKind::User,
|
||||
Reason::BranchNotFound => ErrorKind::User,
|
||||
Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,
|
||||
Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::Quota,
|
||||
Reason::ActiveTimeQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::ComputeTimeQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::WrittenDataQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::DataTransferQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::LogicalSizeQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane,
|
||||
Reason::LockAlreadyTaken => ErrorKind::ControlPlane,
|
||||
Reason::RunningOperations => ErrorKind::ControlPlane,
|
||||
Reason::ActiveEndpointsLimitExceeded => ErrorKind::ControlPlane,
|
||||
Reason::Unknown => ErrorKind::ControlPlane,
|
||||
},
|
||||
ControlPlaneError::Transport(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for ControlPlaneError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
// retry some transport errors
|
||||
Self::Transport(io) => io.could_retry(),
|
||||
Self::Message(e) => e.could_retry(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for ControlPlaneError {
|
||||
fn from(e: reqwest::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest_middleware::Error> for ControlPlaneError {
|
||||
fn from(e: reqwest_middleware::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum GetAuthInfoError {
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Console responded with a malformed auth secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ControlPlaneError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ControlPlaneError>> From<E> for GetAuthInfoError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ApiError(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for GetAuthInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
// We absolutely should not leak any secrets!
|
||||
Self::BadSecret => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
Self::ApiError(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for GetAuthInfoError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Self::BadSecret => crate::error::ErrorKind::ControlPlane,
|
||||
Self::ApiError(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum WakeComputeError {
|
||||
#[error("Console responded with a malformed compute address: {0}")]
|
||||
BadComputeAddress(Box<str>),
|
||||
|
||||
#[error(transparent)]
|
||||
ControlPlane(ControlPlaneError),
|
||||
|
||||
#[error("Too many connections attempts")]
|
||||
TooManyConnections,
|
||||
|
||||
#[error("error acquiring resource permit: {0}")]
|
||||
TooManyConnectionAttempts(#[from] ApiLockError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ControlPlaneError>> From<E> for WakeComputeError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ControlPlane(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for WakeComputeError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
// We shouldn't show user the address even if it's broken.
|
||||
// Besides, user is unlikely to care about this detail.
|
||||
Self::BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
|
||||
// However, control plane might return a meaningful error.
|
||||
Self::ControlPlane(e) => e.to_string_client(),
|
||||
|
||||
Self::TooManyConnections => self.to_string(),
|
||||
|
||||
Self::TooManyConnectionAttempts(_) => {
|
||||
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for WakeComputeError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Self::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
|
||||
Self::ControlPlane(e) => e.get_error_kind(),
|
||||
Self::TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
Self::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for WakeComputeError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
Self::BadComputeAddress(_) => false,
|
||||
Self::ControlPlane(e) => e.could_retry(),
|
||||
Self::TooManyConnections => false,
|
||||
Self::TooManyConnectionAttempts(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GetEndpointJwksError {
|
||||
#[error("endpoint not found")]
|
||||
EndpointNotFound,
|
||||
|
||||
#[error("failed to build control plane request: {0}")]
|
||||
RequestBuild(#[source] reqwest::Error),
|
||||
|
||||
#[error("failed to send control plane request: {0}")]
|
||||
RequestExecute(#[source] reqwest_middleware::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
ControlPlane(#[from] ControlPlaneError),
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
#[error(transparent)]
|
||||
TokioPostgres(#[from] tokio_postgres::Error),
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
#[error(transparent)]
|
||||
ParseUrl(#[from] url::ParseError),
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
#[error(transparent)]
|
||||
TaskJoin(#[from] tokio::task::JoinError),
|
||||
}
|
||||
@@ -10,14 +10,14 @@ use crate::proxy::retry::CouldRetry;
|
||||
/// Generic error response with human-readable description.
|
||||
/// Note that we can't always present it to user as is.
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub(crate) struct ControlPlaneError {
|
||||
pub(crate) struct ControlPlaneErrorMessage {
|
||||
pub(crate) error: Box<str>,
|
||||
#[serde(skip)]
|
||||
pub(crate) http_status_code: http::StatusCode,
|
||||
pub(crate) status: Option<Status>,
|
||||
}
|
||||
|
||||
impl ControlPlaneError {
|
||||
impl ControlPlaneErrorMessage {
|
||||
pub(crate) fn get_reason(&self) -> Reason {
|
||||
self.status
|
||||
.as_ref()
|
||||
@@ -26,7 +26,7 @@ impl ControlPlaneError {
|
||||
}
|
||||
|
||||
pub(crate) fn get_user_facing_message(&self) -> String {
|
||||
use super::provider::errors::REQUEST_FAILED;
|
||||
use super::errors::REQUEST_FAILED;
|
||||
self.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.details.user_facing_message.as_ref())
|
||||
@@ -51,7 +51,7 @@ impl ControlPlaneError {
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ControlPlaneError {
|
||||
impl Display for ControlPlaneErrorMessage {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let msg: &str = self
|
||||
.status
|
||||
@@ -62,7 +62,7 @@ impl Display for ControlPlaneError {
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for ControlPlaneError {
|
||||
impl CouldRetry for ControlPlaneErrorMessage {
|
||||
fn could_retry(&self) -> bool {
|
||||
// If the error message does not have a status,
|
||||
// the error is unknown and probably should not retry automatically
|
||||
@@ -245,7 +245,7 @@ pub(crate) struct WakeCompute {
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
/// Async response which concludes the web auth flow.
|
||||
/// Async response which concludes the console redirect auth flow.
|
||||
/// Also known as `kickResponse` in the console.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct KickSession<'a> {
|
||||
|
||||
@@ -24,8 +24,8 @@ pub(crate) fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), wai
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
/// Console management API listener task.
|
||||
/// It spawns console response handlers needed for the web auth.
|
||||
/// Management API listener task.
|
||||
/// It spawns management response handlers needed for the console redirect auth flow.
|
||||
pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
|
||||
scopeguard::defer! {
|
||||
info!("mgmt has shut down");
|
||||
@@ -43,13 +43,13 @@ pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
|
||||
|
||||
tokio::task::spawn(
|
||||
async move {
|
||||
info!("serving a new console management API connection");
|
||||
info!("serving a new management API connection");
|
||||
|
||||
// these might be long running connections, have a separate logging for cancelling
|
||||
// on shutdown and other ways of stopping.
|
||||
let cancelled = scopeguard::guard(tracing::Span::current(), |span| {
|
||||
let _e = span.entered();
|
||||
info!("console management API task cancelled");
|
||||
info!("management API task cancelled");
|
||||
});
|
||||
|
||||
if let Err(e) = handle_connection(socket).await {
|
||||
|
||||
@@ -5,18 +5,137 @@
|
||||
pub mod messages;
|
||||
|
||||
/// Wrappers for console APIs and their mocks.
|
||||
pub mod provider;
|
||||
pub(crate) use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo};
|
||||
pub mod client;
|
||||
|
||||
pub(crate) mod errors;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::IpPattern;
|
||||
use crate::cache::project_info::ProjectInfoCacheImpl;
|
||||
use crate::cache::{Cached, TimedLru};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
|
||||
use crate::intern::ProjectIdInt;
|
||||
use crate::types::{EndpointCacheKey, EndpointId};
|
||||
use crate::{compute, scram};
|
||||
|
||||
/// Various cache-related types.
|
||||
pub mod caches {
|
||||
pub use super::provider::ApiCaches;
|
||||
pub use super::client::ApiCaches;
|
||||
}
|
||||
|
||||
/// Various cache-related types.
|
||||
pub mod locks {
|
||||
pub use super::provider::ApiLocks;
|
||||
pub use super::client::ApiLocks;
|
||||
}
|
||||
|
||||
/// Console's management API.
|
||||
pub mod mgmt;
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub(crate) enum AuthSecret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct AuthInfo {
|
||||
pub(crate) secret: Option<AuthSecret>,
|
||||
/// List of IP addresses allowed for the autorization.
|
||||
pub(crate) allowed_ips: Vec<IpPattern>,
|
||||
/// Project ID. This is used for cache invalidation.
|
||||
pub(crate) project_id: Option<ProjectIdInt>,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
/// This is what we get after auth succeeded, but not before!
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct NodeInfo {
|
||||
/// Compute node connection params.
|
||||
/// It's sad that we have to clone this, but this will improve
|
||||
/// once we migrate to a bespoke connection logic.
|
||||
pub(crate) config: compute::ConnCfg,
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
|
||||
/// Whether we should accept self-signed certificates (for testing)
|
||||
pub(crate) allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
timeout: Duration,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.config
|
||||
.connect(
|
||||
ctx,
|
||||
self.allow_self_signed_compute,
|
||||
self.aux.clone(),
|
||||
timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
pub(crate) fn reuse_settings(&mut self, other: Self) {
|
||||
self.allow_self_signed_compute = other.allow_self_signed_compute;
|
||||
self.config.reuse_password(other.config);
|
||||
}
|
||||
|
||||
pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
|
||||
match keys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) type NodeInfoCache =
|
||||
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneErrorMessage>>>;
|
||||
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
|
||||
pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
pub(crate) trait ControlPlaneApi {
|
||||
/// Get the client's auth secret for authentication.
|
||||
/// Returns option because user not found situation is special.
|
||||
/// We still have to mock the scram to avoid leaking information that user doesn't exist.
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
}
|
||||
|
||||
@@ -1,588 +0,0 @@
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
pub mod mock;
|
||||
pub mod neon;
|
||||
|
||||
use std::hash::Hash;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::time::Instant;
|
||||
use tracing::info;
|
||||
|
||||
use super::messages::{ControlPlaneError, MetricsAuxInfo};
|
||||
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::IpPattern;
|
||||
use crate::cache::endpoints::EndpointsCache;
|
||||
use crate::cache::project_info::ProjectInfoCacheImpl;
|
||||
use crate::cache::{Cached, TimedLru};
|
||||
use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::error::ReportableError;
|
||||
use crate::intern::ProjectIdInt;
|
||||
use crate::metrics::ApiLockMetrics;
|
||||
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
|
||||
use crate::types::{EndpointCacheKey, EndpointId};
|
||||
use crate::{compute, scram};
|
||||
|
||||
pub(crate) mod errors {
|
||||
use thiserror::Error;
|
||||
|
||||
use super::ApiLockError;
|
||||
use crate::control_plane::messages::{self, ControlPlaneError, Reason};
|
||||
use crate::error::{io_error, ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::proxy::retry::CouldRetry;
|
||||
|
||||
/// A go-to error message which doesn't leak any detail.
|
||||
pub(crate) const REQUEST_FAILED: &str = "Console request failed";
|
||||
|
||||
/// Common console API error.
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum ApiError {
|
||||
/// Error returned by the console itself.
|
||||
#[error("{REQUEST_FAILED} with {0}")]
|
||||
ControlPlane(Box<ControlPlaneError>),
|
||||
|
||||
/// Various IO errors like broken pipe or malformed payload.
|
||||
#[error("{REQUEST_FAILED}: {0}")]
|
||||
Transport(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
/// Returns HTTP status code if it's the reason for failure.
|
||||
pub(crate) fn get_reason(&self) -> messages::Reason {
|
||||
match self {
|
||||
ApiError::ControlPlane(e) => e.get_reason(),
|
||||
ApiError::Transport(_) => messages::Reason::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ApiError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
// To minimize risks, only select errors are forwarded to users.
|
||||
ApiError::ControlPlane(c) => c.get_user_facing_message(),
|
||||
ApiError::Transport(_) => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ApiError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ApiError::ControlPlane(e) => match e.get_reason() {
|
||||
Reason::RoleProtected => ErrorKind::User,
|
||||
Reason::ResourceNotFound => ErrorKind::User,
|
||||
Reason::ProjectNotFound => ErrorKind::User,
|
||||
Reason::EndpointNotFound => ErrorKind::User,
|
||||
Reason::BranchNotFound => ErrorKind::User,
|
||||
Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,
|
||||
Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::Quota,
|
||||
Reason::ActiveTimeQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::ComputeTimeQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::WrittenDataQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::DataTransferQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::LogicalSizeQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane,
|
||||
Reason::LockAlreadyTaken => ErrorKind::ControlPlane,
|
||||
Reason::RunningOperations => ErrorKind::ControlPlane,
|
||||
Reason::ActiveEndpointsLimitExceeded => ErrorKind::ControlPlane,
|
||||
Reason::Unknown => ErrorKind::ControlPlane,
|
||||
},
|
||||
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for ApiError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
// retry some transport errors
|
||||
Self::Transport(io) => io.could_retry(),
|
||||
Self::ControlPlane(e) => e.could_retry(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for ApiError {
|
||||
fn from(e: reqwest::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest_middleware::Error> for ApiError {
|
||||
fn from(e: reqwest_middleware::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum GetAuthInfoError {
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Console responded with a malformed auth secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ApiError(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for GetAuthInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
// We absolutely should not leak any secrets!
|
||||
Self::BadSecret => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
Self::ApiError(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for GetAuthInfoError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Self::BadSecret => crate::error::ErrorKind::ControlPlane,
|
||||
Self::ApiError(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum WakeComputeError {
|
||||
#[error("Console responded with a malformed compute address: {0}")]
|
||||
BadComputeAddress(Box<str>),
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
|
||||
#[error("Too many connections attempts")]
|
||||
TooManyConnections,
|
||||
|
||||
#[error("error acquiring resource permit: {0}")]
|
||||
TooManyConnectionAttempts(#[from] ApiLockError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ApiError>> From<E> for WakeComputeError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ApiError(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for WakeComputeError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
// We shouldn't show user the address even if it's broken.
|
||||
// Besides, user is unlikely to care about this detail.
|
||||
Self::BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
Self::ApiError(e) => e.to_string_client(),
|
||||
|
||||
Self::TooManyConnections => self.to_string(),
|
||||
|
||||
Self::TooManyConnectionAttempts(_) => {
|
||||
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for WakeComputeError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Self::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
|
||||
Self::ApiError(e) => e.get_error_kind(),
|
||||
Self::TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
Self::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for WakeComputeError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
Self::BadComputeAddress(_) => false,
|
||||
Self::ApiError(e) => e.could_retry(),
|
||||
Self::TooManyConnections => false,
|
||||
Self::TooManyConnectionAttempts(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GetEndpointJwksError {
|
||||
#[error("endpoint not found")]
|
||||
EndpointNotFound,
|
||||
|
||||
#[error("failed to build control plane request: {0}")]
|
||||
RequestBuild(#[source] reqwest::Error),
|
||||
|
||||
#[error("failed to send control plane request: {0}")]
|
||||
RequestExecute(#[source] reqwest_middleware::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
ControlPlane(#[from] ApiError),
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
#[error(transparent)]
|
||||
TokioPostgres(#[from] tokio_postgres::Error),
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
#[error(transparent)]
|
||||
ParseUrl(#[from] url::ParseError),
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
#[error(transparent)]
|
||||
TaskJoin(#[from] tokio::task::JoinError),
|
||||
}
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub(crate) enum AuthSecret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct AuthInfo {
|
||||
pub(crate) secret: Option<AuthSecret>,
|
||||
/// List of IP addresses allowed for the autorization.
|
||||
pub(crate) allowed_ips: Vec<IpPattern>,
|
||||
/// Project ID. This is used for cache invalidation.
|
||||
pub(crate) project_id: Option<ProjectIdInt>,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
/// This is what we get after auth succeeded, but not before!
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct NodeInfo {
|
||||
/// Compute node connection params.
|
||||
/// It's sad that we have to clone this, but this will improve
|
||||
/// once we migrate to a bespoke connection logic.
|
||||
pub(crate) config: compute::ConnCfg,
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
|
||||
/// Whether we should accept self-signed certificates (for testing)
|
||||
pub(crate) allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
timeout: Duration,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.config
|
||||
.connect(
|
||||
ctx,
|
||||
self.allow_self_signed_compute,
|
||||
self.aux.clone(),
|
||||
timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
pub(crate) fn reuse_settings(&mut self, other: Self) {
|
||||
self.allow_self_signed_compute = other.allow_self_signed_compute;
|
||||
self.config.reuse_password(other.config);
|
||||
}
|
||||
|
||||
pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
|
||||
match keys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) type NodeInfoCache =
|
||||
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneError>>>;
|
||||
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
|
||||
pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
pub(crate) trait Api {
|
||||
/// Get the client's auth secret for authentication.
|
||||
/// Returns option because user not found situation is special.
|
||||
/// We still have to mock the scram to avoid leaking information that user doesn't exist.
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone)]
|
||||
pub enum ControlPlaneBackend {
|
||||
/// Current Management API (V2).
|
||||
Management(neon::Api),
|
||||
/// Local mock control plane.
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
PostgresMock(mock::Api),
|
||||
/// Internal testing
|
||||
#[cfg(test)]
|
||||
#[allow(private_interfaces)]
|
||||
Test(Box<dyn crate::auth::backend::TestBackend>),
|
||||
}
|
||||
|
||||
impl Api for ControlPlaneBackend {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::Management(api) => api.get_role_secret(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(_) => {
|
||||
unreachable!("this function should never be called in the test backend")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::Management(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.get_allowed_ips_and_secret(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
|
||||
match self {
|
||||
Self::Management(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(_api) => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
|
||||
match self {
|
||||
Self::Management(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.wake_compute(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Various caches for [`control_plane`](super).
|
||||
pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
pub(crate) node_info: NodeInfoCache,
|
||||
/// Cache which stores project_id -> endpoint_ids mapping.
|
||||
pub project_info: Arc<ProjectInfoCacheImpl>,
|
||||
/// List of all valid endpoints.
|
||||
pub endpoints_cache: Arc<EndpointsCache>,
|
||||
}
|
||||
|
||||
impl ApiCaches {
|
||||
pub fn new(
|
||||
wake_compute_cache_config: CacheOptions,
|
||||
project_info_cache_config: ProjectInfoCacheOptions,
|
||||
endpoint_cache_config: EndpointCacheConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
node_info: NodeInfoCache::new(
|
||||
"node_info_cache",
|
||||
wake_compute_cache_config.size,
|
||||
wake_compute_cache_config.ttl,
|
||||
true,
|
||||
),
|
||||
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
|
||||
endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Various caches for [`control_plane`](super).
|
||||
pub struct ApiLocks<K> {
|
||||
name: &'static str,
|
||||
node_locks: DashMap<K, Arc<DynamicLimiter>>,
|
||||
config: RateLimiterConfig,
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
metrics: &'static ApiLockMetrics,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum ApiLockError {
|
||||
#[error("timeout acquiring resource permit")]
|
||||
TimeoutError(#[from] tokio::time::error::Elapsed),
|
||||
}
|
||||
|
||||
impl ReportableError for ApiLockError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
config: RateLimiterConfig,
|
||||
shards: usize,
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
metrics: &'static ApiLockMetrics,
|
||||
) -> prometheus::Result<Self> {
|
||||
Ok(Self {
|
||||
name,
|
||||
node_locks: DashMap::with_shard_amount(shards),
|
||||
config,
|
||||
timeout,
|
||||
epoch,
|
||||
metrics,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
|
||||
if self.config.initial_limit == 0 {
|
||||
return Ok(WakeComputePermit {
|
||||
permit: Token::disabled(),
|
||||
});
|
||||
}
|
||||
let now = Instant::now();
|
||||
let semaphore = {
|
||||
// get fast path
|
||||
if let Some(semaphore) = self.node_locks.get(key) {
|
||||
semaphore.clone()
|
||||
} else {
|
||||
self.node_locks
|
||||
.entry(key.clone())
|
||||
.or_insert_with(|| {
|
||||
self.metrics.semaphores_registered.inc();
|
||||
DynamicLimiter::new(self.config)
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
};
|
||||
let permit = semaphore.acquire_timeout(self.timeout).await;
|
||||
|
||||
self.metrics
|
||||
.semaphore_acquire_seconds
|
||||
.observe(now.elapsed().as_secs_f64());
|
||||
info!("acquired permit {:?}", now.elapsed().as_secs_f64());
|
||||
Ok(WakeComputePermit { permit: permit? })
|
||||
}
|
||||
|
||||
pub async fn garbage_collect_worker(&self) {
|
||||
if self.config.initial_limit == 0 {
|
||||
return;
|
||||
}
|
||||
let mut interval =
|
||||
tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
|
||||
loop {
|
||||
for (i, shard) in self.node_locks.shards().iter().enumerate() {
|
||||
interval.tick().await;
|
||||
// temporary lock a single shard and then clear any semaphores that aren't currently checked out
|
||||
// race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
|
||||
// therefore releasing it is safe from race conditions
|
||||
info!(
|
||||
name = self.name,
|
||||
shard = i,
|
||||
"performing epoch reclamation on api lock"
|
||||
);
|
||||
let mut lock = shard.write();
|
||||
let timer = self.metrics.reclamation_lag_seconds.start_timer();
|
||||
let count = lock
|
||||
.extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
|
||||
.count();
|
||||
drop(lock);
|
||||
self.metrics.semaphores_unregistered.inc_by(count as u64);
|
||||
timer.observe();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct WakeComputePermit {
|
||||
permit: Token,
|
||||
}
|
||||
|
||||
impl WakeComputePermit {
|
||||
pub(crate) fn should_check_cache(&self) -> bool {
|
||||
!self.permit.is_disabled()
|
||||
}
|
||||
pub(crate) fn release(self, outcome: Outcome) {
|
||||
self.permit.release(outcome);
|
||||
}
|
||||
pub(crate) fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
|
||||
match res {
|
||||
Ok(_) => self.release(Outcome::Success),
|
||||
Err(_) => self.release(Outcome::Overload),
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl FetchAuthRules for ControlPlaneBackend {
|
||||
async fn fetch_auth_rules(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
|
||||
self.get_endpoint_jwks(ctx, endpoint)
|
||||
.await
|
||||
.map_err(FetchAuthRulesError::GetEndpointJwks)
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ use bytes::{Buf, Bytes, BytesMut};
|
||||
use pin_project_lite::pin_project;
|
||||
use strum_macros::FromRepr;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
use zerocopy::{FromBytes, FromZeroes};
|
||||
|
||||
pin_project! {
|
||||
/// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
|
||||
@@ -57,16 +58,31 @@ impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
|
||||
}
|
||||
|
||||
/// Proxy Protocol Version 2 Header
|
||||
const HEADER: [u8; 12] = [
|
||||
const SIGNATURE: [u8; 12] = [
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
];
|
||||
|
||||
const LOCAL_V2: u8 = 0x20;
|
||||
const PROXY_V2: u8 = 0x21;
|
||||
|
||||
const TCP_OVER_IPV4: u8 = 0x11;
|
||||
const UDP_OVER_IPV4: u8 = 0x12;
|
||||
const TCP_OVER_IPV6: u8 = 0x21;
|
||||
const UDP_OVER_IPV6: u8 = 0x22;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Debug)]
|
||||
pub struct ConnectionInfo {
|
||||
pub addr: SocketAddr,
|
||||
pub extra: Option<ConnectionInfoExtra>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Debug)]
|
||||
pub enum ConnectHeader {
|
||||
Missing,
|
||||
Local,
|
||||
Proxy(ConnectionInfo),
|
||||
}
|
||||
|
||||
impl fmt::Display for ConnectionInfo {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match &self.extra {
|
||||
@@ -89,96 +105,31 @@ pub enum ConnectionInfoExtra {
|
||||
|
||||
pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
mut read: T,
|
||||
) -> std::io::Result<(ChainRW<T>, Option<ConnectionInfo>)> {
|
||||
) -> std::io::Result<(ChainRW<T>, ConnectHeader)> {
|
||||
let mut buf = BytesMut::with_capacity(128);
|
||||
while buf.len() < 16 {
|
||||
let header = loop {
|
||||
let bytes_read = read.read_buf(&mut buf).await?;
|
||||
|
||||
// exit for bad header
|
||||
let len = usize::min(buf.len(), HEADER.len());
|
||||
if buf[..len] != HEADER[..len] {
|
||||
return Ok((ChainRW { inner: read, buf }, None));
|
||||
// exit for bad header signature
|
||||
let len = usize::min(buf.len(), SIGNATURE.len());
|
||||
if buf[..len] != SIGNATURE[..len] {
|
||||
return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
|
||||
}
|
||||
|
||||
// if no more bytes available then exit
|
||||
if bytes_read == 0 {
|
||||
return Ok((ChainRW { inner: read, buf }, None));
|
||||
return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
|
||||
};
|
||||
}
|
||||
|
||||
let header = buf.split_to(16);
|
||||
|
||||
// The next byte (the 13th one) is the protocol version and command.
|
||||
// The highest four bits contains the version. As of this specification, it must
|
||||
// always be sent as \x2 and the receiver must only accept this value.
|
||||
let vc = header[12];
|
||||
let version = vc >> 4;
|
||||
let command = vc & 0b1111;
|
||||
if version != 2 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"invalid proxy protocol version. expected version 2",
|
||||
));
|
||||
}
|
||||
match command {
|
||||
// the connection was established on purpose by the proxy
|
||||
// without being relayed. The connection endpoints are the sender and the
|
||||
// receiver. Such connections exist when the proxy sends health-checks to the
|
||||
// server. The receiver must accept this connection as valid and must use the
|
||||
// real connection endpoints and discard the protocol block including the
|
||||
// family which is ignored.
|
||||
0 => {}
|
||||
// the connection was established on behalf of another node,
|
||||
// and reflects the original connection endpoints. The receiver must then use
|
||||
// the information provided in the protocol block to get original the address.
|
||||
1 => {}
|
||||
// other values are unassigned and must not be emitted by senders. Receivers
|
||||
// must drop connections presenting unexpected values here.
|
||||
_ => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"invalid proxy protocol command. expected local (0) or proxy (1)",
|
||||
))
|
||||
// check if we have enough bytes to continue
|
||||
if let Some(header) = buf.try_get::<ProxyProtocolV2Header>() {
|
||||
break header;
|
||||
}
|
||||
};
|
||||
|
||||
// The 14th byte contains the transport protocol and address family. The highest 4
|
||||
// bits contain the address family, the lowest 4 bits contain the protocol.
|
||||
let ft = header[13];
|
||||
let address_length = match ft {
|
||||
// - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
|
||||
// protocol family. Address length is 2*4 + 2*2 = 12 bytes.
|
||||
// - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
|
||||
// protocol family. Address length is 2*4 + 2*2 = 12 bytes.
|
||||
0x11 | 0x12 => 12,
|
||||
// - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
|
||||
// protocol family. Address length is 2*16 + 2*2 = 36 bytes.
|
||||
// - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
|
||||
// protocol family. Address length is 2*16 + 2*2 = 36 bytes.
|
||||
0x21 | 0x22 => 36,
|
||||
// unspecified or unix stream. ignore the addresses
|
||||
_ => 0,
|
||||
};
|
||||
let remaining_length = usize::from(header.len.get());
|
||||
|
||||
// The 15th and 16th bytes is the address length in bytes in network endian order.
|
||||
// It is used so that the receiver knows how many address bytes to skip even when
|
||||
// it does not implement the presented protocol. Thus the length of the protocol
|
||||
// header in bytes is always exactly 16 + this value. When a sender presents a
|
||||
// LOCAL connection, it should not present any address so it sets this field to
|
||||
// zero. Receivers MUST always consider this field to skip the appropriate number
|
||||
// of bytes and must not assume zero is presented for LOCAL connections. When a
|
||||
// receiver accepts an incoming connection showing an UNSPEC address family or
|
||||
// protocol, it may or may not decide to log the address information if present.
|
||||
let remaining_length = u16::from_be_bytes(header[14..16].try_into().unwrap());
|
||||
if remaining_length < address_length {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"invalid proxy protocol length. not enough to fit requested IP addresses",
|
||||
));
|
||||
}
|
||||
drop(header);
|
||||
|
||||
while buf.len() < remaining_length as usize {
|
||||
while buf.len() < remaining_length {
|
||||
if read.read_buf(&mut buf).await? == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
@@ -186,36 +137,69 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
));
|
||||
}
|
||||
}
|
||||
let payload = buf.split_to(remaining_length);
|
||||
|
||||
// Starting from the 17th byte, addresses are presented in network byte order.
|
||||
// The address order is always the same :
|
||||
// - source layer 3 address in network byte order
|
||||
// - destination layer 3 address in network byte order
|
||||
// - source layer 4 address if any, in network byte order (port)
|
||||
// - destination layer 4 address if any, in network byte order (port)
|
||||
let mut header = buf.split_to(usize::from(remaining_length));
|
||||
let mut addr = header.split_to(usize::from(address_length));
|
||||
let socket = match addr.len() {
|
||||
12 => {
|
||||
let src_addr = Ipv4Addr::from_bits(addr.get_u32());
|
||||
let _dst_addr = Ipv4Addr::from_bits(addr.get_u32());
|
||||
let src_port = addr.get_u16();
|
||||
let _dst_port = addr.get_u16();
|
||||
Some(SocketAddr::from((src_addr, src_port)))
|
||||
let res = process_proxy_payload(header, payload)?;
|
||||
Ok((ChainRW { inner: read, buf }, res))
|
||||
}
|
||||
|
||||
fn process_proxy_payload(
|
||||
header: ProxyProtocolV2Header,
|
||||
mut payload: BytesMut,
|
||||
) -> std::io::Result<ConnectHeader> {
|
||||
match header.version_and_command {
|
||||
// the connection was established on purpose by the proxy
|
||||
// without being relayed. The connection endpoints are the sender and the
|
||||
// receiver. Such connections exist when the proxy sends health-checks to the
|
||||
// server. The receiver must accept this connection as valid and must use the
|
||||
// real connection endpoints and discard the protocol block including the
|
||||
// family which is ignored.
|
||||
LOCAL_V2 => return Ok(ConnectHeader::Local),
|
||||
// the connection was established on behalf of another node,
|
||||
// and reflects the original connection endpoints. The receiver must then use
|
||||
// the information provided in the protocol block to get original the address.
|
||||
PROXY_V2 => {}
|
||||
// other values are unassigned and must not be emitted by senders. Receivers
|
||||
// must drop connections presenting unexpected values here.
|
||||
#[rustfmt::skip] // https://github.com/rust-lang/rustfmt/issues/6384
|
||||
_ => return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!(
|
||||
"invalid proxy protocol command 0x{:02X}. expected local (0x20) or proxy (0x21)",
|
||||
header.version_and_command
|
||||
),
|
||||
)),
|
||||
};
|
||||
|
||||
let size_err =
|
||||
"invalid proxy protocol length. payload not large enough to fit requested IP addresses";
|
||||
let addr = match header.protocol_and_family {
|
||||
TCP_OVER_IPV4 | UDP_OVER_IPV4 => {
|
||||
let addr = payload
|
||||
.try_get::<ProxyProtocolV2HeaderV4>()
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, size_err))?;
|
||||
|
||||
SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
|
||||
}
|
||||
36 => {
|
||||
let src_addr = Ipv6Addr::from_bits(addr.get_u128());
|
||||
let _dst_addr = Ipv6Addr::from_bits(addr.get_u128());
|
||||
let src_port = addr.get_u16();
|
||||
let _dst_port = addr.get_u16();
|
||||
Some(SocketAddr::from((src_addr, src_port)))
|
||||
TCP_OVER_IPV6 | UDP_OVER_IPV6 => {
|
||||
let addr = payload
|
||||
.try_get::<ProxyProtocolV2HeaderV6>()
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, size_err))?;
|
||||
|
||||
SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
|
||||
}
|
||||
// unspecified or unix stream. ignore the addresses
|
||||
_ => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"invalid proxy protocol address family/transport protocol.",
|
||||
))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let mut extra = None;
|
||||
|
||||
while let Some(mut tlv) = read_tlv(&mut header) {
|
||||
while let Some(mut tlv) = read_tlv(&mut payload) {
|
||||
match Pp2Kind::from_repr(tlv.kind) {
|
||||
Some(Pp2Kind::Aws) => {
|
||||
if tlv.value.is_empty() {
|
||||
@@ -259,9 +243,7 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
}
|
||||
}
|
||||
|
||||
let conn_info = socket.map(|addr| ConnectionInfo { addr, extra });
|
||||
|
||||
Ok((ChainRW { inner: read, buf }, conn_info))
|
||||
Ok(ConnectHeader::Proxy(ConnectionInfo { addr, extra }))
|
||||
}
|
||||
|
||||
#[derive(FromRepr, Debug, Copy, Clone)]
|
||||
@@ -337,27 +319,93 @@ struct Tlv {
|
||||
}
|
||||
|
||||
fn read_tlv(b: &mut BytesMut) -> Option<Tlv> {
|
||||
if b.len() < 3 {
|
||||
return None;
|
||||
}
|
||||
let kind = b.get_u8();
|
||||
let len = usize::from(b.get_u16());
|
||||
let tlv_header = b.try_get::<TlvHeader>()?;
|
||||
let len = usize::from(tlv_header.len.get());
|
||||
if b.len() < len {
|
||||
return None;
|
||||
}
|
||||
let value = b.split_to(len).freeze();
|
||||
Some(Tlv { kind, value })
|
||||
Some(Tlv {
|
||||
kind: tlv_header.kind,
|
||||
value: b.split_to(len).freeze(),
|
||||
})
|
||||
}
|
||||
|
||||
trait BufExt: Sized {
|
||||
fn try_get<T: FromBytes>(&mut self) -> Option<T>;
|
||||
}
|
||||
impl BufExt for BytesMut {
|
||||
fn try_get<T: FromBytes>(&mut self) -> Option<T> {
|
||||
let res = T::read_from_prefix(self)?;
|
||||
self.advance(size_of::<T>());
|
||||
Some(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromBytes, FromZeroes, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
struct ProxyProtocolV2Header {
|
||||
signature: [u8; 12],
|
||||
version_and_command: u8,
|
||||
protocol_and_family: u8,
|
||||
len: zerocopy::byteorder::network_endian::U16,
|
||||
}
|
||||
|
||||
#[derive(FromBytes, FromZeroes, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
struct ProxyProtocolV2HeaderV4 {
|
||||
src_addr: NetworkEndianIpv4,
|
||||
dst_addr: NetworkEndianIpv4,
|
||||
src_port: zerocopy::byteorder::network_endian::U16,
|
||||
dst_port: zerocopy::byteorder::network_endian::U16,
|
||||
}
|
||||
|
||||
#[derive(FromBytes, FromZeroes, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
struct ProxyProtocolV2HeaderV6 {
|
||||
src_addr: NetworkEndianIpv6,
|
||||
dst_addr: NetworkEndianIpv6,
|
||||
src_port: zerocopy::byteorder::network_endian::U16,
|
||||
dst_port: zerocopy::byteorder::network_endian::U16,
|
||||
}
|
||||
|
||||
#[derive(FromBytes, FromZeroes, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
struct TlvHeader {
|
||||
kind: u8,
|
||||
len: zerocopy::byteorder::network_endian::U16,
|
||||
}
|
||||
|
||||
#[derive(FromBytes, FromZeroes, Copy, Clone)]
|
||||
#[repr(transparent)]
|
||||
struct NetworkEndianIpv4(zerocopy::byteorder::network_endian::U32);
|
||||
impl NetworkEndianIpv4 {
|
||||
#[inline]
|
||||
fn get(self) -> Ipv4Addr {
|
||||
Ipv4Addr::from_bits(self.0.get())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromBytes, FromZeroes, Copy, Clone)]
|
||||
#[repr(transparent)]
|
||||
struct NetworkEndianIpv6(zerocopy::byteorder::network_endian::U128);
|
||||
impl NetworkEndianIpv6 {
|
||||
#[inline]
|
||||
fn get(self) -> Ipv6Addr {
|
||||
Ipv6Addr::from_bits(self.0.get())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
use crate::protocol2::read_proxy_protocol;
|
||||
use crate::protocol2::{
|
||||
read_proxy_protocol, ConnectHeader, LOCAL_V2, PROXY_V2, TCP_OVER_IPV4, UDP_OVER_IPV6,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ipv4() {
|
||||
let header = super::HEADER
|
||||
let header = super::SIGNATURE
|
||||
// Proxy command, IPV4 | TCP
|
||||
.chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
|
||||
// 12 + 3 bytes
|
||||
@@ -384,15 +432,17 @@ mod tests {
|
||||
|
||||
assert_eq!(bytes, extra_data);
|
||||
|
||||
let info = info.unwrap();
|
||||
let ConnectHeader::Proxy(info) = info else {
|
||||
panic!()
|
||||
};
|
||||
assert_eq!(info.addr, ([127, 0, 0, 1], 65535).into());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ipv6() {
|
||||
let header = super::HEADER
|
||||
let header = super::SIGNATURE
|
||||
// Proxy command, IPV6 | UDP
|
||||
.chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
|
||||
.chain([PROXY_V2, UDP_OVER_IPV6].as_slice())
|
||||
// 36 + 3 bytes
|
||||
.chain([0, 39].as_slice())
|
||||
// src ip
|
||||
@@ -417,7 +467,9 @@ mod tests {
|
||||
|
||||
assert_eq!(bytes, extra_data);
|
||||
|
||||
let info = info.unwrap();
|
||||
let ConnectHeader::Proxy(info) = info else {
|
||||
panic!()
|
||||
};
|
||||
assert_eq!(
|
||||
info.addr,
|
||||
([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
|
||||
@@ -433,7 +485,7 @@ mod tests {
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
assert_eq!(bytes, data);
|
||||
assert_eq!(info, None);
|
||||
assert_eq!(info, ConnectHeader::Missing);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -445,7 +497,7 @@ mod tests {
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
assert_eq!(bytes, data);
|
||||
assert_eq!(info, None);
|
||||
assert_eq!(info, ConnectHeader::Missing);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -454,9 +506,9 @@ mod tests {
|
||||
let tlv_len = (tlv.len() as u16).to_be_bytes();
|
||||
let len = (12 + 3 + tlv.len() as u16).to_be_bytes();
|
||||
|
||||
let header = super::HEADER
|
||||
let header = super::SIGNATURE
|
||||
// Proxy command, Inet << 4 | Stream
|
||||
.chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
|
||||
.chain([PROXY_V2, TCP_OVER_IPV4].as_slice())
|
||||
// 12 + 3 bytes
|
||||
.chain(len.as_slice())
|
||||
// src ip
|
||||
@@ -483,7 +535,30 @@ mod tests {
|
||||
|
||||
assert_eq!(bytes, extra_data);
|
||||
|
||||
let info = info.unwrap();
|
||||
let ConnectHeader::Proxy(info) = info else {
|
||||
panic!()
|
||||
};
|
||||
assert_eq!(info.addr, ([55, 56, 57, 58], 65535).into());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_local() {
|
||||
let len = 0u16.to_be_bytes();
|
||||
let header = super::SIGNATURE
|
||||
.chain([LOCAL_V2, 0x00].as_slice())
|
||||
.chain(len.as_slice());
|
||||
|
||||
let extra_data = [0xaa; 256];
|
||||
|
||||
let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
|
||||
assert_eq!(bytes, extra_data);
|
||||
|
||||
let ConnectHeader::Local = info else { panic!() };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ use smol_str::{format_smolstr, SmolStr};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, warn, Instrument};
|
||||
use tracing::{debug, error, info, warn, Instrument};
|
||||
|
||||
use self::connect_compute::{connect_to_compute, TcpMechanism};
|
||||
use self::passthrough::ProxyPassthrough;
|
||||
@@ -28,7 +28,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::{read_proxy_protocol, ConnectionInfo};
|
||||
use crate::protocol2::{read_proxy_protocol, ConnectHeader, ConnectionInfo};
|
||||
use crate::proxy::handshake::{handshake, HandshakeData};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
@@ -83,7 +83,7 @@ pub async fn task_main(
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
|
||||
tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
@@ -92,16 +92,21 @@ pub async fn task_main(
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
Ok((_socket, ConnectHeader::Local)) => {
|
||||
debug!("healthcheck received");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, ConnectHeader::Missing)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
|
||||
warn!("missing required proxy protocol header");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
|
||||
Ok((_socket, ConnectHeader::Proxy(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
|
||||
warn!("proxy protocol header not supported");
|
||||
return;
|
||||
}
|
||||
Ok((socket, Some(info))) => (socket, info),
|
||||
Ok((socket, None)) => (socket, ConnectionInfo { addr: peer_addr, extra: None }),
|
||||
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
||||
Ok((socket, ConnectHeader::Missing)) => (socket, ConnectionInfo { addr: peer_addr, extra: None }),
|
||||
};
|
||||
|
||||
match socket.inner.set_nodelay(true) {
|
||||
|
||||
@@ -20,14 +20,14 @@ use super::connect_compute::ConnectMechanism;
|
||||
use super::retry::CouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
|
||||
};
|
||||
use crate::config::{CertResolver, RetryConfig};
|
||||
use crate::control_plane::messages::{ControlPlaneError, Details, MetricsAuxInfo, Status};
|
||||
use crate::control_plane::provider::{
|
||||
CachedAllowedIps, CachedRoleSecret, ControlPlaneBackend, NodeInfoCache,
|
||||
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
|
||||
use crate::control_plane::{
|
||||
self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache,
|
||||
};
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId};
|
||||
use crate::{sasl, scram};
|
||||
@@ -490,7 +490,7 @@ impl ConnectMechanism for TestConnectMechanism {
|
||||
fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
impl TestBackend for TestConnectMechanism {
|
||||
impl TestControlPlaneClient for TestConnectMechanism {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
let mut counter = self.counter.lock().unwrap();
|
||||
let action = self.sequence[*counter];
|
||||
@@ -498,18 +498,19 @@ impl TestBackend for TestConnectMechanism {
|
||||
match action {
|
||||
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
|
||||
ConnectAction::WakeFail => {
|
||||
let err =
|
||||
control_plane::errors::ApiError::ControlPlane(Box::new(ControlPlaneError {
|
||||
let err = control_plane::errors::ControlPlaneError::Message(Box::new(
|
||||
ControlPlaneErrorMessage {
|
||||
http_status_code: StatusCode::BAD_REQUEST,
|
||||
error: "TEST".into(),
|
||||
status: None,
|
||||
}));
|
||||
},
|
||||
));
|
||||
assert!(!err.could_retry());
|
||||
Err(control_plane::errors::WakeComputeError::ApiError(err))
|
||||
Err(control_plane::errors::WakeComputeError::ControlPlane(err))
|
||||
}
|
||||
ConnectAction::WakeRetry => {
|
||||
let err =
|
||||
control_plane::errors::ApiError::ControlPlane(Box::new(ControlPlaneError {
|
||||
let err = control_plane::errors::ControlPlaneError::Message(Box::new(
|
||||
ControlPlaneErrorMessage {
|
||||
http_status_code: StatusCode::BAD_REQUEST,
|
||||
error: "TEST".into(),
|
||||
status: Some(Status {
|
||||
@@ -523,9 +524,10 @@ impl TestBackend for TestConnectMechanism {
|
||||
user_facing_message: None,
|
||||
},
|
||||
}),
|
||||
}));
|
||||
},
|
||||
));
|
||||
assert!(err.could_retry());
|
||||
Err(control_plane::errors::WakeComputeError::ApiError(err))
|
||||
Err(control_plane::errors::WakeComputeError::ControlPlane(err))
|
||||
}
|
||||
x => panic!("expecting action {x:?}, wake_compute is called instead"),
|
||||
}
|
||||
@@ -538,7 +540,7 @@ impl TestBackend for TestConnectMechanism {
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
|
||||
fn dyn_clone(&self) -> Box<dyn TestBackend> {
|
||||
fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
@@ -562,7 +564,7 @@ fn helper_create_connect_info(
|
||||
mechanism: &TestConnectMechanism,
|
||||
) -> auth::Backend<'static, ComputeCredentials> {
|
||||
let user_info = auth::Backend::ControlPlane(
|
||||
MaybeOwned::Owned(ControlPlaneBackend::Test(Box::new(mechanism.clone()))),
|
||||
MaybeOwned::Owned(ControlPlaneClient::Test(Box::new(mechanism.clone()))),
|
||||
ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
|
||||
@@ -4,7 +4,7 @@ use super::connect_compute::ComputeConnectBackend;
|
||||
use crate::config::RetryConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::errors::WakeComputeError;
|
||||
use crate::control_plane::provider::CachedNodeInfo;
|
||||
use crate::control_plane::CachedNodeInfo;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType,
|
||||
|
||||
@@ -14,7 +14,7 @@ use tracing::{debug, info};
|
||||
use super::conn_pool::poll_client;
|
||||
use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool};
|
||||
use super::http_conn_pool::{self, poll_http2_client, Send};
|
||||
use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION};
|
||||
use super::local_conn_pool::{self, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION};
|
||||
use crate::auth::backend::local::StaticAuthRules;
|
||||
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::{self, check_peer_addr_is_in_list, AuthError};
|
||||
@@ -24,9 +24,9 @@ use crate::compute_ctl::{
|
||||
};
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::client::ApiLockError;
|
||||
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::control_plane::provider::ApiLockError;
|
||||
use crate::control_plane::CachedNodeInfo;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::intern::EndpointIdInt;
|
||||
@@ -205,7 +205,7 @@ impl PoolingBackend {
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
|
||||
info!("pool: looking for an existing connection");
|
||||
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
|
||||
if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
@@ -248,7 +248,7 @@ impl PoolingBackend {
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<LocalClient<tokio_postgres::Client>, HttpConnError> {
|
||||
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
|
||||
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
@@ -18,7 +18,9 @@ use {
|
||||
std::{sync::atomic, time::Duration},
|
||||
};
|
||||
|
||||
use super::conn_pool_lib::{Client, ClientInnerExt, ConnInfo, GlobalConnPool};
|
||||
use super::conn_pool_lib::{
|
||||
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, GlobalConnPool,
|
||||
};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::Metrics;
|
||||
@@ -152,53 +154,30 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
|
||||
}
|
||||
.instrument(span));
|
||||
let inner = ClientInnerRemote {
|
||||
let inner = ClientInnerCommon {
|
||||
inner: client,
|
||||
session: tx,
|
||||
cancel,
|
||||
aux,
|
||||
conn_id,
|
||||
data: ClientDataEnum::Remote(ClientDataRemote {
|
||||
session: tx,
|
||||
cancel,
|
||||
}),
|
||||
};
|
||||
|
||||
Client::new(inner, conn_info, pool_clone)
|
||||
}
|
||||
|
||||
pub(crate) struct ClientInnerRemote<C: ClientInnerExt> {
|
||||
inner: C,
|
||||
pub(crate) struct ClientDataRemote {
|
||||
session: tokio::sync::watch::Sender<uuid::Uuid>,
|
||||
cancel: CancellationToken,
|
||||
aux: MetricsAuxInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> ClientInnerRemote<C> {
|
||||
pub(crate) fn inner_mut(&mut self) -> &mut C {
|
||||
&mut self.inner
|
||||
}
|
||||
|
||||
pub(crate) fn inner(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
pub(crate) fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
|
||||
impl ClientDataRemote {
|
||||
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
|
||||
&mut self.session
|
||||
}
|
||||
|
||||
pub(crate) fn aux(&self) -> &MetricsAuxInfo {
|
||||
&self.aux
|
||||
}
|
||||
|
||||
pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
|
||||
self.conn_id
|
||||
}
|
||||
|
||||
pub(crate) fn is_closed(&self) -> bool {
|
||||
self.inner.is_closed()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for ClientInnerRemote<C> {
|
||||
fn drop(&mut self) {
|
||||
// on client drop, tell the conn to shut down
|
||||
pub fn cancel(&mut self) {
|
||||
self.cancel.cancel();
|
||||
}
|
||||
}
|
||||
@@ -228,15 +207,13 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn create_inner() -> ClientInnerRemote<MockClient> {
|
||||
fn create_inner() -> ClientInnerCommon<MockClient> {
|
||||
create_inner_with(MockClient::new(false))
|
||||
}
|
||||
|
||||
fn create_inner_with(client: MockClient) -> ClientInnerRemote<MockClient> {
|
||||
ClientInnerRemote {
|
||||
fn create_inner_with(client: MockClient) -> ClientInnerCommon<MockClient> {
|
||||
ClientInnerCommon {
|
||||
inner: client,
|
||||
session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
|
||||
cancel: CancellationToken::new(),
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
@@ -244,6 +221,10 @@ mod tests {
|
||||
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
conn_id: uuid::Uuid::new_v4(),
|
||||
data: ClientDataEnum::Remote(ClientDataRemote {
|
||||
session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
|
||||
cancel: CancellationToken::new(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,7 +261,7 @@ mod tests {
|
||||
{
|
||||
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
|
||||
assert_eq!(0, pool.get_global_connections_count());
|
||||
client.inner_mut().1.discard();
|
||||
client.inner().1.discard();
|
||||
// Discard should not add the connection from the pool.
|
||||
assert_eq!(0, pool.get_global_connections_count());
|
||||
}
|
||||
|
||||
@@ -11,10 +11,13 @@ use tokio_postgres::ReadyForQueryStatus;
|
||||
use tracing::{debug, info, Span};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
use super::conn_pool::ClientInnerRemote;
|
||||
use super::conn_pool::ClientDataRemote;
|
||||
use super::http_conn_pool::ClientDataHttp;
|
||||
use super::local_conn_pool::ClientDataLocal;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::ColdStartInfo;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::types::{DbName, EndpointCacheKey, RoleName};
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
@@ -41,8 +44,46 @@ impl ConnInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum ClientDataEnum {
|
||||
Remote(ClientDataRemote),
|
||||
Local(ClientDataLocal),
|
||||
#[allow(dead_code)]
|
||||
Http(ClientDataHttp),
|
||||
}
|
||||
|
||||
pub(crate) struct ClientInnerCommon<C: ClientInnerExt> {
|
||||
pub(crate) inner: C,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
pub(crate) conn_id: uuid::Uuid,
|
||||
pub(crate) data: ClientDataEnum, // custom client data like session, key, jti
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for ClientInnerCommon<C> {
|
||||
fn drop(&mut self) {
|
||||
match &mut self.data {
|
||||
ClientDataEnum::Remote(remote_data) => {
|
||||
remote_data.cancel();
|
||||
}
|
||||
ClientDataEnum::Local(local_data) => {
|
||||
local_data.cancel();
|
||||
}
|
||||
ClientDataEnum::Http(_http_data) => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> ClientInnerCommon<C> {
|
||||
pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
|
||||
self.conn_id
|
||||
}
|
||||
|
||||
pub(crate) fn get_data(&mut self) -> &mut ClientDataEnum {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ConnPoolEntry<C: ClientInnerExt> {
|
||||
pub(crate) conn: ClientInnerRemote<C>,
|
||||
pub(crate) conn: ClientInnerCommon<C>,
|
||||
pub(crate) _last_access: std::time::Instant,
|
||||
}
|
||||
|
||||
@@ -55,10 +96,33 @@ pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
|
||||
_guard: HttpEndpointPoolsGuard<'static>,
|
||||
global_connections_count: Arc<AtomicUsize>,
|
||||
global_pool_size_max_conns: usize,
|
||||
pool_name: String,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
|
||||
pub(crate) fn new(
|
||||
hmap: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
|
||||
tconns: usize,
|
||||
max_conns_per_endpoint: usize,
|
||||
global_connections_count: Arc<AtomicUsize>,
|
||||
max_total_conns: usize,
|
||||
pname: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
pools: hmap,
|
||||
total_conns: tconns,
|
||||
max_conns: max_conns_per_endpoint,
|
||||
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
|
||||
global_connections_count,
|
||||
global_pool_size_max_conns: max_total_conns,
|
||||
pool_name: pname,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_conn_entry(
|
||||
&mut self,
|
||||
db_user: (DbName, RoleName),
|
||||
) -> Option<ConnPoolEntry<C>> {
|
||||
let Self {
|
||||
pools,
|
||||
total_conns,
|
||||
@@ -84,9 +148,10 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
..
|
||||
} = self;
|
||||
if let Some(pool) = pools.get_mut(&db_user) {
|
||||
let old_len = pool.conns.len();
|
||||
pool.conns.retain(|conn| conn.conn.get_conn_id() != conn_id);
|
||||
let new_len = pool.conns.len();
|
||||
let old_len = pool.get_conns().len();
|
||||
pool.get_conns()
|
||||
.retain(|conn| conn.conn.get_conn_id() != conn_id);
|
||||
let new_len = pool.get_conns().len();
|
||||
let removed = old_len - new_len;
|
||||
if removed > 0 {
|
||||
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
|
||||
@@ -103,11 +168,26 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerRemote<C>) {
|
||||
let conn_id = client.get_conn_id();
|
||||
pub(crate) fn get_name(&self) -> &str {
|
||||
&self.pool_name
|
||||
}
|
||||
|
||||
if client.is_closed() {
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
|
||||
pub(crate) fn get_pool(&self, db_user: (DbName, RoleName)) -> Option<&DbUserConnPool<C>> {
|
||||
self.pools.get(&db_user)
|
||||
}
|
||||
|
||||
pub(crate) fn get_pool_mut(
|
||||
&mut self,
|
||||
db_user: (DbName, RoleName),
|
||||
) -> Option<&mut DbUserConnPool<C>> {
|
||||
self.pools.get_mut(&db_user)
|
||||
}
|
||||
|
||||
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerCommon<C>) {
|
||||
let conn_id = client.get_conn_id();
|
||||
let pool_name = pool.read().get_name().to_string();
|
||||
if client.inner.is_closed() {
|
||||
info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -118,7 +198,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
.load(atomic::Ordering::Relaxed)
|
||||
>= global_max_conn
|
||||
{
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
|
||||
info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -130,13 +210,13 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
|
||||
if pool.total_conns < pool.max_conns {
|
||||
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
|
||||
pool_entries.conns.push(ConnPoolEntry {
|
||||
pool_entries.get_conns().push(ConnPoolEntry {
|
||||
conn: client,
|
||||
_last_access: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
returned = true;
|
||||
per_db_size = pool_entries.conns.len();
|
||||
per_db_size = pool_entries.get_conns().len();
|
||||
|
||||
pool.total_conns += 1;
|
||||
pool.global_connections_count
|
||||
@@ -153,9 +233,9 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
|
||||
// do logging outside of the mutex
|
||||
if returned {
|
||||
info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
|
||||
info!(%conn_id, "{pool_name}: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
|
||||
} else {
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
|
||||
info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -176,19 +256,39 @@ impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
|
||||
|
||||
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
|
||||
pub(crate) conns: Vec<ConnPoolEntry<C>>,
|
||||
pub(crate) initialized: Option<bool>, // a bit ugly, exists only for local pools
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
|
||||
fn default() -> Self {
|
||||
Self { conns: Vec::new() }
|
||||
Self {
|
||||
conns: Vec::new(),
|
||||
initialized: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> DbUserConnPool<C> {
|
||||
pub(crate) trait DbUserConn<C: ClientInnerExt>: Default {
|
||||
fn set_initialized(&mut self);
|
||||
fn is_initialized(&self) -> bool;
|
||||
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize;
|
||||
fn get_conn_entry(&mut self, conns: &mut usize) -> (Option<ConnPoolEntry<C>>, usize);
|
||||
fn get_conns(&mut self) -> &mut Vec<ConnPoolEntry<C>>;
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> DbUserConn<C> for DbUserConnPool<C> {
|
||||
fn set_initialized(&mut self) {
|
||||
self.initialized = Some(true);
|
||||
}
|
||||
|
||||
fn is_initialized(&self) -> bool {
|
||||
self.initialized.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
|
||||
let old_len = self.conns.len();
|
||||
|
||||
self.conns.retain(|conn| !conn.conn.is_closed());
|
||||
self.conns.retain(|conn| !conn.conn.inner.is_closed());
|
||||
|
||||
let new_len = self.conns.len();
|
||||
let removed = old_len - new_len;
|
||||
@@ -196,10 +296,7 @@ impl<C: ClientInnerExt> DbUserConnPool<C> {
|
||||
removed
|
||||
}
|
||||
|
||||
pub(crate) fn get_conn_entry(
|
||||
&mut self,
|
||||
conns: &mut usize,
|
||||
) -> (Option<ConnPoolEntry<C>>, usize) {
|
||||
fn get_conn_entry(&mut self, conns: &mut usize) -> (Option<ConnPoolEntry<C>>, usize) {
|
||||
let mut removed = self.clear_closed_clients(conns);
|
||||
let conn = self.conns.pop();
|
||||
if conn.is_some() {
|
||||
@@ -215,6 +312,10 @@ impl<C: ClientInnerExt> DbUserConnPool<C> {
|
||||
|
||||
(conn, removed)
|
||||
}
|
||||
|
||||
fn get_conns(&mut self) -> &mut Vec<ConnPoolEntry<C>> {
|
||||
&mut self.conns
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
|
||||
@@ -278,6 +379,60 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
self.config.pool_options.idle_timeout
|
||||
}
|
||||
|
||||
pub(crate) fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<Option<Client<C>>, HttpConnError> {
|
||||
let mut client: Option<ClientInnerCommon<C>> = None;
|
||||
let Some(endpoint) = conn_info.endpoint_cache_key() else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
|
||||
if let Some(entry) = endpoint_pool
|
||||
.write()
|
||||
.get_conn_entry(conn_info.db_and_user())
|
||||
{
|
||||
client = Some(entry.conn);
|
||||
}
|
||||
let endpoint_pool = Arc::downgrade(&endpoint_pool);
|
||||
|
||||
// ok return cached connection if found and establish a new one otherwise
|
||||
if let Some(mut client) = client {
|
||||
if client.inner.is_closed() {
|
||||
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
return Ok(None);
|
||||
}
|
||||
tracing::Span::current()
|
||||
.record("conn_id", tracing::field::display(client.get_conn_id()));
|
||||
tracing::Span::current().record(
|
||||
"pid",
|
||||
tracing::field::display(client.inner.get_process_id()),
|
||||
);
|
||||
info!(
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"pool: reusing connection '{conn_info}'"
|
||||
);
|
||||
|
||||
match client.get_data() {
|
||||
ClientDataEnum::Local(data) => {
|
||||
data.session().send(ctx.session_id())?;
|
||||
}
|
||||
|
||||
ClientDataEnum::Remote(data) => {
|
||||
data.session().send(ctx.session_id())?;
|
||||
}
|
||||
ClientDataEnum::Http(_) => (),
|
||||
}
|
||||
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub(crate) fn shutdown(&self) {
|
||||
// drops all strong references to endpoint-pools
|
||||
self.global_pool.clear();
|
||||
@@ -374,6 +529,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
|
||||
global_connections_count: self.global_connections_count.clone(),
|
||||
global_pool_size_max_conns: self.config.pool_options.max_total_conns,
|
||||
pool_name: String::from("remote"),
|
||||
}));
|
||||
|
||||
// find or create a pool for this endpoint
|
||||
@@ -400,55 +556,23 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
|
||||
pool
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<Option<Client<C>>, HttpConnError> {
|
||||
let mut client: Option<ClientInnerRemote<C>> = None;
|
||||
let Some(endpoint) = conn_info.endpoint_cache_key() else {
|
||||
return Ok(None);
|
||||
};
|
||||
pub(crate) struct Client<C: ClientInnerExt> {
|
||||
span: Span,
|
||||
inner: Option<ClientInnerCommon<C>>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<RwLock<EndpointConnPool<C>>>,
|
||||
}
|
||||
|
||||
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
|
||||
if let Some(entry) = endpoint_pool
|
||||
.write()
|
||||
.get_conn_entry(conn_info.db_and_user())
|
||||
{
|
||||
client = Some(entry.conn);
|
||||
}
|
||||
let endpoint_pool = Arc::downgrade(&endpoint_pool);
|
||||
|
||||
// ok return cached connection if found and establish a new one otherwise
|
||||
if let Some(mut client) = client {
|
||||
if client.is_closed() {
|
||||
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
return Ok(None);
|
||||
}
|
||||
tracing::Span::current()
|
||||
.record("conn_id", tracing::field::display(client.get_conn_id()));
|
||||
tracing::Span::current().record(
|
||||
"pid",
|
||||
tracing::field::display(client.inner().get_process_id()),
|
||||
);
|
||||
info!(
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"pool: reusing connection '{conn_info}'"
|
||||
);
|
||||
|
||||
client.session().send(ctx.session_id())?;
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
pub(crate) struct Discard<'a, C: ClientInnerExt> {
|
||||
conn_info: &'a ConnInfo,
|
||||
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Client<C> {
|
||||
pub(crate) fn new(
|
||||
inner: ClientInnerRemote<C>,
|
||||
inner: ClientInnerCommon<C>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<RwLock<EndpointConnPool<C>>>,
|
||||
) -> Self {
|
||||
@@ -460,7 +584,18 @@ impl<C: ClientInnerExt> Client<C> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn inner_mut(&mut self) -> (&mut C, Discard<'_, C>) {
|
||||
pub(crate) fn client_inner(&mut self) -> (&mut ClientInnerCommon<C>, Discard<'_, C>) {
|
||||
let Self {
|
||||
inner,
|
||||
pool,
|
||||
conn_info,
|
||||
span: _,
|
||||
} = self;
|
||||
let inner_m = inner.as_mut().expect("client inner should not be removed");
|
||||
(inner_m, Discard { conn_info, pool })
|
||||
}
|
||||
|
||||
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
|
||||
let Self {
|
||||
inner,
|
||||
pool,
|
||||
@@ -468,12 +603,11 @@ impl<C: ClientInnerExt> Client<C> {
|
||||
span: _,
|
||||
} = self;
|
||||
let inner = inner.as_mut().expect("client inner should not be removed");
|
||||
let inner_ref = inner.inner_mut();
|
||||
(inner_ref, Discard { conn_info, pool })
|
||||
(&mut inner.inner, Discard { conn_info, pool })
|
||||
}
|
||||
|
||||
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
|
||||
let aux = &self.inner.as_ref().unwrap().aux();
|
||||
let aux = &self.inner.as_ref().unwrap().aux;
|
||||
USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id,
|
||||
branch_id: aux.branch_id,
|
||||
@@ -498,13 +632,6 @@ impl<C: ClientInnerExt> Client<C> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Client<C: ClientInnerExt> {
|
||||
span: Span,
|
||||
inner: Option<ClientInnerRemote<C>>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<RwLock<EndpointConnPool<C>>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for Client<C> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(drop) = self.do_drop() {
|
||||
@@ -517,10 +644,11 @@ impl<C: ClientInnerExt> Deref for Client<C> {
|
||||
type Target = C;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.inner
|
||||
&self
|
||||
.inner
|
||||
.as_ref()
|
||||
.expect("client inner should not be removed")
|
||||
.inner()
|
||||
.inner
|
||||
}
|
||||
}
|
||||
|
||||
@@ -539,11 +667,6 @@ impl ClientInnerExt for tokio_postgres::Client {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Discard<'a, C: ClientInnerExt> {
|
||||
conn_info: &'a ConnInfo,
|
||||
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
let conn_info = &self.conn_info;
|
||||
|
||||
@@ -7,9 +7,11 @@ use hyper::client::conn::http2;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use parking_lot::RwLock;
|
||||
use rand::Rng;
|
||||
use std::result::Result::Ok;
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::{debug, error, info, info_span, Instrument};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
use super::conn_pool_lib::{ClientInnerExt, ConnInfo};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
@@ -28,6 +30,8 @@ pub(crate) struct ConnPoolEntry<C: ClientInnerExt + Clone> {
|
||||
aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
pub(crate) struct ClientDataHttp();
|
||||
|
||||
// Per-endpoint connection pool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub(crate) struct EndpointConnPool<C: ClientInnerExt + Clone> {
|
||||
@@ -206,14 +210,22 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C> {
|
||||
}
|
||||
}
|
||||
|
||||
#[expect(unused_results)]
|
||||
pub(crate) fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Option<Client<C>> {
|
||||
let endpoint = conn_info.endpoint_cache_key()?;
|
||||
) -> Result<Option<Client<C>>, HttpConnError> {
|
||||
let result: Result<Option<Client<C>>, HttpConnError>;
|
||||
let Some(endpoint) = conn_info.endpoint_cache_key() else {
|
||||
result = Ok(None);
|
||||
return result;
|
||||
};
|
||||
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
|
||||
let client = endpoint_pool.write().get_conn_entry()?;
|
||||
let Some(client) = endpoint_pool.write().get_conn_entry() else {
|
||||
result = Ok(None);
|
||||
return result;
|
||||
};
|
||||
|
||||
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||
info!(
|
||||
@@ -222,7 +234,7 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C> {
|
||||
);
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
Some(Client::new(client.conn, client.aux))
|
||||
Ok(Some(Client::new(client.conn, client.aux)))
|
||||
}
|
||||
|
||||
fn get_or_create_endpoint_pool(
|
||||
|
||||
@@ -11,7 +11,8 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::pin;
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::Arc;
|
||||
use std::task::{ready, Poll};
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -26,177 +27,42 @@ use signature::Signer;
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::types::ToSql;
|
||||
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
|
||||
use tokio_postgres::{AsyncMessage, Socket};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, warn, Instrument, Span};
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
use super::conn_pool_lib::{ClientInnerExt, ConnInfo};
|
||||
use super::conn_pool_lib::{
|
||||
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, DbUserConn,
|
||||
EndpointConnPool,
|
||||
};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::types::{DbName, RoleName};
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
|
||||
pub(crate) const EXT_NAME: &str = "pg_session_jwt";
|
||||
pub(crate) const EXT_VERSION: &str = "0.1.2";
|
||||
pub(crate) const EXT_SCHEMA: &str = "auth";
|
||||
|
||||
struct ConnPoolEntry<C: ClientInnerExt> {
|
||||
conn: ClientInner<C>,
|
||||
_last_access: std::time::Instant,
|
||||
pub(crate) struct ClientDataLocal {
|
||||
session: tokio::sync::watch::Sender<uuid::Uuid>,
|
||||
cancel: CancellationToken,
|
||||
key: SigningKey,
|
||||
jti: u64,
|
||||
}
|
||||
|
||||
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
|
||||
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
|
||||
total_conns: usize,
|
||||
max_conns: usize,
|
||||
global_pool_size_max_conns: usize,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
|
||||
let Self {
|
||||
pools, total_conns, ..
|
||||
} = self;
|
||||
pools
|
||||
.get_mut(&db_user)
|
||||
.and_then(|pool_entries| pool_entries.get_conn_entry(total_conns))
|
||||
impl ClientDataLocal {
|
||||
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
|
||||
&mut self.session
|
||||
}
|
||||
|
||||
fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
|
||||
let Self {
|
||||
pools, total_conns, ..
|
||||
} = self;
|
||||
if let Some(pool) = pools.get_mut(&db_user) {
|
||||
let old_len = pool.conns.len();
|
||||
pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
|
||||
let new_len = pool.conns.len();
|
||||
let removed = old_len - new_len;
|
||||
if removed > 0 {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(removed as i64);
|
||||
}
|
||||
*total_conns -= removed;
|
||||
removed > 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
|
||||
let conn_id = client.conn_id;
|
||||
|
||||
if client.is_closed() {
|
||||
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because connection is closed");
|
||||
return;
|
||||
}
|
||||
let global_max_conn = pool.read().global_pool_size_max_conns;
|
||||
if pool.read().total_conns >= global_max_conn {
|
||||
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full");
|
||||
return;
|
||||
}
|
||||
|
||||
// return connection to the pool
|
||||
let mut returned = false;
|
||||
let mut per_db_size = 0;
|
||||
let total_conns = {
|
||||
let mut pool = pool.write();
|
||||
|
||||
if pool.total_conns < pool.max_conns {
|
||||
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
|
||||
pool_entries.conns.push(ConnPoolEntry {
|
||||
conn: client,
|
||||
_last_access: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
returned = true;
|
||||
per_db_size = pool_entries.conns.len();
|
||||
|
||||
pool.total_conns += 1;
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.inc();
|
||||
}
|
||||
|
||||
pool.total_conns
|
||||
};
|
||||
|
||||
// do logging outside of the mutex
|
||||
if returned {
|
||||
info!(%conn_id, "local_pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
|
||||
} else {
|
||||
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
|
||||
fn drop(&mut self) {
|
||||
if self.total_conns > 0 {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(self.total_conns as i64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
|
||||
conns: Vec<ConnPoolEntry<C>>,
|
||||
|
||||
// true if we have definitely installed the extension and
|
||||
// granted the role access to the auth schema.
|
||||
initialized: bool,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
conns: Vec::new(),
|
||||
initialized: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> DbUserConnPool<C> {
|
||||
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
|
||||
let old_len = self.conns.len();
|
||||
|
||||
self.conns.retain(|conn| !conn.conn.is_closed());
|
||||
|
||||
let new_len = self.conns.len();
|
||||
let removed = old_len - new_len;
|
||||
*conns -= removed;
|
||||
removed
|
||||
}
|
||||
|
||||
fn get_conn_entry(&mut self, conns: &mut usize) -> Option<ConnPoolEntry<C>> {
|
||||
let mut removed = self.clear_closed_clients(conns);
|
||||
let conn = self.conns.pop();
|
||||
if conn.is_some() {
|
||||
*conns -= 1;
|
||||
removed += 1;
|
||||
}
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(removed as i64);
|
||||
conn
|
||||
pub fn cancel(&mut self) {
|
||||
self.cancel.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LocalConnPool<C: ClientInnerExt> {
|
||||
global_pool: RwLock<EndpointConnPool<C>>,
|
||||
global_pool: Arc<RwLock<EndpointConnPool<C>>>,
|
||||
|
||||
config: &'static crate::config::HttpConfig,
|
||||
}
|
||||
@@ -204,12 +70,14 @@ pub(crate) struct LocalConnPool<C: ClientInnerExt> {
|
||||
impl<C: ClientInnerExt> LocalConnPool<C> {
|
||||
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
global_pool: RwLock::new(EndpointConnPool {
|
||||
pools: HashMap::new(),
|
||||
total_conns: 0,
|
||||
max_conns: config.pool_options.max_conns_per_endpoint,
|
||||
global_pool_size_max_conns: config.pool_options.max_total_conns,
|
||||
}),
|
||||
global_pool: Arc::new(RwLock::new(EndpointConnPool::new(
|
||||
HashMap::new(),
|
||||
0,
|
||||
config.pool_options.max_conns_per_endpoint,
|
||||
Arc::new(AtomicUsize::new(0)),
|
||||
config.pool_options.max_total_conns,
|
||||
String::from("local_pool"),
|
||||
))),
|
||||
config,
|
||||
})
|
||||
}
|
||||
@@ -222,7 +90,7 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<Option<LocalClient<C>>, HttpConnError> {
|
||||
) -> Result<Option<Client<C>>, HttpConnError> {
|
||||
let client = self
|
||||
.global_pool
|
||||
.write()
|
||||
@@ -230,12 +98,14 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
|
||||
.map(|entry| entry.conn);
|
||||
|
||||
// ok return cached connection if found and establish a new one otherwise
|
||||
if let Some(client) = client {
|
||||
if client.is_closed() {
|
||||
if let Some(mut client) = client {
|
||||
if client.inner.is_closed() {
|
||||
info!("local_pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
return Ok(None);
|
||||
}
|
||||
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||
|
||||
tracing::Span::current()
|
||||
.record("conn_id", tracing::field::display(client.get_conn_id()));
|
||||
tracing::Span::current().record(
|
||||
"pid",
|
||||
tracing::field::display(client.inner.get_process_id()),
|
||||
@@ -244,47 +114,59 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"local_pool: reusing connection '{conn_info}'"
|
||||
);
|
||||
client.session.send(ctx.session_id())?;
|
||||
|
||||
match client.get_data() {
|
||||
ClientDataEnum::Local(data) => {
|
||||
data.session().send(ctx.session_id())?;
|
||||
}
|
||||
|
||||
ClientDataEnum::Remote(data) => {
|
||||
data.session().send(ctx.session_id())?;
|
||||
}
|
||||
ClientDataEnum::Http(_) => (),
|
||||
}
|
||||
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
return Ok(Some(LocalClient::new(
|
||||
|
||||
return Ok(Some(Client::new(
|
||||
client,
|
||||
conn_info.clone(),
|
||||
Arc::downgrade(self),
|
||||
Arc::downgrade(&self.global_pool),
|
||||
)));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub(crate) fn initialized(self: &Arc<Self>, conn_info: &ConnInfo) -> bool {
|
||||
self.global_pool
|
||||
.read()
|
||||
.pools
|
||||
.get(&conn_info.db_and_user())
|
||||
.map_or(false, |pool| pool.initialized)
|
||||
if let Some(pool) = self.global_pool.read().get_pool(conn_info.db_and_user()) {
|
||||
return pool.is_initialized();
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub(crate) fn set_initialized(self: &Arc<Self>, conn_info: &ConnInfo) {
|
||||
self.global_pool
|
||||
if let Some(pool) = self
|
||||
.global_pool
|
||||
.write()
|
||||
.pools
|
||||
.entry(conn_info.db_and_user())
|
||||
.or_default()
|
||||
.initialized = true;
|
||||
.get_pool_mut(conn_info.db_and_user())
|
||||
{
|
||||
pool.set_initialized();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn poll_client(
|
||||
global_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
global_pool: Arc<LocalConnPool<C>>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
client: tokio_postgres::Client,
|
||||
client: C,
|
||||
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
|
||||
key: SigningKey,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> LocalClient<tokio_postgres::Client> {
|
||||
) -> Client<C> {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let mut session_id = ctx.session_id();
|
||||
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
|
||||
@@ -377,111 +259,47 @@ pub(crate) fn poll_client(
|
||||
}
|
||||
.instrument(span));
|
||||
|
||||
let inner = ClientInner {
|
||||
let inner = ClientInnerCommon {
|
||||
inner: client,
|
||||
session: tx,
|
||||
cancel,
|
||||
aux,
|
||||
conn_id,
|
||||
key,
|
||||
jti: 0,
|
||||
data: ClientDataEnum::Local(ClientDataLocal {
|
||||
session: tx,
|
||||
cancel,
|
||||
key,
|
||||
jti: 0,
|
||||
}),
|
||||
};
|
||||
LocalClient::new(inner, conn_info, pool_clone)
|
||||
|
||||
Client::new(
|
||||
inner,
|
||||
conn_info,
|
||||
Arc::downgrade(&pool_clone.upgrade().unwrap().global_pool),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) struct ClientInner<C: ClientInnerExt> {
|
||||
inner: C,
|
||||
session: tokio::sync::watch::Sender<uuid::Uuid>,
|
||||
cancel: CancellationToken,
|
||||
aux: MetricsAuxInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
// needed for pg_session_jwt state
|
||||
key: SigningKey,
|
||||
jti: u64,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for ClientInner<C> {
|
||||
fn drop(&mut self) {
|
||||
// on client drop, tell the conn to shut down
|
||||
self.cancel.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> ClientInner<C> {
|
||||
pub(crate) fn is_closed(&self) -> bool {
|
||||
self.inner.is_closed()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientInner<tokio_postgres::Client> {
|
||||
impl ClientInnerCommon<tokio_postgres::Client> {
|
||||
pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> {
|
||||
self.jti += 1;
|
||||
let token = resign_jwt(&self.key, payload, self.jti)?;
|
||||
if let ClientDataEnum::Local(local_data) = &mut self.data {
|
||||
local_data.jti += 1;
|
||||
let token = resign_jwt(&local_data.key, payload, local_data.jti)?;
|
||||
|
||||
// initiates the auth session
|
||||
self.inner.simple_query("discard all").await?;
|
||||
self.inner
|
||||
.query(
|
||||
"select auth.jwt_session_init($1)",
|
||||
&[&token as &(dyn ToSql + Sync)],
|
||||
)
|
||||
.await?;
|
||||
// initiates the auth session
|
||||
self.inner.simple_query("discard all").await?;
|
||||
self.inner
|
||||
.query(
|
||||
"select auth.jwt_session_init($1)",
|
||||
&[&token as &(dyn ToSql + Sync)],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let pid = self.inner.get_process_id();
|
||||
info!(pid, jti = self.jti, "user session state init");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LocalClient<C: ClientInnerExt> {
|
||||
span: Span,
|
||||
inner: Option<ClientInner<C>>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<LocalConnPool<C>>,
|
||||
}
|
||||
|
||||
pub(crate) struct Discard<'a, C: ClientInnerExt> {
|
||||
conn_info: &'a ConnInfo,
|
||||
pool: &'a mut Weak<LocalConnPool<C>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> LocalClient<C> {
|
||||
pub(self) fn new(
|
||||
inner: ClientInner<C>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<LocalConnPool<C>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Some(inner),
|
||||
span: Span::current(),
|
||||
conn_info,
|
||||
pool,
|
||||
let pid = self.inner.get_process_id();
|
||||
info!(pid, jti = local_data.jti, "user session state init");
|
||||
Ok(())
|
||||
} else {
|
||||
panic!("unexpected client data type");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn client_inner(&mut self) -> (&mut ClientInner<C>, Discard<'_, C>) {
|
||||
let Self {
|
||||
inner,
|
||||
pool,
|
||||
conn_info,
|
||||
span: _,
|
||||
} = self;
|
||||
let inner_m = inner.as_mut().expect("client inner should not be removed");
|
||||
(inner_m, Discard { conn_info, pool })
|
||||
}
|
||||
|
||||
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
|
||||
let Self {
|
||||
inner,
|
||||
pool,
|
||||
conn_info,
|
||||
span: _,
|
||||
} = self;
|
||||
let inner = inner.as_mut().expect("client inner should not be removed");
|
||||
(&mut inner.inner, Discard { conn_info, pool })
|
||||
}
|
||||
}
|
||||
|
||||
/// implements relatively efficient in-place json object key upserting
|
||||
@@ -547,58 +365,6 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
|
||||
jwt
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> LocalClient<C> {
|
||||
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
|
||||
let aux = &self.inner.as_ref().unwrap().aux;
|
||||
USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id,
|
||||
branch_id: aux.branch_id,
|
||||
})
|
||||
}
|
||||
|
||||
fn do_drop(&mut self) -> Option<impl FnOnce() + use<C>> {
|
||||
let conn_info = self.conn_info.clone();
|
||||
let client = self
|
||||
.inner
|
||||
.take()
|
||||
.expect("client inner should not be removed");
|
||||
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
|
||||
let current_span = self.span.clone();
|
||||
// return connection to the pool
|
||||
return Some(move || {
|
||||
let _span = current_span.enter();
|
||||
EndpointConnPool::put(&conn_pool.global_pool, &conn_info, client);
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for LocalClient<C> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(drop) = self.do_drop() {
|
||||
tokio::task::spawn_blocking(drop);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
let conn_info = &self.conn_info;
|
||||
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!(
|
||||
"local_pool: throwing away connection '{conn_info}' because connection is not idle"
|
||||
);
|
||||
}
|
||||
}
|
||||
pub(crate) fn discard(&mut self) {
|
||||
let conn_info = &self.conn_info;
|
||||
if std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use p256::ecdsa::SigningKey;
|
||||
|
||||
@@ -47,7 +47,7 @@ use crate::cancellation::CancellationHandlerMain;
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::protocol2::{read_proxy_protocol, ChainRW, ConnectionInfo};
|
||||
use crate::protocol2::{read_proxy_protocol, ChainRW, ConnectHeader, ConnectionInfo};
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
@@ -251,16 +251,21 @@ async fn connection_startup(
|
||||
};
|
||||
|
||||
let conn_info = match peer {
|
||||
None if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
ConnectHeader::Local => {
|
||||
tracing::debug!("healthcheck received");
|
||||
return None;
|
||||
}
|
||||
ConnectHeader::Missing if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
|
||||
tracing::warn!("missing required proxy protocol header");
|
||||
return None;
|
||||
}
|
||||
Some(_) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
|
||||
ConnectHeader::Proxy(_) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
|
||||
tracing::warn!("proxy protocol header not supported");
|
||||
return None;
|
||||
}
|
||||
Some(info) => info,
|
||||
None => ConnectionInfo {
|
||||
ConnectHeader::Proxy(info) => info,
|
||||
ConnectHeader::Missing => ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
extra: None,
|
||||
},
|
||||
|
||||
@@ -31,7 +31,6 @@ use super::conn_pool_lib::{self, ConnInfo};
|
||||
use super::error::HttpCodeError;
|
||||
use super::http_util::json_response;
|
||||
use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError};
|
||||
use super::local_conn_pool;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::{endpoint_sni, ComputeUserInfoParseError};
|
||||
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
|
||||
@@ -1052,12 +1051,12 @@ async fn query_to_json<T: GenericClient>(
|
||||
|
||||
enum Client {
|
||||
Remote(conn_pool_lib::Client<tokio_postgres::Client>),
|
||||
Local(local_conn_pool::LocalClient<tokio_postgres::Client>),
|
||||
Local(conn_pool_lib::Client<tokio_postgres::Client>),
|
||||
}
|
||||
|
||||
enum Discard<'a> {
|
||||
Remote(conn_pool_lib::Discard<'a, tokio_postgres::Client>),
|
||||
Local(local_conn_pool::Discard<'a, tokio_postgres::Client>),
|
||||
Local(conn_pool_lib::Discard<'a, tokio_postgres::Client>),
|
||||
}
|
||||
|
||||
impl Client {
|
||||
@@ -1071,7 +1070,7 @@ impl Client {
|
||||
fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
|
||||
match self {
|
||||
Client::Remote(client) => {
|
||||
let (c, d) = client.inner_mut();
|
||||
let (c, d) = client.inner();
|
||||
(c, Discard::Remote(d))
|
||||
}
|
||||
Client::Local(local_client) => {
|
||||
|
||||
@@ -28,6 +28,7 @@ hyper0.workspace = true
|
||||
futures.workspace = true
|
||||
once_cell.workspace = true
|
||||
parking_lot.workspace = true
|
||||
pageserver_api.workspace = true
|
||||
postgres.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
rand.workspace = true
|
||||
@@ -57,6 +58,7 @@ sd-notify.workspace = true
|
||||
storage_broker.workspace = true
|
||||
tokio-stream.workspace = true
|
||||
utils.workspace = true
|
||||
wal_decoder.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ async fn copy_disk_segments(
|
||||
) -> Result<()> {
|
||||
let mut wal_reader = tli.get_walreader(start_lsn).await?;
|
||||
|
||||
let mut buf = [0u8; MAX_SEND_SIZE];
|
||||
let mut buf = vec![0u8; MAX_SEND_SIZE];
|
||||
|
||||
let first_segment = start_lsn.segment_number(wal_seg_size);
|
||||
let last_segment = end_lsn.segment_number(wal_seg_size);
|
||||
|
||||
@@ -383,7 +383,7 @@ pub async fn calculate_digest(
|
||||
let mut wal_reader = tli.get_walreader(request.from_lsn).await?;
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buf = [0u8; MAX_SEND_SIZE];
|
||||
let mut buf = vec![0u8; MAX_SEND_SIZE];
|
||||
|
||||
let mut bytes_left = (request.until_lsn.0 - request.from_lsn.0) as usize;
|
||||
while bytes_left > 0 {
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
//! protocol commands.
|
||||
|
||||
use anyhow::Context;
|
||||
use pageserver_api::shard::{ShardIdentity, ShardStripeSize};
|
||||
use std::future::Future;
|
||||
use std::str::{self, FromStr};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{debug, info, info_span, Instrument};
|
||||
use utils::postgres_client::PAGESERVER_SAFEKEEPER_PROTO_VERSION;
|
||||
use utils::shard::{ShardCount, ShardNumber};
|
||||
|
||||
use crate::auth::check_permission;
|
||||
use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage};
|
||||
@@ -35,6 +38,8 @@ pub struct SafekeeperPostgresHandler {
|
||||
pub tenant_id: Option<TenantId>,
|
||||
pub timeline_id: Option<TimelineId>,
|
||||
pub ttid: TenantTimelineId,
|
||||
pub shard: Option<ShardIdentity>,
|
||||
pub protocol_version: Option<u8>,
|
||||
/// Unique connection id is logged in spans for observability.
|
||||
pub conn_id: ConnectionId,
|
||||
/// Auth scope allowed on the connections and public key used to check auth tokens. None if auth is not configured.
|
||||
@@ -107,11 +112,21 @@ impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
|
||||
) -> Result<(), QueryError> {
|
||||
if let FeStartupPacket::StartupMessage { params, .. } = sm {
|
||||
if let Some(options) = params.options_raw() {
|
||||
let mut shard_count: Option<u8> = None;
|
||||
let mut shard_number: Option<u8> = None;
|
||||
let mut shard_stripe_size: Option<u32> = None;
|
||||
|
||||
for opt in options {
|
||||
// FIXME `ztenantid` and `ztimelineid` left for compatibility during deploy,
|
||||
// remove these after the PR gets deployed:
|
||||
// https://github.com/neondatabase/neon/pull/2433#discussion_r970005064
|
||||
match opt.split_once('=') {
|
||||
Some(("protocol_version", value)) => {
|
||||
self.protocol_version =
|
||||
Some(value.parse::<u8>().with_context(|| {
|
||||
format!("Failed to parse {value} as protocol_version")
|
||||
})?);
|
||||
}
|
||||
Some(("ztenantid", value)) | Some(("tenant_id", value)) => {
|
||||
self.tenant_id = Some(value.parse().with_context(|| {
|
||||
format!("Failed to parse {value} as tenant id")
|
||||
@@ -127,9 +142,44 @@ impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
|
||||
metrics.set_client_az(client_az)
|
||||
}
|
||||
}
|
||||
Some(("shard_count", value)) => {
|
||||
shard_count = Some(value.parse::<u8>().with_context(|| {
|
||||
format!("Failed to parse {value} as shard count")
|
||||
})?);
|
||||
}
|
||||
Some(("shard_number", value)) => {
|
||||
shard_number = Some(value.parse::<u8>().with_context(|| {
|
||||
format!("Failed to parse {value} as shard number")
|
||||
})?);
|
||||
}
|
||||
Some(("shard_stripe_size", value)) => {
|
||||
shard_stripe_size = Some(value.parse::<u32>().with_context(|| {
|
||||
format!("Failed to parse {value} as shard stripe size")
|
||||
})?);
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
|
||||
if self.protocol_version == Some(PAGESERVER_SAFEKEEPER_PROTO_VERSION) {
|
||||
match (shard_count, shard_number, shard_stripe_size) {
|
||||
(Some(count), Some(number), Some(stripe_size)) => {
|
||||
self.shard = Some(
|
||||
ShardIdentity::new(
|
||||
ShardNumber(number),
|
||||
ShardCount(count),
|
||||
ShardStripeSize(stripe_size),
|
||||
)
|
||||
.with_context(|| "Failed to create shard identity")?,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"Shard params were not specified"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(app_name) = params.get("application_name") {
|
||||
@@ -150,6 +200,11 @@ impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
|
||||
tracing::field::debug(self.appname.clone()),
|
||||
);
|
||||
|
||||
if let Some(shard) = self.shard.as_ref() {
|
||||
tracing::Span::current()
|
||||
.record("shard", tracing::field::display(shard.shard_slug()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
Err(QueryError::Other(anyhow::anyhow!(
|
||||
@@ -258,6 +313,8 @@ impl SafekeeperPostgresHandler {
|
||||
tenant_id: None,
|
||||
timeline_id: None,
|
||||
ttid: TenantTimelineId::empty(),
|
||||
shard: None,
|
||||
protocol_version: None,
|
||||
conn_id,
|
||||
claims: None,
|
||||
auth,
|
||||
|
||||
@@ -17,6 +17,7 @@ use tokio::{
|
||||
use tokio_postgres::replication::ReplicationStream;
|
||||
use tokio_postgres::types::PgLsn;
|
||||
use tracing::*;
|
||||
use utils::postgres_client::{ConnectionConfigArgs, POSTGRES_PROTO_VERSION};
|
||||
use utils::{id::NodeId, lsn::Lsn, postgres_client::wal_stream_connection_config};
|
||||
|
||||
use crate::receive_wal::{WalAcceptor, REPLY_QUEUE_SIZE};
|
||||
@@ -325,7 +326,17 @@ async fn recovery_stream(
|
||||
conf: &SafeKeeperConf,
|
||||
) -> anyhow::Result<String> {
|
||||
// TODO: pass auth token
|
||||
let cfg = wal_stream_connection_config(tli.ttid, &donor.pg_connstr, None, None)?;
|
||||
let connection_conf_args = ConnectionConfigArgs {
|
||||
protocol_version: POSTGRES_PROTO_VERSION,
|
||||
ttid: tli.ttid,
|
||||
shard_number: None,
|
||||
shard_count: None,
|
||||
shard_stripe_size: None,
|
||||
listen_pg_addr_str: &donor.pg_connstr,
|
||||
auth_token: None,
|
||||
availability_zone: None,
|
||||
};
|
||||
let cfg = wal_stream_connection_config(connection_conf_args)?;
|
||||
let mut cfg = cfg.to_tokio_postgres_config();
|
||||
// It will make safekeeper give out not committed WAL (up to flush_lsn).
|
||||
cfg.application_name(&format!("safekeeper_{}", conf.my_id));
|
||||
|
||||
@@ -11,17 +11,21 @@ use crate::wal_storage::WalReader;
|
||||
use crate::GlobalTimelines;
|
||||
use anyhow::{bail, Context as AnyhowContext};
|
||||
use bytes::Bytes;
|
||||
use pageserver_api::shard::ShardIdentity;
|
||||
use parking_lot::Mutex;
|
||||
use postgres_backend::PostgresBackend;
|
||||
use postgres_backend::{CopyStreamHandlerEnd, PostgresBackendReader, QueryError};
|
||||
use postgres_ffi::get_current_timestamp;
|
||||
use postgres_ffi::waldecoder::WalStreamDecoder;
|
||||
use postgres_ffi::{TimestampTz, MAX_SEND_SIZE};
|
||||
use pq_proto::{BeMessage, WalSndKeepAlive, XLogDataBody};
|
||||
use pq_proto::{BeMessage, InterpretedWalRecordBody, WalSndKeepAlive, XLogDataBody};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::failpoint_support;
|
||||
use utils::id::TenantTimelineId;
|
||||
use utils::pageserver_feedback::PageserverFeedback;
|
||||
use utils::postgres_client::{PAGESERVER_SAFEKEEPER_PROTO_VERSION, POSTGRES_PROTO_VERSION};
|
||||
use wal_decoder::models::InterpretedWalRecord;
|
||||
|
||||
use std::cmp::{max, min};
|
||||
use std::net::SocketAddr;
|
||||
@@ -377,6 +381,10 @@ impl Drop for WalSenderGuard {
|
||||
}
|
||||
|
||||
impl SafekeeperPostgresHandler {
|
||||
pub fn protocol_version(&self) -> u8 {
|
||||
self.protocol_version.unwrap_or(POSTGRES_PROTO_VERSION)
|
||||
}
|
||||
|
||||
/// Wrapper around handle_start_replication_guts handling result. Error is
|
||||
/// handled here while we're still in walsender ttid span; with API
|
||||
/// extension, this can probably be moved into postgres_backend.
|
||||
@@ -412,6 +420,7 @@ impl SafekeeperPostgresHandler {
|
||||
let appname = self.appname.clone();
|
||||
|
||||
// Use a guard object to remove our entry from the timeline when we are done.
|
||||
// TODO(vlad): maybe thread shard stuff into here
|
||||
let ws_guard = Arc::new(tli.get_walsenders().register(
|
||||
self.ttid,
|
||||
*pgb.get_peer_addr(),
|
||||
@@ -467,7 +476,7 @@ impl SafekeeperPostgresHandler {
|
||||
end_watch,
|
||||
ws_guard: ws_guard.clone(),
|
||||
wal_reader,
|
||||
send_buf: [0; MAX_SEND_SIZE],
|
||||
send_buf: vec![0u8; MAX_SEND_SIZE],
|
||||
};
|
||||
let mut reply_reader = ReplyReader {
|
||||
reader,
|
||||
@@ -475,9 +484,10 @@ impl SafekeeperPostgresHandler {
|
||||
tli,
|
||||
};
|
||||
|
||||
let protocol_version = self.protocol_version();
|
||||
let res = tokio::select! {
|
||||
// todo: add read|write .context to these errors
|
||||
r = sender.run() => r,
|
||||
r = sender.run(protocol_version, self.shard.as_ref()) => r,
|
||||
r = reply_reader.run() => r,
|
||||
};
|
||||
|
||||
@@ -548,7 +558,7 @@ struct WalSender<'a, IO> {
|
||||
ws_guard: Arc<WalSenderGuard>,
|
||||
wal_reader: WalReader,
|
||||
// buffer for readling WAL into to send it
|
||||
send_buf: [u8; MAX_SEND_SIZE],
|
||||
send_buf: Vec<u8>,
|
||||
}
|
||||
|
||||
const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1);
|
||||
@@ -560,7 +570,35 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> WalSender<'_, IO> {
|
||||
///
|
||||
/// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ?
|
||||
/// convenience.
|
||||
async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> {
|
||||
/// TODO(vlad): add a run variant which accumulates a full wall record
|
||||
/// and interprets it.
|
||||
async fn run(
|
||||
&mut self,
|
||||
protocol_version: u8,
|
||||
shard: Option<&ShardIdentity>,
|
||||
) -> Result<(), CopyStreamHandlerEnd> {
|
||||
match protocol_version {
|
||||
POSTGRES_PROTO_VERSION => self.run_wal_sender().await,
|
||||
PAGESERVER_SAFEKEEPER_PROTO_VERSION => {
|
||||
self.run_interpreted_record_sender(shard.unwrap()).await
|
||||
}
|
||||
// TODO: make the proto version an enum
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_interpreted_record_sender(
|
||||
&mut self,
|
||||
shard: &ShardIdentity,
|
||||
) -> Result<(), CopyStreamHandlerEnd> {
|
||||
let mut last_logged_at = std::time::Instant::now();
|
||||
let mut interpreted_records = 0;
|
||||
let mut interpreted_bytes = 0;
|
||||
let mut useful_bytes = 0;
|
||||
|
||||
let pg_version = self.tli.tli.get_state().await.1.server.pg_version / 10000;
|
||||
let mut wal_decoder = WalStreamDecoder::new(self.start_pos, pg_version);
|
||||
|
||||
loop {
|
||||
// Wait for the next portion if it is not there yet, or just
|
||||
// update our end of WAL available for sending value, we
|
||||
@@ -601,6 +639,141 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> WalSender<'_, IO> {
|
||||
};
|
||||
let send_buf = &send_buf[..send_size];
|
||||
|
||||
wal_decoder.feed_bytes(send_buf);
|
||||
|
||||
// How fast or slow is this. Write a little benchmark
|
||||
// to see how quiclky we can decode 1GiB of WAL.
|
||||
// If this is slow, then we have a problem since it bottlenecks
|
||||
// the whole afair. SK can send about 60-70MiB of raw WAL and
|
||||
// about 13-17MiB of useful interpreted WAL per second (these
|
||||
// number are for one shard).
|
||||
while let Some((record_end_lsn, recdata)) = wal_decoder
|
||||
.poll_decode()
|
||||
.with_context(|| "Failed to decode WAL")?
|
||||
{
|
||||
assert!(record_end_lsn.is_aligned());
|
||||
|
||||
// Deserialize and interpret WAL record
|
||||
let interpreted = InterpretedWalRecord::from_bytes_filtered(
|
||||
recdata,
|
||||
shard,
|
||||
record_end_lsn,
|
||||
pg_version,
|
||||
)
|
||||
.with_context(|| "Failed to interpret WAL")?;
|
||||
|
||||
let useful_size = interpreted.batch.buffer_size();
|
||||
|
||||
let mut buf = Vec::new();
|
||||
interpreted
|
||||
.ser_into(&mut buf)
|
||||
.with_context(|| "Failed to serialize interpreted WAL")?;
|
||||
|
||||
let size = buf.len();
|
||||
|
||||
self.pgb
|
||||
.write_message(&BeMessage::InterpretedWalRecord(InterpretedWalRecordBody {
|
||||
wal_end: self.end_pos.0,
|
||||
data: buf.as_slice(),
|
||||
}))
|
||||
.await?;
|
||||
|
||||
interpreted_records += 1;
|
||||
interpreted_bytes += size;
|
||||
useful_bytes += useful_size;
|
||||
}
|
||||
|
||||
// and send it
|
||||
// self.pgb
|
||||
// .write_message(&BeMessage::XLogData(XLogDataBody {
|
||||
// wal_start: self.start_pos.0,
|
||||
// wal_end: self.end_pos.0,
|
||||
// timestamp: get_current_timestamp(),
|
||||
// data: send_buf,
|
||||
// }))
|
||||
// .await?;
|
||||
|
||||
// if let Some(appname) = &self.appname {
|
||||
// if appname == "replica" {
|
||||
// failpoint_support::sleep_millis_async!("sk-send-wal-replica-sleep");
|
||||
// }
|
||||
// }
|
||||
// trace!(
|
||||
// "sent {} bytes of WAL {}-{}",
|
||||
// send_size,
|
||||
// self.start_pos,
|
||||
// self.start_pos + send_size as u64
|
||||
// );
|
||||
|
||||
self.start_pos += send_size as u64;
|
||||
|
||||
let elapsed = last_logged_at.elapsed();
|
||||
if elapsed >= Duration::from_secs(5) {
|
||||
let records_rate = interpreted_records / elapsed.as_millis() * 1000;
|
||||
let bytes_rate = interpreted_bytes / elapsed.as_millis() as usize * 1000;
|
||||
let useful_bytes_rate = useful_bytes / elapsed.as_millis() as usize * 1000;
|
||||
tracing::info!(
|
||||
"Shard {} sender rate: rps={} bps={} ubps={}",
|
||||
shard.number.0,
|
||||
records_rate,
|
||||
bytes_rate,
|
||||
useful_bytes_rate
|
||||
);
|
||||
|
||||
last_logged_at = std::time::Instant::now();
|
||||
interpreted_records = 0;
|
||||
interpreted_bytes = 0;
|
||||
useful_bytes = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_wal_sender(&mut self) -> Result<(), CopyStreamHandlerEnd> {
|
||||
let mut useful_bytes = 0;
|
||||
let mut last_logged_at = std::time::Instant::now();
|
||||
|
||||
loop {
|
||||
// Wait for the next portion if it is not there yet, or just
|
||||
// update our end of WAL available for sending value, we
|
||||
// communicate it to the receiver.
|
||||
self.wait_wal().await?;
|
||||
assert!(
|
||||
self.end_pos > self.start_pos,
|
||||
"nothing to send after waiting for WAL"
|
||||
);
|
||||
|
||||
// try to send as much as available, capped by MAX_SEND_SIZE
|
||||
let mut chunk_end_pos = self.start_pos + MAX_SEND_SIZE as u64;
|
||||
// if we went behind available WAL, back off
|
||||
if chunk_end_pos >= self.end_pos {
|
||||
chunk_end_pos = self.end_pos;
|
||||
} else {
|
||||
// If sending not up to end pos, round down to page boundary to
|
||||
// avoid breaking WAL record not at page boundary, as protocol
|
||||
// demands. See walsender.c (XLogSendPhysical).
|
||||
chunk_end_pos = chunk_end_pos
|
||||
.checked_sub(chunk_end_pos.block_offset())
|
||||
.unwrap();
|
||||
}
|
||||
let send_size = (chunk_end_pos.0 - self.start_pos.0) as usize;
|
||||
let send_buf = &mut self.send_buf[..send_size];
|
||||
let send_size: usize;
|
||||
{
|
||||
// If uncommitted part is being pulled, check that the term is
|
||||
// still the expected one.
|
||||
let _term_guard = if let Some(t) = self.term {
|
||||
Some(self.tli.acquire_term(t).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
// Read WAL into buffer. send_size can be additionally capped to
|
||||
// segment boundary here.
|
||||
send_size = self.wal_reader.read(send_buf).await?
|
||||
};
|
||||
let send_buf = &send_buf[..send_size];
|
||||
|
||||
useful_bytes += send_buf.len();
|
||||
|
||||
// and send it
|
||||
self.pgb
|
||||
.write_message(&BeMessage::XLogData(XLogDataBody {
|
||||
@@ -623,6 +796,18 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> WalSender<'_, IO> {
|
||||
self.start_pos + send_size as u64
|
||||
);
|
||||
self.start_pos += send_size as u64;
|
||||
|
||||
let elapsed = last_logged_at.elapsed();
|
||||
if elapsed >= Duration::from_secs(5) {
|
||||
let useful_bytes_rate = useful_bytes / elapsed.as_millis() as usize * 1000;
|
||||
tracing::info!(
|
||||
"Sender rate: ubps={}",
|
||||
useful_bytes_rate
|
||||
);
|
||||
|
||||
last_logged_at = std::time::Instant::now();
|
||||
useful_bytes = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ base64ct = { version = "1", default-features = false, features = ["std"] }
|
||||
bytes = { version = "1", features = ["serde"] }
|
||||
camino = { version = "1", default-features = false, features = ["serde1"] }
|
||||
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
|
||||
clap = { version = "4", features = ["derive", "string"] }
|
||||
clap_builder = { version = "4", default-features = false, features = ["color", "help", "std", "string", "suggestions", "usage"] }
|
||||
clap = { version = "4", features = ["derive", "env", "string"] }
|
||||
clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] }
|
||||
crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] }
|
||||
der = { version = "0.7", default-features = false, features = ["oid", "pem", "std"] }
|
||||
deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] }
|
||||
@@ -88,6 +88,7 @@ tower = { version = "0.4", default-features = false, features = ["balance", "buf
|
||||
tracing = { version = "0.1", features = ["log"] }
|
||||
tracing-core = { version = "0.1" }
|
||||
url = { version = "2", features = ["serde"] }
|
||||
zerocopy = { version = "0.7", features = ["derive", "simd"] }
|
||||
zeroize = { version = "1", features = ["derive", "serde"] }
|
||||
zstd = { version = "0.13" }
|
||||
zstd-safe = { version = "7", default-features = false, features = ["arrays", "legacy", "std", "zdict_builder"] }
|
||||
@@ -126,6 +127,7 @@ serde = { version = "1", features = ["alloc", "derive"] }
|
||||
syn = { version = "2", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] }
|
||||
time-macros = { version = "0.2", default-features = false, features = ["formatting", "parsing", "serde"] }
|
||||
toml_edit = { version = "0.22", features = ["serde"] }
|
||||
zerocopy = { version = "0.7", features = ["derive", "simd"] }
|
||||
zstd = { version = "0.13" }
|
||||
zstd-safe = { version = "7", default-features = false, features = ["arrays", "legacy", "std", "zdict_builder"] }
|
||||
zstd-sys = { version = "2", default-features = false, features = ["legacy", "std", "zdict_builder"] }
|
||||
|
||||
Reference in New Issue
Block a user