diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 3e558a7d3c..285be56264 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -1484,6 +1484,28 @@ LIMIT 100", info!("Pageserver config changed"); } } + + // Gather info about installed extensions + pub fn get_installed_extensions(&self) -> Result<()> { + let connstr = self.connstr.clone(); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to create runtime"); + let result = rt + .block_on(crate::installed_extensions::get_installed_extensions( + connstr, + )) + .expect("failed to get installed extensions"); + + info!( + "{}", + serde_json::to_string(&result).expect("failed to serialize extensions list") + ); + + Ok(()) + } } pub fn forward_termination_signal() { diff --git a/compute_tools/src/http/api.rs b/compute_tools/src/http/api.rs index fade3bbe6d..79e6158081 100644 --- a/compute_tools/src/http/api.rs +++ b/compute_tools/src/http/api.rs @@ -165,6 +165,32 @@ async fn routes(req: Request, compute: &Arc) -> Response { + info!("serving /installed_extensions GET request"); + let status = compute.get_status(); + if status != ComputeStatus::Running { + let msg = format!( + "invalid compute status for extensions request: {:?}", + status + ); + error!(msg); + return Response::new(Body::from(msg)); + } + + let connstr = compute.connstr.clone(); + let res = crate::installed_extensions::get_installed_extensions(connstr).await; + match res { + Ok(res) => render_json(Body::from(serde_json::to_string(&res).unwrap())), + Err(e) => render_json_error( + &format!("could not get list of installed extensions: {}", e), + StatusCode::INTERNAL_SERVER_ERROR, + ), + } + } + // download extension files from remote extension storage on demand (&Method::POST, route) if route.starts_with("/extension_server/") => { info!("serving {:?} POST request", route); diff --git a/compute_tools/src/http/openapi_spec.yaml b/compute_tools/src/http/openapi_spec.yaml index b0ddaeae2b..e9fa66b323 100644 --- a/compute_tools/src/http/openapi_spec.yaml +++ b/compute_tools/src/http/openapi_spec.yaml @@ -53,6 +53,20 @@ paths: schema: $ref: "#/components/schemas/ComputeInsights" + /installed_extensions: + get: + tags: + - Info + summary: Get installed extensions. + description: "" + operationId: getInstalledExtensions + responses: + 200: + description: List of installed extensions + content: + application/json: + schema: + $ref: "#/components/schemas/InstalledExtensions" /info: get: tags: @@ -395,6 +409,24 @@ components: - configuration example: running + InstalledExtensions: + type: object + properties: + extensions: + description: Contains list of installed extensions. + type: array + items: + type: object + properties: + extname: + type: string + versions: + type: array + items: + type: string + n_databases: + type: integer + # # Errors # diff --git a/compute_tools/src/installed_extensions.rs b/compute_tools/src/installed_extensions.rs new file mode 100644 index 0000000000..3d8b22a8a3 --- /dev/null +++ b/compute_tools/src/installed_extensions.rs @@ -0,0 +1,80 @@ +use compute_api::responses::{InstalledExtension, InstalledExtensions}; +use std::collections::HashMap; +use std::collections::HashSet; +use url::Url; + +use anyhow::Result; +use postgres::{Client, NoTls}; +use tokio::task; + +/// We don't reuse get_existing_dbs() just for code clarity +/// and to make database listing query here more explicit. +/// +/// Limit the number of databases to 500 to avoid excessive load. +fn list_dbs(client: &mut Client) -> Result> { + // `pg_database.datconnlimit = -2` means that the database is in the + // invalid state + let databases = client + .query( + "SELECT datname FROM pg_catalog.pg_database + WHERE datallowconn + AND datconnlimit <> - 2 + LIMIT 500", + &[], + )? + .iter() + .map(|row| { + let db: String = row.get("datname"); + db + }) + .collect(); + + Ok(databases) +} + +/// Connect to every database (see list_dbs above) and get the list of installed extensions. +/// Same extension can be installed in multiple databases with different versions, +/// we only keep the highest and lowest version across all databases. +pub async fn get_installed_extensions(connstr: Url) -> Result { + let mut connstr = connstr.clone(); + + task::spawn_blocking(move || { + let mut client = Client::connect(connstr.as_str(), NoTls)?; + let databases: Vec = list_dbs(&mut client)?; + + let mut extensions_map: HashMap = HashMap::new(); + for db in databases.iter() { + connstr.set_path(db); + let mut db_client = Client::connect(connstr.as_str(), NoTls)?; + let extensions: Vec<(String, String)> = db_client + .query( + "SELECT extname, extversion FROM pg_catalog.pg_extension;", + &[], + )? + .iter() + .map(|row| (row.get("extname"), row.get("extversion"))) + .collect(); + + for (extname, v) in extensions.iter() { + let version = v.to_string(); + extensions_map + .entry(extname.to_string()) + .and_modify(|e| { + e.versions.insert(version.clone()); + // count the number of databases where the extension is installed + e.n_databases += 1; + }) + .or_insert(InstalledExtension { + extname: extname.to_string(), + versions: HashSet::from([version.clone()]), + n_databases: 1, + }); + } + } + + Ok(InstalledExtensions { + extensions: extensions_map.values().cloned().collect(), + }) + }) + .await? +} diff --git a/compute_tools/src/lib.rs b/compute_tools/src/lib.rs index 477f423aa2..d27ae58fa2 100644 --- a/compute_tools/src/lib.rs +++ b/compute_tools/src/lib.rs @@ -15,6 +15,7 @@ pub mod catalog; pub mod compute; pub mod disk_quota; pub mod extension_server; +pub mod installed_extensions; pub mod local_proxy; pub mod lsn_lease; mod migration; diff --git a/libs/compute_api/src/responses.rs b/libs/compute_api/src/responses.rs index 3f055b914a..5023fce003 100644 --- a/libs/compute_api/src/responses.rs +++ b/libs/compute_api/src/responses.rs @@ -1,5 +1,6 @@ //! Structs representing the JSON formats used in the compute_ctl's HTTP API. +use std::collections::HashSet; use std::fmt::Display; use chrono::{DateTime, Utc}; @@ -155,3 +156,15 @@ pub enum ControlPlaneComputeStatus { // should be able to start with provided spec. Attached, } + +#[derive(Clone, Debug, Default, Serialize)] +pub struct InstalledExtension { + pub extname: String, + pub versions: HashSet, + pub n_databases: u32, // Number of databases using this extension +} + +#[derive(Clone, Debug, Default, Serialize)] +pub struct InstalledExtensions { + pub extensions: Vec, +} diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py index aedd711dbd..26895df8a6 100644 --- a/test_runner/fixtures/endpoint/http.py +++ b/test_runner/fixtures/endpoint/http.py @@ -23,3 +23,8 @@ class EndpointHttpClient(requests.Session): res = self.get(f"http://localhost:{self.port}/database_schema?database={database}") res.raise_for_status() return res.text + + def installed_extensions(self): + res = self.get(f"http://localhost:{self.port}/installed_extensions") + res.raise_for_status() + return res.json() diff --git a/test_runner/regress/test_installed_extensions.py b/test_runner/regress/test_installed_extensions.py new file mode 100644 index 0000000000..4700db85ee --- /dev/null +++ b/test_runner/regress/test_installed_extensions.py @@ -0,0 +1,87 @@ +from logging import info + +from fixtures.neon_fixtures import NeonEnv + + +def test_installed_extensions(neon_simple_env: NeonEnv): + """basic test for the endpoint that returns the list of installed extensions""" + + env = neon_simple_env + + env.create_branch("test_installed_extensions") + + endpoint = env.endpoints.create_start("test_installed_extensions") + + endpoint.safe_psql("CREATE DATABASE test_installed_extensions") + endpoint.safe_psql("CREATE DATABASE test_installed_extensions_2") + + client = endpoint.http_client() + res = client.installed_extensions() + + info("Extensions list: %s", res) + info("Extensions: %s", res["extensions"]) + # 'plpgsql' is a default extension that is always installed. + assert any( + ext["extname"] == "plpgsql" and ext["versions"] == ["1.0"] for ext in res["extensions"] + ), "The 'plpgsql' extension is missing" + + # check that the neon_test_utils extension is not installed + assert not any( + ext["extname"] == "neon_test_utils" for ext in res["extensions"] + ), "The 'neon_test_utils' extension is installed" + + pg_conn = endpoint.connect(dbname="test_installed_extensions") + with pg_conn.cursor() as cur: + cur.execute("CREATE EXTENSION neon_test_utils") + cur.execute( + "SELECT default_version FROM pg_available_extensions WHERE name = 'neon_test_utils'" + ) + res = cur.fetchone() + neon_test_utils_version = res[0] + + with pg_conn.cursor() as cur: + cur.execute("CREATE EXTENSION neon version '1.1'") + + pg_conn_2 = endpoint.connect(dbname="test_installed_extensions_2") + with pg_conn_2.cursor() as cur: + cur.execute("CREATE EXTENSION neon version '1.2'") + + res = client.installed_extensions() + + info("Extensions list: %s", res) + info("Extensions: %s", res["extensions"]) + + # check that the neon_test_utils extension is installed only in 1 database + # and has the expected version + assert any( + ext["extname"] == "neon_test_utils" + and ext["versions"] == [neon_test_utils_version] + and ext["n_databases"] == 1 + for ext in res["extensions"] + ) + + # check that the plpgsql extension is installed in all databases + # this is a default extension that is always installed + assert any(ext["extname"] == "plpgsql" and ext["n_databases"] == 4 for ext in res["extensions"]) + + # check that the neon extension is installed and has expected versions + for ext in res["extensions"]: + if ext["extname"] == "neon": + assert ext["n_databases"] == 2 + ext["versions"].sort() + assert ext["versions"] == ["1.1", "1.2"] + + with pg_conn.cursor() as cur: + cur.execute("ALTER EXTENSION neon UPDATE TO '1.3'") + + res = client.installed_extensions() + + info("Extensions list: %s", res) + info("Extensions: %s", res["extensions"]) + + # check that the neon_test_utils extension is updated + for ext in res["extensions"]: + if ext["extname"] == "neon": + assert ext["n_databases"] == 2 + ext["versions"].sort() + assert ext["versions"] == ["1.2", "1.3"]