mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-04 12:02:55 +00:00
move common error types and http realted functions to error.rs and http_util.rs
This commit is contained in:
@@ -1,5 +1,98 @@
|
||||
use http::StatusCode;
|
||||
use http::header::HeaderName;
|
||||
|
||||
use crate::auth::ComputeUserInfoParseError;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::ReadBodyError;
|
||||
|
||||
pub trait HttpCodeError {
|
||||
fn get_http_status_code(&self) -> StatusCode;
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum ConnInfoError {
|
||||
#[error("invalid header: {0}")]
|
||||
InvalidHeader(&'static HeaderName),
|
||||
#[error("invalid connection string: {0}")]
|
||||
UrlParseError(#[from] url::ParseError),
|
||||
#[error("incorrect scheme")]
|
||||
IncorrectScheme,
|
||||
#[error("missing database name")]
|
||||
MissingDbName,
|
||||
#[error("invalid database name")]
|
||||
InvalidDbName,
|
||||
#[error("missing username")]
|
||||
MissingUsername,
|
||||
#[error("invalid username: {0}")]
|
||||
InvalidUsername(#[from] std::string::FromUtf8Error),
|
||||
#[error("missing authentication credentials: {0}")]
|
||||
MissingCredentials(Credentials),
|
||||
#[error("missing hostname")]
|
||||
MissingHostname,
|
||||
#[error("invalid hostname: {0}")]
|
||||
InvalidEndpoint(#[from] ComputeUserInfoParseError),
|
||||
#[error("malformed endpoint")]
|
||||
MalformedEndpoint,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum Credentials {
|
||||
#[error("required password")]
|
||||
Password,
|
||||
#[error("required authorization bearer token in JWT format")]
|
||||
BearerJwt,
|
||||
}
|
||||
|
||||
impl ReportableError for ConnInfoError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ConnInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum ReadPayloadError {
|
||||
#[error("could not read the HTTP request body: {0}")]
|
||||
Read(#[from] hyper::Error),
|
||||
#[error("request is too large (max is {limit} bytes)")]
|
||||
BodyTooLarge { limit: usize },
|
||||
#[error("could not parse the HTTP request body: {0}")]
|
||||
Parse(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
|
||||
fn from(value: ReadBodyError<hyper::Error>) -> Self {
|
||||
match value {
|
||||
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
|
||||
ReadBodyError::Read(e) => Self::Read(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ReadPayloadError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
|
||||
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
|
||||
ReadPayloadError::Parse(_) => ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpCodeError for ReadPayloadError {
|
||||
fn get_http_status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
|
||||
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
|
||||
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -3,11 +3,28 @@
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use http::{Response, StatusCode};
|
||||
use http::{Response, StatusCode, HeaderName, HeaderValue};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use http_utils::error::ApiError;
|
||||
use serde::Serialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
// Common header names used across serverless modules
|
||||
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
|
||||
pub(super) static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
|
||||
pub(super) static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
|
||||
pub(super) static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
|
||||
pub(super) static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
|
||||
pub(super) static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
|
||||
pub(super) static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
|
||||
pub(super) static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
|
||||
|
||||
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
|
||||
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
|
||||
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
|
||||
.expect("uuid hyphenated format should be all valid header characters")
|
||||
}
|
||||
|
||||
/// Like [`ApiError::into_response`]
|
||||
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
|
||||
@@ -36,7 +36,7 @@ use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
use rand::SeedableRng;
|
||||
use rand::rngs::StdRng;
|
||||
use sql_over_http::{NEON_REQUEST_ID, uuid_to_header_value};
|
||||
use http_util::{NEON_REQUEST_ID, uuid_to_header_value};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::time::timeout;
|
||||
|
||||
@@ -13,25 +13,22 @@ use indexmap::IndexMap;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use super::http_conn_pool::{self, Send,};
|
||||
use serde_json::{value::RawValue, Value as JsonValue};
|
||||
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info};
|
||||
use typed_json::json;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::backend::{LocalProxyConnError, PoolingBackend};
|
||||
use super::conn_pool::{AuthData, ConnInfoWithAuth};
|
||||
use super::conn_pool_lib::{ConnInfo};
|
||||
use super::error::HttpCodeError;
|
||||
use super::http_util::json_response;
|
||||
use super::error::{HttpCodeError, ConnInfoError, Credentials, ReadPayloadError};
|
||||
use super::http_util::{json_response, uuid_to_header_value, NEON_REQUEST_ID, CONN_STRING, RAW_TEXT_OUTPUT, ALLOW_POOL, TXN_ISOLATION_LEVEL, TXN_READ_ONLY};
|
||||
use super::json::{JsonConversionError};
|
||||
use crate::auth::backend::{ComputeUserInfo, ComputeCredentialKeys};
|
||||
use crate::auth::{ComputeUserInfoParseError, endpoint_sni, };
|
||||
use crate::auth::{endpoint_sni, };
|
||||
use crate::config::{AuthenticationConfig, ProxyConfig, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::{ReadBodyError, read_body_with_limit};
|
||||
use crate::http::{read_body_with_limit};
|
||||
use crate::metrics::{Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::NeonOptions;
|
||||
@@ -39,48 +36,41 @@ use crate::serverless::backend::HttpConnError;
|
||||
use crate::types::{DbName, RoleName};
|
||||
use crate::cache::{TimedLru};
|
||||
use crate::types::{EndpointCacheKey};
|
||||
|
||||
use ouroboros::self_referencing;
|
||||
use std::collections::HashMap;
|
||||
use jsonpath_lib::select;
|
||||
use url::form_urlencoded;
|
||||
use subzero_core::{
|
||||
api::{SingleVal, ListVal, Payload},
|
||||
error::Error::{self as SubzeroCoreError, JsonDeserialize, NotFound, JwtTokenInvalid, InternalError, GucHeadersError, GucStatusError, ContentTypeError},
|
||||
schema::{DbSchema},
|
||||
formatter::{
|
||||
Param,
|
||||
Param::*,
|
||||
postgresql::{fmt_main_query, generate},
|
||||
Snippet, SqlParam,
|
||||
api::{ApiResponse, SingleVal, ListVal, Payload, ContentType::*, Preferences, QueryNode::*, Representation, Resolution::*,},
|
||||
error::Error::{
|
||||
self as SubzeroCoreError, JsonDeserialize, NotFound, JwtTokenInvalid, InternalError, GucHeadersError, GucStatusError, ContentTypeError,
|
||||
},
|
||||
error::{pg_error_to_status_code},
|
||||
schema::{DbSchema},
|
||||
formatter::{Param, Param::*, postgresql::{fmt_main_query, generate}, Snippet, SqlParam},
|
||||
dynamic_statement::{param, sql, JoinIterator},
|
||||
config::{db_schemas, db_allowed_select_functions, role_claim_key, /*to_tuple*/},
|
||||
};
|
||||
use subzero_core::{
|
||||
api::{ContentType::*, Preferences, QueryNode::*, Representation, Resolution::*, },
|
||||
error::{*, pg_error_to_status_code},
|
||||
parser::postgrest::parse,
|
||||
permissions::{check_safe_functions},
|
||||
api::ApiResponse,
|
||||
};
|
||||
use ouroboros::self_referencing;
|
||||
|
||||
static MAX_SCHEMA_SIZE: usize = 1024 * 1024 * 5; // 5MB
|
||||
static MAX_HTTP_BODY_SIZE: usize = 10 * 1024 * 1024; // 10MB limit
|
||||
static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
|
||||
const INTROSPECTION_SQL: &str = include_str!("../../../../subzero/introspection/postgresql_introspection_query.sql");
|
||||
const CONFIGURATION_SQL: &str = include_str!("../../../../subzero/introspection/postgresql_configuration_query.sql");
|
||||
static HEADERS_TO_FORWARD: &[&HeaderName] = &[
|
||||
&AUTHORIZATION,
|
||||
];
|
||||
|
||||
// A wrapper around the DbSchema that allows for self-referencing
|
||||
#[self_referencing]
|
||||
pub struct DbSchemaOwned {
|
||||
schema_string: String,
|
||||
#[covariant]
|
||||
#[borrows(schema_string)]
|
||||
schema: Result<DbSchema<'this>>,
|
||||
schema: Result<DbSchema<'this>, SubzeroCoreError>,
|
||||
}
|
||||
use std::collections::HashMap;
|
||||
use jsonpath_lib::select;
|
||||
use url::form_urlencoded;
|
||||
|
||||
static MAX_SCHEMA_SIZE: usize = 1024 * 1024 * 5; // 5MB
|
||||
static MAX_HTTP_BODY_SIZE: usize = 10 * 1024 * 1024; // 10MB limit
|
||||
|
||||
static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
|
||||
const INTROSPECTION_SQL: &str = include_str!("../../../../subzero/introspection/postgresql_introspection_query.sql");
|
||||
const CONFIGURATION_SQL: &str = include_str!("../../../../subzero/introspection/postgresql_configuration_query.sql");
|
||||
|
||||
|
||||
fn deserialize_comma_separated<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
@@ -91,6 +81,8 @@ where
|
||||
.collect())
|
||||
}
|
||||
|
||||
// The ApiConfig is the configuration for the API per endpoint
|
||||
// The configuration is read from the database and cached in the DbSchemaCache
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ApiConfig {
|
||||
#[serde(default = "db_schemas", deserialize_with = "deserialize_comma_separated")]
|
||||
@@ -105,9 +97,10 @@ pub struct ApiConfig {
|
||||
pub role_claim_key: String,
|
||||
pub db_extra_search_path: Option<String>,
|
||||
}
|
||||
|
||||
// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
|
||||
pub(crate) type DbSchemaCache = TimedLru<EndpointCacheKey, Arc<(ApiConfig, DbSchemaOwned)>>;
|
||||
impl DbSchemaCache {
|
||||
|
||||
pub async fn get_local_or_remote(&self,
|
||||
endpoint_id: &EndpointCacheKey,
|
||||
auth_header: &HeaderValue,
|
||||
@@ -329,64 +322,6 @@ impl DbSchemaCache {
|
||||
}
|
||||
}
|
||||
}
|
||||
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
|
||||
|
||||
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
|
||||
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
|
||||
//static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
|
||||
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
|
||||
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
|
||||
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
|
||||
//static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
|
||||
|
||||
//static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
||||
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum ConnInfoError {
|
||||
#[error("invalid header: {0}")]
|
||||
InvalidHeader(&'static HeaderName),
|
||||
#[error("invalid connection string: {0}")]
|
||||
UrlParseError(#[from] url::ParseError),
|
||||
#[error("incorrect scheme")]
|
||||
IncorrectScheme,
|
||||
#[error("missing database name")]
|
||||
MissingDbName,
|
||||
#[error("invalid database name")]
|
||||
InvalidDbName,
|
||||
#[error("missing username")]
|
||||
MissingUsername,
|
||||
#[error("invalid username: {0}")]
|
||||
InvalidUsername(#[from] std::string::FromUtf8Error),
|
||||
#[error("missing authentication credentials: {0}")]
|
||||
MissingCredentials(Credentials),
|
||||
#[error("missing hostname")]
|
||||
MissingHostname,
|
||||
#[error("invalid hostname: {0}")]
|
||||
InvalidEndpoint(#[from] ComputeUserInfoParseError),
|
||||
#[error("malformed endpoint")]
|
||||
MalformedEndpoint,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum Credentials {
|
||||
#[error("required password")]
|
||||
Password,
|
||||
#[error("required authorization bearer token in JWT format")]
|
||||
BearerJwt,
|
||||
}
|
||||
|
||||
impl ReportableError for ConnInfoError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ConnInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
config: &'static AuthenticationConfig,
|
||||
@@ -524,8 +459,8 @@ fn get_conn_info(
|
||||
Ok(ConnInfoWithAuth { conn_info, auth })
|
||||
}
|
||||
|
||||
|
||||
// we use our own type because we get the error from the json response
|
||||
// A type to represent a postgres errors
|
||||
// we use our own type (instead of postgres_client::Error) because we get the error from the json response
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) struct PostgresError {
|
||||
pub code: String,
|
||||
@@ -553,15 +488,13 @@ impl UserFacingError for PostgresError {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PostgresError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.message)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// A type to represent errors that can occur in the rest broker
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum RestError {
|
||||
#[error("{0}")]
|
||||
@@ -570,66 +503,37 @@ pub(crate) enum RestError {
|
||||
ConnectCompute(#[from] HttpConnError),
|
||||
#[error("{0}")]
|
||||
ConnInfo(#[from] ConnInfoError),
|
||||
//#[error("response is too large (max is {0} bytes)")]
|
||||
//ResponseTooLarge(usize),
|
||||
//#[error("invalid isolation level")]
|
||||
//InvalidIsolationLevel,
|
||||
/// for queries our customers choose to run
|
||||
//#[error("{0}")]
|
||||
//Postgres(#[source] postgres_client::Error),
|
||||
#[error("{0}")]
|
||||
Postgres(#[source] PostgresError),
|
||||
/// for queries we choose to run
|
||||
//#[error("{0}")]
|
||||
//InternalPostgres(#[source] postgres_client::Error),
|
||||
#[error("{0}")]
|
||||
JsonConversion(#[from] JsonConversionError),
|
||||
//#[error("{0}")]
|
||||
//Cancelled(SqlOverHttpCancel),
|
||||
#[error("{0}")]
|
||||
SubzeroCore(#[source] SubzeroCoreError),
|
||||
|
||||
#[error("schema is too large (max is {0} bytes, current is {1} bytes)")]
|
||||
SchemaTooLarge(usize, usize),
|
||||
}
|
||||
|
||||
impl ReportableError for RestError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
RestError::ReadPayload(e) => e.get_error_kind(),
|
||||
RestError::ConnectCompute(e) => e.get_error_kind(),
|
||||
RestError::ConnInfo(e) => e.get_error_kind(),
|
||||
//RestError::ResponseTooLarge(_) => ErrorKind::User,
|
||||
//RestError::InvalidIsolationLevel => ErrorKind::User,
|
||||
RestError::Postgres(_) => ErrorKind::Postgres,
|
||||
// RestError::InternalPostgres(p) => {
|
||||
// if p.as_db_error().is_some() {
|
||||
// ErrorKind::Service
|
||||
// } else {
|
||||
// ErrorKind::Compute
|
||||
// }
|
||||
// }
|
||||
RestError::JsonConversion(_) => ErrorKind::Postgres,
|
||||
//RestError::Cancelled(c) => c.get_error_kind(),
|
||||
RestError::SubzeroCore(_) => ErrorKind::User,
|
||||
RestError::SchemaTooLarge(_, _) => ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for RestError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
RestError::ReadPayload(p) => p.to_string(),
|
||||
RestError::ConnectCompute(c) => c.to_string_client(),
|
||||
RestError::ConnInfo(c) => c.to_string_client(),
|
||||
//RestError::ResponseTooLarge(_) => self.to_string(),
|
||||
RestError::SchemaTooLarge(_, _) => self.to_string(),
|
||||
//RestError::InvalidIsolationLevel => self.to_string(),
|
||||
RestError::Postgres(p) => p.to_string_client(),
|
||||
//RestError::InternalPostgres(p) => p.to_string(),
|
||||
RestError::JsonConversion(_) => "could not parse postgres response".to_string(),
|
||||
//RestError::Cancelled(_) => self.to_string(),
|
||||
RestError::SubzeroCore(s) => {
|
||||
// TODO: this is a hack to get the message from the json body
|
||||
let json = s.json_body();
|
||||
@@ -645,7 +549,6 @@ impl UserFacingError for RestError {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpCodeError for RestError {
|
||||
fn get_http_status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
@@ -655,13 +558,9 @@ impl HttpCodeError for RestError {
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
},
|
||||
RestError::ConnInfo(_) => StatusCode::BAD_REQUEST,
|
||||
//RestError::ResponseTooLarge(_) => StatusCode::INSUFFICIENT_STORAGE,
|
||||
//RestError::InvalidIsolationLevel => StatusCode::BAD_REQUEST,
|
||||
RestError::Postgres(e) => e.get_http_status_code(),
|
||||
//RestError::InternalPostgres(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
RestError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
RestError::SchemaTooLarge(_, _) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
//RestError::Cancelled(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
RestError::SubzeroCore(e) => {
|
||||
let status = e.status_code();
|
||||
StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
@@ -670,58 +569,7 @@ impl HttpCodeError for RestError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum ReadPayloadError {
|
||||
#[error("could not read the HTTP request body: {0}")]
|
||||
Read(#[from] hyper::Error),
|
||||
#[error("request is too large (max is {limit} bytes)")]
|
||||
BodyTooLarge { limit: usize },
|
||||
#[error("could not parse the HTTP request body: {0}")]
|
||||
Parse(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
|
||||
fn from(value: ReadBodyError<hyper::Error>) -> Self {
|
||||
match value {
|
||||
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
|
||||
ReadBodyError::Read(e) => Self::Read(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ReadPayloadError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
|
||||
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
|
||||
ReadPayloadError::Parse(_) => ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpCodeError for ReadPayloadError {
|
||||
fn get_http_status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
|
||||
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
|
||||
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
|
||||
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
|
||||
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
|
||||
.expect("uuid hyphenated format should be all valid header characters")
|
||||
}
|
||||
|
||||
static HEADERS_TO_FORWARD: &[&HeaderName] = &[
|
||||
&AUTHORIZATION,
|
||||
];
|
||||
|
||||
|
||||
// Helper functions for the rest broker
|
||||
|
||||
fn content_range_header(lower: i64, upper: i64, total: Option<i64>) -> String {
|
||||
//debug!("content_range_header: lower: {}, upper: {}, total: {:?}", lower, upper, total);
|
||||
@@ -746,6 +594,7 @@ fn content_range_status(lower: i64, upper: i64, total: Option<i64>) -> u16 {
|
||||
_ => 200,
|
||||
}
|
||||
}
|
||||
|
||||
fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> {
|
||||
"select "
|
||||
+ if env.is_empty() {
|
||||
@@ -803,6 +652,7 @@ fn current_schema(db_schemas: &Vec<String>, method: &Method, headers: &HeaderMap
|
||||
|
||||
fn to_core_error(e: SubzeroCoreError) -> RestError { RestError::SubzeroCore(e) }
|
||||
|
||||
// TODO: see about removing the need for cloning the values (inner things are &Cow<str> already)
|
||||
fn to_sql_param(p: &Param) -> JsonValue {
|
||||
match p {
|
||||
SV(SingleVal(v, ..)) => {
|
||||
@@ -833,6 +683,12 @@ fn to_sql_param(p: &Param) -> JsonValue {
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_string(json: &mut serde_json::Value, key: &str) -> Option<String> {
|
||||
match json[key].take() {
|
||||
JsonValue::String(s) => Some(s),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
@@ -974,7 +830,6 @@ pub(crate) async fn handle(
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
|
||||
async fn handle_inner(
|
||||
_cancel: CancellationToken,
|
||||
config: &'static ProxyConfig,
|
||||
@@ -1022,15 +877,6 @@ async fn handle_inner(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Helper function to extract optional string from JSON
|
||||
fn extract_string(json: &mut serde_json::Value, key: &str) -> Option<String> {
|
||||
match json[key].take() {
|
||||
JsonValue::String(s) => Some(s),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_rest_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestContext,
|
||||
@@ -1504,7 +1350,6 @@ async fn handle_rest_inner(
|
||||
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
//use super::*;
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::future::{Either, select, try_join};
|
||||
use futures::{StreamExt, TryFutureExt};
|
||||
@@ -25,20 +24,18 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info};
|
||||
use typed_json::json;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::backend::{LocalProxyConnError, PoolingBackend};
|
||||
use super::conn_pool::{AuthData, ConnInfoWithAuth};
|
||||
use super::conn_pool_lib::{self, ConnInfo};
|
||||
use super::error::HttpCodeError;
|
||||
use super::http_util::json_response;
|
||||
use super::error::{HttpCodeError, ConnInfoError, Credentials, ReadPayloadError};
|
||||
use super::http_util::{json_response, uuid_to_header_value, NEON_REQUEST_ID, CONN_STRING, RAW_TEXT_OUTPUT, ARRAY_MODE, ALLOW_POOL, TXN_ISOLATION_LEVEL, TXN_READ_ONLY, TXN_DEFERRABLE};
|
||||
use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json};
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::{ComputeUserInfoParseError, endpoint_sni};
|
||||
use crate::auth::{endpoint_sni};
|
||||
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::{ReadBodyError, read_body_with_limit};
|
||||
use crate::http::{read_body_with_limit};
|
||||
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::NeonOptions;
|
||||
@@ -70,16 +67,6 @@ enum Payload {
|
||||
Batch(BatchQueryData),
|
||||
}
|
||||
|
||||
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
|
||||
|
||||
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
|
||||
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
|
||||
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
|
||||
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
|
||||
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
|
||||
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
|
||||
static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
|
||||
|
||||
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
||||
|
||||
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
|
||||
@@ -91,52 +78,6 @@ where
|
||||
Ok(json_to_pg_text(json))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum ConnInfoError {
|
||||
#[error("invalid header: {0}")]
|
||||
InvalidHeader(&'static HeaderName),
|
||||
#[error("invalid connection string: {0}")]
|
||||
UrlParseError(#[from] url::ParseError),
|
||||
#[error("incorrect scheme")]
|
||||
IncorrectScheme,
|
||||
#[error("missing database name")]
|
||||
MissingDbName,
|
||||
#[error("invalid database name")]
|
||||
InvalidDbName,
|
||||
#[error("missing username")]
|
||||
MissingUsername,
|
||||
#[error("invalid username: {0}")]
|
||||
InvalidUsername(#[from] std::string::FromUtf8Error),
|
||||
#[error("missing authentication credentials: {0}")]
|
||||
MissingCredentials(Credentials),
|
||||
#[error("missing hostname")]
|
||||
MissingHostname,
|
||||
#[error("invalid hostname: {0}")]
|
||||
InvalidEndpoint(#[from] ComputeUserInfoParseError),
|
||||
#[error("malformed endpoint")]
|
||||
MalformedEndpoint,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum Credentials {
|
||||
#[error("required password")]
|
||||
Password,
|
||||
#[error("required authorization bearer token in JWT format")]
|
||||
BearerJwt,
|
||||
}
|
||||
|
||||
impl ReportableError for ConnInfoError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ConnInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
config: &'static AuthenticationConfig,
|
||||
ctx: &RequestContext,
|
||||
@@ -509,44 +450,6 @@ impl HttpCodeError for SqlOverHttpError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum ReadPayloadError {
|
||||
#[error("could not read the HTTP request body: {0}")]
|
||||
Read(#[from] hyper::Error),
|
||||
#[error("request is too large (max is {limit} bytes)")]
|
||||
BodyTooLarge { limit: usize },
|
||||
#[error("could not parse the HTTP request body: {0}")]
|
||||
Parse(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
|
||||
fn from(value: ReadBodyError<hyper::Error>) -> Self {
|
||||
match value {
|
||||
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
|
||||
ReadBodyError::Read(e) => Self::Read(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ReadPayloadError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
|
||||
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
|
||||
ReadPayloadError::Parse(_) => ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpCodeError for ReadPayloadError {
|
||||
fn get_http_status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
|
||||
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
|
||||
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum SqlOverHttpCancel {
|
||||
@@ -834,11 +737,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
|
||||
&TXN_DEFERRABLE,
|
||||
];
|
||||
|
||||
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
|
||||
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
|
||||
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
|
||||
.expect("uuid hyphenated format should be all valid header characters")
|
||||
}
|
||||
|
||||
async fn handle_auth_broker_inner(
|
||||
ctx: &RequestContext,
|
||||
|
||||
Reference in New Issue
Block a user