diff --git a/Cargo.lock b/Cargo.lock index 961101b151..b1f53404ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1233,8 +1233,10 @@ dependencies = [ "serde_json", "signal-hook", "tar", + "thiserror", "tokio", "tokio-postgres", + "tokio-stream", "tokio-util", "toml_edit", "tracing", diff --git a/compute_tools/Cargo.toml b/compute_tools/Cargo.toml index 759a117ee9..8f96530a9d 100644 --- a/compute_tools/Cargo.toml +++ b/compute_tools/Cargo.toml @@ -27,10 +27,12 @@ reqwest = { workspace = true, features = ["json"] } tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } tokio-postgres.workspace = true tokio-util.workspace = true +tokio-stream.workspace = true tracing.workspace = true tracing-opentelemetry.workspace = true tracing-subscriber.workspace = true tracing-utils.workspace = true +thiserror.workspace = true url.workspace = true compute_api.workspace = true diff --git a/compute_tools/src/catalog.rs b/compute_tools/src/catalog.rs new file mode 100644 index 0000000000..4fefa831e0 --- /dev/null +++ b/compute_tools/src/catalog.rs @@ -0,0 +1,116 @@ +use compute_api::{ + responses::CatalogObjects, + spec::{Database, Role}, +}; +use futures::Stream; +use postgres::{Client, NoTls}; +use std::{path::Path, process::Stdio, result::Result, sync::Arc}; +use tokio::{ + io::{AsyncBufReadExt, BufReader}, + process::Command, + task, +}; +use tokio_stream::{self as stream, StreamExt}; +use tokio_util::codec::{BytesCodec, FramedRead}; +use tracing::warn; + +use crate::{ + compute::ComputeNode, + pg_helpers::{get_existing_dbs, get_existing_roles}, +}; + +pub async fn get_dbs_and_roles(compute: &Arc) -> anyhow::Result { + let connstr = compute.connstr.clone(); + task::spawn_blocking(move || { + let mut client = Client::connect(connstr.as_str(), NoTls)?; + let roles: Vec; + { + let mut xact = client.transaction()?; + roles = get_existing_roles(&mut xact)?; + } + let databases: Vec = get_existing_dbs(&mut client)?.values().cloned().collect(); + + Ok(CatalogObjects { roles, databases }) + }) + .await? +} + +#[derive(Debug, thiserror::Error)] +pub enum SchemaDumpError { + #[error("Database does not exist.")] + DatabaseDoesNotExist, + #[error("Failed to execute pg_dump.")] + IO(#[from] std::io::Error), +} + +// It uses the pg_dump utility to dump the schema of the specified database. +// The output is streamed back to the caller and supposed to be streamed via HTTP. +// +// Before return the result with the output, it checks that pg_dump produced any output. +// If not, it tries to parse the stderr output to determine if the database does not exist +// and special error is returned. +// +// To make sure that the process is killed when the caller drops the stream, we use tokio kill_on_drop feature. +pub async fn get_database_schema( + compute: &Arc, + dbname: &str, +) -> Result>, SchemaDumpError> { + let pgbin = &compute.pgbin; + let basepath = Path::new(pgbin).parent().unwrap(); + let pgdump = basepath.join("pg_dump"); + let mut connstr = compute.connstr.clone(); + connstr.set_path(dbname); + let mut cmd = Command::new(pgdump) + .arg("--schema-only") + .arg(connstr.as_str()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .kill_on_drop(true) + .spawn()?; + + let stdout = cmd.stdout.take().ok_or_else(|| { + std::io::Error::new(std::io::ErrorKind::Other, "Failed to capture stdout.") + })?; + + let stderr = cmd.stderr.take().ok_or_else(|| { + std::io::Error::new(std::io::ErrorKind::Other, "Failed to capture stderr.") + })?; + + let mut stdout_reader = FramedRead::new(stdout, BytesCodec::new()); + let stderr_reader = BufReader::new(stderr); + + let first_chunk = match stdout_reader.next().await { + Some(Ok(bytes)) if !bytes.is_empty() => bytes, + Some(Err(e)) => { + return Err(SchemaDumpError::IO(e)); + } + _ => { + let mut lines = stderr_reader.lines(); + if let Some(line) = lines.next_line().await? { + if line.contains(&format!("FATAL: database \"{}\" does not exist", dbname)) { + return Err(SchemaDumpError::DatabaseDoesNotExist); + } + warn!("pg_dump stderr: {}", line) + } + tokio::spawn(async move { + while let Ok(Some(line)) = lines.next_line().await { + warn!("pg_dump stderr: {}", line) + } + }); + + return Err(SchemaDumpError::IO(std::io::Error::new( + std::io::ErrorKind::Other, + "failed to start pg_dump", + ))); + } + }; + let initial_stream = stream::once(Ok(first_chunk.freeze())); + // Consume stderr and log warnings + tokio::spawn(async move { + let mut lines = stderr_reader.lines(); + while let Ok(Some(line)) = lines.next_line().await { + warn!("pg_dump stderr: {}", line) + } + }); + Ok(initial_stream.chain(stdout_reader.map(|res| res.map(|b| b.freeze())))) +} diff --git a/compute_tools/src/http/api.rs b/compute_tools/src/http/api.rs index 128783b477..0286429cf2 100644 --- a/compute_tools/src/http/api.rs +++ b/compute_tools/src/http/api.rs @@ -5,17 +5,21 @@ use std::net::SocketAddr; use std::sync::Arc; use std::thread; +use crate::catalog::SchemaDumpError; +use crate::catalog::{get_database_schema, get_dbs_and_roles}; use crate::compute::forward_termination_signal; use crate::compute::{ComputeNode, ComputeState, ParsedSpec}; use compute_api::requests::ConfigurationRequest; use compute_api::responses::{ComputeStatus, ComputeStatusResponse, GenericAPIError}; use anyhow::Result; +use hyper::header::CONTENT_TYPE; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Method, Request, Response, Server, StatusCode}; use tokio::task; use tracing::{error, info, warn}; use tracing_utils::http::OtelName; +use utils::http::request::must_get_query_param; fn status_response_from_state(state: &ComputeState) -> ComputeStatusResponse { ComputeStatusResponse { @@ -133,6 +137,34 @@ async fn routes(req: Request, compute: &Arc) -> Response { + info!("serving /dbs_and_roles GET request",); + match get_dbs_and_roles(compute).await { + Ok(res) => render_json(Body::from(serde_json::to_string(&res).unwrap())), + Err(_) => { + render_json_error("can't get dbs and roles", StatusCode::INTERNAL_SERVER_ERROR) + } + } + } + + (&Method::GET, "/database_schema") => { + let database = match must_get_query_param(&req, "database") { + Err(e) => return e.into_response(), + Ok(database) => database, + }; + info!("serving /database_schema GET request with database: {database}",); + match get_database_schema(compute, &database).await { + Ok(res) => render_plain(Body::wrap_stream(res)), + Err(SchemaDumpError::DatabaseDoesNotExist) => { + render_json_error("database does not exist", StatusCode::NOT_FOUND) + } + Err(e) => { + error!("can't get schema dump: {}", e); + render_json_error("can't get schema dump", StatusCode::INTERNAL_SERVER_ERROR) + } + } + } + // download extension files from remote extension storage on demand (&Method::POST, route) if route.starts_with("/extension_server/") => { info!("serving {:?} POST request", route); @@ -303,10 +335,25 @@ fn render_json_error(e: &str, status: StatusCode) -> Response { }; Response::builder() .status(status) + .header(CONTENT_TYPE, "application/json") .body(Body::from(serde_json::to_string(&error).unwrap())) .unwrap() } +fn render_json(body: Body) -> Response { + Response::builder() + .header(CONTENT_TYPE, "application/json") + .body(body) + .unwrap() +} + +fn render_plain(body: Body) -> Response { + Response::builder() + .header(CONTENT_TYPE, "text/plain") + .body(body) + .unwrap() +} + async fn handle_terminate_request(compute: &Arc) -> Result<(), (String, StatusCode)> { { let mut state = compute.state.lock().unwrap(); diff --git a/compute_tools/src/http/openapi_spec.yaml b/compute_tools/src/http/openapi_spec.yaml index d2ec54299f..b0ddaeae2b 100644 --- a/compute_tools/src/http/openapi_spec.yaml +++ b/compute_tools/src/http/openapi_spec.yaml @@ -68,6 +68,51 @@ paths: schema: $ref: "#/components/schemas/Info" + /dbs_and_roles: + get: + tags: + - Info + summary: Get databases and roles in the catalog. + description: "" + operationId: getDbsAndRoles + responses: + 200: + description: Compute schema objects + content: + application/json: + schema: + $ref: "#/components/schemas/DbsAndRoles" + + /database_schema: + get: + tags: + - Info + summary: Get schema dump + parameters: + - name: database + in: query + description: Database name to dump. + required: true + schema: + type: string + example: "postgres" + description: Get schema dump in SQL format. + operationId: getDatabaseSchema + responses: + 200: + description: Schema dump + content: + text/plain: + schema: + type: string + description: Schema dump in SQL format. + 404: + description: Non existing database. + content: + application/json: + schema: + $ref: "#/components/schemas/GenericError" + /check_writability: post: tags: @@ -229,6 +274,73 @@ components: num_cpus: type: integer + DbsAndRoles: + type: object + description: Databases and Roles + required: + - roles + - databases + properties: + roles: + type: array + items: + $ref: "#/components/schemas/Role" + databases: + type: array + items: + $ref: "#/components/schemas/Database" + + Database: + type: object + description: Database + required: + - name + - owner + - restrict_conn + - invalid + properties: + name: + type: string + owner: + type: string + options: + type: array + items: + $ref: "#/components/schemas/GenericOption" + restrict_conn: + type: boolean + invalid: + type: boolean + + Role: + type: object + description: Role + required: + - name + properties: + name: + type: string + encrypted_password: + type: string + options: + type: array + items: + $ref: "#/components/schemas/GenericOption" + + GenericOption: + type: object + description: Schema Generic option + required: + - name + - vartype + properties: + name: + type: string + value: + type: string + vartype: + type: string + ComputeState: type: object required: diff --git a/compute_tools/src/lib.rs b/compute_tools/src/lib.rs index eac808385c..18c228ba54 100644 --- a/compute_tools/src/lib.rs +++ b/compute_tools/src/lib.rs @@ -8,6 +8,7 @@ pub mod configurator; pub mod http; #[macro_use] pub mod logger; +pub mod catalog; pub mod compute; pub mod extension_server; pub mod monitor; diff --git a/libs/compute_api/src/responses.rs b/libs/compute_api/src/responses.rs index fd0c90d447..d05d625b0a 100644 --- a/libs/compute_api/src/responses.rs +++ b/libs/compute_api/src/responses.rs @@ -3,7 +3,7 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize, Serializer}; -use crate::spec::ComputeSpec; +use crate::spec::{ComputeSpec, Database, Role}; #[derive(Serialize, Debug, Deserialize)] pub struct GenericAPIError { @@ -113,6 +113,12 @@ pub struct ComputeMetrics { pub total_ext_download_size: u64, } +#[derive(Clone, Debug, Default, Serialize)] +pub struct CatalogObjects { + pub roles: Vec, + pub databases: Vec, +} + /// Response of the `/computes/{compute_id}/spec` control-plane API. /// This is not actually a compute API response, so consider moving /// to a different place. diff --git a/test_runner/fixtures/endpoint/__init__.py b/test_runner/fixtures/endpoint/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py new file mode 100644 index 0000000000..42f0539c19 --- /dev/null +++ b/test_runner/fixtures/endpoint/http.py @@ -0,0 +1,23 @@ +import requests +from requests.adapters import HTTPAdapter + + +class EndpointHttpClient(requests.Session): + def __init__( + self, + port: int, + ): + super().__init__() + self.port = port + + self.mount("http://", HTTPAdapter()) + + def dbs_and_roles(self): + res = self.get(f"http://localhost:{self.port}/dbs_and_roles") + res.raise_for_status() + return res.json() + + def database_schema(self, database: str): + res = self.get(f"http://localhost:{self.port}/database_schema?database={database}") + res.raise_for_status() + return res.text diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index a6fd4792dd..b4761f103b 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -48,6 +48,7 @@ from urllib3.util.retry import Retry from fixtures import overlayfs from fixtures.broker import NeonBroker from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId +from fixtures.endpoint.http import EndpointHttpClient from fixtures.log_helper import log from fixtures.metrics import Metrics, MetricsGetter, parse_metrics from fixtures.pageserver.allowed_errors import ( @@ -3373,6 +3374,13 @@ class Endpoint(PgProtocol): self.active_safekeepers: List[int] = list(map(lambda sk: sk.id, env.safekeepers)) # path to conf is /endpoints//pgdata/postgresql.conf + def http_client( + self, auth_token: Optional[str] = None, retries: Optional[Retry] = None + ) -> EndpointHttpClient: + return EndpointHttpClient( + port=self.http_port, + ) + def create( self, branch_name: str, diff --git a/test_runner/regress/test_compute_catalog.py b/test_runner/regress/test_compute_catalog.py new file mode 100644 index 0000000000..dd36190fcd --- /dev/null +++ b/test_runner/regress/test_compute_catalog.py @@ -0,0 +1,34 @@ +import requests +from fixtures.neon_fixtures import NeonEnv + + +def test_compute_catalog(neon_simple_env: NeonEnv): + env = neon_simple_env + env.neon_cli.create_branch("test_config", "empty") + + endpoint = env.endpoints.create_start("test_config", config_lines=["log_min_messages=debug1"]) + client = endpoint.http_client() + + objects = client.dbs_and_roles() + + # Assert that 'cloud_admin' role exists in the 'roles' list + assert any( + role["name"] == "cloud_admin" for role in objects["roles"] + ), "The 'cloud_admin' role is missing" + + # Assert that 'postgres' database exists in the 'databases' list + assert any( + db["name"] == "postgres" for db in objects["databases"] + ), "The 'postgres' database is missing" + + ddl = client.database_schema(database="postgres") + + assert "-- PostgreSQL database dump" in ddl + + try: + client.database_schema(database="nonexistentdb") + raise AssertionError("Expected HTTPError was not raised") + except requests.exceptions.HTTPError as e: + assert ( + e.response.status_code == 404 + ), f"Expected 404 status code, but got {e.response.status_code}"