diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 38f3b53f65..b33f4f05dd 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -142,14 +142,14 @@ fn create_neon_superuser(spec: &ComputeSpec, client: &mut Client) -> Result<()> .cluster .roles .iter() - .map(|r| format!("'{}'", escape_literal(&r.name))) + .map(|r| escape_literal(&r.name)) .collect::>(); let dbs = spec .cluster .databases .iter() - .map(|db| format!("'{}'", escape_literal(&db.name))) + .map(|db| escape_literal(&db.name)) .collect::>(); let roles_decl = if roles.is_empty() { diff --git a/compute_tools/src/config.rs b/compute_tools/src/config.rs index 99346433d0..68b943eec8 100644 --- a/compute_tools/src/config.rs +++ b/compute_tools/src/config.rs @@ -47,30 +47,22 @@ pub fn write_postgres_conf(path: &Path, spec: &ComputeSpec) -> Result<()> { // Add options for connecting to storage writeln!(file, "# Neon storage settings")?; if let Some(s) = &spec.pageserver_connstring { - writeln!( - file, - "neon.pageserver_connstring='{}'", - escape_conf_value(s) - )?; + writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?; } if !spec.safekeeper_connstrings.is_empty() { writeln!( file, - "neon.safekeepers='{}'", + "neon.safekeepers={}", escape_conf_value(&spec.safekeeper_connstrings.join(",")) )?; } if let Some(s) = &spec.tenant_id { - writeln!( - file, - "neon.tenant_id='{}'", - escape_conf_value(&s.to_string()) - )?; + writeln!(file, "neon.tenant_id={}", escape_conf_value(&s.to_string()))?; } if let Some(s) = &spec.timeline_id { writeln!( file, - "neon.timeline_id='{}'", + "neon.timeline_id={}", escape_conf_value(&s.to_string()) )?; } diff --git a/compute_tools/src/pg_helpers.rs b/compute_tools/src/pg_helpers.rs index 6a78bffd1b..75550978d8 100644 --- a/compute_tools/src/pg_helpers.rs +++ b/compute_tools/src/pg_helpers.rs @@ -16,15 +16,26 @@ use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role}; const POSTGRES_WAIT_TIMEOUT: Duration = Duration::from_millis(60 * 1000); // milliseconds -/// Escape a string for including it in a SQL literal +/// Escape a string for including it in a SQL literal. Wrapping the result +/// with `E'{}'` or `'{}'` is not required, as it returns a ready-to-use +/// SQL string literal, e.g. `'db'''` or `E'db\\'`. +/// See https://github.com/postgres/postgres/blob/da98d005cdbcd45af563d0c4ac86d0e9772cd15f/src/backend/utils/adt/quote.c#L47 +/// for the original implementation. pub fn escape_literal(s: &str) -> String { - s.replace('\'', "''").replace('\\', "\\\\") + let res = s.replace('\'', "''").replace('\\', "\\\\"); + + if res.contains('\\') { + format!("E'{}'", res) + } else { + format!("'{}'", res) + } } -/// Escape a string so that it can be used in postgresql.conf. -/// Same as escape_literal, currently. +/// Escape a string so that it can be used in postgresql.conf. Wrapping the result +/// with `'{}'` is not required, as it returns a ready-to-use config string. pub fn escape_conf_value(s: &str) -> String { - s.replace('\'', "''").replace('\\', "\\\\") + let res = s.replace('\'', "''").replace('\\', "\\\\"); + format!("'{}'", res) } trait GenericOptionExt { @@ -37,7 +48,7 @@ impl GenericOptionExt for GenericOption { fn to_pg_option(&self) -> String { if let Some(val) = &self.value { match self.vartype.as_ref() { - "string" => format!("{} '{}'", self.name, escape_literal(val)), + "string" => format!("{} {}", self.name, escape_literal(val)), _ => format!("{} {}", self.name, val), } } else { @@ -49,7 +60,7 @@ impl GenericOptionExt for GenericOption { fn to_pg_setting(&self) -> String { if let Some(val) = &self.value { match self.vartype.as_ref() { - "string" => format!("{} = '{}'", self.name, escape_conf_value(val)), + "string" => format!("{} = {}", self.name, escape_conf_value(val)), _ => format!("{} = {}", self.name, val), } } else { diff --git a/compute_tools/src/spec.rs b/compute_tools/src/spec.rs index 520696da00..575a5332a8 100644 --- a/compute_tools/src/spec.rs +++ b/compute_tools/src/spec.rs @@ -397,10 +397,44 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> { // We do not check either DB exists or not, // Postgres will take care of it for us "delete_db" => { - let query: String = format!("DROP DATABASE IF EXISTS {}", &op.name.pg_quote()); + // In Postgres we can't drop a database if it is a template. + // So we need to unset the template flag first, but it could + // be a retry, so we could've already dropped the database. + // Check that database exists first to make it idempotent. + let unset_template_query: String = format!( + " + DO $$ + BEGIN + IF EXISTS( + SELECT 1 + FROM pg_catalog.pg_database + WHERE datname = {} + ) + THEN + ALTER DATABASE {} is_template false; + END IF; + END + $$;", + escape_literal(&op.name), + &op.name.pg_quote() + ); + // Use FORCE to drop database even if there are active connections. + // We run this from `cloud_admin`, so it should have enough privileges. + // NB: there could be other db states, which prevent us from dropping + // the database. For example, if db is used by any active subscription + // or replication slot. + // TODO: deal with it once we allow logical replication. Proper fix should + // involve returning an error code to the control plane, so it could + // figure out that this is a non-retryable error, return it to the user + // and fail operation permanently. + let drop_db_query: String = format!( + "DROP DATABASE IF EXISTS {} WITH (FORCE)", + &op.name.pg_quote() + ); warn!("deleting database '{}'", &op.name); - client.execute(query.as_str(), &[])?; + client.execute(unset_template_query.as_str(), &[])?; + client.execute(drop_db_query.as_str(), &[])?; } "rename_db" => { let new_name = op.new_name.as_ref().unwrap(); diff --git a/compute_tools/tests/pg_helpers_tests.rs b/compute_tools/tests/pg_helpers_tests.rs index 265556d3b9..7d27d22a78 100644 --- a/compute_tools/tests/pg_helpers_tests.rs +++ b/compute_tools/tests/pg_helpers_tests.rs @@ -89,4 +89,12 @@ test.escaping = 'here''s a backslash \\ and a quote '' and a double-quote " hoor assert_eq!(none_generic_options.find("missed_value"), None); assert_eq!(none_generic_options.find("invalid_value"), None); } + + #[test] + fn test_escape_literal() { + assert_eq!(escape_literal("test"), "'test'"); + assert_eq!(escape_literal("test'"), "'test'''"); + assert_eq!(escape_literal("test\\'"), "E'test\\\\'''"); + assert_eq!(escape_literal("test\\'\\'"), "E'test\\\\''\\\\'''"); + } }