mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-09 14:32:57 +00:00
Use Url::join() when creating the final remote extension URL (#12121)
Url::to_string() adds a trailing slash on the base URL, so when we did the format!(), we were adding a double forward slash. Signed-off-by: Tristan Partin <tristan@neon.tech>
This commit is contained in:
@@ -40,7 +40,7 @@ use std::sync::mpsc;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use clap::Parser;
|
||||
use compute_api::responses::ComputeConfig;
|
||||
use compute_tools::compute::{
|
||||
@@ -57,14 +57,14 @@ use tracing::{error, info};
|
||||
use url::Url;
|
||||
use utils::failpoint_support;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[derive(Debug, Parser)]
|
||||
#[command(rename_all = "kebab-case")]
|
||||
struct Cli {
|
||||
#[arg(short = 'b', long, default_value = "postgres", env = "POSTGRES_PATH")]
|
||||
pub pgbin: String,
|
||||
|
||||
/// The base URL for the remote extension storage proxy gateway.
|
||||
#[arg(short = 'r', long)]
|
||||
#[arg(short = 'r', long, value_parser = Self::parse_remote_ext_base_url)]
|
||||
pub remote_ext_base_url: Option<Url>,
|
||||
|
||||
/// The port to bind the external listening HTTP server to. Clients running
|
||||
@@ -126,6 +126,25 @@ struct Cli {
|
||||
pub installed_extensions_collection_interval: u64,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
/// Parse a URL from an argument. By default, this isn't necessary, but we
|
||||
/// want to do some sanity checking.
|
||||
fn parse_remote_ext_base_url(value: &str) -> Result<Url> {
|
||||
// Remove extra trailing slashes, and add one. We use Url::join() later
|
||||
// when downloading remote extensions. If the base URL is something like
|
||||
// http://example.com/pg-ext-s3-gateway, and join() is called with
|
||||
// something like "xyz", the resulting URL is http://example.com/xyz.
|
||||
let value = value.trim_end_matches('/').to_owned() + "/";
|
||||
let url = Url::parse(&value)?;
|
||||
|
||||
if url.query_pairs().count() != 0 {
|
||||
bail!("parameters detected in remote extensions base URL")
|
||||
}
|
||||
|
||||
Ok(url)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
@@ -252,7 +271,8 @@ fn handle_exit_signal(sig: i32) {
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use clap::CommandFactory;
|
||||
use clap::{CommandFactory, Parser};
|
||||
use url::Url;
|
||||
|
||||
use super::Cli;
|
||||
|
||||
@@ -260,4 +280,43 @@ mod test {
|
||||
fn verify_cli() {
|
||||
Cli::command().debug_assert()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_remote_ext_base_url() {
|
||||
let cli = Cli::parse_from([
|
||||
"compute_ctl",
|
||||
"--pgdata=test",
|
||||
"--connstr=test",
|
||||
"--compute-id=test",
|
||||
"--remote-ext-base-url",
|
||||
"https://example.com/subpath",
|
||||
]);
|
||||
assert_eq!(
|
||||
cli.remote_ext_base_url.unwrap(),
|
||||
Url::parse("https://example.com/subpath/").unwrap()
|
||||
);
|
||||
|
||||
let cli = Cli::parse_from([
|
||||
"compute_ctl",
|
||||
"--pgdata=test",
|
||||
"--connstr=test",
|
||||
"--compute-id=test",
|
||||
"--remote-ext-base-url",
|
||||
"https://example.com//",
|
||||
]);
|
||||
assert_eq!(
|
||||
cli.remote_ext_base_url.unwrap(),
|
||||
Url::parse("https://example.com").unwrap()
|
||||
);
|
||||
|
||||
Cli::try_parse_from([
|
||||
"compute_ctl",
|
||||
"--pgdata=test",
|
||||
"--connstr=test",
|
||||
"--compute-id=test",
|
||||
"--remote-ext-base-url",
|
||||
"https://example.com?hello=world",
|
||||
])
|
||||
.expect_err("URL parameters are not allowed");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,7 +166,7 @@ pub async fn download_extension(
|
||||
|
||||
// TODO add retry logic
|
||||
let download_buffer =
|
||||
match download_extension_tar(remote_ext_base_url.as_str(), &ext_path.to_string()).await {
|
||||
match download_extension_tar(remote_ext_base_url, &ext_path.to_string()).await {
|
||||
Ok(buffer) => buffer,
|
||||
Err(error_message) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
@@ -271,10 +271,14 @@ pub fn create_control_files(remote_extensions: &RemoteExtSpec, pgbin: &str) {
|
||||
}
|
||||
|
||||
// Do request to extension storage proxy, e.g.,
|
||||
// curl http://pg-ext-s3-gateway/latest/v15/extensions/anon.tar.zst
|
||||
// curl http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/latest/v15/extensions/anon.tar.zst
|
||||
// using HTTP GET and return the response body as bytes.
|
||||
async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Result<Bytes> {
|
||||
let uri = format!("{}/{}", remote_ext_base_url, ext_path);
|
||||
async fn download_extension_tar(remote_ext_base_url: &Url, ext_path: &str) -> Result<Bytes> {
|
||||
let uri = remote_ext_base_url.join(ext_path).with_context(|| {
|
||||
format!(
|
||||
"failed to create the remote extension URI for {ext_path} using {remote_ext_base_url}"
|
||||
)
|
||||
})?;
|
||||
let filename = Path::new(ext_path)
|
||||
.file_name()
|
||||
.unwrap_or_else(|| std::ffi::OsStr::new("unknown"))
|
||||
@@ -284,7 +288,7 @@ async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Re
|
||||
|
||||
info!("Downloading extension file '{}' from uri {}", filename, uri);
|
||||
|
||||
match do_extension_server_request(&uri).await {
|
||||
match do_extension_server_request(uri).await {
|
||||
Ok(resp) => {
|
||||
info!("Successfully downloaded remote extension data {}", ext_path);
|
||||
REMOTE_EXT_REQUESTS_TOTAL
|
||||
@@ -303,7 +307,7 @@ async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Re
|
||||
|
||||
// Do a single remote extensions server request.
|
||||
// Return result or (error message + stringified status code) in case of any failures.
|
||||
async fn do_extension_server_request(uri: &str) -> Result<Bytes, (String, String)> {
|
||||
async fn do_extension_server_request(uri: Url) -> Result<Bytes, (String, String)> {
|
||||
let resp = reqwest::get(uri).await.map_err(|e| {
|
||||
(
|
||||
format!(
|
||||
|
||||
@@ -250,34 +250,44 @@ impl RemoteExtSpec {
|
||||
}
|
||||
|
||||
match self.extension_data.get(real_ext_name) {
|
||||
Some(_ext_data) => {
|
||||
// We have decided to use the Go naming convention due to Kubernetes.
|
||||
|
||||
let arch = match std::env::consts::ARCH {
|
||||
"x86_64" => "amd64",
|
||||
"aarch64" => "arm64",
|
||||
arch => arch,
|
||||
};
|
||||
|
||||
// Construct the path to the extension archive
|
||||
// BUILD_TAG/PG_MAJOR_VERSION/extensions/EXTENSION_NAME.tar.zst
|
||||
//
|
||||
// Keep it in sync with path generation in
|
||||
// https://github.com/neondatabase/build-custom-extensions/tree/main
|
||||
let archive_path_str = format!(
|
||||
"{build_tag}/{arch}/{pg_major_version}/extensions/{real_ext_name}.tar.zst"
|
||||
);
|
||||
Ok((
|
||||
real_ext_name.to_string(),
|
||||
RemotePath::from_string(&archive_path_str)?,
|
||||
))
|
||||
}
|
||||
Some(_ext_data) => Ok((
|
||||
real_ext_name.to_string(),
|
||||
Self::build_remote_path(build_tag, pg_major_version, real_ext_name)?,
|
||||
)),
|
||||
None => Err(anyhow::anyhow!(
|
||||
"real_ext_name {} is not found",
|
||||
real_ext_name
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the architecture-specific portion of the remote extension path. We
|
||||
/// use the Go naming convention due to Kubernetes.
|
||||
fn get_arch() -> &'static str {
|
||||
match std::env::consts::ARCH {
|
||||
"x86_64" => "amd64",
|
||||
"aarch64" => "arm64",
|
||||
arch => arch,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a [`RemotePath`] for an extension.
|
||||
fn build_remote_path(
|
||||
build_tag: &str,
|
||||
pg_major_version: &str,
|
||||
ext_name: &str,
|
||||
) -> anyhow::Result<RemotePath> {
|
||||
let arch = Self::get_arch();
|
||||
|
||||
// Construct the path to the extension archive
|
||||
// BUILD_TAG/PG_MAJOR_VERSION/extensions/EXTENSION_NAME.tar.zst
|
||||
//
|
||||
// Keep it in sync with path generation in
|
||||
// https://github.com/neondatabase/build-custom-extensions/tree/main
|
||||
RemotePath::from_string(&format!(
|
||||
"{build_tag}/{arch}/{pg_major_version}/extensions/{ext_name}.tar.zst"
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Deserialize, Serialize)]
|
||||
@@ -518,6 +528,37 @@ mod tests {
|
||||
.expect("Library should be found");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_extension_path() {
|
||||
let rspec: RemoteExtSpec = serde_json::from_value(serde_json::json!({
|
||||
"public_extensions": ["ext"],
|
||||
"custom_extensions": [],
|
||||
"library_index": {
|
||||
"extlib": "ext",
|
||||
},
|
||||
"extension_data": {
|
||||
"ext": {
|
||||
"control_data": {
|
||||
"ext.control": ""
|
||||
},
|
||||
"archive_path": ""
|
||||
}
|
||||
},
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let (_ext_name, ext_path) = rspec
|
||||
.get_ext("ext", false, "latest", "v17")
|
||||
.expect("Extension should be found");
|
||||
// Starting with a forward slash would have consequences for the
|
||||
// Url::join() that occurs when downloading a remote extension.
|
||||
assert!(!ext_path.to_string().starts_with("/"));
|
||||
assert_eq!(
|
||||
ext_path,
|
||||
RemoteExtSpec::build_remote_path("latest", "v17", "ext").unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_file() {
|
||||
let file = File::open("tests/cluster_spec.json").unwrap();
|
||||
|
||||
@@ -159,7 +159,8 @@ def test_remote_extensions(
|
||||
|
||||
# Setup a mock nginx S3 gateway which will return our test extension.
|
||||
(host, port) = httpserver_listen_address
|
||||
extensions_endpoint = f"http://{host}:{port}/pg-ext-s3-gateway"
|
||||
remote_ext_base_url = f"http://{host}:{port}/pg-ext-s3-gateway"
|
||||
log.info(f"remote extensions base URL: {remote_ext_base_url}")
|
||||
|
||||
extension.build(pg_config, test_output_dir)
|
||||
tarball = extension.package(test_output_dir)
|
||||
@@ -221,7 +222,7 @@ def test_remote_extensions(
|
||||
|
||||
endpoint.create_remote_extension_spec(spec)
|
||||
|
||||
endpoint.start(remote_ext_base_url=extensions_endpoint)
|
||||
endpoint.start(remote_ext_base_url=remote_ext_base_url)
|
||||
|
||||
with endpoint.connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
@@ -249,7 +250,7 @@ def test_remote_extensions(
|
||||
# Remove the extension files to force a redownload of the extension.
|
||||
extension.remove(test_output_dir, pg_version)
|
||||
|
||||
endpoint.start(remote_ext_base_url=extensions_endpoint)
|
||||
endpoint.start(remote_ext_base_url=remote_ext_base_url)
|
||||
|
||||
# Test that ALTER EXTENSION UPDATE statements also fetch remote extensions.
|
||||
with endpoint.connect() as conn:
|
||||
|
||||
Reference in New Issue
Block a user