diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index db6835da61..8b502a058e 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -40,7 +40,7 @@ use std::sync::mpsc; use std::thread; use std::time::Duration; -use anyhow::{Context, Result}; +use anyhow::{Context, Result, bail}; use clap::Parser; use compute_api::responses::ComputeConfig; use compute_tools::compute::{ @@ -57,14 +57,14 @@ use tracing::{error, info}; use url::Url; use utils::failpoint_support; -#[derive(Parser)] +#[derive(Debug, Parser)] #[command(rename_all = "kebab-case")] struct Cli { #[arg(short = 'b', long, default_value = "postgres", env = "POSTGRES_PATH")] pub pgbin: String, /// The base URL for the remote extension storage proxy gateway. - #[arg(short = 'r', long)] + #[arg(short = 'r', long, value_parser = Self::parse_remote_ext_base_url)] pub remote_ext_base_url: Option, /// The port to bind the external listening HTTP server to. Clients running @@ -126,6 +126,25 @@ struct Cli { pub installed_extensions_collection_interval: u64, } +impl Cli { + /// Parse a URL from an argument. By default, this isn't necessary, but we + /// want to do some sanity checking. + fn parse_remote_ext_base_url(value: &str) -> Result { + // Remove extra trailing slashes, and add one. We use Url::join() later + // when downloading remote extensions. If the base URL is something like + // http://example.com/pg-ext-s3-gateway, and join() is called with + // something like "xyz", the resulting URL is http://example.com/xyz. + let value = value.trim_end_matches('/').to_owned() + "/"; + let url = Url::parse(&value)?; + + if url.query_pairs().count() != 0 { + bail!("parameters detected in remote extensions base URL") + } + + Ok(url) + } +} + fn main() -> Result<()> { let cli = Cli::parse(); @@ -252,7 +271,8 @@ fn handle_exit_signal(sig: i32) { #[cfg(test)] mod test { - use clap::CommandFactory; + use clap::{CommandFactory, Parser}; + use url::Url; use super::Cli; @@ -260,4 +280,43 @@ mod test { fn verify_cli() { Cli::command().debug_assert() } + + #[test] + fn verify_remote_ext_base_url() { + let cli = Cli::parse_from([ + "compute_ctl", + "--pgdata=test", + "--connstr=test", + "--compute-id=test", + "--remote-ext-base-url", + "https://example.com/subpath", + ]); + assert_eq!( + cli.remote_ext_base_url.unwrap(), + Url::parse("https://example.com/subpath/").unwrap() + ); + + let cli = Cli::parse_from([ + "compute_ctl", + "--pgdata=test", + "--connstr=test", + "--compute-id=test", + "--remote-ext-base-url", + "https://example.com//", + ]); + assert_eq!( + cli.remote_ext_base_url.unwrap(), + Url::parse("https://example.com").unwrap() + ); + + Cli::try_parse_from([ + "compute_ctl", + "--pgdata=test", + "--connstr=test", + "--compute-id=test", + "--remote-ext-base-url", + "https://example.com?hello=world", + ]) + .expect_err("URL parameters are not allowed"); + } } diff --git a/compute_tools/src/extension_server.rs b/compute_tools/src/extension_server.rs index 1857afa08c..3764bc1525 100644 --- a/compute_tools/src/extension_server.rs +++ b/compute_tools/src/extension_server.rs @@ -166,7 +166,7 @@ pub async fn download_extension( // TODO add retry logic let download_buffer = - match download_extension_tar(remote_ext_base_url.as_str(), &ext_path.to_string()).await { + match download_extension_tar(remote_ext_base_url, &ext_path.to_string()).await { Ok(buffer) => buffer, Err(error_message) => { return Err(anyhow::anyhow!( @@ -271,10 +271,14 @@ pub fn create_control_files(remote_extensions: &RemoteExtSpec, pgbin: &str) { } // Do request to extension storage proxy, e.g., -// curl http://pg-ext-s3-gateway/latest/v15/extensions/anon.tar.zst +// curl http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/latest/v15/extensions/anon.tar.zst // using HTTP GET and return the response body as bytes. -async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Result { - let uri = format!("{}/{}", remote_ext_base_url, ext_path); +async fn download_extension_tar(remote_ext_base_url: &Url, ext_path: &str) -> Result { + let uri = remote_ext_base_url.join(ext_path).with_context(|| { + format!( + "failed to create the remote extension URI for {ext_path} using {remote_ext_base_url}" + ) + })?; let filename = Path::new(ext_path) .file_name() .unwrap_or_else(|| std::ffi::OsStr::new("unknown")) @@ -284,7 +288,7 @@ async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Re info!("Downloading extension file '{}' from uri {}", filename, uri); - match do_extension_server_request(&uri).await { + match do_extension_server_request(uri).await { Ok(resp) => { info!("Successfully downloaded remote extension data {}", ext_path); REMOTE_EXT_REQUESTS_TOTAL @@ -303,7 +307,7 @@ async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Re // Do a single remote extensions server request. // Return result or (error message + stringified status code) in case of any failures. -async fn do_extension_server_request(uri: &str) -> Result { +async fn do_extension_server_request(uri: Url) -> Result { let resp = reqwest::get(uri).await.map_err(|e| { ( format!( diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index 9b7cf43bb9..343923d446 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -250,34 +250,44 @@ impl RemoteExtSpec { } match self.extension_data.get(real_ext_name) { - Some(_ext_data) => { - // We have decided to use the Go naming convention due to Kubernetes. - - let arch = match std::env::consts::ARCH { - "x86_64" => "amd64", - "aarch64" => "arm64", - arch => arch, - }; - - // Construct the path to the extension archive - // BUILD_TAG/PG_MAJOR_VERSION/extensions/EXTENSION_NAME.tar.zst - // - // Keep it in sync with path generation in - // https://github.com/neondatabase/build-custom-extensions/tree/main - let archive_path_str = format!( - "{build_tag}/{arch}/{pg_major_version}/extensions/{real_ext_name}.tar.zst" - ); - Ok(( - real_ext_name.to_string(), - RemotePath::from_string(&archive_path_str)?, - )) - } + Some(_ext_data) => Ok(( + real_ext_name.to_string(), + Self::build_remote_path(build_tag, pg_major_version, real_ext_name)?, + )), None => Err(anyhow::anyhow!( "real_ext_name {} is not found", real_ext_name )), } } + + /// Get the architecture-specific portion of the remote extension path. We + /// use the Go naming convention due to Kubernetes. + fn get_arch() -> &'static str { + match std::env::consts::ARCH { + "x86_64" => "amd64", + "aarch64" => "arm64", + arch => arch, + } + } + + /// Build a [`RemotePath`] for an extension. + fn build_remote_path( + build_tag: &str, + pg_major_version: &str, + ext_name: &str, + ) -> anyhow::Result { + let arch = Self::get_arch(); + + // Construct the path to the extension archive + // BUILD_TAG/PG_MAJOR_VERSION/extensions/EXTENSION_NAME.tar.zst + // + // Keep it in sync with path generation in + // https://github.com/neondatabase/build-custom-extensions/tree/main + RemotePath::from_string(&format!( + "{build_tag}/{arch}/{pg_major_version}/extensions/{ext_name}.tar.zst" + )) + } } #[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Deserialize, Serialize)] @@ -518,6 +528,37 @@ mod tests { .expect("Library should be found"); } + #[test] + fn remote_extension_path() { + let rspec: RemoteExtSpec = serde_json::from_value(serde_json::json!({ + "public_extensions": ["ext"], + "custom_extensions": [], + "library_index": { + "extlib": "ext", + }, + "extension_data": { + "ext": { + "control_data": { + "ext.control": "" + }, + "archive_path": "" + } + }, + })) + .unwrap(); + + let (_ext_name, ext_path) = rspec + .get_ext("ext", false, "latest", "v17") + .expect("Extension should be found"); + // Starting with a forward slash would have consequences for the + // Url::join() that occurs when downloading a remote extension. + assert!(!ext_path.to_string().starts_with("/")); + assert_eq!( + ext_path, + RemoteExtSpec::build_remote_path("latest", "v17", "ext").unwrap() + ); + } + #[test] fn parse_spec_file() { let file = File::open("tests/cluster_spec.json").unwrap(); diff --git a/test_runner/regress/test_download_extensions.py b/test_runner/regress/test_download_extensions.py index 24ba0713d2..fe3b220c67 100644 --- a/test_runner/regress/test_download_extensions.py +++ b/test_runner/regress/test_download_extensions.py @@ -159,7 +159,8 @@ def test_remote_extensions( # Setup a mock nginx S3 gateway which will return our test extension. (host, port) = httpserver_listen_address - extensions_endpoint = f"http://{host}:{port}/pg-ext-s3-gateway" + remote_ext_base_url = f"http://{host}:{port}/pg-ext-s3-gateway" + log.info(f"remote extensions base URL: {remote_ext_base_url}") extension.build(pg_config, test_output_dir) tarball = extension.package(test_output_dir) @@ -221,7 +222,7 @@ def test_remote_extensions( endpoint.create_remote_extension_spec(spec) - endpoint.start(remote_ext_base_url=extensions_endpoint) + endpoint.start(remote_ext_base_url=remote_ext_base_url) with endpoint.connect() as conn: with conn.cursor() as cur: @@ -249,7 +250,7 @@ def test_remote_extensions( # Remove the extension files to force a redownload of the extension. extension.remove(test_output_dir, pg_version) - endpoint.start(remote_ext_base_url=extensions_endpoint) + endpoint.start(remote_ext_base_url=remote_ext_base_url) # Test that ALTER EXTENSION UPDATE statements also fetch remote extensions. with endpoint.connect() as conn: