mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-05 12:32:54 +00:00
## Problem We want to see how many users of the legacy serverless driver are still using the old URL for SQL-over-HTTP traffic. ## Summary of changes Adds a protocol field to the connections_by_sni metric. Ensures it's incremented for sql-over-http.
555 lines
18 KiB
Rust
555 lines
18 KiB
Rust
//! User credentials used in authentication.
|
|
|
|
use std::collections::HashSet;
|
|
use std::net::IpAddr;
|
|
use std::str::FromStr;
|
|
|
|
use itertools::Itertools;
|
|
use pq_proto::StartupMessageParams;
|
|
use thiserror::Error;
|
|
use tracing::{debug, warn};
|
|
|
|
use crate::auth::password_hack::parse_endpoint_param;
|
|
use crate::context::RequestContext;
|
|
use crate::error::{ReportableError, UserFacingError};
|
|
use crate::metrics::{Metrics, SniGroup, SniKind};
|
|
use crate::proxy::NeonOptions;
|
|
use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI};
|
|
use crate::types::{EndpointId, RoleName};
|
|
|
|
#[derive(Debug, Error, PartialEq, Eq, Clone)]
|
|
pub(crate) enum ComputeUserInfoParseError {
|
|
#[error("Parameter '{0}' is missing in startup packet.")]
|
|
MissingKey(&'static str),
|
|
|
|
#[error(
|
|
"Inconsistent project name inferred from \
|
|
SNI ('{}') and project option ('{}').",
|
|
.domain, .option,
|
|
)]
|
|
InconsistentProjectNames {
|
|
domain: EndpointId,
|
|
option: EndpointId,
|
|
},
|
|
|
|
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
|
|
MalformedProjectName(EndpointId),
|
|
}
|
|
|
|
impl UserFacingError for ComputeUserInfoParseError {}
|
|
|
|
impl ReportableError for ComputeUserInfoParseError {
|
|
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
|
crate::error::ErrorKind::User
|
|
}
|
|
}
|
|
|
|
/// Various client credentials which we use for authentication.
|
|
/// Note that we don't store any kind of client key or password here.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub(crate) struct ComputeUserInfoMaybeEndpoint {
|
|
pub(crate) user: RoleName,
|
|
pub(crate) endpoint_id: Option<EndpointId>,
|
|
pub(crate) options: NeonOptions,
|
|
}
|
|
|
|
impl ComputeUserInfoMaybeEndpoint {
|
|
#[inline]
|
|
pub(crate) fn endpoint(&self) -> Option<&str> {
|
|
self.endpoint_id.as_deref()
|
|
}
|
|
}
|
|
|
|
pub(crate) fn endpoint_sni(sni: &str, common_names: &HashSet<String>) -> Option<EndpointId> {
|
|
let (subdomain, common_name) = sni.split_once('.')?;
|
|
if !common_names.contains(common_name) {
|
|
return None;
|
|
}
|
|
if subdomain == SERVERLESS_DRIVER_SNI || subdomain == AUTH_BROKER_SNI {
|
|
return None;
|
|
}
|
|
Some(EndpointId::from(subdomain))
|
|
}
|
|
|
|
impl ComputeUserInfoMaybeEndpoint {
|
|
pub(crate) fn parse(
|
|
ctx: &RequestContext,
|
|
params: &StartupMessageParams,
|
|
sni: Option<&str>,
|
|
common_names: Option<&HashSet<String>>,
|
|
) -> Result<Self, ComputeUserInfoParseError> {
|
|
// Some parameters are stored in the startup message.
|
|
let get_param = |key| {
|
|
params
|
|
.get(key)
|
|
.ok_or(ComputeUserInfoParseError::MissingKey(key))
|
|
};
|
|
let user: RoleName = get_param("user")?.into();
|
|
|
|
// Project name might be passed via PG's command-line options.
|
|
let endpoint_option = params
|
|
.options_raw()
|
|
.and_then(|options| {
|
|
// We support both `project` (deprecated) and `endpoint` options for backward compatibility.
|
|
// However, if both are present, we don't exactly know which one to use.
|
|
// Therefore we require that only one of them is present.
|
|
options
|
|
.filter_map(parse_endpoint_param)
|
|
.at_most_one()
|
|
.ok()?
|
|
})
|
|
.map(|name| name.into());
|
|
|
|
let endpoint_from_domain =
|
|
sni.and_then(|sni_str| common_names.and_then(|cn| endpoint_sni(sni_str, cn)));
|
|
|
|
let endpoint = match (endpoint_option, endpoint_from_domain) {
|
|
// Invariant: if we have both project name variants, they should match.
|
|
(Some(option), Some(domain)) if option != domain => {
|
|
Some(Err(ComputeUserInfoParseError::InconsistentProjectNames {
|
|
domain,
|
|
option,
|
|
}))
|
|
}
|
|
// Invariant: project name may not contain certain characters.
|
|
(a, b) => a.or(b).map(|name| {
|
|
if project_name_valid(name.as_ref()) {
|
|
Ok(name)
|
|
} else {
|
|
Err(ComputeUserInfoParseError::MalformedProjectName(name))
|
|
}
|
|
}),
|
|
}
|
|
.transpose()?;
|
|
|
|
if let Some(ep) = &endpoint {
|
|
ctx.set_endpoint_id(ep.clone());
|
|
}
|
|
|
|
let metrics = Metrics::get();
|
|
debug!(%user, "credentials");
|
|
|
|
let protocol = ctx.protocol();
|
|
let kind = if sni.is_some() {
|
|
debug!("Connection with sni");
|
|
SniKind::Sni
|
|
} else if endpoint.is_some() {
|
|
debug!("Connection without sni");
|
|
SniKind::NoSni
|
|
} else {
|
|
debug!("Connection with password hack");
|
|
SniKind::PasswordHack
|
|
};
|
|
|
|
metrics
|
|
.proxy
|
|
.accepted_connections_by_sni
|
|
.inc(SniGroup { protocol, kind });
|
|
|
|
let options = NeonOptions::parse_params(params);
|
|
|
|
Ok(Self {
|
|
user,
|
|
endpoint_id: endpoint,
|
|
options,
|
|
})
|
|
}
|
|
}
|
|
|
|
pub(crate) fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
|
|
ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern))
|
|
}
|
|
|
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
|
pub(crate) enum IpPattern {
|
|
Subnet(ipnet::IpNet),
|
|
Range(IpAddr, IpAddr),
|
|
Single(IpAddr),
|
|
None,
|
|
}
|
|
|
|
impl<'de> serde::de::Deserialize<'de> for IpPattern {
|
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
where
|
|
D: serde::Deserializer<'de>,
|
|
{
|
|
struct StrVisitor;
|
|
impl serde::de::Visitor<'_> for StrVisitor {
|
|
type Value = IpPattern;
|
|
|
|
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(
|
|
formatter,
|
|
"comma separated list with ip address, ip address range, or ip address subnet mask"
|
|
)
|
|
}
|
|
|
|
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(parse_ip_pattern(v).unwrap_or_else(|e| {
|
|
warn!("Cannot parse ip pattern {v}: {e}");
|
|
IpPattern::None
|
|
}))
|
|
}
|
|
}
|
|
deserializer.deserialize_str(StrVisitor)
|
|
}
|
|
}
|
|
|
|
impl FromStr for IpPattern {
|
|
type Err = anyhow::Error;
|
|
|
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
parse_ip_pattern(s)
|
|
}
|
|
}
|
|
|
|
fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
|
|
if pattern.contains('/') {
|
|
let subnet: ipnet::IpNet = pattern.parse()?;
|
|
return Ok(IpPattern::Subnet(subnet));
|
|
}
|
|
if let Some((start, end)) = pattern.split_once('-') {
|
|
let start: IpAddr = start.parse()?;
|
|
let end: IpAddr = end.parse()?;
|
|
return Ok(IpPattern::Range(start, end));
|
|
}
|
|
let addr: IpAddr = pattern.parse()?;
|
|
Ok(IpPattern::Single(addr))
|
|
}
|
|
|
|
fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
|
|
match pattern {
|
|
IpPattern::Subnet(subnet) => subnet.contains(ip),
|
|
IpPattern::Range(start, end) => start <= ip && ip <= end,
|
|
IpPattern::Single(addr) => addr == ip,
|
|
IpPattern::None => false,
|
|
}
|
|
}
|
|
|
|
fn project_name_valid(name: &str) -> bool {
|
|
name.chars().all(|c| c.is_alphanumeric() || c == '-')
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use ComputeUserInfoParseError::*;
|
|
use serde_json::json;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn parse_bare_minimum() -> anyhow::Result<()> {
|
|
// According to postgresql, only `user` should be required.
|
|
let options = StartupMessageParams::new([("user", "john_doe")]);
|
|
let ctx = RequestContext::test();
|
|
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
|
assert_eq!(user_info.user, "john_doe");
|
|
assert_eq!(user_info.endpoint_id, None);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn parse_excessive() -> anyhow::Result<()> {
|
|
let options = StartupMessageParams::new([
|
|
("user", "john_doe"),
|
|
("database", "world"), // should be ignored
|
|
("foo", "bar"), // should be ignored
|
|
]);
|
|
let ctx = RequestContext::test();
|
|
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
|
assert_eq!(user_info.user, "john_doe");
|
|
assert_eq!(user_info.endpoint_id, None);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn parse_project_from_sni() -> anyhow::Result<()> {
|
|
let options = StartupMessageParams::new([("user", "john_doe")]);
|
|
|
|
let sni = Some("foo.localhost");
|
|
let common_names = Some(["localhost".into()].into());
|
|
|
|
let ctx = RequestContext::test();
|
|
let user_info =
|
|
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
|
assert_eq!(user_info.user, "john_doe");
|
|
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
|
|
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn parse_project_from_options() -> anyhow::Result<()> {
|
|
let options = StartupMessageParams::new([
|
|
("user", "john_doe"),
|
|
("options", "-ckey=1 project=bar -c geqo=off"),
|
|
]);
|
|
|
|
let ctx = RequestContext::test();
|
|
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
|
assert_eq!(user_info.user, "john_doe");
|
|
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn parse_endpoint_from_options() -> anyhow::Result<()> {
|
|
let options = StartupMessageParams::new([
|
|
("user", "john_doe"),
|
|
("options", "-ckey=1 endpoint=bar -c geqo=off"),
|
|
]);
|
|
|
|
let ctx = RequestContext::test();
|
|
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
|
assert_eq!(user_info.user, "john_doe");
|
|
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn parse_three_endpoints_from_options() -> anyhow::Result<()> {
|
|
let options = StartupMessageParams::new([
|
|
("user", "john_doe"),
|
|
(
|
|
"options",
|
|
"-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off",
|
|
),
|
|
]);
|
|
|
|
let ctx = RequestContext::test();
|
|
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
|
assert_eq!(user_info.user, "john_doe");
|
|
assert!(user_info.endpoint_id.is_none());
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> {
|
|
let options = StartupMessageParams::new([
|
|
("user", "john_doe"),
|
|
("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
|
|
]);
|
|
|
|
let ctx = RequestContext::test();
|
|
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
|
assert_eq!(user_info.user, "john_doe");
|
|
assert!(user_info.endpoint_id.is_none());
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn parse_projects_identical() -> anyhow::Result<()> {
|
|
let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
|
|
|
|
let sni = Some("baz.localhost");
|
|
let common_names = Some(["localhost".into()].into());
|
|
|
|
let ctx = RequestContext::test();
|
|
let user_info =
|
|
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
|
assert_eq!(user_info.user, "john_doe");
|
|
assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn parse_multi_common_names() -> anyhow::Result<()> {
|
|
let options = StartupMessageParams::new([("user", "john_doe")]);
|
|
|
|
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
|
let sni = Some("p1.a.com");
|
|
let ctx = RequestContext::test();
|
|
let user_info =
|
|
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
|
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
|
|
|
|
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
|
let sni = Some("p1.b.com");
|
|
let ctx = RequestContext::test();
|
|
let user_info =
|
|
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
|
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn parse_projects_different() {
|
|
let options =
|
|
StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
|
|
|
|
let sni = Some("second.localhost");
|
|
let common_names = Some(["localhost".into()].into());
|
|
|
|
let ctx = RequestContext::test();
|
|
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
|
|
.expect_err("should fail");
|
|
match err {
|
|
InconsistentProjectNames { domain, option } => {
|
|
assert_eq!(option, "first");
|
|
assert_eq!(domain, "second");
|
|
}
|
|
_ => panic!("bad error: {err:?}"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn parse_unknown_sni() {
|
|
let options = StartupMessageParams::new([("user", "john_doe")]);
|
|
|
|
let sni = Some("project.localhost");
|
|
let common_names = Some(["example.com".into()].into());
|
|
|
|
let ctx = RequestContext::test();
|
|
let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
|
|
.unwrap();
|
|
|
|
assert!(info.endpoint_id.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_unknown_sni_with_options() {
|
|
let options = StartupMessageParams::new([
|
|
("user", "john_doe"),
|
|
("options", "endpoint=foo-bar-baz-1234"),
|
|
]);
|
|
|
|
let sni = Some("project.localhost");
|
|
let common_names = Some(["example.com".into()].into());
|
|
|
|
let ctx = RequestContext::test();
|
|
let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
|
|
.unwrap();
|
|
|
|
assert_eq!(info.endpoint_id.as_deref(), Some("foo-bar-baz-1234"));
|
|
}
|
|
|
|
#[test]
|
|
fn parse_neon_options() -> anyhow::Result<()> {
|
|
let options = StartupMessageParams::new([
|
|
("user", "john_doe"),
|
|
("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
|
|
]);
|
|
|
|
let sni = Some("project.localhost");
|
|
let common_names = Some(["localhost".into()].into());
|
|
let ctx = RequestContext::test();
|
|
let user_info =
|
|
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
|
assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
|
|
assert_eq!(
|
|
user_info.options.get_cache_key("project"),
|
|
"project endpoint_type:read_write lsn:0/2"
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_check_peer_addr_is_in_list() {
|
|
fn check(v: serde_json::Value) -> bool {
|
|
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
|
let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
|
|
check_peer_addr_is_in_list(&peer_addr, &ip_list)
|
|
}
|
|
|
|
assert!(check(json!([])));
|
|
assert!(check(json!(["127.0.0.1"])));
|
|
assert!(!check(json!(["8.8.8.8"])));
|
|
// If there is an incorrect address, it will be skipped.
|
|
assert!(check(json!(["88.8.8", "127.0.0.1"])));
|
|
}
|
|
#[test]
|
|
fn test_parse_ip_v4() -> anyhow::Result<()> {
|
|
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
|
// Ok
|
|
assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
|
|
assert_eq!(
|
|
parse_ip_pattern("127.0.0.1/31")?,
|
|
IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
|
|
);
|
|
assert_eq!(
|
|
parse_ip_pattern("0.0.0.0-200.0.1.2")?,
|
|
IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
|
|
);
|
|
|
|
// Error
|
|
assert!(parse_ip_pattern("300.0.1.2").is_err());
|
|
assert!(parse_ip_pattern("30.1.2").is_err());
|
|
assert!(parse_ip_pattern("127.0.0.1/33").is_err());
|
|
assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
|
|
assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_check_ipv4() -> anyhow::Result<()> {
|
|
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
|
let peer_addr_next = IpAddr::from([127, 0, 0, 2]);
|
|
let peer_addr_prev = IpAddr::from([127, 0, 0, 0]);
|
|
// Success
|
|
assert!(check_ip(&peer_addr, &IpPattern::Single(peer_addr)));
|
|
assert!(check_ip(
|
|
&peer_addr,
|
|
&IpPattern::Subnet(ipnet::IpNet::new(peer_addr_prev, 31)?)
|
|
));
|
|
assert!(check_ip(
|
|
&peer_addr,
|
|
&IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 30)?)
|
|
));
|
|
assert!(check_ip(
|
|
&peer_addr,
|
|
&IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
|
|
));
|
|
assert!(check_ip(
|
|
&peer_addr,
|
|
&IpPattern::Range(peer_addr, peer_addr)
|
|
));
|
|
|
|
// Not success
|
|
assert!(!check_ip(&peer_addr, &IpPattern::Single(peer_addr_prev)));
|
|
assert!(!check_ip(
|
|
&peer_addr,
|
|
&IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 31)?)
|
|
));
|
|
assert!(!check_ip(
|
|
&peer_addr,
|
|
&IpPattern::Range(IpAddr::from([0, 0, 0, 0]), peer_addr_prev)
|
|
));
|
|
assert!(!check_ip(
|
|
&peer_addr,
|
|
&IpPattern::Range(peer_addr_next, IpAddr::from([128, 0, 0, 0]))
|
|
));
|
|
// There is no check that for range start <= end. But it's fine as long as for all this cases the result is false.
|
|
assert!(!check_ip(
|
|
&peer_addr,
|
|
&IpPattern::Range(peer_addr, peer_addr_prev)
|
|
));
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_connection_blocker() {
|
|
fn check(v: serde_json::Value) -> bool {
|
|
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
|
let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
|
|
check_peer_addr_is_in_list(&peer_addr, &ip_list)
|
|
}
|
|
|
|
assert!(check(json!([])));
|
|
assert!(check(json!(["127.0.0.1"])));
|
|
assert!(!check(json!(["255.255.255.255"])));
|
|
}
|
|
}
|