mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 05:22:56 +00:00
subzero pre-integration refactor (#12416)
## Problem integrating subzero requires a bit of refactoring. To make the integration PR a bit more manageable, the refactoring is done in this separate PR. ## Summary of changes * move common types/functions used in sql_over_http to errors.rs and http_util.rs * add the "Local" auth backend to proxy (similar to local_proxy), useful in local testing * change the Connect and Send type for the http client to allow for custom body when making post requests to local_proxy from the proxy --------- Co-authored-by: Ruslan Talpa <ruslan.talpa@databricks.com>
This commit is contained in:
@@ -115,7 +115,8 @@ impl PoolingBackend {
|
||||
|
||||
match &self.auth_backend {
|
||||
crate::auth::Backend::ControlPlane(console, ()) => {
|
||||
self.config
|
||||
let keys = self
|
||||
.config
|
||||
.authentication_config
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
@@ -129,7 +130,7 @@ impl PoolingBackend {
|
||||
|
||||
Ok(ComputeCredentials {
|
||||
info: user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
keys,
|
||||
})
|
||||
}
|
||||
crate::auth::Backend::Local(_) => {
|
||||
@@ -256,6 +257,7 @@ impl PoolingBackend {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
conn_info: ConnInfo,
|
||||
disable_pg_session_jwt: bool,
|
||||
) -> Result<Client<postgres_client::Client>, HttpConnError> {
|
||||
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
|
||||
return Ok(client);
|
||||
@@ -277,7 +279,7 @@ impl PoolingBackend {
|
||||
.expect("semaphore should never be closed");
|
||||
|
||||
// check again for race
|
||||
if !self.local_pool.initialized(&conn_info) {
|
||||
if !self.local_pool.initialized(&conn_info) && !disable_pg_session_jwt {
|
||||
local_backend
|
||||
.compute_ctl
|
||||
.install_extension(&ExtensionInstallRequest {
|
||||
@@ -313,14 +315,16 @@ impl PoolingBackend {
|
||||
.to_postgres_client_config();
|
||||
config
|
||||
.user(&conn_info.user_info.user)
|
||||
.dbname(&conn_info.dbname)
|
||||
.set_param(
|
||||
.dbname(&conn_info.dbname);
|
||||
if !disable_pg_session_jwt {
|
||||
config.set_param(
|
||||
"options",
|
||||
&format!(
|
||||
"-c pg_session_jwt.jwk={}",
|
||||
serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (client, connection) = config.connect(&postgres_client::NoTls).await?;
|
||||
@@ -345,9 +349,11 @@ impl PoolingBackend {
|
||||
debug!("setting up backend session state");
|
||||
|
||||
// initiates the auth session
|
||||
if let Err(e) = client.batch_execute("select auth.init();").await {
|
||||
discard.discard();
|
||||
return Err(e.into());
|
||||
if !disable_pg_session_jwt {
|
||||
if let Err(e) = client.batch_execute("select auth.init();").await {
|
||||
discard.discard();
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
|
||||
info!("backend session state initialized");
|
||||
|
||||
@@ -1,5 +1,93 @@
|
||||
use http::StatusCode;
|
||||
use http::header::HeaderName;
|
||||
|
||||
use crate::auth::ComputeUserInfoParseError;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::ReadBodyError;
|
||||
|
||||
pub trait HttpCodeError {
|
||||
fn get_http_status_code(&self) -> StatusCode;
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum ConnInfoError {
|
||||
#[error("invalid header: {0}")]
|
||||
InvalidHeader(&'static HeaderName),
|
||||
#[error("invalid connection string: {0}")]
|
||||
UrlParseError(#[from] url::ParseError),
|
||||
#[error("incorrect scheme")]
|
||||
IncorrectScheme,
|
||||
#[error("missing database name")]
|
||||
MissingDbName,
|
||||
#[error("invalid database name")]
|
||||
InvalidDbName,
|
||||
#[error("missing username")]
|
||||
MissingUsername,
|
||||
#[error("invalid username: {0}")]
|
||||
InvalidUsername(#[from] std::string::FromUtf8Error),
|
||||
#[error("missing authentication credentials: {0}")]
|
||||
MissingCredentials(Credentials),
|
||||
#[error("missing hostname")]
|
||||
MissingHostname,
|
||||
#[error("invalid hostname: {0}")]
|
||||
InvalidEndpoint(#[from] ComputeUserInfoParseError),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum Credentials {
|
||||
#[error("required password")]
|
||||
Password,
|
||||
#[error("required authorization bearer token in JWT format")]
|
||||
BearerJwt,
|
||||
}
|
||||
|
||||
impl ReportableError for ConnInfoError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ConnInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum ReadPayloadError {
|
||||
#[error("could not read the HTTP request body: {0}")]
|
||||
Read(#[from] hyper::Error),
|
||||
#[error("request is too large (max is {limit} bytes)")]
|
||||
BodyTooLarge { limit: usize },
|
||||
#[error("could not parse the HTTP request body: {0}")]
|
||||
Parse(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
|
||||
fn from(value: ReadBodyError<hyper::Error>) -> Self {
|
||||
match value {
|
||||
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
|
||||
ReadBodyError::Read(e) => Self::Read(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ReadPayloadError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
|
||||
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
|
||||
ReadPayloadError::Parse(_) => ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpCodeError for ReadPayloadError {
|
||||
fn get_http_status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
|
||||
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
|
||||
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,9 +20,12 @@ use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::types::EndpointCacheKey;
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use bytes::Bytes;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
|
||||
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
|
||||
pub(crate) type Connect = http2::Connection<TokioIo<AsyncRW>, hyper::body::Incoming, TokioExecutor>;
|
||||
pub(crate) type Send = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
|
||||
pub(crate) type Connect =
|
||||
http2::Connection<TokioIo<AsyncRW>, BoxBody<Bytes, hyper::Error>, TokioExecutor>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ClientDataHttp();
|
||||
|
||||
@@ -3,11 +3,43 @@
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use http::{Response, StatusCode};
|
||||
use http::header::AUTHORIZATION;
|
||||
use http::{HeaderMap, HeaderName, HeaderValue, Response, StatusCode};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use http_utils::error::ApiError;
|
||||
use serde::Serialize;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::conn_pool::AuthData;
|
||||
use super::conn_pool::ConnInfoWithAuth;
|
||||
use super::conn_pool_lib::ConnInfo;
|
||||
use super::error::{ConnInfoError, Credentials};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::metrics::{Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::types::{DbName, EndpointId, RoleName};
|
||||
|
||||
// Common header names used across serverless modules
|
||||
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
|
||||
pub(super) static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
|
||||
pub(super) static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
|
||||
pub(super) static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
|
||||
pub(super) static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
|
||||
pub(super) static TXN_ISOLATION_LEVEL: HeaderName =
|
||||
HeaderName::from_static("neon-batch-isolation-level");
|
||||
pub(super) static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
|
||||
pub(super) static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
|
||||
|
||||
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
|
||||
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
|
||||
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
|
||||
.expect("uuid hyphenated format should be all valid header characters")
|
||||
}
|
||||
|
||||
/// Like [`ApiError::into_response`]
|
||||
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
@@ -107,3 +139,136 @@ pub(crate) fn json_response<T: Serialize>(
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub(crate) fn get_conn_info(
|
||||
config: &'static AuthenticationConfig,
|
||||
ctx: &RequestContext,
|
||||
connection_string: Option<&str>,
|
||||
headers: &HeaderMap,
|
||||
) -> Result<ConnInfoWithAuth, ConnInfoError> {
|
||||
let connection_url = match connection_string {
|
||||
Some(connection_string) => Url::parse(connection_string)?,
|
||||
None => {
|
||||
let connection_string = headers
|
||||
.get(&CONN_STRING)
|
||||
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
|
||||
Url::parse(connection_string)?
|
||||
}
|
||||
};
|
||||
|
||||
let protocol = connection_url.scheme();
|
||||
if protocol != "postgres" && protocol != "postgresql" {
|
||||
return Err(ConnInfoError::IncorrectScheme);
|
||||
}
|
||||
|
||||
let mut url_path = connection_url
|
||||
.path_segments()
|
||||
.ok_or(ConnInfoError::MissingDbName)?;
|
||||
|
||||
let dbname: DbName =
|
||||
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
|
||||
ctx.set_dbname(dbname.clone());
|
||||
|
||||
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
|
||||
if username.is_empty() {
|
||||
return Err(ConnInfoError::MissingUsername);
|
||||
}
|
||||
ctx.set_user(username.clone());
|
||||
// TODO: make sure this is right in the context of rest broker
|
||||
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
|
||||
if !config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
}
|
||||
|
||||
let auth = auth
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
|
||||
AuthData::Jwt(
|
||||
auth.strip_prefix("Bearer ")
|
||||
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
|
||||
.into(),
|
||||
)
|
||||
} else if let Some(pass) = connection_url.password() {
|
||||
// wrong credentials provided
|
||||
if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
}
|
||||
|
||||
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
|
||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||
std::borrow::Cow::Owned(b) => b.into(),
|
||||
})
|
||||
} else if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
} else {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
};
|
||||
let endpoint: EndpointId = match connection_url.host() {
|
||||
Some(url::Host::Domain(hostname)) => hostname
|
||||
.split_once('.')
|
||||
.map_or(hostname, |(prefix, _)| prefix)
|
||||
.into(),
|
||||
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
|
||||
return Err(ConnInfoError::MissingHostname);
|
||||
}
|
||||
};
|
||||
ctx.set_endpoint_id(endpoint.clone());
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
|
||||
let mut options = Option::None;
|
||||
|
||||
let mut params = StartupMessageParams::default();
|
||||
params.insert("user", &username);
|
||||
params.insert("database", &dbname);
|
||||
for (key, value) in pairs {
|
||||
params.insert(&key, &value);
|
||||
if key == "options" {
|
||||
options = Some(NeonOptions::parse_options_raw(&value));
|
||||
}
|
||||
}
|
||||
|
||||
// check the URL that was used, for metrics
|
||||
{
|
||||
let host_endpoint = headers
|
||||
// get the host header
|
||||
.get("host")
|
||||
// extract the domain
|
||||
.and_then(|h| {
|
||||
let (host, _port) = h.to_str().ok()?.split_once(':')?;
|
||||
Some(host)
|
||||
})
|
||||
// get the endpoint prefix
|
||||
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
|
||||
|
||||
let kind = if host_endpoint == Some(&*endpoint) {
|
||||
SniKind::Sni
|
||||
} else {
|
||||
SniKind::NoSni
|
||||
};
|
||||
|
||||
let protocol = ctx.protocol();
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.accepted_connections_by_sni
|
||||
.inc(SniGroup { protocol, kind });
|
||||
}
|
||||
|
||||
ctx.set_user_agent(
|
||||
headers
|
||||
.get(hyper::header::USER_AGENT)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(Into::into),
|
||||
);
|
||||
|
||||
let user_info = ComputeUserInfo {
|
||||
endpoint,
|
||||
user: username,
|
||||
options: options.unwrap_or_default(),
|
||||
};
|
||||
|
||||
let conn_info = ConnInfo { user_info, dbname };
|
||||
Ok(ConnInfoWithAuth { conn_info, auth })
|
||||
}
|
||||
|
||||
@@ -29,13 +29,13 @@ use futures::future::{Either, select};
|
||||
use http::{Method, Response, StatusCode};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Empty};
|
||||
use http_util::{NEON_REQUEST_ID, uuid_to_header_value};
|
||||
use http_utils::error::ApiError;
|
||||
use hyper::body::Incoming;
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
use rand::SeedableRng;
|
||||
use rand::rngs::StdRng;
|
||||
use sql_over_http::{NEON_REQUEST_ID, uuid_to_header_value};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::time::timeout;
|
||||
|
||||
@@ -1,49 +1,45 @@
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::future::{Either, select, try_join};
|
||||
use futures::{StreamExt, TryFutureExt};
|
||||
use http::Method;
|
||||
use http::header::AUTHORIZATION;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use http::{Method, header::AUTHORIZATION};
|
||||
use http_body_util::{BodyExt, Full, combinators::BoxBody};
|
||||
use http_utils::error::ApiError;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::http::{HeaderName, HeaderValue};
|
||||
use hyper::{HeaderMap, Request, Response, StatusCode, header};
|
||||
use hyper::{
|
||||
Request, Response, StatusCode, header,
|
||||
http::{HeaderName, HeaderValue},
|
||||
};
|
||||
use indexmap::IndexMap;
|
||||
use postgres_client::error::{DbError, ErrorPosition, SqlState};
|
||||
use postgres_client::{
|
||||
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::value::RawValue;
|
||||
use serde_json::{Value, value::RawValue};
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{self, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Level, debug, error, info};
|
||||
use typed_json::json;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::backend::{LocalProxyConnError, PoolingBackend};
|
||||
use super::conn_pool::{AuthData, ConnInfoWithAuth};
|
||||
use super::conn_pool::AuthData;
|
||||
use super::conn_pool_lib::{self, ConnInfo};
|
||||
use super::error::HttpCodeError;
|
||||
use super::http_util::json_response;
|
||||
use super::error::{ConnInfoError, HttpCodeError, ReadPayloadError};
|
||||
use super::http_util::{
|
||||
ALLOW_POOL, ARRAY_MODE, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_DEFERRABLE,
|
||||
TXN_ISOLATION_LEVEL, TXN_READ_ONLY, get_conn_info, json_response, uuid_to_header_value,
|
||||
};
|
||||
use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json};
|
||||
use crate::auth::ComputeUserInfoParseError;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig};
|
||||
use crate::auth::backend::ComputeCredentialKeys;
|
||||
|
||||
use crate::config::{HttpConfig, ProxyConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::{ReadBodyError, read_body_with_limit};
|
||||
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::http::read_body_with_limit;
|
||||
use crate::metrics::{HttpDirection, Metrics};
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::types::{DbName, EndpointId, RoleName};
|
||||
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
@@ -70,16 +66,6 @@ enum Payload {
|
||||
Batch(BatchQueryData),
|
||||
}
|
||||
|
||||
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
|
||||
|
||||
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
|
||||
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
|
||||
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
|
||||
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
|
||||
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
|
||||
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
|
||||
static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
|
||||
|
||||
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
||||
|
||||
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
|
||||
@@ -91,179 +77,6 @@ where
|
||||
Ok(json_to_pg_text(json))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum ConnInfoError {
|
||||
#[error("invalid header: {0}")]
|
||||
InvalidHeader(&'static HeaderName),
|
||||
#[error("invalid connection string: {0}")]
|
||||
UrlParseError(#[from] url::ParseError),
|
||||
#[error("incorrect scheme")]
|
||||
IncorrectScheme,
|
||||
#[error("missing database name")]
|
||||
MissingDbName,
|
||||
#[error("invalid database name")]
|
||||
InvalidDbName,
|
||||
#[error("missing username")]
|
||||
MissingUsername,
|
||||
#[error("invalid username: {0}")]
|
||||
InvalidUsername(#[from] std::string::FromUtf8Error),
|
||||
#[error("missing authentication credentials: {0}")]
|
||||
MissingCredentials(Credentials),
|
||||
#[error("missing hostname")]
|
||||
MissingHostname,
|
||||
#[error("invalid hostname: {0}")]
|
||||
InvalidEndpoint(#[from] ComputeUserInfoParseError),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum Credentials {
|
||||
#[error("required password")]
|
||||
Password,
|
||||
#[error("required authorization bearer token in JWT format")]
|
||||
BearerJwt,
|
||||
}
|
||||
|
||||
impl ReportableError for ConnInfoError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ConnInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
config: &'static AuthenticationConfig,
|
||||
ctx: &RequestContext,
|
||||
headers: &HeaderMap,
|
||||
) -> Result<ConnInfoWithAuth, ConnInfoError> {
|
||||
let connection_string = headers
|
||||
.get(&CONN_STRING)
|
||||
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
|
||||
|
||||
let connection_url = Url::parse(connection_string)?;
|
||||
|
||||
let protocol = connection_url.scheme();
|
||||
if protocol != "postgres" && protocol != "postgresql" {
|
||||
return Err(ConnInfoError::IncorrectScheme);
|
||||
}
|
||||
|
||||
let mut url_path = connection_url
|
||||
.path_segments()
|
||||
.ok_or(ConnInfoError::MissingDbName)?;
|
||||
|
||||
let dbname: DbName =
|
||||
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
|
||||
ctx.set_dbname(dbname.clone());
|
||||
|
||||
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
|
||||
if username.is_empty() {
|
||||
return Err(ConnInfoError::MissingUsername);
|
||||
}
|
||||
ctx.set_user(username.clone());
|
||||
|
||||
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
|
||||
if !config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
}
|
||||
|
||||
let auth = auth
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
|
||||
AuthData::Jwt(
|
||||
auth.strip_prefix("Bearer ")
|
||||
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
|
||||
.into(),
|
||||
)
|
||||
} else if let Some(pass) = connection_url.password() {
|
||||
// wrong credentials provided
|
||||
if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
}
|
||||
|
||||
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
|
||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||
std::borrow::Cow::Owned(b) => b.into(),
|
||||
})
|
||||
} else if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
} else {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
};
|
||||
|
||||
let endpoint: EndpointId = match connection_url.host() {
|
||||
Some(url::Host::Domain(hostname)) => hostname
|
||||
.split_once('.')
|
||||
.map_or(hostname, |(prefix, _)| prefix)
|
||||
.into(),
|
||||
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
|
||||
return Err(ConnInfoError::MissingHostname);
|
||||
}
|
||||
};
|
||||
ctx.set_endpoint_id(endpoint.clone());
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
|
||||
let mut options = Option::None;
|
||||
|
||||
let mut params = StartupMessageParams::default();
|
||||
params.insert("user", &username);
|
||||
params.insert("database", &dbname);
|
||||
for (key, value) in pairs {
|
||||
params.insert(&key, &value);
|
||||
if key == "options" {
|
||||
options = Some(NeonOptions::parse_options_raw(&value));
|
||||
}
|
||||
}
|
||||
|
||||
// check the URL that was used, for metrics
|
||||
{
|
||||
let host_endpoint = headers
|
||||
// get the host header
|
||||
.get("host")
|
||||
// extract the domain
|
||||
.and_then(|h| {
|
||||
let (host, _port) = h.to_str().ok()?.split_once(':')?;
|
||||
Some(host)
|
||||
})
|
||||
// get the endpoint prefix
|
||||
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
|
||||
|
||||
let kind = if host_endpoint == Some(&*endpoint) {
|
||||
SniKind::Sni
|
||||
} else {
|
||||
SniKind::NoSni
|
||||
};
|
||||
|
||||
let protocol = ctx.protocol();
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.accepted_connections_by_sni
|
||||
.inc(SniGroup { protocol, kind });
|
||||
}
|
||||
|
||||
ctx.set_user_agent(
|
||||
headers
|
||||
.get(hyper::header::USER_AGENT)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(Into::into),
|
||||
);
|
||||
|
||||
let user_info = ComputeUserInfo {
|
||||
endpoint,
|
||||
user: username,
|
||||
options: options.unwrap_or_default(),
|
||||
};
|
||||
|
||||
let conn_info = ConnInfo { user_info, dbname };
|
||||
Ok(ConnInfoWithAuth { conn_info, auth })
|
||||
}
|
||||
|
||||
pub(crate) async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: RequestContext,
|
||||
@@ -532,45 +345,6 @@ impl HttpCodeError for SqlOverHttpError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum ReadPayloadError {
|
||||
#[error("could not read the HTTP request body: {0}")]
|
||||
Read(#[from] hyper::Error),
|
||||
#[error("request is too large (max is {limit} bytes)")]
|
||||
BodyTooLarge { limit: usize },
|
||||
#[error("could not parse the HTTP request body: {0}")]
|
||||
Parse(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
|
||||
fn from(value: ReadBodyError<hyper::Error>) -> Self {
|
||||
match value {
|
||||
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
|
||||
ReadBodyError::Read(e) => Self::Read(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ReadPayloadError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
|
||||
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
|
||||
ReadPayloadError::Parse(_) => ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpCodeError for ReadPayloadError {
|
||||
fn get_http_status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
|
||||
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
|
||||
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum SqlOverHttpCancel {
|
||||
#[error("query was cancelled")]
|
||||
@@ -661,7 +435,7 @@ async fn handle_inner(
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let conn_info = get_conn_info(&config.authentication_config, ctx, request.headers())?;
|
||||
let conn_info = get_conn_info(&config.authentication_config, ctx, None, request.headers())?;
|
||||
info!(
|
||||
user = conn_info.conn_info.user_info.user.as_str(),
|
||||
"credentials"
|
||||
@@ -747,9 +521,17 @@ async fn handle_db_inner(
|
||||
ComputeCredentialKeys::JwtPayload(payload)
|
||||
if backend.auth_backend.is_local_proxy() =>
|
||||
{
|
||||
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
|
||||
let (cli_inner, _dsc) = client.client_inner();
|
||||
cli_inner.set_jwt_session(&payload).await?;
|
||||
#[cfg(feature = "testing")]
|
||||
let disable_pg_session_jwt = config.disable_pg_session_jwt;
|
||||
#[cfg(not(feature = "testing"))]
|
||||
let disable_pg_session_jwt = false;
|
||||
let mut client = backend
|
||||
.connect_to_local_postgres(ctx, conn_info, disable_pg_session_jwt)
|
||||
.await?;
|
||||
if !disable_pg_session_jwt {
|
||||
let (cli_inner, _dsc) = client.client_inner();
|
||||
cli_inner.set_jwt_session(&payload).await?;
|
||||
}
|
||||
Client::Local(client)
|
||||
}
|
||||
_ => {
|
||||
@@ -848,12 +630,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
|
||||
&TXN_DEFERRABLE,
|
||||
];
|
||||
|
||||
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
|
||||
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
|
||||
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
|
||||
.expect("uuid hyphenated format should be all valid header characters")
|
||||
}
|
||||
|
||||
async fn handle_auth_broker_inner(
|
||||
ctx: &RequestContext,
|
||||
request: Request<Incoming>,
|
||||
@@ -883,7 +659,7 @@ async fn handle_auth_broker_inner(
|
||||
req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
|
||||
|
||||
let req = req
|
||||
.body(body)
|
||||
.body(body.map_err(|e| e).boxed()) //TODO: is there a potential for a regression here?
|
||||
.expect("all headers and params received via hyper should be valid for request");
|
||||
|
||||
// todo: map body to count egress
|
||||
|
||||
Reference in New Issue
Block a user