mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-06 13:02:55 +00:00
cleanup the rest path code
This commit is contained in:
@@ -3,13 +3,28 @@
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use http::{Response, StatusCode, HeaderName, HeaderValue};
|
||||
use http::{Response, StatusCode, HeaderName, HeaderValue, HeaderMap};
|
||||
use http::header::AUTHORIZATION;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use http_utils::error::ApiError;
|
||||
use serde::Serialize;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use crate::config::{AuthenticationConfig, TlsConfig};
|
||||
use crate::auth::backend::{ComputeUserInfo};
|
||||
use crate::auth::{endpoint_sni, };
|
||||
use crate::metrics::{Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::NeonOptions;
|
||||
use super::conn_pool::ConnInfoWithAuth;
|
||||
use super::error::{ConnInfoError, Credentials};
|
||||
use crate::types::{DbName, RoleName};
|
||||
use super::conn_pool::{AuthData};
|
||||
use super::conn_pool_lib::{ConnInfo};
|
||||
|
||||
// Common header names used across serverless modules
|
||||
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
|
||||
pub(super) static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
|
||||
@@ -124,3 +139,144 @@ pub(crate) fn json_response<T: Serialize>(
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
|
||||
pub(crate) fn get_conn_info(
|
||||
config: &'static AuthenticationConfig,
|
||||
ctx: &RequestContext,
|
||||
connection_string: Option<&str>,
|
||||
headers: &HeaderMap,
|
||||
tls: Option<&TlsConfig>,
|
||||
) -> Result<ConnInfoWithAuth, ConnInfoError> {
|
||||
let connection_url = match connection_string {
|
||||
Some(connection_string) => Url::parse(connection_string)?,
|
||||
None => {
|
||||
let connection_string = headers
|
||||
.get(&CONN_STRING)
|
||||
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
|
||||
Url::parse(connection_string)?
|
||||
}
|
||||
};
|
||||
|
||||
let protocol = connection_url.scheme();
|
||||
if protocol != "postgres" && protocol != "postgresql" {
|
||||
return Err(ConnInfoError::IncorrectScheme);
|
||||
}
|
||||
|
||||
let mut url_path = connection_url
|
||||
.path_segments()
|
||||
.ok_or(ConnInfoError::MissingDbName)?;
|
||||
|
||||
let dbname: DbName =
|
||||
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
|
||||
ctx.set_dbname(dbname.clone());
|
||||
|
||||
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
|
||||
if username.is_empty() {
|
||||
return Err(ConnInfoError::MissingUsername);
|
||||
}
|
||||
ctx.set_user(username.clone());
|
||||
// TODO: make sure this is right in the context of rest broker
|
||||
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
|
||||
if !config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
}
|
||||
|
||||
let auth = auth
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
|
||||
AuthData::Jwt(
|
||||
auth.strip_prefix("Bearer ")
|
||||
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
|
||||
.into(),
|
||||
)
|
||||
} else if let Some(pass) = connection_url.password() {
|
||||
// wrong credentials provided
|
||||
if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
}
|
||||
|
||||
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
|
||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||
std::borrow::Cow::Owned(b) => b.into(),
|
||||
})
|
||||
} else if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
} else {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
};
|
||||
let endpoint = match connection_url.host() {
|
||||
Some(url::Host::Domain(hostname)) => {
|
||||
if let Some(tls) = tls {
|
||||
endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)?
|
||||
} else {
|
||||
hostname
|
||||
.split_once('.')
|
||||
.map_or(hostname, |(prefix, _)| prefix)
|
||||
.into()
|
||||
}
|
||||
}
|
||||
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
|
||||
return Err(ConnInfoError::MissingHostname);
|
||||
}
|
||||
};
|
||||
ctx.set_endpoint_id(endpoint.clone());
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
|
||||
let mut options = Option::None;
|
||||
|
||||
let mut params = StartupMessageParams::default();
|
||||
params.insert("user", &username);
|
||||
params.insert("database", &dbname);
|
||||
for (key, value) in pairs {
|
||||
params.insert(&key, &value);
|
||||
if key == "options" {
|
||||
options = Some(NeonOptions::parse_options_raw(&value));
|
||||
}
|
||||
}
|
||||
|
||||
// check the URL that was used, for metrics
|
||||
{
|
||||
let host_endpoint = headers
|
||||
// get the host header
|
||||
.get("host")
|
||||
// extract the domain
|
||||
.and_then(|h| {
|
||||
let (host, _port) = h.to_str().ok()?.split_once(':')?;
|
||||
Some(host)
|
||||
})
|
||||
// get the endpoint prefix
|
||||
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
|
||||
|
||||
let kind = if host_endpoint == Some(&*endpoint) {
|
||||
SniKind::Sni
|
||||
} else {
|
||||
SniKind::NoSni
|
||||
};
|
||||
|
||||
let protocol = ctx.protocol();
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.accepted_connections_by_sni
|
||||
.inc(SniGroup { protocol, kind });
|
||||
}
|
||||
|
||||
ctx.set_user_agent(
|
||||
headers
|
||||
.get(hyper::header::USER_AGENT)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(Into::into),
|
||||
);
|
||||
|
||||
let user_info = ComputeUserInfo {
|
||||
endpoint,
|
||||
user: username,
|
||||
options: options.unwrap_or_default(),
|
||||
};
|
||||
|
||||
let conn_info = ConnInfo { user_info, dbname };
|
||||
Ok(ConnInfoWithAuth { conn_info, auth })
|
||||
}
|
||||
@@ -2,13 +2,9 @@ use std::sync::Arc;
|
||||
use bytes::Bytes;
|
||||
use http::Method;
|
||||
use http::header::AUTHORIZATION;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::Full;
|
||||
use http_body_util::{BodyExt};
|
||||
use http_body_util::{combinators::BoxBody, Full, BodyExt};
|
||||
use http_utils::error::ApiError;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::http::{HeaderName, HeaderValue};
|
||||
use hyper::{HeaderMap, Request, Response, StatusCode};
|
||||
use hyper::{body::Incoming, http::{HeaderName, HeaderValue}, Request, Response, StatusCode};
|
||||
use indexmap::IndexMap;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use super::http_conn_pool::{self, Send,};
|
||||
@@ -16,24 +12,22 @@ use serde_json::{value::RawValue, Value as JsonValue};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info};
|
||||
use typed_json::json;
|
||||
use url::Url;
|
||||
use super::backend::{LocalProxyConnError, PoolingBackend};
|
||||
use super::conn_pool::{AuthData, ConnInfoWithAuth};
|
||||
use super::conn_pool::{AuthData};
|
||||
use super::conn_pool_lib::{ConnInfo};
|
||||
use super::error::{HttpCodeError, ConnInfoError, Credentials, ReadPayloadError};
|
||||
use super::http_util::{json_response, uuid_to_header_value, NEON_REQUEST_ID, CONN_STRING, RAW_TEXT_OUTPUT, ALLOW_POOL, TXN_ISOLATION_LEVEL, TXN_READ_ONLY};
|
||||
use super::http_util::{
|
||||
json_response, uuid_to_header_value, get_conn_info,
|
||||
NEON_REQUEST_ID, CONN_STRING, RAW_TEXT_OUTPUT, ALLOW_POOL, TXN_ISOLATION_LEVEL, TXN_READ_ONLY
|
||||
};
|
||||
use super::json::{JsonConversionError};
|
||||
use crate::auth::backend::{ComputeUserInfo, ComputeCredentialKeys};
|
||||
use crate::auth::{endpoint_sni, };
|
||||
use crate::config::{AuthenticationConfig, ProxyConfig, TlsConfig};
|
||||
use crate::auth::backend::{ComputeCredentialKeys};
|
||||
use crate::config::{ProxyConfig, };
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::{read_body_with_limit};
|
||||
use crate::metrics::{Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::types::{DbName, RoleName};
|
||||
use crate::metrics::{Metrics, };
|
||||
use super::backend::HttpConnError;
|
||||
use crate::cache::{TimedLru};
|
||||
use crate::types::{EndpointCacheKey};
|
||||
use ouroboros::self_referencing;
|
||||
@@ -52,6 +46,7 @@ use subzero_core::{
|
||||
config::{db_schemas, db_allowed_select_functions, role_claim_key, /*to_tuple*/},
|
||||
parser::postgrest::parse,
|
||||
permissions::{check_safe_functions},
|
||||
content_range_header, content_range_status
|
||||
};
|
||||
|
||||
static MAX_SCHEMA_SIZE: usize = 1024 * 1024 * 5; // 5MB
|
||||
@@ -59,9 +54,6 @@ static MAX_HTTP_BODY_SIZE: usize = 10 * 1024 * 1024; // 10MB limit
|
||||
static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
|
||||
const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL;
|
||||
const CONFIGURATION_SQL: &str = POSTGRESQL_CONFIGURATION_SQL;
|
||||
static HEADERS_TO_FORWARD: &[&HeaderName] = &[
|
||||
&AUTHORIZATION,
|
||||
];
|
||||
|
||||
// A wrapper around the DbSchema that allows for self-referencing
|
||||
#[self_referencing]
|
||||
@@ -101,7 +93,7 @@ pub struct ApiConfig {
|
||||
// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
|
||||
pub(crate) type DbSchemaCache = TimedLru<EndpointCacheKey, Arc<(ApiConfig, DbSchemaOwned)>>;
|
||||
impl DbSchemaCache {
|
||||
pub async fn get_local_or_remote(&self,
|
||||
pub async fn get_cached_or_remote(&self,
|
||||
endpoint_id: &EndpointCacheKey,
|
||||
auth_header: &HeaderValue,
|
||||
connection_string: &str,
|
||||
@@ -153,61 +145,29 @@ impl DbSchemaCache {
|
||||
ctx: &RequestContext,
|
||||
) -> Result<(ApiConfig, DbSchemaOwned), RestError> {
|
||||
|
||||
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
|
||||
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
|
||||
req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
|
||||
req = req.header(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap());
|
||||
req = req.header(&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap());
|
||||
req = req.header(AUTHORIZATION, auth_header);
|
||||
//req = req.header(&ARRAY_MODE, HeaderValue::from_str("true").unwrap());
|
||||
req = req.header(&RAW_TEXT_OUTPUT, HeaderValue::from_str("true").unwrap());
|
||||
let headers = vec![
|
||||
(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
|
||||
(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap()),
|
||||
(&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap()),
|
||||
(&AUTHORIZATION, auth_header.clone()),
|
||||
(&RAW_TEXT_OUTPUT, HeaderValue::from_str("true").unwrap()),
|
||||
];
|
||||
|
||||
let body = json!({"query": CONFIGURATION_SQL}).to_string();
|
||||
let body_boxed = Full::new(Bytes::from(body))
|
||||
.map_err(|never| match never {}) // Convert Infallible to hyper::Error
|
||||
.boxed();
|
||||
let req = req
|
||||
.body(body_boxed)
|
||||
.map_err(|_| RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to build request".to_string()
|
||||
}))?;
|
||||
let body = serde_json::json!({"query": CONFIGURATION_SQL});
|
||||
let (response_status, mut response_json) = make_local_proxy_request(client, headers, body).await?;
|
||||
|
||||
// send the request to the local proxy
|
||||
let response = client
|
||||
.inner
|
||||
.inner
|
||||
.send_request(req)
|
||||
.await
|
||||
.map_err(LocalProxyConnError::from)
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let response_status = response.status();
|
||||
|
||||
if response_status != StatusCode::OK {
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to get configuration data".to_string()
|
||||
message: "Failed to get endpoint configuration".to_string()
|
||||
}));
|
||||
}
|
||||
|
||||
// Capture the response body
|
||||
let response_body = response
|
||||
.collect()
|
||||
.await
|
||||
.map_err(ReadPayloadError::from)?
|
||||
.to_bytes();
|
||||
|
||||
//println!("response_body: {:?}", response_body);
|
||||
|
||||
let mut response_json: serde_json::Value = serde_json::from_slice(&response_body)
|
||||
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
|
||||
|
||||
//println!("response_json: {:?}", response_json);
|
||||
let rows = response_json["rows"].as_array_mut()
|
||||
.ok_or_else(|| RestError::SubzeroCore(InternalError {
|
||||
message: "Missing 'rows' array in second result".to_string()
|
||||
}))?;
|
||||
|
||||
//println!("rows: {:?}", rows);
|
||||
|
||||
if rows.is_empty() {
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "No rows in second result".to_string()
|
||||
@@ -217,74 +177,41 @@ impl DbSchemaCache {
|
||||
// Extract columns from the first (and only) row
|
||||
let mut row = &mut rows[0];
|
||||
let config_string = extract_string(&mut row, "config").unwrap_or_default();
|
||||
//println!("config_string: {:?}", config_string);
|
||||
// Parse the JSON response and extract the body content efficiently
|
||||
|
||||
// Parse the configuration response
|
||||
let api_config: ApiConfig = serde_json::from_str(&config_string)
|
||||
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
|
||||
|
||||
//println!("api_config: {:?}", api_config);
|
||||
|
||||
// now that we have the api_config let's run the second INTROSPECTION_SQL query
|
||||
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
|
||||
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
|
||||
req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
|
||||
req = req.header(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap());
|
||||
req = req.header(&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap());
|
||||
req = req.header(AUTHORIZATION, auth_header);
|
||||
//req = req.header(&ARRAY_MODE, HeaderValue::from_str("true").unwrap());
|
||||
req = req.header(&RAW_TEXT_OUTPUT, HeaderValue::from_str("true").unwrap());
|
||||
let body = json!({
|
||||
let headers = vec![
|
||||
(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
|
||||
(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap()),
|
||||
(&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap()),
|
||||
(&AUTHORIZATION, auth_header.clone()),
|
||||
(&RAW_TEXT_OUTPUT, HeaderValue::from_str("true").unwrap()),
|
||||
];
|
||||
|
||||
let body = serde_json::json!({
|
||||
"query": INTROSPECTION_SQL,
|
||||
"params": [
|
||||
&api_config.db_schemas,
|
||||
false, // include_roles_with_login
|
||||
false, // use_internal_permissions
|
||||
]
|
||||
}).to_string();
|
||||
let body_boxed = Full::new(Bytes::from(body))
|
||||
.map_err(|never| match never {}) // Convert Infallible to hyper::Error
|
||||
.boxed();
|
||||
let req = req
|
||||
.body(body_boxed)
|
||||
.map_err(|_| RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to build request".to_string()
|
||||
}))?;
|
||||
|
||||
// send the request to the local proxy
|
||||
let response = client
|
||||
.inner
|
||||
.inner
|
||||
.send_request(req)
|
||||
.await
|
||||
.map_err(LocalProxyConnError::from)
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let response_status = response.status();
|
||||
});
|
||||
let (response_status, mut response_json) = make_local_proxy_request(client, headers, body).await?;
|
||||
|
||||
if response_status != StatusCode::OK {
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to get introspection data".to_string()
|
||||
message: "Failed to get endpoint schema".to_string()
|
||||
}));
|
||||
}
|
||||
|
||||
let response_body = response
|
||||
.collect()
|
||||
.await
|
||||
.map_err(ReadPayloadError::from)?
|
||||
.to_bytes();
|
||||
|
||||
//println!("second response_body: {:?}", response_body);
|
||||
|
||||
let mut response_json: serde_json::Value = serde_json::from_slice(&response_body)
|
||||
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
|
||||
|
||||
//println!("response_json: {:?}", response_json);
|
||||
let rows = response_json["rows"].as_array_mut()
|
||||
.ok_or_else(|| RestError::SubzeroCore(InternalError {
|
||||
message: "Missing 'rows' array in second result".to_string()
|
||||
}))?;
|
||||
|
||||
//println!("rows: {:?}", rows);
|
||||
if rows.is_empty() {
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "No rows in second result".to_string()
|
||||
@@ -297,21 +224,14 @@ impl DbSchemaCache {
|
||||
let string_size = json_schema.len();
|
||||
|
||||
if string_size > MAX_SCHEMA_SIZE {
|
||||
// return Err(RestError::SubzeroCore(InternalError {
|
||||
// message: format!("Schema is too large: {} bytes, max is {} bytes", string_size, MAX_SCHEMA_SIZE)
|
||||
// }));
|
||||
return Err(RestError::SchemaTooLarge(MAX_SCHEMA_SIZE, string_size));
|
||||
}
|
||||
|
||||
|
||||
|
||||
let schema_owned = DbSchemaOwned::new(json_schema, |s| {
|
||||
serde_json::from_str::<DbSchema>(s.as_str())
|
||||
.map_err(|e| JsonDeserialize { source: e })
|
||||
});
|
||||
|
||||
//let schema = schema_owned.borrow_schema().as_ref().unwrap();
|
||||
//println!("schema!!!!!: {:?}", schema);
|
||||
// check if schema is an ok result
|
||||
let schema = schema_owned.borrow_schema();
|
||||
if schema.is_ok() {
|
||||
@@ -323,143 +243,8 @@ impl DbSchemaCache {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
config: &'static AuthenticationConfig,
|
||||
ctx: &RequestContext,
|
||||
connection_string: &str,
|
||||
headers: &HeaderMap,
|
||||
tls: Option<&TlsConfig>,
|
||||
) -> Result<ConnInfoWithAuth, ConnInfoError> {
|
||||
// let connection_string = headers
|
||||
// .get(&CONN_STRING)
|
||||
// .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
|
||||
// .to_str()
|
||||
// .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
|
||||
|
||||
let connection_url = Url::parse(connection_string)?;
|
||||
|
||||
let protocol = connection_url.scheme();
|
||||
if protocol != "postgres" && protocol != "postgresql" {
|
||||
return Err(ConnInfoError::IncorrectScheme);
|
||||
}
|
||||
|
||||
let mut url_path = connection_url
|
||||
.path_segments()
|
||||
.ok_or(ConnInfoError::MissingDbName)?;
|
||||
|
||||
let dbname: DbName =
|
||||
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
|
||||
ctx.set_dbname(dbname.clone());
|
||||
|
||||
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
|
||||
if username.is_empty() {
|
||||
return Err(ConnInfoError::MissingUsername);
|
||||
}
|
||||
ctx.set_user(username.clone());
|
||||
// TODO: make sure this is right in the context of rest broker
|
||||
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
|
||||
if !config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
}
|
||||
|
||||
let auth = auth
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
|
||||
AuthData::Jwt(
|
||||
auth.strip_prefix("Bearer ")
|
||||
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
|
||||
.into(),
|
||||
)
|
||||
} else if let Some(pass) = connection_url.password() {
|
||||
// wrong credentials provided
|
||||
if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
}
|
||||
|
||||
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
|
||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||
std::borrow::Cow::Owned(b) => b.into(),
|
||||
})
|
||||
} else if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
} else {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
};
|
||||
let endpoint = match connection_url.host() {
|
||||
Some(url::Host::Domain(hostname)) => {
|
||||
if let Some(tls) = tls {
|
||||
endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)?
|
||||
} else {
|
||||
hostname
|
||||
.split_once('.')
|
||||
.map_or(hostname, |(prefix, _)| prefix)
|
||||
.into()
|
||||
}
|
||||
}
|
||||
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
|
||||
return Err(ConnInfoError::MissingHostname);
|
||||
}
|
||||
};
|
||||
ctx.set_endpoint_id(endpoint.clone());
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
|
||||
let mut options = Option::None;
|
||||
|
||||
let mut params = StartupMessageParams::default();
|
||||
params.insert("user", &username);
|
||||
params.insert("database", &dbname);
|
||||
for (key, value) in pairs {
|
||||
params.insert(&key, &value);
|
||||
if key == "options" {
|
||||
options = Some(NeonOptions::parse_options_raw(&value));
|
||||
}
|
||||
}
|
||||
|
||||
// check the URL that was used, for metrics
|
||||
{
|
||||
let host_endpoint = headers
|
||||
// get the host header
|
||||
.get("host")
|
||||
// extract the domain
|
||||
.and_then(|h| {
|
||||
let (host, _port) = h.to_str().ok()?.split_once(':')?;
|
||||
Some(host)
|
||||
})
|
||||
// get the endpoint prefix
|
||||
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
|
||||
|
||||
let kind = if host_endpoint == Some(&*endpoint) {
|
||||
SniKind::Sni
|
||||
} else {
|
||||
SniKind::NoSni
|
||||
};
|
||||
|
||||
let protocol = ctx.protocol();
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.accepted_connections_by_sni
|
||||
.inc(SniGroup { protocol, kind });
|
||||
}
|
||||
|
||||
ctx.set_user_agent(
|
||||
headers
|
||||
.get(hyper::header::USER_AGENT)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(Into::into),
|
||||
);
|
||||
|
||||
let user_info = ComputeUserInfo {
|
||||
endpoint,
|
||||
user: username,
|
||||
options: options.unwrap_or_default(),
|
||||
};
|
||||
|
||||
let conn_info = ConnInfo { user_info, dbname };
|
||||
Ok(ConnInfoWithAuth { conn_info, auth })
|
||||
}
|
||||
|
||||
// A type to represent a postgres errors
|
||||
// A type to represent a postgresql errors
|
||||
// we use our own type (instead of postgres_client::Error) because we get the error from the json response
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) struct PostgresError {
|
||||
@@ -568,33 +353,14 @@ impl HttpCodeError for RestError {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for the rest broker
|
||||
|
||||
fn content_range_header(lower: i64, upper: i64, total: Option<i64>) -> String {
|
||||
//debug!("content_range_header: lower: {}, upper: {}, total: {:?}", lower, upper, total);
|
||||
let range_string = if total != Some(0) && lower <= upper {
|
||||
format!("{lower}-{upper}")
|
||||
} else {
|
||||
"*".to_string()
|
||||
};
|
||||
let total_string = match total {
|
||||
Some(t) => format!("{t}"),
|
||||
None => "*".to_string(),
|
||||
};
|
||||
format!("{range_string}/{total_string}")
|
||||
}
|
||||
|
||||
fn content_range_status(lower: i64, upper: i64, total: Option<i64>) -> u16 {
|
||||
//debug!("content_range_status: lower: {}, upper: {}, total: {:?}", lower, upper, total);
|
||||
match (lower, upper, total) {
|
||||
//(_, _, None) => 200,
|
||||
(l, _, Some(t)) if l > t => 406,
|
||||
(l, u, Some(t)) if (1 + u - l) < t => 206,
|
||||
_ => 200,
|
||||
impl From<SubzeroCoreError> for RestError {
|
||||
fn from(e: SubzeroCoreError) -> Self {
|
||||
RestError::SubzeroCore(e)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for the rest broker
|
||||
|
||||
fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> {
|
||||
"select "
|
||||
+ if env.is_empty() {
|
||||
@@ -606,52 +372,6 @@ fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn current_schema(db_schemas: &Vec<String>, method: &Method, headers: &HeaderMap) -> Result<String, SubzeroCoreError> {
|
||||
match (db_schemas.len() > 1, method, headers.get("accept-profile"), headers.get("content-profile")) {
|
||||
(false, ..) => Ok(db_schemas.first().unwrap_or(&"_inexistent_".to_string()).clone()),
|
||||
(_, &Method::DELETE, _, Some(content_profile_header))
|
||||
| (_, &Method::POST, _, Some(content_profile_header))
|
||||
| (_, &Method::PATCH, _, Some(content_profile_header))
|
||||
| (_, &Method::PUT, _, Some(content_profile_header)) => {
|
||||
match content_profile_header.to_str() {
|
||||
Ok(content_profile_str) => {
|
||||
let content_profile = String::from(content_profile_str);
|
||||
if db_schemas.contains(&content_profile) {
|
||||
Ok(content_profile)
|
||||
} else {
|
||||
Err(SubzeroCoreError::UnacceptableSchema {
|
||||
schemas: db_schemas.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
Err(_) => Err(SubzeroCoreError::UnacceptableSchema {
|
||||
schemas: db_schemas.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
(_, _, Some(accept_profile_header), _) => {
|
||||
match accept_profile_header.to_str() {
|
||||
Ok(accept_profile_str) => {
|
||||
let accept_profile = String::from(accept_profile_str);
|
||||
if db_schemas.contains(&accept_profile) {
|
||||
Ok(accept_profile)
|
||||
} else {
|
||||
Err(SubzeroCoreError::UnacceptableSchema {
|
||||
schemas: db_schemas.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
Err(_) => Err(SubzeroCoreError::UnacceptableSchema {
|
||||
schemas: db_schemas.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
_ => Ok(db_schemas.first().unwrap_or(&"_inexistent_".to_string()).clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_core_error(e: SubzeroCoreError) -> RestError { RestError::SubzeroCore(e) }
|
||||
|
||||
// TODO: see about removing the need for cloning the values (inner things are &Cow<str> already)
|
||||
fn to_sql_param(p: &Param) -> JsonValue {
|
||||
match p {
|
||||
@@ -690,6 +410,55 @@ fn extract_string(json: &mut serde_json::Value, key: &str) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
async fn make_local_proxy_request(
|
||||
client: &mut http_conn_pool::Client<Send>,
|
||||
headers: Vec<(&HeaderName, HeaderValue)>,
|
||||
body: JsonValue,
|
||||
) -> Result<(StatusCode, JsonValue), RestError> {
|
||||
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
|
||||
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
|
||||
let req_headers = req.headers_mut().unwrap();
|
||||
// Add all provided headers to the request
|
||||
for (header_name, header_value) in headers.into_iter() {
|
||||
req_headers.insert(header_name, header_value);
|
||||
}
|
||||
|
||||
let body_string = body.to_string();
|
||||
let body_boxed = Full::new(Bytes::from(body_string))
|
||||
.map_err(|never| match never {}) // Convert Infallible to hyper::Error
|
||||
.boxed();
|
||||
|
||||
let req = req
|
||||
.body(body_boxed)
|
||||
.map_err(|_| RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to build request".to_string()
|
||||
}))?;
|
||||
|
||||
// Send the request to the local proxy
|
||||
let response = client
|
||||
.inner
|
||||
.inner
|
||||
.send_request(req)
|
||||
.await
|
||||
.map_err(LocalProxyConnError::from)
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let response_status = response.status();
|
||||
|
||||
// Capture the response body
|
||||
let response_body = response
|
||||
.collect()
|
||||
.await
|
||||
.map_err(ReadPayloadError::from)?
|
||||
.to_bytes();
|
||||
|
||||
// Parse the JSON response
|
||||
let response_json: serde_json::Value = serde_json::from_slice(&response_body)
|
||||
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
|
||||
|
||||
Ok((response_status, response_json))
|
||||
}
|
||||
|
||||
pub(crate) async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: RequestContext,
|
||||
@@ -856,7 +625,7 @@ async fn handle_inner(
|
||||
let conn_info = get_conn_info(
|
||||
&config.authentication_config,
|
||||
ctx,
|
||||
&connection_string,
|
||||
Some(&connection_string),
|
||||
request.headers(),
|
||||
// todo: race condition?
|
||||
// we're unlikely to change the common names.
|
||||
@@ -910,7 +679,7 @@ async fn handle_rest_inner(
|
||||
let (parts, originial_body) = request.into_parts();
|
||||
let headers_map = parts.headers;
|
||||
let auth_header = headers_map.get(AUTHORIZATION).unwrap();
|
||||
let entry = db_schema_cache.get_local_or_remote(&endpoint_cache_key, auth_header, &connection_string, &mut client, &ctx).await?;
|
||||
let entry = db_schema_cache.get_cached_or_remote(&endpoint_cache_key, auth_header, &connection_string, &mut client, &ctx).await?;
|
||||
let (api_config, db_schema_owned) = entry.as_ref();
|
||||
let db_schema = db_schema_owned.borrow_schema().as_ref().map_err(|_| RestError::SubzeroCore(InternalError { message: "Failed to get schema".to_string() }))?;
|
||||
|
||||
@@ -927,8 +696,7 @@ async fn handle_rest_inner(
|
||||
let db_allowed_select_functions = api_config.db_allowed_select_functions.iter().map(|s| s.as_str()).collect::<Vec<_>>();
|
||||
// end hardcoded values
|
||||
|
||||
|
||||
|
||||
|
||||
// extract the jwt claims (we'll need them later to set the role and env)
|
||||
let jwt_claims = match jwt_parsed.keys {
|
||||
ComputeCredentialKeys::JwtPayload(payload_bytes) => {
|
||||
@@ -978,12 +746,13 @@ async fn handle_rest_inner(
|
||||
}?;
|
||||
|
||||
// pick the current schema from the headers (or the first one from config)
|
||||
let schema_name = ¤t_schema(db_schemas, &method, &headers_map).map_err(RestError::SubzeroCore)?;
|
||||
//let schema_name = ¤t_schema(db_schemas, &method, &headers_map).map_err(RestError::SubzeroCore)?;
|
||||
let schema_name = db_schema.pick_current_schema(&method_str, &headers_map).map_err(RestError::SubzeroCore)?;
|
||||
|
||||
// add the content-profile header to the response
|
||||
let mut response_headers = vec![];
|
||||
if db_schemas.len() > 1 {
|
||||
response_headers.push(("Content-Profile".to_string(), schema_name.clone()));
|
||||
response_headers.push(("Content-Profile".to_string(), schema_name.to_string()));
|
||||
}
|
||||
|
||||
// parse the query string into a Vec<(&str, &str)>
|
||||
@@ -994,17 +763,6 @@ async fn handle_rest_inner(
|
||||
let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect();
|
||||
|
||||
|
||||
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
|
||||
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
|
||||
|
||||
// todo(conradludgate): maybe auth-broker should parse these and re-serialize
|
||||
// these instead just to ensure they remain normalised.
|
||||
for &h in HEADERS_TO_FORWARD {
|
||||
if let Some(hv) = headers_map.get(h) {
|
||||
req = req.header(h, hv);
|
||||
}
|
||||
}
|
||||
|
||||
// convert the headers map to a HashMap<&str, &str>
|
||||
let headers: HashMap<&str, &str> = headers_map.iter()
|
||||
.map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__")))
|
||||
@@ -1028,7 +786,7 @@ async fn handle_rest_inner(
|
||||
|
||||
// replace "*" with the list of columns the user has access to
|
||||
// so that he does not encounter permission errors
|
||||
// replace_select_star(db_schema, schema_name, role, &mut api_request.query).map_err(to_core_error)?;
|
||||
// replace_select_star(db_schema, schema_name, role, &mut api_request.query)?;
|
||||
|
||||
let role_str = match role {
|
||||
Some(r) => r,
|
||||
@@ -1036,14 +794,14 @@ async fn handle_rest_inner(
|
||||
};
|
||||
// this is not relevant when acting as PostgREST
|
||||
// if !disable_internal_permissions {
|
||||
// check_privileges(db_schema, schema_name, role_str, &api_request).map_err(to_core_error)?;
|
||||
// check_privileges(db_schema, schema_name, role_str, &api_request)?;
|
||||
// }
|
||||
|
||||
check_safe_functions(&api_request, &db_allowed_select_functions).map_err(to_core_error)?;
|
||||
check_safe_functions(&api_request, &db_allowed_select_functions)?;
|
||||
|
||||
// this is not relevant when acting as PostgREST
|
||||
// if !disable_internal_permissions {
|
||||
// insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query).map_err(to_core_error)?;
|
||||
// insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?;
|
||||
// }
|
||||
|
||||
// when using internal privileges not switch "current_role"
|
||||
@@ -1089,20 +847,24 @@ async fn handle_rest_inner(
|
||||
}
|
||||
// generate the sql statements
|
||||
let (env_statement, env_parameters, _) = generate(fmt_env_query(&env));
|
||||
let (main_statement, main_parameters, _) = generate(fmt_main_query(db_schema, api_request.schema_name, &api_request, &env).map_err(to_core_error)?);
|
||||
let (main_statement, main_parameters, _) = generate(fmt_main_query(db_schema, api_request.schema_name, &api_request, &env)?);
|
||||
|
||||
req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
|
||||
req = req.header(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap());
|
||||
req = req.header(&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap());
|
||||
req = req.header(&ALLOW_POOL, HeaderValue::from_str("true").unwrap());
|
||||
let mut headers = vec![
|
||||
(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
|
||||
(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap()),
|
||||
(&AUTHORIZATION, auth_header.clone()),
|
||||
(&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap()),
|
||||
(&ALLOW_POOL, HeaderValue::from_str("true").unwrap()),
|
||||
];
|
||||
|
||||
if api_request.read_only {
|
||||
req = req.header(&TXN_READ_ONLY, HeaderValue::from_str("true").unwrap());
|
||||
headers.push((&TXN_READ_ONLY, HeaderValue::from_str("true").unwrap()));
|
||||
}
|
||||
|
||||
// convert the parameters from subzero core representation to a Vec<JsonValue>
|
||||
let env_parameters_json = env_parameters.iter().map(|p| to_sql_param(&p.to_param())).collect::<Vec<_>>();
|
||||
let main_parameters_json = main_parameters.iter().map(|p| to_sql_param(&p.to_param())).collect::<Vec<_>>();
|
||||
let body: String = json!({
|
||||
let body = serde_json::json!({
|
||||
"queries": [
|
||||
{
|
||||
"query": env_statement,
|
||||
@@ -1113,42 +875,13 @@ async fn handle_rest_inner(
|
||||
"params": main_parameters_json,
|
||||
}
|
||||
]
|
||||
}).to_string();
|
||||
|
||||
let body_boxed = Full::new(Bytes::from(body))
|
||||
.map_err(|never| match never {}) // Convert Infallible to hyper::Error
|
||||
.boxed();
|
||||
|
||||
let req = req
|
||||
.body(body_boxed)
|
||||
.map_err(|_| RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to build request".to_string()
|
||||
}))?;
|
||||
});
|
||||
|
||||
// todo: map body to count egress
|
||||
let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
|
||||
|
||||
// send the request to the local proxy
|
||||
let response = client
|
||||
.inner
|
||||
.inner
|
||||
.send_request(req)
|
||||
.await
|
||||
.map_err(LocalProxyConnError::from)
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let response_status = response.status();
|
||||
|
||||
// Capture the response body
|
||||
let response_body = response
|
||||
.collect()
|
||||
.await
|
||||
.map_err(ReadPayloadError::from)?
|
||||
.to_bytes();
|
||||
|
||||
// Parse the JSON response and extract the body content efficiently
|
||||
let mut response_json: serde_json::Value = serde_json::from_slice(&response_body)
|
||||
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
|
||||
let (response_status, mut response_json) = make_local_proxy_request(&mut client, headers, body).await?;
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -3,44 +3,40 @@ use std::sync::Arc;
|
||||
use bytes::Bytes;
|
||||
use futures::future::{Either, select, try_join};
|
||||
use futures::{StreamExt, TryFutureExt};
|
||||
use http::Method;
|
||||
use http::header::AUTHORIZATION;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use http::{Method, header::AUTHORIZATION};
|
||||
use http_body_util::{combinators::BoxBody, Full, BodyExt};
|
||||
use http_utils::error::ApiError;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::http::{HeaderName, HeaderValue};
|
||||
use hyper::{HeaderMap, Request, Response, StatusCode, header};
|
||||
use hyper::{http::{HeaderName, HeaderValue}, Request, Response, StatusCode, header};
|
||||
use indexmap::IndexMap;
|
||||
use postgres_client::error::{DbError, ErrorPosition, SqlState};
|
||||
use postgres_client::{
|
||||
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::value::RawValue;
|
||||
use serde_json::{Value, value::RawValue};
|
||||
use tokio::time::{self, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info};
|
||||
use typed_json::json;
|
||||
use url::Url;
|
||||
|
||||
use super::backend::{LocalProxyConnError, PoolingBackend};
|
||||
use super::conn_pool::{AuthData, ConnInfoWithAuth};
|
||||
use super::conn_pool::{AuthData,};
|
||||
use super::conn_pool_lib::{self, ConnInfo};
|
||||
use super::error::{HttpCodeError, ConnInfoError, Credentials, ReadPayloadError};
|
||||
use super::http_util::{json_response, uuid_to_header_value, NEON_REQUEST_ID, CONN_STRING, RAW_TEXT_OUTPUT, ARRAY_MODE, ALLOW_POOL, TXN_ISOLATION_LEVEL, TXN_READ_ONLY, TXN_DEFERRABLE};
|
||||
use super::error::{HttpCodeError, ConnInfoError, ReadPayloadError};
|
||||
use super::http_util::{
|
||||
json_response, uuid_to_header_value, get_conn_info,
|
||||
NEON_REQUEST_ID, CONN_STRING, RAW_TEXT_OUTPUT, ARRAY_MODE, ALLOW_POOL, TXN_ISOLATION_LEVEL, TXN_READ_ONLY, TXN_DEFERRABLE
|
||||
};
|
||||
use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json};
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::{endpoint_sni};
|
||||
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
|
||||
use crate::auth::backend::{ComputeCredentialKeys,};
|
||||
|
||||
use crate::config::{HttpConfig, ProxyConfig,};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::{read_body_with_limit};
|
||||
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::metrics::{HttpDirection, Metrics, };
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::types::{DbName, RoleName};
|
||||
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
@@ -78,142 +74,6 @@ where
|
||||
Ok(json_to_pg_text(json))
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
config: &'static AuthenticationConfig,
|
||||
ctx: &RequestContext,
|
||||
headers: &HeaderMap,
|
||||
tls: Option<&TlsConfig>,
|
||||
) -> Result<ConnInfoWithAuth, ConnInfoError> {
|
||||
let connection_string = headers
|
||||
.get(&CONN_STRING)
|
||||
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
|
||||
|
||||
let connection_url = Url::parse(connection_string)?;
|
||||
|
||||
let protocol = connection_url.scheme();
|
||||
if protocol != "postgres" && protocol != "postgresql" {
|
||||
return Err(ConnInfoError::IncorrectScheme);
|
||||
}
|
||||
|
||||
let mut url_path = connection_url
|
||||
.path_segments()
|
||||
.ok_or(ConnInfoError::MissingDbName)?;
|
||||
|
||||
let dbname: DbName =
|
||||
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
|
||||
ctx.set_dbname(dbname.clone());
|
||||
|
||||
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
|
||||
if username.is_empty() {
|
||||
return Err(ConnInfoError::MissingUsername);
|
||||
}
|
||||
ctx.set_user(username.clone());
|
||||
|
||||
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
|
||||
if !config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
}
|
||||
|
||||
let auth = auth
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
|
||||
AuthData::Jwt(
|
||||
auth.strip_prefix("Bearer ")
|
||||
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
|
||||
.into(),
|
||||
)
|
||||
} else if let Some(pass) = connection_url.password() {
|
||||
// wrong credentials provided
|
||||
if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
}
|
||||
|
||||
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
|
||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||
std::borrow::Cow::Owned(b) => b.into(),
|
||||
})
|
||||
} else if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
} else {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
};
|
||||
|
||||
let endpoint = match connection_url.host() {
|
||||
Some(url::Host::Domain(hostname)) => {
|
||||
if let Some(tls) = tls {
|
||||
endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)?
|
||||
} else {
|
||||
hostname
|
||||
.split_once('.')
|
||||
.map_or(hostname, |(prefix, _)| prefix)
|
||||
.into()
|
||||
}
|
||||
}
|
||||
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
|
||||
return Err(ConnInfoError::MissingHostname);
|
||||
}
|
||||
};
|
||||
ctx.set_endpoint_id(endpoint.clone());
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
|
||||
let mut options = Option::None;
|
||||
|
||||
let mut params = StartupMessageParams::default();
|
||||
params.insert("user", &username);
|
||||
params.insert("database", &dbname);
|
||||
for (key, value) in pairs {
|
||||
params.insert(&key, &value);
|
||||
if key == "options" {
|
||||
options = Some(NeonOptions::parse_options_raw(&value));
|
||||
}
|
||||
}
|
||||
|
||||
// check the URL that was used, for metrics
|
||||
{
|
||||
let host_endpoint = headers
|
||||
// get the host header
|
||||
.get("host")
|
||||
// extract the domain
|
||||
.and_then(|h| {
|
||||
let (host, _port) = h.to_str().ok()?.split_once(':')?;
|
||||
Some(host)
|
||||
})
|
||||
// get the endpoint prefix
|
||||
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
|
||||
|
||||
let kind = if host_endpoint == Some(&*endpoint) {
|
||||
SniKind::Sni
|
||||
} else {
|
||||
SniKind::NoSni
|
||||
};
|
||||
|
||||
let protocol = ctx.protocol();
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.accepted_connections_by_sni
|
||||
.inc(SniGroup { protocol, kind });
|
||||
}
|
||||
|
||||
ctx.set_user_agent(
|
||||
headers
|
||||
.get(hyper::header::USER_AGENT)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(Into::into),
|
||||
);
|
||||
|
||||
let user_info = ComputeUserInfo {
|
||||
endpoint,
|
||||
user: username,
|
||||
options: options.unwrap_or_default(),
|
||||
};
|
||||
|
||||
let conn_info = ConnInfo { user_info, dbname };
|
||||
Ok(ConnInfoWithAuth { conn_info, auth })
|
||||
}
|
||||
|
||||
pub(crate) async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: RequestContext,
|
||||
@@ -544,6 +404,7 @@ async fn handle_inner(
|
||||
let conn_info = get_conn_info(
|
||||
&config.authentication_config,
|
||||
ctx,
|
||||
None,
|
||||
request.headers(),
|
||||
// todo: race condition?
|
||||
// we're unlikely to change the common names.
|
||||
|
||||
Reference in New Issue
Block a user