Compare commits

...

5 Commits

Author SHA1 Message Date
Nikita Kalyanov
e29fbb96b5 idempotency 2023-11-17 08:56:03 +01:00
Nikita Kalyanov
90357fbc45 match more precise 2023-11-16 21:24:13 +01:00
Nikita Kalyanov
8987de089c fix query usage 2023-11-16 21:04:27 +01:00
Nikita Kalyanov
c61fca9a5f fmt 2023-11-16 18:52:55 +01:00
Nikita Kalyanov
c83722f348 API 2023-11-16 17:09:03 +01:00
2 changed files with 69 additions and 1 deletions

View File

@@ -115,6 +115,16 @@ pub struct ParsedSpec {
pub storage_auth_token: Option<String>,
}
#[derive(Clone, Debug, serde::Deserialize)]
pub struct RowLevelParams {
pub table_name: String,
pub role: String,
pub user_name: String,
pub password: String,
pub database_name: String,
pub column_name: String,
}
impl TryFrom<ComputeSpec> for ParsedSpec {
type Error = String;
fn try_from(spec: ComputeSpec) -> Result<Self, String> {
@@ -1030,6 +1040,39 @@ LIMIT 100",
download_size
}
pub async fn ensure_row_level_sec(&self, params: RowLevelParams) -> Result<bool> {
let conn_str = self
.connstr
.as_str()
.replace("/postgres", &format!("/{}", params.database_name));
let connect_result = tokio_postgres::connect(&conn_str, NoTls).await;
let (client, connection) = connect_result.unwrap();
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let result = client
.batch_execute(&format!(
"BEGIN;
ALTER TABLE {0} ENABLE ROW LEVEL SECURITY;
DROP POLICY IF EXISTS neon_row_level ON {0};
DROP ROLE IF EXISTS {1};
CREATE USER {1} WITH PASSWORD '{2}' IN GROUP {3};
CREATE POLICY neon_row_level ON {0} TO {3}
USING ({4} = current_user)
WITH CHECK ({4} = current_user);
COMMIT;",
&params.table_name,
&params.user_name,
&params.password,
&params.role,
&params.column_name,
))
.await;
Ok(result.is_ok())
}
#[tokio::main]
pub async fn prepare_preload_libraries(
&self,

View File

@@ -5,7 +5,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::thread;
use crate::compute::{ComputeNode, ComputeState, ParsedSpec};
use crate::compute::{ComputeNode, ComputeState, ParsedSpec, RowLevelParams};
use compute_api::requests::ConfigurationRequest;
use compute_api::responses::{ComputeStatus, ComputeStatusResponse, GenericAPIError};
@@ -199,6 +199,31 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
}
(&Method::POST, "/ensure_row_level_sec") => {
info!("serving /ensure_row_level_sec GET request");
let status = compute.get_status();
if status != ComputeStatus::Running {
let msg = format!("compute is not running, current status: {:?}", status);
error!(msg);
let mut err_resp = Response::new(Body::from(msg));
*err_resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return err_resp;
}
let body_bytes: Vec<u8> = hyper::body::to_bytes(req.into_body()).await.unwrap().into();
let params: RowLevelParams =
serde_json::from_str(&String::from_utf8(body_bytes).unwrap()).unwrap();
let res = compute.ensure_row_level_sec(params).await;
match res {
Ok(true) => (),
_ => {
let mut err_resp = Response::new(Body::from("query failed"));
*err_resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return err_resp;
}
}
Response::new(Body::from(""))
}
// Return the `404 Not Found` for any other routes.
_ => {