From ebc55e6ae87723a95303e62e9e7b16dae218676c Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Wed, 5 Feb 2025 08:58:33 -0600 Subject: [PATCH] Fix logic for checking if a compute can install a remote extension (#10656) Given a remote extensions manifest of the following: ```json { "public_extensions": [], "custom_extensions": null, "library_index": { "pg_search": "pg_search" }, "extension_data": { "pg_search": { "control_data": { "pg_search.control": "comment = 'pg_search: Full text search for PostgreSQL using BM25'\ndefault_version = '0.14.1'\nmodule_pathname = '$libdir/pg_search'\nrelocatable = false\nsuperuser = true\nschema = paradedb\ntrusted = true\n" }, "archive_path": "13117844657/v14/extensions/pg_search.tar.zst" } } } ``` We were allowing a compute to install a remote extension that wasn't listed in either public_extensions or custom_extensions. Signed-off-by: Tristan Partin --- libs/compute_api/src/spec.rs | 108 ++++++++++++++++-- .../regress/test_download_extensions.py | 2 + 2 files changed, 102 insertions(+), 8 deletions(-) diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index b3f18dc6da..2fc95c47c6 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -204,14 +204,16 @@ impl RemoteExtSpec { // Check if extension is present in public or custom. // If not, then it is not allowed to be used by this compute. - if let Some(public_extensions) = &self.public_extensions { - if !public_extensions.contains(&real_ext_name.to_string()) { - if let Some(custom_extensions) = &self.custom_extensions { - if !custom_extensions.contains(&real_ext_name.to_string()) { - return Err(anyhow::anyhow!("extension {} is not found", real_ext_name)); - } - } - } + if !self + .public_extensions + .as_ref() + .is_some_and(|exts| exts.iter().any(|e| e == ext_name)) + && !self + .custom_extensions + .as_ref() + .is_some_and(|exts| exts.iter().any(|e| e == ext_name)) + { + return Err(anyhow::anyhow!("extension {} is not found", real_ext_name)); } match self.extension_data.get(real_ext_name) { @@ -340,6 +342,96 @@ mod tests { use super::*; use std::fs::File; + #[test] + fn allow_installing_remote_extensions() { + let rspec: RemoteExtSpec = serde_json::from_value(serde_json::json!({ + "public_extensions": null, + "custom_extensions": null, + "library_index": {}, + "extension_data": {}, + })) + .unwrap(); + + rspec + .get_ext("ext", false, "latest", "v17") + .expect_err("Extension should not be found"); + + let rspec: RemoteExtSpec = serde_json::from_value(serde_json::json!({ + "public_extensions": [], + "custom_extensions": null, + "library_index": {}, + "extension_data": {}, + })) + .unwrap(); + + rspec + .get_ext("ext", false, "latest", "v17") + .expect_err("Extension should not be found"); + + let rspec: RemoteExtSpec = serde_json::from_value(serde_json::json!({ + "public_extensions": [], + "custom_extensions": [], + "library_index": { + "ext": "ext" + }, + "extension_data": { + "ext": { + "control_data": { + "ext.control": "" + }, + "archive_path": "" + } + }, + })) + .unwrap(); + + rspec + .get_ext("ext", false, "latest", "v17") + .expect_err("Extension should not be found"); + + let rspec: RemoteExtSpec = serde_json::from_value(serde_json::json!({ + "public_extensions": [], + "custom_extensions": ["ext"], + "library_index": { + "ext": "ext" + }, + "extension_data": { + "ext": { + "control_data": { + "ext.control": "" + }, + "archive_path": "" + } + }, + })) + .unwrap(); + + rspec + .get_ext("ext", false, "latest", "v17") + .expect("Extension should be found"); + + let rspec: RemoteExtSpec = serde_json::from_value(serde_json::json!({ + "public_extensions": ["ext"], + "custom_extensions": [], + "library_index": { + "ext": "ext" + }, + "extension_data": { + "ext": { + "control_data": { + "ext.control": "" + }, + "archive_path": "" + } + }, + })) + .unwrap(); + + rspec + .get_ext("ext", false, "latest", "v17") + .expect("Extension should be found"); + } + #[test] fn parse_spec_file() { let file = File::open("tests/cluster_spec.json").unwrap(); diff --git a/test_runner/regress/test_download_extensions.py b/test_runner/regress/test_download_extensions.py index d7e6e9de56..7f12c14073 100644 --- a/test_runner/regress/test_download_extensions.py +++ b/test_runner/regress/test_download_extensions.py @@ -95,6 +95,8 @@ def test_remote_extensions( # mock remote_extensions spec spec: dict[str, Any] = { + "public_extensions": ["anon"], + "custom_extensions": None, "library_index": { "anon": "anon", },