IP allowlist on the proxy side (#5906)

## Problem

Per-project IP allowlist:
https://github.com/neondatabase/cloud/issues/8116

## Summary of changes

Implemented IP filtering on the proxy side. 

To retrieve ip allowlist for all scenarios, added `get_auth_info` call
to the control plane for:
* sql-over-http
* password_hack
* cleartext_hack

Added cache with ttl for sql-over-http path

This might slow down a bit, consider using redis in the future.

---------

Co-authored-by: Conrad Ludgate <conrad@neon.tech>
This commit is contained in:
Anna Khanova
2023-11-30 14:14:33 +01:00
committed by GitHub
parent 1e57ddaabc
commit e12e2681e9
23 changed files with 601 additions and 115 deletions

5
Cargo.lock generated
View File

@@ -2382,9 +2382,9 @@ dependencies = [
[[package]]
name = "ipnet"
version = "2.7.2"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12b6ee2129af8d4fb011108c73d99a1b83a85977f23b82460c0ae2e25bb4b57f"
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
[[package]]
name = "is-terminal"
@@ -3612,6 +3612,7 @@ dependencies = [
"humantime",
"hyper",
"hyper-tungstenite",
"ipnet",
"itertools",
"md5",
"metrics",

View File

@@ -88,6 +88,7 @@ humantime-serde = "1.1.1"
hyper = "0.14"
hyper-tungstenite = "0.11"
inotify = "0.10.2"
ipnet = "2.9.0"
itertools = "0.10"
jsonwebtoken = "8"
libc = "0.2"

View File

@@ -24,6 +24,7 @@ hostname.workspace = true
humantime.workspace = true
hyper-tungstenite.workspace = true
hyper.workspace = true
ipnet.workspace = true
itertools.workspace = true
md5.workspace = true
metrics.workspace = true

View File

@@ -4,7 +4,7 @@ pub mod backend;
pub use backend::BackendType;
mod credentials;
pub use credentials::ClientCredentials;
pub use credentials::{check_peer_addr_is_in_list, ClientCredentials};
mod password_hack;
pub use password_hack::parse_endpoint_param;
@@ -56,6 +56,12 @@ pub enum AuthErrorImpl {
/// Errors produced by e.g. [`crate::stream::PqStream`].
#[error(transparent)]
Io(#[from] io::Error),
#[error(
"This IP address is not allowed to connect to this endpoint. \
Please add it to the allowed list in the Neon console."
)]
IpAddressNotAllowed,
}
#[derive(Debug, Error)]
@@ -70,6 +76,10 @@ impl AuthError {
pub fn auth_failed(user: impl Into<Box<str>>) -> Self {
AuthErrorImpl::AuthFailed(user.into()).into()
}
pub fn ip_address_not_allowed() -> Self {
AuthErrorImpl::IpAddressNotAllowed.into()
}
}
impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
@@ -91,6 +101,7 @@ impl UserFacingError for AuthError {
MalformedPassword(_) => self.to_string(),
MissingEndpointName => self.to_string(),
Io(_) => "Internal error".to_string(),
IpAddressNotAllowed => self.to_string(),
}
}
}

View File

