Use Url::join() when creating the final remote extension URL (#12121)

Url::to_string() adds a trailing slash on the base URL, so when we did
the format!(), we were adding a double forward slash.

Signed-off-by: Tristan Partin <tristan@neon.tech>
This commit is contained in:
Tristan Partin
2025-06-04 10:56:12 -05:00
committed by GitHub
parent e7d6f525b3
commit 3fd5a94a85
4 changed files with 140 additions and 35 deletions

View File

@@ -40,7 +40,7 @@ use std::sync::mpsc;
use std::thread;
use std::time::Duration;
use anyhow::{Context, Result};
use anyhow::{Context, Result, bail};
use clap::Parser;
use compute_api::responses::ComputeConfig;
use compute_tools::compute::{
@@ -57,14 +57,14 @@ use tracing::{error, info};
use url::Url;
use utils::failpoint_support;
#[derive(Parser)]
#[derive(Debug, Parser)]
#[command(rename_all = "kebab-case")]
struct Cli {
#[arg(short = 'b', long, default_value = "postgres", env = "POSTGRES_PATH")]
pub pgbin: String,
/// The base URL for the remote extension storage proxy gateway.
#[arg(short = 'r', long)]
#[arg(short = 'r', long, value_parser = Self::parse_remote_ext_base_url)]
pub remote_ext_base_url: Option<Url>,
/// The port to bind the external listening HTTP server to. Clients running
@@ -126,6 +126,25 @@ struct Cli {
pub installed_extensions_collection_interval: u64,
}
impl Cli {
/// Parse a URL from an argument. By default, this isn't necessary, but we
/// want to do some sanity checking.
fn parse_remote_ext_base_url(value: &str) -> Result<Url> {
// Remove extra trailing slashes, and add one. We use Url::join() later
// when downloading remote extensions. If the base URL is something like
// http://example.com/pg-ext-s3-gateway, and join() is called with
// something like "xyz", the resulting URL is http://example.com/xyz.
let value = value.trim_end_matches('/').to_owned() + "/";
let url = Url::parse(&value)?;
if url.query_pairs().count() != 0 {
bail!("parameters detected in remote extensions base URL")
}
Ok(url)
}
}
fn main() -> Result<()> {
let cli = Cli::parse();
@@ -252,7 +271,8 @@ fn handle_exit_signal(sig: i32) {
#[cfg(test)]
mod test {
use clap::CommandFactory;
use clap::{CommandFactory, Parser};
use url::Url;
use super::Cli;
@@ -260,4 +280,43 @@ mod test {
fn verify_cli() {
Cli::command().debug_assert()
}
#[test]
fn verify_remote_ext_base_url() {
let cli = Cli::parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--remote-ext-base-url",
"https://example.com/subpath",
]);
assert_eq!(
cli.remote_ext_base_url.unwrap(),
Url::parse("https://example.com/subpath/").unwrap()
);
let cli = Cli::parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--remote-ext-base-url",
"https://example.com//",
]);
assert_eq!(
cli.remote_ext_base_url.unwrap(),
Url::parse("https://example.com").unwrap()
);
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--remote-ext-base-url",
"https://example.com?hello=world",
])
.expect_err("URL parameters are not allowed");
}
}

View File

