diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 11fee73f03..c9dd4dcfc5 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -28,7 +28,7 @@ use utils::lsn::Lsn; use compute_api::privilege::Privilege; use compute_api::responses::{ComputeMetrics, ComputeStatus}; -use compute_api::spec::{ComputeFeature, ComputeMode, ComputeSpec}; +use compute_api::spec::{ComputeFeature, ComputeMode, ComputeSpec, ExtVersion}; use utils::measured_stream::MeasuredReader; use nix::sys::signal::{kill, Signal}; @@ -1416,6 +1416,56 @@ LIMIT 100", Ok(()) } + pub async fn install_extension( + &self, + ext_name: &PgIdent, + db_name: &PgIdent, + ext_version: ExtVersion, + ) -> Result { + use tokio_postgres::config::Config; + use tokio_postgres::NoTls; + + let mut conf = Config::from_str(self.connstr.as_str()).unwrap(); + conf.dbname(db_name); + + let (db_client, conn) = conf + .connect(NoTls) + .await + .context("Failed to connect to the database")?; + tokio::spawn(conn); + + let version_query = "SELECT extversion FROM pg_extension WHERE extname = $1"; + let version: Option = db_client + .query_opt(version_query, &[&ext_name]) + .await + .with_context(|| format!("Failed to execute query: {}", version_query))? + .map(|row| row.get(0)); + + // sanitize the inputs as postgres idents. + let ext_name: String = ext_name.pg_quote(); + let quoted_version: String = ext_version.pg_quote(); + + if let Some(installed_version) = version { + if installed_version == ext_version { + return Ok(installed_version); + } + let query = format!("ALTER EXTENSION {ext_name} UPDATE TO {quoted_version}"); + db_client + .simple_query(&query) + .await + .with_context(|| format!("Failed to execute query: {}", query))?; + } else { + let query = + format!("CREATE EXTENSION IF NOT EXISTS {ext_name} WITH VERSION {quoted_version}"); + db_client + .simple_query(&query) + .await + .with_context(|| format!("Failed to execute query: {}", query))?; + } + + Ok(ext_version) + } + #[tokio::main] pub async fn prepare_preload_libraries( &self, diff --git a/compute_tools/src/http/api.rs b/compute_tools/src/http/api.rs index 133ab9f5af..af35f71bf2 100644 --- a/compute_tools/src/http/api.rs +++ b/compute_tools/src/http/api.rs @@ -9,9 +9,10 @@ use crate::catalog::SchemaDumpError; use crate::catalog::{get_database_schema, get_dbs_and_roles}; use crate::compute::forward_termination_signal; use crate::compute::{ComputeNode, ComputeState, ParsedSpec}; -use compute_api::requests::{ConfigurationRequest, SetRoleGrantsRequest}; +use compute_api::requests::{ConfigurationRequest, ExtensionInstallRequest, SetRoleGrantsRequest}; use compute_api::responses::{ - ComputeStatus, ComputeStatusResponse, GenericAPIError, SetRoleGrantsResponse, + ComputeStatus, ComputeStatusResponse, ExtensionInstallResult, GenericAPIError, + SetRoleGrantsResponse, }; use anyhow::Result; @@ -100,6 +101,38 @@ async fn routes(req: Request, compute: &Arc) -> Response { + info!("serving /extensions POST request"); + let status = compute.get_status(); + if status != ComputeStatus::Running { + let msg = format!( + "invalid compute status for extensions request: {:?}", + status + ); + error!(msg); + return render_json_error(&msg, StatusCode::PRECONDITION_FAILED); + } + + let request = hyper::body::to_bytes(req.into_body()).await.unwrap(); + let request = serde_json::from_slice::(&request).unwrap(); + let res = compute + .install_extension(&request.extension, &request.database, request.version) + .await; + match res { + Ok(version) => render_json(Body::from( + serde_json::to_string(&ExtensionInstallResult { + extension: request.extension, + version, + }) + .unwrap(), + )), + Err(e) => { + error!("install_extension failed: {}", e); + render_json_error(&e.to_string(), StatusCode::INTERNAL_SERVER_ERROR) + } + } + } + (&Method::GET, "/info") => { let num_cpus = num_cpus::get_physical(); info!("serving /info GET request. num_cpus: {}", num_cpus); diff --git a/compute_tools/src/http/openapi_spec.yaml b/compute_tools/src/http/openapi_spec.yaml index 73dbdc3ee9..11eee6ccfd 100644 --- a/compute_tools/src/http/openapi_spec.yaml +++ b/compute_tools/src/http/openapi_spec.yaml @@ -179,6 +179,41 @@ paths: description: Error text or 'true' if check passed. example: "true" + /extensions: + post: + tags: + - Extensions + summary: Install extension if possible. + description: "" + operationId: installExtension + requestBody: + description: Extension name and database to install it to. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ExtensionInstallRequest" + responses: + 200: + description: Result from extension installation + content: + application/json: + schema: + $ref: "#/components/schemas/ExtensionInstallResult" + 412: + description: | + Compute is in the wrong state for processing the request. + content: + application/json: + schema: + $ref: "#/components/schemas/GenericError" + 500: + description: Error during extension installation. + content: + application/json: + schema: + $ref: "#/components/schemas/GenericError" + /configure: post: tags: @@ -404,7 +439,7 @@ components: moment, when spec was received. example: "2022-10-12T07:20:50.52Z" status: - $ref: '#/components/schemas/ComputeStatus' + $ref: "#/components/schemas/ComputeStatus" last_active: type: string description: | @@ -444,6 +479,38 @@ components: - configuration example: running + ExtensionInstallRequest: + type: object + required: + - extension + - database + - version + properties: + extension: + type: string + description: Extension name. + example: "pg_session_jwt" + version: + type: string + description: Version of the extension. + example: "1.0.0" + database: + type: string + description: Database name. + example: "neondb" + + ExtensionInstallResult: + type: object + properties: + extension: + description: Name of the extension. + type: string + example: "pg_session_jwt" + version: + description: Version of the extension. + type: string + example: "1.0.0" + InstalledExtensions: type: object properties: diff --git a/libs/compute_api/src/requests.rs b/libs/compute_api/src/requests.rs index fbc7577dd9..fc3757d981 100644 --- a/libs/compute_api/src/requests.rs +++ b/libs/compute_api/src/requests.rs @@ -1,8 +1,7 @@ //! Structs representing the JSON formats used in the compute_ctl's HTTP API. - use crate::{ privilege::Privilege, - spec::{ComputeSpec, PgIdent}, + spec::{ComputeSpec, ExtVersion, PgIdent}, }; use serde::Deserialize; @@ -16,6 +15,13 @@ pub struct ConfigurationRequest { pub spec: ComputeSpec, } +#[derive(Deserialize, Debug)] +pub struct ExtensionInstallRequest { + pub extension: PgIdent, + pub database: PgIdent, + pub version: ExtVersion, +} + #[derive(Deserialize, Debug)] pub struct SetRoleGrantsRequest { pub database: PgIdent, diff --git a/libs/compute_api/src/responses.rs b/libs/compute_api/src/responses.rs index fadf524273..79234be720 100644 --- a/libs/compute_api/src/responses.rs +++ b/libs/compute_api/src/responses.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize, Serializer}; use crate::{ privilege::Privilege, - spec::{ComputeSpec, Database, PgIdent, Role}, + spec::{ComputeSpec, Database, ExtVersion, PgIdent, Role}, }; #[derive(Serialize, Debug, Deserialize)] @@ -172,6 +172,11 @@ pub struct InstalledExtensions { pub extensions: Vec, } +#[derive(Clone, Debug, Default, Serialize)] +pub struct ExtensionInstallResult { + pub extension: PgIdent, + pub version: ExtVersion, +} #[derive(Clone, Debug, Default, Serialize)] pub struct SetRoleGrantsResponse { pub database: PgIdent, diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index 5903db7055..8a447563dc 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -16,6 +16,9 @@ use remote_storage::RemotePath; /// intended to be used for DB / role names. pub type PgIdent = String; +/// String type alias representing Postgres extension version +pub type ExtVersion = String; + /// Cluster spec or configuration represented as an optional number of /// delta operations + final cluster state description. #[derive(Clone, Debug, Default, Deserialize, Serialize)] diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py index e7b014b4a9..ea8291c1e0 100644 --- a/test_runner/fixtures/endpoint/http.py +++ b/test_runner/fixtures/endpoint/http.py @@ -29,6 +29,16 @@ class EndpointHttpClient(requests.Session): res.raise_for_status() return res.json() + def extensions(self, extension: str, version: str, database: str): + body = { + "extension": extension, + "version": version, + "database": database, + } + res = self.post(f"http://localhost:{self.port}/extensions", json=body) + res.raise_for_status() + return res.json() + def set_role_grants(self, database: str, role: str, schema: str, privileges: list[str]): res = self.post( f"http://localhost:{self.port}/grants", diff --git a/test_runner/regress/test_extensions.py b/test_runner/regress/test_extensions.py new file mode 100644 index 0000000000..100fd4b048 --- /dev/null +++ b/test_runner/regress/test_extensions.py @@ -0,0 +1,50 @@ +from logging import info + +from fixtures.neon_fixtures import NeonEnv + + +def test_extensions(neon_simple_env: NeonEnv): + """basic test for the extensions endpoint testing installing extensions""" + + env = neon_simple_env + + env.create_branch("test_extensions") + + endpoint = env.endpoints.create_start("test_extensions") + extension = "neon_test_utils" + database = "test_extensions" + + endpoint.safe_psql("CREATE DATABASE test_extensions") + + with endpoint.connect(dbname=database) as pg_conn: + with pg_conn.cursor() as cur: + cur.execute( + "SELECT default_version FROM pg_available_extensions WHERE name = 'neon_test_utils'" + ) + res = cur.fetchone() + assert res is not None + version = res[0] + + with pg_conn.cursor() as cur: + cur.execute( + "SELECT extname, extversion FROM pg_extension WHERE extname = 'neon_test_utils'", + ) + res = cur.fetchone() + assert not res, "The 'neon_test_utils' extension is installed" + + client = endpoint.http_client() + install_res = client.extensions(extension, version, database) + + info("Extension install result: %s", res) + assert install_res["extension"] == extension and install_res["version"] == version + + with endpoint.connect(dbname=database) as pg_conn: + with pg_conn.cursor() as cur: + cur.execute( + "SELECT extname, extversion FROM pg_extension WHERE extname = 'neon_test_utils'", + ) + res = cur.fetchone() + assert res is not None + (db_extension_name, db_extension_version) = res + + assert db_extension_name == extension and db_extension_version == version