diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 1fe13d5c3a..f9efb155b4 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::HashSet; use std::fs; use std::os::unix::fs::PermissionsExt; use std::path::Path; @@ -18,7 +18,7 @@ use utils::lsn::Lsn; use compute_api::responses::{ComputeMetrics, ComputeStatus}; use compute_api::spec::{ComputeMode, ComputeSpec}; -use remote_storage::{GenericRemoteStorage, RemotePath}; +use remote_storage::GenericRemoteStorage; use crate::pg_helpers::*; use crate::spec::*; @@ -54,7 +54,7 @@ pub struct ComputeNode { pub ext_remote_storage: Option, // cached lists of available extensions and libraries // pub available_libraries: OnceLock>>, - pub available_extensions: OnceLock>, + pub available_extensions: OnceLock>, } #[derive(Clone, Debug)] @@ -723,17 +723,7 @@ LIMIT 100", match &self.ext_remote_storage { None => anyhow::bail!("No remote extension storage"), Some(remote_storage) => { - extension_server::download_extension( - ext_name, - self.available_extensions - .get() - .context("extension download error")? - .get(ext_name) - .context("cannot find extension")?, - remote_storage, - &self.pgbin, - ) - .await + extension_server::download_extension(ext_name, remote_storage, &self.pgbin).await } } } @@ -741,7 +731,7 @@ LIMIT 100", #[tokio::main] pub async fn prepare_preload_libraries(&self, compute_state: &ComputeState) -> Result<()> { // TODO: revive some of the old logic for downloading shared preload libaries - info!("ERRRRRORRRR"); + info!("I HAVENT IMPLEMENTED DOWNLOADING SHARED PRELOAD LIBRARIES YET"); Ok(()) } } diff --git a/compute_tools/src/extension_server.rs b/compute_tools/src/extension_server.rs index 0ea6f52fd3..4284f21a88 100644 --- a/compute_tools/src/extension_server.rs +++ b/compute_tools/src/extension_server.rs @@ -1,11 +1,28 @@ // Download extension files from the extension store // and put them in the right place in the postgres directory +/* +The layout of the S3 bucket is as follows: + +v14/ext_index.json + -- this contains information necessary to create control files +v14/extensions/test_ext1.tar.gz + -- this contains the library files and sql files necessary to create this extension +v14/extensions/custom_ext1.tar.gz + +The difference between a private and public extensions is determined by who can +load the extension this is specified in ext_index.json + +Speicially, ext_index.json has a list of public extensions, and a list of +extensions enabled for specific tenant-ids. +*/ use crate::compute::ComputeNode; use anyhow::{self, bail, Result}; +use futures::future::ok; use remote_storage::*; -use serde_json::{self, Value}; -use std::collections::HashMap; +use serde_json::{self, Map, Value}; +use std::collections::HashSet; use std::fs::File; +use std::hash::Hash; use std::io::BufWriter; use std::io::Write; use std::num::{NonZeroU32, NonZeroUsize}; @@ -14,7 +31,7 @@ use std::str; use std::sync::Arc; use std::thread; use tokio::io::AsyncReadExt; -use tracing::{info, warn}; +use tracing::info; fn get_pg_config(argument: &str, pgbin: &str) -> String { // gives the result of `pg_config [argument]` @@ -49,87 +66,104 @@ pub async fn get_available_extensions( pgbin: &str, pg_version: &str, custom_ext_prefixes: &Vec, -) -> Result> { +) -> Result> { + // TODO: in this function change expect's to pass the error instead of panic-ing + let local_sharedir = Path::new(&get_pg_config("--sharedir", pgbin)).join("extension"); - let index_path = RemotePath::new(Path::new(&format!("{:?}/ext_index.json", pg_version))) - .expect("error forming path"); + let index_path = pg_version.to_owned() + "/ext_index.json"; + let index_path = RemotePath::new(Path::new(&index_path)).expect("error forming path"); + info!("download extension index json: {:?}", &index_path); + let all_files = remote_storage.list_files(None).await?; + + dbg!(all_files); + let mut download = remote_storage.download(&index_path).await?; let mut write_data_buffer = Vec::new(); download .download_stream .read_to_end(&mut write_data_buffer) .await?; - let ext_index_str = - serde_json::to_string(&write_data_buffer).expect("Failed to convert to JSON"); + let ext_index_str = match str::from_utf8(&write_data_buffer) { + Ok(v) => v, + Err(e) => panic!("Invalid UTF-8 sequence: {}", e), + }; + + dbg!(ext_index_str); + let ext_index_full = match serde_json::from_str(&ext_index_str) { Ok(Value::Object(map)) => map, _ => bail!("error parsing json"), }; + let control_data = ext_index_full["control_data"] + .as_object() + .expect("json parse error"); + let enabled_extensions = ext_index_full["enabled_extensions"] + .as_object() + .expect("json parse error"); + + dbg!(ext_index_full.clone()); + dbg!(control_data.clone()); + dbg!(enabled_extensions.clone()); let mut prefixes = vec!["public".to_string()]; prefixes.extend(custom_ext_prefixes.clone()); - let mut ext_index_limited = HashMap::new(); + dbg!(prefixes.clone()); + let mut all_extensions = HashSet::new(); for prefix in prefixes { - let ext_details_str = ext_index_full.get(&prefix); - if let Some(ext_details_str) = ext_details_str { - let ext_details = - serde_json::to_string(ext_details_str).expect("Failed to convert to JSON"); - let ext_details = match serde_json::from_str(&ext_details) { - Ok(Value::Object(map)) => map, - _ => bail!("error parsing json"), - }; - let control_contents = match ext_details.get("control").expect("broken json file") { - Value::String(s) => s, - _ => bail!("broken json file"), - }; - let path = RemotePath::new(Path::new(&format!( - "{:?}/{:?}", - pg_version, - ext_details.get("path") - ))) - .expect("error forming path"); - - let control_path = format!("{:?}/{:?}.control", &local_sharedir, &prefix); - std::fs::write(control_path, &control_contents)?; - - ext_index_limited.insert(prefix, path); - } else { - warn!("BAD PREFIX {:?}", prefix); + let prefix_extensions = match enabled_extensions.get(&prefix) { + Some(Value::Array(ext_name)) => ext_name, + _ => { + info!("prefix {} has no extensions", prefix); + continue; + } + }; + dbg!(prefix_extensions); + for ext_name in prefix_extensions { + all_extensions.insert(ext_name.as_str().expect("json parse error").to_string()); } } - Ok(ext_index_limited) + + // TODO: this is probably I/O bound, could benefit from parallelizing + for prefix in &all_extensions { + let control_contents = control_data[prefix].as_str().expect("json parse error"); + let control_path = local_sharedir.join(prefix.to_owned() + ".control"); + + info!("WRITING FILE {:?}{:?}", control_path, control_contents); + std::fs::write(control_path, &control_contents)?; + } + + Ok(all_extensions.into_iter().collect()) } // download all sqlfiles (and possibly data files) for a given extension name // pub async fn download_extension( ext_name: &str, - ext_path: &RemotePath, remote_storage: &GenericRemoteStorage, pgbin: &str, ) -> Result<()> { - let local_sharedir = Path::new(&get_pg_config("--sharedir", pgbin)).join("extension"); - let local_libdir = Path::new(&get_pg_config("--libdir", pgbin)).to_owned(); - info!("Start downloading extension {:?}", ext_name); - let mut download = remote_storage.download(&ext_path).await?; - let mut write_data_buffer = Vec::new(); - download - .download_stream - .read_to_end(&mut write_data_buffer) - .await?; - let zip_name = ext_path.object_name().expect("invalid extension path"); - let mut output_file = BufWriter::new(File::create(zip_name)?); - output_file.write_all(&write_data_buffer)?; - info!("Download {:?} completed successfully", &ext_path); - info!("Unzipping extension {:?}", zip_name); + todo!(); + // let local_sharedir = Path::new(&get_pg_config("--sharedir", pgbin)).join("extension"); + // let local_libdir = Path::new(&get_pg_config("--libdir", pgbin)).to_owned(); + // info!("Start downloading extension {:?}", ext_name); + // let mut download = remote_storage.download(&ext_path).await?; + // let mut write_data_buffer = Vec::new(); + // download + // .download_stream + // .read_to_end(&mut write_data_buffer) + // .await?; + // let zip_name = ext_path.object_name().expect("invalid extension path"); + // let mut output_file = BufWriter::new(File::create(zip_name)?); + // output_file.write_all(&write_data_buffer)?; + // info!("Download {:?} completed successfully", &ext_path); + // info!("Unzipping extension {:?}", zip_name); - // TODO unzip and place files in appropriate locations - info!("unzip {zip_name:?}"); - info!("place extension files in {local_sharedir:?}"); - info!("place library files in {local_libdir:?}"); - - Ok(()) + // // TODO unzip and place files in appropriate locations + // info!("unzip {zip_name:?}"); + // info!("place extension files in {local_sharedir:?}"); + // info!("place library files in {local_libdir:?}"); + // Ok(()) } // This function initializes the necessary structs to use remmote storage (should be fairly cheap) diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 6f32ababdf..ffe7211378 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -499,7 +499,7 @@ impl Endpoint { // in spec, but we don't have a way to do that yet in the python tests. // NEW HACK: we enable the anon custom extension for everyone! this is of course just for testing // how will we do it for real? - custom_extensions: Some(vec!["anon".to_string(), self.tenant_id.to_string()]), + custom_extensions: Some(vec!["123454321".to_string(), self.tenant_id.to_string()]), }; let spec_path = self.endpoint_path().join("spec.json"); std::fs::write(spec_path, serde_json::to_string_pretty(&spec)?)?; diff --git a/test_runner/regress/data/extension_test/ext_index.json b/test_runner/regress/data/extension_test/ext_index.json new file mode 100644 index 0000000000..7fa10701f4 --- /dev/null +++ b/test_runner/regress/data/extension_test/ext_index.json @@ -0,0 +1,14 @@ +{ + "enabled_extensions": { + "123454321": [ + "anon" + ], + "public": [ + "embedding" + ] + }, + "control_data": { + "embedding": "comment = 'hnsw index' \ndefault_version = '0.1.0' \nmodule_pathname = '$libdir/embedding' \nrelocatable = true \ntrusted = true", + "anon": "# PostgreSQL Anonymizer (anon) extension \ncomment = 'Data anonymization tools' \ndefault_version = '1.1.0' \ndirectory='extension/anon' \nrelocatable = false \nrequires = 'pgcrypto' \nsuperuser = false \nmodule_pathname = '$libdir/anon' \ntrusted = true \n" + } +} \ No newline at end of file diff --git a/test_runner/regress/test_download_extensions.py b/test_runner/regress/test_download_extensions.py index 4e298567e1..0fe7775bd4 100644 --- a/test_runner/regress/test_download_extensions.py +++ b/test_runner/regress/test_download_extensions.py @@ -29,7 +29,7 @@ def test_remote_extensions( return None neon_env_builder.enable_remote_storage( - remote_storage_kind=RemoteStorageKind.MOCK_S3, + remote_storage_kind=remote_storage_kind, test_name="test_remote_extensions", enable_remote_extensions=True, ) @@ -43,17 +43,24 @@ def test_remote_extensions( # For MOCK_S3 we upload some test files. for REAL_S3 we use the files created in CICD if remote_storage_kind == RemoteStorageKind.MOCK_S3: + log.info("Uploading test files to mock bucket") + with open("test_runner/regress/data/extension_test/ext_index.json", "rb") as f: + env.remote_storage_client.upload_fileobj( + f, + env.ext_remote_storage.bucket_name, + f"ext/v{pg_version}/ext_index.json", + ) with open("test_runner/regress/data/extension_test/anon.tar.gz", "rb") as f: env.remote_storage_client.upload_fileobj( - f.read(), + f, env.ext_remote_storage.bucket_name, - f"{pg_version}/{str(tenant_id)}/anon.tar.gz", + f"ext/v{pg_version}/anon.tar.gz", ) with open("test_runner/regress/data/extension_test/embedding.tar.gz", "rb") as f: env.remote_storage_client.upload_fileobj( - f.read(), + f, env.ext_remote_storage.bucket_name, - f"{pg_version}/public/embedding.tar.gz", + f"ext/v{pg_version}/embedding.tar.gz", ) # Start a compute node and check that it can download the extensions