mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-28 18:40:38 +00:00
Merge pull request #8351 from neondatabase/rc/proxy/2024-07-11
Proxy release 2024-07-11
This commit is contained in:
@@ -216,10 +216,11 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
use pq_proto::FeStartupPacket::*;
|
||||
|
||||
match msg {
|
||||
SslRequest => {
|
||||
SslRequest { direct: false } => {
|
||||
stream
|
||||
.write_message(&pq_proto::BeMessage::EncryptionResponse(true))
|
||||
.await?;
|
||||
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ use proxy::usage_metrics;
|
||||
use anyhow::bail;
|
||||
use proxy::config::{self, ProxyConfig};
|
||||
use proxy::serverless;
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
@@ -205,8 +206,8 @@ struct ProxyCliArgs {
|
||||
/// remote storage configuration for backup metric collection
|
||||
/// Encoded as toml (same format as pageservers), eg
|
||||
/// `{bucket_name='the-bucket',bucket_region='us-east-1',prefix_in_bucket='proxy',endpoint='http://minio:9000'}`
|
||||
#[clap(long, default_value = "{}")]
|
||||
metric_backup_collection_remote_storage: String,
|
||||
#[clap(long, value_parser = remote_storage_from_toml)]
|
||||
metric_backup_collection_remote_storage: Option<RemoteStorageConfig>,
|
||||
/// chunk size for backup metric collection
|
||||
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
|
||||
#[clap(long, default_value = "4194304")]
|
||||
@@ -511,9 +512,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
}
|
||||
let backup_metric_collection_config = config::MetricBackupCollectionConfig {
|
||||
interval: args.metric_backup_collection_interval,
|
||||
remote_storage_config: remote_storage_from_toml(
|
||||
&args.metric_backup_collection_remote_storage,
|
||||
)?,
|
||||
remote_storage_config: args.metric_backup_collection_remote_storage.clone(),
|
||||
chunk_size: args.metric_backup_collection_chunk_size,
|
||||
};
|
||||
|
||||
|
||||
7
proxy/src/cache/common.rs
vendored
7
proxy/src/cache/common.rs
vendored
@@ -53,6 +53,13 @@ impl<C: Cache, V> Cached<C, V> {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn map<U>(self, f: impl FnOnce(V) -> U) -> Cached<C, U> {
|
||||
Cached {
|
||||
token: self.token,
|
||||
value: f(self.value),
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop this entry from a cache if it's still there.
|
||||
pub fn invalidate(self) -> V {
|
||||
if let Some((cache, info)) = &self.token {
|
||||
|
||||
38
proxy/src/cache/timed_lru.rs
vendored
38
proxy/src/cache/timed_lru.rs
vendored
@@ -65,6 +65,8 @@ impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
expires_at: Instant,
|
||||
ttl: Duration,
|
||||
update_ttl_on_retrieval: bool,
|
||||
value: T,
|
||||
}
|
||||
|
||||
@@ -122,7 +124,6 @@ impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let now = Instant::now();
|
||||
let deadline = now.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
@@ -142,7 +143,8 @@ impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
|
||||
// Update the deadline and the entry's position in the LRU list.
|
||||
if self.update_ttl_on_retrieval {
|
||||
let deadline = now.checked_add(raw_entry.get().ttl).expect("time overflow");
|
||||
if raw_entry.get().update_ttl_on_retrieval {
|
||||
raw_entry.get_mut().expires_at = deadline;
|
||||
}
|
||||
raw_entry.to_back();
|
||||
@@ -162,12 +164,27 @@ impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
/// existed, return the previous value and its creation timestamp.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
|
||||
self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval)
|
||||
}
|
||||
|
||||
/// Insert an entry to the cache. If an entry with the same key already
|
||||
/// existed, return the previous value and its creation timestamp.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn insert_raw_ttl(
|
||||
&self,
|
||||
key: K,
|
||||
value: V,
|
||||
ttl: Duration,
|
||||
update: bool,
|
||||
) -> (Instant, Option<V>) {
|
||||
let created_at = Instant::now();
|
||||
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
|
||||
let expires_at = created_at.checked_add(ttl).expect("time overflow");
|
||||
|
||||
let entry = Entry {
|
||||
created_at,
|
||||
expires_at,
|
||||
ttl,
|
||||
update_ttl_on_retrieval: update,
|
||||
value,
|
||||
};
|
||||
|
||||
@@ -190,6 +207,21 @@ impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
pub fn insert_ttl(&self, key: K, value: V, ttl: Duration) {
|
||||
self.insert_raw_ttl(key, value, ttl, false);
|
||||
}
|
||||
|
||||
pub fn insert_unit(&self, key: K, value: V) -> (Option<V>, Cached<&Self, ()>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value);
|
||||
|
||||
let cached = Cached {
|
||||
token: Some((self, LookupInfo { created_at, key })),
|
||||
value: (),
|
||||
};
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
|
||||
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
|
||||
|
||||
|
||||
@@ -75,6 +75,9 @@ impl TlsConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
|
||||
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
|
||||
|
||||
/// Configure TLS for the main endpoint.
|
||||
pub fn configure_tls(
|
||||
key_path: &str,
|
||||
@@ -111,16 +114,17 @@ pub fn configure_tls(
|
||||
let cert_resolver = Arc::new(cert_resolver);
|
||||
|
||||
// allow TLS 1.2 to be compatible with older client libraries
|
||||
let config = rustls::ServerConfig::builder_with_protocol_versions(&[
|
||||
let mut config = rustls::ServerConfig::builder_with_protocol_versions(&[
|
||||
&rustls::version::TLS13,
|
||||
&rustls::version::TLS12,
|
||||
])
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(cert_resolver.clone())
|
||||
.into();
|
||||
.with_cert_resolver(cert_resolver.clone());
|
||||
|
||||
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
|
||||
|
||||
Ok(TlsConfig {
|
||||
config,
|
||||
config: Arc::new(config),
|
||||
common_names,
|
||||
cert_resolver,
|
||||
})
|
||||
@@ -399,15 +403,11 @@ impl FromStr for EndpointCacheConfig {
|
||||
#[derive(Debug)]
|
||||
pub struct MetricBackupCollectionConfig {
|
||||
pub interval: Duration,
|
||||
pub remote_storage_config: OptRemoteStorageConfig,
|
||||
pub remote_storage_config: Option<RemoteStorageConfig>,
|
||||
pub chunk_size: usize,
|
||||
}
|
||||
|
||||
/// Hack to avoid clap being smarter. If you don't use this type alias, clap assumes more about the optional state and you get
|
||||
/// runtime type errors from the value parser we use.
|
||||
pub type OptRemoteStorageConfig = Option<RemoteStorageConfig>;
|
||||
|
||||
pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<OptRemoteStorageConfig> {
|
||||
pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<RemoteStorageConfig> {
|
||||
RemoteStorageConfig::from_toml(&s.parse()?)
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::proxy::retry::CouldRetry;
|
||||
|
||||
/// Generic error response with human-readable description.
|
||||
/// Note that we can't always present it to user as is.
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct ConsoleError {
|
||||
pub error: Box<str>,
|
||||
#[serde(skip)]
|
||||
@@ -82,41 +82,19 @@ impl CouldRetry for ConsoleError {
|
||||
.details
|
||||
.error_info
|
||||
.map_or(Reason::Unknown, |e| e.reason);
|
||||
match reason {
|
||||
// not a transitive error
|
||||
Reason::RoleProtected => false,
|
||||
// on retry, it will still not be found
|
||||
Reason::ResourceNotFound
|
||||
| Reason::ProjectNotFound
|
||||
| Reason::EndpointNotFound
|
||||
| Reason::BranchNotFound => false,
|
||||
// we were asked to go away
|
||||
Reason::RateLimitExceeded
|
||||
| Reason::NonDefaultBranchComputeTimeExceeded
|
||||
| Reason::ActiveTimeQuotaExceeded
|
||||
| Reason::ComputeTimeQuotaExceeded
|
||||
| Reason::WrittenDataQuotaExceeded
|
||||
| Reason::DataTransferQuotaExceeded
|
||||
| Reason::LogicalSizeQuotaExceeded => false,
|
||||
// transitive error. control plane is currently busy
|
||||
// but might be ready soon
|
||||
Reason::RunningOperations => true,
|
||||
Reason::ConcurrencyLimitReached => true,
|
||||
Reason::LockAlreadyTaken => true,
|
||||
// unknown error. better not retry it.
|
||||
Reason::Unknown => false,
|
||||
}
|
||||
|
||||
reason.can_retry()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct Status {
|
||||
pub code: Box<str>,
|
||||
pub message: Box<str>,
|
||||
pub details: Details,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct Details {
|
||||
pub error_info: Option<ErrorInfo>,
|
||||
pub retry_info: Option<RetryInfo>,
|
||||
@@ -199,6 +177,34 @@ impl Reason {
|
||||
| Reason::BranchNotFound
|
||||
)
|
||||
}
|
||||
|
||||
pub fn can_retry(&self) -> bool {
|
||||
match self {
|
||||
// do not retry role protected errors
|
||||
// not a transitive error
|
||||
Reason::RoleProtected => false,
|
||||
// on retry, it will still not be found
|
||||
Reason::ResourceNotFound
|
||||
| Reason::ProjectNotFound
|
||||
| Reason::EndpointNotFound
|
||||
| Reason::BranchNotFound => false,
|
||||
// we were asked to go away
|
||||
Reason::RateLimitExceeded
|
||||
| Reason::NonDefaultBranchComputeTimeExceeded
|
||||
| Reason::ActiveTimeQuotaExceeded
|
||||
| Reason::ComputeTimeQuotaExceeded
|
||||
| Reason::WrittenDataQuotaExceeded
|
||||
| Reason::DataTransferQuotaExceeded
|
||||
| Reason::LogicalSizeQuotaExceeded => false,
|
||||
// transitive error. control plane is currently busy
|
||||
// but might be ready soon
|
||||
Reason::RunningOperations
|
||||
| Reason::ConcurrencyLimitReached
|
||||
| Reason::LockAlreadyTaken => true,
|
||||
// unknown error. better not retry it.
|
||||
Reason::Unknown => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Deserialize)]
|
||||
@@ -206,7 +212,7 @@ pub struct RetryInfo {
|
||||
pub retry_delay_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct UserFacingMessage {
|
||||
pub message: Box<str>,
|
||||
}
|
||||
|
||||
@@ -6,8 +6,9 @@ use anyhow::Context;
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
|
||||
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
|
||||
use std::{convert::Infallible, future};
|
||||
use std::convert::Infallible;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, Instrument};
|
||||
|
||||
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
|
||||
@@ -67,7 +68,9 @@ pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
|
||||
|
||||
async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
|
||||
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
|
||||
pgbackend.run(&mut MgmtHandler, future::pending::<()>).await
|
||||
pgbackend
|
||||
.run(&mut MgmtHandler, &CancellationToken::new())
|
||||
.await
|
||||
}
|
||||
|
||||
/// A message received by `mgmt` when a compute node is ready.
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
pub mod mock;
|
||||
pub mod neon;
|
||||
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use super::messages::{ConsoleError, MetricsAuxInfo};
|
||||
use crate::{
|
||||
auth::{
|
||||
backend::{ComputeCredentialKeys, ComputeUserInfo},
|
||||
@@ -317,8 +317,8 @@ impl NodeInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
|
||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
|
||||
pub type NodeInfoCache = TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ConsoleError>>>;
|
||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
|
||||
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ use super::{
|
||||
use crate::{
|
||||
auth::backend::ComputeUserInfo,
|
||||
compute,
|
||||
console::messages::ColdStartInfo,
|
||||
console::messages::{ColdStartInfo, Reason},
|
||||
http,
|
||||
metrics::{CacheOutcome, Metrics},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
@@ -17,10 +17,10 @@ use crate::{
|
||||
};
|
||||
use crate::{cache::Cached, context::RequestMonitoring};
|
||||
use futures::TryFutureExt;
|
||||
use std::sync::Arc;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||
|
||||
pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
@@ -273,26 +273,34 @@ impl super::Api for Api {
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let key = user_info.endpoint_cache_key();
|
||||
|
||||
macro_rules! check_cache {
|
||||
() => {
|
||||
if let Some(cached) = self.caches.node_info.get(&key) {
|
||||
let (cached, info) = cached.take_value();
|
||||
let info = info.map_err(|c| {
|
||||
info!(key = &*key, "found cached wake_compute error");
|
||||
WakeComputeError::ApiError(ApiError::Console(*c))
|
||||
})?;
|
||||
|
||||
debug!(key = &*key, "found cached compute node info");
|
||||
ctx.set_project(info.aux.clone());
|
||||
return Ok(cached.map(|()| info));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Every time we do a wakeup http request, the compute node will stay up
|
||||
// for some time (highly depends on the console's scale-to-zero policy);
|
||||
// The connection info remains the same during that period of time,
|
||||
// which means that we might cache it to reduce the load and latency.
|
||||
if let Some(cached) = self.caches.node_info.get(&key) {
|
||||
info!(key = &*key, "found cached compute node info");
|
||||
ctx.set_project(cached.aux.clone());
|
||||
return Ok(cached);
|
||||
}
|
||||
check_cache!();
|
||||
|
||||
let permit = self.locks.get_permit(&key).await?;
|
||||
|
||||
// after getting back a permit - it's possible the cache was filled
|
||||
// double check
|
||||
if permit.should_check_cache() {
|
||||
if let Some(cached) = self.caches.node_info.get(&key) {
|
||||
info!(key = &*key, "found cached compute node info");
|
||||
ctx.set_project(cached.aux.clone());
|
||||
return Ok(cached);
|
||||
}
|
||||
check_cache!();
|
||||
}
|
||||
|
||||
// check rate limit
|
||||
@@ -300,23 +308,56 @@ impl super::Api for Api {
|
||||
.wake_compute_endpoint_rate_limiter
|
||||
.check(user_info.endpoint.normalize_intern(), 1)
|
||||
{
|
||||
info!(key = &*key, "found cached compute node info");
|
||||
return Err(WakeComputeError::TooManyConnections);
|
||||
}
|
||||
|
||||
let mut node = permit.release_result(self.do_wake_compute(ctx, user_info).await)?;
|
||||
ctx.set_project(node.aux.clone());
|
||||
let cold_start_info = node.aux.cold_start_info;
|
||||
info!("woken up a compute node");
|
||||
let node = permit.release_result(self.do_wake_compute(ctx, user_info).await);
|
||||
match node {
|
||||
Ok(node) => {
|
||||
ctx.set_project(node.aux.clone());
|
||||
debug!(key = &*key, "created a cache entry for woken compute node");
|
||||
|
||||
// store the cached node as 'warm'
|
||||
node.aux.cold_start_info = ColdStartInfo::WarmCached;
|
||||
let (_, mut cached) = self.caches.node_info.insert(key.clone(), node);
|
||||
cached.aux.cold_start_info = cold_start_info;
|
||||
let mut stored_node = node.clone();
|
||||
// store the cached node as 'warm_cached'
|
||||
stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
|
||||
|
||||
info!(key = &*key, "created a cache entry for compute node info");
|
||||
let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
|
||||
|
||||
Ok(cached)
|
||||
Ok(cached.map(|()| node))
|
||||
}
|
||||
Err(err) => match err {
|
||||
WakeComputeError::ApiError(ApiError::Console(err)) => {
|
||||
let Some(status) = &err.status else {
|
||||
return Err(WakeComputeError::ApiError(ApiError::Console(err)));
|
||||
};
|
||||
|
||||
let reason = status
|
||||
.details
|
||||
.error_info
|
||||
.map_or(Reason::Unknown, |x| x.reason);
|
||||
|
||||
// if we can retry this error, do not cache it.
|
||||
if reason.can_retry() {
|
||||
return Err(WakeComputeError::ApiError(ApiError::Console(err)));
|
||||
}
|
||||
|
||||
// at this point, we should only have quota errors.
|
||||
debug!(
|
||||
key = &*key,
|
||||
"created a cache entry for the wake compute error"
|
||||
);
|
||||
|
||||
self.caches.node_info.insert_ttl(
|
||||
key,
|
||||
Err(Box::new(err.clone())),
|
||||
Duration::from_secs(30),
|
||||
);
|
||||
|
||||
Err(WakeComputeError::ApiError(ApiError::Console(err)))
|
||||
}
|
||||
err => return Err(err),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,17 +14,14 @@ use parquet::{
|
||||
record::RecordWriter,
|
||||
};
|
||||
use pq_proto::StartupMessageParams;
|
||||
use remote_storage::{GenericRemoteStorage, RemotePath, TimeoutOrCancel};
|
||||
use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel};
|
||||
use serde::ser::SerializeMap;
|
||||
use tokio::{sync::mpsc, time};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, Span};
|
||||
use utils::backoff;
|
||||
|
||||
use crate::{
|
||||
config::{remote_storage_from_toml, OptRemoteStorageConfig},
|
||||
context::LOG_CHAN_DISCONNECT,
|
||||
};
|
||||
use crate::{config::remote_storage_from_toml, context::LOG_CHAN_DISCONNECT};
|
||||
|
||||
use super::{RequestMonitoring, LOG_CHAN};
|
||||
|
||||
@@ -33,11 +30,11 @@ pub struct ParquetUploadArgs {
|
||||
/// Storage location to upload the parquet files to.
|
||||
/// Encoded as toml (same format as pageservers), eg
|
||||
/// `{bucket_name='the-bucket',bucket_region='us-east-1',prefix_in_bucket='proxy',endpoint='http://minio:9000'}`
|
||||
#[clap(long, default_value = "{}", value_parser = remote_storage_from_toml)]
|
||||
parquet_upload_remote_storage: OptRemoteStorageConfig,
|
||||
#[clap(long, value_parser = remote_storage_from_toml)]
|
||||
parquet_upload_remote_storage: Option<RemoteStorageConfig>,
|
||||
|
||||
#[clap(long, default_value = "{}", value_parser = remote_storage_from_toml)]
|
||||
parquet_upload_disconnect_events_remote_storage: OptRemoteStorageConfig,
|
||||
#[clap(long, value_parser = remote_storage_from_toml)]
|
||||
parquet_upload_disconnect_events_remote_storage: Option<RemoteStorageConfig>,
|
||||
|
||||
/// How many rows to include in a row group
|
||||
#[clap(long, default_value_t = 8192)]
|
||||
|
||||
@@ -4,14 +4,11 @@
|
||||
|
||||
pub mod health_server;
|
||||
|
||||
use std::{str::FromStr, sync::Arc, time::Duration};
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::FutureExt;
|
||||
pub use reqwest::{Request, Response, StatusCode};
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio::time::Instant;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::{
|
||||
metrics::{ConsoleRequest, Metrics},
|
||||
@@ -24,8 +21,6 @@ use reqwest_middleware::RequestBuilder;
|
||||
/// We deliberately don't want to replace this with a public static.
|
||||
pub fn new_client() -> ClientWithMiddleware {
|
||||
let client = reqwest::ClientBuilder::new()
|
||||
.dns_resolver(Arc::new(GaiResolver::default()))
|
||||
.connection_verbose(true)
|
||||
.build()
|
||||
.expect("Failed to create http client");
|
||||
|
||||
@@ -36,8 +31,6 @@ pub fn new_client() -> ClientWithMiddleware {
|
||||
|
||||
pub fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware {
|
||||
let timeout_client = reqwest::ClientBuilder::new()
|
||||
.dns_resolver(Arc::new(GaiResolver::default()))
|
||||
.connection_verbose(true)
|
||||
.timeout(default_timout)
|
||||
.build()
|
||||
.expect("Failed to create http client with timeout");
|
||||
@@ -103,38 +96,6 @@ impl Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
use hyper_util::client::legacy::connect::dns::{
|
||||
GaiResolver as HyperGaiResolver, Name as HyperName,
|
||||
};
|
||||
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
|
||||
/// https://docs.rs/reqwest/0.11.18/src/reqwest/dns/gai.rs.html
|
||||
use tower_service::Service;
|
||||
#[derive(Debug)]
|
||||
pub struct GaiResolver(HyperGaiResolver);
|
||||
|
||||
impl Default for GaiResolver {
|
||||
fn default() -> Self {
|
||||
Self(HyperGaiResolver::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl Resolve for GaiResolver {
|
||||
fn resolve(&self, name: Name) -> Resolving {
|
||||
let this = &mut self.0.clone();
|
||||
let hyper_name = HyperName::from_str(name.as_str()).expect("name should be valid");
|
||||
let start = Instant::now();
|
||||
Box::pin(
|
||||
Service::<HyperName>::call(this, hyper_name).map(move |result| {
|
||||
let resolve_duration = start.elapsed();
|
||||
trace!(duration = ?resolve_duration, addr = %name.as_str(), "resolve host complete");
|
||||
result
|
||||
.map(|addrs| -> Addrs { Box::new(addrs) })
|
||||
.map_err(|err| -> Box<dyn std::error::Error + Send + Sync> { Box::new(err) })
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -3,8 +3,8 @@ use std::marker::PhantomData;
|
||||
use measured::{
|
||||
label::NoLabels,
|
||||
metric::{
|
||||
gauge::GaugeState, group::Encoding, group::MetricValue, name::MetricNameEncoder,
|
||||
MetricEncoding, MetricFamilyEncoding, MetricType,
|
||||
gauge::GaugeState, group::Encoding, name::MetricNameEncoder, MetricEncoding,
|
||||
MetricFamilyEncoding, MetricType,
|
||||
},
|
||||
text::TextEncoder,
|
||||
LabelGroup, MetricGroup,
|
||||
@@ -100,7 +100,7 @@ macro_rules! jemalloc_gauge {
|
||||
enc: &mut TextEncoder<W>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
if let Ok(v) = mib.read() {
|
||||
enc.write_metric_value(name, labels, MetricValue::Int(v as i64))?;
|
||||
GaugeState::new(v as i64).collect_into(&(), labels, name, enc)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -15,7 +15,8 @@ use tracing_subscriber::{
|
||||
pub async fn init() -> anyhow::Result<LoggingGuard> {
|
||||
let env_filter = EnvFilter::builder()
|
||||
.with_default_directive(LevelFilter::INFO.into())
|
||||
.from_env_lossy();
|
||||
.from_env_lossy()
|
||||
.add_directive("azure_core::policies::transport=off".parse().unwrap());
|
||||
|
||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||
.with_ansi(false)
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::sync::{Arc, OnceLock};
|
||||
|
||||
use lasso::ThreadedRodeo;
|
||||
use measured::{
|
||||
label::{FixedCardinalitySet, LabelName, LabelSet, LabelValue, StaticLabelSet},
|
||||
label::{FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet},
|
||||
metric::{histogram::Thresholds, name::MetricName},
|
||||
Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec,
|
||||
LabelGroup, MetricGroup,
|
||||
@@ -577,6 +577,32 @@ impl LabelGroup for ThreadPoolWorkerId {
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelGroupSet for ThreadPoolWorkers {
|
||||
type Group<'a> = ThreadPoolWorkerId;
|
||||
|
||||
fn cardinality(&self) -> Option<usize> {
|
||||
Some(self.0)
|
||||
}
|
||||
|
||||
fn encode_dense(&self, value: Self::Unique) -> Option<usize> {
|
||||
Some(value)
|
||||
}
|
||||
|
||||
fn decode_dense(&self, value: usize) -> Self::Group<'_> {
|
||||
ThreadPoolWorkerId(value)
|
||||
}
|
||||
|
||||
type Unique = usize;
|
||||
|
||||
fn encode(&self, value: Self::Group<'_>) -> Option<Self::Unique> {
|
||||
Some(value.0)
|
||||
}
|
||||
|
||||
fn decode(&self, value: &Self::Unique) -> Self::Group<'_> {
|
||||
ThreadPoolWorkerId(*value)
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelSet for ThreadPoolWorkers {
|
||||
type Value<'a> = ThreadPoolWorkerId;
|
||||
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
use pq_proto::{BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams};
|
||||
use bytes::Buf;
|
||||
use pq_proto::{
|
||||
framed::Framed, BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion,
|
||||
StartupMessageParams,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
config::TlsConfig,
|
||||
auth::endpoint_sni,
|
||||
config::{TlsConfig, PG_ALPN_PROTOCOL},
|
||||
error::ReportableError,
|
||||
metrics::Metrics,
|
||||
proxy::ERR_INSECURE_CONNECTION,
|
||||
stream::{PqStream, Stream, StreamUpgradeError},
|
||||
};
|
||||
@@ -68,6 +74,9 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// Client may try upgrading to each protocol only once
|
||||
let (mut tried_ssl, mut tried_gss) = (false, false);
|
||||
|
||||
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
@@ -75,40 +84,96 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
use FeStartupPacket::*;
|
||||
match msg {
|
||||
SslRequest => match stream.get_ref() {
|
||||
SslRequest { direct } => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_ssl => {
|
||||
tried_ssl = true;
|
||||
|
||||
// We can't perform TLS handshake without a config
|
||||
let enc = tls.is_some();
|
||||
stream.write_message(&Be::EncryptionResponse(enc)).await?;
|
||||
let have_tls = tls.is_some();
|
||||
if !direct {
|
||||
stream
|
||||
.write_message(&Be::EncryptionResponse(have_tls))
|
||||
.await?;
|
||||
} else if !have_tls {
|
||||
return Err(HandshakeError::ProtocolViolation);
|
||||
}
|
||||
|
||||
if let Some(tls) = tls.take() {
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let (raw, read_buf) = stream.into_inner();
|
||||
// TODO: Normally, client doesn't send any data before
|
||||
// server says TLS handshake is ok and read_buf is empy.
|
||||
// However, you could imagine pipelining of postgres
|
||||
// SSLRequest + TLS ClientHello in one hunk similar to
|
||||
// pipelining in our node js driver. We should probably
|
||||
// support that by chaining read_buf with the stream.
|
||||
let Framed {
|
||||
stream: raw,
|
||||
read_buf,
|
||||
write_buf,
|
||||
} = stream.framed;
|
||||
|
||||
let Stream::Raw { raw } = raw else {
|
||||
return Err(HandshakeError::StreamUpgradeError(
|
||||
StreamUpgradeError::AlreadyTls,
|
||||
));
|
||||
};
|
||||
|
||||
let mut read_buf = read_buf.reader();
|
||||
let mut res = Ok(());
|
||||
let accept = tokio_rustls::TlsAcceptor::from(tls.to_server_config())
|
||||
.accept_with(raw, |session| {
|
||||
// push the early data to the tls session
|
||||
while !read_buf.get_ref().is_empty() {
|
||||
match session.read_tls(&mut read_buf) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
res = Err(e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
res?;
|
||||
|
||||
let read_buf = read_buf.into_inner();
|
||||
if !read_buf.is_empty() {
|
||||
return Err(HandshakeError::EarlyData);
|
||||
}
|
||||
let tls_stream = raw
|
||||
.upgrade(tls.to_server_config(), record_handshake_error)
|
||||
.await?;
|
||||
|
||||
let tls_stream = accept.await.inspect_err(|_| {
|
||||
if record_handshake_error {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc()
|
||||
}
|
||||
})?;
|
||||
|
||||
let conn_info = tls_stream.get_ref().1;
|
||||
|
||||
// check the ALPN, if exists, as required.
|
||||
match conn_info.alpn_protocol() {
|
||||
None | Some(PG_ALPN_PROTOCOL) => {}
|
||||
Some(other) => {
|
||||
// try parse ep for better error
|
||||
let ep = conn_info.server_name().and_then(|sni| {
|
||||
endpoint_sni(sni, &tls.common_names).ok().flatten()
|
||||
});
|
||||
let alpn = String::from_utf8_lossy(other);
|
||||
warn!(?ep, %alpn, "unexpected ALPN");
|
||||
return Err(HandshakeError::ProtocolViolation);
|
||||
}
|
||||
}
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(tls_stream.get_ref().1.server_name())
|
||||
.resolve(conn_info.server_name())
|
||||
.ok_or(HandshakeError::MissingCertificate)?;
|
||||
|
||||
stream = PqStream::new(Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
});
|
||||
stream = PqStream {
|
||||
framed: Framed {
|
||||
stream: Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
},
|
||||
read_buf,
|
||||
write_buf,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
@@ -122,7 +187,9 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
},
|
||||
StartupMessage { params, .. } => {
|
||||
StartupMessage { params, version }
|
||||
if PG_PROTOCOL_EARLIEST <= version && version <= PG_PROTOCOL_LATEST =>
|
||||
{
|
||||
// Check that the config has been consumed during upgrade
|
||||
// OR we didn't provide it at all (for dev purposes).
|
||||
if tls.is_some() {
|
||||
@@ -131,9 +198,48 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
.await?;
|
||||
}
|
||||
|
||||
info!(session_type = "normal", "successful handshake");
|
||||
info!(?version, session_type = "normal", "successful handshake");
|
||||
break Ok(HandshakeData::Startup(stream, params));
|
||||
}
|
||||
// downgrade protocol version
|
||||
StartupMessage { params, version }
|
||||
if version.major() == 3 && version > PG_PROTOCOL_LATEST =>
|
||||
{
|
||||
warn!(?version, "unsupported minor version");
|
||||
|
||||
// no protocol extensions are supported.
|
||||
// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/backend/tcop/backend_startup.c#L744-L753>
|
||||
let mut unsupported = vec![];
|
||||
for (k, _) in params.iter() {
|
||||
if k.starts_with("_pq_.") {
|
||||
unsupported.push(k);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: remove unsupported options so we don't send them to compute.
|
||||
|
||||
stream
|
||||
.write_message(&Be::NegotiateProtocolVersion {
|
||||
version: PG_PROTOCOL_LATEST,
|
||||
options: &unsupported,
|
||||
})
|
||||
.await?;
|
||||
|
||||
info!(
|
||||
?version,
|
||||
session_type = "normal",
|
||||
"successful handshake; unsupported minor version requested"
|
||||
);
|
||||
break Ok(HandshakeData::Startup(stream, params));
|
||||
}
|
||||
StartupMessage { version, .. } => {
|
||||
warn!(
|
||||
?version,
|
||||
session_type = "normal",
|
||||
"unsuccessful handshake; unsupported version"
|
||||
);
|
||||
return Err(HandshakeError::ProtocolViolation);
|
||||
}
|
||||
CancelRequest(cancel_key_data) => {
|
||||
info!(session_type = "cancellation", "successful handshake");
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data));
|
||||
|
||||
@@ -540,8 +540,8 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
|
||||
},
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
let (_, node) = cache.insert("key".into(), node);
|
||||
node
|
||||
let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone()));
|
||||
node2.map(|()| node)
|
||||
}
|
||||
|
||||
fn helper_create_connect_info(
|
||||
|
||||
@@ -837,8 +837,9 @@ async fn query_to_json<T: GenericClient>(
|
||||
"finished reading rows"
|
||||
);
|
||||
|
||||
let mut fields = vec![];
|
||||
let mut columns = vec![];
|
||||
let columns_len = row_stream.columns().len();
|
||||
let mut fields = Vec::with_capacity(columns_len);
|
||||
let mut columns = Vec::with_capacity(columns_len);
|
||||
|
||||
for c in row_stream.columns() {
|
||||
fields.push(json!({
|
||||
|
||||
Reference in New Issue
Block a user