From c83722f348967bca8ea8affea636d423c4ef9d8c Mon Sep 17 00:00:00 2001 From: Nikita Kalyanov Date: Thu, 16 Nov 2023 17:09:03 +0100 Subject: [PATCH] API --- compute_tools/src/compute.rs | 33 +++++++++++++++++++++++++++++++++ compute_tools/src/http/api.rs | 23 ++++++++++++++++++++++- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 373f05ab2f..e9a3ff0930 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -115,6 +115,16 @@ pub struct ParsedSpec { pub storage_auth_token: Option, } +#[derive(Clone, Debug, serde::Deserialize)] +pub struct RowLevelParams { + pub table_name: String, + pub role: String, + pub user_name: String, + pub password: String, + pub database_name: String, + pub column_name: String, +} + impl TryFrom for ParsedSpec { type Error = String; fn try_from(spec: ComputeSpec) -> Result { @@ -1030,6 +1040,29 @@ LIMIT 100", download_size } + pub async fn ensure_row_level_sec(&self, params: RowLevelParams) -> Result { + let conn_str = self.connstr.as_str().replace("/postgres", &format!("/{}", params.database_name)); + let connect_result = tokio_postgres::connect(&conn_str, NoTls).await; + let (client, connection) = connect_result.unwrap(); + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + let result = client + .query( + "BEGIN; +ALTER TABLE $1 ENABLE ROW LEVEL SECURITY; +CREATE USER $2 WITH PASSWORD $3 IN GROUP $4; +CREATE POLICY neon_row_level ON $1 TO $4 + USING ($5 = current_user); +COMMIT;", +&[¶ms.table_name, ¶ms.user_name, ¶ms.password,¶ms.role, ¶ms.column_name], + ) + .await; + Ok(result.is_ok()) + } + #[tokio::main] pub async fn prepare_preload_libraries( &self, diff --git a/compute_tools/src/http/api.rs b/compute_tools/src/http/api.rs index 8851be1ec1..ab922c2e0a 100644 --- a/compute_tools/src/http/api.rs +++ b/compute_tools/src/http/api.rs @@ -5,7 +5,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::thread; -use crate::compute::{ComputeNode, ComputeState, ParsedSpec}; +use crate::compute::{ComputeNode, ComputeState, ParsedSpec, RowLevelParams}; use compute_api::requests::ConfigurationRequest; use compute_api::responses::{ComputeStatus, ComputeStatusResponse, GenericAPIError}; @@ -199,6 +199,27 @@ async fn routes(req: Request, compute: &Arc) -> Response { + info!("serving /ensure_row_level_sec GET request"); + let status = compute.get_status(); + if status != ComputeStatus::Running { + let msg = format!("compute is not running, current status: {:?}", status); + error!(msg); + let mut err_resp = Response::new(Body::from(msg)); + *err_resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return err_resp; + } + let body_bytes: Vec = hyper::body::to_bytes(req.into_body()).await.unwrap().into(); + let params: RowLevelParams = serde_json::from_str(&String::from_utf8(body_bytes).unwrap()).unwrap(); + + let res = compute.ensure_row_level_sec(params).await; + if !res.is_ok() { + let mut err_resp = Response::new(Body::from("query failed")); + *err_resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return err_resp; + } + Response::new(Body::from("")) + } // Return the `404 Not Found` for any other routes. _ => {