@@ -166,7 +166,7 @@ pub async fn download_extension(
// TODO add retry logic
let download_buffer =
match download_extension_tar(remote_ext_base_url.as_str(), &ext_path.to_string()).await {
match download_extension_tar(remote_ext_base_url, &ext_path.to_string()).await {
Ok(buffer) => buffer,
Err(error_message) => {
return Err(anyhow::anyhow!(
@@ -271,10 +271,14 @@ pub fn create_control_files(remote_extensions: &RemoteExtSpec, pgbin: &str) {
}
// Do request to extension storage proxy, e.g.,
// curl http://pg-ext-s3-gateway/latest/v15/extensions/anon.tar.zst
// curl http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/latest/v15/extensions/anon.tar.zst
// using HTTP GET and return the response body as bytes.
async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Result<Bytes> {
let uri = format!("{}/{}", remote_ext_base_url, ext_path);
async fn download_extension_tar(remote_ext_base_url: &Url, ext_path: &str) -> Result<Bytes> {
let uri = remote_ext_base_url.join(ext_path).with_context(|| {
format!(
"failed to create the remote extension URI for {ext_path} using {remote_ext_base_url}"
)
})?;
let filename = Path::new(ext_path)
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("unknown"))
@@ -284,7 +288,7 @@ async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Re
info!("Downloading extension file '{}' from uri {}", filename, uri);
match do_extension_server_request(&uri).await {
match do_extension_server_request(uri).await {
Ok(resp) => {
info!("Successfully downloaded remote extension data {}", ext_path);
REMOTE_EXT_REQUESTS_TOTAL
@@ -303,7 +307,7 @@ async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Re
// Do a single remote extensions server request.
// Return result or (error message + stringified status code) in case of any failures.
async fn do_extension_server_request(uri: &str) -> Result<Bytes, (String, String)> {
async fn do_extension_server_request(uri: Url) -> Result<Bytes, (String, String)> {
let resp = reqwest::get(uri).await.map_err(|e| {
(
format!(

View File

@@ -250,34 +250,44 @@ impl RemoteExtSpec {
}
match self.extension_data.get(real_ext_name) {
Some(_ext_data) => {
// We have decided to use the Go naming convention due to Kubernetes.
let arch = match std::env::consts::ARCH {
"x86_64" => "amd64",
"aarch64" => "arm64",
arch => arch,
};
// Construct the path to the extension archive
// BUILD_TAG/PG_MAJOR_VERSION/extensions/EXTENSION_NAME.tar.zst
//
// Keep it in sync with path generation in
// https://github.com/neondatabase/build-custom-extensions/tree/main
let archive_path_str = format!(
"{build_tag}/{arch}/{pg_major_version}/extensions/{real_ext_name}.tar.zst"
);
Ok((
real_ext_name.to_string(),
RemotePath::from_string(&archive_path_str)?,
))
}
Some(_ext_data) => Ok((
real_ext_name.to_string(),
Self::build_remote_path(build_tag, pg_major_version, real_ext_name)?,
)),
None => Err(anyhow::anyhow!(
"real_ext_name {} is not found",
real_ext_name
)),
}
}
/// Get the architecture-specific portion of the remote extension path. We
/// use the Go naming convention due to Kubernetes.
fn get_arch() -> &'static str {
match std::env::consts::ARCH {
"x86_64" => "amd64",
"aarch64" => "arm64",
arch => arch,
}
}
/// Build a [`RemotePath`] for an extension.
fn build_remote_path(
build_tag: &str,
pg_major_version: &str,
ext_name: &str,
) -> anyhow::Result<RemotePath> {
let arch = Self::get_arch();
// Construct the path to the extension archive
// BUILD_TAG/PG_MAJOR_VERSION/extensions/EXTENSION_NAME.tar.zst
//
// Keep it in sync with path generation in
// https://github.com/neondatabase/build-custom-extensions/tree/main
RemotePath::from_string(&format!(
"{build_tag}/{arch}/{pg_major_version}/extensions/{ext_name}.tar.zst"
))
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Deserialize, Serialize)]
@@ -518,6 +528,37 @@ mod tests {
.expect("Library should be found");
}
#[test]
fn remote_extension_path() {
let rspec: RemoteExtSpec = serde_json::from_value(serde_json::json!({
"public_extensions": ["ext"],
"custom_extensions": [],
"library_index": {
"extlib": "ext",
},
"extension_data": {
"ext": {
"control_data": {
"ext.control": ""
},
"archive_path": ""
}
},
}))
.unwrap();
let (_ext_name, ext_path) = rspec
.get_ext("ext", false, "latest", "v17")
.expect("Extension should be found");
// Starting with a forward slash would have consequences for the
// Url::join() that occurs when downloading a remote extension.
assert!(!ext_path.to_string().starts_with("/"));
assert_eq!(
ext_path,
RemoteExtSpec::build_remote_path("latest", "v17", "ext").unwrap()
);
}
#[test]
fn parse_spec_file() {
let file = File::open("tests/cluster_spec.json").unwrap();

View File

@@ -159,7 +159,8 @@ def test_remote_extensions(
# Setup a mock nginx S3 gateway which will return our test extension.
(host, port) = httpserver_listen_address
extensions_endpoint = f"http://{host}:{port}/pg-ext-s3-gateway"
remote_ext_base_url = f"http://{host}:{port}/pg-ext-s3-gateway"
log.info(f"remote extensions base URL: {remote_ext_base_url}")
extension.build(pg_config, test_output_dir)
tarball = extension.package(test_output_dir)
@@ -221,7 +222,7 @@ def test_remote_extensions(
endpoint.create_remote_extension_spec(spec)
endpoint.start(remote_ext_base_url=extensions_endpoint)
endpoint.start(remote_ext_base_url=remote_ext_base_url)
with endpoint.connect() as conn:
with conn.cursor() as cur:
@@ -249,7 +250,7 @@ def test_remote_extensions(
# Remove the extension files to force a redownload of the extension.
extension.remove(test_output_dir, pg_version)
endpoint.start(remote_ext_base_url=extensions_endpoint)
endpoint.start(remote_ext_base_url=remote_ext_base_url)
# Test that ALTER EXTENSION UPDATE statements also fetch remote extensions.
with endpoint.connect() as conn: