proxy: format imports with nightly rustfmt (#9414)

```shell
cargo +nightly fmt -p proxy -- -l --config imports_granularity=Module,group_imports=StdExternalCrate,reorder_imports=true
```

These rust-analyzer settings for VSCode should help retain this style:
```json
  "rust-analyzer.imports.group.enable": true,
  "rust-analyzer.imports.prefix": "crate",
  "rust-analyzer.imports.merge.glob": false,
  "rust-analyzer.imports.granularity.group": "module",
  "rust-analyzer.imports.granularity.enforce": true,
```
This commit is contained in:
Folke Behrens
2024-10-16 15:01:56 +02:00
committed by GitHub
parent 89a65a9e5a
commit f14e45f0ce
73 changed files with 726 additions and 835 deletions

View File

@@ -1,16 +1,15 @@
use super::{ComputeCredentials, ComputeUserInfo};
use crate::{
auth::{self, backend::ComputeCredentialKeys, AuthFlow},
compute,
config::AuthenticationConfig,
context::RequestMonitoring,
control_plane::AuthSecret,
sasl,
stream::{PqStream, Stream},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
use super::{ComputeCredentials, ComputeUserInfo};
use crate::auth::backend::ComputeCredentialKeys;
use crate::auth::{self, AuthFlow};
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::AuthSecret;
use crate::stream::{PqStream, Stream};
use crate::{compute, sasl};
pub(super) async fn authenticate(
ctx: &RequestMonitoring,
creds: ComputeUserInfo,

View File

@@ -1,15 +1,3 @@
use crate::{
auth,
cache::Cached,
compute,
config::AuthenticationConfig,
context::RequestMonitoring,
control_plane::{self, provider::NodeInfo, CachedNodeInfo},
error::{ReportableError, UserFacingError},
proxy::connect_compute::ComputeConnectBackend,
stream::PqStream,
waiters,
};
use async_trait::async_trait;
use pq_proto::BeMessage as Be;
use thiserror::Error;
@@ -18,6 +6,15 @@ use tokio_postgres::config::SslMode;
use tracing::{info, info_span};
use super::ComputeCredentialKeys;
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::provider::NodeInfo;
use crate::control_plane::{self, CachedNodeInfo};
use crate::error::{ReportableError, UserFacingError};
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::stream::PqStream;
use crate::{auth, compute, waiters};
#[derive(Debug, Error)]
pub(crate) enum WebAuthError {

View File

@@ -1,16 +1,15 @@
use super::{ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint};
use crate::{
auth::{self, AuthFlow},
config::AuthenticationConfig,
context::RequestMonitoring,
control_plane::AuthSecret,
intern::EndpointIdInt,
sasl,
stream::{self, Stream},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
use super::{ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint};
use crate::auth::{self, AuthFlow};
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::sasl;
use crate::stream::{self, Stream};
/// Compared to [SCRAM](crate::scram), cleartext password auth saves
/// one round trip and *expensive* computations (>= 4096 HMAC iterations).
/// These properties are benefical for serverless JS workers, so we

View File

@@ -1,22 +1,22 @@
use std::{
future::Future,
sync::Arc,
time::{Duration, SystemTime},
};
use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use arc_swap::ArcSwapOption;
use dashmap::DashMap;
use jose_jwk::crypto::KeyInfo;
use serde::{de::Visitor, Deserialize, Deserializer};
use serde::de::Visitor;
use serde::{Deserialize, Deserializer};
use signature::Verifier;
use thiserror::Error;
use tokio::time::Instant;
use crate::{
auth::backend::ComputeCredentialKeys, context::RequestMonitoring,
control_plane::errors::GetEndpointJwksError, http::parse_json_body_with_limit,
intern::RoleNameInt, EndpointId, RoleName,
};
use crate::auth::backend::ComputeCredentialKeys;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::GetEndpointJwksError;
use crate::http::parse_json_body_with_limit;
use crate::intern::RoleNameInt;
use crate::{EndpointId, RoleName};
// TODO(conrad): make these configurable.
const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
@@ -381,10 +381,8 @@ fn verify_rsa_signature(
alg: &jose_jwa::Algorithm,
) -> Result<(), JwtError> {
use jose_jwa::{Algorithm, Signing};
use rsa::{
pkcs1v15::{Signature, VerifyingKey},
RsaPublicKey,
};
use rsa::pkcs1v15::{Signature, VerifyingKey};
use rsa::RsaPublicKey;
let key = RsaPublicKey::try_from(key).map_err(JwtError::InvalidRsaKey)?;
@@ -655,11 +653,9 @@ impl From<&jose_jwk::Key> for KeyType {
#[cfg(test)]
mod tests {
use crate::RoleName;
use super::*;
use std::{future::IntoFuture, net::SocketAddr, time::SystemTime};
use std::future::IntoFuture;
use std::net::SocketAddr;
use std::time::SystemTime;
use base64::URL_SAFE_NO_PAD;
use bytes::Bytes;
@@ -672,6 +668,9 @@ mod tests {
use signature::Signer;
use tokio::net::TcpListener;
use super::*;
use crate::RoleName;
fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
let sk = p256::SecretKey::random(&mut OsRng);
let pk = sk.public_key().into();

View File

@@ -2,19 +2,14 @@ use std::net::SocketAddr;
use arc_swap::ArcSwapOption;
use crate::{
auth::backend::jwt::FetchAuthRulesError,
compute::ConnCfg,
context::RequestMonitoring,
control_plane::{
messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo},
NodeInfo,
},
intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag},
EndpointId,
};
use super::jwt::{AuthRule, FetchAuthRules};
use crate::auth::backend::jwt::FetchAuthRulesError;
use crate::compute::ConnCfg;
use crate::context::RequestMonitoring;
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo};
use crate::control_plane::NodeInfo;
use crate::intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag};
use crate::EndpointId;
pub struct LocalBackend {
pub(crate) node_info: NodeInfo,

View File

@@ -17,29 +17,22 @@ use tokio_postgres::config::AuthKeys;
use tracing::{info, warn};
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::{validate_password_and_exchange, AuthError};
use crate::auth::{self, validate_password_and_exchange, AuthError, ComputeUserInfoMaybeEndpoint};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::provider::{CachedRoleSecret, ControlPlaneBackend};
use crate::control_plane::AuthSecret;
use crate::control_plane::provider::{
CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneBackend,
};
use crate::control_plane::{self, Api, AuthSecret};
use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::NeonOptions;
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
use crate::stream::Stream;
use crate::{
auth::{self, ComputeUserInfoMaybeEndpoint},
config::AuthenticationConfig,
control_plane::{
self,
provider::{CachedAllowedIps, CachedNodeInfo},
Api,
},
stream,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use crate::{scram, stream, EndpointCacheKey, EndpointId, RoleName};
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
pub enum MaybeOwned<'a, T> {
@@ -500,34 +493,32 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
#[cfg(test)]
mod tests {
use std::{net::IpAddr, sync::Arc, time::Duration};
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use once_cell::sync::Lazy;
use postgres_protocol::{
authentication::sasl::{ChannelBinding, ScramSha256},
message::{backend::Message as PgMessage, frontend},
};
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
use postgres_protocol::message::backend::Message as PgMessage;
use postgres_protocol::message::frontend;
use provider::AuthSecret;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use crate::{
auth::{backend::MaskedIp, ComputeUserInfoMaybeEndpoint, IpPattern},
config::AuthenticationConfig,
context::RequestMonitoring,
control_plane::{
self,
provider::{self, CachedAllowedIps, CachedRoleSecret},
CachedNodeInfo,
},
proxy::NeonOptions,
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
scram::{threadpool::ThreadPool, ServerSecret},
stream::{PqStream, Stream},
};
use super::{auth_quirks, jwt::JwkCache, AuthRateLimiter};
use super::jwt::JwkCache;
use super::{auth_quirks, AuthRateLimiter};
use crate::auth::backend::MaskedIp;
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::provider::{self, CachedAllowedIps, CachedRoleSecret};
use crate::control_plane::{self, CachedNodeInfo};
use crate::proxy::NeonOptions;
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::scram::ServerSecret;
use crate::stream::{PqStream, Stream};
struct Auth {
ips: Vec<IpPattern>,

View File

@@ -1,20 +1,22 @@
//! User credentials used in authentication.
use crate::{
auth::password_hack::parse_endpoint_param,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::{Metrics, SniKind},
proxy::NeonOptions,
serverless::SERVERLESS_DRIVER_SNI,
EndpointId, RoleName,
};
use std::collections::HashSet;
use std::net::IpAddr;
use std::str::FromStr;
use itertools::Itertools;
use pq_proto::StartupMessageParams;
use std::{collections::HashSet, net::IpAddr, str::FromStr};
use thiserror::Error;
use tracing::{info, warn};
use crate::auth::password_hack::parse_endpoint_param;
use crate::context::RequestMonitoring;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, SniKind};
use crate::proxy::NeonOptions;
use crate::serverless::SERVERLESS_DRIVER_SNI;
use crate::{EndpointId, RoleName};
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub(crate) enum ComputeUserInfoParseError {
#[error("Parameter '{0}' is missing in startup packet.")]
@@ -249,10 +251,11 @@ fn project_name_valid(name: &str) -> bool {
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use ComputeUserInfoParseError::*;
use super::*;
#[test]
fn parse_bare_minimum() -> anyhow::Result<()> {
// According to postgresql, only `user` should be required.

View File

@@ -1,21 +1,24 @@
//! Main authentication flow.
use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
use crate::{
config::TlsServerEndPoint,
context::RequestMonitoring,
control_plane::AuthSecret,
intern::EndpointIdInt,
sasl,
scram::{self, threadpool::ThreadPool},
stream::{PqStream, Stream},
};
use std::io;
use std::sync::Arc;
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::{io, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use super::backend::ComputeCredentialKeys;
use super::{AuthErrorImpl, PasswordHackPayload};
use crate::config::TlsServerEndPoint;
use crate::context::RequestMonitoring;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::sasl;
use crate::scram::threadpool::ThreadPool;
use crate::scram::{self};
use crate::stream::{PqStream, Stream};
/// Every authentication selector is supposed to implement this trait.
pub(crate) trait AuthMethod {
/// Any authentication selector should provide initial backend message

View File

@@ -14,15 +14,15 @@ pub(crate) use password_hack::parse_endpoint_param;
use password_hack::PasswordHackPayload;
mod flow;
use std::io;
use std::net::IpAddr;
pub(crate) use flow::*;
use thiserror::Error;
use tokio::time::error::Elapsed;
use crate::{
control_plane,
error::{ReportableError, UserFacingError},
};
use std::{io, net::IpAddr};
use thiserror::Error;
use crate::control_plane;
use crate::error::{ReportableError, UserFacingError};
/// Convenience wrapper for the authentication error.
pub(crate) type Result<T> = std::result::Result<T, AuthError>;

View File

@@ -1,41 +1,43 @@
use std::{net::SocketAddr, pin::pin, str::FromStr, sync::Arc, time::Duration};
use std::net::SocketAddr;
use std::pin::pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{bail, ensure, Context};
use camino::{Utf8Path, Utf8PathBuf};
use compute_api::spec::LocalProxySpec;
use dashmap::DashMap;
use futures::future::Either;
use proxy::{
auth::{
self,
backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
},
},
cancellation::CancellationHandlerMain,
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
control_plane::{
locks::ApiLocks,
messages::{EndpointJwksResponse, JwksSettings},
},
http::health_server::AppMetrics,
intern::RoleNameInt,
metrics::{Metrics, ThreadPoolMetrics},
rate_limiter::{BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo},
scram::threadpool::ThreadPool,
serverless::{self, cancel_set::CancelSet, GlobalConnPoolOptions},
RoleName,
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP};
use proxy::auth::{self};
use proxy::cancellation::CancellationHandlerMain;
use proxy::config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig};
use proxy::control_plane::locks::ApiLocks;
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use proxy::http::health_server::AppMetrics;
use proxy::intern::RoleNameInt;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::rate_limiter::{
BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo,
};
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::{self, GlobalConnPoolOptions};
use proxy::RoleName;
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
use clap::Parser;
use tokio::{net::TcpListener, sync::Notify, task::JoinSet};
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use utils::{pid_file, project_build_tag, project_git_version, sentry_init::init_sentry};
use utils::sentry_init::init_sentry;
use utils::{pid_file, project_build_tag, project_git_version};
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;

View File

@@ -5,25 +5,23 @@
/// the outside. Similar to an ingress controller for HTTPS.
use std::{net::SocketAddr, sync::Arc};
use anyhow::{anyhow, bail, ensure, Context};
use clap::Arg;
use futures::future::Either;
use futures::TryFutureExt;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestMonitoring;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use rustls::pki_types::PrivateKeyDer;
use tokio::net::TcpListener;
use anyhow::{anyhow, bail, ensure, Context};
use clap::Arg;
use futures::TryFutureExt;
use proxy::stream::{PqStream, Stream};
use rustls::pki_types::PrivateKeyDer;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use utils::{project_git_version, sentry_init::init_sentry};
use tracing::{error, info, Instrument};
use utils::project_git_version;
use utils::sentry_init::init_sentry;
project_git_version!(GIT_VERSION);

View File

@@ -1,3 +1,8 @@
use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use anyhow::bail;
use aws_config::environment::EnvironmentVariableCredentialsProvider;
use aws_config::imds::credentials::ImdsCredentialsProvider;
use aws_config::meta::credentials::CredentialsProviderChain;
@@ -7,52 +12,34 @@ use aws_config::provider_config::ProviderConfig;
use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use aws_config::Region;
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::AuthRateLimiter;
use proxy::auth::backend::ConsoleRedirectBackend;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
use proxy::cancellation::CancellationHandler;
use proxy::config::remote_storage_from_toml;
use proxy::config::AuthenticationConfig;
use proxy::config::CacheOptions;
use proxy::config::HttpConfig;
use proxy::config::ProjectInfoCacheOptions;
use proxy::config::ProxyProtocolV2;
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
use proxy::cancellation::{CancelMap, CancellationHandler};
use proxy::config::{
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
};
use proxy::context::parquet::ParquetUploadArgs;
use proxy::control_plane;
use proxy::http;
use proxy::http::health_server::AppMetrics;
use proxy::metrics::Metrics;
use proxy::rate_limiter::EndpointRateLimiter;
use proxy::rate_limiter::LeakyBucketConfig;
use proxy::rate_limiter::RateBucketInfo;
use proxy::rate_limiter::WakeComputeRateLimiter;
use proxy::rate_limiter::{
EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter,
};
use proxy::redis::cancellation_publisher::RedisPublisherClient;
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use proxy::redis::elasticache;
use proxy::redis::notifications;
use proxy::redis::{elasticache, notifications};
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::usage_metrics;
use anyhow::bail;
use proxy::config::{self, ProxyConfig};
use proxy::serverless;
use proxy::{auth, control_plane, http, serverless, usage_metrics};
use remote_storage::RemoteStorageConfig;
use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::info;
use tracing::warn;
use tracing::Instrument;
use utils::{project_build_tag, project_git_version, sentry_init::init_sentry};
use tracing::{info, warn, Instrument};
use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);

View File

@@ -1,31 +1,23 @@
use std::{
convert::Infallible,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use std::convert::Infallible;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use dashmap::DashSet;
use redis::{
streams::{StreamReadOptions, StreamReadReply},
AsyncCommands, FromRedisValue, Value,
};
use redis::streams::{StreamReadOptions, StreamReadReply};
use redis::{AsyncCommands, FromRedisValue, Value};
use serde::Deserialize;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
use tracing::info;
use crate::{
config::EndpointCacheConfig,
context::RequestMonitoring,
intern::{BranchIdInt, EndpointIdInt, ProjectIdInt},
metrics::{Metrics, RedisErrors, RedisEventsCount},
rate_limiter::GlobalRateLimiter,
redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider,
EndpointId,
};
use crate::config::EndpointCacheConfig;
use crate::context::RequestMonitoring;
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
use crate::rate_limiter::GlobalRateLimiter;
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::EndpointId;
#[derive(Deserialize, Debug, Clone)]
pub(crate) struct ControlPlaneEventKey {

View File

@@ -1,9 +1,8 @@
use std::{
collections::HashSet,
convert::Infallible,
sync::{atomic::AtomicU64, Arc},
time::Duration,
};
use std::collections::HashSet;
use std::convert::Infallible;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use dashmap::DashMap;
@@ -13,15 +12,12 @@ use tokio::sync::Mutex;
use tokio::time::Instant;
use tracing::{debug, info};
use crate::{
auth::IpPattern,
config::ProjectInfoCacheOptions,
control_plane::AuthSecret,
intern::{EndpointIdInt, ProjectIdInt, RoleNameInt},
EndpointId, RoleName,
};
use super::{Cache, Cached};
use crate::auth::IpPattern;
use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::AuthSecret;
use crate::intern::{EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::{EndpointId, RoleName};
#[async_trait]
pub(crate) trait ProjectInfoCache {
@@ -371,7 +367,8 @@ impl Cache for ProjectInfoCacheImpl {
#[cfg(test)]
mod tests {
use super::*;
use crate::{scram::ServerSecret, ProjectId};
use crate::scram::ServerSecret;
use crate::ProjectId;
#[tokio::test]
async fn test_project_info_cache_settings() {

View File

@@ -1,9 +1,6 @@
use std::{
borrow::Borrow,
hash::Hash,
time::{Duration, Instant},
};
use tracing::debug;
use std::borrow::Borrow;
use std::hash::Hash;
use std::time::{Duration, Instant};
// This seems to make more sense than `lru` or `cached`:
//
@@ -15,8 +12,10 @@ use tracing::debug;
//
// On the other hand, `hashlink` has good download stats and appears to be maintained.
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
use tracing::debug;
use super::{common::Cached, timed_lru, Cache};
use super::common::Cached;
use super::{timed_lru, Cache};
/// An implementation of timed LRU cache with fixed capacity.
/// Key properties:

View File

@@ -1,6 +1,8 @@
use std::net::SocketAddr;
use std::sync::Arc;
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::{net::SocketAddr, sync::Arc};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
@@ -8,12 +10,10 @@ use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
use uuid::Uuid;
use crate::{
error::ReportableError,
metrics::{CancellationRequest, CancellationSource, Metrics},
redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
},
use crate::error::ReportableError;
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
use crate::redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
};
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;

View File

@@ -1,25 +1,31 @@
use crate::{
auth::parse_endpoint_param,
cancellation::CancelClosure,
context::RequestMonitoring,
control_plane::{errors::WakeComputeError, messages::MetricsAuxInfo, provider::ApiLockError},
error::{ReportableError, UserFacingError},
metrics::{Metrics, NumDbConnectionsGuard},
proxy::neon_option,
Host,
};
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use once_cell::sync::OnceCell;
use pq_proto::StartupMessageParams;
use rustls::{client::danger::ServerCertVerifier, pki_types::InvalidDnsNameError};
use std::{io, net::SocketAddr, sync::Arc, time::Duration};
use rustls::client::danger::ServerCertVerifier;
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres_rustls::MakeRustlsConnect;
use tracing::{error, info, warn};
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::control_plane::provider::ApiLockError;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::proxy::neon_option;
use crate::Host;
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
#[derive(Debug, Error)]

View File

@@ -1,29 +1,27 @@
use crate::{
auth::backend::{jwt::JwkCache, AuthRateLimiter},
control_plane::locks::ApiLocks,
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
scram::threadpool::ThreadPool,
serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
Host,
};
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{bail, ensure, Context, Ok};
use clap::ValueEnum;
use itertools::Itertools;
use remote_storage::RemoteStorageConfig;
use rustls::{
crypto::ring::sign,
pki_types::{CertificateDer, PrivateKeyDer},
};
use rustls::crypto::ring::sign;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use sha2::{Digest, Sha256};
use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::Arc,
time::Duration,
};
use tracing::{error, info};
use x509_parser::oid_registry;
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::AuthRateLimiter;
use crate::control_plane::locks::ApiLocks;
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::GlobalConnPoolOptions;
use crate::Host;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub metric_collection: Option<MetricCollectionConfig>,
@@ -692,9 +690,8 @@ impl FromStr for ConcurrencyLockOptions {
#[cfg(test)]
mod tests {
use crate::rate_limiter::Aimd;
use super::*;
use crate::rate_limiter::Aimd;
#[test]
fn test_parse_cache_options() -> anyhow::Result<()> {

View File

@@ -1,25 +1,22 @@
use crate::auth::backend::ConsoleRedirectBackend;
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::proxy::{
prepare_client_connection, run_until_cancelled, ClientRequestError, ErrorSource,
};
use crate::{
cancellation::{CancellationHandlerMain, CancellationHandlerMainInternal},
context::RequestMonitoring,
error::ReportableError,
metrics::{Metrics, NumClientConnectionsGuard},
protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData},
};
use futures::TryFutureExt;
use std::sync::Arc;
use futures::TryFutureExt;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, Instrument};
use crate::auth::backend::ConsoleRedirectBackend;
use crate::cancellation::{CancellationHandlerMain, CancellationHandlerMainInternal};
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::read_proxy_protocol;
use crate::proxy::connect_compute::{connect_to_compute, TcpMechanism};
use crate::proxy::handshake::{handshake, HandshakeData};
use crate::proxy::passthrough::ProxyPassthrough;
use crate::proxy::{
connect_compute::{connect_to_compute, TcpMechanism},
passthrough::ProxyPassthrough,
prepare_client_connection, run_until_cancelled, ClientRequestError, ErrorSource,
};
pub async fn task_main(

View File

@@ -1,24 +1,25 @@
//! Connection request monitoring contexts
use std::net::IpAddr;
use chrono::Utc;
use once_cell::sync::OnceCell;
use pq_proto::StartupMessageParams;
use smol_str::SmolStr;
use std::net::IpAddr;
use tokio::sync::mpsc;
use tracing::{debug, field::display, info, info_span, Span};
use tracing::field::display;
use tracing::{debug, info, info_span, Span};
use try_lock::TryLock;
use uuid::Uuid;
use crate::{
control_plane::messages::{ColdStartInfo, MetricsAuxInfo},
error::ErrorKind,
intern::{BranchIdInt, ProjectIdInt},
metrics::{ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting},
DbName, EndpointId, RoleName,
};
use self::parquet::RequestData;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::error::ErrorKind;
use crate::intern::{BranchIdInt, ProjectIdInt};
use crate::metrics::{
ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting,
};
use crate::{DbName, EndpointId, RoleName};
pub mod parquet;

View File

@@ -1,29 +1,28 @@
use std::{sync::Arc, time::SystemTime};
use std::sync::Arc;
use std::time::SystemTime;
use anyhow::Context;
use bytes::{buf::Writer, BufMut, BytesMut};
use bytes::buf::Writer;
use bytes::{BufMut, BytesMut};
use chrono::{Datelike, Timelike};
use futures::{Stream, StreamExt};
use parquet::{
basic::Compression,
file::{
metadata::RowGroupMetaDataPtr,
properties::{WriterProperties, WriterPropertiesPtr, DEFAULT_PAGE_SIZE},
writer::SerializedFileWriter,
},
record::RecordWriter,
};
use parquet::basic::Compression;
use parquet::file::metadata::RowGroupMetaDataPtr;
use parquet::file::properties::{WriterProperties, WriterPropertiesPtr, DEFAULT_PAGE_SIZE};
use parquet::file::writer::SerializedFileWriter;
use parquet::record::RecordWriter;
use pq_proto::StartupMessageParams;
use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel};
use serde::ser::SerializeMap;
use tokio::{sync::mpsc, time};
use tokio::sync::mpsc;
use tokio::time;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, Span};
use utils::backoff;
use crate::{config::remote_storage_from_toml, context::LOG_CHAN_DISCONNECT};
use super::{RequestMonitoringInner, LOG_CHAN};
use crate::config::remote_storage_from_toml;
use crate::context::LOG_CHAN_DISCONNECT;
#[derive(clap::Args, Clone, Debug)]
pub struct ParquetUploadArgs {
@@ -407,26 +406,26 @@ async fn upload_parquet(
#[cfg(test)]
mod tests {
use std::{net::Ipv4Addr, num::NonZeroUsize, sync::Arc};
use std::net::Ipv4Addr;
use std::num::NonZeroUsize;
use std::sync::Arc;
use camino::Utf8Path;
use clap::Parser;
use futures::{Stream, StreamExt};
use itertools::Itertools;
use parquet::{
basic::{Compression, ZstdLevel},
file::{
properties::{WriterProperties, DEFAULT_PAGE_SIZE},
reader::FileReader,
serialized_reader::SerializedFileReader,
},
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use parquet::basic::{Compression, ZstdLevel};
use parquet::file::properties::{WriterProperties, DEFAULT_PAGE_SIZE};
use parquet::file::reader::FileReader;
use parquet::file::serialized_reader::SerializedFileReader;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use remote_storage::{
GenericRemoteStorage, RemoteStorageConfig, RemoteStorageKind, S3Config,
DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT,
};
use tokio::{sync::mpsc, time};
use tokio::sync::mpsc;
use tokio::time;
use walkdir::WalkDir;
use super::{worker_inner, ParquetConfig, ParquetUploadArgs, RequestData};

View File

@@ -1,9 +1,9 @@
use measured::FixedCardinalityLabel;
use serde::{Deserialize, Serialize};
use std::fmt::{self, Display};
use crate::auth::IpPattern;
use measured::FixedCardinalityLabel;
use serde::{Deserialize, Serialize};
use crate::auth::IpPattern;
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::proxy::retry::CouldRetry;
@@ -362,9 +362,10 @@ pub struct JwksSettings {
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use super::*;
fn dummy_aux() -> serde_json::Value {
json!({
"endpoint_id": "endpoint",

View File

@@ -1,16 +1,16 @@
use crate::{
control_plane::messages::{DatabaseInfo, KickSession},
waiters::{self, Waiter, Waiters},
};
use std::convert::Infallible;
use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use std::convert::Infallible;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, Instrument};
use crate::control_plane::messages::{DatabaseInfo, KickSession};
use crate::waiters::{self, Waiter, Waiters};
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
/// Give caller an opportunity to wait for the cloud's reply.

View File

@@ -1,28 +1,29 @@
//! Mock console backend which relies on a user-provided postgres instance.
use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
};
use crate::{
auth::backend::jwt::AuthRule, context::RequestMonitoring,
control_plane::errors::GetEndpointJwksError, intern::RoleNameInt, RoleName,
};
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::IpPattern, cache::Cached};
use crate::{
control_plane::{
messages::MetricsAuxInfo,
provider::{CachedAllowedIps, CachedRoleSecret},
},
BranchId, EndpointId, ProjectId,
};
use std::str::FromStr;
use std::sync::Arc;
use futures::TryFutureExt;
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
use tokio_postgres::{config::SslMode, Client};
use tokio_postgres::config::SslMode;
use tokio_postgres::Client;
use tracing::{error, info, info_span, warn, Instrument};
use super::errors::{ApiError, GetAuthInfoError, WakeComputeError};
use super::{AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::IpPattern;
use crate::cache::Cached;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::GetEndpointJwksError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::control_plane::provider::{CachedAllowedIps, CachedRoleSecret};
use crate::error::io_error;
use crate::intern::RoleNameInt;
use crate::url::ApiUrl;
use crate::{compute, scram, BranchId, EndpointId, ProjectId, RoleName};
#[derive(Debug, Error)]
enum MockApiError {
#[error("Failed to read password: {0}")]

View File

@@ -2,39 +2,36 @@
pub mod mock;
pub mod neon;
use super::messages::{ControlPlaneError, MetricsAuxInfo};
use crate::{
auth::{
backend::{
jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError},
ComputeCredentialKeys, ComputeUserInfo,
},
IpPattern,
},
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
compute,
config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
context::RequestMonitoring,
error::ReportableError,
intern::ProjectIdInt,
metrics::ApiLockMetrics,
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
scram, EndpointCacheKey, EndpointId,
};
use std::hash::Hash;
use std::sync::Arc;
use std::time::Duration;
use dashmap::DashMap;
use std::{hash::Hash, sync::Arc, time::Duration};
use tokio::time::Instant;
use tracing::info;
use super::messages::{ControlPlaneError, MetricsAuxInfo};
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::IpPattern;
use crate::cache::endpoints::EndpointsCache;
use crate::cache::project_info::ProjectInfoCacheImpl;
use crate::cache::{Cached, TimedLru};
use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::intern::ProjectIdInt;
use crate::metrics::ApiLockMetrics;
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
use crate::{compute, scram, EndpointCacheKey, EndpointId};
pub(crate) mod errors {
use crate::{
control_plane::messages::{self, ControlPlaneError, Reason},
error::{io_error, ErrorKind, ReportableError, UserFacingError},
proxy::retry::CouldRetry,
};
use thiserror::Error;
use super::ApiLockError;
use crate::control_plane::messages::{self, ControlPlaneError, Reason};
use crate::error::{io_error, ErrorKind, ReportableError, UserFacingError};
use crate::proxy::retry::CouldRetry;
/// A go-to error message which doesn't leak any detail.
pub(crate) const REQUEST_FAILED: &str = "Console request failed";

View File

@@ -1,31 +1,31 @@
//! Production console backend.
use super::{
super::messages::{ControlPlaneError, GetRoleSecret, WakeCompute},
errors::{ApiError, GetAuthInfoError, WakeComputeError},
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
NodeInfo,
};
use crate::{
auth::backend::{jwt::AuthRule, ComputeUserInfo},
compute,
control_plane::{
errors::GetEndpointJwksError,
messages::{ColdStartInfo, EndpointJwksResponse, Reason},
},
http,
metrics::{CacheOutcome, Metrics},
rate_limiter::WakeComputeRateLimiter,
scram, EndpointCacheKey, EndpointId,
};
use crate::{cache::Cached, context::RequestMonitoring};
use ::http::{header::AUTHORIZATION, HeaderName};
use std::sync::Arc;
use std::time::Duration;
use ::http::header::AUTHORIZATION;
use ::http::HeaderName;
use futures::TryFutureExt;
use std::{sync::Arc, time::Duration};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
use tracing::{debug, info, info_span, warn, Instrument};
use super::super::messages::{ControlPlaneError, GetRoleSecret, WakeCompute};
use super::errors::{ApiError, GetAuthInfoError, WakeComputeError};
use super::{
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
NodeInfo,
};
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::ComputeUserInfo;
use crate::cache::Cached;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::GetEndpointJwksError;
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
use crate::metrics::{CacheOutcome, Metrics};
use crate::rate_limiter::WakeComputeRateLimiter;
use crate::{compute, http, scram, EndpointCacheKey, EndpointId};
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
#[derive(Clone)]

View File

@@ -1,4 +1,5 @@
use std::{error::Error as StdError, fmt, io};
use std::error::Error as StdError;
use std::{fmt, io};
use measured::FixedCardinalityLabel;

View File

@@ -1,19 +1,18 @@
use std::convert::Infallible;
use std::net::TcpListener;
use std::sync::{Arc, Mutex};
use anyhow::{anyhow, bail};
use hyper0::{header::CONTENT_TYPE, Body, Request, Response, StatusCode};
use measured::{text::BufferedTextEncoder, MetricGroup};
use hyper0::header::CONTENT_TYPE;
use hyper0::{Body, Request, Response, StatusCode};
use measured::text::BufferedTextEncoder;
use measured::MetricGroup;
use metrics::NeonMetrics;
use std::{
convert::Infallible,
net::TcpListener,
sync::{Arc, Mutex},
};
use tracing::{info, info_span};
use utils::http::{
endpoint::{self, request_span},
error::ApiError,
json::json_response,
RouterBuilder, RouterService,
};
use utils::http::endpoint::{self, request_span};
use utils::http::error::ApiError;
use utils::http::json::json_response;
use utils::http::{RouterBuilder, RouterService};
use crate::jemalloc;

View File

@@ -10,17 +10,15 @@ use anyhow::bail;
use bytes::Bytes;
use http_body_util::BodyExt;
use hyper::body::Body;
pub(crate) use reqwest::{Request, Response};
use reqwest_middleware::RequestBuilder;
pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error};
pub(crate) use reqwest_retry::policies::ExponentialBackoff;
pub(crate) use reqwest_retry::RetryTransientMiddleware;
use serde::de::DeserializeOwned;
pub(crate) use reqwest::{Request, Response};
pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error};
pub(crate) use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use crate::{
metrics::{ConsoleRequest, Metrics},
url::ApiUrl,
};
use reqwest_middleware::RequestBuilder;
use crate::metrics::{ConsoleRequest, Metrics};
use crate::url::ApiUrl;
/// This is the preferred way to create new http clients,
/// because it takes care of observability (OpenTelemetry).
@@ -142,9 +140,10 @@ pub(crate) async fn parse_json_body_with_limit<D: DeserializeOwned>(
#[cfg(test)]
mod tests {
use super::*;
use reqwest::Client;
use super::*;
#[test]
fn optional_query_params() -> anyhow::Result<()> {
let url = "http://example.com".parse()?;

View File

@@ -1,6 +1,8 @@
use std::{
hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
};
use std::hash::BuildHasherDefault;
use std::marker::PhantomData;
use std::num::NonZeroUsize;
use std::ops::Index;
use std::sync::OnceLock;
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
use rustc_hash::FxHasher;
@@ -208,9 +210,8 @@ impl From<ProjectId> for ProjectIdInt {
mod tests {
use std::sync::OnceLock;
use crate::intern::StringInterner;
use super::InternId;
use crate::intern::StringInterner;
struct MyId;
impl InternId for MyId {
@@ -222,7 +223,8 @@ mod tests {
#[test]
fn push_many_strings() {
use rand::{rngs::StdRng, Rng, SeedableRng};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rand_distr::Zipf;
let endpoint_dist = Zipf::new(500000, 0.8).unwrap();

View File

@@ -1,14 +1,12 @@
use std::marker::PhantomData;
use measured::{
label::NoLabels,
metric::{
gauge::GaugeState, group::Encoding, name::MetricNameEncoder, MetricEncoding,
MetricFamilyEncoding, MetricType,
},
text::TextEncoder,
LabelGroup, MetricGroup,
};
use measured::label::NoLabels;
use measured::metric::gauge::GaugeState;
use measured::metric::group::Encoding;
use measured::metric::name::MetricNameEncoder;
use measured::metric::{MetricEncoding, MetricFamilyEncoding, MetricType};
use measured::text::TextEncoder;
use measured::{LabelGroup, MetricGroup};
use tikv_jemalloc_ctl::{config, epoch, epoch_mib, stats, version};
pub struct MetricRecorder {

View File

@@ -1,14 +1,10 @@
use tracing::Subscriber;
use tracing_subscriber::{
filter::{EnvFilter, LevelFilter},
fmt::{
format::{Format, Full},
time::SystemTime,
FormatEvent, FormatFields,
},
prelude::*,
registry::LookupSpan,
};
use tracing_subscriber::filter::{EnvFilter, LevelFilter};
use tracing_subscriber::fmt::format::{Format, Full};
use tracing_subscriber::fmt::time::SystemTime;
use tracing_subscriber::fmt::{FormatEvent, FormatFields};
use tracing_subscriber::prelude::*;
use tracing_subscriber::registry::LookupSpan;
/// Initialize logging and OpenTelemetry tracing and exporter.
///

View File

@@ -1,14 +1,16 @@
use std::sync::{Arc, OnceLock};
use lasso::ThreadedRodeo;
use measured::label::{
FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet,
};
use measured::metric::histogram::Thresholds;
use measured::metric::name::MetricName;
use measured::{
label::{FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet},
metric::{histogram::Thresholds, name::MetricName},
Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup,
MetricGroup,
};
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec};
use tokio::time::{self, Instant};
use crate::control_plane::messages::ColdStartInfo;

View File

@@ -1,11 +1,9 @@
//! Proxy Protocol V2 implementation
use std::{
io,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use pin_project_lite::pin_project;

View File

@@ -1,24 +1,23 @@
use crate::{
auth::backend::ComputeCredentialKeys,
compute::COULD_NOT_CONNECT,
compute::{self, PostgresConnection},
config::RetryConfig,
context::RequestMonitoring,
control_plane::{self, errors::WakeComputeError, locks::ApiLocks, CachedNodeInfo, NodeInfo},
error::ReportableError,
metrics::{ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType},
proxy::{
retry::{retry_after, should_retry, CouldRetry},
wake_compute::wake_compute,
},
Host,
};
use async_trait::async_trait;
use pq_proto::StartupMessageParams;
use tokio::time;
use tracing::{debug, info, warn};
use super::retry::ShouldRetryWakeCompute;
use crate::auth::backend::ComputeCredentialKeys;
use crate::compute::{self, PostgresConnection, COULD_NOT_CONNECT};
use crate::config::RetryConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::proxy::retry::{retry_after, should_retry, CouldRetry};
use crate::proxy::wake_compute::wake_compute;
use crate::Host;
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);

View File

@@ -1,11 +1,11 @@
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::info;
use std::future::poll_fn;
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::info;
#[derive(Debug)]
enum TransferState {
Running(CopyBuffer),
@@ -256,9 +256,10 @@ impl CopyBuffer {
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncWriteExt;
use super::*;
#[tokio::test]
async fn test_client_to_compute() {
let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream

View File

@@ -1,21 +1,19 @@
use bytes::Buf;
use pq_proto::framed::Framed;
use pq_proto::{
framed::Framed, BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion,
StartupMessageParams,
BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
use crate::{
auth::endpoint_sni,
config::{TlsConfig, PG_ALPN_PROTOCOL},
context::RequestMonitoring,
error::ReportableError,
metrics::Metrics,
proxy::ERR_INSECURE_CONNECTION,
stream::{PqStream, Stream, StreamUpgradeError},
};
use crate::auth::endpoint_sni;
use crate::config::{TlsConfig, PG_ALPN_PROTOCOL};
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::ERR_INSECURE_CONNECTION;
use crate::stream::{PqStream, Stream, StreamUpgradeError};
#[derive(Error, Debug)]
pub(crate) enum HandshakeError {

View File

@@ -7,40 +7,32 @@ pub(crate) mod handshake;
pub(crate) mod passthrough;
pub(crate) mod retry;
pub(crate) mod wake_compute;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use std::sync::Arc;
use crate::config::ProxyProtocolV2;
use crate::{
auth,
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
compute,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
error::ReportableError,
metrics::{Metrics, NumClientConnectionsGuard},
protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
EndpointCacheKey,
};
pub use copy_bidirectional::{copy_bidirectional_client_compute, ErrorSource};
use futures::TryFutureExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn, Instrument};
use self::{
connect_compute::{connect_to_compute, TcpMechanism},
passthrough::ProxyPassthrough,
};
use self::connect_compute::{connect_to_compute, TcpMechanism};
use self::passthrough::ProxyPassthrough;
use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::read_proxy_protocol;
use crate::proxy::handshake::{handshake, HandshakeData};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
use crate::{auth, compute, EndpointCacheKey};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";

View File

@@ -1,16 +1,14 @@
use crate::{
cancellation,
compute::PostgresConnection,
control_plane::messages::MetricsAuxInfo,
metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard},
stream::Stream,
usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
use crate::compute::PostgresConnection;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]

View File

@@ -1,7 +1,11 @@
use crate::{compute, config::RetryConfig};
use std::{error::Error, io};
use std::error::Error;
use std::io;
use tokio::time;
use crate::compute;
use crate::config::RetryConfig;
pub(crate) trait CouldRetry {
/// Returns true if the error could be retried
fn could_retry(&self) -> bool;

View File

@@ -6,7 +6,6 @@
use std::fmt::Debug;
use super::*;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use postgres_protocol::message::frontend;
@@ -14,6 +13,8 @@ use tokio::io::{AsyncReadExt, DuplexStream};
use tokio_postgres::tls::TlsConnect;
use tokio_util::codec::{Decoder, Encoder};
use super::*;
enum Intercept {
None,
Methods,

View File

@@ -4,6 +4,16 @@ mod mitm;
use std::time::Duration;
use anyhow::{bail, Context};
use async_trait::async_trait;
use http::StatusCode;
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
use rustls::pki_types;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
use super::connect_compute::ConnectMechanism;
use super::retry::CouldRetry;
use super::*;
@@ -18,15 +28,6 @@ use crate::control_plane::provider::{
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
use crate::{sasl, scram, BranchId, EndpointId, ProjectId};
use anyhow::{bail, Context};
use async_trait::async_trait;
use http::StatusCode;
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
use rustls::pki_types;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
/// Generate a set of TLS certificates: CA + server.
fn generate_certs(
@@ -336,7 +337,8 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock()));
use rand::{distributions::Alphanumeric, Rng};
use rand::distributions::Alphanumeric;
use rand::Rng;
let password: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(rand::random::<u8>() as usize)

View File

@@ -1,16 +1,17 @@
use hyper::StatusCode;
use tracing::{error, info, warn};
use super::connect_compute::ComputeConnectBackend;
use crate::config::RetryConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::{ControlPlaneError, Reason};
use crate::control_plane::{errors::WakeComputeError, provider::CachedNodeInfo};
use crate::control_plane::provider::CachedNodeInfo;
use crate::metrics::{
ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType,
WakeupFailureKind,
};
use crate::proxy::retry::{retry_after, should_retry};
use hyper::StatusCode;
use tracing::{error, info, warn};
use super::connect_compute::ComputeConnectBackend;
pub(crate) async fn wake_compute<B: ComputeConnectBackend>(
num_retries: &mut u32,

View File

@@ -1,7 +1,5 @@
use std::{
hash::Hash,
sync::atomic::{AtomicUsize, Ordering},
};
use std::hash::Hash;
use std::sync::atomic::{AtomicUsize, Ordering};
use ahash::RandomState;
use dashmap::DashMap;

View File

@@ -1,10 +1,12 @@
//! Algorithms for controlling concurrency limits.
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
use parking_lot::Mutex;
use std::{pin::pin, sync::Arc, time::Duration};
use tokio::{
sync::Notify,
time::{error::Elapsed, Instant},
};
use tokio::sync::Notify;
use tokio::time::error::Elapsed;
use tokio::time::Instant;
use self::aimd::Aimd;

View File

@@ -60,12 +60,11 @@ impl LimitAlgorithm for Aimd {
mod tests {
use std::time::Duration;
use super::*;
use crate::rate_limiter::limit_algorithm::{
DynamicLimiter, RateLimitAlgorithm, RateLimiterConfig,
};
use super::*;
#[tokio::test(start_paused = true)]
async fn increase_decrease() {
let config = RateLimiterConfig {

View File

@@ -1,17 +1,14 @@
use std::{
borrow::Cow,
collections::hash_map::RandomState,
hash::{BuildHasher, Hash},
sync::{
atomic::{AtomicUsize, Ordering},
Mutex,
},
};
use std::borrow::Cow;
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hash};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
use anyhow::bail;
use dashmap::DashMap;
use itertools::Itertools;
use rand::{rngs::StdRng, Rng, SeedableRng};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use tokio::time::{Duration, Instant};
use tracing::info;
@@ -243,14 +240,17 @@ impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
#[cfg(test)]
mod tests {
use std::{hash::BuildHasherDefault, time::Duration};
use std::hash::BuildHasherDefault;
use std::time::Duration;
use rand::SeedableRng;
use rustc_hash::FxHasher;
use tokio::time;
use super::{BucketRateLimiter, WakeComputeRateLimiter};
use crate::{intern::EndpointIdInt, rate_limiter::RateBucketInfo, EndpointId};
use crate::intern::EndpointIdInt;
use crate::rate_limiter::RateBucketInfo;
use crate::EndpointId;
#[test]
fn rate_bucket_rpi() {

View File

@@ -2,13 +2,11 @@ mod leaky_bucket;
mod limit_algorithm;
mod limiter;
pub use leaky_bucket::{EndpointRateLimiter, LeakyBucketConfig, LeakyBucketRateLimiter};
#[cfg(test)]
pub(crate) use limit_algorithm::aimd::Aimd;
pub(crate) use limit_algorithm::{
DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token,
};
pub(crate) use limiter::GlobalRateLimiter;
pub use leaky_bucket::{EndpointRateLimiter, LeakyBucketConfig, LeakyBucketRateLimiter};
pub use limiter::{BucketRateLimiter, RateBucketInfo, WakeComputeRateLimiter};

View File

@@ -5,13 +5,10 @@ use redis::AsyncCommands;
use tokio::sync::Mutex;
use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME};
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
use super::{
connection_with_credentials_provider::ConnectionWithCredentialsProvider,
notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME},
};
pub trait CancellationPublisherMut: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(

View File

@@ -1,10 +1,9 @@
use std::{sync::Arc, time::Duration};
use std::sync::Arc;
use std::time::Duration;
use futures::FutureExt;
use redis::{
aio::{ConnectionLike, MultiplexedConnection},
ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult,
};
use redis::aio::{ConnectionLike, MultiplexedConnection};
use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult};
use tokio::task::JoinHandle;
use tracing::{debug, error, info, warn};

View File

@@ -1,4 +1,5 @@
use std::{convert::Infallible, sync::Arc};
use std::convert::Infallible;
use std::sync::Arc;
use futures::StreamExt;
use pq_proto::CancelKeyData;
@@ -8,12 +9,10 @@ use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::{
cache::project_info::ProjectInfoCache,
cancellation::{CancelMap, CancellationHandler},
intern::{ProjectIdInt, RoleNameInt},
metrics::{Metrics, RedisErrors, RedisEventsCount},
};
use crate::cache::project_info::ProjectInfoCache;
use crate::cancellation::{CancelMap, CancellationHandler};
use crate::intern::{ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
@@ -269,10 +268,10 @@ where
#[cfg(test)]
mod tests {
use crate::{ProjectId, RoleName};
use serde_json::json;
use super::*;
use serde_json::json;
use crate::{ProjectId, RoleName};
#[test]
fn parse_allowed_ips() -> anyhow::Result<()> {

View File

@@ -1,8 +1,9 @@
//! Definitions for SASL messages.
use crate::parse::{split_at_const, split_cstr};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage};
use crate::parse::{split_at_const, split_cstr};
/// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage).
#[derive(Debug)]
pub(crate) struct FirstMessage<'a> {

View File

@@ -10,13 +10,14 @@ mod channel_binding;
mod messages;
mod stream;
use crate::error::{ReportableError, UserFacingError};
use std::io;
use thiserror::Error;
pub(crate) use channel_binding::ChannelBinding;
pub(crate) use messages::FirstMessage;
pub(crate) use stream::{Outcome, SaslStream};
use thiserror::Error;
use crate::error::{ReportableError, UserFacingError};
/// Fine-grained auth errors help in writing tests.
#[derive(Error, Debug)]

View File

@@ -1,11 +1,14 @@
//! Abstraction for the string-oriented SASL protocols.
use super::{messages::ServerMessage, Mechanism};
use crate::stream::PqStream;
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use super::messages::ServerMessage;
use super::Mechanism;
use crate::stream::PqStream;
/// Abstracts away all peculiarities of the libpq's protocol.
pub(crate) struct SaslStream<'a, S> {
/// The underlying stream.

View File

@@ -69,7 +69,9 @@ impl CountMinSketch {
#[cfg(test)]
mod tests {
use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use rand::{Rng, SeedableRng};
use super::CountMinSketch;

View File

@@ -209,7 +209,8 @@ impl sasl::Mechanism for Exchange<'_> {
type Output = super::ScramKey;
fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
use {sasl::Step, ExchangeState};
use sasl::Step;
use ExchangeState;
match &self.state {
ExchangeState::Initial(init) => {
match init.transition(self.secret, &self.tls_server_end_point, input)? {

View File

@@ -1,11 +1,12 @@
//! Definitions for SCRAM messages.
use std::fmt;
use std::ops::Range;
use super::base64_decode_array;
use super::key::{ScramKey, SCRAM_KEY_LEN};
use super::signature::SignatureBuilder;
use crate::sasl::ChannelBinding;
use std::fmt;
use std::ops::Range;
/// Faithfully taken from PostgreSQL.
pub(crate) const SCRAM_RAW_NONCE_LEN: usize = 18;

View File

@@ -16,10 +16,9 @@ mod signature;
pub mod threadpool;
pub(crate) use exchange::{exchange, Exchange};
use hmac::{Hmac, Mac};
pub(crate) use key::ScramKey;
pub(crate) use secret::ServerSecret;
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
@@ -59,13 +58,11 @@ fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
#[cfg(test)]
mod tests {
use crate::{
intern::EndpointIdInt,
sasl::{Mechanism, Step},
EndpointId,
};
use super::{threadpool::ThreadPool, Exchange, ServerSecret};
use super::threadpool::ThreadPool;
use super::{Exchange, ServerSecret};
use crate::intern::EndpointIdInt;
use crate::sasl::{Mechanism, Step};
use crate::EndpointId;
#[test]
fn snapshot() {

View File

@@ -1,7 +1,6 @@
use hmac::{
digest::{consts::U32, generic_array::GenericArray},
Hmac, Mac,
};
use hmac::digest::consts::U32;
use hmac::digest::generic_array::GenericArray;
use hmac::{Hmac, Mac};
use sha2::Sha256;
pub(crate) struct Pbkdf2 {
@@ -66,10 +65,11 @@ impl Pbkdf2 {
#[cfg(test)]
mod tests {
use super::Pbkdf2;
use pbkdf2::pbkdf2_hmac_array;
use sha2::Sha256;
use super::Pbkdf2;
#[test]
fn works() {
let salt = b"sodium chloride";

View File

@@ -4,28 +4,21 @@
//! 1. Fairness per endpoint.
//! 2. Yield support for high iteration counts.
use std::{
cell::RefCell,
future::Future,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Weak,
},
task::{Context, Poll},
};
use std::cell::RefCell;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Weak};
use std::task::{Context, Poll};
use futures::FutureExt;
use rand::Rng;
use rand::{rngs::SmallRng, SeedableRng};
use crate::{
intern::EndpointIdInt,
metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
scram::countmin::CountMinSketch,
};
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::pbkdf2::Pbkdf2;
use crate::intern::EndpointIdInt;
use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
use crate::scram::countmin::CountMinSketch;
pub struct ThreadPool {
runtime: Option<tokio::runtime::Runtime>,
@@ -195,9 +188,8 @@ impl Drop for JobHandle {
#[cfg(test)]
mod tests {
use crate::EndpointId;
use super::*;
use crate::EndpointId;
#[tokio::test]
async fn hash_is_correct() {

View File

@@ -1,42 +1,34 @@
use std::{io, sync::Arc, time::Duration};
use std::io;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use p256::{ecdsa::SigningKey, elliptic_curve::JwkEcKey};
use p256::ecdsa::SigningKey;
use p256::elliptic_curve::JwkEcKey;
use rand::rngs::OsRng;
use tokio::net::{lookup_host, TcpStream};
use tracing::{debug, field::display, info};
use tracing::field::display;
use tracing::{debug, info};
use crate::{
auth::{
self,
backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo},
check_peer_addr_is_in_list, AuthError,
},
compute,
config::ProxyConfig,
context::RequestMonitoring,
control_plane::{
errors::{GetAuthInfoError, WakeComputeError},
locks::ApiLocks,
provider::ApiLockError,
CachedNodeInfo,
},
error::{ErrorKind, ReportableError, UserFacingError},
intern::EndpointIdInt,
proxy::{
connect_compute::ConnectMechanism,
retry::{CouldRetry, ShouldRetryWakeCompute},
},
rate_limiter::EndpointRateLimiter,
EndpointId, Host,
};
use super::{
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
http_conn_pool::{self, poll_http2_client},
local_conn_pool::{self, LocalClient, LocalConnPool},
};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
use super::http_conn_pool::{self, poll_http2_client};
use super::local_conn_pool::{self, LocalClient, LocalConnPool};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, check_peer_addr_is_in_list, AuthError};
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::provider::ApiLockError;
use crate::control_plane::CachedNodeInfo;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::{compute, EndpointId, Host};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,

View File

@@ -1,10 +1,8 @@
//! A set for cancelling random http connections
use std::{
hash::{BuildHasher, BuildHasherDefault},
num::NonZeroUsize,
time::Duration,
};
use std::hash::{BuildHasher, BuildHasherDefault};
use std::num::NonZeroUsize;
use std::time::Duration;
use indexmap::IndexMap;
use parking_lot::Mutex;

View File

@@ -1,33 +1,31 @@
use std::collections::HashMap;
use std::fmt;
use std::ops::Deref;
use std::pin::pin;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Weak};
use std::task::{ready, Poll};
use std::time::Duration;
use dashmap::DashMap;
use futures::{future::poll_fn, Future};
use futures::future::poll_fn;
use futures::Future;
use parking_lot::RwLock;
use rand::Rng;
use smallvec::SmallVec;
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
use std::{
fmt,
task::{ready, Poll},
};
use std::{
ops::Deref,
sync::atomic::{self, AtomicUsize},
};
use tokio::time::Instant;
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, info_span, warn, Instrument, Span};
use super::backend::HttpConnError;
use crate::auth::backend::ComputeUserInfo;
use crate::context::RequestMonitoring;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{
auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName,
};
use tracing::{debug, error, warn, Span};
use tracing::{info, info_span, Instrument};
use super::backend::HttpConnError;
use crate::{DbName, EndpointCacheKey, RoleName};
#[derive(Debug, Clone)]
pub(crate) struct ConnInfoWithAuth {
@@ -724,13 +722,13 @@ impl<C: ClientInnerExt> Drop for Client<C> {
#[cfg(test)]
mod tests {
use std::{mem, sync::atomic::AtomicBool};
use crate::{
proxy::NeonOptions, serverless::cancel_set::CancelSet, BranchId, EndpointId, ProjectId,
};
use std::mem;
use std::sync::atomic::AtomicBool;
use super::*;
use crate::proxy::NeonOptions;
use crate::serverless::cancel_set::CancelSet;
use crate::{BranchId, EndpointId, ProjectId};
struct MockClient(Arc<AtomicBool>);
impl MockClient {

View File

@@ -1,22 +1,21 @@
use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Weak};
use dashmap::DashMap;
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use rand::Rng;
use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::{sync::Arc, sync::Weak};
use tokio::net::TcpStream;
use tracing::{debug, error, info, info_span, Instrument};
use super::conn_pool::ConnInfo;
use crate::context::RequestMonitoring;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, EndpointCacheKey};
use tracing::{debug, error};
use tracing::{info, info_span, Instrument};
use super::conn_pool::ConnInfo;
use crate::EndpointCacheKey;
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
pub(crate) type Connect =

View File

@@ -1,12 +1,11 @@
//! Things stolen from `libs/utils/src/http` to add hyper 1.0 compatibility
//! Will merge back in at some point in the future.
use bytes::Bytes;
use anyhow::Context;
use bytes::Bytes;
use http::{Response, StatusCode};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use serde::Serialize;
use utils::http::error::ApiError;

View File

@@ -1,7 +1,5 @@
use serde_json::Map;
use serde_json::Value;
use tokio_postgres::types::Kind;
use tokio_postgres::types::Type;
use serde_json::{Map, Value};
use tokio_postgres::types::{Kind, Type};
use tokio_postgres::Row;
//
@@ -256,9 +254,10 @@ fn _pg_array_parse(
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use super::*;
#[test]
fn test_atomic_types_to_pg_params() {
let json = vec![Value::Bool(true), Value::Bool(false)];

View File

@@ -1,28 +1,31 @@
use futures::{future::poll_fn, Future};
use std::collections::HashMap;
use std::pin::pin;
use std::sync::{Arc, Weak};
use std::task::{ready, Poll};
use std::time::Duration;
use futures::future::poll_fn;
use futures::Future;
use indexmap::IndexMap;
use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding};
use p256::ecdsa::{Signature, SigningKey};
use parking_lot::RwLock;
use serde_json::value::RawValue;
use signature::Signer;
use std::task::{ready, Poll};
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
use tokio::time::Instant;
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::types::ToSql;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
use tokio_util::sync::CancellationToken;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::Metrics;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, DbName, RoleName};
use tracing::{error, warn, Span};
use tracing::{info, info_span, Instrument};
use tracing::{error, info, info_span, warn, Instrument, Span};
use super::backend::HttpConnError;
use super::conn_pool::{ClientInnerExt, ConnInfo};
use crate::context::RequestMonitoring;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::Metrics;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{DbName, RoleName};
struct ConnPoolEntry<C: ClientInnerExt> {
conn: ClientInner<C>,

View File

@@ -12,12 +12,15 @@ mod local_conn_pool;
mod sql_over_http;
mod websocket;
use std::net::{IpAddr, SocketAddr};
use std::pin::{pin, Pin};
use std::sync::Arc;
use anyhow::Context;
use async_trait::async_trait;
use atomic_take::AtomicTake;
use bytes::Bytes;
pub use conn_pool::GlobalConnPoolOptions;
use anyhow::Context;
use futures::future::{select, Either};
use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
@@ -29,9 +32,13 @@ use hyper_util::server::conn::auto::Builder;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tracing::{info, warn, Instrument};
use utils::http::error::ApiError;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
@@ -43,14 +50,6 @@ use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
use std::net::{IpAddr, SocketAddr};
use std::pin::{pin, Pin};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{info, warn, Instrument};
use utils::http::error::ApiError;
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub async fn task_main(

View File

@@ -2,77 +2,43 @@ use std::pin::pin;
use std::sync::Arc;
use bytes::Bytes;
use futures::future::select;
use futures::future::try_join;
use futures::future::Either;
use futures::StreamExt;
use futures::TryFutureExt;
use futures::future::{select, try_join, Either};
use futures::{StreamExt, TryFutureExt};
use http::header::AUTHORIZATION;
use http::Method;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper::body::Body;
use hyper::body::Incoming;
use hyper::header;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::Response;
use hyper::StatusCode;
use hyper::{HeaderMap, Request};
use http_body_util::{BodyExt, Full};
use hyper::body::{Body, Incoming};
use hyper::http::{HeaderName, HeaderValue};
use hyper::{header, HeaderMap, Request, Response, StatusCode};
use pq_proto::StartupMessageParamsBuilder;
use serde::Serialize;
use serde_json::Value;
use tokio::time;
use tokio_postgres::error::DbError;
use tokio_postgres::error::ErrorPosition;
use tokio_postgres::error::SqlState;
use tokio_postgres::GenericClient;
use tokio_postgres::IsolationLevel;
use tokio_postgres::NoTls;
use tokio_postgres::ReadyForQueryStatus;
use tokio_postgres::Transaction;
use tokio_postgres::error::{DbError, ErrorPosition, SqlState};
use tokio_postgres::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction};
use tokio_util::sync::CancellationToken;
use tracing::error;
use tracing::info;
use tracing::{error, info};
use typed_json::json;
use url::Url;
use urlencoding;
use utils::http::error::ApiError;
use crate::auth::backend::ComputeCredentialKeys;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::config::AuthenticationConfig;
use crate::config::HttpConfig;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::error::ErrorKind;
use crate::error::ReportableError;
use crate::error::UserFacingError;
use crate::metrics::HttpDirection;
use crate::metrics::Metrics;
use crate::proxy::run_until_cancelled;
use crate::proxy::NeonOptions;
use crate::serverless::backend::HttpConnError;
use crate::usage_metrics::MetricCounter;
use crate::usage_metrics::MetricCounterRecorder;
use crate::DbName;
use crate::RoleName;
use super::backend::LocalProxyConnError;
use super::backend::PoolingBackend;
use super::conn_pool;
use super::conn_pool::AuthData;
use super::conn_pool::ConnInfo;
use super::conn_pool::ConnInfoWithAuth;
use super::backend::{LocalProxyConnError, PoolingBackend};
use super::conn_pool::{AuthData, ConnInfo, ConnInfoWithAuth};
use super::http_util::json_response;
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::json::JsonConversionError;
use super::local_conn_pool;
use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError};
use super::{conn_pool, local_conn_pool};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::{endpoint_sni, ComputeUserInfoParseError};
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
use crate::context::RequestMonitoring;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::{HttpDirection, Metrics};
use crate::proxy::{run_until_cancelled, NeonOptions};
use crate::serverless::backend::HttpConnError;
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
use crate::{DbName, RoleName};
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]

View File

@@ -1,13 +1,7 @@
use crate::proxy::ErrorSource;
use crate::{
cancellation::CancellationHandlerMain,
config::ProxyConfig,
context::RequestMonitoring,
error::{io_error, ReportableError},
metrics::Metrics,
proxy::{handle_client, ClientMode},
rate_limiter::EndpointRateLimiter,
};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use anyhow::Context as _;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
@@ -15,15 +9,17 @@ use futures::{Sink, Stream};
use hyper::upgrade::OnUpgrade;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use std::{
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
};
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tracing::warn;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::error::{io_error, ReportableError};
use crate::metrics::Metrics;
use crate::proxy::{handle_client, ClientMode, ErrorSource};
use crate::rate_limiter::EndpointRateLimiter;
pin_project! {
/// This is a wrapper around a [`WebSocketStream`] that
/// implements [`AsyncRead`] and [`AsyncWrite`].
@@ -184,14 +180,11 @@ mod tests {
use framed_websockets::WebSocketServer;
use futures::{SinkExt, StreamExt};
use tokio::{
io::{duplex, AsyncReadExt, AsyncWriteExt},
task::JoinSet,
};
use tokio_tungstenite::{
tungstenite::{protocol::Role, Message},
WebSocketStream,
};
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::task::JoinSet;
use tokio_tungstenite::tungstenite::protocol::Role;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::WebSocketStream;
use super::WebSocketRw;

View File

@@ -1,19 +1,20 @@
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use bytes::BytesMut;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
use bytes::BytesMut;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::server::TlsStream;
use tracing::debug;
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
/// Stream wrapper which implements libpq's protocol.
///
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]

View File

@@ -1,36 +1,33 @@
//! Periodically collect proxy consumption metrics
//! and push them to a HTTP endpoint.
use crate::{
config::{MetricBackupCollectionConfig, MetricCollectionConfig},
context::parquet::{FAILED_UPLOAD_MAX_RETRIES, FAILED_UPLOAD_WARN_THRESHOLD},
http,
intern::{BranchIdInt, EndpointIdInt},
};
use std::convert::Infallible;
use std::pin::pin;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context;
use async_compression::tokio::write::GzipEncoder;
use bytes::Bytes;
use chrono::{DateTime, Datelike, Timelike, Utc};
use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE};
use dashmap::{mapref::entry::Entry, DashMap};
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use futures::future::select;
use once_cell::sync::Lazy;
use remote_storage::{GenericRemoteStorage, RemotePath, TimeoutOrCancel};
use serde::{Deserialize, Serialize};
use std::{
convert::Infallible,
pin::pin,
sync::{
atomic::{AtomicU64, AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use tokio::io::AsyncWriteExt;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, instrument, trace, warn};
use utils::backoff;
use uuid::{NoContext, Timestamp};
use crate::config::{MetricBackupCollectionConfig, MetricCollectionConfig};
use crate::context::parquet::{FAILED_UPLOAD_MAX_RETRIES, FAILED_UPLOAD_WARN_THRESHOLD};
use crate::http;
use crate::intern::{BranchIdInt, EndpointIdInt};
const PROXY_IO_BYTES_PER_CLIENT: &str = "proxy_io_bytes_per_client";
const HTTP_REPORTING_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
@@ -485,19 +482,23 @@ async fn upload_events_chunk(
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
use crate::{http, BranchId, EndpointId};
use anyhow::Error;
use chrono::Utc;
use consumption_metrics::{Event, EventChunk};
use http_body_util::BodyExt;
use hyper::{body::Incoming, server::conn::http1, service::service_fn, Request, Response};
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::rt::TokioIo;
use std::sync::{Arc, Mutex};
use tokio::net::TcpListener;
use url::Url;
use super::*;
use crate::{http, BranchId, EndpointId};
#[tokio::test]
async fn metrics() {
type Report = EventChunk<'static, Event<Ids, String>>;

View File

@@ -1,8 +1,9 @@
use std::pin::Pin;
use std::task;
use hashbrown::HashMap;
use parking_lot::Mutex;
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task;
use thiserror::Error;
use tokio::sync::oneshot;
@@ -99,9 +100,10 @@ impl<T> std::future::Future for Waiter<'_, T> {
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use super::*;
#[tokio::test]
async fn test_waiter() -> anyhow::Result<()> {
let waiters = Arc::new(Waiters::default());