This commit is contained in:
Nikita Kalyanov
2023-11-16 17:09:03 +01:00
parent 6b82f22ada
commit c83722f348
2 changed files with 55 additions and 1 deletions

View File

@@ -115,6 +115,16 @@ pub struct ParsedSpec {
pub storage_auth_token: Option<String>,
}
#[derive(Clone, Debug, serde::Deserialize)]
pub struct RowLevelParams {
pub table_name: String,
pub role: String,
pub user_name: String,
pub password: String,
pub database_name: String,
pub column_name: String,
}
impl TryFrom<ComputeSpec> for ParsedSpec {
type Error = String;
fn try_from(spec: ComputeSpec) -> Result<Self, String> {
@@ -1030,6 +1040,29 @@ LIMIT 100",
download_size
}
pub async fn ensure_row_level_sec(&self, params: RowLevelParams) -> Result<bool> {
let conn_str = self.connstr.as_str().replace("/postgres", &format!("/{}", params.database_name));
let connect_result = tokio_postgres::connect(&conn_str, NoTls).await;
let (client, connection) = connect_result.unwrap();
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let result = client
.query(
"BEGIN;
ALTER TABLE $1 ENABLE ROW LEVEL SECURITY;
CREATE USER $2 WITH PASSWORD $3 IN GROUP $4;
CREATE POLICY neon_row_level ON $1 TO $4
USING ($5 = current_user);
COMMIT;",
&[&params.table_name, &params.user_name, &params.password,&params.role, &params.column_name],
)
.await;
Ok(result.is_ok())
}
#[tokio::main]
pub async fn prepare_preload_libraries(
&self,

View File

@@ -5,7 +5,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::thread;
use crate::compute::{ComputeNode, ComputeState, ParsedSpec};
use crate::compute::{ComputeNode, ComputeState, ParsedSpec, RowLevelParams};
use compute_api::requests::ConfigurationRequest;
use compute_api::responses::{ComputeStatus, ComputeStatusResponse, GenericAPIError};
@@ -199,6 +199,27 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
}
(&Method::POST, "/ensure_row_level_sec") => {
info!("serving /ensure_row_level_sec GET request");
let status = compute.get_status();
if status != ComputeStatus::Running {
let msg = format!("compute is not running, current status: {:?}", status);
error!(msg);
let mut err_resp = Response::new(Body::from(msg));
*err_resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return err_resp;
}
let body_bytes: Vec<u8> = hyper::body::to_bytes(req.into_body()).await.unwrap().into();
let params: RowLevelParams = serde_json::from_str(&String::from_utf8(body_bytes).unwrap()).unwrap();
let res = compute.ensure_row_level_sec(params).await;
if !res.is_ok() {
let mut err_resp = Response::new(Body::from("query failed"));
*err_resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return err_resp;
}
Response::new(Body::from(""))
}
// Return the `404 Not Found` for any other routes.
_ => {