mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-14 11:40:38 +00:00
API
This commit is contained in:
@@ -115,6 +115,16 @@ pub struct ParsedSpec {
|
||||
pub storage_auth_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, serde::Deserialize)]
|
||||
pub struct RowLevelParams {
|
||||
pub table_name: String,
|
||||
pub role: String,
|
||||
pub user_name: String,
|
||||
pub password: String,
|
||||
pub database_name: String,
|
||||
pub column_name: String,
|
||||
}
|
||||
|
||||
impl TryFrom<ComputeSpec> for ParsedSpec {
|
||||
type Error = String;
|
||||
fn try_from(spec: ComputeSpec) -> Result<Self, String> {
|
||||
@@ -1030,6 +1040,29 @@ LIMIT 100",
|
||||
download_size
|
||||
}
|
||||
|
||||
pub async fn ensure_row_level_sec(&self, params: RowLevelParams) -> Result<bool> {
|
||||
let conn_str = self.connstr.as_str().replace("/postgres", &format!("/{}", params.database_name));
|
||||
let connect_result = tokio_postgres::connect(&conn_str, NoTls).await;
|
||||
let (client, connection) = connect_result.unwrap();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
let result = client
|
||||
.query(
|
||||
"BEGIN;
|
||||
ALTER TABLE $1 ENABLE ROW LEVEL SECURITY;
|
||||
CREATE USER $2 WITH PASSWORD $3 IN GROUP $4;
|
||||
CREATE POLICY neon_row_level ON $1 TO $4
|
||||
USING ($5 = current_user);
|
||||
COMMIT;",
|
||||
&[¶ms.table_name, ¶ms.user_name, ¶ms.password,¶ms.role, ¶ms.column_name],
|
||||
)
|
||||
.await;
|
||||
Ok(result.is_ok())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
pub async fn prepare_preload_libraries(
|
||||
&self,
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
use crate::compute::{ComputeNode, ComputeState, ParsedSpec};
|
||||
use crate::compute::{ComputeNode, ComputeState, ParsedSpec, RowLevelParams};
|
||||
use compute_api::requests::ConfigurationRequest;
|
||||
use compute_api::responses::{ComputeStatus, ComputeStatusResponse, GenericAPIError};
|
||||
|
||||
@@ -199,6 +199,27 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
}
|
||||
}
|
||||
}
|
||||
(&Method::POST, "/ensure_row_level_sec") => {
|
||||
info!("serving /ensure_row_level_sec GET request");
|
||||
let status = compute.get_status();
|
||||
if status != ComputeStatus::Running {
|
||||
let msg = format!("compute is not running, current status: {:?}", status);
|
||||
error!(msg);
|
||||
let mut err_resp = Response::new(Body::from(msg));
|
||||
*err_resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return err_resp;
|
||||
}
|
||||
let body_bytes: Vec<u8> = hyper::body::to_bytes(req.into_body()).await.unwrap().into();
|
||||
let params: RowLevelParams = serde_json::from_str(&String::from_utf8(body_bytes).unwrap()).unwrap();
|
||||
|
||||
let res = compute.ensure_row_level_sec(params).await;
|
||||
if !res.is_ok() {
|
||||
let mut err_resp = Response::new(Body::from("query failed"));
|
||||
*err_resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return err_resp;
|
||||
}
|
||||
Response::new(Body::from(""))
|
||||
}
|
||||
|
||||
// Return the `404 Not Found` for any other routes.
|
||||
_ => {
|
||||
|
||||
Reference in New Issue
Block a user