diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index ec2eb02237..0b30034e7e 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -24,7 +24,7 @@ use compute_api::responses::{ComputeMetrics, ComputeStatus}; use compute_api::spec::{ComputeMode, ComputeSpec}; use utils::measured_stream::MeasuredReader; -use remote_storage::{GenericRemoteStorage, RemotePath}; +use remote_storage::{DownloadError, GenericRemoteStorage, RemotePath}; use crate::pg_helpers::*; use crate::spec::*; @@ -925,6 +925,7 @@ LIMIT 100", let spec = &pspec.spec; let custom_ext = spec.custom_extensions.clone().unwrap_or(Vec::new()); info!("custom extensions: {:?}", &custom_ext); + let (ext_remote_paths, library_index) = extension_server::get_available_extensions( ext_remote_storage, &self.pgbin, @@ -944,94 +945,113 @@ LIMIT 100", } // download an archive, unzip and place files in correct locations - pub async fn download_extension(&self, ext_name: &str, is_library: bool) -> Result { - match &self.ext_remote_storage { - None => anyhow::bail!("No remote extension storage"), - Some(remote_storage) => { - let mut real_ext_name = ext_name.to_string(); - if is_library { - // sometimes library names might have a suffix like - // library.so or library.so.3. We strip this off - // because library_index is based on the name without the file extension - let strip_lib_suffix = Regex::new(r"\.so.*").unwrap(); - let lib_raw_name = strip_lib_suffix.replace(&real_ext_name, "").to_string(); - real_ext_name = self - .library_index - .get() - .expect("must have already downloaded the library_index")[&lib_raw_name] - .clone(); - } + pub async fn download_extension( + &self, + ext_name: &str, + is_library: bool, + ) -> Result { + let remote_storage = self + .ext_remote_storage + .as_ref() + .ok_or(DownloadError::BadInput(anyhow::anyhow!( + "Remote extensions storage is not configured", + )))?; - let ext_path = &self - .ext_remote_paths - .get() - .expect("error accessing ext_remote_paths")[&real_ext_name]; - let ext_archive_name = ext_path.object_name().expect("bad path"); + let mut real_ext_name = ext_name; + if is_library { + // sometimes library names might have a suffix like + // library.so or library.so.3. We strip this off + // because library_index is based on the name without the file extension + let strip_lib_suffix = Regex::new(r"\.so.*").unwrap(); + let lib_raw_name = strip_lib_suffix.replace(real_ext_name, "").to_string(); - let mut first_try = false; - if !self - .ext_download_progress - .read() - .expect("lock err") - .contains_key(ext_archive_name) - { - self.ext_download_progress - .write() - .expect("lock err") - .insert(ext_archive_name.to_string(), (Utc::now(), false)); - first_try = true; - } - let (download_start, download_completed) = - self.ext_download_progress.read().expect("lock err")[ext_archive_name]; - let start_time_delta = Utc::now() - .signed_duration_since(download_start) - .to_std() - .unwrap() - .as_millis() as u64; - - // how long to wait for extension download if it was started by another process - const HANG_TIMEOUT: u64 = 3000; // milliseconds - - if download_completed { - info!("extension already downloaded, skipping re-download"); - return Ok(0); - } else if start_time_delta < HANG_TIMEOUT && !first_try { - info!("download {ext_archive_name} already started by another process, hanging untill completion or timeout"); - let mut interval = - tokio::time::interval(tokio::time::Duration::from_millis(500)); - loop { - info!("waiting for download"); - interval.tick().await; - let (_, download_completed_now) = - self.ext_download_progress.read().expect("lock")[ext_archive_name]; - if download_completed_now { - info!("download finished by whoever else downloaded it"); - return Ok(0); - } - } - // NOTE: the above loop will get terminated - // based on the timeout of the download function - } - - // if extension hasn't been downloaded before or the previous - // attempt to download was at least HANG_TIMEOUT ms ago - // then we try to download it here - info!("downloading new extension {ext_archive_name}"); - - let download_size = extension_server::download_extension( - &real_ext_name, - ext_path, - remote_storage, - &self.pgbin, - ) - .await; - self.ext_download_progress - .write() - .expect("bad lock") - .insert(ext_archive_name.to_string(), (download_start, true)); - download_size - } + real_ext_name = self + .library_index + .get() + .expect("must have already downloaded the library_index") + .get(&lib_raw_name) + .ok_or(DownloadError::BadInput(anyhow::anyhow!( + "library {} is not found", + lib_raw_name + )))?; } + + let ext_path = &self + .ext_remote_paths + .get() + .expect("error accessing ext_remote_paths") + .get(real_ext_name) + .ok_or(DownloadError::BadInput(anyhow::anyhow!( + "real_ext_name {} is not found", + real_ext_name + )))?; + + let ext_archive_name = ext_path.object_name().expect("bad path"); + + let mut first_try = false; + if !self + .ext_download_progress + .read() + .expect("lock err") + .contains_key(ext_archive_name) + { + self.ext_download_progress + .write() + .expect("lock err") + .insert(ext_archive_name.to_string(), (Utc::now(), false)); + first_try = true; + } + let (download_start, download_completed) = + self.ext_download_progress.read().expect("lock err")[ext_archive_name]; + let start_time_delta = Utc::now() + .signed_duration_since(download_start) + .to_std() + .unwrap() + .as_millis() as u64; + + // how long to wait for extension download if it was started by another process + const HANG_TIMEOUT: u64 = 3000; // milliseconds + + if download_completed { + info!("extension already downloaded, skipping re-download"); + return Ok(0); + } else if start_time_delta < HANG_TIMEOUT && !first_try { + info!("download {ext_archive_name} already started by another process, hanging untill completion or timeout"); + let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(500)); + loop { + info!("waiting for download"); + interval.tick().await; + let (_, download_completed_now) = + self.ext_download_progress.read().expect("lock")[ext_archive_name]; + if download_completed_now { + info!("download finished by whoever else downloaded it"); + return Ok(0); + } + } + // NOTE: the above loop will get terminated + // based on the timeout of the download function + } + + // if extension hasn't been downloaded before or the previous + // attempt to download was at least HANG_TIMEOUT ms ago + // then we try to download it here + info!("downloading new extension {ext_archive_name}"); + + let download_size = extension_server::download_extension( + real_ext_name, + ext_path, + remote_storage, + &self.pgbin, + ) + .await + .map_err(DownloadError::Other); + + self.ext_download_progress + .write() + .expect("bad lock") + .insert(ext_archive_name.to_string(), (download_start, true)); + + download_size } #[tokio::main] @@ -1090,7 +1110,17 @@ LIMIT 100", .as_millis() as u64; info!("Prepare extensions took {prep_ext_time_delta}ms"); + // Don't try to download libraries that are not in the index. + // Assume that they are already present locally. + libs_vec.retain(|lib| { + self.library_index + .get() + .expect("error accessing ext_remote_paths") + .contains_key(lib) + }); + info!("Downloading to shared preload libraries: {:?}", &libs_vec); + let mut download_tasks = Vec::new(); for library in &libs_vec { download_tasks.push(self.download_extension(library, true)); @@ -1104,8 +1134,19 @@ LIMIT 100", prep_extensions_ms: prep_ext_time_delta, }; for result in results { - let download_size = result?; - remote_ext_metrics.num_ext_downloaded += 1; + let download_size = match result { + Ok(res) => { + remote_ext_metrics.num_ext_downloaded += 1; + res + } + Err(err) => { + // if we failed to download an extension, we don't want to fail the whole + // process, but we do want to log the error + error!("Failed to download extension: {}", err); + 0 + } + }; + remote_ext_metrics.largest_ext_size = std::cmp::max(remote_ext_metrics.largest_ext_size, download_size); remote_ext_metrics.total_ext_download_size += download_size; diff --git a/compute_tools/src/http/api.rs b/compute_tools/src/http/api.rs index af07412b52..7713d2bb51 100644 --- a/compute_tools/src/http/api.rs +++ b/compute_tools/src/http/api.rs @@ -141,6 +141,15 @@ async fn routes(req: Request, compute: &Arc) -> Response Response::new(Body::from("OK")), Err(e) => {