mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 05:52:55 +00:00
proxy: local auth backend (#8806)
Adds a Local authentication backend. Updates http to extract JWT bearer tokens and passes them to the local backend to validate.
This commit is contained in:
@@ -2,6 +2,7 @@ mod classic;
|
||||
mod hacks;
|
||||
pub mod jwt;
|
||||
mod link;
|
||||
pub mod local;
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
@@ -9,6 +10,7 @@ use std::time::Duration;
|
||||
|
||||
use ipnet::{Ipv4Net, Ipv6Net};
|
||||
pub use link::LinkAuthError;
|
||||
use local::LocalBackend;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
use tracing::{info, warn};
|
||||
@@ -68,6 +70,8 @@ pub enum BackendType<'a, T, D> {
|
||||
Console(MaybeOwned<'a, ConsoleBackend>, T),
|
||||
/// Authentication via a web browser.
|
||||
Link(MaybeOwned<'a, url::ApiUrl>, D),
|
||||
/// Local proxy uses configured auth credentials and does not wake compute
|
||||
Local(MaybeOwned<'a, LocalBackend>),
|
||||
}
|
||||
|
||||
pub trait TestBackend: Send + Sync + 'static {
|
||||
@@ -93,6 +97,7 @@ impl std::fmt::Display for BackendType<'_, (), ()> {
|
||||
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
|
||||
},
|
||||
Self::Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
Self::Local(_) => fmt.debug_tuple("Local").finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -104,6 +109,7 @@ impl<T, D> BackendType<'_, T, D> {
|
||||
match self {
|
||||
Self::Console(c, x) => BackendType::Console(MaybeOwned::Borrowed(c), x),
|
||||
Self::Link(c, x) => BackendType::Link(MaybeOwned::Borrowed(c), x),
|
||||
Self::Local(l) => BackendType::Local(MaybeOwned::Borrowed(l)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -116,6 +122,7 @@ impl<'a, T, D> BackendType<'a, T, D> {
|
||||
match self {
|
||||
Self::Console(c, x) => BackendType::Console(c, f(x)),
|
||||
Self::Link(c, x) => BackendType::Link(c, x),
|
||||
Self::Local(l) => BackendType::Local(l),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -126,6 +133,7 @@ impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
|
||||
match self {
|
||||
Self::Console(c, x) => x.map(|x| BackendType::Console(c, x)),
|
||||
Self::Link(c, x) => Ok(BackendType::Link(c, x)),
|
||||
Self::Local(l) => Ok(BackendType::Local(l)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -157,6 +165,7 @@ impl ComputeUserInfo {
|
||||
pub enum ComputeCredentialKeys {
|
||||
Password(Vec<u8>),
|
||||
AuthKeys(AuthKeys),
|
||||
None,
|
||||
}
|
||||
|
||||
impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
|
||||
@@ -289,7 +298,7 @@ async fn auth_quirks(
|
||||
ctx.set_endpoint_id(res.info.endpoint.clone());
|
||||
let password = match res.keys {
|
||||
ComputeCredentialKeys::Password(p) => p,
|
||||
ComputeCredentialKeys::AuthKeys(_) => {
|
||||
ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => {
|
||||
unreachable!("password hack should return a password")
|
||||
}
|
||||
};
|
||||
@@ -401,6 +410,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
match self {
|
||||
Self::Console(_, user_info) => user_info.endpoint_id.clone(),
|
||||
Self::Link(_, _) => Some("link".into()),
|
||||
Self::Local(_) => Some("local".into()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -409,6 +419,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
match self {
|
||||
Self::Console(_, user_info) => &user_info.user,
|
||||
Self::Link(_, _) => "link",
|
||||
Self::Local(_) => "local",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -450,6 +461,9 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
|
||||
BackendType::Link(url, info)
|
||||
}
|
||||
Self::Local(_) => {
|
||||
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
|
||||
}
|
||||
};
|
||||
|
||||
info!("user successfully authenticated");
|
||||
@@ -465,6 +479,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
match self {
|
||||
Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
|
||||
Self::Link(_, _) => Ok(Cached::new_uncached(None)),
|
||||
Self::Local(_) => Ok(Cached::new_uncached(None)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -475,6 +490,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
match self {
|
||||
Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Self::Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -488,13 +504,15 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
|
||||
match self {
|
||||
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Self::Link(_, info) => Ok(Cached::new_uncached(info.clone())),
|
||||
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
match self {
|
||||
Self::Console(_, creds) => Some(&creds.keys),
|
||||
Self::Link(_, _) => None,
|
||||
Self::Console(_, creds) => &creds.keys,
|
||||
Self::Link(_, _) => &ComputeCredentialKeys::None,
|
||||
Self::Local(_) => &ComputeCredentialKeys::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -508,13 +526,15 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
|
||||
match self {
|
||||
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Self::Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
|
||||
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
match self {
|
||||
Self::Console(_, creds) => Some(&creds.keys),
|
||||
Self::Link(_, _) => None,
|
||||
Self::Console(_, creds) => &creds.keys,
|
||||
Self::Link(_, _) => &ComputeCredentialKeys::None,
|
||||
Self::Local(_) => &ComputeCredentialKeys::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
79
proxy/src/auth/backend/local.rs
Normal file
79
proxy/src/auth/backend/local.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use std::{collections::HashMap, net::SocketAddr};
|
||||
|
||||
use anyhow::Context;
|
||||
use arc_swap::ArcSwapOption;
|
||||
|
||||
use crate::{
|
||||
compute::ConnCfg,
|
||||
console::{
|
||||
messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo},
|
||||
NodeInfo,
|
||||
},
|
||||
intern::{BranchIdInt, BranchIdTag, EndpointIdTag, InternId, ProjectIdInt, ProjectIdTag},
|
||||
RoleName,
|
||||
};
|
||||
|
||||
use super::jwt::{AuthRule, FetchAuthRules, JwkCache};
|
||||
|
||||
pub struct LocalBackend {
|
||||
pub jwks_cache: JwkCache,
|
||||
pub postgres_addr: SocketAddr,
|
||||
pub node_info: NodeInfo,
|
||||
}
|
||||
|
||||
impl LocalBackend {
|
||||
pub fn new(postgres_addr: SocketAddr) -> Self {
|
||||
LocalBackend {
|
||||
jwks_cache: JwkCache::default(),
|
||||
postgres_addr,
|
||||
node_info: NodeInfo {
|
||||
config: {
|
||||
let mut cfg = ConnCfg::new();
|
||||
cfg.host(&postgres_addr.ip().to_string());
|
||||
cfg.port(postgres_addr.port());
|
||||
cfg
|
||||
},
|
||||
// TODO(conrad): make this better reflect compute info rather than endpoint info.
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),
|
||||
project_id: ProjectIdTag::get_interner().get_or_intern("local"),
|
||||
branch_id: BranchIdTag::get_interner().get_or_intern("local"),
|
||||
cold_start_info: ColdStartInfo::WarmCached,
|
||||
},
|
||||
allow_self_signed_compute: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct StaticAuthRules;
|
||||
|
||||
pub static JWKS_ROLE_MAP: ArcSwapOption<JwksRoleSettings> = ArcSwapOption::const_empty();
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JwksRoleSettings {
|
||||
pub roles: HashMap<RoleName, EndpointJwksResponse>,
|
||||
pub project_id: ProjectIdInt,
|
||||
pub branch_id: BranchIdInt,
|
||||
}
|
||||
|
||||
impl FetchAuthRules for StaticAuthRules {
|
||||
async fn fetch_auth_rules(&self, role_name: RoleName) -> anyhow::Result<Vec<AuthRule>> {
|
||||
let mappings = JWKS_ROLE_MAP.load();
|
||||
let role_mappings = mappings
|
||||
.as_deref()
|
||||
.and_then(|m| m.roles.get(&role_name))
|
||||
.context("JWKs settings for this role were not configured")?;
|
||||
let mut rules = vec![];
|
||||
for setting in &role_mappings.jwks {
|
||||
rules.push(AuthRule {
|
||||
id: setting.id.clone(),
|
||||
jwks_url: setting.jwks_url.clone(),
|
||||
audience: setting.jwt_audience.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(rules)
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
use measured::FixedCardinalityLabel;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::{self, Display};
|
||||
|
||||
use crate::auth::IpPattern;
|
||||
|
||||
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
|
||||
use crate::proxy::retry::CouldRetry;
|
||||
use crate::RoleName;
|
||||
|
||||
/// Generic error response with human-readable description.
|
||||
/// Note that we can't always present it to user as is.
|
||||
@@ -341,6 +343,26 @@ impl ColdStartInfo {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct JwksRoleMapping {
|
||||
pub roles: HashMap<RoleName, EndpointJwksResponse>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct EndpointJwksResponse {
|
||||
pub jwks: Vec<JwksSettings>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct JwksSettings {
|
||||
pub id: String,
|
||||
pub project_id: ProjectIdInt,
|
||||
pub branch_id: BranchIdInt,
|
||||
pub jwks_url: url::Url,
|
||||
pub provider_name: String,
|
||||
pub jwt_audience: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -305,6 +305,7 @@ impl NodeInfo {
|
||||
match keys {
|
||||
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||
ComputeCredentialKeys::None => &mut self.config,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ pub trait ComputeConnectBackend {
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys>;
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys;
|
||||
}
|
||||
|
||||
pub struct TcpMechanism<'a> {
|
||||
@@ -112,9 +112,8 @@ where
|
||||
let mut num_retries = 0;
|
||||
let mut node_info =
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
if let Some(keys) = user_info.get_keys() {
|
||||
node_info.set_keys(keys);
|
||||
}
|
||||
|
||||
node_info.set_keys(user_info.get_keys());
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
// let mut node_info = credentials.get_node_info(ctx, user_info).await?;
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
@@ -407,7 +407,7 @@ async fn request_handler(
|
||||
.header("Access-Control-Allow-Origin", "*")
|
||||
.header(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
|
||||
"Authorization, Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
|
||||
@@ -4,7 +4,10 @@ use async_trait::async_trait;
|
||||
use tracing::{field::display, info};
|
||||
|
||||
use crate::{
|
||||
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
|
||||
auth::{
|
||||
backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo},
|
||||
check_peer_addr_is_in_list, AuthError,
|
||||
},
|
||||
compute,
|
||||
config::{AuthenticationConfig, ProxyConfig},
|
||||
console::{
|
||||
@@ -24,7 +27,7 @@ use crate::{
|
||||
Host,
|
||||
};
|
||||
|
||||
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
|
||||
use super::conn_pool::{poll_client, AuthData, Client, ConnInfo, GlobalConnPool};
|
||||
|
||||
pub struct PoolingBackend {
|
||||
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
@@ -33,13 +36,14 @@ pub struct PoolingBackend {
|
||||
}
|
||||
|
||||
impl PoolingBackend {
|
||||
pub async fn authenticate(
|
||||
pub async fn authenticate_with_password(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
conn_info: &ConnInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
password: &[u8],
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
let user_info = conn_info.user_info.clone();
|
||||
let user_info = user_info.clone();
|
||||
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
|
||||
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
|
||||
@@ -47,7 +51,7 @@ impl PoolingBackend {
|
||||
}
|
||||
if !self
|
||||
.endpoint_rate_limiter
|
||||
.check(conn_info.user_info.endpoint.clone().into(), 1)
|
||||
.check(user_info.endpoint.clone().into(), 1)
|
||||
{
|
||||
return Err(AuthError::too_many_connections());
|
||||
}
|
||||
@@ -70,14 +74,10 @@ impl PoolingBackend {
|
||||
return Err(AuthError::auth_failed(&*user_info.user));
|
||||
}
|
||||
};
|
||||
let ep = EndpointIdInt::from(&conn_info.user_info.endpoint);
|
||||
let auth_outcome = crate::auth::validate_password_and_exchange(
|
||||
&config.thread_pool,
|
||||
ep,
|
||||
&conn_info.password,
|
||||
secret,
|
||||
)
|
||||
.await?;
|
||||
let ep = EndpointIdInt::from(&user_info.endpoint);
|
||||
let auth_outcome =
|
||||
crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret)
|
||||
.await?;
|
||||
let res = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => {
|
||||
info!("user successfully authenticated");
|
||||
@@ -85,7 +85,7 @@ impl PoolingBackend {
|
||||
}
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
Err(AuthError::auth_failed(&*conn_info.user_info.user))
|
||||
Err(AuthError::auth_failed(&*user_info.user))
|
||||
}
|
||||
};
|
||||
res.map(|key| ComputeCredentials {
|
||||
@@ -94,6 +94,39 @@ impl PoolingBackend {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn authenticate_with_jwt(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
jwt: &str,
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
match &self.config.auth_backend {
|
||||
crate::auth::BackendType::Console(_, _) => {
|
||||
Err(AuthError::auth_failed("JWT login is not yet supported"))
|
||||
}
|
||||
crate::auth::BackendType::Link(_, _) => Err(AuthError::auth_failed(
|
||||
"JWT login over link proxy is not supported",
|
||||
)),
|
||||
crate::auth::BackendType::Local(cache) => {
|
||||
cache
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
ctx,
|
||||
user_info.endpoint.clone(),
|
||||
user_info.user.clone(),
|
||||
&StaticAuthRules,
|
||||
jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
||||
Ok(ComputeCredentials {
|
||||
info: user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wake up the destination if needed. Code here is a bit involved because
|
||||
// we reuse the code from the usual proxy and we need to prepare few structures
|
||||
// that this code expects.
|
||||
@@ -232,10 +265,16 @@ impl ConnectMechanism for TokioMechanism {
|
||||
let mut config = (*node_info.config).clone();
|
||||
let config = config
|
||||
.user(&self.conn_info.user_info.user)
|
||||
.password(&*self.conn_info.password)
|
||||
.dbname(&self.conn_info.dbname)
|
||||
.connect_timeout(timeout);
|
||||
|
||||
match &self.conn_info.auth {
|
||||
AuthData::Jwt(_) => {}
|
||||
AuthData::Password(pw) => {
|
||||
config.password(pw);
|
||||
}
|
||||
}
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let res = config.connect(tokio_postgres::NoTls).await;
|
||||
drop(pause);
|
||||
|
||||
@@ -33,7 +33,13 @@ use super::backend::HttpConnError;
|
||||
pub struct ConnInfo {
|
||||
pub user_info: ComputeUserInfo,
|
||||
pub dbname: DbName,
|
||||
pub password: SmallVec<[u8; 16]>,
|
||||
pub auth: AuthData,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AuthData {
|
||||
Password(SmallVec<[u8; 16]>),
|
||||
Jwt(String),
|
||||
}
|
||||
|
||||
impl ConnInfo {
|
||||
@@ -778,7 +784,7 @@ mod tests {
|
||||
options: Default::default(),
|
||||
},
|
||||
dbname: "dbname".into(),
|
||||
password: "password".as_bytes().into(),
|
||||
auth: AuthData::Password("password".as_bytes().into()),
|
||||
};
|
||||
let ep_pool = Arc::downgrade(
|
||||
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
|
||||
@@ -836,7 +842,7 @@ mod tests {
|
||||
options: Default::default(),
|
||||
},
|
||||
dbname: "dbname".into(),
|
||||
password: "password".as_bytes().into(),
|
||||
auth: AuthData::Password("password".as_bytes().into()),
|
||||
};
|
||||
let ep_pool = Arc::downgrade(
|
||||
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
|
||||
|
||||
@@ -7,6 +7,7 @@ use futures::future::try_join;
|
||||
use futures::future::Either;
|
||||
use futures::StreamExt;
|
||||
use futures::TryFutureExt;
|
||||
use http::header::AUTHORIZATION;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper1::body::Body;
|
||||
@@ -56,6 +57,7 @@ use crate::DbName;
|
||||
use crate::RoleName;
|
||||
|
||||
use super::backend::PoolingBackend;
|
||||
use super::conn_pool::AuthData;
|
||||
use super::conn_pool::Client;
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::http_util::json_response;
|
||||
@@ -88,6 +90,7 @@ enum Payload {
|
||||
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
|
||||
const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
|
||||
|
||||
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
|
||||
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
|
||||
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
|
||||
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
|
||||
@@ -109,7 +112,7 @@ where
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnInfoError {
|
||||
#[error("invalid header: {0}")]
|
||||
InvalidHeader(&'static str),
|
||||
InvalidHeader(&'static HeaderName),
|
||||
#[error("invalid connection string: {0}")]
|
||||
UrlParseError(#[from] url::ParseError),
|
||||
#[error("incorrect scheme")]
|
||||
@@ -153,10 +156,10 @@ fn get_conn_info(
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
let connection_string = headers
|
||||
.get("Neon-Connection-String")
|
||||
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
|
||||
.get(&CONN_STRING)
|
||||
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?;
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
|
||||
|
||||
let connection_url = Url::parse(connection_string)?;
|
||||
|
||||
@@ -179,10 +182,23 @@ fn get_conn_info(
|
||||
}
|
||||
ctx.set_user(username.clone());
|
||||
|
||||
let password = connection_url
|
||||
.password()
|
||||
.ok_or(ConnInfoError::MissingPassword)?;
|
||||
let password = urlencoding::decode_binary(password.as_bytes());
|
||||
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
|
||||
let auth = auth
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
|
||||
AuthData::Jwt(
|
||||
auth.strip_prefix("Bearer ")
|
||||
.ok_or(ConnInfoError::MissingPassword)?
|
||||
.into(),
|
||||
)
|
||||
} else if let Some(pass) = connection_url.password() {
|
||||
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
|
||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||
std::borrow::Cow::Owned(b) => b.into(),
|
||||
})
|
||||
} else {
|
||||
return Err(ConnInfoError::MissingPassword);
|
||||
};
|
||||
|
||||
let endpoint = match connection_url.host() {
|
||||
Some(url::Host::Domain(hostname)) => {
|
||||
@@ -225,10 +241,7 @@ fn get_conn_info(
|
||||
Ok(ConnInfo {
|
||||
user_info,
|
||||
dbname,
|
||||
password: match password {
|
||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||
std::borrow::Cow::Owned(b) => b.into(),
|
||||
},
|
||||
auth,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -550,9 +563,24 @@ async fn handle_inner(
|
||||
|
||||
let authenticate_and_connect = Box::pin(
|
||||
async {
|
||||
let keys = backend
|
||||
.authenticate(ctx, &config.authentication_config, &conn_info)
|
||||
.await?;
|
||||
let keys = match &conn_info.auth {
|
||||
AuthData::Password(pw) => {
|
||||
backend
|
||||
.authenticate_with_password(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
pw,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
AuthData::Jwt(jwt) => {
|
||||
backend
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await?;
|
||||
|
||||
Reference in New Issue
Block a user