mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
apply cargo fmt
This commit is contained in:
@@ -24,7 +24,7 @@ use crate::auth::backend::local::{JWKS_ROLE_MAP, LocalBackend};
|
||||
use crate::auth::{self};
|
||||
use crate::cancellation::CancellationHandler;
|
||||
use crate::config::{
|
||||
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, RestConfig,
|
||||
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RestConfig, RetryConfig,
|
||||
};
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
|
||||
|
||||
@@ -27,7 +27,7 @@ use crate::batch::BatchQueue;
|
||||
use crate::cancellation::{CancellationHandler, CancellationProcessor};
|
||||
use crate::config::{
|
||||
self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions,
|
||||
ProxyConfig, ProxyProtocolV2, remote_storage_from_toml, RestConfig,
|
||||
ProxyConfig, ProxyProtocolV2, RestConfig, remote_storage_from_toml,
|
||||
};
|
||||
use crate::context::parquet::ParquetUploadArgs;
|
||||
use crate::http::health_server::AppMetrics;
|
||||
@@ -500,12 +500,11 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
if let Some(db_schema_cache) = &config.rest_config.db_schema_cache {
|
||||
maintenance_tasks.spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(600)).await;
|
||||
db_schema_cache.flush();
|
||||
}
|
||||
tokio::time::sleep(Duration::from_secs(600)).await;
|
||||
db_schema_cache.flush();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
if let Some(metrics_config) = &config.metric_collection {
|
||||
// TODO: Add gc regardles of the metric collection being enabled.
|
||||
|
||||
2
proxy/src/cache/timed_lru.rs
vendored
2
proxy/src/cache/timed_lru.rs
vendored
@@ -228,7 +228,7 @@ impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
pub(crate) fn flush(&self) {
|
||||
let now = Instant::now();
|
||||
let mut cache = self.cache.lock();
|
||||
|
||||
|
||||
// Collect keys of expired entries first
|
||||
let expired_keys: Vec<_> = cache
|
||||
.iter()
|
||||
|
||||
@@ -13,9 +13,9 @@ use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::serverless::GlobalConnPoolOptions;
|
||||
use crate::serverless::cancel_set::CancelSet;
|
||||
use crate::serverless::rest::DbSchemaCache;
|
||||
pub use crate::tls::server_config::{TlsConfig, configure_tls};
|
||||
use crate::types::Host;
|
||||
use crate::serverless::rest::DbSchemaCache;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: ArcSwapOption<TlsConfig>,
|
||||
|
||||
@@ -396,7 +396,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
.parse()
|
||||
.expect("url is valid"),
|
||||
audience: None,
|
||||
role_names: vec![(&RoleName::from("authenticator")).into(), (&RoleName::from("authenticated")).into(), (&RoleName::from("anon")).into()],
|
||||
role_names: vec![
|
||||
(&RoleName::from("authenticator")).into(),
|
||||
(&RoleName::from("authenticated")).into(),
|
||||
(&RoleName::from("anon")).into(),
|
||||
],
|
||||
}]);
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::time::Duration;
|
||||
|
||||
use clashmap::ClashMap;
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug};
|
||||
use tracing::debug;
|
||||
|
||||
use super::{EndpointAccessControl, RoleAccessControl};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
@@ -229,7 +229,8 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
||||
// temporary lock a single shard and then clear any semaphores that aren't currently checked out
|
||||
// race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
|
||||
// therefore releasing it is safe from race conditions
|
||||
debug!( //FIXME: is anything depending on this being info?
|
||||
debug!(
|
||||
//FIXME: is anything depending on this being info?
|
||||
name = self.name,
|
||||
shard = i,
|
||||
"performing epoch reclamation on api lock"
|
||||
|
||||
@@ -115,7 +115,8 @@ impl PoolingBackend {
|
||||
|
||||
match &self.auth_backend {
|
||||
crate::auth::Backend::ControlPlane(console, ()) => {
|
||||
let keys = self.config
|
||||
let keys = self
|
||||
.config
|
||||
.authentication_config
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
@@ -357,7 +358,6 @@ impl PoolingBackend {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
info!("backend session state initialized");
|
||||
}
|
||||
|
||||
|
||||
@@ -93,6 +93,3 @@ impl HttpCodeError for ReadPayloadError {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -20,11 +20,12 @@ use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::types::EndpointCacheKey;
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use bytes::Bytes;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
|
||||
pub(crate) type Send = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
|
||||
pub(crate) type Connect = http2::Connection<TokioIo<AsyncRW>, BoxBody<Bytes, hyper::Error>, TokioExecutor>;
|
||||
pub(crate) type Connect =
|
||||
http2::Connection<TokioIo<AsyncRW>, BoxBody<Bytes, hyper::Error>, TokioExecutor>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ClientDataHttp();
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use http::{Response, StatusCode, HeaderName, HeaderValue, HeaderMap};
|
||||
use http::header::AUTHORIZATION;
|
||||
use http::{HeaderMap, HeaderName, HeaderValue, Response, StatusCode};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use http_utils::error::ApiError;
|
||||
@@ -12,18 +12,18 @@ use serde::Serialize;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use super::conn_pool::AuthData;
|
||||
use super::conn_pool::ConnInfoWithAuth;
|
||||
use super::conn_pool_lib::ConnInfo;
|
||||
use super::error::{ConnInfoError, Credentials};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::config::{AuthenticationConfig, TlsConfig};
|
||||
use crate::auth::backend::{ComputeUserInfo};
|
||||
use crate::auth::{endpoint_sni, };
|
||||
use crate::context::RequestContext;
|
||||
use crate::metrics::{Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::NeonOptions;
|
||||
use super::conn_pool::ConnInfoWithAuth;
|
||||
use super::error::{ConnInfoError, Credentials};
|
||||
use crate::types::{DbName, RoleName};
|
||||
use super::conn_pool::{AuthData};
|
||||
use super::conn_pool_lib::{ConnInfo};
|
||||
|
||||
// Common header names used across serverless modules
|
||||
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
|
||||
@@ -31,7 +31,8 @@ pub(super) static CONN_STRING: HeaderName = HeaderName::from_static("neon-connec
|
||||
pub(super) static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
|
||||
pub(super) static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
|
||||
pub(super) static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
|
||||
pub(super) static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
|
||||
pub(super) static TXN_ISOLATION_LEVEL: HeaderName =
|
||||
HeaderName::from_static("neon-batch-isolation-level");
|
||||
pub(super) static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
|
||||
pub(super) static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
|
||||
|
||||
@@ -140,7 +141,6 @@ pub(crate) fn json_response<T: Serialize>(
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
|
||||
pub(crate) fn get_conn_info(
|
||||
config: &'static AuthenticationConfig,
|
||||
ctx: &RequestContext,
|
||||
@@ -148,7 +148,7 @@ pub(crate) fn get_conn_info(
|
||||
headers: &HeaderMap,
|
||||
tls: Option<&TlsConfig>,
|
||||
) -> Result<ConnInfoWithAuth, ConnInfoError> {
|
||||
let connection_url = match connection_string {
|
||||
let connection_url = match connection_string {
|
||||
Some(connection_string) => Url::parse(connection_string)?,
|
||||
None => {
|
||||
let connection_string = headers
|
||||
@@ -159,7 +159,7 @@ pub(crate) fn get_conn_info(
|
||||
Url::parse(connection_string)?
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let protocol = connection_url.scheme();
|
||||
if protocol != "postgres" && protocol != "postgresql" {
|
||||
return Err(ConnInfoError::IncorrectScheme);
|
||||
@@ -279,4 +279,4 @@ pub(crate) fn get_conn_info(
|
||||
|
||||
let conn_info = ConnInfo { user_info, dbname };
|
||||
Ok(ConnInfoWithAuth { conn_info, auth })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,8 +11,8 @@ mod http_conn_pool;
|
||||
mod http_util;
|
||||
mod json;
|
||||
mod local_conn_pool;
|
||||
mod sql_over_http;
|
||||
pub mod rest;
|
||||
mod sql_over_http;
|
||||
mod websocket;
|
||||
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
@@ -30,13 +30,13 @@ use futures::future::{Either, select};
|
||||
use http::{Method, Response, StatusCode};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Empty};
|
||||
use http_util::{NEON_REQUEST_ID, uuid_to_header_value};
|
||||
use http_utils::error::ApiError;
|
||||
use hyper::body::Incoming;
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
use rand::SeedableRng;
|
||||
use rand::rngs::StdRng;
|
||||
use http_util::{NEON_REQUEST_ID, uuid_to_header_value};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::time::timeout;
|
||||
@@ -497,7 +497,7 @@ async fn request_handler(
|
||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
.body(Empty::new().map_err(|x| match x {}).boxed())
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))
|
||||
} else if config.rest_config.is_rest_broker && request.uri().path().starts_with("/rest") {
|
||||
} else if config.rest_config.is_rest_broker && request.uri().path().starts_with("/rest") {
|
||||
let ctx = RequestContext::new(
|
||||
session_id,
|
||||
conn_info,
|
||||
@@ -520,7 +520,6 @@ async fn request_handler(
|
||||
rest::handle(config, ctx, request, backend, http_cancellation_token)
|
||||
.instrument(span)
|
||||
.await
|
||||
|
||||
} else {
|
||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||
}
|
||||
|
||||
@@ -1,53 +1,66 @@
|
||||
use std::sync::Arc;
|
||||
use super::backend::HttpConnError;
|
||||
use super::backend::{LocalProxyConnError, PoolingBackend};
|
||||
use super::conn_pool::AuthData;
|
||||
use super::conn_pool_lib::ConnInfo;
|
||||
use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError};
|
||||
use super::http_conn_pool::{self, Send};
|
||||
use super::http_util::{
|
||||
ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY,
|
||||
get_conn_info, json_response, uuid_to_header_value,
|
||||
};
|
||||
use super::json::JsonConversionError;
|
||||
use crate::auth::backend::ComputeCredentialKeys;
|
||||
use crate::cache::TimedLru;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::read_body_with_limit;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::types::EndpointCacheKey;
|
||||
use bytes::Bytes;
|
||||
use http::Method;
|
||||
use http::header::AUTHORIZATION;
|
||||
use http_body_util::{combinators::BoxBody, Full, BodyExt};
|
||||
use http_body_util::{BodyExt, Full, combinators::BoxBody};
|
||||
use http_utils::error::ApiError;
|
||||
use hyper::{body::Incoming, http::{HeaderName, HeaderValue}, Request, Response, StatusCode};
|
||||
use hyper::{
|
||||
Request, Response, StatusCode,
|
||||
body::Incoming,
|
||||
http::{HeaderName, HeaderValue},
|
||||
};
|
||||
use indexmap::IndexMap;
|
||||
use jsonpath_lib::select;
|
||||
use ouroboros::self_referencing;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use super::http_conn_pool::{self, Send,};
|
||||
use serde_json::{value::RawValue, Value as JsonValue};
|
||||
use serde_json::{Value as JsonValue, value::RawValue};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use subzero_core::{
|
||||
api::{
|
||||
ApiResponse, ContentType::*, ListVal, Payload, Preferences, QueryNode::*, Representation,
|
||||
Resolution::*, SingleVal,
|
||||
},
|
||||
config::{db_allowed_select_functions, db_schemas, role_claim_key /*to_tuple*/},
|
||||
content_range_header, content_range_status,
|
||||
dynamic_statement::{JoinIterator, param, sql},
|
||||
error::Error::{
|
||||
self as SubzeroCoreError, ContentTypeError, GucHeadersError, GucStatusError, InternalError,
|
||||
JsonDeserialize, JwtTokenInvalid, NotFound,
|
||||
},
|
||||
error::pg_error_to_status_code,
|
||||
formatter::{
|
||||
Param,
|
||||
Param::*,
|
||||
Snippet, SqlParam,
|
||||
postgresql::{fmt_main_query, generate},
|
||||
},
|
||||
parser::postgrest::parse,
|
||||
permissions::check_safe_functions,
|
||||
schema::{DbSchema, POSTGRESQL_CONFIGURATION_SQL, POSTGRESQL_INTROSPECTION_SQL},
|
||||
};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info};
|
||||
use typed_json::json;
|
||||
use super::backend::{LocalProxyConnError, PoolingBackend};
|
||||
use super::conn_pool::{AuthData};
|
||||
use super::conn_pool_lib::{ConnInfo};
|
||||
use super::error::{HttpCodeError, ConnInfoError, Credentials, ReadPayloadError};
|
||||
use super::http_util::{
|
||||
json_response, uuid_to_header_value, get_conn_info,
|
||||
NEON_REQUEST_ID, CONN_STRING, RAW_TEXT_OUTPUT, ALLOW_POOL, TXN_ISOLATION_LEVEL, TXN_READ_ONLY
|
||||
};
|
||||
use super::json::{JsonConversionError};
|
||||
use crate::auth::backend::{ComputeCredentialKeys};
|
||||
use crate::config::{ProxyConfig, };
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::{read_body_with_limit};
|
||||
use crate::metrics::{Metrics, };
|
||||
use super::backend::HttpConnError;
|
||||
use crate::cache::{TimedLru};
|
||||
use crate::types::{EndpointCacheKey};
|
||||
use ouroboros::self_referencing;
|
||||
use std::collections::HashMap;
|
||||
use jsonpath_lib::select;
|
||||
use url::form_urlencoded;
|
||||
use subzero_core::{
|
||||
api::{ApiResponse, SingleVal, ListVal, Payload, ContentType::*, Preferences, QueryNode::*, Representation, Resolution::*,},
|
||||
error::Error::{
|
||||
self as SubzeroCoreError, JsonDeserialize, NotFound, JwtTokenInvalid, InternalError, GucHeadersError, GucStatusError, ContentTypeError,
|
||||
},
|
||||
error::{pg_error_to_status_code},
|
||||
schema::{DbSchema, POSTGRESQL_INTROSPECTION_SQL, POSTGRESQL_CONFIGURATION_SQL},
|
||||
formatter::{Param, Param::*, postgresql::{fmt_main_query, generate}, Snippet, SqlParam},
|
||||
dynamic_statement::{param, sql, JoinIterator},
|
||||
config::{db_schemas, db_allowed_select_functions, role_claim_key, /*to_tuple*/},
|
||||
parser::postgrest::parse,
|
||||
permissions::{check_safe_functions},
|
||||
content_range_header, content_range_status
|
||||
};
|
||||
|
||||
static MAX_SCHEMA_SIZE: usize = 1024 * 1024 * 5; // 5MB
|
||||
static MAX_HTTP_BODY_SIZE: usize = 10 * 1024 * 1024; // 10MB limit
|
||||
@@ -68,16 +81,17 @@ where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
Ok(s.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.collect())
|
||||
Ok(s.split(',').map(|s| s.trim().to_string()).collect())
|
||||
}
|
||||
|
||||
// The ApiConfig is the configuration for the API per endpoint
|
||||
// The configuration is read from the database and cached in the DbSchemaCache
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ApiConfig {
|
||||
#[serde(default = "db_schemas", deserialize_with = "deserialize_comma_separated")]
|
||||
#[serde(
|
||||
default = "db_schemas",
|
||||
deserialize_with = "deserialize_comma_separated"
|
||||
)]
|
||||
pub db_schemas: Vec<String>,
|
||||
pub db_anon_role: Option<String>,
|
||||
pub db_max_rows: Option<String>,
|
||||
@@ -93,23 +107,24 @@ pub struct ApiConfig {
|
||||
// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
|
||||
pub(crate) type DbSchemaCache = TimedLru<EndpointCacheKey, Arc<(ApiConfig, DbSchemaOwned)>>;
|
||||
impl DbSchemaCache {
|
||||
pub async fn get_cached_or_remote(&self,
|
||||
endpoint_id: &EndpointCacheKey,
|
||||
pub async fn get_cached_or_remote(
|
||||
&self,
|
||||
endpoint_id: &EndpointCacheKey,
|
||||
auth_header: &HeaderValue,
|
||||
connection_string: &str,
|
||||
client: &mut http_conn_pool::Client<Send>,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
|
||||
match self.get(endpoint_id){
|
||||
Some(entry) => {
|
||||
Ok(entry.value)
|
||||
}
|
||||
match self.get(endpoint_id) {
|
||||
Some(entry) => Ok(entry.value),
|
||||
None => {
|
||||
info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
|
||||
let remote_value = self.get_remote(auth_header, connection_string, client, ctx).await;
|
||||
let remote_value = self
|
||||
.get_remote(auth_header, connection_string, client, ctx)
|
||||
.await;
|
||||
let (api_config, schema_owned) = match remote_value {
|
||||
Ok((api_config, schema_owned)) => (api_config, schema_owned),
|
||||
Err(e@RestError::SchemaTooLarge(_, _)) => {
|
||||
Err(e @ RestError::SchemaTooLarge(_, _)) => {
|
||||
// for the case where the schema is too large, we cache an empty dummy value
|
||||
// all the other requests will fail without triggering the introspection query
|
||||
let schema_owned = DbSchemaOwned::new(EMPTY_JSON_SCHEMA.to_string(), |s| {
|
||||
@@ -138,46 +153,53 @@ impl DbSchemaCache {
|
||||
}
|
||||
}
|
||||
}
|
||||
pub async fn get_remote(&self,
|
||||
pub async fn get_remote(
|
||||
&self,
|
||||
auth_header: &HeaderValue,
|
||||
connection_string: &str,
|
||||
client: &mut http_conn_pool::Client<Send>,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<(ApiConfig, DbSchemaOwned), RestError> {
|
||||
|
||||
let headers = vec![
|
||||
(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
|
||||
(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap()),
|
||||
(&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap()),
|
||||
(
|
||||
&CONN_STRING,
|
||||
HeaderValue::from_str(connection_string).unwrap(),
|
||||
),
|
||||
(
|
||||
&TXN_ISOLATION_LEVEL,
|
||||
HeaderValue::from_str("ReadCommitted").unwrap(),
|
||||
),
|
||||
(&AUTHORIZATION, auth_header.clone()),
|
||||
(&RAW_TEXT_OUTPUT, HeaderValue::from_str("true").unwrap()),
|
||||
];
|
||||
|
||||
let body = serde_json::json!({"query": CONFIGURATION_SQL});
|
||||
let (response_status, mut response_json) = make_local_proxy_request(client, headers, body).await?;
|
||||
let (response_status, mut response_json) =
|
||||
make_local_proxy_request(client, headers, body).await?;
|
||||
|
||||
if response_status != StatusCode::OK {
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to get endpoint configuration".to_string()
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to get endpoint configuration".to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
let rows = response_json["rows"].as_array_mut()
|
||||
.ok_or_else(|| RestError::SubzeroCore(InternalError {
|
||||
message: "Missing 'rows' array in second result".to_string()
|
||||
}))?;
|
||||
|
||||
|
||||
let rows = response_json["rows"].as_array_mut().ok_or_else(|| {
|
||||
RestError::SubzeroCore(InternalError {
|
||||
message: "Missing 'rows' array in second result".to_string(),
|
||||
})
|
||||
})?;
|
||||
|
||||
if rows.is_empty() {
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "No rows in second result".to_string()
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "No rows in second result".to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
// Extract columns from the first (and only) row
|
||||
let mut row = &mut rows[0];
|
||||
let config_string = extract_string(&mut row, "config").unwrap_or_default();
|
||||
|
||||
|
||||
// Parse the configuration response
|
||||
let api_config: ApiConfig = serde_json::from_str(&config_string)
|
||||
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
|
||||
@@ -185,12 +207,18 @@ impl DbSchemaCache {
|
||||
// now that we have the api_config let's run the second INTROSPECTION_SQL query
|
||||
let headers = vec![
|
||||
(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
|
||||
(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap()),
|
||||
(&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap()),
|
||||
(
|
||||
&CONN_STRING,
|
||||
HeaderValue::from_str(connection_string).unwrap(),
|
||||
),
|
||||
(
|
||||
&TXN_ISOLATION_LEVEL,
|
||||
HeaderValue::from_str("ReadCommitted").unwrap(),
|
||||
),
|
||||
(&AUTHORIZATION, auth_header.clone()),
|
||||
(&RAW_TEXT_OUTPUT, HeaderValue::from_str("true").unwrap()),
|
||||
];
|
||||
|
||||
|
||||
let body = serde_json::json!({
|
||||
"query": INTROSPECTION_SQL,
|
||||
"params": [
|
||||
@@ -199,25 +227,27 @@ impl DbSchemaCache {
|
||||
false, // use_internal_permissions
|
||||
]
|
||||
});
|
||||
let (response_status, mut response_json) = make_local_proxy_request(client, headers, body).await?;
|
||||
let (response_status, mut response_json) =
|
||||
make_local_proxy_request(client, headers, body).await?;
|
||||
|
||||
if response_status != StatusCode::OK {
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to get endpoint schema".to_string()
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to get endpoint schema".to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
let rows = response_json["rows"].as_array_mut()
|
||||
.ok_or_else(|| RestError::SubzeroCore(InternalError {
|
||||
message: "Missing 'rows' array in second result".to_string()
|
||||
}))?;
|
||||
|
||||
let rows = response_json["rows"].as_array_mut().ok_or_else(|| {
|
||||
RestError::SubzeroCore(InternalError {
|
||||
message: "Missing 'rows' array in second result".to_string(),
|
||||
})
|
||||
})?;
|
||||
|
||||
if rows.is_empty() {
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "No rows in second result".to_string()
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "No rows in second result".to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
// Extract columns from the first (and only) row
|
||||
let mut row = &mut rows[0];
|
||||
let json_schema = extract_string(&mut row, "json_schema").unwrap_or_default();
|
||||
@@ -228,22 +258,22 @@ impl DbSchemaCache {
|
||||
}
|
||||
|
||||
let schema_owned = DbSchemaOwned::new(json_schema, |s| {
|
||||
serde_json::from_str::<DbSchema>(s.as_str())
|
||||
.map_err(|e| JsonDeserialize { source: e })
|
||||
serde_json::from_str::<DbSchema>(s.as_str()).map_err(|e| JsonDeserialize { source: e })
|
||||
});
|
||||
|
||||
|
||||
// check if schema is an ok result
|
||||
let schema = schema_owned.borrow_schema();
|
||||
if schema.is_ok() {
|
||||
Ok((api_config, schema_owned))
|
||||
} else {
|
||||
//
|
||||
Err(RestError::SubzeroCore(SubzeroCoreError::InternalError { message: "Failed to get schema".to_string() }))
|
||||
Err(RestError::SubzeroCore(SubzeroCoreError::InternalError {
|
||||
message: "Failed to get schema".to_string(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// A type to represent a postgresql errors
|
||||
// we use our own type (instead of postgres_client::Error) because we get the error from the json response
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -251,7 +281,7 @@ pub(crate) struct PostgresError {
|
||||
pub code: String,
|
||||
pub message: String,
|
||||
pub detail: Option<String>,
|
||||
pub hint: Option<String>
|
||||
pub hint: Option<String>,
|
||||
}
|
||||
impl HttpCodeError for PostgresError {
|
||||
fn get_http_status_code(&self) -> StatusCode {
|
||||
@@ -323,12 +353,12 @@ impl UserFacingError for RestError {
|
||||
// TODO: this is a hack to get the message from the json body
|
||||
let json = s.json_body();
|
||||
let default_message = "Unknown error".to_string();
|
||||
let message = json.get("message").map_or(default_message.clone(), |m|
|
||||
match m {
|
||||
let message = json
|
||||
.get("message")
|
||||
.map_or(default_message.clone(), |m| match m {
|
||||
JsonValue::String(s) => s.clone(),
|
||||
_ => default_message,
|
||||
}
|
||||
);
|
||||
});
|
||||
message
|
||||
}
|
||||
}
|
||||
@@ -367,7 +397,9 @@ fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> {
|
||||
sql("null")
|
||||
} else {
|
||||
env.iter()
|
||||
.map(|(k, v)| "set_config(" + param(k as &SqlParam) + ", " + param(v as &SqlParam) + ", true)")
|
||||
.map(|(k, v)| {
|
||||
"set_config(" + param(k as &SqlParam) + ", " + param(v as &SqlParam) + ", true)"
|
||||
})
|
||||
.join(",")
|
||||
}
|
||||
}
|
||||
@@ -375,18 +407,10 @@ fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> {
|
||||
// TODO: see about removing the need for cloning the values (inner things are &Cow<str> already)
|
||||
fn to_sql_param(p: &Param) -> JsonValue {
|
||||
match p {
|
||||
SV(SingleVal(v, ..)) => {
|
||||
JsonValue::String(v.to_string())
|
||||
}
|
||||
Str(v) => {
|
||||
JsonValue::String(v.to_string())
|
||||
}
|
||||
StrOwned(v) => {
|
||||
JsonValue::String((*v).clone())
|
||||
}
|
||||
PL(Payload(v, ..)) => {
|
||||
JsonValue::String(v.clone().into_owned())
|
||||
}
|
||||
SV(SingleVal(v, ..)) => JsonValue::String(v.to_string()),
|
||||
Str(v) => JsonValue::String(v.to_string()),
|
||||
StrOwned(v) => JsonValue::String((*v).clone()),
|
||||
PL(Payload(v, ..)) => JsonValue::String(v.clone().into_owned()),
|
||||
LV(ListVal(v, ..)) => {
|
||||
if !v.is_empty() {
|
||||
JsonValue::String(format!(
|
||||
@@ -395,7 +419,7 @@ fn to_sql_param(p: &Param) -> JsonValue {
|
||||
.map(|e| e.replace('\\', "\\\\").replace('\"', "\\\""))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\",\"")
|
||||
))
|
||||
))
|
||||
} else {
|
||||
JsonValue::String(r#"{}"#.to_string())
|
||||
}
|
||||
@@ -427,12 +451,12 @@ async fn make_local_proxy_request(
|
||||
let body_boxed = Full::new(Bytes::from(body_string))
|
||||
.map_err(|never| match never {}) // Convert Infallible to hyper::Error
|
||||
.boxed();
|
||||
|
||||
let req = req
|
||||
.body(body_boxed)
|
||||
.map_err(|_| RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to build request".to_string()
|
||||
}))?;
|
||||
|
||||
let req = req.body(body_boxed).map_err(|_| {
|
||||
RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to build request".to_string(),
|
||||
})
|
||||
})?;
|
||||
|
||||
// Send the request to the local proxy
|
||||
let response = client
|
||||
@@ -444,7 +468,7 @@ async fn make_local_proxy_request(
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let response_status = response.status();
|
||||
|
||||
|
||||
// Capture the response body
|
||||
let response_body = response
|
||||
.collect()
|
||||
@@ -553,7 +577,7 @@ pub(crate) async fn handle(
|
||||
let json_body = subzero_err.json_body();
|
||||
let status_code = StatusCode::from_u16(subzero_err.status_code())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
|
||||
json_response(status_code, json_body)?
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -571,13 +595,15 @@ pub(crate) async fn handle(
|
||||
);
|
||||
|
||||
let (code, detail, hint) = match e {
|
||||
RestError::Postgres(e) => {
|
||||
(if e.code.starts_with("PT") {
|
||||
RestError::Postgres(e) => (
|
||||
if e.code.starts_with("PT") {
|
||||
None
|
||||
} else {
|
||||
Some(e.code)
|
||||
}, e.detail, e.hint)
|
||||
},
|
||||
},
|
||||
e.detail,
|
||||
e.hint,
|
||||
),
|
||||
_ => (None, None, None),
|
||||
};
|
||||
|
||||
@@ -587,7 +613,7 @@ pub(crate) async fn handle(
|
||||
"message": message,
|
||||
"code": code,
|
||||
"detail": detail,
|
||||
"hint": hint,
|
||||
"hint": hint,
|
||||
}),
|
||||
)?
|
||||
}
|
||||
@@ -615,12 +641,20 @@ async fn handle_inner(
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
|
||||
let endpoint_id = request.uri().host().unwrap_or("").split('.').next().unwrap_or("");
|
||||
let endpoint_id = request
|
||||
.uri()
|
||||
.host()
|
||||
.unwrap_or("")
|
||||
.split('.')
|
||||
.next()
|
||||
.unwrap_or("");
|
||||
|
||||
// we always use the authenticator role to connect to the database
|
||||
let autheticator_role = "authenticator";
|
||||
let connection_string = format!("postgresql://{}@{}.local.neon.build/database", autheticator_role, endpoint_id);
|
||||
let connection_string = format!(
|
||||
"postgresql://{}@{}.local.neon.build/database",
|
||||
autheticator_role, endpoint_id
|
||||
);
|
||||
|
||||
let conn_info = get_conn_info(
|
||||
&config.authentication_config,
|
||||
@@ -638,11 +672,20 @@ async fn handle_inner(
|
||||
|
||||
match conn_info.auth {
|
||||
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
|
||||
handle_rest_inner(config, ctx, request, &connection_string, conn_info.conn_info, jwt, backend).await
|
||||
}
|
||||
_ => {
|
||||
Err(RestError::ConnInfo(ConnInfoError::MissingCredentials(Credentials::Password)))
|
||||
handle_rest_inner(
|
||||
config,
|
||||
ctx,
|
||||
request,
|
||||
&connection_string,
|
||||
conn_info.conn_info,
|
||||
jwt,
|
||||
backend,
|
||||
)
|
||||
.await
|
||||
}
|
||||
_ => Err(RestError::ConnInfo(ConnInfoError::MissingCredentials(
|
||||
Credentials::Password,
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -655,23 +698,22 @@ async fn handle_rest_inner(
|
||||
jwt: String,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
|
||||
|
||||
// validate the jwt token
|
||||
let jwt_parsed = backend
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
|
||||
let db_schema_cache = match config.rest_config.db_schema_cache.as_ref() {
|
||||
Some(cache) => cache,
|
||||
None => {
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "DB schema cache is not configured".to_string()
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "DB schema cache is not configured".to_string(),
|
||||
}));
|
||||
}
|
||||
};
|
||||
// hardcoded values for now, these should come from a config per tenant
|
||||
|
||||
|
||||
let api_prefix = "/rest/v1/";
|
||||
|
||||
let endpoint_cache_key = conn_info.endpoint_cache_key().unwrap();
|
||||
@@ -679,13 +721,25 @@ async fn handle_rest_inner(
|
||||
let (parts, originial_body) = request.into_parts();
|
||||
let headers_map = parts.headers;
|
||||
let auth_header = headers_map.get(AUTHORIZATION).unwrap();
|
||||
let entry = db_schema_cache.get_cached_or_remote(&endpoint_cache_key, auth_header, &connection_string, &mut client, &ctx).await?;
|
||||
let entry = db_schema_cache
|
||||
.get_cached_or_remote(
|
||||
&endpoint_cache_key,
|
||||
auth_header,
|
||||
&connection_string,
|
||||
&mut client,
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
let (api_config, db_schema_owned) = entry.as_ref();
|
||||
let db_schema = db_schema_owned.borrow_schema().as_ref().map_err(|_| RestError::SubzeroCore(InternalError { message: "Failed to get schema".to_string() }))?;
|
||||
let db_schema = db_schema_owned.borrow_schema().as_ref().map_err(|_| {
|
||||
RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to get schema".to_string(),
|
||||
})
|
||||
})?;
|
||||
|
||||
let db_schemas = &api_config.db_schemas; // list of schemas available for the api
|
||||
//let db_schema = &*DB_SCHEMA; // use the global static schema
|
||||
|
||||
|
||||
//let db_extra_search_path = "public, extensions".to_string();
|
||||
let db_extra_search_path = &api_config.db_extra_search_path;
|
||||
let role_claim_key = &api_config.role_claim_key;
|
||||
@@ -693,22 +747,23 @@ async fn handle_rest_inner(
|
||||
let db_anon_role = &api_config.db_anon_role;
|
||||
//let max_rows = Some("1000".to_string());
|
||||
let max_rows = api_config.db_max_rows.as_ref().map(|s| s.as_str());
|
||||
let db_allowed_select_functions = api_config.db_allowed_select_functions.iter().map(|s| s.as_str()).collect::<Vec<_>>();
|
||||
let db_allowed_select_functions = api_config
|
||||
.db_allowed_select_functions
|
||||
.iter()
|
||||
.map(|s| s.as_str())
|
||||
.collect::<Vec<_>>();
|
||||
// end hardcoded values
|
||||
|
||||
|
||||
// extract the jwt claims (we'll need them later to set the role and env)
|
||||
let jwt_claims = match jwt_parsed.keys {
|
||||
ComputeCredentialKeys::JwtPayload(payload_bytes) => {
|
||||
// `payload_bytes` contains the raw JWT payload as Vec<u8>
|
||||
// You can deserialize it back to JSON or parse specific claims
|
||||
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
|
||||
.map_err(|e| RestError::SubzeroCore(JsonDeserialize {source:e }))?;
|
||||
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
|
||||
Some(payload)
|
||||
},
|
||||
_ => {
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
//TODO: check if the token is properly cached in the backend (should we cache the parsed claims?)
|
||||
// read the role from the jwt claims (and set it to the "anon" role if not present)
|
||||
@@ -718,23 +773,29 @@ async fn handle_rest_inner(
|
||||
[JsonValue::String(s)] => Ok((Some(s), true)),
|
||||
_ => Ok((db_anon_role.as_ref(), true)),
|
||||
},
|
||||
Err(e) => Err(RestError::SubzeroCore(JwtTokenInvalid { message: format!("{e}") })),
|
||||
Err(e) => Err(RestError::SubzeroCore(JwtTokenInvalid {
|
||||
message: format!("{e}"),
|
||||
})),
|
||||
},
|
||||
None => Ok((db_anon_role.as_ref(), false)),
|
||||
}?;
|
||||
|
||||
|
||||
// do not allow unauthenticated requests when there is no anonymous role setup
|
||||
if let (None, false) = (role, authenticated) {
|
||||
return Err(RestError::SubzeroCore(JwtTokenInvalid {
|
||||
message: "unauthenticated requests not allowed".to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
// start deconstructing the request because subzero core mostly works with &str
|
||||
|
||||
|
||||
let method = parts.method;
|
||||
let method_str = method.to_string();
|
||||
let path = parts.uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
|
||||
let path = parts
|
||||
.uri
|
||||
.path_and_query()
|
||||
.map(|pq| pq.as_str())
|
||||
.unwrap_or("/");
|
||||
|
||||
// this is actually the table name (or rpc/function_name)
|
||||
// TODO: rename this to something more descriptive
|
||||
@@ -744,17 +805,19 @@ async fn handle_rest_inner(
|
||||
target: parts.uri.path().to_string(),
|
||||
})),
|
||||
}?;
|
||||
|
||||
|
||||
// pick the current schema from the headers (or the first one from config)
|
||||
//let schema_name = ¤t_schema(db_schemas, &method, &headers_map).map_err(RestError::SubzeroCore)?;
|
||||
let schema_name = db_schema.pick_current_schema(&method_str, &headers_map).map_err(RestError::SubzeroCore)?;
|
||||
let schema_name = db_schema
|
||||
.pick_current_schema(&method_str, &headers_map)
|
||||
.map_err(RestError::SubzeroCore)?;
|
||||
|
||||
// add the content-profile header to the response
|
||||
let mut response_headers = vec![];
|
||||
if db_schemas.len() > 1 {
|
||||
response_headers.push(("Content-Profile".to_string(), schema_name.to_string()));
|
||||
}
|
||||
|
||||
|
||||
// parse the query string into a Vec<(&str, &str)>
|
||||
let query = match parts.uri.query() {
|
||||
Some(q) => form_urlencoded::parse(q.as_bytes()).collect(),
|
||||
@@ -762,24 +825,38 @@ async fn handle_rest_inner(
|
||||
};
|
||||
let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect();
|
||||
|
||||
|
||||
// convert the headers map to a HashMap<&str, &str>
|
||||
let headers: HashMap<&str, &str> = headers_map.iter()
|
||||
let headers: HashMap<&str, &str> = headers_map
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__")))
|
||||
.collect();
|
||||
|
||||
let cookies = HashMap::new(); // TODO: add cookies
|
||||
|
||||
|
||||
// Read the request body
|
||||
let body_bytes = read_body_with_limit(originial_body, MAX_HTTP_BODY_SIZE).await.map_err(ReadPayloadError::from)?;
|
||||
let body_bytes = read_body_with_limit(originial_body, MAX_HTTP_BODY_SIZE)
|
||||
.await
|
||||
.map_err(ReadPayloadError::from)?;
|
||||
let body_as_string: Option<String> = if body_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(String::from_utf8_lossy(&body_bytes).into_owned())
|
||||
};
|
||||
|
||||
|
||||
// parse the request into an ApiRequest struct
|
||||
let api_request = parse(schema_name, root, db_schema, method_str.as_str(), path, get, body_as_string.as_deref(), headers, cookies, max_rows).map_err(RestError::SubzeroCore)?;
|
||||
let api_request = parse(
|
||||
schema_name,
|
||||
root,
|
||||
db_schema,
|
||||
method_str.as_str(),
|
||||
path,
|
||||
get,
|
||||
body_as_string.as_deref(),
|
||||
headers,
|
||||
cookies,
|
||||
max_rows,
|
||||
)
|
||||
.map_err(RestError::SubzeroCore)?;
|
||||
|
||||
// in case when the role is not set (but authenticated through jwt) the query will be executed with the privileges
|
||||
// of the "authenticator" role unless the DbSchema has internal privileges set
|
||||
@@ -787,7 +864,7 @@ async fn handle_rest_inner(
|
||||
// replace "*" with the list of columns the user has access to
|
||||
// so that he does not encounter permission errors
|
||||
// replace_select_star(db_schema, schema_name, role, &mut api_request.query)?;
|
||||
|
||||
|
||||
let role_str = match role {
|
||||
Some(r) => r,
|
||||
None => "",
|
||||
@@ -796,12 +873,12 @@ async fn handle_rest_inner(
|
||||
// if !disable_internal_permissions {
|
||||
// check_privileges(db_schema, schema_name, role_str, &api_request)?;
|
||||
// }
|
||||
|
||||
|
||||
check_safe_functions(&api_request, &db_allowed_select_functions)?;
|
||||
|
||||
|
||||
// this is not relevant when acting as PostgREST
|
||||
// if !disable_internal_permissions {
|
||||
// insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?;
|
||||
// insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?;
|
||||
// }
|
||||
|
||||
// when using internal privileges not switch "current_role"
|
||||
@@ -819,16 +896,14 @@ async fn handle_rest_inner(
|
||||
let cookies_env = serde_json::to_string(&api_request.cookies).unwrap_or(empty_json.clone());
|
||||
let get_env = serde_json::to_string(&api_request.get).unwrap_or(empty_json.clone());
|
||||
let jwt_claims_env = jwt_claims
|
||||
.as_ref()
|
||||
.map(|v| serde_json::to_string(v).unwrap_or(empty_json.clone()))
|
||||
.unwrap_or(
|
||||
if let Some(r) = env_role {
|
||||
let claims: HashMap<&str, &str> = HashMap::from([("role", r)]);
|
||||
serde_json::to_string(&claims).unwrap_or(empty_json.clone())
|
||||
} else {
|
||||
empty_json.clone()
|
||||
}
|
||||
);
|
||||
.as_ref()
|
||||
.map(|v| serde_json::to_string(v).unwrap_or(empty_json.clone()))
|
||||
.unwrap_or(if let Some(r) = env_role {
|
||||
let claims: HashMap<&str, &str> = HashMap::from([("role", r)]);
|
||||
serde_json::to_string(&claims).unwrap_or(empty_json.clone())
|
||||
} else {
|
||||
empty_json.clone()
|
||||
});
|
||||
let mut env: HashMap<&str, &str> = HashMap::from([
|
||||
("request.method", api_request.method),
|
||||
("request.path", api_request.path),
|
||||
@@ -841,29 +916,46 @@ async fn handle_rest_inner(
|
||||
if let Some(r) = env_role {
|
||||
env.insert("role", r.into());
|
||||
}
|
||||
|
||||
|
||||
if let Some(search_path) = db_extra_search_path {
|
||||
env.insert("search_path", search_path);
|
||||
}
|
||||
// generate the sql statements
|
||||
let (env_statement, env_parameters, _) = generate(fmt_env_query(&env));
|
||||
let (main_statement, main_parameters, _) = generate(fmt_main_query(db_schema, api_request.schema_name, &api_request, &env)?);
|
||||
|
||||
let (main_statement, main_parameters, _) = generate(fmt_main_query(
|
||||
db_schema,
|
||||
api_request.schema_name,
|
||||
&api_request,
|
||||
&env,
|
||||
)?);
|
||||
|
||||
let mut headers = vec![
|
||||
(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())),
|
||||
(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap()),
|
||||
(
|
||||
&CONN_STRING,
|
||||
HeaderValue::from_str(connection_string).unwrap(),
|
||||
),
|
||||
(&AUTHORIZATION, auth_header.clone()),
|
||||
(&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap()),
|
||||
(
|
||||
&TXN_ISOLATION_LEVEL,
|
||||
HeaderValue::from_str("ReadCommitted").unwrap(),
|
||||
),
|
||||
(&ALLOW_POOL, HeaderValue::from_str("true").unwrap()),
|
||||
];
|
||||
|
||||
|
||||
if api_request.read_only {
|
||||
headers.push((&TXN_READ_ONLY, HeaderValue::from_str("true").unwrap()));
|
||||
}
|
||||
|
||||
// convert the parameters from subzero core representation to a Vec<JsonValue>
|
||||
let env_parameters_json = env_parameters.iter().map(|p| to_sql_param(&p.to_param())).collect::<Vec<_>>();
|
||||
let main_parameters_json = main_parameters.iter().map(|p| to_sql_param(&p.to_param())).collect::<Vec<_>>();
|
||||
let env_parameters_json = env_parameters
|
||||
.iter()
|
||||
.map(|p| to_sql_param(&p.to_param()))
|
||||
.collect::<Vec<_>>();
|
||||
let main_parameters_json = main_parameters
|
||||
.iter()
|
||||
.map(|p| to_sql_param(&p.to_param()))
|
||||
.collect::<Vec<_>>();
|
||||
let body = serde_json::json!({
|
||||
"queries": [
|
||||
{
|
||||
@@ -881,9 +973,8 @@ async fn handle_rest_inner(
|
||||
let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
|
||||
|
||||
// send the request to the local proxy
|
||||
let (response_status, mut response_json) = make_local_proxy_request(&mut client, headers, body).await?;
|
||||
|
||||
|
||||
let (response_status, mut response_json) =
|
||||
make_local_proxy_request(&mut client, headers, body).await?;
|
||||
|
||||
// if the response status is greater than 399, then it is an error
|
||||
// FIXME: check if there are other error codes or shapes of the response
|
||||
@@ -895,31 +986,33 @@ async fn handle_rest_inner(
|
||||
detail: extract_string(&mut response_json, "detail"),
|
||||
hint: extract_string(&mut response_json, "hint"),
|
||||
};
|
||||
|
||||
|
||||
return Err(RestError::Postgres(postgres_error));
|
||||
}
|
||||
|
||||
// Extract the second query result (main query)
|
||||
let results = response_json["results"].as_array_mut()
|
||||
.ok_or_else(|| RestError::SubzeroCore(InternalError {
|
||||
message: "Missing 'results' array".to_string()
|
||||
}))?;
|
||||
let results = response_json["results"].as_array_mut().ok_or_else(|| {
|
||||
RestError::SubzeroCore(InternalError {
|
||||
message: "Missing 'results' array".to_string(),
|
||||
})
|
||||
})?;
|
||||
|
||||
if results.len() < 2 {
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "Expected at least 2 results".to_string()
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "Expected at least 2 results".to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
let second_result = &mut results[1];
|
||||
let rows = second_result["rows"].as_array_mut()
|
||||
.ok_or_else(|| RestError::SubzeroCore(InternalError {
|
||||
message: "Missing 'rows' array in second result".to_string()
|
||||
}))?;
|
||||
let rows = second_result["rows"].as_array_mut().ok_or_else(|| {
|
||||
RestError::SubzeroCore(InternalError {
|
||||
message: "Missing 'rows' array in second result".to_string(),
|
||||
})
|
||||
})?;
|
||||
|
||||
if rows.is_empty() {
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "No rows in second result".to_string()
|
||||
return Err(RestError::SubzeroCore(InternalError {
|
||||
message: "No rows in second result".to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
@@ -935,7 +1028,9 @@ async fn handle_rest_inner(
|
||||
|
||||
// build the intermediate response object
|
||||
let api_response = ApiResponse {
|
||||
page_total: page_total.map(|v| v.parse::<u64>().unwrap_or(0)).unwrap_or(0),
|
||||
page_total: page_total
|
||||
.map(|v| v.parse::<u64>().unwrap_or(0))
|
||||
.unwrap_or(0),
|
||||
total_result_set: total_result_set.map(|v| v.parse::<u64>().unwrap_or(0)),
|
||||
top_level_offset: 0, // FIXME: check why this is 0
|
||||
response_headers: response_headers_json,
|
||||
@@ -1036,7 +1131,11 @@ async fn handle_rest_inner(
|
||||
};
|
||||
|
||||
// add the preference-applied header
|
||||
if let Some(Preferences { resolution: Some(r), .. }) = api_request.preferences {
|
||||
if let Some(Preferences {
|
||||
resolution: Some(r),
|
||||
..
|
||||
}) = api_request.preferences
|
||||
{
|
||||
response_headers.push((
|
||||
"Preference-Applied".to_string(),
|
||||
match r {
|
||||
@@ -1049,7 +1148,9 @@ async fn handle_rest_inner(
|
||||
// check if the SQL env set some response status (happens when we called a rpc function)
|
||||
let response_status: Option<String> = api_response.response_status;
|
||||
if let Some(response_status_str) = response_status {
|
||||
status = response_status_str.parse::<u16>().map_err(|_| RestError::SubzeroCore(GucStatusError))?;
|
||||
status = response_status_str
|
||||
.parse::<u16>()
|
||||
.map_err(|_| RestError::SubzeroCore(GucStatusError))?;
|
||||
}
|
||||
|
||||
// set the content type header
|
||||
@@ -1066,21 +1167,22 @@ async fn handle_rest_inner(
|
||||
let response_body = Full::new(Bytes::from(api_response.body))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
|
||||
|
||||
// build the response
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
|
||||
.header("content-type", http_content_type);
|
||||
|
||||
|
||||
// Add all headers from response_headers vector
|
||||
for (header_name, header_value) in response_headers {
|
||||
response = response.header(header_name, header_value);
|
||||
}
|
||||
|
||||
Ok(response.body(response_body).map_err(|_| RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to build response".to_string()
|
||||
}))?)
|
||||
|
||||
Ok(response.body(response_body).map_err(|_| {
|
||||
RestError::SubzeroCore(InternalError {
|
||||
message: "Failed to build response".to_string(),
|
||||
})
|
||||
})?)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -1088,7 +1190,5 @@ mod tests {
|
||||
//use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_payload() {
|
||||
|
||||
}
|
||||
fn test_payload() {}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use bytes::Bytes;
|
||||
use futures::future::{Either, select, try_join};
|
||||
use futures::{StreamExt, TryFutureExt};
|
||||
use http::{Method, header::AUTHORIZATION};
|
||||
use http_body_util::{combinators::BoxBody, Full, BodyExt};
|
||||
use http_body_util::{BodyExt, Full, combinators::BoxBody};
|
||||
use http_utils::error::ApiError;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{http::{HeaderName, HeaderValue}, Request, Response, StatusCode, header};
|
||||
use hyper::{
|
||||
Request, Response, StatusCode, header,
|
||||
http::{HeaderName, HeaderValue},
|
||||
};
|
||||
use indexmap::IndexMap;
|
||||
use postgres_client::error::{DbError, ErrorPosition, SqlState};
|
||||
use postgres_client::{
|
||||
@@ -15,27 +16,29 @@ use postgres_client::{
|
||||
};
|
||||
use serde::Serialize;
|
||||
use serde_json::{Value, value::RawValue};
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{self, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info};
|
||||
use typed_json::json;
|
||||
|
||||
use super::backend::{LocalProxyConnError, PoolingBackend};
|
||||
use super::conn_pool::{AuthData,};
|
||||
use super::conn_pool::AuthData;
|
||||
use super::conn_pool_lib::{self, ConnInfo};
|
||||
use super::error::{HttpCodeError, ConnInfoError, ReadPayloadError};
|
||||
use super::error::{ConnInfoError, HttpCodeError, ReadPayloadError};
|
||||
use super::http_util::{
|
||||
json_response, uuid_to_header_value, get_conn_info,
|
||||
NEON_REQUEST_ID, CONN_STRING, RAW_TEXT_OUTPUT, ARRAY_MODE, ALLOW_POOL, TXN_ISOLATION_LEVEL, TXN_READ_ONLY, TXN_DEFERRABLE
|
||||
ALLOW_POOL, ARRAY_MODE, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_DEFERRABLE,
|
||||
TXN_ISOLATION_LEVEL, TXN_READ_ONLY, get_conn_info, json_response, uuid_to_header_value,
|
||||
};
|
||||
use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json};
|
||||
use crate::auth::backend::{ComputeCredentialKeys,};
|
||||
use crate::auth::backend::ComputeCredentialKeys;
|
||||
|
||||
use crate::config::{HttpConfig, ProxyConfig,};
|
||||
use crate::config::{HttpConfig, ProxyConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::{read_body_with_limit};
|
||||
use crate::metrics::{HttpDirection, Metrics, };
|
||||
use crate::http::read_body_with_limit;
|
||||
use crate::metrics::{HttpDirection, Metrics};
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
|
||||
use crate::util::run_until_cancelled;
|
||||
@@ -310,7 +313,6 @@ impl HttpCodeError for SqlOverHttpError {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum SqlOverHttpCancel {
|
||||
#[error("query was cancelled")]
|
||||
@@ -495,7 +497,9 @@ async fn handle_db_inner(
|
||||
ComputeCredentialKeys::JwtPayload(payload)
|
||||
if backend.auth_backend.is_local_proxy() =>
|
||||
{
|
||||
let mut client = backend.connect_to_local_postgres(ctx, conn_info, config.disable_pg_session_jwt).await?;
|
||||
let mut client = backend
|
||||
.connect_to_local_postgres(ctx, conn_info, config.disable_pg_session_jwt)
|
||||
.await?;
|
||||
if !config.disable_pg_session_jwt {
|
||||
let (cli_inner, _dsc) = client.client_inner();
|
||||
cli_inner.set_jwt_session(&payload).await?;
|
||||
@@ -598,7 +602,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
|
||||
&TXN_DEFERRABLE,
|
||||
];
|
||||
|
||||
|
||||
async fn handle_auth_broker_inner(
|
||||
ctx: &RequestContext,
|
||||
request: Request<Incoming>,
|
||||
|
||||
Reference in New Issue
Block a user