diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 036dd39a49..95ddf20901 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -1368,32 +1368,27 @@ LIMIT 100", download_size } - pub fn install_extension(&self, ext_name: &str, db_name: &str) -> Result { - let mut conf = - Config::from_str(self.connstr.as_str()).context("Failed to parse connection string")?; + pub fn install_extension( + &self, + ext_name: &str, + db_name: &str, + ext_version: &str, + ) -> Result { + let mut conf = Config::from_str(self.connstr.as_str()).unwrap(); conf.dbname(db_name); let mut db_client = conf .connect(NoTls) .context("Failed to connect to the database")?; - let query = format!( - "CREATE EXTENSION IF NOT EXISTS {}", - ext_name.to_string().pg_quote() - ); - info!("creating extension with query: {}", query); - + let query = "CREATE EXTENSION IF NOT EXISTS $1 WITH VERSION $2"; db_client - .execute(&query, &[]) + .query(query, &[&ext_name, &ext_version]) .context(format!("Failed to execute query: {}", query))?; - let version_query = format!( - "SELECT extversion FROM pg_extension WHERE extname = '{}'", - ext_name.to_string().pg_quote() - ); - + let version_query = "SELECT extversion FROM pg_extension WHERE extname = $1"; let version: String = db_client - .query_one(&version_query, &[]) + .query_one(version_query, &[&ext_name]) .context(format!("Failed to execute query: {}", version_query))? .get(0); diff --git a/compute_tools/src/http/api.rs b/compute_tools/src/http/api.rs index 7542f5ef36..f92e260071 100644 --- a/compute_tools/src/http/api.rs +++ b/compute_tools/src/http/api.rs @@ -9,9 +9,10 @@ use crate::catalog::SchemaDumpError; use crate::catalog::{get_database_schema, get_dbs_and_roles}; use crate::compute::forward_termination_signal; use crate::compute::{ComputeNode, ComputeState, ParsedSpec}; -use compute_api::requests::ConfigurationRequest; -use compute_api::responses::ExtensionInstallResult; -use compute_api::responses::{ComputeStatus, ComputeStatusResponse, GenericAPIError}; +use compute_api::requests::{ConfigurationRequest, ExtensionInstallRequest}; +use compute_api::responses::{ + ComputeStatus, ComputeStatusResponse, ExtensionInstallResult, GenericAPIError, +}; use anyhow::Result; use hyper::header::CONTENT_TYPE; @@ -111,15 +112,14 @@ async fn routes(req: Request, compute: &Arc) -> Response(&body).unwrap(); - let extension = body["extension"].as_str().unwrap(); - let database = body["database"].as_str().unwrap(); - let res = compute.install_extension(extension, database); + let request = hyper::body::to_bytes(req.into_body()).await.unwrap(); + let request = serde_json::from_slice::(&request).unwrap(); + let res = + compute.install_extension(&request.extension, &request.database, &request.version); match res { Ok(res) => render_json(Body::from( serde_json::to_string(&ExtensionInstallResult { - extension: extension.to_string(), + extension: request.extension, version: res, }) .unwrap(), diff --git a/compute_tools/src/http/openapi_spec.yaml b/compute_tools/src/http/openapi_spec.yaml index 9bf5aeebf1..90401ef2e5 100644 --- a/compute_tools/src/http/openapi_spec.yaml +++ b/compute_tools/src/http/openapi_spec.yaml @@ -10,7 +10,7 @@ paths: /status: get: tags: - - Info + - Info summary: Get compute node internal status. description: "" operationId: getComputeStatus @@ -25,7 +25,7 @@ paths: /metrics.json: get: tags: - - Info + - Info summary: Get compute node startup metrics in JSON format. description: "" operationId: getComputeMetricsJSON @@ -40,7 +40,7 @@ paths: /insights: get: tags: - - Info + - Info summary: Get current compute insights in JSON format. description: | Note, that this doesn't include any historical data. @@ -56,7 +56,7 @@ paths: /installed_extensions: get: tags: - - Info + - Info summary: Get installed extensions. description: "" operationId: getInstalledExtensions @@ -70,7 +70,7 @@ paths: /info: get: tags: - - Info + - Info summary: Get info about the compute pod / VM. description: "" operationId: getInfo @@ -130,7 +130,7 @@ paths: /check_writability: post: tags: - - Check + - Check summary: Check that we can write new data on this compute. description: "" operationId: checkComputeWritability @@ -147,7 +147,7 @@ paths: /extensions/install: post: tags: - - Extensions + - Extensions summary: Install extension if possible. description: "" operationId: installExtension @@ -157,18 +157,7 @@ paths: content: application/json: schema: - type: object - required: - - extension - - database - properties: - extension: - type: string - description: Extension name. - database: - type: string - description: Database name. - example: "neondb" + $ref: "#/components/schemas/ExtensionInstallRequest" responses: 200: description: Result from extension installation @@ -186,7 +175,7 @@ paths: /configure: post: tags: - - Configure + - Configure summary: Perform compute node configuration. description: | This is a blocking API endpoint, i.e. it blocks waiting until @@ -240,7 +229,7 @@ paths: /extension_server: post: tags: - - Extension + - Extension summary: Download extension from S3 to local folder. description: "" operationId: downloadExtension @@ -269,7 +258,7 @@ paths: /terminate: post: tags: - - Terminate + - Terminate summary: Terminate Postgres and wait for it to exit description: "" operationId: terminate @@ -408,7 +397,7 @@ components: moment, when spec was received. example: "2022-10-12T07:20:50.52Z" status: - $ref: '#/components/schemas/ComputeStatus' + $ref: "#/components/schemas/ComputeStatus" last_active: type: string description: | @@ -448,15 +437,37 @@ components: - configuration example: running + ExtensionInstallRequest: + type: object + required: + - extension + - database + - version + properties: + extension: + type: string + description: Extension name. + example: "pg_session_jwt" + version: + type: string + description: Version of the extension. + example: "1.0.0" + database: + type: string + description: Database name. + example: "neondb" + ExtensionInstallResult: type: object properties: extension: description: Name of the extension. type: string + example: "pg_session_jwt" version: description: Version of the extension. type: string + example: "1.0.0" InstalledExtensions: type: object diff --git a/libs/compute_api/src/requests.rs b/libs/compute_api/src/requests.rs index 5896c7dc65..34771834f3 100644 --- a/libs/compute_api/src/requests.rs +++ b/libs/compute_api/src/requests.rs @@ -12,3 +12,10 @@ use serde::Deserialize; pub struct ConfigurationRequest { pub spec: ComputeSpec, } + +#[derive(Deserialize, Debug)] +pub struct ExtensionInstallRequest { + pub extension: String, + pub database: String, + pub version: String, +}