@@ -5,7 +5,12 @@ mod link;
pub use link::LinkAuthError;
use tokio_postgres::config::AuthKeys;
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::console::errors::GetAuthInfoError;
use crate::console::provider::AuthInfo;
use crate::console::AuthSecret;
use crate::proxy::{handle_try_wake, retry_after, LatencyTimer};
use crate::scram;
use crate::stream::Stream;
use crate::{
auth::{self, ClientCredentials},
@@ -20,6 +25,7 @@ use crate::{
use futures::TryFutureExt;
use std::borrow::Cow;
use std::ops::ControlFlow;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, warn};
@@ -64,6 +70,7 @@ pub enum BackendType<'a, T> {
pub trait TestBackend: Send + Sync + 'static {
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
fn get_allowed_ips(&self) -> Result<Arc<Vec<String>>, console::errors::GetAuthInfoError>;
}
impl std::fmt::Display for BackendType<'_, ()> {
@@ -140,14 +147,38 @@ async fn auth_quirks_creds(
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
if creds.project.is_none() {
let maybe_success = if creds.project.is_none() {
// Password will be checked by the compute node later.
return hacks::password_hack(creds, client, latency_timer).await;
}
Some(hacks::password_hack(creds, client, latency_timer).await?)
} else {
None
};
// Password hack should set the project name.
// TODO: make `creds.project` more type-safe.
assert!(creds.project.is_some());
info!("fetching user's authentication info");
// TODO(anna): this will slow down both "hacks" below; we probably need a cache.
let AuthInfo {
secret,
allowed_ips,
} = api.get_auth_info(extra, creds).await?;
// check allowed list
if !check_peer_addr_is_in_list(&creds.peer_addr.ip(), &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed());
}
let secret = secret.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthSecret::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
});
if let Some(success) = maybe_success {
return Ok(success);
}
// Perform cleartext auth if we're allowed to do that.
// Currently, we use it for websocket connections (latency).
@@ -157,7 +188,7 @@ async fn auth_quirks_creds(
}
// Finally, proceed with the main auth flow (SCRAM-based).
classic::authenticate(api, extra, creds, client, config, latency_timer).await
classic::authenticate(creds, client, config, latency_timer, secret).await
}
/// True to its name, this function encapsulates our current auth trade-offs.
@@ -305,6 +336,19 @@ impl BackendType<'_, ClientCredentials<'_>> {
Ok(res)
}
pub async fn get_allowed_ips(
&self,
extra: &ConsoleReqExtra<'_>,
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
use BackendType::*;
match self {
Console(api, creds) => api.get_allowed_ips(extra, creds).await,
Postgres(api, creds) => api.get_allowed_ips(extra, creds).await,
Link(_) => Ok(Arc::new(vec![])),
Test(x) => x.get_allowed_ips(),
}
}
/// When applicable, wake the compute node, gaining its connection info in the process.
/// The link auth flow doesn't support this, so we return [`None`] in that case.
pub async fn wake_compute(

View File

@@ -3,38 +3,28 @@ use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute,
config::AuthenticationConfig,
console::{self, AuthInfo, ConsoleReqExtra},
console::AuthSecret,
proxy::LatencyTimer,
sasl, scram,
sasl,
stream::{PqStream, Stream},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
pub(super) async fn authenticate(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,
secret: AuthSecret,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
info!("fetching user's authentication info");
let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
});
let flow = AuthFlow::new(client);
let scram_keys = match info {
AuthInfo::Md5(_) => {
let scram_keys = match secret {
AuthSecret::Md5(_) => {
info!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));
}
AuthInfo::Scram(secret) => {
AuthSecret::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret);

View File

@@ -7,9 +7,12 @@ use crate::{
};
use itertools::Itertools;
use pq_proto::StartupMessageParams;
use std::collections::HashSet;
use std::{
collections::HashSet,
net::{IpAddr, SocketAddr},
};
use thiserror::Error;
use tracing::info;
use tracing::{info, warn};
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub enum ClientCredsParseError {
@@ -44,6 +47,7 @@ pub struct ClientCredentials<'a> {
pub project: Option<String>,
pub cache_key: String,
pub peer_addr: SocketAddr,
}
impl ClientCredentials<'_> {
@@ -54,19 +58,11 @@ impl ClientCredentials<'_> {
}
impl<'a> ClientCredentials<'a> {
#[cfg(test)]
pub fn new_noop() -> Self {
ClientCredentials {
user: "",
project: None,
cache_key: "".to_string(),
}
}
pub fn parse(
params: &'a StartupMessageParams,
sni: Option<&str>,
common_names: Option<HashSet<String>>,
peer_addr: SocketAddr,
) -> Result<Self, ClientCredsParseError> {
use ClientCredsParseError::*;
@@ -153,10 +149,59 @@ impl<'a> ClientCredentials<'a> {
user,
project,
cache_key,
peer_addr,
})
}
}
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec<String>) -> bool {
if ip_list.is_empty() {
return true;
}
for ip in ip_list {
// We expect that all ip addresses from control plane are correct.
// However, if some of them are broken, we still can check the others.
match parse_ip_pattern(ip) {
Ok(pattern) => {
if check_ip(peer_addr, &pattern) {
return true;
}
}
Err(err) => warn!("Cannot parse ip: {}; err: {}", ip, err),
}
}
false
}
#[derive(Debug, Clone, Eq, PartialEq)]
enum IpPattern {
Subnet(ipnet::IpNet),
Range(IpAddr, IpAddr),
Single(IpAddr),
}
fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
if pattern.contains('/') {
let subnet: ipnet::IpNet = pattern.parse()?;
return Ok(IpPattern::Subnet(subnet));
}
if let Some((start, end)) = pattern.split_once('-') {
let start: IpAddr = start.parse()?;
let end: IpAddr = end.parse()?;
return Ok(IpPattern::Range(start, end));
}
let addr: IpAddr = pattern.parse()?;
Ok(IpPattern::Single(addr))
}
fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
match pattern {
IpPattern::Subnet(subnet) => subnet.contains(ip),
IpPattern::Range(start, end) => start <= ip && ip <= end,
IpPattern::Single(addr) => addr == ip,
}
}
fn project_name_valid(name: &str) -> bool {
name.chars().all(|c| c.is_alphanumeric() || c == '-')
}
@@ -176,8 +221,8 @@ mod tests {
fn parse_bare_minimum() -> anyhow::Result<()> {
// According to postgresql, only `user` should be required.
let options = StartupMessageParams::new([("user", "john_doe")]);
let creds = ClientCredentials::parse(&options, None, None)?;
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project, None);
@@ -191,8 +236,8 @@ mod tests {
("database", "world"), // should be ignored
("foo", "bar"), // should be ignored
]);
let creds = ClientCredentials::parse(&options, None, None)?;
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project, None);
@@ -206,7 +251,8 @@ mod tests {
let sni = Some("foo.localhost");
let common_names = Some(["localhost".into()].into());
let creds = ClientCredentials::parse(&options, sni, common_names)?;
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("foo"));
assert_eq!(creds.cache_key, "foo");
@@ -221,7 +267,8 @@ mod tests {
("options", "-ckey=1 project=bar -c geqo=off"),
]);
let creds = ClientCredentials::parse(&options, None, None)?;
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("bar"));
@@ -235,7 +282,8 @@ mod tests {
("options", "-ckey=1 endpoint=bar -c geqo=off"),
]);
let creds = ClientCredentials::parse(&options, None, None)?;
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("bar"));
@@ -252,7 +300,8 @@ mod tests {
),
]);
let creds = ClientCredentials::parse(&options, None, None)?;
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
assert_eq!(creds.user, "john_doe");
assert!(creds.project.is_none());
@@ -266,7 +315,8 @@ mod tests {
("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
]);
let creds = ClientCredentials::parse(&options, None, None)?;
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
assert_eq!(creds.user, "john_doe");
assert!(creds.project.is_none());
@@ -280,7 +330,8 @@ mod tests {
let sni = Some("baz.localhost");
let common_names = Some(["localhost".into()].into());
let creds = ClientCredentials::parse(&options, sni, common_names)?;
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("baz"));
@@ -293,12 +344,14 @@ mod tests {
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.a.com");
let creds = ClientCredentials::parse(&options, sni, common_names)?;
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
assert_eq!(creds.project.as_deref(), Some("p1"));
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.b.com");
let creds = ClientCredentials::parse(&options, sni, common_names)?;
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
assert_eq!(creds.project.as_deref(), Some("p1"));
Ok(())
@@ -312,7 +365,9 @@ mod tests {
let sni = Some("second.localhost");
let common_names = Some(["localhost".into()].into());
let err = ClientCredentials::parse(&options, sni, common_names).expect_err("should fail");
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let err = ClientCredentials::parse(&options, sni, common_names, peer_addr)
.expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
assert_eq!(option, "first");
@@ -329,7 +384,9 @@ mod tests {
let sni = Some("project.localhost");
let common_names = Some(["example.com".into()].into());
let err = ClientCredentials::parse(&options, sni, common_names).expect_err("should fail");
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let err = ClientCredentials::parse(&options, sni, common_names, peer_addr)
.expect_err("should fail");
match err {
UnknownCommonName { cn } => {
assert_eq!(cn, "localhost");
@@ -347,7 +404,8 @@ mod tests {
let sni = Some("project.localhost");
let common_names = Some(["localhost".into()].into());
let creds = ClientCredentials::parse(&options, sni, common_names)?;
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
assert_eq!(creds.project.as_deref(), Some("project"));
assert_eq!(
creds.cache_key,
@@ -356,4 +414,91 @@ mod tests {
Ok(())
}
#[test]
fn test_check_peer_addr_is_in_list() {
let peer_addr = IpAddr::from([127, 0, 0, 1]);
assert!(check_peer_addr_is_in_list(&peer_addr, &vec![]));
assert!(check_peer_addr_is_in_list(
&peer_addr,
&vec!["127.0.0.1".into()]
));
assert!(!check_peer_addr_is_in_list(
&peer_addr,
&vec!["8.8.8.8".into()]
));
// If there is an incorrect address, it will be skipped.
assert!(check_peer_addr_is_in_list(
&peer_addr,
&vec!["88.8.8".into(), "127.0.0.1".into()]
));
}
#[test]
fn test_parse_ip_v4() -> anyhow::Result<()> {
let peer_addr = IpAddr::from([127, 0, 0, 1]);
// Ok
assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
assert_eq!(
parse_ip_pattern("127.0.0.1/31")?,
IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
);
assert_eq!(
parse_ip_pattern("0.0.0.0-200.0.1.2")?,
IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
);
// Error
assert!(parse_ip_pattern("300.0.1.2").is_err());
assert!(parse_ip_pattern("30.1.2").is_err());
assert!(parse_ip_pattern("127.0.0.1/33").is_err());
assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
Ok(())
}
#[test]
fn test_check_ipv4() -> anyhow::Result<()> {
let peer_addr = IpAddr::from([127, 0, 0, 1]);
let peer_addr_next = IpAddr::from([127, 0, 0, 2]);
let peer_addr_prev = IpAddr::from([127, 0, 0, 0]);
// Success
assert!(check_ip(&peer_addr, &IpPattern::Single(peer_addr)));
assert!(check_ip(
&peer_addr,
&IpPattern::Subnet(ipnet::IpNet::new(peer_addr_prev, 31)?)
));
assert!(check_ip(
&peer_addr,
&IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 30)?)
));
assert!(check_ip(
&peer_addr,
&IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
));
assert!(check_ip(
&peer_addr,
&IpPattern::Range(peer_addr, peer_addr)
));
// Not success
assert!(!check_ip(&peer_addr, &IpPattern::Single(peer_addr_prev)));
assert!(!check_ip(
&peer_addr,
&IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 31)?)
));
assert!(!check_ip(
&peer_addr,
&IpPattern::Range(IpAddr::from([0, 0, 0, 0]), peer_addr_prev)
));
assert!(!check_ip(
&peer_addr,
&IpPattern::Range(peer_addr_next, IpAddr::from([128, 0, 0, 0]))
));
// There is no check that for range start <= end. But it's fine as long as for all this cases the result is false.
assert!(!check_ip(
&peer_addr,
&IpPattern::Range(peer_addr, peer_addr_prev)
));
Ok(())
}
}

View File

@@ -1,8 +1,11 @@
use futures::future::Either;
use proxy::auth;
use proxy::config::AuthenticationConfig;
use proxy::config::CacheOptions;
use proxy::config::HttpConfig;
use proxy::console;
use proxy::console::provider::AllowedIpsCache;
use proxy::console::provider::NodeInfoCache;
use proxy::http;
use proxy::rate_limiter::RateLimiterConfig;
use proxy::usage_metrics;
@@ -113,6 +116,12 @@ struct ProxyCliArgs {
initial_limit: usize,
#[clap(flatten)]
aimd_config: proxy::rate_limiter::AimdConfig,
/// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO)]
allowed_ips_cache: String,
/// disable ip check for http requests. If it is too time consuming, it could be turned off.
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
disable_ip_check_for_http: bool,
}
#[tokio::main]
@@ -241,11 +250,24 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let auth_backend = match &args.auth_backend {
AuthBackend::Console => {
let config::CacheOptions { size, ttl } = args.wake_compute_cache.parse()?;
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let allowed_ips_cache_config: CacheOptions = args.allowed_ips_cache.parse()?;
info!("Using NodeInfoCache (wake_compute) with size={size} ttl={ttl:?}");
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!("Using AllowedIpsCache (wake_compute) with options={allowed_ips_cache_config:?}");
let caches = Box::leak(Box::new(console::caches::ApiCaches {
node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl),
node_info: NodeInfoCache::new(
"node_info_cache",
wake_compute_cache_config.size,
wake_compute_cache_config.ttl,
true,
),
allowed_ips: AllowedIpsCache::new(
"allowed_ips_cache",
allowed_ips_cache_config.size,
allowed_ips_cache_config.ttl,
false,
),
}));
let config::WakeComputeLockOptions {
@@ -292,6 +314,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
http_config,
authentication_config,
require_client_ip: args.require_client_ip,
disable_ip_check_for_http: args.disable_ip_check_for_http,
}));
Ok(config)

View File

@@ -55,7 +55,7 @@ pub mod timed_lru {
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
///
/// * When the entry is about to be retrieved, we check its expiration timestamp.
/// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp.
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
/// its existence.
@@ -79,6 +79,8 @@ pub mod timed_lru {
/// Default time-to-live of a single entry.
ttl: Duration,
update_ttl_on_retrieval: bool,
}
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
@@ -99,11 +101,17 @@ pub mod timed_lru {
impl<K: Hash + Eq, V> TimedLru<K, V> {
/// Construct a new LRU cache with timed entries.
pub fn new(name: &'static str, capacity: usize, ttl: Duration) -> Self {
pub fn new(
name: &'static str,
capacity: usize,
ttl: Duration,
update_ttl_on_retrieval: bool,
) -> Self {
Self {
name,
cache: LruCache::new(capacity).into(),
ttl,
update_ttl_on_retrieval,
}
}
@@ -165,7 +173,9 @@ pub mod timed_lru {
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
// Update the deadline and the entry's position in the LRU list.
raw_entry.get_mut().expires_at = deadline;
if self.update_ttl_on_retrieval {
raw_entry.get_mut().expires_at = deadline;
}
raw_entry.to_back();
drop(cache); // drop lock before logging

View File

@@ -19,6 +19,7 @@ pub struct ProxyConfig {
pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig,
pub require_client_ip: bool,
pub disable_ip_check_for_http: bool,
}
#[derive(Debug)]
@@ -298,6 +299,7 @@ impl CertResolver {
}
/// Helper for cmdline cache options parsing.
#[derive(Debug)]
pub struct CacheOptions {
/// Max number of entries.
pub size: usize,

View File

@@ -6,7 +6,7 @@ pub mod messages;
/// Wrappers for console APIs and their mocks.
pub mod provider;
pub use provider::{errors, Api, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo};
pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo};
/// Various cache-related types.
pub mod caches {

View File

@@ -204,7 +204,7 @@ pub struct ConsoleReqExtra<'a> {
}
/// Auth secret which is managed by the cloud.
pub enum AuthInfo {
pub enum AuthSecret {
/// Md5 hash of user's password.
Md5([u8; 16]),
@@ -212,6 +212,13 @@ pub enum AuthInfo {
Scram(scram::ServerSecret),
}
#[derive(Default)]
pub struct AuthInfo {
pub secret: Option<AuthSecret>,
/// List of IP addresses allowed for the autorization.
pub allowed_ips: Vec<String>,
}
/// Info for establishing a connection to a compute node.
/// This is what we get after auth succeeded, but not before!
#[derive(Clone)]
@@ -230,6 +237,7 @@ pub struct NodeInfo {
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
pub type AllowedIpsCache = TimedLru<Arc<str>, Arc<Vec<String>>>;
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.
@@ -240,7 +248,13 @@ pub trait Api {
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials,
) -> Result<Option<AuthInfo>, errors::GetAuthInfoError>;
) -> Result<AuthInfo, errors::GetAuthInfoError>;
async fn get_allowed_ips(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials,
) -> Result<Arc<Vec<String>>, errors::GetAuthInfoError>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
@@ -254,6 +268,8 @@ pub trait Api {
pub struct ApiCaches {
/// Cache for the `wake_compute` API method.
pub node_info: NodeInfoCache,
/// Cache for the `get_allowed_ips`. TODO(anna): use notifications listener instead.
pub allowed_ips: TimedLru<Arc<str>, Arc<Vec<String>>>,
}
/// Various caches for [`console`](super).

View File

@@ -1,14 +1,16 @@
//! Mock console backend which relies on a user-provided postgres instance.
use std::sync::Arc;
use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
use async_trait::async_trait;
use futures::TryFutureExt;
use thiserror::Error;
use tokio_postgres::config::SslMode;
use tokio_postgres::{config::SslMode, Client};
use tracing::{error, info, info_span, warn, Instrument};
#[derive(Debug, Error)]
@@ -46,8 +48,8 @@ impl Api {
async fn do_get_auth_info(
&self,
creds: &ClientCredentials<'_>,
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
async {
) -> Result<AuthInfo, GetAuthInfoError> {
let (secret, allowed_ips) = async {
// Perhaps we could persist this connection, but then we'd have to
// write more code for reopening it if it got closed, which doesn't
// seem worth it.
@@ -55,32 +57,48 @@ impl Api {
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
tokio::spawn(connection);
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
let rows = client.query(query, &[&creds.user]).await?;
// We can get at most one row, because `rolname` is unique.
let row = match rows.first() {
Some(row) => row,
// This means that the user doesn't exist, so there can be no secret.
// However, this is still a *valid* outcome which is very similar
// to getting `404 Not found` from the Neon console.
let secret = match get_execute_postgres_query(
&client,
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
&[&creds.user],
"rolpassword",
)
.await?
{
Some(entry) => {
info!("got a secret: {entry}"); // safe since it's not a prod scenario
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
}
None => {
warn!("user '{}' does not exist", creds.user);
return Ok(None);
None
}
};
let allowed_ips = match get_execute_postgres_query(
&client,
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
&[&creds.project.clone().unwrap_or_default().as_str()],
"allowed_ips",
)
.await?
{
Some(s) => {
info!("got allowed_ips: {s}");
s.split(',').map(String::from).collect()
}
None => vec![],
};
let entry = row
.try_get("rolpassword")
.map_err(MockApiError::PasswordNotSet)?;
info!("got a secret: {entry}"); // safe since it's not a prod scenario
let secret = scram::ServerSecret::parse(entry).map(AuthInfo::Scram);
Ok(secret.or_else(|| parse_md5(entry).map(AuthInfo::Md5)))
Ok((secret, allowed_ips))
}
.map_err(crate::error::log_error)
.map_err(crate::error::log_error::<GetAuthInfoError>)
.instrument(info_span!("postgres", url = self.endpoint.as_str()))
.await
.await?;
Ok(AuthInfo {
secret,
allowed_ips,
})
}
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
@@ -100,6 +118,27 @@ impl Api {
}
}
async fn get_execute_postgres_query(
client: &Client,
query: &str,
params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
idx: &str,
) -> Result<Option<String>, GetAuthInfoError> {
let rows = client.query(query, params).await?;
// We can get at most one row, because `rolname` is unique.
let row = match rows.first() {
Some(row) => row,
// This means that the user doesn't exist, so there can be no secret.
// However, this is still a *valid* outcome which is very similar
// to getting `404 Not found` from the Neon console.
None => return Ok(None),
};
let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
Ok(Some(entry))
}
#[async_trait]
impl super::Api for Api {
#[tracing::instrument(skip_all)]
@@ -107,10 +146,18 @@ impl super::Api for Api {
&self,
_extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials,
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
) -> Result<AuthInfo, GetAuthInfoError> {
self.do_get_auth_info(creds).await
}
async fn get_allowed_ips(
&self,
_extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials,
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips))
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,

View File

@@ -3,11 +3,17 @@
use super::{
super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
errors::{ApiError, GetAuthInfoError, WakeComputeError},
ApiCaches, ApiLocks, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
use crate::{
auth::ClientCredentials,
compute, http,
proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
scram,
};
use crate::{auth::ClientCredentials, compute, http, scram};
use async_trait::async_trait;
use futures::TryFutureExt;
use itertools::Itertools;
use std::{net::SocketAddr, sync::Arc};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
@@ -48,7 +54,7 @@ impl Api {
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
) -> Result<AuthInfo, GetAuthInfoError> {
let request_id = uuid::Uuid::new_v4().to_string();
async {
let request = self
@@ -72,16 +78,25 @@ impl Api {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
Err(e) => match e.http_status_code() {
Some(http::StatusCode::NOT_FOUND) => return Ok(None),
Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()),
_otherwise => return Err(e.into()),
},
};
let secret = scram::ServerSecret::parse(&body.role_secret)
.map(AuthInfo::Scram)
.map(AuthSecret::Scram)
.ok_or(GetAuthInfoError::BadSecret)?;
Ok(Some(secret))
let allowed_ips = body
.allowed_ips
.into_iter()
.flatten()
.map(String::from)
.collect_vec();
ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
Ok(AuthInfo {
secret: Some(secret),
allowed_ips,
})
}
.map_err(crate::error::log_error)
.instrument(info_span!("http", id = request_id))
@@ -148,10 +163,32 @@ impl super::Api for Api {
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials,
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
) -> Result<AuthInfo, GetAuthInfoError> {
self.do_get_auth_info(extra, creds).await
}
async fn get_allowed_ips(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials,
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
let key: &str = creds.project().expect("impossible");
if let Some(allowed_ips) = self.caches.allowed_ips.get(key) {
ALLOWED_IPS_BY_CACHE_OUTCOME
.with_label_values(&["hit"])
.inc();
return Ok(Arc::new(allowed_ips.to_vec()));
}
ALLOWED_IPS_BY_CACHE_OUTCOME
.with_label_values(&["miss"])
.inc();
let allowed_ips = Arc::new(self.do_get_auth_info(extra, creds).await?.allowed_ips);
self.caches
.allowed_ips
.insert(key.into(), allowed_ips.clone());
Ok(allowed_ips)
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,

View File

@@ -13,7 +13,7 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::time::Instant;
use tracing::trace;
use crate::{rate_limiter, url::ApiUrl};
use crate::{proxy::CONSOLE_REQUEST_LATENCY, rate_limiter, url::ApiUrl};
use reqwest_middleware::RequestBuilder;
/// This is the preferred way to create new http clients,
@@ -90,7 +90,13 @@ impl Endpoint {
/// Execute a [request](reqwest::Request).
pub async fn execute(&self, request: Request) -> Result<Response, Error> {
self.client.execute(request).await
let path = request.url().path().to_string();
let start = Instant::now();
let res = self.client.execute(request).await;
CONSOLE_REQUEST_LATENCY
.with_label_values(&[&path])
.observe(start.elapsed().as_micros() as f64);
res
}
}

View File

@@ -24,7 +24,7 @@ use prometheus::{
IntGaugeVec,
};
use regex::Regex;
use std::{error::Error, io, ops::ControlFlow, sync::Arc, time::Instant};
use std::{error::Error, io, net::SocketAddr, ops::ControlFlow, sync::Arc, time::Instant};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
time,
@@ -110,12 +110,34 @@ static COMPUTE_CONNECTION_LATENCY: Lazy<HistogramVec> = Lazy::new(|| {
.unwrap()
});
pub static CONSOLE_REQUEST_LATENCY: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"proxy_console_request_latency",
"Time it took for proxy to establish a connection to the compute endpoint",
// proxy_wake_compute/proxy_get_role_info
&["request"],
// largest bucket = 2^16 * 0.2ms = 13s
exponential_buckets(0.2, 2.0, 16).unwrap(),
)
.unwrap()
});
pub static ALLOWED_IPS_BY_CACHE_OUTCOME: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_allowed_ips_cache_misses",
"Number of cache hits/misses for allowed ips",
// hit/miss
&["outcome"],
)
.unwrap()
});
pub static RATE_LIMITER_ACQUIRE_LATENCY: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"semaphore_control_plane_token_acquire_seconds",
"Time it took for proxy to establish a connection to the compute endpoint",
// largest bucket = 2^16 * 0.5ms = 32s
exponential_buckets(0.0005, 2.0, 16).unwrap(),
// largest bucket = 3^16 * 0.00005s = 3.28s
exponential_buckets(0.00005, 3.0, 16).unwrap(),
)
.unwrap()
});
@@ -138,6 +160,15 @@ pub static NUM_CONNECTION_ACCEPTED_BY_SNI: Lazy<IntCounterVec> = Lazy::new(|| {
.unwrap()
});
pub static ALLOWED_IPS_NUMBER: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"proxy_allowed_ips_number",
"Number of allowed ips",
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0],
)
.unwrap()
});
pub struct LatencyTimer {
// time since the stopwatch was started
start: Option<Instant>,
@@ -265,7 +296,7 @@ pub async fn task_main(
loop {
tokio::select! {
accept_result = listener.accept() => {
let (socket, _) = accept_result?;
let (socket, peer_addr) = accept_result?;
let session_id = uuid::Uuid::new_v4();
let cancel_map = Arc::clone(&cancel_map);
@@ -274,7 +305,9 @@ pub async fn task_main(
info!("accepted postgres client connection");
let mut socket = WithClientIp::new(socket);
let mut peer_addr = peer_addr;
if let Some(ip) = socket.wait_for_addr().await? {
peer_addr = ip;
tracing::Span::current().record("peer_addr", &tracing::field::display(ip));
} else if config.require_client_ip {
bail!("missing required client IP");
@@ -285,7 +318,7 @@ pub async fn task_main(
.set_nodelay(true)
.context("failed to set socket option")?;
handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp).await
handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp, peer_addr).await
}
.instrument(info_span!("handle_client", ?session_id, peer_addr = tracing::field::Empty))
.unwrap_or_else(move |e| {
@@ -375,6 +408,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
session_id: uuid::Uuid,
stream: S,
mode: ClientMode,
peer_addr: SocketAddr,
) -> anyhow::Result<()> {
info!(
protocol = mode.protocol_label(),
@@ -408,7 +442,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, hostname, common_names))
.map(|_| auth::ClientCredentials::parse(&params, hostname, common_names, peer_addr))
.transpose();
match result {

View File

@@ -466,6 +466,10 @@ impl TestBackend for TestConnectMechanism {
x => panic!("expecting action {:?}, wake_compute is called instead", x),
}
}
fn get_allowed_ips(&self) -> Result<Arc<Vec<String>>, console::errors::GetAuthInfoError> {
unimplemented!("not used in tests")
}
}
fn helper_create_cached_node_info() -> CachedNodeInfo {

View File

@@ -23,6 +23,7 @@ use hyper::{
Body, Method, Request, Response,
};
use std::net::SocketAddr;
use std::task::Poll;
use std::{future::ready, sync::Arc};
use tls_listener::TlsListener;
@@ -102,7 +103,7 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
request_handler(
req, config, conn_pool, cancel_map, session_id, sni_name,
req, config, conn_pool, cancel_map, session_id, sni_name, peer_addr,
)
.instrument(info_span!(
"serverless",
@@ -170,6 +171,7 @@ async fn request_handler(
cancel_map: Arc<CancelMap>,
session_id: uuid::Uuid,
sni_hostname: Option<String>,
peer_addr: SocketAddr,
) -> Result<Response<Body>, ApiError> {
let host = request
.headers()
@@ -187,9 +189,15 @@ async fn request_handler(
tokio::spawn(
async move {
if let Err(e) =
websocket::serve_websocket(websocket, config, &cancel_map, session_id, host)
.await
if let Err(e) = websocket::serve_websocket(
websocket,
config,
&cancel_map,
session_id,
host,
peer_addr,
)
.await
{
error!(session_id = ?session_id, "error in websocket connection: {e:#}");
}
@@ -205,6 +213,7 @@ async fn request_handler(
sni_hostname,
conn_pool,
session_id,
peer_addr,
&config.http_config,
)
.await

View File

@@ -8,7 +8,7 @@ use pbkdf2::{
Params, Pbkdf2,
};
use pq_proto::StartupMessageParams;
use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use std::{
fmt,
task::{ready, Poll},
@@ -21,7 +21,8 @@ use tokio::time;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
use crate::{
auth, console,
auth::{self, check_peer_addr_is_in_list},
console,
proxy::{
neon_options, LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER,
NUM_DB_CONNECTIONS_OPENED_COUNTER,
@@ -144,6 +145,7 @@ impl GlobalConnPool {
conn_info: &ConnInfo,
force_new: bool,
session_id: uuid::Uuid,
peer_addr: SocketAddr,
) -> anyhow::Result<Client> {
let mut client: Option<ClientInner> = None;
let mut latency_timer = LatencyTimer::new("http");
@@ -203,6 +205,7 @@ impl GlobalConnPool {
conn_id,
session_id,
latency_timer,
peer_addr,
)
.await
} else {
@@ -225,6 +228,7 @@ impl GlobalConnPool {
conn_id,
session_id,
latency_timer,
peer_addr,
)
.await
};
@@ -401,6 +405,7 @@ async fn connect_to_compute(
conn_id: uuid::Uuid,
session_id: uuid::Uuid,
latency_timer: LatencyTimer,
peer_addr: SocketAddr,
) -> anyhow::Result<ClientInner> {
let tls = config.tls_config.as_ref();
let common_names = tls.and_then(|tls| tls.common_names.clone());
@@ -411,12 +416,13 @@ async fn connect_to_compute(
("application_name", APP_NAME),
("options", conn_info.options.as_deref().unwrap_or("")),
]);
let creds = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, Some(&conn_info.hostname), common_names))
.transpose()?;
let creds = auth::ClientCredentials::parse(
&params,
Some(&conn_info.hostname),
common_names,
peer_addr,
)?;
let backend = config.auth_backend.as_ref().map(|_| creds);
let console_options = neon_options(&params);
@@ -425,8 +431,14 @@ async fn connect_to_compute(
application_name: Some(APP_NAME),
options: console_options.as_deref(),
};
let node_info = creds
// TODO(anna): this is a bit hacky way, consider using console notification listener.
if !config.disable_ip_check_for_http {
let allowed_ips = backend.get_allowed_ips(&extra).await?;
if !check_peer_addr_is_in_list(&peer_addr.ip(), &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed().into());
}
}
let node_info = backend
.wake_compute(&extra)
.await?
.context("missing cache entry from wake_compute")?;
@@ -439,7 +451,7 @@ async fn connect_to_compute(
},
node_info,
&extra,
&creds,
&backend,
latency_timer,
)
.await

View File

@@ -1,3 +1,4 @@
use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::bail;
@@ -201,11 +202,19 @@ pub async fn handle(
sni_hostname: Option<String>,
conn_pool: Arc<GlobalConnPool>,
session_id: uuid::Uuid,
peer_addr: SocketAddr,
config: &'static HttpConfig,
) -> Result<Response<Body>, ApiError> {
let result = tokio::time::timeout(
config.timeout,
handle_inner(config, request, sni_hostname, conn_pool, session_id),
handle_inner(
config,
request,
sni_hostname,
conn_pool,
session_id,
peer_addr,
),
)
.await;
let mut response = match result {
@@ -292,6 +301,7 @@ async fn handle_inner(
sni_hostname: Option<String>,
conn_pool: Arc<GlobalConnPool>,
session_id: uuid::Uuid,
peer_addr: SocketAddr,
) -> anyhow::Result<Response<Body>> {
NUM_CONNECTIONS_ACCEPTED_COUNTER
.with_label_values(&["http"])
@@ -351,7 +361,9 @@ async fn handle_inner(
let body = hyper::body::to_bytes(request.into_body()).await?;
let payload: Payload = serde_json::from_slice(&body)?;
let mut client = conn_pool.get(&conn_info, !allow_pool, session_id).await?;
let mut client = conn_pool
.get(&conn_info, !allow_pool, session_id, peer_addr)
.await?;
let mut response = Response::builder()
.status(StatusCode::OK)

View File

@@ -11,6 +11,7 @@ use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
use pin_project_lite::pin_project;
use std::{
net::SocketAddr,
pin::Pin,
task::{ready, Context, Poll},
};
@@ -132,6 +133,7 @@ pub async fn serve_websocket(
cancel_map: &CancelMap,
session_id: uuid::Uuid,
hostname: Option<String>,
peer_addr: SocketAddr,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
handle_client(
@@ -140,6 +142,7 @@ pub async fn serve_websocket(
session_id,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
peer_addr,
)
.await?;
Ok(())

View File

@@ -2390,6 +2390,10 @@ def static_proxy(
# For simplicity, we use the same user for both `--auth-endpoint` and `safe_psql`
vanilla_pg.start()
vanilla_pg.safe_psql("create user proxy with login superuser password 'password'")
vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS neon_control_plane")
vanilla_pg.safe_psql(
"CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))"
)
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()

View File

@@ -0,0 +1,74 @@
import psycopg2
import pytest
from fixtures.neon_fixtures import (
NeonProxy,
VanillaPostgres,
)
TABLE_NAME = "neon_control_plane.endpoints"
# Proxy uses the same logic for psql and websockets.
@pytest.mark.asyncio
async def test_proxy_psql_allowed_ips(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres):
# Shouldn't be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')"
)
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')"
)
def check_cannot_connect(**kwargs):
with pytest.raises(psycopg2.Error) as exprinfo:
static_proxy.safe_psql(**kwargs)
text = str(exprinfo.value).strip()
assert "This IP address is not allowed to connect" in text
# no SNI, deprecated `options=project` syntax (before we had several endpoint in project)
check_cannot_connect(query="select 1", sslsni=0, options="project=private-project")
# no SNI, new `options=endpoint` syntax
check_cannot_connect(query="select 1", sslsni=0, options="endpoint=private-project")
# with SNI
check_cannot_connect(query="select 1", host="private-project.localtest.me")
# no SNI, deprecated `options=project` syntax (before we had several endpoint in project)
out = static_proxy.safe_psql(query="select 1", sslsni=0, options="project=generic-project")
assert out[0][0] == 1
# no SNI, new `options=endpoint` syntax
out = static_proxy.safe_psql(query="select 1", sslsni=0, options="endpoint=generic-project")
assert out[0][0] == 1
# with SNI
out = static_proxy.safe_psql(query="select 1", host="generic-project.localtest.me")
assert out[0][0] == 1
@pytest.mark.asyncio
async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres):
static_proxy.safe_psql("create user http_auth with password 'http' superuser")
# Shouldn't be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')"
)
def query(status: int, query: str, *args):
static_proxy.http_query(
query,
args,
user="http_auth",
password="http",
expected_code=status,
)
query(400, "select 1;") # ip address is not allowed
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'"
)
query(200, "select 1;") # should work now