subzero integration WIP4

queries generated by subzero reach database and execute succesfully
This commit is contained in:
Ruslan Talpa
2025-06-24 15:33:51 +03:00
parent 67d3026fc4
commit d1445cf3eb
10 changed files with 502 additions and 40 deletions

View File

@@ -396,7 +396,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
.parse()
.expect("url is valid"),
audience: None,
role_names: vec![(&RoleName::from("authenticated")).into()],
role_names: vec![(&RoleName::from("authenticator")).into(), (&RoleName::from("authenticated")).into(), (&RoleName::from("anon")).into()],
}]);
}

View File

@@ -229,7 +229,7 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
// temporary lock a single shard and then clear any semaphores that aren't currently checked out
// race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
// therefore releasing it is safe from race conditions
info!(
debug!( //FIXME: is anything depending on this being info?
name = self.name,
shard = i,
"performing epoch reclamation on api lock"

View File

@@ -115,7 +115,7 @@ impl PoolingBackend {
match &self.auth_backend {
crate::auth::Backend::ControlPlane(console, ()) => {
self.config
let keys = self.config
.authentication_config
.jwks_cache
.check_jwt(
@@ -129,7 +129,9 @@ impl PoolingBackend {
Ok(ComputeCredentials {
info: user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
// FIXME: why was this set to None?
//keys: crate::auth::backend::ComputeCredentialKeys::None,
keys,
})
}
crate::auth::Backend::Local(_) => {

View File

@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::sync::Arc;
use bytes::Bytes;
@@ -14,7 +15,7 @@ use indexmap::IndexMap;
use postgres_client::error::{DbError, ErrorPosition, SqlState};
use serde_json::value::RawValue;
use serde_json::{value::RawValue, Value as JsonValue};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info};
@@ -28,8 +29,8 @@ use super::conn_pool_lib::{ConnInfo};
use super::error::HttpCodeError;
use super::http_util::json_response;
use super::json::{JsonConversionError};
use crate::auth::backend::{ComputeUserInfo};
use crate::auth::{ComputeUserInfoParseError, endpoint_sni};
use crate::auth::backend::{ComputeUserInfo, ComputeCredentialKeys};
use crate::auth::{ComputeUserInfoParseError, endpoint_sni, };
use crate::config::{AuthenticationConfig, ProxyConfig, TlsConfig};
use crate::context::RequestContext;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
@@ -40,6 +41,30 @@ use crate::proxy::NeonOptions;
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, RoleName};
use subzero_core::{
api::{ApiRequest, ApiResponse, ContentType::*, SingleVal, ListVal, Payload},
error::Error::{self as SubzeroCoreError, SingularityError, PutMatchingPkError, PermissionDenied, JsonDeserialize, NotFound, JwtTokenInvalid,},
schema::DbSchema,
formatter::{
Param,
Param::*,
postgresql::{fmt_main_query, generate},
ToParam, Snippet, SqlParam,
},
error::JsonDeserializeSnafu,
dynamic_statement::{param, sql, JoinIterator},
};
use subzero_core::{
api::{ContentType, ContentType::*, Preferences, QueryNode::*, Representation, Resolution::*,},
error::{*},
parser::postgrest::parse,
permissions::{check_safe_functions, check_privileges, insert_policy_conditions, replace_select_star},
api::DEFAULT_SAFE_SELECT_FUNCTIONS,
};
use std::collections::HashMap;
use jsonpath_lib::select;
use url::form_urlencoded;
@@ -55,6 +80,9 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
// FIXME: remove this header
static HACK_TRUST_ROLE_SWITCHING: HeaderName = HeaderName::from_static("neon-hack-trust-role-switching");
#[derive(Debug, thiserror::Error)]
@@ -416,6 +444,8 @@ pub(crate) enum RestError {
JsonConversion(#[from] JsonConversionError),
#[error("{0}")]
Cancelled(SqlOverHttpCancel),
#[error("{0}")]
SubzeroCore(#[source] SubzeroCoreError),
}
impl ReportableError for RestError {
@@ -436,6 +466,7 @@ impl ReportableError for RestError {
}
RestError::JsonConversion(_) => ErrorKind::Postgres,
RestError::Cancelled(c) => c.get_error_kind(),
RestError::SubzeroCore(s) => ErrorKind::User,
}
}
}
@@ -452,6 +483,18 @@ impl UserFacingError for RestError {
RestError::InternalPostgres(p) => p.to_string(),
RestError::JsonConversion(_) => "could not parse postgres response".to_string(),
RestError::Cancelled(_) => self.to_string(),
RestError::SubzeroCore(s) => {
// TODO: this is a hack to get the message from the json body
let json = s.json_body();
let default_message = "Unknown error".to_string();
let message = json.get("message").map_or(default_message.clone(), |m|
match m {
JsonValue::String(s) => s.clone(),
_ => default_message,
}
);
message
}
}
}
}
@@ -471,6 +514,10 @@ impl HttpCodeError for RestError {
RestError::InternalPostgres(_) => StatusCode::INTERNAL_SERVER_ERROR,
RestError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR,
RestError::Cancelled(_) => StatusCode::INTERNAL_SERVER_ERROR,
RestError::SubzeroCore(e) => {
let status = e.status_code();
StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}
}
@@ -550,9 +597,10 @@ async fn handle_inner(
let host = request.uri().host().unwrap_or("").split('.').next().unwrap_or("");
let connection_string = format!("postgresql://authenticated@{}.local.neon.build/database", host);
// we always use the authenticator role to connect to the database
let autheticator_role = "authenticator";
let connection_string = format!("postgresql://{}@{}.local.neon.build/database", autheticator_role, host);
let conn_info = get_conn_info(
&config.authentication_config,
@@ -595,6 +643,125 @@ static HEADERS_TO_STRIP: &[&HeaderName] = &[
&TXN_READ_ONLY,
&TXN_DEFERRABLE,
];
static JSON_SCHEMA: &str = r#"
{
"schemas":[
{
"name":"test",
"objects":[
{
"kind":"table",
"name":"items",
"columns":[
{
"name":"id",
"data_type":"integer",
"primary_key":true
},
{
"name":"name",
"data_type":"text"
}
],
"foreign_keys":[],
"permissions":[]
}
]
}
]
}
"#;
pub fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> {
"select "
+ if env.is_empty() {
sql("null")
} else {
env.iter()
.map(|(k, v)| "set_config(" + param(k as &SqlParam) + ", " + param(v as &SqlParam) + ", true)")
.join(",")
}
}
fn current_schema(db_schemas: &Vec<String>, method: &Method, headers: &HeaderMap) -> Result<String, SubzeroCoreError> {
match (db_schemas.len() > 1, method, headers.get("accept-profile"), headers.get("content-profile")) {
(false, ..) => Ok(db_schemas.first().unwrap_or(&"_inexistent_".to_string()).clone()),
(_, &Method::DELETE, _, Some(content_profile_header))
| (_, &Method::POST, _, Some(content_profile_header))
| (_, &Method::PATCH, _, Some(content_profile_header))
| (_, &Method::PUT, _, Some(content_profile_header)) => {
match content_profile_header.to_str() {
Ok(content_profile_str) => {
let content_profile = String::from(content_profile_str);
if db_schemas.contains(&content_profile) {
Ok(content_profile)
} else {
Err(SubzeroCoreError::UnacceptableSchema {
schemas: db_schemas.clone(),
})
}
}
Err(_) => Err(SubzeroCoreError::UnacceptableSchema {
schemas: db_schemas.clone(),
})
}
}
(_, _, Some(accept_profile_header), _) => {
match accept_profile_header.to_str() {
Ok(accept_profile_str) => {
let accept_profile = String::from(accept_profile_str);
if db_schemas.contains(&accept_profile) {
Ok(accept_profile)
} else {
Err(SubzeroCoreError::UnacceptableSchema {
schemas: db_schemas.clone(),
})
}
}
Err(_) => Err(SubzeroCoreError::UnacceptableSchema {
schemas: db_schemas.clone(),
})
}
}
_ => Ok(db_schemas.first().unwrap_or(&"_inexistent_".to_string()).clone()),
}
}
pub fn to_core_error(e: SubzeroCoreError) -> RestError {
RestError::SubzeroCore(e)
}
fn to_sql_param(p: &Param) -> JsonValue {
match p {
SV(SingleVal(v, ..)) => {
JsonValue::String(v.to_string())
}
Str(v) => {
JsonValue::String(v.to_string())
}
StrOwned(v) => {
JsonValue::String((*v).clone())
}
PL(Payload(v, ..)) => {
JsonValue::String(v.clone().into_owned())
}
LV(ListVal(v, ..)) => {
if !v.is_empty() {
JsonValue::String(format!(
"{{\"{}\"}}",
v.iter()
.map(|e| e.replace('\\', "\\\\").replace('\"', "\\\""))
.collect::<Vec<_>>()
.join("\",\"")
))
} else {
JsonValue::String(r#"{}"#.to_string())
}
}
}
}
async fn handle_rest_inner(
ctx: &RequestContext,
request: Request<Incoming>,
@@ -603,16 +770,200 @@ async fn handle_rest_inner(
jwt: String,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
let mut response_headers = vec![];
// hardcoded values for now
let max_http_body_size = 10 * 1024 * 1024; // 10MB limit
let db_schemas = Vec::from(["test".to_string()]); // list of schemas available for the api
let mut db_schema_parsed = serde_json::from_str::<DbSchema>(JSON_SCHEMA) // database schema shape (will come from introspection)
.map_err(|e| RestError::SubzeroCore(JsonDeserialize {source:e }))?;
let disable_internal_permissions = true; // in the context of neon we emulate postgrest (so no internal permissions checks)
db_schema_parsed.use_internal_permissions = false; // TODO: change the introspection query to auto set this to false depending on params
let db_schema = &db_schema_parsed;
let api_prefix = "/rest/v1/";
let db_extra_search_path = "public, extensions".to_string();
let role_claim_key = ".role".to_string();
let role_claim_path = format!("${}", role_claim_key);
println!("role_claim_path: {:?}", role_claim_path);
let db_anon_role = Some("anon".to_string());
//let max_rows = Some("1000".to_string());
let max_rows = None;
let db_allowed_select_functions = DEFAULT_SAFE_SELECT_FUNCTIONS.iter().map(|m| *m).collect::<Vec<_>>();
let jwt_parsed = backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt.clone()) //TODO: do not clone jwt
.await
.map_err(HttpConnError::from)?;
let jwt_claims = match jwt_parsed.keys {
ComputeCredentialKeys::JwtPayload(payload_bytes) => {
// `payload_bytes` contains the raw JWT payload as Vec<u8>
// You can deserialize it back to JSON or parse specific claims
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
.map_err(|e| RestError::SubzeroCore(JsonDeserialize {source:e }))?;
Some(payload)
},
_ => {
None
}
};
println!("jwt_payload: {:?}", &jwt_claims);
let (role, authenticated) = match &jwt_claims {
Some(claims) => match select(claims, &role_claim_path) {
Ok(v) => match &v[..] {
[JsonValue::String(s)] => Ok((Some(s), true)),
_ => Ok((db_anon_role.as_ref(), true)),
},
Err(e) => Err(RestError::SubzeroCore(JwtTokenInvalid { message: format!("{e}") })),
},
None => Ok((db_anon_role.as_ref(), false)),
}?;
println!("role: {:?}", role);
println!("authenticated: {:?}", authenticated);
// do not allow unauthenticated requests when there is no anonymous role setup
if let (None, false) = (role, authenticated) {
return Err(RestError::SubzeroCore(JwtTokenInvalid {
message: "unauthenticated requests not allowed".to_string(),
}));
}
// println!("jwt: {:?}", jwt.keys);
let role = match role {
Some(r) => r,
None => "",
};
let (parts, originial_body) = request.into_parts();
let method = parts.method.to_string();
let path = parts.uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
// this is actually the table name (or rpc/function_name)
// TODO: rename this to something more descriptive
let root = match parts.uri.path().strip_prefix(api_prefix) {
Some(p) => Ok(p),
None => Err(RestError::SubzeroCore(NotFound {
target: parts.uri.path().to_string(),
})),
}?;
let schema_name = &current_schema(&db_schemas, &parts.method, &parts.headers).map_err(RestError::SubzeroCore)?;
if db_schemas.len() > 1 {
response_headers.push(("Content-Profile".to_string(), schema_name.clone()));
}
let body = Full::new(Bytes::new()).map_err(|never| match never {}).boxed();
// print all the local variables
println!("schema_name: {:?}", schema_name);
println!("db_schemas: {:?}", db_schemas);
println!("db_schema: {:?}", db_schema);
println!("root: {:?}", root);
println!("method: {:?}", method);
println!("path: {:?}", path);
println!("response_headers: {:?}", response_headers);
println!("originial_body: {:?}", originial_body);
//println!("parts: {:?}", parts);
println!("conn_info: {:?}", conn_info);
println!("jwt: {:?}", jwt);
let query = match parts.uri.query() {
Some(q) => form_urlencoded::parse(q.as_bytes()).collect(),
None => vec![],
};
let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect();
let headers_map = parts.headers;
let headers: HashMap<&str, &str> = headers_map.iter()
.map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__")))
.collect();
let cookies = HashMap::new(); // TODO: add cookies
// Read the request body
let body_bytes = read_body_with_limit(originial_body, max_http_body_size).await.map_err(ReadPayloadError::from)?; // 10MB limit
let body_as_string: Option<String> = if body_bytes.is_empty() {
None
} else {
Some(String::from_utf8_lossy(&body_bytes).into_owned())
};
println!("ready to parse!!!!!!!");
let mut api_request = parse(schema_name, root, db_schema, method.as_str(), path, get, body_as_string.as_deref(), headers, cookies, max_rows).map_err(RestError::SubzeroCore)?;
// in case when the role is not set (but authenticated through jwt) the query will be executed with the privileges
// of the "authenticator" role unless the DbSchema has internal privileges set
// replace "*" with the list of columns the user has access to
// so that he does not encounter permission errors
// replace_select_star(db_schema, schema_name, role, &mut api_request.query).map_err(to_core_error)?;
println!("after replace_select_star !!!!!!!");
if !disable_internal_permissions {
// check_privileges(db_schema, schema_name, role, &api_request).map_err(to_core_error)?;
println!("after check_privileges !!!!!!!");
}
println!("after check_privileges 2 !!!!!!!");
check_safe_functions(&api_request, &db_allowed_select_functions).map_err(to_core_error)?;
println!("after check_safe_functions !!!!!!!");
if !disable_internal_permissions {
insert_policy_conditions(db_schema, schema_name, role, &mut api_request.query).map_err(to_core_error)?;
println!("after insert_policy_conditions !!!!!!!");
}
println!("api_request after checks: {:?}", api_request);
// when using internal privileges not switch "current_role"
// TODO: why do we need this?
let env_role = if !disable_internal_permissions && db_schema.use_internal_permissions {
None
} else {
Some(role)
};
let empty_json = "{}".to_string();
let headers_env = serde_json::to_string(&api_request.headers).unwrap_or(empty_json.clone());
let cookies_env = serde_json::to_string(&api_request.cookies).unwrap_or(empty_json.clone());
let get_env = serde_json::to_string(&api_request.get).unwrap_or(empty_json.clone());
let jwt_claims_env = jwt_claims.as_ref().map(|v| serde_json::to_string(v).unwrap_or(empty_json.clone()))
.unwrap_or(
if let Some(r) = env_role {
let claims: HashMap<&str, &str> = HashMap::from([("role", r)]);
serde_json::to_string(&claims).unwrap_or(empty_json.clone())
} else {
empty_json.clone()
}
);
let mut env: HashMap<&str, &str> = HashMap::from([
("request.method", api_request.method),
("request.path", api_request.path),
("search_path", &db_extra_search_path),
("request.headers", &headers_env),
("request.cookies", &cookies_env),
("request.get", &get_env),
("request.jwt.claims", &jwt_claims_env),
]);
if let Some(r) = env_role {
env.insert("role", r.into());
}
let (env_statement, env_parameters, _) = generate(fmt_env_query(&env));
let (main_statement, main_parameters, _) = generate(fmt_main_query(db_schema, api_request.schema_name, &api_request, &env).map_err(to_core_error)?);
println!("env_statement: {:?} \n env_parameters: {:?}", env_statement, env_parameters);
println!("main_statement: {:?} \n main_parameters: {:?}", main_statement, main_parameters);
// now we are ready to send the request to the local proxy
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
let (parts, _originial_body) = request.into_parts();
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
// todo(conradludgate): maybe auth-broker should parse these and re-serialize
@@ -623,17 +974,33 @@ async fn handle_rest_inner(
// }
// }
// forward all headers except the ones in HEADERS_TO_STRIP
for (h, v) in parts.headers.iter() {
for (h, v) in headers_map.iter() {
if !HEADERS_TO_STRIP.contains(&h) {
req = req.header(h, v);
}
}
req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
req = req.header(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap());
// FIXME: remove this header
req = req.header(&HACK_TRUST_ROLE_SWITCHING, HeaderValue::from_static("true"));
let env_parameters_json = env_parameters.iter().map(|p| to_sql_param(&p.to_param())).collect::<Vec<_>>();
let main_parameters_json = main_parameters.iter().map(|p| to_sql_param(&p.to_param())).collect::<Vec<_>>();
let body: String = json!({
"query": "select 1 as one",
"params": [],
"queries": [
{
"query": env_statement,
"params": env_parameters_json,
},
{
"query": main_statement,
"params": main_parameters_json,
}
]
}).to_string();
let body_boxed = Full::new(Bytes::from(body))
@@ -656,6 +1023,12 @@ async fn handle_rest_inner(
.map_err(LocalProxyConnError::from)
.map_err(HttpConnError::from)?
.map(|b| b.boxed()))
// Ok(Response::builder()
// .status(StatusCode::OK)
// .body(body)
// .unwrap())
}

View File

@@ -869,7 +869,7 @@ async fn handle_auth_broker_inner(
req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
let req = req
.body(body.map_err(|e| e).boxed())
.body(body.map_err(|e| e).boxed()) //TODO: is there a potential for a regression here?
.expect("all headers and params received via hyper should be valid for request");
// todo: map body to count egress