From fae734b15c31f889077949f7e9088f99d63fb096 Mon Sep 17 00:00:00 2001 From: William Huang Date: Tue, 13 May 2025 22:42:31 -0700 Subject: [PATCH] [BRC-3082] Monitor commit LSN lag among active SKs (#1002) Commit https://github.com/databricks-eng/hadron/commit/e69c3d632b3919e79c2a1efb8f51d2230f0e137a added metrics (used for alerting) to indicate whether Safekeepers are operating with a degraded quorum due to some of them being down. However, even if all SKs are active/reachable, we probably still want to raise an alert if some of them are really slow or otherwise lagging behind, as it is technically still a "degraded quorum" situation. Added a new field `max_active_safekeeper_commit_lag` to the `neon_perf_counters` view that reports the lag between the most advanced and most lagging commit LSNs among active Safekeepers. Commit LSNs are received from `AppendResponse` messages from SKs and recorded in the `WalProposer`'s shared memory state. Note that this lag is calculated among active SKs only to keep this alert clean. If there are inactive SKs the previous metric on active SKs will capture that instead. Note: @chen-luo_data pointed out during the PR review that we can probably get the benefits of this metric with PromQL query like `max (safekeeper_commit_lsn) by (tenant_id, timeline_id) - min(safekeeper_commit_lsn) by (tenant_id, timeline_id)` on existing metrics exported by SKs. Given that this code is already ready, @haoyu-huang_data suggested that I just check in this change anyway, as the reliability of prometheus metrics (and especially the aggregation operators when the result set cardinality is high) is somewhat questionable based on our prior experience. Added integration test `test_wal_acceptor.py::test_max_active_safekeeper_commit_lag`. --- kind/endpoint_tests.py | 499 +++ libs/walproposer/src/api_bindings.rs | 1 + pgxn/neon/neon_perf_counters.c | 49 +- pgxn/neon/walproposer.h | 2 + pgxn/neon/walproposer_pg.c | 11 + safekeeper/src/receive_wal.rs | 3 + scripts/neon_grep.txt | 13 + storage_controller/src/hadron_k8s.rs | 4906 ++++++++++++++++++++++ test_runner/regress/test_wal_acceptor.py | 84 + 9 files changed, 5553 insertions(+), 15 deletions(-) create mode 100644 kind/endpoint_tests.py create mode 100644 storage_controller/src/hadron_k8s.rs diff --git a/kind/endpoint_tests.py b/kind/endpoint_tests.py new file mode 100644 index 0000000000..72947d2625 --- /dev/null +++ b/kind/endpoint_tests.py @@ -0,0 +1,499 @@ +import json +import logging +import re +import time +import uuid +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +import backoff +import jwt +import psycopg2 +import requests +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from kind_test_env import ( + HADRON_COMPUTE_IMAGE_NAME, + HADRON_MT_IMAGE_NAME, + HadronEndpoint, + KindTestEnvironment, + NodeType, + check_prerequisites, + read_table_data, + unique_node_id, + write_table_data, +) + +logging.getLogger("backoff").addHandler(logging.StreamHandler()) + + +def read_private_key_and_public_key( + privkey_filename: str, certificate_path: str +) -> tuple[str, str]: + # Get certificate serial number + with open(certificate_path, "rb") as pem_file: + content = pem_file.read() + certificate = x509.load_pem_x509_certificate(content, default_backend()) + certificate_thumbprint = certificate.fingerprint(hashes.SHA256()).hex() + + # Get private key content string. + private_key_content = Path(privkey_filename).read_text() + return (private_key_content, certificate_thumbprint) + + +def generate_token( + testuser: str, + endpoint_id: str, + database: str, + cert_directory: str, + payload: dict[str, Any] | None = None, +) -> str: + """ + Generate a JWT token for a testuser using the private key specified in the environment. + + :param testuser: user name to generate the token for. + :param endpoint_id: hadron endpoint id. + :param database: database to connect to. + :param cert_directory: directory containing the private key and public key for generating the token. + :param payload: additional payload. It will be merged with the default payload. + :return: + """ + (private_key, certificate_thumbprint) = read_private_key_and_public_key( + f"{cert_directory}/privkey1.pem", f"{cert_directory}/pubkey1.pem" + ) + expiration = time.time() + (10 * 60) # Expiration time is 10 minutes. + + payload = { + "sub": testuser, + "iss": "brickstore.databricks.com", + "exp": int(expiration), + "endpointId": endpoint_id, + "database": database, + } | (payload or {}) + + token = jwt.encode( + payload, private_key, algorithm="RS256", headers={"kid": certificate_thumbprint} + ) + return token + + +def check_token_login( + endpoint: HadronEndpoint, endpoint_id: str, database: str, cert_directory: str +) -> None: + """ + Check that we can login to the endpoint using a JWT token. + """ + # Create test_token_user + testuser = "test_token_user" + with endpoint.cursor() as cur: + cur.execute(f"CREATE ROLE {testuser} LOGIN PASSWORD NULL;") + + # Login with the token + token = generate_token(testuser, endpoint_id, database, cert_directory) + with endpoint.cursor(user=testuser, password=token) as c: + c.execute("select current_user") + result = c.fetchone() + assert result == (testuser,) + + +def check_databricks_roles(endpoint: HadronEndpoint) -> None: + """ + Check that the expected Databricks roles are present in the endpoint. + """ + expected_roles = ["databricks_monitor", "databricks_control_plane", "databricks_gateway"] + with endpoint.cursor() as cur: + for role in expected_roles: + cur.execute(f"SELECT 1 FROM pg_roles WHERE rolname = '{role}';") + result = cur.fetchone() + assert result == (1,) + + +def test_create_endpoint_and_connect() -> None: + """ + Tests that we can create an endpoint on a Hadron deployment and connect to it/run simple queries. + """ + + with KindTestEnvironment() as env: + env.load_image(HADRON_MT_IMAGE_NAME) + env.load_image(HADRON_COMPUTE_IMAGE_NAME) + + # Setup the Hadron deployment in the brand new KIND cluster. + # We have 2 PS, 3 SK, and mocked S3 storage in the test. + env.start_hcc_with_configmaps( + "configmaps/configmap_2ps_3sk.yaml", "configmaps/safe_configmap_1.yaml" + ) + + ps_id_0 = unique_node_id(0, 0) + ps_id_1 = unique_node_id(0, 1) + + # Wait for all the Hadron storage components to come up. + env.wait_for_app_ready("storage-controller", namespace="hadron") + env.wait_for_statefulset_replicas("safe-keeper-0", replicas=3, namespace="hadron") + env.wait_for_statefulset_replicas("page-server-0", replicas=2, namespace="hadron") + env.wait_for_node_registration(NodeType.PAGE_SERVER, {ps_id_0, ps_id_1}) + env.wait_for_node_registration( + NodeType.SAFE_KEEPER, {unique_node_id(0, 0), unique_node_id(0, 1), unique_node_id(0, 2)} + ) + + @backoff.on_exception(backoff.expo, psycopg2.Error, max_tries=10) + def check_superuser_and_basic_data_operation(endpoint): + with endpoint.cursor() as cur: + # Check whether the current user is the super user. + cur.execute("SELECT usesuper FROM pg_user WHERE usename = CURRENT_USER") + is_superuser = cur.fetchone()[0] + + # Create a simple table and insert some data. + cur.execute("DROP TABLE IF EXISTS t") + cur.execute("CREATE TABLE t (x int)") + cur.execute("INSERT INTO t VALUES (1), (2), (3)") + cur.execute("SELECT * FROM t") + rows = cur.fetchall() + + # Test that the HCC-created default user was indeed the super user + # and that we can retrieve the same data we inserted from the table. + assert is_superuser is True, "Current user is not a superuser" + assert rows == [(1,), (2,), (3,)], f"Data retrieval mismatch: {rows}" + + # Verify that the server has ssl turn on . + cur.execute("SHOW SSL;") + result = cur.fetchone() + assert result == ("on",) + + # Verify that the connection is using SSL. + cur.execute("SELECT SSL FROM pg_stat_ssl WHERE pid = pg_backend_pid();") + result = cur.fetchone() + assert result == (True,) + + @backoff.on_exception(backoff.expo, psycopg2.Error, max_tries=10) + def check_databricks_system_tables(endpoint): + with endpoint.cursor(dbname="databricks_system") as cur: + # Verify that the LFC is working by querying the LFC stats table. + # If the LFC is not running, the table will contain a single row with all NULL values. + cur.execute( + "SELECT 1 FROM neon.NEON_STAT_FILE_CACHE WHERE file_cache_misses IS NOT NULL;" + ) + lfcStatsRows = cur.fetchall() + + assert len(lfcStatsRows) == 1, "LFC stats table is empty" + + # Check that the system-level GUCs are set to the expected values. These should be set before the endpoint + # starts accepting connections. + def check_guc_values(endpoint): + with endpoint.cursor() as cur: + cur.execute("SHOW databricks.workspace_url;") + res = cur.fetchone()[0] + print(f"atabricks.workspace_url: {res}") + assert ( + res == urlparse(test_workspace_url).hostname + ), "Failed to get the correct databricks.workspace_url GUC value" + cur.execute("SHOW databricks.enable_databricks_identity_login;") + res = cur.fetchone()[0] + print(f"databricks.enable_databricks_identity_login: {res}") + assert ( + res == "on" + ), "Failed to get the correct databricks.enable_databricks_identity_login GUC value" + cur.execute("SHOW databricks.enable_sql_restrictions;") + res = cur.fetchone()[0] + print(f"databricks.enable_sql_restrictions: {res}") + assert ( + res == "on" + ), "Failed to get the correct databricks.enable_sql_restrictions GUC value" + cur.execute("SHOW databricks.disable_PAT_login;") + res = cur.fetchone()[0] + print(f"databricks.disable_PAT_login: {res}") + assert ( + res == "on" + ), "Failed to get the correct databricks.disable_PAT_login GUC value" + + def check_cert_auth_user(endpoint): + expected_user = "databricks_control_plane" + with endpoint.cursor( + user=expected_user, + sslcert=f"{env.tempDir.name}/pubkey1.pem", + sslkey=f"{env.tempDir.name}/privkey1.pem", + sslmode="require", + ) as cur: + cur.execute("select current_user;") + current_user = cur.fetchone()[0] + assert current_user == expected_user, f"{current_user} is not {expected_user}" + + # Query the "neon.pageserver_connstring" Postgres GUC to see which pageserver the compute node is currently connected to. + def check_current_ps_id(endpoint: HadronEndpoint) -> int: + with endpoint.cursor() as cur: + cur.execute("SHOW neon.pageserver_connstring;") + res = cur.fetchone() + assert res is not None, "Failed to get the current pageserver connection URL" + connection_url = res[0] + print(f"Current pageserver connection URL is {connection_url}") + host = urlparse(connection_url).hostname + # In this test, the hostname is in the form of "page-server-{pool}-{ordinal}.page-server.hadron.svc.cluster.local" + # We extract the "page-server-{pool}-{ordinal}" part and convert the two numbers into the pageserver ID. + pool_id, ordinal_id = host.split(".")[0].split("-")[-2:] + return unique_node_id(int(pool_id), int(ordinal_id)) + + def verify_compute_pod_metadata(pod_name: str, pod_namespace: str, endpoint_id: uuid.UUID): + # Check that dblet-required labels and annotations are present on the compute pod. + # See go/dblet-labels + compute_pod = env.kubectl_get_pod(namespace=pod_namespace, pod_name=pod_name) + assert ( + compute_pod["metadata"]["annotations"]["databricks.com/workspace-url"] + == urlparse(test_workspace_url).hostname + ) + assert compute_pod["metadata"]["labels"]["orgId"] == test_workspace_id + assert compute_pod["metadata"]["labels"]["dblet.dev/appid"] == f"{endpoint_id}-0" + + def check_pg_log_redaction(pod_name: str, container_name: str, pod_namespace: str): + """ + Various checks to ensure that the PG log redactor is working as expected via comparing + PG log vs. redacted log. + + Checks between original from PG and redacted log: + - log folders exist + - there's at least 1 log file + - the number of files match + - number of log entries is close + - the last redacted log entry is in the last few PG log entries, ignoring the + redacted message field + """ + # a little higher than redactor flush entries + lag_tolerance_items = 22 + MESSAGE_FIELD = "message" + LOG_DAEMON_EXPECTED_REGEX = r"hadron-compute-redacted-[a-zA-Z]{3}[0-9]{4}\.json" + redactor_file_count_catchup_timeout_seconds = 60 + + def kex(command: list[str]) -> str: + # mypy can't tell kubectl_exec returns str + result: str = env.kubectl_exec(pod_namespace, pod_name, container_name, command) + return result + + log_folders = kex(["ls", "/databricks/logs/"]).split() + assert "brickstore" in log_folders, "PG log folder not found" + assert "brickstore-redacted" in log_folders, "Redacted log folder not found" + + @backoff.on_exception( + backoff.expo, AssertionError, max_time=redactor_file_count_catchup_timeout_seconds + ) + def get_caught_up_files() -> tuple[list[str], list[str]]: + """ + Get pg (input) and redacted (output) log files after verifying the files exist and the counts match. + + @return: tuple of: + - list of pg log file names + - list of redacted log file names + """ + pg_log_files = kex(["ls", "-t", "/databricks/logs/brickstore/"]).split() + redacted_log_files = kex( + ["ls", "-t", "/databricks/logs/brickstore-redacted/"] + ).split() + pg_log_files = [file for file in pg_log_files if ".json" in file] + print("Compute log files:", pg_log_files, redacted_log_files) + assert len(pg_log_files) > 0, "PG didn't produce any JSON log files" + assert len(redacted_log_files) > 0, "Redactor didn't produce any log files" + assert len(pg_log_files) == len( + redacted_log_files + ), "Redactor didn't process each log file exactly once" + for file in redacted_log_files: + assert re.match( + LOG_DAEMON_EXPECTED_REGEX, file + ), f"Unexpected redacted log file name: {file}" + return pg_log_files, redacted_log_files + + # wait for pg_log_redactor to catch up, by file count + pg_log_files, redacted_log_files = get_caught_up_files() + + # Rest will examine latest files closer + last_pg_log_file = pg_log_files[0] + last_redacted_log_file = redacted_log_files[0] + + pg_log_entries_num = int( + kex(["wc", "-l", f"/databricks/logs/brickstore/{last_pg_log_file}"]).split()[0] + ) + redacted_log_entries_num = int( + kex( + ["wc", "-l", f"/databricks/logs/brickstore-redacted/{last_redacted_log_file}"] + ).split()[0] + ) + assert ( + redacted_log_entries_num <= pg_log_entries_num + ), "Redactor emitted non-PG log messages, either through bug or own error msg." + assert ( + redacted_log_entries_num - pg_log_entries_num < lag_tolerance_items + ), "Redactor lagged behind, more than OS buffering should allow for." + + # Order to decrease chance of lag flakiness + pg_log_tail = kex( + [ + "tail", + "-n", + str(lag_tolerance_items), + f"/databricks/logs/brickstore/{last_pg_log_file}", + ] + ) + redacted_log_tail_item = kex( + [ + "tail", + "-n", + "1", + f"/databricks/logs/brickstore-redacted/{last_redacted_log_file}", + ] + ) + + redacted_log_tail_json = json.loads(redacted_log_tail_item) + if MESSAGE_FIELD in redacted_log_tail_json: + del redacted_log_tail_json[MESSAGE_FIELD] + found_in_pg_log = False + for pg_log_item in pg_log_tail.split("\n"): + pg_log_json = json.loads(pg_log_item) + if MESSAGE_FIELD in pg_log_json: + del pg_log_json[MESSAGE_FIELD] + if redacted_log_tail_json == pg_log_json: + found_in_pg_log = True + break + # Note: lag is possible because tail call is not synced w/ lag check and there's no simple way to + assert found_in_pg_log, "Last log seen in redactor is not a recent log from PG, through lag bug or own error msg" + + # Create an endpoint with random IDs. + test_metastore_id = uuid.uuid4() + test_endpoint_id = uuid.uuid4() + test_workspace_id = "987654321" + test_workspace_url = "https://test-workspace-url/" + compute_namespace = "" + compute_name = "" + with env.hcc_create_endpoint( + test_metastore_id, test_endpoint_id, test_workspace_id, test_workspace_url + ) as endpoint: + check_superuser_and_basic_data_operation(endpoint) + check_databricks_system_tables(endpoint) + check_databricks_roles(endpoint) + check_guc_values(endpoint) + check_cert_auth_user(endpoint) + check_token_login(endpoint, str(test_endpoint_id), "postgres", env.tempDir.name) + + write_table_data(endpoint, "my_table", ["a", "b", "c"]) + assert read_table_data(endpoint, "my_table") == ["a", "b", "c"] + + compute_name = endpoint.name + compute_namespace = endpoint.namespace + + hadron_compute_pods = env.kubectl_pods(namespace=compute_namespace) + hadron_compute_pod_name = hadron_compute_pods[0] + verify_compute_pod_metadata( + hadron_compute_pod_name, compute_namespace, test_endpoint_id + ) + + # Check in compute log that we have initialized the Databricks extension. + logs = env.kubectl_logs(namespace=compute_namespace, pod_name=hadron_compute_pod_name) + assert "Databricks extension initialized" in logs, "Endpoint creation not logged" + + # Check that metrics are exported + r = requests.get(endpoint.metrics_url) + assert r.status_code == 200 + assert "pg_static" in r.text + # Check for this particular metric to make sure prometheus exporter has the permission to + # execute wal-related functions such as `pg_current_wal_lsn`, `pg_wal_lsn_diff`, etc. + assert "pg_replication_slots_pg_wal_lsn_diff" in r.text + # Check for these metrics from function or view in neon schema in databricks_system database + # to ensure extra grants were successful + assert "pg_backpressure_throttling_time" in r.text + assert "pg_lfc_hits" in r.text + assert "pg_lfc_working_set_size" in r.text + assert "pg_cluster_size_bytes" in r.text + assert "pg_snapshot_files_count" in r.text + assert re.search(r"pg_writable_bool{.*} 1", r.text) + assert "pg_database_size_bytes" not in r.text + # Check for this label key to ensure that the metrics are being labeled correctly for PuPr model + assert "pg_instance_id=" in r.text + assert "pg_metrics_sql_index_corruption_count" in r.text + assert "pg_metrics_num_active_safekeepers" in r.text + assert "pg_metrics_num_configured_safekeepers" in r.text + assert "pg_metrics_max_active_safekeeper_commit_lag" in r.text + + check_pg_log_redaction(hadron_compute_pod_name, endpoint.container, compute_namespace) + + # Smoke test tenant migration + curr_ps_id = check_current_ps_id(endpoint) + new_ps_id = ps_id_0 if curr_ps_id == ps_id_1 else ps_id_1 + env.hcc_migrate_endpoint(test_endpoint_id, new_ps_id) + assert check_current_ps_id(endpoint) == new_ps_id + # Check that data operation still works after migration + check_superuser_and_basic_data_operation(endpoint) + # Check that the data we wrote before migration stays untouched + assert read_table_data(endpoint, "my_table") == ["a", "b", "c"] + + # Restart the compute endpoint to clear any local caches to be extra sure that tenant migration indeed + # does not lose data + with env.restart_endpoint(compute_name, compute_namespace) as endpoint: + # Check that data persists after the compute node restarts. + assert read_table_data(endpoint, "my_table") == ["a", "b", "c"] + # Check that data operations can resume after the restart + check_superuser_and_basic_data_operation(endpoint) + + # PG compute reconciliation verification test. We intetionally run this test after the tenant migration + # restart test so that the tenant migration test can observe the first, "untainted" restart. + # + # Update the cluster-config map's default PgParams and ensure that the compute instance is reconciled properly. + # In this case, test updating the compute_http_port as a trivial example. + current_configmap = env.get_configmap_json("cluster-config", "hadron") + config_map_key = "config.json" + config_json = json.loads(current_configmap["data"][config_map_key]) + + if "pg_params" not in config_json: + config_json["pg_params"] = {} + + test_http_port = 3456 + + config_json["pg_params"]["compute_http_port"] = test_http_port + + patch_json = {"data": {"config.json": json.dumps(config_json)}} + + env.kubectl_patch( + resource="configmap", + name="cluster-config", + namespace="hadron", + json_patch=json.dumps(patch_json), + ) + + # Ensure that the deployment is updated by the HCC accordingly within 2 mintues. + # Note that waiting this long makes sense since files mounted via a config maps + # are not updated immediately, and instead updated every kubelet sync period (typically 1m) + # on top of the kubelets configmap cahce TTL (which is also typically 1m). + timeout = 120 + start_time = time.time() + while True: + deployment = env.get_deployment_json(compute_name, compute_namespace) + port = deployment["spec"]["template"]["spec"]["containers"][0]["ports"][1] + if port["containerPort"] == test_http_port: + print("Compute succesfully updated") + break + else: + print( + f"Current PG HTTP port spec: {port}, Expected container port: {test_http_port}" + ) + + if time.time() - start_time >= timeout: + raise Exception(f"Compute deployment did not update within {timeout} seconds") + + time.sleep(5) + + # Wait for the updated endpoint to become ready again after a rollout. + env.wait_for_app_ready(compute_name, namespace="hadron-compute", timeout=60) + + # Verify that the updated compute pod has the correct workspace annotations/labels + # as persisted in the metadata database. + hadron_compute_pods = env.kubectl_pods(namespace=compute_namespace) + hadron_compute_pod_name = hadron_compute_pods[0] + verify_compute_pod_metadata( + hadron_compute_pod_name, compute_namespace, test_endpoint_id + ) + + # Delete the endpoint. + env.hcc_delete_endpoint(test_endpoint_id) + # Ensure that k8s resources of the compute endpoint are deleted. + env.wait_for_compute_resource_deletion(compute_name, compute_namespace) + + +if __name__ == "__main__": + check_prerequisites() + test_create_endpoint_and_connect() diff --git a/libs/walproposer/src/api_bindings.rs b/libs/walproposer/src/api_bindings.rs index 7b09ee8080..74662f8b12 100644 --- a/libs/walproposer/src/api_bindings.rs +++ b/libs/walproposer/src/api_bindings.rs @@ -479,6 +479,7 @@ pub fn empty_shmem() -> crate::bindings::WalproposerShmemState { wal_rate_limiter: empty_wal_rate_limiter, num_safekeepers: 0, safekeeper_status: [0; 32], + safekeeper_commit_lsn: [0; 32], } } diff --git a/pgxn/neon/neon_perf_counters.c b/pgxn/neon/neon_perf_counters.c index ed624ea6d6..fc437e5940 100644 --- a/pgxn/neon/neon_perf_counters.c +++ b/pgxn/neon/neon_perf_counters.c @@ -396,6 +396,7 @@ neon_get_perf_counters(PG_FUNCTION_ARGS) WalproposerShmemState *wp_shmem; uint32 num_safekeepers; uint32 num_active_safekeepers; + XLogRecPtr max_active_safekeeper_commit_lag; /* END_HADRON */ /* We put all the tuples into a tuplestore in one go. */ @@ -451,35 +452,53 @@ neon_get_perf_counters(PG_FUNCTION_ARGS) // Note that we are taking a mutex when reading from walproposer shared memory so that the total safekeeper count is // consistent with the active wal acceptors count. Assuming that we don't query this view too often the mutex should // not be a huge deal. + XLogRecPtr min_commit_lsn = InvalidXLogRecPtr; + XLogRecPtr max_commit_lsn = InvalidXLogRecPtr; + XLogRecPtr lsn; + wp_shmem = GetWalpropShmemState(); SpinLockAcquire(&wp_shmem->mutex); + num_safekeepers = wp_shmem->num_safekeepers; num_active_safekeepers = 0; for (int i = 0; i < num_safekeepers; i++) { if (wp_shmem->safekeeper_status[i] == 1) { num_active_safekeepers++; + // Only track the commit LSN lag among active safekeepers. + // If there are inactive safekeepers we will raise another alert so this lag value + // is less critical. + lsn = wp_shmem->safekeeper_commit_lsn[i]; + if (XLogRecPtrIsInvalid(min_commit_lsn) || lsn < min_commit_lsn) { + min_commit_lsn = lsn; + } + if (XLogRecPtrIsInvalid(max_commit_lsn) || lsn > max_commit_lsn) { + max_commit_lsn = lsn; + } } } + // Calculate max commit LSN lag across active safekeepers + max_active_safekeeper_commit_lag = (XLogRecPtrIsInvalid(min_commit_lsn) ? 0 : max_commit_lsn - min_commit_lsn); + SpinLockRelease(&wp_shmem->mutex); } { - metric_t databricks_metrics[] = { - {"sql_index_corruption_count", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->index_corruption_count)}, - {"sql_data_corruption_count", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->data_corruption_count)}, - {"sql_internal_error_count", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->internal_error_count)}, - {"ps_corruption_detected", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->ps_corruption_detected)}, - {"num_active_safekeepers", false, 0.0, (double) num_active_safekeepers}, - {"num_configured_safekeepers", false, 0.0, (double) num_safekeepers}, - {NULL, false, 0, 0}, - }; - for (int i = 0; databricks_metrics[i].name != NULL; i++) - { - metric_to_datums(&databricks_metrics[i], &values[0], &nulls[0]); - tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls); - } + metric_t databricks_metrics[] = { + {"sql_index_corruption_count", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->index_corruption_count)}, + {"sql_data_corruption_count", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->data_corruption_count)}, + {"sql_internal_error_count", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->internal_error_count)}, + {"ps_corruption_detected", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->ps_corruption_detected)}, + {"num_active_safekeepers", false, 0.0, (double) num_active_safekeepers}, + {"num_configured_safekeepers", false, 0.0, (double) num_safekeepers}, + {"max_active_safekeeper_commit_lag", false, 0.0, (double) max_active_safekeeper_commit_lag}, + {NULL, false, 0, 0}, + }; + for (int i = 0; databricks_metrics[i].name != NULL; i++) + { + metric_to_datums(&databricks_metrics[i], &values[0], &nulls[0]); + tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls); } - /* END_HADRON */ } + /* END_HADRON */ pfree(metrics); diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index ac42c2925d..ecc8882bb2 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -436,6 +436,8 @@ typedef struct WalproposerShmemState uint32 num_safekeepers; /* Per-safekeeper status flags: 0=inactive, 1=active */ uint8 safekeeper_status[MAX_SAFEKEEPERS]; + /* Per-safekeeper commit LSN for metrics */ + XLogRecPtr safekeeper_commit_lsn[MAX_SAFEKEEPERS]; /* END_HADRON */ } WalproposerShmemState; diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index 0972b2df17..6c1f56d919 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -2106,6 +2106,16 @@ walprop_pg_process_safekeeper_feedback(WalProposer *wp, Safekeeper *sk) if (wp->config->syncSafekeepers) return; + /* BEGIN_HADRON */ + // Record safekeeper commit LSN in shared memory for lag monitoring + { + WalproposerShmemState *shmem = wp->api.get_shmem_state(wp); + Assert(sk->index < MAX_SAFEKEEPERS); + SpinLockAcquire(&shmem->mutex); + shmem->safekeeper_commit_lsn[sk->index] = sk->appendResponse.commitLsn; + SpinLockRelease(&shmem->mutex); + } + /* END_HADRON */ /* handle fresh ps_feedback */ if (sk->appendResponse.ps_feedback.present) @@ -2243,6 +2253,7 @@ walprop_pg_reset_safekeeper_statuses_for_metrics(WalProposer *wp, uint32 num_saf SpinLockAcquire(&shmem->mutex); shmem->num_safekeepers = num_safekeepers; memset(shmem->safekeeper_status, 0, sizeof(shmem->safekeeper_status)); + memset(shmem->safekeeper_commit_lsn, 0, sizeof(shmem->safekeeper_commit_lsn)); SpinLockRelease(&shmem->mutex); } diff --git a/safekeeper/src/receive_wal.rs b/safekeeper/src/receive_wal.rs index eb8eee6ab8..349b228f1c 100644 --- a/safekeeper/src/receive_wal.rs +++ b/safekeeper/src/receive_wal.rs @@ -24,6 +24,7 @@ use tracing::*; use utils::id::TenantTimelineId; use utils::lsn::Lsn; use utils::pageserver_feedback::PageserverFeedback; +use utils::pausable_failpoint; use crate::GlobalTimelines; use crate::handler::SafekeeperPostgresHandler; @@ -598,6 +599,8 @@ impl WalAcceptor { // Note that a flush can still happen on segment bounds, which will result // in an AppendResponse. if let ProposerAcceptorMessage::AppendRequest(append_request) = msg { + // allow tests to pause AppendRequest processing to simulate lag + pausable_failpoint!("sk-acceptor-pausable"); msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request); dirty = true; } diff --git a/scripts/neon_grep.txt b/scripts/neon_grep.txt index 8b323ab920..a41cc714c5 100644 --- a/scripts/neon_grep.txt +++ b/scripts/neon_grep.txt @@ -547,6 +547,7 @@ pgxn/neon/neon_perf_counters.c:neon_get_perf_counters(PG_FUNCTION_ARGS) pgxn/neon/neon_perf_counters.c: neon_per_backend_counters totals = {0}; pgxn/neon/neon_perf_counters.c: uint32 num_safekeepers; pgxn/neon/neon_perf_counters.c: uint32 num_active_safekeepers; +pgxn/neon/neon_perf_counters.c: XLogRecPtr max_active_safekeeper_commit_lag; pgxn/neon/neon_perf_counters.c: for (int procno = 0; procno < NUM_NEON_PERF_COUNTER_SLOTS; procno++) pgxn/neon/neon_perf_counters.c: neon_per_backend_counters *counters = &neon_per_backend_counters_shared[procno]; pgxn/neon/neon_perf_counters.c: metrics = neon_perf_counters_to_metrics(&totals); @@ -559,8 +560,14 @@ pgxn/neon/neon_perf_counters.c: num_active_safekeepers = 0; pgxn/neon/neon_perf_counters.c: for (int i = 0; i < num_safekeepers; i++) { pgxn/neon/neon_perf_counters.c: if (wp_shmem->safekeeper_status[i] == 1) { pgxn/neon/neon_perf_counters.c: num_active_safekeepers++; +pgxn/neon/neon_perf_counters.c: // Only track the commit LSN lag among active safekeepers. +pgxn/neon/neon_perf_counters.c: // If there are inactive safekeepers we will raise another alert so this lag value +pgxn/neon/neon_perf_counters.c: lsn = wp_shmem->safekeeper_commit_lsn[i]; +pgxn/neon/neon_perf_counters.c: // Calculate max commit LSN lag across active safekeepers +pgxn/neon/neon_perf_counters.c: max_active_safekeeper_commit_lag = (XLogRecPtrIsInvalid(min_commit_lsn) ? 0 : max_commit_lsn - min_commit_lsn); pgxn/neon/neon_perf_counters.c: {"num_active_safekeepers", false, 0.0, (double) num_active_safekeepers}, pgxn/neon/neon_perf_counters.c: {"num_configured_safekeepers", false, 0.0, (double) num_safekeepers}, +pgxn/neon/neon_perf_counters.c: {"max_active_safekeeper_commit_lag", false, 0.0, (double) max_active_safekeeper_commit_lag}, pgxn/neon/neon_perf_counters.h: * neon_perf_counters.h pgxn/neon/neon_perf_counters.h: * Performance counters for neon storage requests pgxn/neon/neon_perf_counters.h:#ifndef NEON_PERF_COUNTERS_H @@ -1487,6 +1494,8 @@ pgxn/neon/walproposer.h: /* Number of safekeepers in the config */ pgxn/neon/walproposer.h: uint32 num_safekeepers; pgxn/neon/walproposer.h: /* Per-safekeeper status flags: 0=inactive, 1=active */ pgxn/neon/walproposer.h: uint8 safekeeper_status[MAX_SAFEKEEPERS]; +pgxn/neon/walproposer.h: /* Per-safekeeper commit LSN for metrics */ +pgxn/neon/walproposer.h: XLogRecPtr safekeeper_commit_lsn[MAX_SAFEKEEPERS]; pgxn/neon/walproposer.h: * Report safekeeper state to proposer pgxn/neon/walproposer.h: * Current term of the safekeeper; if it is higher than proposer's, the pgxn/neon/walproposer.h: /* Safekeeper reports back his awareness about which WAL is committed, as */ @@ -1718,6 +1727,9 @@ pgxn/neon/walproposer_pg.c: * Based on commitLsn and safekeeper responses includ pgxn/neon/walproposer_pg.c: * None of that is functional in sync-safekeepers. pgxn/neon/walproposer_pg.c:walprop_pg_process_safekeeper_feedback(WalProposer *wp, Safekeeper *sk) pgxn/neon/walproposer_pg.c: if (wp->config->syncSafekeepers) +pgxn/neon/walproposer_pg.c: // Record safekeeper commit LSN in shared memory for lag monitoring +pgxn/neon/walproposer_pg.c: Assert(sk->index < MAX_SAFEKEEPERS); +pgxn/neon/walproposer_pg.c: shmem->safekeeper_commit_lsn[sk->index] = sk->appendResponse.commitLsn; pgxn/neon/walproposer_pg.c: SetNeonCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize); pgxn/neon/walproposer_pg.c: * hardened and will be fetched from one of safekeepers by pgxn/neon/walproposer_pg.c: * neon_walreader if needed. @@ -1729,6 +1741,7 @@ pgxn/neon/walproposer_pg.c:uint64 GetNeonCurrentClusterSize(void); pgxn/neon/walproposer_pg.c:walprop_pg_reset_safekeeper_statuses_for_metrics(WalProposer *wp, uint32 num_safekeepers) pgxn/neon/walproposer_pg.c: shmem->num_safekeepers = num_safekeepers; pgxn/neon/walproposer_pg.c: memset(shmem->safekeeper_status, 0, sizeof(shmem->safekeeper_status)); +pgxn/neon/walproposer_pg.c: memset(shmem->safekeeper_commit_lsn, 0, sizeof(shmem->safekeeper_commit_lsn)); pgxn/neon/walproposer_pg.c:walprop_pg_update_safekeeper_status_for_metrics(WalProposer *wp, uint32 sk_index, uint8 status) pgxn/neon/walproposer_pg.c: Assert(sk_index < MAX_SAFEKEEPERS); pgxn/neon/walproposer_pg.c: shmem->safekeeper_status[sk_index] = status; diff --git a/storage_controller/src/hadron_k8s.rs b/storage_controller/src/hadron_k8s.rs new file mode 100644 index 0000000000..ca91293e37 --- /dev/null +++ b/storage_controller/src/hadron_k8s.rs @@ -0,0 +1,4906 @@ +use crate::config_manager; +use crate::hadron_drain_and_fill::DrainAndFillManager; +use crate::hadron_pageserver_watcher::create_pageserver_pod_watcher; +use crate::hadron_sk_maintenance::SKMaintenanceManager; +use crate::metrics::{self, ConfigWatcherCompleteLabelGroup, ReconcileOutcome}; +use crate::node::transform_pool_id; +use crate::persistence::Persistence; +use crate::service; +use anyhow::{anyhow, Context}; +use async_trait::async_trait; +use core::fmt; +use itertools::Either; +use storage_scrubber::NodeKind; +use tokio::time::sleep; +use utils::env::is_chaos_testing; +use utils::env::is_dev_or_staging; + +use reqwest::Url; +use sha2::{Digest, Sha256}; +use std::collections::BTreeMap; +use std::fmt::Debug; +use std::fs::File; +use std::io::BufReader; +use std::sync::{Arc, RwLock}; +use std::time::{Duration, Instant}; +use tokio::io::AsyncReadExt; +use tokio::runtime::Handle; +use tracing::Instrument; + +use uuid::Uuid; + +use compute_api::spec::{DatabricksSettings, PgComputeTlsSettings}; +use k8s_openapi::api::apps::v1::{ + DaemonSet, DaemonSetSpec, DaemonSetUpdateStrategy, Deployment, DeploymentSpec, ReplicaSet, + ReplicaSetSpec, RollingUpdateDaemonSet, +}; +use k8s_openapi::api::core::v1::{ + Affinity, ConfigMapVolumeSource, Container, ContainerPort, EmptyDirVolumeSource, EnvVar, + EnvVarSource, HTTPGetAction, HostPathVolumeSource, LocalObjectReference, NodeAffinity, + NodeSelector, NodeSelectorRequirement, NodeSelectorTerm, ObjectFieldSelector, + PersistentVolumeClaim, PersistentVolumeClaimSpec, Pod, PodAffinityTerm, PodAntiAffinity, + PodReadinessGate, PodSecurityContext, PodSpec, PodTemplateSpec, Probe, ResourceRequirements, + SecretKeySelector, SecretVolumeSource, SecurityContext, Service, ServicePort, ServiceSpec, + Toleration, Volume, VolumeMount, VolumeResourceRequirements, +}; +use k8s_openapi::apimachinery::pkg::api::resource::Quantity; +use k8s_openapi::apimachinery::pkg::apis::meta; +use k8s_openapi::apimachinery::pkg::util::intstr::IntOrString; +use kube::api::{DeleteParams, ListParams, PostParams}; +use kube::{Api, Client}; +use mockall::automock; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use utils::ip_address::HADRON_NODE_IP_ADDRESS; + +use crate::hadron_token::HadronTokenGeneratorImpl; +use hcc_api::models::{EndpointConfig, EndpointTShirtSize, PostgresConnectionInfo}; +use openkruise::{ + StatefulSet, StatefulSetPersistentVolumeClaimRetentionPolicy, StatefulSetSpec, + StatefulSetUpdateStrategy, StatefulSetUpdateStrategyRollingUpdate, +}; + +#[derive(PartialEq, Clone, Debug, Hash, Eq, Serialize, Deserialize)] +pub enum CloudProvider { + AWS, + Azure, +} + +/// Represents the model to use when deploying/managing "compute". +/// - PrivatePreview: In the Private Preview mode, each "compute" is modeled as a Kubernetes Deployment object, +/// a ClusterIP admin Service object, and a LoadBalancer Service object used for direct ingress. +/// - PublicPreview: In the Public Preview mode, each "compute" is modeled as a Kubernetes ReplicaSet object +/// and a ClusterIP admin Service object. +#[derive(Serialize, Deserialize, Debug)] +pub enum ComputeModel { + PrivatePreview, + PublicPreview, +} + +/// Enum representing the Kubernetes Service type. +pub enum K8sServiceType { + ClusterIP, + LoadBalancer, +} + +// Struct representing the various parameters we can set when defining a readiness probe for a container. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ReadinessProbeParams { + /// The endpoint on the server that we want to use to perform the readiness check. If pass + /// the path of /healthz then Kubernetes will send a HTTP GET to /healthz and any code in the + /// range [200, 400) is considered a success and any other code is considered a failure + pub endpoint_path: String, + /// The minimum number of consecutive failures for the probe to be considered failed after having + /// succeeded + pub failure_threshold: i32, + /// The number of seconds after the container has started that the probe is run for the first time + pub initial_delay_seconds: i32, + /// How often, in seconds, does Kubernetes perform a probe + pub period_seconds: i32, + /// Minimum consecutive success for the probe to be considered success after a failure + pub success_threshold: i32, + /// Number of seconds after which the probe times out + pub timeout_seconds: i32, +} + +impl Default for ReadinessProbeParams { + fn default() -> Self { + Self { + endpoint_path: "/status".to_string(), + failure_threshold: 2, + initial_delay_seconds: 10, + period_seconds: 2, + success_threshold: 2, + timeout_seconds: 1, + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct PgCompute { + // NB: Consider DNS name length limits when modifying the compute_name format. + // Most cloud DNS providers will limit records to ~250 chars in length, which, + // depending on the number of additional DNS labels, could impose restrictions + // on compute name length specifically. + pub name: String, + pub compute_id: String, + pub control_plane_token: String, + pub workspace_id: Option, + pub workspace_url: Option, + pub image_override: Option, + pub exporter_image_override: Option, + pub node_selector_override: Option>, + pub resources: ResourceRequirements, + pub tshirt_size: EndpointTShirtSize, + pub model: ComputeModel, + pub readiness_probe: Option, + // Used in the PublicPreview model as the endpoint id for PG authentication + pub instance_id: Option, +} + +pub const INSTANCE_ID_LABEL_KEY: &str = "instanceId"; +pub const COMPUTE_SECONDARY_LABEL_KEY: &str = "isSecondary"; +pub const COMPUTE_ID_LABEL_KEY: &str = "computeId"; +pub const BRICKSTORE_POOL_TYPES_LABEL_KEY: &str = "brickstore-pool-types"; +// The default instance ID to use when the instance ID is not provided in PrPr model. +const DEFAULT_INSTANCE_ID: &str = "00000000-0000-0000-0000-000000000000"; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum BrcDbletNodeGroup { + Dblet2C, + Dblet4C, + Dblet8C, + Dblet16C, +} + +impl fmt::Display for BrcDbletNodeGroup { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BrcDbletNodeGroup::Dblet2C => write!(f, "dbletbrc2c"), + BrcDbletNodeGroup::Dblet4C => write!(f, "dbletbrc4c"), + BrcDbletNodeGroup::Dblet8C => write!(f, "dbletbrc8c"), + BrcDbletNodeGroup::Dblet16C => write!(f, "dbletbrc16c"), + } + } +} + +fn select_node_group_by_tshirt_size(tshirt_size: &EndpointTShirtSize) -> BrcDbletNodeGroup { + match tshirt_size { + EndpointTShirtSize::XSmall => BrcDbletNodeGroup::Dblet2C, + EndpointTShirtSize::Small => BrcDbletNodeGroup::Dblet4C, + EndpointTShirtSize::Medium => BrcDbletNodeGroup::Dblet8C, + EndpointTShirtSize::Large => BrcDbletNodeGroup::Dblet16C, + EndpointTShirtSize::Test => BrcDbletNodeGroup::Dblet4C, + } +} + +impl PgCompute { + fn choose_dblet_node_group(&self) -> String { + // If the user has provided a node selector override and it contains the pool type label key, + // use that value. Otherwise, use the tshirt size to determine the node group. + self.node_selector_override + .as_ref() + .and_then(|node_selector| node_selector.get(BRICKSTORE_POOL_TYPES_LABEL_KEY)) + .cloned() + .unwrap_or_else(|| select_node_group_by_tshirt_size(&self.tshirt_size).to_string()) + } + + fn dblet_node_selector(node_group: String) -> BTreeMap { + vec![(BRICKSTORE_POOL_TYPES_LABEL_KEY.to_string(), node_group)] + .into_iter() + .collect() + } + + fn k8s_label_compute_id(&self) -> String { + // Replace slashes with dashes as k8s does not allow slashes in labels. + self.compute_id.clone().replace('/', "-") + } + + fn dblet_tolerations(&self, node_group: String) -> Vec { + vec![ + Toleration { + key: Some("databricks.com/node-type".to_string()), + operator: Some("Equal".to_string()), + value: Some(node_group), + effect: Some("NoSchedule".to_string()), + ..Default::default() + }, + Toleration { + key: Some("dblet.dev/appid".to_string()), + operator: Some("Equal".to_string()), + value: Some(self.k8s_label_compute_id()), + effect: Some("NoSchedule".to_string()), + ..Default::default() + }, + ] + } + + // Return the node selector to be used for this compute deployment. User overrides take precedence, followed by + // resource-based selection. + pub fn get_node_selector(&self, dblet_node_group: String) -> BTreeMap { + self.node_selector_override + .clone() + .unwrap_or(Self::dblet_node_selector(dblet_node_group)) + } + + pub fn get_tolerations(&self, dblet_node_group: String) -> Vec { + self.dblet_tolerations(dblet_node_group) + } +} + +/// Represents a HadronCluster object in Kubernetes. +/// A JSON object representing this struct is stored in the cluster-config ConfigMap. +/// The ConfigMap is periodically polled by the storage controller to update the cluster. +#[derive(Serialize, Deserialize)] +pub struct HadronCluster { + /// The metadata for the resource, like kind and API version. + pub type_meta: Option, + /// The object metadata, like name and namespace. + pub object_meta: Option, + /// The specification of the HadronCluster. + pub hadron_cluster_spec: Option, +} + +/// Represents the specification of a HadronCluster object in Kubernetes. +#[derive(Serialize, Deserialize)] +pub struct HadronClusterSpec { + /// The name of the service account to access object storage. + pub service_account_name: Option, + /// The configuration for object storage such as bucket name and region. + pub object_storage_config: Option, + /// The specification for the storage broker. + pub storage_broker_spec: Option, + /// The specification for the safe keepers. It is a vector in case there are different version of safe keepers to be deployed. + pub safe_keeper_specs: Option>, + /// The specification for the page servers. It is a vector in case there are different version of page servers to be deployed. + pub page_server_specs: Option>, +} + +#[derive(Serialize, Deserialize)] +pub struct ObjectStorageConfigTestParameters { + pub endpoint: Option, + pub access_key_id: Option, + pub secret_access_key: Option, +} + +/// Represents the configuration for object storage such as bucket name and region. +/// TODO(steve.greene): Refactor this config to support at most one object storage provider at a time. +/// This would be considered a breaking change and would require careful coordination to migrate +/// existing config maps to the new config format. +#[derive(Serialize, Deserialize, Default)] +pub struct HadronObjectStorageConfig { + /// AWS S3 config options. + /// The name of the AWS S3 bucket in the object storage. + bucket_name: Option, + /// The region of the AWS S3 bucket in the object storage. + bucket_region: Option, + + /// Azure storage account options. + /// The (full) Azure storage account resource ID. + storage_account_resource_id: Option, + /// The tenant ID for Azure object storage. + azure_tenant_id: Option, + /// The Azure storage account container name. + storage_container_name: Option, + /// The Azure storage container region. + storage_container_region: Option, + + /// [Test-only] Parameters used in testing to point to a mock object store. + test_params: Option, +} + +/// Helper functions for deducing the object storage provider. +/// TODO(steve.greene): If we refactor HadronObjectStorageConfig to support at most one object storage, +/// we can most likely remove these functions. +impl HadronObjectStorageConfig { + fn is_aws(&self) -> bool { + self.bucket_name.is_some() && self.bucket_region.is_some() + } + + pub fn is_azure(&self) -> bool { + self.storage_account_resource_id.is_some() + && self.azure_tenant_id.is_some() + && self.storage_container_name.is_some() + && self.storage_container_region.is_some() + } +} + +/// Represents the specification for the storage broker. +#[derive(Serialize, Deserialize)] +pub struct HadronStorageBrokerSpec { + /// The hadron image. + pub image: Option, + /// The image pull policy. + pub image_pull_policy: Option, + /// The image pull secrets. + pub image_pull_secrets: Option>, + /// The node selector requirements. + pub node_selector: Option, + /// The resources for the storage broker. + pub resources: Option, +} + +/// Represents the specification for the safe keepers. +#[derive(Serialize, Deserialize)] +pub struct HadronSafeKeeperSpec { + /// The hadron image. + pub image: Option, + /// The image pull policy. + pub image_pull_policy: Option, + /// The image pull secrets. + pub image_pull_secrets: Option>, + /// The pool ID distinguishes between different StatefulSets if there are different version of safe keepers in one cluster. + pub pool_id: Option, + /// The number of replicas. + pub replicas: Option, + /// The node selector requirements. + pub node_selector: Option, + /// Suffix of the availability zone for this pool of safekeepers. Note that this is just the suffix, not the full availability zone name. + /// For example, if the full availability zone name is "us-west-2a", the suffix is "a". Region is assumed to be the same as where HCC runs. + pub availability_zone_suffix: Option, + /// The storage class name. + pub storage_class_name: Option, + /// The resources for the safe keeper. + pub resources: Option, + /// Whether to use low downtime maintenance to upgrade Safekeepers. + pub enable_low_downtime_maintenance: Option, + /// Whether to LDTM checks SK status. + pub enable_ldtm_sk_status_check: Option, +} + +/// Represents the specification for the page servers. +#[derive(Serialize, Deserialize)] +pub struct HadronPageServerSpec { + /// The hadron image. + pub image: Option, + /// The image pull policy. + pub image_pull_policy: Option, + /// The image pull secrets. + pub image_pull_secrets: Option>, + /// The pool ID distinguishes between different StatefulSets if there are different version of page servers in one cluster. + pub pool_id: Option, + /// The number of replicas. + pub replicas: Option, + /// The node selector requirements. + pub node_selector: Option, + /// Suffix of the availability zone for this pool of pageservers. Note that this is just the suffix, not the full availability zone name. + /// For example, if the full availability zone name is "us-west-2a", the suffix is "a". Region is assumed to be the same as where HCC runs. + pub availability_zone_suffix: Option, + /// The storage class name. + pub storage_class_name: Option, + /// The resources for the page server. + /// The storage must be specified with Gi as the suffix and only the limits value is read. + pub resources: Option, + /// Custom pageserver.toml configuration to use for this pool of pageservers. Intended to be used in dev/testing only. + /// Any content specified here is inserted verbatim to the pageserver launch script, so do not use any user-generated + /// content for this field. + pub custom_pageserver_toml: Option, + /// Parallelism to pre-pull pageserver images before starting rolling updates. + pub image_prepull_parallelism: Option, + /// Timeout for pre-pulling pageserver images before starting rolling updates. + pub image_prepull_timeout_seconds: Option, + /// Whether to use drain_and_fill to upgrade PageServers. + pub use_drain_and_fill: Option, +} + +#[automock] +#[async_trait] +pub trait K8sManager: Send + Sync { + fn get_client(&self) -> Arc; + + fn get_current_pg_params(&self) -> Result; + + fn set_pg_params(&self, params: PgParams) -> Result<(), anyhow::Error>; + + async fn deploy_compute(&self, pg_compute: PgCompute) -> kube::Result<()>; + + // Delete compute resources for a given compute name and model. + // Returns Ok(true) if all the required resources are completely cleaned up i.e. + // they are not found in k8s anymore. + // Return Ok(false) if all the required resources are called for deletion but + // the resources are still found in k8s. + // The callers can retry the deletion if Ok(false) is returned with some backoff. + async fn delete_compute( + &self, + pg_compute_name: &str, + model: ComputeModel, + ) -> kube::Result; + + async fn get_http_urls_for_compute_services(&self, service_names: Vec) -> Vec; + + async fn get_databricks_compute_settings( + &self, + workspace_url: Option, + ) -> DatabricksSettings; + + // Retrieve the k8s Service object handling the primary (read/write traffic) ingress to an instance. + async fn get_instance_primary_ingress_service( + &self, + instance_id: Uuid, + ) -> kube::Result; + + // Idempotently create or patch the k8s Service object handling primary (read/write traffic) ingress to an instance, + // so that it routes traffic to the compute Pod with the specified ID. + async fn create_or_patch_instance_primary_ingress_service( + &self, + instance_id: Uuid, + compute_id: Uuid, + service_type: K8sServiceType, + ) -> kube::Result; + + async fn create_or_patch_readable_secondary_ingress_service( + &self, + instance_id: Uuid, + ) -> kube::Result; + + // Idempotently delete the k8s Service object handling primary (read/write traffic) ingress to an instance. + // Returns true if the service was found and deleted, false if the service was not found. + async fn delete_instance_primary_ingress_service( + &self, + instance_id: Uuid, + ) -> kube::Result; +} + +/// Hadron K8sManager manages access to Kubernetes API. This object is not mutable and therefore +/// thread-safe. (Mutability and thread-safety is type-checked in Rust.) +pub struct K8sManagerImpl { + // The k8s client to use for all operations by the `K8sManager`. + pub client: Arc, + // The k8s namespace where this HCC is running. + namespace: String, + // The region this HCC runs in. + region: String, + // The reachable DNS name of the Hadron Cluster Coordinator (HCC) that is advertised to other nodes. + hcc_dns_name: String, + // The port of the HCC service that is advertised to storage nodes (PS/SK). + hcc_listening_port: u16, + // The port of the HCC service that is advertised to PG compute nodes. + hcc_compute_listening_port: u16, + // The defaults for PG compute nodes. + pg_params: RwLock, + // The configured cloud provider (when not running in tests). + cloud_provider: Option, +} + +/// Stores the Deployment and Service objects for the StorageBroker. +pub struct StorageBrokerObjs { + pub deployment: Deployment, + pub service: Service, +} + +/// Stores the StatefulSets and Service objects for SafeKeepers. +/// There may be multiple StatefulSets if there are different versions of SafeKeepers. +pub struct SafeKeeperObjs { + pub stateful_sets: Vec, + pub service: Service, +} + +/// Struct describing an image-prepull operation to be performed using an image puller DaemonSet. +const IMAGE_PREPULL_DEFAULT_TIMEOUT: Duration = Duration::from_secs(60); +pub struct ImagePullerDaemonsetInfo { + /// The DaemonSet manifest of the image puller. + pub daemonset: DaemonSet, + /// The timeout to wait for the prepull operation to complete before starting the rolling update. + /// If None, defaults to IMAGE_PREPULL_DEFAULT_TIMEOUT. + pub image_prepull_timeout: Option, +} + +/// Stores the StatefulSets and Service objects for PageServers. +/// There may be multiple StatefulSets if there are different versions of PageServers. +pub struct PageServerObjs { + pub image_puller_daemonsets: Vec, + pub stateful_sets: Vec, + pub service: Service, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum MountType { + ConfigMap, + Secret, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct K8sMount { + pub name: String, + pub mount_type: MountType, + pub mount_path: String, + pub files: Vec, +} + +pub struct K8sSecretVolumesAndMounts { + pub volumes: Vec, + pub volume_mounts: Vec, +} + +/// Stores the parameter values for PG compute nodes. +/// TODO: Find a better way to do unwraps with default values. +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub struct PgParams { + /// The namespace for compute nodes. Required. + pub compute_namespace: String, + /// The image for compute nodes. Required. + pub compute_image: String, + /// The image for the prometheus exporter that runs in the same pod as the compute nodes. Required. + pub prometheus_exporter_image: String, + /// The compute node image pull secret name. + #[serde(default = "pg_params_default_compute_image_pull_secret")] + pub compute_image_pull_secret: Option, + /// The pg port for compute nodes. + #[serde(default = "pg_params_default_compute_pg_port")] + pub compute_pg_port: Option, + /// The http port for compute nodes. + #[serde(default = "pg_params_default_compute_http_port")] + pub compute_http_port: Option, + /// The kubernetes secret name, its mount path in the container and chmod mode. + #[serde(default = "pg_params_default_compute_mounts")] + pub compute_mounts: Option>, + #[serde(default = "pg_params_pg_compute_tls_settings")] + pub pg_compute_tls_settings: Option, + #[serde(default = "pg_params_default_databricks_pg_hba")] + pub databricks_pg_hba: Option, + #[serde(default = "pg_params_default_databricks_pg_ident")] + pub databricks_pg_ident: Option, +} + +impl Default for PgParams { + fn default() -> Self { + PgParams { + compute_namespace: String::new(), + compute_image: String::new(), + prometheus_exporter_image: String::new(), + compute_image_pull_secret: pg_params_default_compute_image_pull_secret(), + compute_pg_port: pg_params_default_compute_pg_port(), + compute_http_port: pg_params_default_compute_http_port(), + compute_mounts: pg_params_default_compute_mounts(), + pg_compute_tls_settings: pg_params_pg_compute_tls_settings(), + databricks_pg_hba: pg_params_default_databricks_pg_hba(), + databricks_pg_ident: pg_params_default_databricks_pg_ident(), + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct PageServerBillingMetricsConfig { + pub metric_collection_endpoint: Option, + pub metric_collection_interval: Option, + pub synthetic_size_calculation_interval: Option, +} + +impl PageServerBillingMetricsConfig { + pub fn to_toml(&self) -> String { + let mut toml_str = String::new(); + + if let Some(ref endpoint) = self.metric_collection_endpoint { + toml_str.push_str(&format!("metric_collection_endpoint = \"{}\"\n", endpoint)); + } + if let Some(ref interval) = self.metric_collection_interval { + toml_str.push_str(&format!("metric_collection_interval = \"{}\"\n", interval)); + } + if let Some(ref synthetic_interval) = self.synthetic_size_calculation_interval { + toml_str.push_str(&format!( + "synthetic_size_calculation_interval = \"{}\"\n", + synthetic_interval + )); + } + + toml_str + } +} + +/// The information in /etc/config/config.json. +#[derive(Serialize, Deserialize)] +pub struct ConfigData { + pub hadron_cluster: Option, + pub pg_params: Option, + pub page_server_billing_metrics_config: Option, +} + +// Implement default for PgParams optional fields. +fn pg_params_default_compute_image_pull_secret() -> Option { + Some("harbor-image-pull-secret".to_string()) +} + +fn pg_params_default_compute_pg_port() -> Option { + Some(55432) +} + +fn pg_params_default_compute_http_port() -> Option { + Some(55433) +} + +fn brickstore_internal_token_verification_key_mount_path() -> String { + "/databricks/secrets/brickstore-internal-token-public-keys".to_string() +} + +fn brickstore_internal_token_verification_key_secret_mount() -> K8sMount { + K8sMount { + name: "brickstore-internal-token-public-keys".to_string(), + mount_type: MountType::Secret, + mount_path: brickstore_internal_token_verification_key_mount_path(), + files: vec!["key1.pem".to_string(), "key2.pem".to_string()], + } +} + +pub fn azure_storage_accont_service_principal_mount_path() -> String { + "/databricks/secrets/azure-service-principal".to_string() +} + +fn azure_storage_account_service_principal_secret_mount() -> K8sMount { + K8sMount { + name: "brickstore-hadron-storage-account-service-principal-secret".to_string(), + mount_type: MountType::Secret, + mount_path: azure_storage_accont_service_principal_mount_path(), + files: vec!["client.pem".to_string()], + } +} + +fn pg_params_default_compute_mounts() -> Option> { + Some(vec![ + brickstore_internal_token_verification_key_secret_mount(), + K8sMount { + name: "brickstore-domain-certs".to_string(), + mount_type: MountType::Secret, + mount_path: "/databricks/secrets/brickstore-domain-certs".to_string(), + files: vec!["server.key".to_string(), "server.crt".to_string()], + }, + K8sMount { + name: "trusted-ca-certificates".to_string(), + mount_type: MountType::Secret, + mount_path: "/databricks/secrets/trusted-ca".to_string(), + files: vec!["data-plane-misc-root-ca-cert.pem".to_string()], + }, + K8sMount { + name: "pg-compute-config".to_string(), + mount_type: MountType::ConfigMap, + mount_path: "/databricks/pg_config".to_string(), + files: vec![ + "databricks_pg_hba.conf".to_string(), + "databricks_pg_ident.conf".to_string(), + ], + }, + ]) +} + +fn pg_params_pg_compute_tls_settings() -> Option { + Some(PgComputeTlsSettings { + cert_file: "/databricks/secrets/brickstore-domain-certs/server.crt".to_string(), + key_file: "/databricks/secrets/brickstore-domain-certs/server.key".to_string(), + ca_file: "/databricks/secrets/trusted-ca/data-plane-misc-root-ca-cert.pem".to_string(), + }) +} + +fn pg_params_default_databricks_pg_hba() -> Option { + Some("/databricks/pg_config/databricks_pg_hba.conf".to_string()) +} + +fn pg_params_default_databricks_pg_ident() -> Option { + Some("/databricks/pg_config/databricks_pg_ident.conf".to_string()) +} + +/// Gets the volume claim template for PS, SK. +fn get_volume_claim_template( + resources: Option, + storage_class_name: Option, +) -> anyhow::Result>> { + let volume_resource_requirements = VolumeResourceRequirements { + requests: resources + .as_ref() + .and_then(|res| res.limits.as_ref()) + .and_then(|limits| limits.get("storage")) + .map(|storage| { + let mut map = BTreeMap::new(); + map.insert("storage".to_string(), storage.clone()); + map + }), + ..Default::default() + }; + + Ok(Some(vec![PersistentVolumeClaim { + metadata: meta::v1::ObjectMeta { + name: Some("local-data".to_string()), + ..Default::default() + }, + spec: Some(PersistentVolumeClaimSpec { + access_modes: Some(vec!["ReadWriteOnce".to_string()]), + storage_class_name: storage_class_name.clone(), + resources: Some(volume_resource_requirements), + ..Default::default() + }), + ..Default::default() + }])) +} + +/// Gets the pod metadata for PS, SK, SB. +fn get_pod_metadata(app_name: String, prometheus_port: u16) -> Option { + Some(meta::v1::ObjectMeta { + labels: Some( + vec![("app".to_string(), app_name.clone())] + .into_iter() + .collect(), + ), + annotations: { + Some( + vec![ + ("prometheus.io/path".to_string(), "/metrics".to_string()), + ( + "prometheus.io/port".to_string(), + prometheus_port.to_string(), + ), + ("prometheus.io/scrape".to_string(), "true".to_string()), + ("enableLogDaemon".to_string(), "true".to_string()), + ( + "logDaemonDockerLoggingGroup".to_string(), + "docker-common-log-group".to_string(), + ), + ] + .into_iter() + .collect(), + ) + }, + ..Default::default() + }) +} + +/// Gets the security context for PS, SK. +fn get_pod_security_context() -> Option { + Some(PodSecurityContext { + run_as_user: Some(1000), + fs_group: Some(2000), + run_as_non_root: Some(true), + ..Default::default() + }) +} + +/// Gets the volume mounts for PS, SK. +fn get_local_data_volume_mounts() -> Vec { + vec![VolumeMount { + mount_path: "/data/.neon/".to_string(), + name: "local-data".to_string(), + ..Default::default() + }] +} + +/// Gets the container ports for PS, SK, SB. +fn get_container_ports(ports: Vec) -> Option> { + Some( + ports + .iter() + .map(|port| ContainerPort { + container_port: *port, + ..Default::default() + }) + .collect(), + ) +} + +/// Gets the environment variables for PS, SK. +pub fn get_env_vars( + object_storage_config: &HadronObjectStorageConfig, + additional_env_vars: Vec, +) -> Option> { + let mut env_vars = vec![EnvVar { + name: "BROKER_ENDPOINT".to_string(), + value: Some("http://storage-broker:50051".to_string()), + ..Default::default() + }]; + + if object_storage_config.is_aws() { + env_vars.push(EnvVar { + name: "S3_BUCKET_URI".to_string(), + value: object_storage_config.bucket_name.clone(), + ..Default::default() + }); + env_vars.push(EnvVar { + name: "S3_REGION".to_string(), + value: object_storage_config.bucket_region.clone(), + ..Default::default() + }); + } + + if object_storage_config.is_azure() { + // The following azure env vars come from the object storage config directly. + env_vars.push(EnvVar { + // Need to fish out the account name for the full storage account resource ID. + // Luckily its the last component when splitting on slashes by design. + name: "AZURE_STORAGE_ACCOUNT_NAME".to_string(), + value: object_storage_config + .storage_account_resource_id + .clone() + .unwrap() + .split('/') + .last() + .map(|s| s.to_string()), + ..Default::default() + }); + env_vars.push(EnvVar { + name: "AZURE_TENANT_ID".to_string(), + value: object_storage_config.azure_tenant_id.clone(), + ..Default::default() + }); + env_vars.push(EnvVar { + name: "AZURE_STORAGE_CONTAINER_NAME".to_string(), + value: object_storage_config.storage_container_name.clone(), + ..Default::default() + }); + env_vars.push(EnvVar { + name: "AZURE_STORAGE_CONTAINER_REGION".to_string(), + value: object_storage_config.storage_container_region.clone(), + ..Default::default() + }); + + // The following azure env vars come from the mounted azure service principal secret. + env_vars.push(EnvVar { + name: "AZURE_CLIENT_ID".to_string(), + // Extrat client_id from service principal secret. + value_from: Some(EnvVarSource { + secret_key_ref: Some(SecretKeySelector { + key: "client-id".to_string(), + name: azure_storage_account_service_principal_secret_mount().name, + optional: Some(false), + }), + ..Default::default() + }), + ..Default::default() + }); + env_vars.push(EnvVar { + name: "AZURE_CLIENT_CERTIFICATE_PATH".to_string(), + value: Some(format!( + "{}/client.pem", + azure_storage_accont_service_principal_mount_path() + )), + ..Default::default() + }); + } + + // Add EnvVars used in KIND tests if specified. + if let Some(test_params) = &object_storage_config.test_params { + // AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are only used in tests to authenticate with + // the mocked S3 service. We use service accounts and STS to authenticate with S3 in production. + if let Some(key_id) = &test_params.access_key_id { + env_vars.push(EnvVar { + name: "AWS_ACCESS_KEY_ID".to_string(), + value: Some(key_id.clone()), + ..Default::default() + }); + } + if let Some(access_key) = &test_params.secret_access_key { + env_vars.push(EnvVar { + name: "AWS_SECRET_ACCESS_KEY".to_string(), + value: Some(access_key.clone()), + ..Default::default() + }); + } + } + + env_vars.extend(additional_env_vars); + + Some(env_vars) +} + +/// Gets the headless service for PS, SK. +fn get_service(name: String, namespace: String, port: i32, admin_port: i32) -> Service { + Service { + metadata: meta::v1::ObjectMeta { + name: Some(name.clone()), + namespace: Some(namespace.clone()), + ..Default::default() + }, + spec: Some(ServiceSpec { + selector: Some(BTreeMap::from([("app".to_string(), name.clone())])), + ports: Some(vec![ + ServicePort { + port, + target_port: Some(IntOrString::Int(port)), + name: Some(name.clone()), + ..Default::default() + }, + ServicePort { + port: admin_port, + target_port: Some(IntOrString::Int(admin_port)), + name: Some(format!("{}-{}", name.clone(), "admin")), + ..Default::default() + }, + ]), + cluster_ip: Some("None".to_string()), + ..Default::default() + }), + ..Default::default() + } +} + +pub async fn hash_file_contents(file_path: &str) -> std::io::Result { + // Open and read file + let mut file = tokio::fs::File::open(file_path).await?; + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer).await?; + + // Hash file + let mut hasher = Sha256::new(); + hasher.update(&buffer); + let hash_result = hasher.finalize(); + + // Convert the hash result to a hex string + let hash_hex = format!("{:x}", hash_result); + Ok(hash_hex) +} + +/// Helper function parse a workspace URL (if any) string to a proper `Url` object. Returns an error if a URL string is +/// present but it is invalid. +fn parse_to_url(url_str: Option) -> anyhow::Result> { + match url_str { + Some(url) => Url::parse(&url).map(Some).map_err(|e| anyhow::anyhow!(e)), + None => Ok(None), + } +} + +pub fn endpoint_default_resources() -> ResourceRequirements { + // Regardless T-shirt size, each PG pod requests 500m CPU and 4 GiB memory (1/4 of a 2-core node), and no limit. + // This does not reflect the actual resource usage of the PG pod, but just a short-term solution to balance operation, perf and HA: + // - perf: PG should be able to leverage all idle resources on the node. Therefore, we don't want to limit it. + // - HA: PG should be able to schedule on a node, regardless of unregulated increases from daemonsets, sidecar containers, or node allocatable resources. + // Low requests makes sure scheduling works for all pods when unregulated increases are not crazily high. + // High priority makes sure PG are not preempted by other pods when their increases are unreasonably high. + // - operation: We don't need to frequently adjust/backfill the resource requests for PG pods. Low requests make sure that. + // Note that QoS limitation still stands - when the node is under memory pressure, PG pods will be OOM killed if other pods memory usage are under their requests. + // TODO(yan): follow up with the Compute Lifecycle team about the long-term solution tracked in https://databricks.atlassian.net/browse/ES-1282825. + ResourceRequirements { + requests: Some( + vec![ + ("cpu".to_string(), Quantity("500m".to_string())), + ("memory".to_string(), Quantity("4Gi".to_string())), + ] + .into_iter() + .collect(), + ), + ..Default::default() + } +} + +impl K8sManagerImpl { + /// Creates a new K8sManager. + /// - `region`: The region in which the Hadron Cluster Coordinator runs. + /// - `namespace`: The namespace in which the Hadron Cluster Coordinator (this binary) service is running. + /// - `advertised_hcc_host`: The reachable hostname of this HCC advertised to other nodes. + /// - `advertised_hcc_port`: The reachable port of this HCC advertised to trusted storage nodes (PS/SK). + /// - `advertised_hcc_compute_port`: The reachable port of this HCC advertised to PG compute nodes. + /// - `pg_params`: Parameters used to launch Postgres compute nodes. + /// - `cloud_provider`: The cloud provider where this K8sManager is running. + #[allow(clippy::too_many_arguments)] // Clippy is too opinionated about this. + pub async fn new( + client: Arc, + region: String, + namespace: String, + advertised_hcc_host: String, + advertised_hcc_port: u16, + advertised_hcc_compute_port: u16, + pg_params: PgParams, + cloud_provider: Option, + ) -> anyhow::Result { + Ok(Self { + client, + namespace, + region, + hcc_dns_name: advertised_hcc_host, + hcc_listening_port: advertised_hcc_port, + hcc_compute_listening_port: advertised_hcc_compute_port, + pg_params: RwLock::new(pg_params), + cloud_provider, + }) + } + + #[cfg(test)] + fn new_for_test( + mock_client: Arc, + region: String, + cloud_provider: Option, + ) -> Self { + Self { + client: mock_client, + namespace: "test-namespace".to_string(), + region, + hcc_dns_name: "localhost".to_string(), + hcc_listening_port: 1234, + hcc_compute_listening_port: 1236, + pg_params: RwLock::new( + serde_json::from_str::( + r#"{ + "compute_namespace": "test-namespace", + "compute_image": "test-image", + "prometheus_exporter_image": "test-prometheus-exporter-image" + }"#, + ) + .unwrap(), + ), + cloud_provider, + } + } + + async fn k8s_create_or_replace(api: Api, name: &str, data: T) -> kube::Result<()> + where + T: Clone + Serialize + DeserializeOwned + Debug, + { + if (api.get_opt(name).await?).is_some() { + api.replace(name, &PostParams::default(), &data).await?; + } else { + api.create(&PostParams::default(), &data).await?; + } + Ok(()) + } + + async fn k8s_get(api: Api, name: &str) -> kube::Result> + where + T: Clone + Serialize + DeserializeOwned + Debug, + { + api.get_opt(name).await + } + + // Wait for a DaemonSet to reach the desired state. + // Returns an error if we cannot confirm that the DaemonSet has reached the desired state by the specified deadline. + async fn k8s_wait_for_daemonset( + api: Api, + name: &str, + deadline: Instant, + ) -> anyhow::Result<()> { + let mut interval = tokio::time::interval(Duration::from_secs(10)); + let timeout = tokio::time::timeout_at(deadline.into(), async { + // Poll the image puller daemonset pods every 10 seconds and inspect statuses, until timeout. + loop { + match api.get(name).await { + Ok(ds) => { + if let Some(status) = ds.status { + let generation_is_current = ds.metadata.generation.is_some() + && (ds.metadata.generation == status.observed_generation); + let all_pods_updated = status.desired_number_scheduled + == status.updated_number_scheduled.unwrap_or( + if let Some(observed_generation) = status.observed_generation { + if observed_generation == 1 { + // If observed_generation is 1, the DaemonSet has just been created and the + // `updated_number_scheduled` field is not populated. We simply use the + // current number of scheduled pods. + // Note that generation number starts from 1 in Kubernetes. + status.current_number_scheduled + } else { + // For all other generations we expect `update_number_scheduled` to be + // (eventually) present. If it's not present it means no pods have been + // updated yet, so we return 0 here. + 0 + } + } else { + // We don't know anything if there isn't an `observed_generation`, so just bail + // out with -1 here to force a retry. + -1 + }, + ); + if generation_is_current && all_pods_updated { + tracing::info!( + "DaemonSet {} is ready, successfully updated {} pods", + name, + status.desired_number_scheduled + ); + break; + } else { + tracing::info!( + "DaemonSet {} is not ready yet. Generation is current: {}, All pods updated: {}", + name, + generation_is_current, + all_pods_updated + ); + } + } + } + Err(e) => { + tracing::warn!("Error retriving status of DaemonSet {}: {:?}", name, e); + } + } + interval.tick().await; + } + }); + + timeout.await.map_err(|_| { + anyhow::anyhow!( + "Cannot confirm DaemonSet {} reached the desired state by {:?}", + name, + deadline + ) + }) + } + + pub(crate) async fn k8s_get_stateful_set( + api: &Api, + name: &str, + ) -> Option { + match api.get(name).await { + Ok(stateful_set) => Some(stateful_set), + Err(e) => { + tracing::info!( + "Error retrieving StatefulSet {}, treating it as non-existent. Error: {:?}", + name, + e + ); + None + } + } + } + + pub(crate) fn k8s_is_update_paused(stateful_set: &StatefulSet) -> bool { + stateful_set + .spec + .update_strategy + .as_ref() + .and_then(|update_strategy| update_strategy.rolling_update.as_ref()) + .and_then(|rolling_update| rolling_update.paused) + .unwrap_or(false) + } + + /** + * Return all pods for a given stateful set. K8s does not have an API to do this directly, so here + * we first fetch all pods based on the app_name, then filter by ownerReferences. + */ + pub(crate) async fn k8s_get_stateful_set_pods( + api: &Api, + name: &str, + app_name: &str, + ) -> anyhow::Result> { + let params = ListParams::default().labels(format!("app={}", app_name).as_str()); + let pods = api.list(¶ms).await?; + let mut result = Vec::new(); + for pod in pods { + if let Some(owner_references) = pod.metadata.owner_references.as_ref() { + if owner_references + .iter() + .any(|owner| owner.kind == "StatefulSet" && owner.name == name) + { + result.push(pod); + } + } + } + Ok(result) + } + + pub(crate) fn k8s_set_stateful_set_paused( + name: &str, + stateful_set: &mut StatefulSet, + paused: bool, + ) -> anyhow::Result<()> { + let update_strategy = stateful_set.spec.update_strategy.as_mut().ok_or_else(|| { + anyhow::anyhow!("StatefulSet {} does not have an update strategy", name) + })?; + let rolling_update = update_strategy.rolling_update.as_mut().ok_or_else(|| { + anyhow::anyhow!("StatefulSet {} does not have a rolling update", name) + })?; + rolling_update.paused = Some(paused); + Ok(()) + } + + pub(crate) async fn k8s_replace_stateful_set( + api: &Api, + stateful_set: &StatefulSet, + name: &str, + ) -> anyhow::Result<()> { + api.replace(name, &PostParams::default(), stateful_set) + .await?; + + Ok(()) + } + + pub(crate) async fn k8s_create_or_replace_advanced_stateful_set( + api: Api, + name: &str, + data: StatefulSet, + ) -> anyhow::Result<()> { + let existing_stateful_set: Option = + Self::k8s_get_stateful_set(&api, name).await; + + // If the StatefulSet is marked as "paused" explicitly, we have a manual operation (likely node pool rotation) + // going on and the HCC should not touch this StatefulSet until this marker is removed. + if let Some(sts) = existing_stateful_set.as_ref() { + if Self::k8s_is_update_paused(sts) { + return Err(anyhow::anyhow!( + "StatefulSet {} is paused, not proceeding with the update", + name + )); + } + } + + let mut data = data; + data.metadata.resource_version = + existing_stateful_set.and_then(|sts| sts.metadata.resource_version); + + // If current resource is None, then create otherwise replace + if data.metadata.resource_version.is_none() { + api.create(&PostParams::default(), &data).await?; + } else { + api.replace(name, &PostParams::default(), &data).await?; + } + + Ok(()) + } + + // Get kubernetes volumes and mounts for hadron kubernete secrets + // The secrets are deployed using HCC release pipeline in universe repostiory. + fn get_hadron_volumes_and_mounts(&self, mounts: Vec) -> K8sSecretVolumesAndMounts { + let mut volumes = Vec::new(); + let mut volume_mounts = Vec::new(); + + for mount in mounts { + if mount.mount_type == MountType::ConfigMap { + volumes.push(Volume { + name: mount.name.clone(), + config_map: Some(ConfigMapVolumeSource { + name: mount.name.clone(), + ..Default::default() + }), + ..Default::default() + }); + } else { + volumes.push(Volume { + name: mount.name.clone(), + secret: Some(SecretVolumeSource { + secret_name: Some(mount.name.clone()), + ..Default::default() + }), + ..Default::default() + }); + } + + // Create and add the VolumeMount + volume_mounts.push(VolumeMount { + name: mount.name.clone(), + read_only: Some(true), + mount_path: mount.mount_path.to_string(), + ..Default::default() + }); + } + + K8sSecretVolumesAndMounts { + volumes, + volume_mounts, + } + } + + pub async fn deploy_compute(&self, pg_compute: PgCompute) -> kube::Result<()> { + let client = Arc::clone(&self.client); + let ( + pg_params_compute_namespace, + pg_params_compute_http_port, + pg_params_compute_pg_port, + pg_params_compute_image, + pg_params_compute_image_pull_secret, + pg_params_secret_mounts, + pg_params_prometheus_exporter_image, + ) = { + let pg_params = self.pg_params.read().expect("pg_params lock poisoned"); + ( + pg_params.compute_namespace.clone(), + pg_params + .compute_http_port + .unwrap_or(pg_params_default_compute_http_port().unwrap()), + pg_params + .compute_pg_port + .unwrap_or(pg_params_default_compute_pg_port().unwrap()), + pg_params.compute_image.clone(), + pg_params.compute_image_pull_secret.clone(), + pg_params + .compute_mounts + .clone() + .unwrap_or(pg_params_default_compute_mounts().unwrap()), + pg_params.prometheus_exporter_image.clone(), + ) + }; + + // echo 0x31 > /proc/self/coredump_filter is to disable dumping shared buffers. + // Child processes (PG) inherit this setting. + let launch_command = format!( + r#"echo 0x31 > /proc/self/coredump_filter +if [ $? -ne 0 ]; then + echo "Failed to set coredump filter" + exit 1 +fi + +shutdown() {{ + echo "Shutting down compute_ctl running at pid $pid" + kill -TERM $pid + wait $pid +}} + +trap shutdown TERM + +/usr/local/bin/compute_ctl --http-port {http_port} \ + --pgdata "/var/db/postgres/compute" \ + --connstr "postgresql://cloud_admin@localhost:{pg_port}/postgres" \ + --compute-id "{compute_id}" \ + --control-plane-uri "http://{advertised_host}:{advertised_port}/hadron" \ + --pgbin /usr/local/bin/postgres & + +pid=$! + +wait +"#, + http_port = pg_params_compute_http_port, + pg_port = pg_params_compute_pg_port, + compute_id = pg_compute.compute_id.clone(), + advertised_host = self.hcc_dns_name, + advertised_port = self.hcc_compute_listening_port + ); + + // Choose a dblet node group to use for this compute node based on resource settings. + let selected_node_group = pg_compute.choose_dblet_node_group(); + let secret_volumes_and_mounts = self.get_hadron_volumes_and_mounts(pg_params_secret_mounts); + let node_selector = pg_compute.get_node_selector(selected_node_group.clone()); + let tolerations = pg_compute.get_tolerations(selected_node_group.clone()); + + let pg_exporter_launch_command = format!( + r#"#!/bin/sh +# queries.yaml is technically deprecated +# define extra metrics to collect +cat < /tmp/queries.yaml +# metric name: pg_backpressure_throttling_time +pg_backpressure: + query: "SELECT backpressure_throttling_time AS throttling_time FROM neon.backpressure_throttling_time()" + metrics: + - throttling_time: + usage: "COUNTER" + description: "Total time spent throttling since the system was started." +# metric name: pg_lfc_hits, etc. +# All the MAX(CASE WHEN...) lines are needed to transpose rows into columns. +pg_lfc: + query: " +WITH lfc_stats AS ( + SELECT stat_name, count + FROM neon.neon_get_lfc_stats() AS t(stat_name text, count bigint) +) +SELECT + MAX(count) FILTER (WHERE stat_name = 'file_cache_misses') AS misses, + MAX(count) FILTER (WHERE stat_name = 'file_cache_hits') AS hits, + MAX(count) FILTER (WHERE stat_name = 'file_cache_used') AS used, + MAX(count) FILTER (WHERE stat_name = 'file_cache_writes') AS writes, + MAX(count) FILTER (WHERE stat_name = 'file_cache_writes_eviction') AS writes_eviction, + MAX(count) FILTER (WHERE stat_name = 'file_cache_writes_ps_read') AS writes_ps_read, + MAX(count) FILTER (WHERE stat_name = 'file_cache_writes_extend') AS writes_extend, + MAX(count) FILTER (WHERE stat_name = 'file_cache_used_pages') AS pages, + MAX(count) FILTER (WHERE stat_name = 'outstanding_reads') AS outstanding_reads, + MAX(count) FILTER (WHERE stat_name = 'outstanding_writes') AS outstanding_writes, + MAX(count) FILTER (WHERE stat_name = 'skipped_writes') AS skipped_writes, + MAX(count) FILTER (WHERE stat_name = 'file_cache_evictions') AS evictions, + MAX(count) FILTER (WHERE stat_name = 'cumulative_read_time') AS cumulative_read_time, + MAX(count) FILTER (WHERE stat_name = 'cumulative_write_time') AS cumulative_write_time, + MAX(count) FILTER (WHERE stat_name = 'blocks_per_chunk') AS blocks_per_chunk, + pg_size_bytes(current_setting('neon.file_cache_size_limit')) as size_limit +FROM lfc_stats" + metrics: + - misses: + usage: "COUNTER" + description: "Number of GetPage@LSN requests not found in LFC. Same as Hadron Storage GetPage QPS." + - hits: + usage: "COUNTER" + description: "Number of GetPage@LSN requests satisfied by the LFC." + - used: + usage: "GAUGE" + description: "Number of chunks in LFC." + - writes: + usage: "COUNTER" + description: "Total number of writes to LFC for any reason." + - writes_eviction: + usage: "COUNTER" + description: "Number of writes to LFC due to buffer pool eviction." + - writes_ps_read: + usage: "COUNTER" + description: "Number of writes to LFC due to page server read." + - writes_extend: + usage: "COUNTER" + description: "Number of writes to LFC due to extending a relation/file." + - pages: + usage: "GAUGE" + description: "Number of live pages in LFC." + - outstanding_reads: + usage: "GAUGE" + description: "Number of outstanding reads IOs." + - outstanding_writes: + usage: "GAUGE" + description: "Number of outstanding write IOs." + - skipped_writes: + usage: "COUNTER" + description: "Number of LFC writes skipped to to too many outstanding IOs." + - evictions: + usage: "COUNTER" + description: "Number of LFC evictions." + - cumulative_read_time: + usage: "COUNTER" + description: "Cumulative time spent reading from LFC." + - cumulative_write_time: + usage: "COUNTER" + description: "Cumulative time spent writing to LFC." + - blocks_per_chunk: + usage: "GAUGE" + description: "Number of pages/blocks per chunk." + - size_limit: + usage: "GAUGE" + description: "Currently set LFC max size in bytes (dynamically configurable)." +# metric name: pg_lfc_working_set_size{{duration=1m, 5m, 15m, 1h}} +pg_lfc_working_set: + query: " + select + x as duration, + neon.approximate_working_set_size_seconds(extract('epoch' from x::interval)::int)::bigint*8192 as size + from (values ('1m'), ('5m'),('15m'),('1h')) as t (x)" + metrics: + - duration: + usage: "LABEL" + description: "Estimation window." + - size: + usage: "GAUGE" + description: "Estimated working set size in bytes." +# metric name: pg_writable_bool +pg_writable: + query: "SELECT health_check_write_succeeds AS bool FROM health_check_write_succeeds()" + metrics: + - bool: + usage: "GAUGE" + description: "Whether the last write probe succeeded (1 for yes, 0 for no)." +# metric name: pg_cluster_size_bytes +pg_cluster_size: + query: "SELECT pg_cluster_size AS bytes FROM neon.pg_cluster_size()" + metrics: + - bytes: + usage: "GAUGE" + description: "Hadron logical size." +# metric name: pg_snapshot_files_count +pg_snapshot_files: + query: "SELECT COUNT(*) AS count FROM pg_ls_dir('/var/db/postgres/compute/pg_logical/snapshots/') WHERE pg_ls_dir LIKE '%.snap'" + cache_seconds: 120 + metrics: + - count: + usage: "GAUGE" + description: "Number of .snap files currently accumulated on the Postgres compute node" +# metric name: getpage, data_corruptions, etc. +# All the MAX(CASE WHEN...) lines are needed to transpose rows into columns. +pg_metrics: + query: " +WITH pg_perf_counters AS ( + SELECT metric, value + FROM neon.neon_perf_counters +) +SELECT + MAX(value) FILTER (WHERE metric = 'sql_index_corruption_count') AS sql_index_corruption_count, + MAX(value) FILTER (WHERE metric = 'sql_data_corruption_count') AS sql_data_corruption_count, + MAX(value) FILTER (WHERE metric = 'sql_internal_error_count') AS sql_internal_error_count, + MAX(value) FILTER (WHERE metric = 'ps_corruption_detected') AS ps_corruption_detected, + MAX(value) FILTER (WHERE metric = 'getpage_wait_seconds_count') AS getpage_wait_seconds_count, + MAX(value) FILTER (WHERE metric = 'getpage_wait_seconds_sum') AS getpage_wait_seconds_sum, + MAX(value) FILTER (WHERE metric = 'file_cache_read_wait_seconds_count') AS file_cache_read_wait_seconds_count, + MAX(value) FILTER (WHERE metric = 'file_cache_read_wait_seconds_sum') AS file_cache_read_wait_seconds_sum, + MAX(value) FILTER (WHERE metric = 'file_cache_write_wait_seconds_count') AS file_cache_write_wait_seconds_count, + MAX(value) FILTER (WHERE metric = 'file_cache_write_wait_seconds_sum') AS file_cache_write_wait_seconds_sum, + MAX(value) FILTER (WHERE metric = 'pageserver_disconnects_total') AS pageserver_disconnects_total, + MAX(value) FILTER (WHERE metric = 'pageserver_open_requests') AS pageserver_open_requests, + MAX(value) FILTER (WHERE metric = 'num_active_safekeepers') AS num_active_safekeepers, + MAX(value) FILTER (WHERE metric = 'num_configured_safekeepers') AS num_configured_safekeepers, + MAX(value) FILTER (WHERE metric = 'max_active_safekeeper_commit_lag') AS max_active_safekeeper_commit_lag +FROM pg_perf_counters" + metrics: + - sql_index_corruption_count: + usage: "COUNTER" + description: "Number of index corruption errors." + - sql_data_corruption_count: + usage: "COUNTER" + description: "Number of data corruption errors." + - sql_internal_error_count: + usage: "COUNTER" + description: "Number of internal errors." + - ps_corruption_detected: + usage: "COUNTER" + description: "Number of page server corruption errors." + - getpage_wait_seconds_count: + usage: "COUNTER" + description: "Number of GetPage@LSN waits." + - getpage_wait_seconds_sum: + usage: "COUNTER" + description: "Number of GetPage@LSN wait seconds." + - file_cache_read_wait_seconds_count: + usage: "COUNTER" + description: "Number of file cache read waits." + - file_cache_read_wait_seconds_sum: + usage: "COUNTER" + description: "Number of file cache read wait seconds." + - file_cache_write_wait_seconds_count: + usage: "COUNTER" + description: "Number of file cache write waits." + - file_cache_write_wait_seconds_sum: + usage: "COUNTER" + description: "Number of file cache write wait seconds." + - pageserver_disconnects_total: + usage: "COUNTER" + description: "Number of page server disconnects." + - pageserver_open_requests: + usage: "GAUGE" + description: "Number of open page server requests." + - num_active_safekeepers: + usage: "GAUGE" + description: "Number of active safekeepers." + - num_configured_safekeepers: + usage: "GAUGE" + description: "Number of configured safekeepers." + - max_active_safekeeper_commit_lag: + usage: "GAUGE" + description: "Maximum commit lag (in LSN bytes) among active safekeepers." +EOF + +# constantLabels is technically deprecated +/bin/postgres_exporter --constantLabels=pg_endpoint={0},pg_compute_id={1},pg_instance_id={2} \ + --extend.query-path="/tmp/queries.yaml" --no-collector.database"#, + pg_compute.name.clone(), + pg_compute.k8s_label_compute_id(), + pg_compute + .instance_id + .clone() + .unwrap_or(DEFAULT_INSTANCE_ID.to_string()), + ); + + // Optionally defines a readiness probe that can be used by Kubernetes to when the container + // is ready to start serving traffic. Note that we use a readiness probe rather than startup + // or liveness check as Kubernetes will kill/restart containers if it fails those checks but + // if we use a readiness check it will just stop directing traffic away from it while keeping + // the container running. We want to have this behaviour as we want our compute manager, rather + // than Kubernetes, to be responsible for killing/restarting compute. + let readiness_probe = pg_compute + .readiness_probe + .clone() + .map(|readiness_probe_params| Probe { + failure_threshold: Some(readiness_probe_params.failure_threshold), + http_get: Some(HTTPGetAction { + path: Some(readiness_probe_params.endpoint_path.clone()), + port: IntOrString::Int(pg_params_compute_http_port as i32), + ..Default::default() + }), + initial_delay_seconds: Some(readiness_probe_params.initial_delay_seconds), + period_seconds: Some(readiness_probe_params.period_seconds), + success_threshold: Some(readiness_probe_params.success_threshold), + timeout_seconds: Some(readiness_probe_params.timeout_seconds), + ..Default::default() + }); + + let pod_template_spec = PodTemplateSpec { + metadata: Some(meta::v1::ObjectMeta { + annotations: Some( + vec![ + ( + // This annotation is consumed by networking components to allow egress traffic + // from this compute Pod to the workspace URL/control plane. The networking components + // require that the "URL" to be just the hostname, without any schemes or trailing "/". + "databricks.com/workspace-url".to_string(), + pg_compute.workspace_url.clone() + .map(|url| url.host_str().unwrap_or_default().to_string()) + .unwrap_or("".to_string()), + ), + ("enableLogDaemon".to_string(), "true".to_string()), + ] + .into_iter() + .filter(|(_, v)| !v.is_empty()) + .collect(), + ), + labels: Some( + vec![ + ("app".to_string(), pg_compute.name.clone()), + ( + "dblet.dev/appid".to_string(), + pg_compute.k8s_label_compute_id(), + ), + ("dblet.dev/managed".to_string(), "true".to_string()), + ("orgId".to_string(), pg_compute.workspace_id.clone().unwrap_or_default()), + ("hadron-component".to_string(), "compute".to_string()), + (COMPUTE_ID_LABEL_KEY.to_string(), pg_compute.k8s_label_compute_id()), + (INSTANCE_ID_LABEL_KEY.to_string(), pg_compute.instance_id.clone().unwrap_or("".to_string())), + ] + .into_iter() + .filter(|(_, v)| !v.is_empty()) + .collect(), + ), + ..Default::default() + }), + spec: Some(PodSpec { + // Pod Anti Affinity: Do not schedule to nodes already have compute pods running. + // dblet nodes are single-compute. They are single-tenant and use NodePort for + // Pod networking, so running multiple compute Pods on the same node is not + // going to work. + affinity: Some(Affinity { + pod_anti_affinity: Some(PodAntiAffinity { + required_during_scheduling_ignored_during_execution: Some(vec![ + PodAffinityTerm { + label_selector: Some(meta::v1::LabelSelector { + match_labels: Some( + vec![( + "hadron-component".to_string(), + "compute".to_string(), + )] + .into_iter() + .collect(), + ), + ..Default::default() + }), + topology_key: "kubernetes.io/hostname".to_string(), + ..Default::default() + }, + ]), + ..Default::default() + }), + ..Default::default() + }), + node_selector: Some(node_selector), + tolerations: Some(tolerations), + image_pull_secrets: pg_params_compute_image_pull_secret.map( + |secret_name| { + vec![LocalObjectReference { + name: secret_name, + }] + }, + ), + priority_class_name: Some("pg-compute".to_string()), + init_containers: Some(vec![Container { + name: "compute-local-ssd-init".to_string(), + image: Some( + pg_compute + .image_override + .clone() + .unwrap_or_else(|| pg_params_compute_image.clone()), + ), + // Copy the pgdata dir to the mounted local SSD to transparently + // initialize any existing PG data/metadata set in the docker + // image before the compute container starts. Use `mv` with the + // `-f` flag in place of `cp` to overwrite any existing data explicitly. + // + // While root, create brickstore dir for logs. dbctl runs as root group, so need to give + // at least 550 so that it can cd to serve dbctl app-logs. Need at least 007 so that + // PG can create new log files to write logs to in the directory, so 777 is most sensible. + // + // While we are running as root, also configure the *hosts* core dump settings. + // In vanilla neon, `compute_ctl` expects core dumps to be written to the PG data directory + // in a specific format as mentioned in https://github.com/neondatabase/autoscaling/issues/784. + // Since `compute_ctl` wipes the data dir across container restarts, we have patched it to + // read core dump files from a fixed path (`/databricks/logs/brickstore/core`) instead, + // since that will *always* persist container restarts. + command: Some(vec!["/bin/bash".to_string(), "-c".to_string()]), + args: Some(vec![ + "rm -rf /local_ssd/compute && mv /var/db/postgres/compute /local_ssd/ && \ + mkdir -p -m 777 /databricks/logs/brickstore && \ + echo '/databricks/logs/brickstore/core' > /proc/sys/kernel/core_pattern && \ + echo '1' > /proc/sys/kernel/core_uses_pid".to_string(), + ]), + volume_mounts: Some(vec![VolumeMount { + name: "local-ssd".to_string(), + mount_path: "/local_ssd".to_string(), + ..Default::default() + },VolumeMount { + name: "logs".to_string(), + mount_path: "/databricks/logs".to_string(), + ..Default::default() + }]), + security_context: Some(SecurityContext { + // The init container runs as root so that it can create the pgdata directory. + // The compute container runs as the postgres user (and should not + // run as the root user) afterwards using an explicit subdir mount + // on /local_ssd/compute for file isolation purposes (the dblet + // host may leverage this local disk for image layer storage, etc). + privileged: Some(true), + run_as_user: Some(0), + ..Default::default() + }), + ..Default::default() + }]), + containers: vec![Container { + name: "compute".to_string(), + image: Some( + pg_compute + .image_override + .clone() + .unwrap_or_else(|| pg_params_compute_image.clone()), + ), + ports: Some(vec![ + ContainerPort { + container_port: pg_params_compute_pg_port as i32, + ..Default::default() + }, + ContainerPort { + container_port: pg_params_compute_http_port as i32, + ..Default::default() + }, + ]), + env: Some({ + let mut env = vec![EnvVar { + name: "NEON_CONTROL_PLANE_TOKEN".to_string(), + value: Some(pg_compute.control_plane_token.clone()), + ..Default::default() + }]; + if pg_compute.instance_id.is_some() { + env.push(EnvVar { + name: "INSTANCE_ID".to_string(), + value: pg_compute.instance_id.clone(), + ..Default::default() + }); + }; + env + }), + resources: Some( + pg_compute + .resources + .clone() + ), + command: Some(vec!["/bin/sh".to_string(), "-c".to_string()]), + args: Some(vec![launch_command]), + volume_mounts: Some( + [ + &secret_volumes_and_mounts.volume_mounts[..], + &[VolumeMount { + name: "local-ssd".to_string(), + // Explicitly mount over the PGDATA dir with the pre-initialized + // local-SSD path. + mount_path: "/var/db/postgres/compute".to_string(), + // Important that we mount the compute subpath here to avoid giving + // pg visibility to any other data that may exist on the local disk. + sub_path: Some("compute".to_string()), + ..Default::default() + },VolumeMount { + name: "logs".to_string(), + mount_path: "/databricks/logs".to_string(), + ..Default::default() + },VolumeMount { + // Default shared memory size is 64MB, which is not enough for PG. + // See https://stackoverflow.com/questions/43373463/how-to-increase-shm-size-of-a-kubernetes-container-shm-size-equivalent-of-doc. + name: "dshm".to_string(), + mount_path: "/dev/shm".to_string(), + ..Default::default() + }], + ] + .concat(), + ), + readiness_probe, + ..Default::default() + }, Container { + name: "prometheus-exporter".to_string(), + image: Some( + pg_compute + .exporter_image_override.clone() + .unwrap_or_else(|| pg_params_prometheus_exporter_image.clone()), + ), + ports: Some(vec![ + ContainerPort { + container_port: 9187, + name: Some("info-service".to_string()), + ..Default::default() + } + ]), + env: Some(vec![EnvVar { + name: "DATA_SOURCE_NAME".to_string(), + value: Some("user=databricks_monitor host=127.0.0.1 port=55432 sslmode=disable database=databricks_system".to_string()), + ..Default::default() + }]), + resources: Some(ResourceRequirements { + requests: Some(BTreeMap::from([ + ("cpu".to_string(), Quantity("50m".to_string())), + ("memory".to_string(), Quantity("128Mi".to_string())), + ])), + limits: Some(BTreeMap::from([ + ("cpu".to_string(), Quantity("200m".to_string())), + ("memory".to_string(), Quantity("128Mi".to_string())), + ])), + ..Default::default() + }), + command: Some(vec!["/bin/sh".to_string(), "-c".to_string()]), + args: Some(vec![pg_exporter_launch_command]), + volume_mounts: Some(secret_volumes_and_mounts.volume_mounts), + ..Default::default() + }, Container { + name: "pg-log-redactor".to_string(), + image: Some( + pg_compute + .image_override + .clone() + .unwrap_or_else(|| pg_params_compute_image.clone()), + ), + resources: Some(ResourceRequirements { + requests: Some(BTreeMap::from([ + ("cpu".to_string(), Quantity("50m".to_string())), + ("memory".to_string(), Quantity("128Mi".to_string())), + ])), + limits: Some(BTreeMap::from([ + ("cpu".to_string(), Quantity("200m".to_string())), + ("memory".to_string(), Quantity("128Mi".to_string())), + ])), + ..Default::default() + }), + command: Some(vec!["python3".to_string()]), + args: Some(vec!["/usr/local/bin/pg_log_redactor.py".to_string()]), + volume_mounts: Some( + vec![VolumeMount { + name: "logs".to_string(), + mount_path: "/databricks/logs".to_string(), + ..Default::default() + }], + ), + ..Default::default() + }], + volumes: Some( + [ + &secret_volumes_and_mounts.volumes[..], + &[Volume { + name: "local-ssd".to_string(), + host_path: Some(HostPathVolumeSource { + path: "/local_disk0".to_string(), + // Require the directory to exist beforehand. We could in + // theory set this to `DirectoryOrCreate`, but that might require + // the entire pod to run as root in some cases, which we want to avoid. + type_: Some("Directory".to_string()), + }), + ..Default::default() + },Volume { + name: "logs".to_string(), + // Empty dir means we lose local logs on pod restart. This is ok because + // log-daemon backs them up. + empty_dir: Some(EmptyDirVolumeSource::default()), + ..Default::default() + },Volume { + // Default shared memory size is 64MB, which is not enough for PG. + // See https://stackoverflow.com/questions/43373463/how-to-increase-shm-size-of-a-kubernetes-container-shm-size-equivalent-of-doc. + name: "dshm".to_string(), + empty_dir: Some(EmptyDirVolumeSource { medium: Some("Memory".to_string()), size_limit: None }), + ..Default::default() + }], + ] + .concat(), + ), + ..Default::default() + }), + }; + + let admin_service = Service { + metadata: meta::v1::ObjectMeta { + name: Some(format!("{}-admin", &pg_compute.name)), + ..Default::default() + }, + spec: Some(ServiceSpec { + ports: Some(vec![ServicePort { + port: 80, + target_port: Some(IntOrString::Int(pg_params_compute_http_port as i32)), + ..Default::default() + }]), + type_: Some("ClusterIP".to_string()), + selector: Some( + vec![("app".to_string(), pg_compute.name.clone())] + .into_iter() + .collect(), + ), + ..Default::default() + }), + ..Default::default() + }; + + tracing::info!("Deploying compute node {:?}", pg_compute.name); + tracing::debug!("Compute Admin Service: {:?}", admin_service); + + match pg_compute.model { + ComputeModel::PrivatePreview => { + let lb_service = self + .build_loadbalancer_service(pg_params_compute_pg_port, pg_compute.name.clone()); + tracing::debug!("PrPr LB Service: {:?}", lb_service); + + self.create_k8s_deployment( + &client, + &pg_compute, + pg_params_compute_namespace.clone(), + pod_template_spec, + ) + .await?; + Self::k8s_create_or_replace( + Api::namespaced((*client).clone(), pg_params_compute_namespace.as_str()), + &pg_compute.name, + lb_service, + ) + .await?; + Self::k8s_create_or_replace( + Api::namespaced((*client).clone(), pg_params_compute_namespace.as_str()), + &admin_service.metadata.name.clone().unwrap(), + admin_service, + ) + .await + } + ComputeModel::PublicPreview => { + self.create_k8s_replica_set( + &client, + &pg_compute, + pg_params_compute_namespace.clone(), + pod_template_spec, + ) + .await?; + // Note that in the Public Preview model, we don't create any k8s Service objects here to handle postgres protocol + // ingress. The ingress Service objects are created via trait methods `create_or_patch_cluster_primary_ingress_service()` + // and `create_or_patch_instance_ingress_service()`, invoked directly from reconcilers in `compute_manager/src/reconcilers`. + Self::k8s_create_or_replace( + Api::namespaced((*client).clone(), pg_params_compute_namespace.as_str()), + &admin_service.metadata.name.clone().unwrap(), + admin_service, + ) + .await + } + } + } + + fn allow_not_found_error( + result: kube::Error, + ) -> kube::Result> { + match result { + kube::Error::Api(e) if e.code == 404 => { + Ok(Either::Right(kube::core::Status::success())) + } + e => Err(e), + } + } + + fn get_hadron_compute_spec_metadata(&self, name: String) -> meta::v1::ObjectMeta { + meta::v1::ObjectMeta { + name: Some(name), + annotations: Some( + vec![("hadron.dev/managed".to_string(), "true".to_string())] + .into_iter() + .collect(), + ), + ..Default::default() + } + } + + fn get_hadron_compute_label_selector(&self, name: String) -> meta::v1::LabelSelector { + meta::v1::LabelSelector { + match_labels: Some(vec![("app".to_string(), name)].into_iter().collect()), + ..Default::default() + } + } + + async fn create_k8s_deployment( + &self, + client: &Client, + pg_compute: &PgCompute, + pg_params_compute_namespace: String, + pod_template_spec: PodTemplateSpec, + ) -> kube::Result<()> { + let deployment = Deployment { + metadata: self.get_hadron_compute_spec_metadata(pg_compute.name.clone()), + spec: Some(DeploymentSpec { + replicas: Some(1), + selector: self.get_hadron_compute_label_selector(pg_compute.name.clone()), + template: pod_template_spec, + ..Default::default() + }), + ..Default::default() + }; + + tracing::debug!("K8s deployment: {:?}", deployment); + + Self::k8s_create_or_replace( + Api::namespaced((*client).clone(), pg_params_compute_namespace.as_str()), + &pg_compute.name, + deployment, + ) + .await + } + + async fn create_k8s_replica_set( + &self, + client: &Client, + pg_compute: &PgCompute, + pg_params_compute_namespace: String, + pod_template_spec: PodTemplateSpec, + ) -> kube::Result<()> { + let replica_set = ReplicaSet { + metadata: self.get_hadron_compute_spec_metadata(pg_compute.name.clone()), + spec: Some(ReplicaSetSpec { + replicas: Some(1), + selector: self.get_hadron_compute_label_selector(pg_compute.name.clone()), + template: Some(pod_template_spec), + ..Default::default() + }), + ..Default::default() + }; + + tracing::debug!("K8s replica set: {:?}", replica_set); + + Self::k8s_create_or_replace( + Api::namespaced((*client).clone(), pg_params_compute_namespace.as_str()), + &pg_compute.name, + replica_set, + ) + .await + } + + fn get_loadbalancer_annotations( + &self, + dns_name_hint: &str, + ) -> Option> { + match self.cloud_provider { + Some(CloudProvider::AWS) => { + Some( + vec![ + // AWS specific annotations. + // xref https://kubernetes-sigs.github.io/aws-load-balancer-controller/v2.2/guide/service/annotations/ + ( + "service.beta.kubernetes.io/aws-load-balancer-type".to_string(), + "external".to_string(), + ), + ( + "service.beta.kubernetes.io/aws-load-balancer-nlb-target-type" + .to_string(), + "ip".to_string(), + ), + ( + "service.beta.kubernetes.io/aws-load-balancer-scheme".to_string(), + "internet-facing".to_string(), + ), + ] + .into_iter() + .collect(), + ) + } + Some(CloudProvider::Azure) => { + Some( + vec![ + // Azure specific annotations. + // xref https://cloud-provider-azure.sigs.k8s.io/topics/loadbalancer/#loadbalancer-annotations + ( + "service.beta.kubernetes.io/azure-dns-label-name".to_string(), + dns_name_hint.to_string(), + ), + ( + "service.beta.kubernetes.io/azure-load-balancer-internal".to_string(), + "false".to_string(), + ), + ] + .into_iter() + .collect(), + ) + } + _ => None, + } + } + + fn build_loadbalancer_service( + &self, + pg_params_compute_pg_port: u16, + pg_compute_name: String, + ) -> Service { + if std::env::var("HADRON_STRESS_TEST_MODE").is_ok() { + // In stress test mode, we deploy lots of computes, so we don't want to create LoadBalancer Service. + // Instead, we create ClusterIP Services and deploy benchmark Pods within the cluster. + return Service { + metadata: meta::v1::ObjectMeta { + name: Some(pg_compute_name.clone()), + ..Default::default() + }, + spec: Some(ServiceSpec { + ports: Some(vec![ServicePort { + port: 5432, + target_port: Some(IntOrString::Int(pg_params_compute_pg_port as i32)), + name: Some("postgres".to_string()), + ..Default::default() + }]), + type_: Some("ClusterIP".to_string()), + selector: Some( + vec![("app".to_string(), pg_compute_name)] + .into_iter() + .collect(), + ), + ..Default::default() + }), + ..Default::default() + }; + } + + Service { + metadata: meta::v1::ObjectMeta { + name: Some(pg_compute_name.clone()), + annotations: self.get_loadbalancer_annotations(pg_compute_name.as_str()), + ..Default::default() + }, + spec: Some(ServiceSpec { + ports: Some(vec![ServicePort { + port: 5432, + target_port: Some(IntOrString::Int(pg_params_compute_pg_port as i32)), + name: Some("postgres".to_string()), + ..Default::default() + }]), + type_: Some("LoadBalancer".to_string()), + selector: Some( + vec![("app".to_string(), pg_compute_name)] + .into_iter() + .collect(), + ), + external_traffic_policy: match self.cloud_provider { + Some(CloudProvider::AWS) => Some("Local".to_string()), + Some(CloudProvider::Azure) => Some("Cluster".to_string()), + _ => Some("Local".to_string()), + }, + ..Default::default() + }), + ..Default::default() + } + } + + // A helper to check if a resource has been deleted. + async fn is_resource_deleted(api: &Api, name: &str) -> kube::Result + where + T: kube::Resource + Clone + DeserializeOwned + Debug, + { + match api.get(name).await { + Err(kube::Error::Api(e)) if e.code == 404 => Ok(true), + Err(e) => Err(e), + Ok(_) => Ok(false), + } + } + + pub async fn delete_compute( + &self, + compute_name: &str, + model: ComputeModel, + ) -> kube::Result { + let pg_params_compute_namespace = self + .pg_params + .read() + .expect("pg_param lock poisoned") + .compute_namespace + .clone(); + + let client = Arc::clone(&self.client); + let service_api: Api = + Api::namespaced((*client).clone(), pg_params_compute_namespace.as_str()); + + tracing::info!("Deleting resources for compute {compute_name}"); + match model { + ComputeModel::PrivatePreview => { + let deployment_api: Api = + Api::namespaced((*client).clone(), pg_params_compute_namespace.as_str()); + deployment_api + .delete(compute_name, &DeleteParams::default()) + .await + .or_else(Self::allow_not_found_error)?; + + // Azure load balancer services require special attention when deleting since + // we leverage the `azure-dns-label-name` annotation to generate the DNS name. + // On Azure, we have to wipe out this annotation before deleting the service, + // in order to ensure that the reserved DNS name (and its associated IP address) + // is properly cleaned. Note that its OK to make this update right before deletion + // (the AKS LB controller is designed to handle this kind of cleanup in one shot). + if self.cloud_provider == Some(CloudProvider::Azure) { + let service = service_api.get(compute_name).await?; + let mut service = service.clone(); + if let Some(annotations) = service.metadata.annotations.as_mut() { + annotations.insert( + "service.beta.kubernetes.io/azure-dns-label-name".to_string(), + "".to_string(), + ); + } + service_api + .replace(compute_name, &PostParams::default(), &service) + .await?; + } + + service_api + .delete(compute_name, &DeleteParams::default()) + .await + .or_else(Self::allow_not_found_error)?; + } + ComputeModel::PublicPreview => { + let replica_set_api: Api = + Api::namespaced((*client).clone(), pg_params_compute_namespace.as_str()); + replica_set_api + .delete(compute_name, &DeleteParams::default()) + .await + .or_else(Self::allow_not_found_error)?; + } + } + + // Delete the admin service resource. + let admin_name = format!("{}-admin", compute_name); + service_api + .delete(&admin_name, &DeleteParams::default()) + .await + .or_else(Self::allow_not_found_error)?; + + // Verification: Check that all resources have been removed. + let mut all_deleted = true; + + match model { + ComputeModel::PrivatePreview => { + // Check the Deployment and Service for the compute_name. + let deployment_api: Api = + Api::namespaced((*client).clone(), pg_params_compute_namespace.as_str()); + let deployment_deleted = + Self::is_resource_deleted(&deployment_api, compute_name).await?; + all_deleted &= deployment_deleted; + + let service_deleted = Self::is_resource_deleted(&service_api, compute_name).await?; + all_deleted &= service_deleted; + } + ComputeModel::PublicPreview => { + // Check the ReplicaSet for the compute_name. + let replica_set_api: Api = + Api::namespaced((*client).clone(), pg_params_compute_namespace.as_str()); + let replicaset_deleted = + Self::is_resource_deleted(&replica_set_api, compute_name).await?; + all_deleted &= replicaset_deleted; + } + } + + // Check the admin service deletion. + let admin_deleted = Self::is_resource_deleted(&service_api, &admin_name).await?; + all_deleted &= admin_deleted; + + Ok(all_deleted) + } + + pub async fn get_compute_connection_info( + &self, + compute_name: &str, + ) -> Option { + let pg_params_compute_namespace = self + .pg_params + .read() + .expect("pg_param lock poisoned") + .compute_namespace + .clone(); + let service_api: Api = + Api::namespaced((*self.client).clone(), pg_params_compute_namespace.as_str()); + match service_api.get(compute_name).await { + Ok(svc) => svc + .status + .as_ref() + .and_then(|status| status.load_balancer.as_ref()) + .and_then(|lb| lb.ingress.as_ref()) + .and_then(|ingress| ingress.first()) + .map(|ingress| PostgresConnectionInfo { + // On Azure, ingress.ip is set instead of ingress.hostname. In that case, + // return the DNS name generated by the `azure-dns-label-name` annotation, + // since the raw IP address isn't ideal for clients. + host: ingress.hostname.clone().map_or_else( + || { + if ingress.ip.is_some() + && self.cloud_provider == Some(CloudProvider::Azure) + { + // It is important that we ensure that ingress.ip is in fact set + // to avoid returning an incorrect/broken hostname during LB creation. + Some(format!( + "{}.{}.cloudapp.azure.com", + compute_name, self.region + )) + } else { + None + } + }, + Some, + ), + port: svc.spec.as_ref().and_then(|spec| { + spec.ports.as_ref().and_then(|ports| { + ports.first().map(|port: &ServicePort| port.port as u16) + }) + }), + }), + Err(e) => { + // TODO: Figure out if there is a better way to match 404 errors. Ideally we still want to surface + // other errors to the caller. + tracing::error!("Failed to get service {:?}: {:?}", compute_name, e); + None + } + } + } + + /// This function returns additional settings that will be injected into ComputeSpec when deploying a compute node. + pub async fn get_databricks_compute_settings( + &self, + workspace_url: Option, + ) -> DatabricksSettings { + let pg_params = self.pg_params.read().expect("pg_param lock poisoned"); + + let pg_compute_tls_settings = pg_params + .pg_compute_tls_settings + .clone() + .unwrap_or(pg_params_pg_compute_tls_settings().unwrap()); + + let databricks_pg_hba = pg_params + .databricks_pg_hba + .clone() + .unwrap_or(pg_params_default_databricks_pg_hba().unwrap()); + + let databricks_pg_ident = pg_params + .databricks_pg_ident + .clone() + .unwrap_or(pg_params_default_databricks_pg_ident().unwrap()); + + DatabricksSettings { + pg_compute_tls_settings, + databricks_pg_hba, + databricks_pg_ident, + databricks_workspace_host: workspace_url + .and_then(|url| url.host_str().map(|s| s.to_string())) + .unwrap_or("".to_string()), + } + } + + // Functions manipulating k8s Service objects used for the intra-cluster last-leg DpApiProxy -> Postgres + // routing (a.k.a. ingress services). + + // Constructs the name of the primary ingress k8s Service object of a Hadron database instance. + // DpApiProxy selects this service as the "upstream" for `instance-$instance_id.database.$TLD` SNI matches. + fn instance_primary_ingress_service_name(instance_id: Uuid) -> String { + format!("instance-{}", instance_id) + } + + fn instance_read_only_ingress_service_name(instance_id: Uuid) -> String { + format!("instance-ro-{}", instance_id) + } + + // Gets a k8s Service object by name, in the namespace where PG computes are deployed. + async fn get_ingress_service(&self, service_name: &str) -> kube::Result { + let pg_params_compute_namespace = self + .pg_params + .read() + .expect("pg_params lock poisoned") + .compute_namespace + .clone(); + + let api: Api = + Api::namespaced((*self.client).clone(), pg_params_compute_namespace.as_str()); + api.get(service_name).await + } + + // Idempotently create the readable secondary ingress services. + async fn create_if_not_exists_readable_secondary_ingress( + &self, + service_name: String, + instance_id: Uuid, + ) -> kube::Result { + let (pg_params_compute_namespace, pg_params_compute_pg_port) = { + let pg_params = self.pg_params.read().expect("pg_param lock poisoned"); + ( + pg_params.compute_namespace.clone(), + pg_params + .compute_pg_port + .unwrap_or(pg_params_default_compute_pg_port().unwrap()), + ) + }; + + // We match every compute in the compute_pool with a mode equal to secondary. + let ingress_service_selector = Some( + vec![ + (INSTANCE_ID_LABEL_KEY.to_string(), instance_id.to_string()), + (COMPUTE_SECONDARY_LABEL_KEY.to_string(), true.to_string()), + ] + .into_iter() + .collect(), + ); + + let ingress_service_type = Some("ClusterIP".to_string()); + + let ingress_service = Service { + metadata: meta::v1::ObjectMeta { + name: Some(service_name.to_string()), + annotations: None, + ..Default::default() + }, + spec: Some(ServiceSpec { + selector: ingress_service_selector.clone(), + ports: Some(vec![ServicePort { + port: 5432, + target_port: Some(IntOrString::Int(pg_params_compute_pg_port as i32)), + name: Some("postgres".to_string()), + protocol: Some("TCP".to_string()), + ..Default::default() + }]), + type_: ingress_service_type.clone(), + external_traffic_policy: None, + ..Default::default() + }), + ..Default::default() + }; + + let api: Api = + Api::namespaced((*self.client).clone(), pg_params_compute_namespace.as_str()); + match api.get(&service_name).await { + // If the service exists, no-op + Ok(service) => Ok(service), + + // If the service does not exist, create it. + Err(kube::Error::Api(e)) if e.code == 404 => { + api.create(&PostParams::default(), &ingress_service).await + } + Err(e) => Err(e), + } + } + + // Idempotently create or patch a k8s Service object with the given name to route traffic to the specified compute. + async fn create_or_patch_ingress_service( + &self, + service_name: &str, + compute_id: Uuid, + service_type: K8sServiceType, + ) -> kube::Result { + let (pg_params_compute_namespace, pg_params_compute_pg_port) = { + let pg_params = self.pg_params.read().expect("pg_param lock poisoned"); + ( + pg_params.compute_namespace.clone(), + pg_params + .compute_pg_port + .unwrap_or(pg_params_default_compute_pg_port().unwrap()), + ) + }; + + let ingress_service_selector = Some( + vec![(COMPUTE_ID_LABEL_KEY.to_string(), compute_id.to_string())] + .into_iter() + .collect(), + ); + let ingress_service_type = match service_type { + K8sServiceType::ClusterIP => Some("ClusterIP".to_string()), + K8sServiceType::LoadBalancer => Some("LoadBalancer".to_string()), + }; + let annotations = if matches!(service_type, K8sServiceType::LoadBalancer) { + self.get_loadbalancer_annotations(service_name) + } else { + None + }; + let external_traffic_policy = if matches!(service_type, K8sServiceType::LoadBalancer) { + // External traffic policy is only relevant for LoadBalancer services. Normally we are okay with default setting (Local), which + // means ingress traffic from the public network is routed to the target Pod (as defined by the `Service` selector) directly by + // the cloud provider's load balancer. However, in Azure, this does not work if there are security policy configured for the node + // pool receiving the traffic (which is the case in our setup). Techncially in AWS these security policies are also in place, but + // the Network Load Balancers are smart enough to punch holes in the security group to allow the traffic to flow. In Azure, even + // traffic from the Network Load Balancers are always subject to the explicit configurations in the security policy. As a + // workaround, we set the external traffic policy to "Cluster" in Azure to have the traffic routed through the cluster's internal + // network. This results in some performance penalty, but since we are retiring this NLB-based ingress stack soon we are not going + // to optimize this. + match self.cloud_provider { + Some(CloudProvider::AWS) => Some("Local".to_string()), + Some(CloudProvider::Azure) => Some("Cluster".to_string()), + _ => Some("Local".to_string()), + } + } else { + None + }; + + let ingress_service = Service { + metadata: meta::v1::ObjectMeta { + name: Some(service_name.to_string()), + annotations: annotations.clone(), + ..Default::default() + }, + spec: Some(ServiceSpec { + selector: ingress_service_selector.clone(), + ports: Some(vec![ServicePort { + port: 5432, + target_port: Some(IntOrString::Int(pg_params_compute_pg_port as i32)), + name: Some("postgres".to_string()), + protocol: Some("TCP".to_string()), + ..Default::default() + }]), + type_: ingress_service_type.clone(), + external_traffic_policy: external_traffic_policy.clone(), + ..Default::default() + }), + ..Default::default() + }; + + let api: Api = + Api::namespaced((*self.client).clone(), pg_params_compute_namespace.as_str()); + match api.get(service_name).await { + // If the service exists, patch it. + Ok(mut service) => { + // We only need to update select fields such as "selector" in the service spec. Everything else, including the resource version, + // should be left untouched so that the "replace" API can perform the atomic update. + service.metadata.annotations = annotations.clone(); + service.spec.iter_mut().for_each(|spec| { + spec.selector = ingress_service_selector.clone(); + spec.type_ = ingress_service_type.clone(); + spec.external_traffic_policy = external_traffic_policy.clone(); + }); + api.replace(service_name, &PostParams::default(), &service) + .await + } + // If the service does not exist, create it. + Err(kube::Error::Api(e)) if e.code == 404 => { + api.create(&PostParams::default(), &ingress_service).await + } + Err(e) => Err(e), + } + } + + async fn delete_ingress_service(&self, service_name: &str) -> kube::Result { + // Retrieve the namespace from the internal pg_params configuration. + let pg_params_compute_namespace = { + let pg_params = self.pg_params.read().expect("pg_params lock poisoned"); + pg_params.compute_namespace.clone() + }; + + // Create a namespaced API for Service resources. + let api: Api = + Api::namespaced((*self.client).clone(), pg_params_compute_namespace.as_str()); + + // Define the deletion parameters. + let dp = DeleteParams::default(); + + // Attempt to delete the service. + match api.delete(service_name, &dp).await { + // Service delete request accepted, but the service may not have been deleted yet. + Ok(_) => { + // Check if the service still exists. + let service_opt = api.get_opt(service_name).await?; + Ok(service_opt.is_none()) + } + // Service doesn't exist, so it's already deleted. + Err(kube::Error::Api(e)) if e.code == 404 => Ok(true), + // Any other error is returned. + Err(e) => Err(e), + } + } + + /// Watches the config file for changes. If the config file changes, the HadronCluster is deployed and PgParams are updated. + /// Note: Currently even if only PgParams are changed, deploy_storage for the HadronCluster is called. This is not a correctness issue. + /// This is because the hash of the entire config file is used to detect changes. + /// TODO: Consider refactoring the k8s manager struct to take in the additional Arc params here. + pub async fn config_watcher( + &self, + persistence: Arc, + token_generator: Option, + startup_delay: Duration, + service: Arc, + ) { + let mut last_successfully_applied_hash = None; + + // Wait for the startup to be fully completed (e.g., startup reconcilations are scheduled) + service.startup_complete.clone().wait().await; + tracing::info!("Storage controller startup completed, waiting for all spawned reconciliation tasks to complete."); + + const MAX_WAIT_SECONDS: i32 = 600; + let mut waits = 0; + let mut active_tasks: usize; + loop { + active_tasks = service.get_active_reconcile_tasks(); + if active_tasks == 0 || waits >= MAX_WAIT_SECONDS { + break; + } + tokio::time::sleep(Duration::from_secs(1)).await; + waits += 1; + } + + tracing::info!("Finished waiting for startup reconciliation tasks ({} tasks remaining). Starting config watcher.", active_tasks); + + // Wait for startup delay before starting the reconcliation loop. The startup delay is useful to avoid the HCC updating PS/SK/PG nodes + // (which casues restarts) when the HCC itself also just started and needs to contact these nodes to discover cluster state. + tokio::time::sleep(startup_delay).await; + + // Instantiate the page server watchdog to monitor all page server pods in the given namespace for PVC breakages async. + let client = self.client.clone(); + let namespace = self.namespace.clone(); + tokio::task::spawn_blocking(move || { + Handle::current().block_on(async { + let result = + create_pageserver_pod_watcher(client, namespace, "page-server".to_string()) + .await; + + if let Err(e) = result { + tracing::error!("Error running page server pod watcher: {:?}", e); + } + }) + }); + + loop { + tokio::time::sleep(Duration::from_secs(1)).await; + + let result = self + .watch_config( + &mut last_successfully_applied_hash, + Arc::clone(&persistence), + token_generator.clone(), + service.clone(), + ) + .await; + + let outcome_label = if result.is_err() { + tracing::error!("Error watching and applying config: {:?}", result); + ReconcileOutcome::Error + } else { + ReconcileOutcome::Success + }; + + metrics::METRICS_REGISTRY + .metrics_group + .storage_controller_config_watcher_complete + .inc(ConfigWatcherCompleteLabelGroup { + status: outcome_label, + }); + } + } + + // Read the cluster config file and validate that it contains the required fields. + // Returns: + // - The HadronCluster object defining the storage cluster. + // - The PgParams object defining parameters used to launch Postgres compute nodes. + // - The hash of the contents of the config file. + // - Or an error if the config file is missing or otherwise invalid. + pub async fn read_and_validate_cluster_config( + config_file_path: &str, + ) -> anyhow::Result<( + HadronCluster, + PgParams, + String, + Option, + )> { + let file = File::open(config_file_path) + .context(format!("Failed to open config file {config_file_path}"))?; + let reader = BufReader::new(file); + let config: ConfigData = + serde_json::from_reader(reader).context("Failed to parse config file")?; + let file_hash = hash_file_contents(config_file_path) + .await + .context("Failed to hash file contents")?; + + let hadron_cluster = config + .hadron_cluster + .ok_or(anyhow::anyhow!("hadron_cluster is required in config.json"))?; + let pg_params = config + .pg_params + .ok_or(anyhow::anyhow!("pg_params is required in config.json"))?; + let billing_metrics_conf = config.page_server_billing_metrics_config; + + Ok((hadron_cluster, pg_params, file_hash, billing_metrics_conf)) + } + + /// Update the gauge metrics to publish the desired number of pageservers and safekeepers managed by us. + fn publish_desired_ps_sk_counts(&self, cluster: &HadronCluster) { + let num_pageservers = cluster + .hadron_cluster_spec + .as_ref() + .and_then(|spec| spec.page_server_specs.as_ref()) + .map(|pools| pools.iter().map(|p| p.replicas.unwrap_or(0) as i64).sum()) + .unwrap_or(0); + metrics::METRICS_REGISTRY + .metrics_group + .storage_controller_num_pageservers_desired + .set(num_pageservers); + + let num_safekeepers = cluster + .hadron_cluster_spec + .as_ref() + .and_then(|spec| spec.safe_keeper_specs.as_ref()) + .map(|pools| pools.iter().map(|p| p.replicas.unwrap_or(0) as i64).sum()) + .unwrap_or(0); + metrics::METRICS_REGISTRY + .metrics_group + .storage_controller_num_safekeeper_desired + .set(num_safekeepers); + } + + async fn watch_config( + &self, + last_successfully_applied_hash: &mut Option, + persistence: Arc, + token_generator: Option, + service: Arc, + ) -> anyhow::Result<()> { + let cluster_config_file_path = config_manager::get_cluster_config_file_path(); + let (hadron_cluster, pg_params, hash, billing_metrics_conf) = + Self::read_and_validate_cluster_config(&cluster_config_file_path).await?; + + // Report desired pageserver and safekeeper counts. + self.publish_desired_ps_sk_counts(&hadron_cluster); + + // Check if the config file has changed since the last successful deployment. + if Some(hash.as_str()) == last_successfully_applied_hash.as_deref() { + return Ok(()); + } + + // Deploy storage + if let Err(e) = self + .deploy_storage(hadron_cluster, billing_metrics_conf, service) + .await + { + return Err(anyhow::anyhow!("Failed to deploy HadronCluster: {:?}", e)); + } else { + tracing::info!("Successfully deployed HadronCluster"); + } + + // Get the PG compute defaults from the config file and update the local in-memory defaults, + // as well all available managed PG compute deployments. + let compute_namespace = pg_params.compute_namespace.clone(); + *self.pg_params.write().expect("pg_params lock poisoned") = pg_params; + tracing::info!("Updated PG params in K8sManager"); + + // Iterate over all active PG endpoints and update their compute deployments. + let active_endpoints = match persistence.get_active_endpoint_infos().await { + Ok(endpoints) => endpoints, + Err(e) => { + return Err(anyhow::anyhow!("Failed to get active endpoints: {:?}", e)); + } + }; + + tracing::info!( + "Successfully retrieved {} active endpoints", + active_endpoints.len() + ); + + for endpoint in active_endpoints { + // Re-generate the PG HCC auth token for the tenant endpoint / compute. + let pg_hcc_auth_token: String = match token_generator + .as_ref() + .unwrap() + .generate_tenant_endpoint_scope_token(endpoint.endpoint_id) + { + Ok(token) => token, + Err(e) => { + return Err(anyhow::anyhow!( + "Failed to generate PG HCC auth token: {:?}", + e + )); + } + }; + + for compute in endpoint.computes { + // Deserialize the compute's config into the per-compute EndpointConfig. + let endpoint_config: EndpointConfig = + match serde_json::from_str(&compute.compute_config) { + Ok(config) => config, + Err(e) => { + return Err(anyhow::anyhow!( + "Failed to deserialize endpoint config: {:?}", + e + )); + } + }; + + let workspace_url = { + match parse_to_url(endpoint.workspace_url.clone()) { + Ok(url) => url, + Err(e) => { + tracing::error!("Failed to parse workspace URL: {:?}", e); + None + } + } + }; + + let pg_compute = PgCompute { + name: compute.compute_name.clone(), + compute_id: format!("{}/{}", endpoint.endpoint_id, compute.compute_index), + workspace_id: endpoint.workspace_id.clone(), + workspace_url, + image_override: endpoint_config.image, + node_selector_override: endpoint_config.node_selector, + control_plane_token: pg_hcc_auth_token.clone(), + resources: endpoint_config + .resources + .unwrap_or(endpoint_default_resources()), + // T-shirt size is guranteed to exist when creating the endpoint. The same value should exist in meta PG as well. + tshirt_size: endpoint_config.tshirt_size.unwrap_or_else(|| { + panic!("T-shirt size did not exist in the endpoint config from meta PG"); + }), + exporter_image_override: endpoint_config.prometheus_exporter_image, + model: ComputeModel::PrivatePreview, + readiness_probe: None, + instance_id: None, + }; + + // Fetch the deployment and check if it has the skip reconciliation annotation. + let deployments: Api = + Api::namespaced((*self.client).clone(), &compute_namespace); + + let deployment = Self::k8s_get(deployments, &compute.compute_name).await; + + // If the deployment exists, and has the hadorn.dev/managed annotation set to false, skip reconciling the deployment. + if let Ok(Some(deployment)) = deployment { + if let Some(annotations) = deployment.metadata.annotations { + if let Some(managed) = annotations.get("hadron.dev/managed") { + if managed == "false" { + tracing::info!("Skipping syncing deployment for compute {} as it is not managed", compute.compute_name); + continue; + } + } + } + } + + match self.deploy_compute(pg_compute).await { + Ok(_) => { + tracing::info!("Deployed compute {} successfully", compute.compute_name) + } + Err(e) => { + return Err(anyhow::anyhow!("Failed to deploy compute: {:?}", e)); + } + } + } + } + + tracing::info!("Successfully updated PG compute deployments"); + + *last_successfully_applied_hash = Some(hash); + Ok(()) + } + + /// Deploys the HadronCluster to the Kubernetes cluster. + async fn deploy_storage( + &self, + storage: HadronCluster, + billing_metrics_conf: Option, + service: Arc, + ) -> anyhow::Result<()> { + let client = Arc::clone(&self.client); + + // Get objects + let storage_broker_objs: StorageBrokerObjs = self.get_storage_broker_objs(&storage)?; + tracing::info!("Got StorageBroker objects"); + let safe_keeper_objs: SafeKeeperObjs = self.get_safe_keeper_objs(&storage)?; + tracing::info!("Got SafeKeeper objects"); + let page_server_objs: PageServerObjs = + self.get_page_server_objs(&storage, billing_metrics_conf)?; + tracing::info!("Got PageServer objects"); + + // Create objects + Self::k8s_create_or_replace( + Api::namespaced((*client).clone(), &self.namespace), + "storage-broker", + storage_broker_objs.deployment, + ) + .await?; + tracing::info!("Deployed StorageBroker Deployment"); + Self::k8s_create_or_replace( + Api::namespaced((*client).clone(), &self.namespace), + "storage-broker", + storage_broker_objs.service, + ) + .await?; + tracing::info!("Deployed StorageBroker Service"); + + let hadron_cluster_spec = storage + .hadron_cluster_spec + .as_ref() + .ok_or(anyhow::anyhow!("Expected HadronCluster spec"))?; + let safe_keeper_specs = hadron_cluster_spec + .safe_keeper_specs + .as_ref() + .ok_or(anyhow::anyhow!("Expected SafeKeeper spec"))?; + + for (i, spec) in safe_keeper_specs.iter().enumerate() { + let span = tracing::info_span!("sk_maintenance_manager"); + SKMaintenanceManager::new( + Api::namespaced((*client).clone(), &self.namespace), + service.clone(), + safe_keeper_objs.stateful_sets[i] + .metadata + .name + .clone() + .ok_or(anyhow::anyhow!("Expected name"))?, + safe_keeper_objs.stateful_sets[i].clone(), + spec.pool_id, + ) + .run() + .instrument(span) + .await?; + } + + tracing::info!("Deployed SafeKeeper Deployment"); + Self::k8s_create_or_replace( + Api::namespaced((*client).clone(), &self.namespace), + "safe-keeper", + safe_keeper_objs.service, + ) + .await?; + tracing::info!("Deployed SafeKeeper Service"); + + // Handle PS updates + // 1. Create ot update the image puller DaemonSet, which will pre-download hadron images to pageserver nodes. + // 2. Wait (best-effort, with timeout) for all image puller DaemonSet Pods be updated. + // 3. Create or update the pageserver StatefulSet and Service objects. + let mut image_puller_names_and_deadlines: Vec<(String, Instant)> = Vec::new(); + for ImagePullerDaemonsetInfo { + daemonset: image_puller, + image_prepull_timeout: timeout, + } in page_server_objs.image_puller_daemonsets + { + let image_pull_job_name = &image_puller + .metadata + .name + .clone() + .ok_or(anyhow::anyhow!("Expected name"))?; + if let Err(e) = Self::k8s_create_or_replace( + Api::namespaced((*client).clone(), &self.namespace), + image_pull_job_name, + image_puller, + ) + .await + { + tracing::warn!( + "Failed to create or update image puller Daemonset {}: {:?}", + image_pull_job_name, + e + ); + // Skip instead of failing the whole call. Image puller is just an optimization, we still want to deploy + // the PS/SK/SB if there are issues with the image puller. + continue; + } + // We successfully created/updated the DaemonSet. Add it to the list of DaemonSets we need to wait for. + // Note that the wait deadline is calculated from each DaemonSet's creation/update time + timeout. + image_puller_names_and_deadlines.push(( + image_pull_job_name.to_owned(), + Instant::now() + timeout.unwrap_or(IMAGE_PREPULL_DEFAULT_TIMEOUT), + )); + } + + for (ds_name, deadline) in image_puller_names_and_deadlines { + match Self::k8s_wait_for_daemonset( + Api::namespaced((*client).clone(), &self.namespace), + &ds_name, + deadline, + ) + .await + { + Ok(_) => { + tracing::info!( + "Image puller Daemonset {} preloaded images successfully", + &ds_name + ) + } + Err(e) => { + // The image puller DeamonSet is a best-effort performance optimization and is not strictly + // required for the system to function. If it fails for any reason, just log a warning and + // proceed with updating the pageserver StatefulSets. + tracing::warn!( + "Image puller Daemonset {} did not preload images successfully on all nodes, proceeding. Error: {:?}", + &ds_name, + e + ) + } + } + } + + if self.cloud_provider.is_some() { + // during release, we may first update the config map and then restart the storage controller. Add a sleep here + // to avoid the old storage controller from starting the drain-and-fill process, which is currently not resumable. + tracing::info!("Waiting for 60 seconds before deploying page servers"); + sleep(Duration::from_secs(60)).await; + } + + for stateful_set in page_server_objs.stateful_sets { + let span = tracing::info_span!("drain_and_fill_manager"); + DrainAndFillManager::new( + Api::namespaced((*client).clone(), &self.namespace), + Api::namespaced((*client).clone(), &self.namespace), + service.clone(), + stateful_set + .metadata + .name + .clone() + .ok_or(anyhow::anyhow!("Expected name"))?, + stateful_set, + ) + .run() + .instrument(span) + .await?; + } + tracing::info!("Deployed PageServer Deployment"); + Self::k8s_create_or_replace( + Api::namespaced((*client).clone(), &self.namespace), + "page-server", + page_server_objs.service, + ) + .await?; + tracing::info!("Deployed PageServer Service"); + + kube::Result::Ok(()) + } + + fn extract_cpu_memory_resources( + &self, + resources: ResourceRequirements, + ) -> anyhow::Result { + // Note that we only extra resource requests, not limits. Hadron storage components (PS/SK) should not have CPU or memory limits. + // Other containers running on the node should be evicted/killed first if the node is oversubscribed. + let requests = resources + .requests + .clone() + .ok_or(anyhow::anyhow!("Expected resource requests"))?; + Ok(ResourceRequirements { + requests: Some(BTreeMap::from([ + ("cpu".to_string(), requests["cpu"].clone()), + ("memory".to_string(), requests["memory"].clone()), + ])), + ..Default::default() + }) + } + + /// Calculate the node affinity setting to use for a PS/SK/SB component. + /// - `node_group_requirement`: Node selector requirement specifying which node group (usually identified by "bickstore-pool-types" node label) + /// we should use to deploy the component. + /// - `availability_zone_suffix`: The availability zone suffix to use for the component. This is used to schedule different pools to different + /// AZs. + fn compute_node_affinity( + &self, + node_group_requirement: Option<&NodeSelectorRequirement>, + availability_zone_suffix: Option<&String>, + ) -> Option { + // There are two node selector requirements we potentially need: + // 1. The node group node selector requirement, which specifies which node label value we require for the "brickstore-pool-types" label key we require. This is + // needed if we want to schedule a component to a specific type of VM instances/node groups. + // 2. The availability zone node selector requirement, which specifies which node label value we require for the "topology.kubernetes.io/zone" label key. + // This is needed if we want to schedule a component to a specific availability zone. + let mut requirements: Vec = + node_group_requirement.into_iter().cloned().collect(); + + let az_requirement = availability_zone_suffix.map(|az_suffix| NodeSelectorRequirement { + key: "topology.kubernetes.io/zone".to_string(), + operator: "In".to_string(), + values: Some(vec![self.region.clone() + az_suffix]), + }); + + requirements.extend(az_requirement); + + if requirements.is_empty() { + None + } else { + // From official documentation: https://kubernetes.io/docs/concepts/scheduling-eviction/assign-pod-node/#node-affinity + // There are 2 ways to specify affinity rules: NodeSelectorTerms and MatchExpressions. They work as follows: + // + // - nodeSelectorTerms: + // - matchExpressions: + // - cond1 + // - cond2 + // - nodeSelectorTerms: + // - matchExpressions: + // - cond3 + // + // A node is considered a match if (cond1 AND cond2) OR (cond3). + // + // In other words, the nodeSelectorTerms are OR-ed together, and the matchExpressions within a nodeSelectorTerm are AND-ed together. + // In our use case we want the node selector requirements to be AND-ed together, so we put them under a MatchExpressions within a + // single NodeSelectorTerms. + Some(Affinity { + node_affinity: Some(NodeAffinity { + required_during_scheduling_ignored_during_execution: Some(NodeSelector { + node_selector_terms: vec![NodeSelectorTerm { + match_expressions: Some(requirements), + ..Default::default() + }], + }), + ..Default::default() + }), + ..Default::default() + }) + } + } + + fn node_selector_requirement_to_tolerations( + &self, + node_selector: Option<&NodeSelectorRequirement>, + ) -> Option> { + node_selector.and_then(|node_selector| { + node_selector.values.as_ref().map(|values| { + values + .iter() + .map(|label_value| Toleration { + key: Some("databricks.com/node-type".to_string()), + operator: Some("Equal".to_string()), + value: Some(label_value.clone()), + effect: match self.cloud_provider { + Some(CloudProvider::AWS) => Some("NoSchedule".to_string()), + // AKS node pools support PreferNoSchedule but not NoSchedule. + Some(CloudProvider::Azure) => Some("PreferNoSchedule".to_string()), + _ => Some("NoSchedule".to_string()), + }, + ..Default::default() + }) + .collect() + }) + }) + } + + /// Generate an image puller DaemonSet manifest. An image puller DaemonSet downloads the specified image to + /// the specified nodes by running a dummy container (runs a simple `sleep infinity` command). + /// - `name`: The name of the ImagePullJob. + /// - `image`: The image ref of the image to download. + /// - `image_pull_secrets`: Any image pull secrets to use when downloading the image. + /// - `node_selector_requirement`: Node selector requirement specifying which nodes to download the image to. + /// - `availability_zone_suffix`: Specify which available zone to download the image to. + /// The `node_selector_requirement` and `availability_zone_suffix` parameters together determine the nodes + /// to download the image to. If none are specified, the image will be downloaded to all nodes. + /// - `image_pull_parallelism`: The max number of nodes that can pre-download the image in parallel. + fn generate_image_puller_daemonset( + &self, + name: &str, + image: &Option, + image_pull_secrets: &Option>, + node_selector_requirement: &Option, + availability_zone_suffix: &Option, + image_pull_parallelism: &Option, + ) -> anyhow::Result { + let image_puller_ds = DaemonSet { + metadata: meta::v1::ObjectMeta { + name: Some(name.to_string()), + namespace: Some(self.namespace.clone()), + ..Default::default() + }, + spec: Some(DaemonSetSpec { + selector: meta::v1::LabelSelector { + match_labels: Some( + vec![("app".to_string(), name.to_string())] + .into_iter() + .collect(), + ), + ..Default::default() + }, + update_strategy: Some(DaemonSetUpdateStrategy { + type_: Some("RollingUpdate".to_string()), + rolling_update: Some(RollingUpdateDaemonSet { + // We use max_unavailable to control the parallelism of the image pre-pull operation, as + // Kubernetes will allow up to max_unavailable pods to start updating at the same time. + max_unavailable: image_pull_parallelism.map(IntOrString::Int), + ..Default::default() + }), + }), + template: PodTemplateSpec { + metadata: Some(meta::v1::ObjectMeta { + labels: Some( + vec![("app".to_string(), name.to_string())] + .into_iter() + .collect(), + ), + ..Default::default() + }), + spec: Some(PodSpec { + image_pull_secrets: image_pull_secrets.clone(), + containers: vec![Container { + name: "sleep".to_string(), + image: image.clone(), + image_pull_policy: Some("IfNotPresent".to_string()), + command: Some(vec!["/bin/sleep".to_string(), "infinity".to_string()]), + resources: Some(ResourceRequirements { + // Set tiny requests/limits as this container doesn't really do anything. + requests: Some( + vec![ + ("cpu".to_string(), Quantity("10m".to_string())), + ("memory".to_string(), Quantity("10Mi".to_string())), + ] + .into_iter() + .collect(), + ), + limits: Some( + vec![ + ("cpu".to_string(), Quantity("10m".to_string())), + ("memory".to_string(), Quantity("20Mi".to_string())), + ] + .into_iter() + .collect(), + ), + ..Default::default() + }), + ..Default::default() + }], + affinity: self.compute_node_affinity( + node_selector_requirement.as_ref(), + availability_zone_suffix.as_ref(), + ), + // The image puller pod can be terminated immediately because as all it does is `sleep`. + termination_grace_period_seconds: Some(0), + tolerations: self.node_selector_requirement_to_tolerations( + node_selector_requirement.as_ref(), + ), + priority_class_name: Some("databricks-daemonset".to_string()), + ..Default::default() + }), + }, + ..Default::default() + }), + ..Default::default() + }; + + Ok(image_puller_ds) + } + + /// Returns the Deployment and Service objects for the StorageBroker. + fn get_storage_broker_objs( + &self, + storage: &HadronCluster, + ) -> anyhow::Result { + let hadron_cluster_spec = storage + .hadron_cluster_spec + .as_ref() + .ok_or(anyhow::anyhow!("Expected HadronCluster spec"))?; + let storage_broker_spec = hadron_cluster_spec + .storage_broker_spec + .as_ref() + .ok_or(anyhow::anyhow!("Expected StorageBroker spec"))?; + let launch_command = r#"# + /usr/local/bin/storage_broker --listen-addr=0.0.0.0:50051 \ + --timeline-chan-size=1024 \ + --all-keys-chan-size=524288 + "# + .to_string(); + + let deployment = Deployment { + metadata: meta::v1::ObjectMeta { + name: Some("storage-broker".to_string()), + namespace: Some(self.namespace.clone()), + ..Default::default() + }, + spec: Some(DeploymentSpec { + replicas: Some(1), + selector: meta::v1::LabelSelector { + match_labels: Some( + vec![("app".to_string(), "storage-broker".to_string())] + .into_iter() + .collect(), + ), + ..Default::default() + }, + template: PodTemplateSpec { + metadata: get_pod_metadata("storage-broker".to_string(), 50051), + spec: Some(PodSpec { + affinity: self.compute_node_affinity( + storage_broker_spec.node_selector.as_ref(), + // StorageBroker is a single Pod, so we allow it to be in any AZ. + None, + ), + tolerations: self.node_selector_requirement_to_tolerations( + storage_broker_spec.node_selector.as_ref(), + ), + image_pull_secrets: storage_broker_spec.image_pull_secrets.clone(), + containers: vec![Container { + name: "storage-broker".to_string(), + image: storage_broker_spec.image.clone(), + image_pull_policy: storage_broker_spec.image_pull_policy.clone(), + ports: get_container_ports(vec![50051]), + resources: storage_broker_spec.resources.clone(), + command: Some(vec!["/bin/sh".to_string(), "-c".to_string()]), + args: Some(vec![launch_command]), + ..Default::default() + }], + ..Default::default() + }), + }, + ..Default::default() + }), + ..Default::default() + }; + + let service = Service { + metadata: meta::v1::ObjectMeta { + name: Some("storage-broker".to_string()), + namespace: Some(self.namespace.clone()), + ..Default::default() + }, + spec: Some(ServiceSpec { + selector: Some(BTreeMap::from([( + "app".to_string(), + "storage-broker".to_string(), + )])), + ports: Some(vec![ServicePort { + port: 50051, + target_port: Some(IntOrString::Int(50051)), + ..Default::default() + }]), + type_: Some("ClusterIP".to_string()), + ..Default::default() + }), + ..Default::default() + }; + + Ok(StorageBrokerObjs { + deployment, + service, + }) + } + + pub fn get_remote_storage_args( + node_kind: NodeKind, + object_storage_config: &HadronObjectStorageConfig, + ) -> String { + let test_endpoint_opt = object_storage_config + .test_params + .as_ref() + .and_then(|test_params| test_params.endpoint.as_ref()) + .map(|endpoint| format!("endpoint='{endpoint}', ")) + .unwrap_or_default(); + + let mut arg_string = "".to_string(); + // This key is different between AWS and Azure, which is very subtle. + let mut prefix_in_bucket_arg = "".to_string(); + + if object_storage_config.is_aws() { + arg_string = format!( + "bucket_name='{}', bucket_region='{}'", + object_storage_config + .bucket_name + .clone() + .unwrap_or_default(), + object_storage_config + .bucket_region + .clone() + .unwrap_or_default() + ); + prefix_in_bucket_arg = "prefix_in_bucket".to_string(); + } else if object_storage_config.is_azure() { + arg_string = format!( + "storage_account='{}', container_name='{}', container_region='{}'", + object_storage_config + .storage_account_resource_id + .clone() + .unwrap_or_default() + .split('/') + .last() + .map(|s| s.to_string()) + .unwrap_or_default(), + object_storage_config + .storage_container_name + .clone() + .unwrap_or_default(), + object_storage_config + .storage_container_region + .clone() + .unwrap_or_default(), + ); + prefix_in_bucket_arg = "prefix_in_container".to_string(); + } + + let mut prefix_in_bucket = match node_kind { + NodeKind::Pageserver => "pageserver/", + NodeKind::Safekeeper => "safekeeper/", + }; + + // Tests use empty prefix. + if !test_endpoint_opt.is_empty() { + prefix_in_bucket = ""; + } + + format!( + "{{{}{}, {}='{}'}}", + test_endpoint_opt, arg_string, prefix_in_bucket_arg, prefix_in_bucket + ) + } + + fn get_remote_storage_startup_args( + &self, + object_storage_config: &HadronObjectStorageConfig, + ) -> String { + let test_endpoint_opt = object_storage_config + .test_params + .as_ref() + .and_then(|test_params| test_params.endpoint.as_ref()) + .map(|endpoint| format!("endpoint='{endpoint}', ")) + .unwrap_or_default(); + + let mut arg_string = "".to_string(); + // This key is different between AWS and Azure, which is very subtle. + let mut prefix_in_bucket_arg = "".to_string(); + + if object_storage_config.is_aws() { + arg_string = "bucket_name='$S3_BUCKET_URI', bucket_region='$S3_REGION'".to_string(); + prefix_in_bucket_arg = "prefix_in_bucket".to_string(); + } else if object_storage_config.is_azure() { + arg_string = "storage_account='$AZURE_STORAGE_ACCOUNT_NAME', container_name='$AZURE_STORAGE_CONTAINER_NAME', container_region='$AZURE_STORAGE_CONTAINER_REGION'".to_string(); + prefix_in_bucket_arg = "prefix_in_container".to_string(); + } + + format!( + "{{{}{}, {}='$PREFIX_IN_BUCKET'}}", + test_endpoint_opt, arg_string, prefix_in_bucket_arg + ) + } + + /// Returns the StatefulSets and Service objects for SafeKeepers. + fn get_safe_keeper_objs(&self, storage: &HadronCluster) -> anyhow::Result { + let hadron_cluster_spec = storage + .hadron_cluster_spec + .as_ref() + .ok_or(anyhow::anyhow!("Expected HadronCluster spec"))?; + let safe_keeper_specs = hadron_cluster_spec + .safe_keeper_specs + .as_ref() + .ok_or(anyhow::anyhow!("Expected SafeKeeper spec"))?; + let object_storage_config = hadron_cluster_spec + .object_storage_config + .as_ref() + .ok_or(anyhow::anyhow!("Expected ObjectStorageConfig"))?; + + // We reserve 100GB disk space on the safekeeper in case it runs of disk space and needs to recover. + let reserved_file_cmd = if self.cloud_provider.is_some() { + "mkdir -p /data/.neon && fallocate -l 100G /data/.neon/reserved_file" + } else { + "" + }; + + let wal_reader_fanout = if is_dev_or_staging() { + "--wal-reader-fanout --max-delta-for-fanout=2147483648" + } else { + "" + }; + let pull_timeline_on_startup = if is_dev_or_staging() { + "--enable-pull-timeline-on-startup" + } else { + "" + }; + + let mut stateful_sets: Vec = Vec::new(); + for safe_keeper_spec in safe_keeper_specs { + let transformed_pool_id = transform_pool_id(safe_keeper_spec.pool_id); + let remote_storage_opt = self.get_remote_storage_startup_args(object_storage_config); + let token_verification_key_mount_path = + brickstore_internal_token_verification_key_mount_path(); + // PS prefers streaming WALs from SK replica in the same AZ. + let availability_zone = if safe_keeper_spec.availability_zone_suffix.is_some() { + format!( + "--availability-zone='az-{}'", + safe_keeper_spec.availability_zone_suffix.clone().unwrap() + ) + } else { + "".to_string() + }; + // TODO: Fully parameterize the port number constants + let launch_command = format!( + r#"#!/bin/sh +# Extract the ordinal number from the hostname +ordinal=$(hostname | rev | cut -d- -f1 | rev) + +# Set the SAFEKEEPER_ID and PREFIX_IN_BUCKET based on the ordinal number +SAFEKEEPER_ID=$((ordinal + {transformed_pool_id})) +# Do NOT specify a leading / in the prefix (this breaks Azure). +PREFIX_IN_BUCKET="safekeeper/" + +SAFEKEEPER_FQDN="${{HOSTNAME}}.safe-keeper.${{MY_NAMESPACE}}.svc.cluster.local" + +# Make sure SIGTERM received by the shell is propagated to the safekeeper process. +shutdown() {{ + echo "Shutting down safekeeper running at pid $pid" + kill -TERM $pid + wait $pid +}} + +trap shutdown TERM + +{reserved_file_cmd} + +# Start Safekeeper. Notes on ports: +# Port 5454 accpets PG wire protocol connections from compute nodes and only accepts tenant-scoped tokens. This is the only port allowed from the untrusted worker subnet. +# Port 5455 accepts PG wire protocol connections from PS (and maybe other trusted components) and is advertised via the storage broker. +/usr/local/bin/safekeeper --listen-pg='0.0.0.0:5455' \ + --listen-pg-tenant-only='0.0.0.0:5454' \ + --advertise-pg-tenant-only="$SAFEKEEPER_FQDN:5454" \ + --hcc-base-url=$STORAGE_CONTROLLER_URL \ + --advertise-pg="$SAFEKEEPER_FQDN:5455" \ + --listen-http='0.0.0.0:7676' \ + --token-auth-type='HadronJWT' \ + --pg-tenant-only-auth-public-key-path='{token_verification_key_mount_path}' \ + --id=$SAFEKEEPER_ID \ + --broker-endpoint="$BROKER_ENDPOINT" \ + --max-offloader-lag=1073741824 \ + --max-reelect-offloader-lag-bytes=4294967296 \ + --wal-backup-parallel-jobs=64 \ + -D /data/.neon \ + {wal_reader_fanout} \ + {pull_timeline_on_startup} \ + {availability_zone} \ + --remote-storage="{remote_storage_opt}" & +pid=$! + +wait +"# + ); + + // Get Brickstore secrets we need to mount to safe keepers. Currently it's just the token verification keys, + // and the azure service principal secret (when appropriate). + let mut k8s_mounts = vec![brickstore_internal_token_verification_key_secret_mount()]; + + if object_storage_config.is_azure() { + k8s_mounts.push(azure_storage_account_service_principal_secret_mount()) + } + + let K8sSecretVolumesAndMounts { + volumes: secret_volumes, + volume_mounts: secret_volume_mounts, + } = self.get_hadron_volumes_and_mounts(k8s_mounts); + + let stateful_set = StatefulSet { + metadata: meta::v1::ObjectMeta { + name: Some(format!( + "{}-{}", + "safe-keeper", + safe_keeper_spec.pool_id.unwrap_or(0) + )), + namespace: Some(self.namespace.clone()), + annotations: Some( + vec![ + ( + SKMaintenanceManager::SK_LOW_DOWNTIME_MAINTENANCE_KEY.to_string(), + safe_keeper_spec + .enable_low_downtime_maintenance + .unwrap_or( + SKMaintenanceManager::SK_LOW_DOWNTIME_MAINTENANCE_DEFAULT, + ) + .to_string(), + ), + ( + SKMaintenanceManager::SK_LDTM_SK_STATUS_CHECK_KEY.to_string(), + safe_keeper_spec + .enable_ldtm_sk_status_check + .unwrap_or( + SKMaintenanceManager::SK_LDTM_SK_STATUS_CHECK_DEFAULT, + ) + .to_string(), + ), + ] + .into_iter() + .collect(), + ), + ..Default::default() + }, + spec: StatefulSetSpec { + replicas: safe_keeper_spec.replicas, + selector: meta::v1::LabelSelector { + match_labels: Some( + vec![("app".to_string(), "safe-keeper".to_string())] + .into_iter() + .collect(), + ), + ..Default::default() + }, + service_name: Some("safe-keeper".to_string()), + volume_claim_templates: get_volume_claim_template( + safe_keeper_spec.resources.clone(), + safe_keeper_spec.storage_class_name.clone(), + )?, + template: PodTemplateSpec { + metadata: get_pod_metadata("safe-keeper".to_string(), 7676), + spec: Some(PodSpec { + affinity: self.compute_node_affinity( + safe_keeper_spec.node_selector.as_ref(), + safe_keeper_spec.availability_zone_suffix.as_ref(), + ), + tolerations: self.node_selector_requirement_to_tolerations( + safe_keeper_spec.node_selector.as_ref(), + ), + image_pull_secrets: safe_keeper_spec.image_pull_secrets.clone(), + service_account_name: hadron_cluster_spec.service_account_name.clone(), + security_context: get_pod_security_context(), + volumes: Some(secret_volumes), + // Set the priority class to the very-high-priority "pg-compute", which should allow + // the safekeeper to preempt all other pods on the same nodes (including log daemon) + // if we run low on resources for whatever reason. + priority_class_name: Some("pg-compute".to_string()), + containers: vec![Container { + name: "safe-keeper".to_string(), + image: safe_keeper_spec.image.clone(), + image_pull_policy: safe_keeper_spec.image_pull_policy.clone(), + ports: get_container_ports(vec![5454, 7676]), + resources: Some( + self.extract_cpu_memory_resources( + safe_keeper_spec + .resources + .clone() + .ok_or(anyhow::anyhow!("Expected resources"))?, + )?, + ), + volume_mounts: Some(itertools::concat(vec![ + get_local_data_volume_mounts(), + secret_volume_mounts, + ])), + command: Some(vec!["/bin/sh".to_string(), "-c".to_string()]), + args: Some(vec![launch_command]), + env: { + let additional_env_vars = vec![ + EnvVar { + name: "MY_NAMESPACE".to_string(), + value_from: Some(EnvVarSource { + field_ref: Some(ObjectFieldSelector { + field_path: "metadata.namespace".to_string(), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }, + EnvVar { + name: "STORAGE_CONTROLLER_URL".to_string(), + value: Some(format!( + "http://{}:{}", + self.hcc_dns_name, self.hcc_listening_port + )), + ..Default::default() + }, + EnvVar { + name: HADRON_NODE_IP_ADDRESS.to_string(), + value_from: Some(EnvVarSource { + field_ref: Some(ObjectFieldSelector { + field_path: "status.podIP".to_string(), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }, + ]; + + get_env_vars(object_storage_config, additional_env_vars) + }, + ..Default::default() + }], + ..Default::default() + }), + }, + ..Default::default() + }, + status: None, + }; + stateful_sets.push(stateful_set); + } + + let service = get_service( + "safe-keeper".to_string(), + self.namespace.clone(), + 5454, + 7676, + ); + + Ok(SafeKeeperObjs { + stateful_sets, + service, + }) + } + + /// Returns the StatefulSets and Service objects for PageServers. + fn get_page_server_objs( + &self, + storage: &HadronCluster, + billing_metrics_conf: Option, + ) -> anyhow::Result { + let hadron_cluster_spec = storage + .hadron_cluster_spec + .as_ref() + .ok_or(anyhow::anyhow!("Expected HadronCluster spec"))?; + let page_server_specs = hadron_cluster_spec + .page_server_specs + .as_ref() + .ok_or(anyhow::anyhow!("Expected PageServer spec"))?; + let object_storage_config = hadron_cluster_spec + .object_storage_config + .as_ref() + .ok_or(anyhow::anyhow!("Expected ObjectStorageConfig"))?; + + let billing_metrics_configs = + billing_metrics_conf.map_or("".to_string(), move |conf| conf.to_toml()); + + let mut image_puller_daemonsets: Vec = Vec::new(); + let mut stateful_sets: Vec = Vec::new(); + for page_server_spec in page_server_specs { + // First, generate an "image puller" DaemonSet for this pageserver pool. + // The image puller DaemonSet's purpose is to pre-download the new pageserver image to the nodes running this pageserver + // pool before we start shutting down pageservers for the rolling upgrade. This speeds up rolling upgrades and reduces + // downtime between pageserver shutdown and restart significantly. + let image_pull_ds_name = + format!("image-puller-ps-{}", page_server_spec.pool_id.unwrap_or(0)); + let image_puller_daemonset = self.generate_image_puller_daemonset( + &image_pull_ds_name, + &page_server_spec.image, + &page_server_spec.image_pull_secrets, + &page_server_spec.node_selector, + &page_server_spec.availability_zone_suffix, + &page_server_spec.image_prepull_parallelism, + ); + match image_puller_daemonset { + Ok(d) => { + image_puller_daemonsets.push(ImagePullerDaemonsetInfo { + daemonset: d, + image_prepull_timeout: page_server_spec + .image_prepull_timeout_seconds + .map(Duration::from_secs), + }); + } + Err(e) => { + // We may not be able to generate an image puller DaemonSet manifest if the cluster config is missing crucial + // fields (e.g., doesn't specify an image). When this happens we just skip it with a warning but don't fail + // anything else. The ImagePullJob is just a performance optimization and it is not strictly required for the + // system to function. Besides, any error here due to malformed cluster spec will likely result in a more permant + // error down below where we construct the main workload, the pageserver StatefulSet manifest, so we will bail + // out there with a more informative error message if it comes to that. + tracing::warn!( + "Failed to generate image puller daemonset for page server pool {}, skipping: {:?}", + page_server_spec.pool_id.unwrap_or(0), + e + ); + } + }; + + // Now generate the pageserver StatefulSet spec. + // Hacky way to convert from Gi to bytes + let storage_quantity = page_server_spec + .resources + .clone() + .ok_or(anyhow::anyhow!("Expected resources"))? + .limits + .clone() + .ok_or(anyhow::anyhow!("Expected resource limits"))?["storage"] + .clone(); + let mut storage_size_bytes_str: String = + serde_json::to_string(&storage_quantity)?.replace("Gi", ""); + storage_size_bytes_str.pop(); + storage_size_bytes_str.remove(0); + let storage_size_bytes = storage_size_bytes_str.parse::()? * 1073741824; + + let transformed_pool_id = transform_pool_id(page_server_spec.pool_id); + // Extract any additional pageserver configs that we should append to pageserver.toml. + let additional_pageserver_configs = page_server_spec + .custom_pageserver_toml + .clone() + .unwrap_or_default(); + let remote_storage_opt = self.get_remote_storage_startup_args(object_storage_config); + let token_verification_key_mount_path = + brickstore_internal_token_verification_key_mount_path(); + let wal_receiver_protocol = if is_dev_or_staging() { + "wal_receiver_protocol = {type='interpreted', args={format='protobuf', compression={zstd={level=1} } } }" + } else { + "" + }; + let s3_fault_injection = if is_chaos_testing() { + "test_remote_failures = 10000 \n test_remote_failures_probability = 20" + } else { + "" + }; + + let image_layer_force_creation_period = if is_dev_or_staging() { + ", image_layer_force_creation_period='1d'" + } else { + "" + }; + // PS prefers streaming WALs from SK replica in the same AZ. + // It also notifies SC about its AZ so that SC can optimize its placements, + // e.g., placing primary and secondary in different AZs, + // co-locating the primary shard in the same AZ as the tenant. + let availability_zone = if page_server_spec.availability_zone_suffix.is_some() { + format!( + "availability_zone='az-{}'", + page_server_spec.availability_zone_suffix.clone().unwrap() + ) + } else { + "".to_string() + }; + + let launch_command = format!( + r#"#!/bin/sh +# Extract the ordinal index from the hostname (e.g. "page-server-0" -> 0) +ordinal=$(hostname | rev | cut -d- -f1 | rev) + +# Set the id and prefix_in_bucket dynamically +PAGESERVER_ID=$((ordinal + {transformed_pool_id})) +# Do NOT specify a leading / in the prefix (this breaks Azure). +PREFIX_IN_BUCKET="pageserver/" + +# Compute the in-cluster FQDN of the page server +PAGESERVER_FQDN="${{HOSTNAME}}.page-server.${{NAMESPACE}}.svc.cluster.local" + +# Write the page server identity file. +cat < /data/.neon/identity.toml +id=$PAGESERVER_ID +EOF + +# Create the page server metadata.json file so that it auto-registers with the storage controller. +cat < /data/.neon/metadata.json +{{ + "host": "$PAGESERVER_FQDN", + "port": 6400, + "http_host": "$PAGESERVER_FQDN", + "http_port": 9898, + "other": {{}} +}} +EOF + +# Write the pageserver.toml config file. +cat < /data/.neon/pageserver.toml +pg_distrib_dir='/usr/local/' +pg_auth_type='HadronJWT' +auth_validation_public_key_path='{token_verification_key_mount_path}' +broker_endpoint='$BROKER_ENDPOINT' +control_plane_api='$STORAGE_CONTROLLER_ENDPOINT' +listen_pg_addr='0.0.0.0:6400' +listen_http_addr='0.0.0.0:9898' +max_file_descriptors=10000 +remote_storage={remote_storage_opt} +{s3_fault_injection} +ephemeral_bytes_per_memory_kb=512 +disk_usage_based_eviction={{max_usage_pct=80, min_avail_bytes=$MIN_DISK_AVAIL_BYTES, period='1m'}} +tenant_config = {{checkpoint_distance=1_073_741_824, compaction_target_size=134_217_728 {image_layer_force_creation_period}}} +{availability_zone} +{billing_metrics_configs} +{wal_receiver_protocol} +{additional_pageserver_configs} +EOF + +# Make sure SIGTERM received by the shell is propagated to the pageserver process. +shutdown() {{ + echo "Shutting down pageserver running at pid $pid" + kill -TERM $pid + wait $pid +}} + +trap shutdown TERM + +# Start the pageserver binary. +/usr/local/bin/pageserver -D /data/.neon/ & +pid=$! + +wait +"# + ); + + // Get Brickstore secerts we need to mount to page servers. Currently it's just the token verification keys, + // and the azure service principal secret (when appropriate). + let mut k8s_mounts = vec![brickstore_internal_token_verification_key_secret_mount()]; + + if object_storage_config.is_azure() { + k8s_mounts.push(azure_storage_account_service_principal_secret_mount()) + } + + let K8sSecretVolumesAndMounts { + volumes: secret_volumes, + volume_mounts: secret_volume_mounts, + } = self.get_hadron_volumes_and_mounts(k8s_mounts); + + let stateful_set = StatefulSet { + metadata: meta::v1::ObjectMeta { + annotations: Some( + vec![ + // Set up PersistentPodState for page servers so that they are re-scheduled to the same nodes upon restarts, + // if possible. This helps performance because page servers use local SSDs for caches and it would be nice + // to not lose this cache due to k8s moving Pods around. + // Note that PersistentPodState is an OpenKruise AdvancedStatefulSet feature. + // https://openkruise.io/docs/user-manuals/persistentpodstate/#annotation-auto-generate-persistentpodstate + ( + "kruise.io/auto-generate-persistent-pod-state".to_string(), + "true".to_string(), + ), + ( + "kruise.io/preferred-persistent-topology".to_string(), + "kubernetes.io/hostname".to_string(), + ), + ( + DrainAndFillManager::DRAIN_AND_FILL_KEY.to_string(), + page_server_spec + .use_drain_and_fill + .unwrap_or(DrainAndFillManager::DRAIN_AND_FILL_DEFAULT) + .to_string(), + ), + ] + .into_iter() + .collect(), + ), + name: Some(format!( + "{}-{}", + "page-server", + page_server_spec.pool_id.unwrap_or(0) + )), + namespace: Some(self.namespace.clone()), + ..Default::default() + }, + spec: StatefulSetSpec { + // NB: The pageservers persistent volumes are instance-bound, making + // them effectively ephemeral. We want the corresponding PVCs to be automatically + // cleaned up on deletion of the StatefulSet or scaling down since we don't + // have to worry about data loss. + persistent_volume_claim_retention_policy: Some( + StatefulSetPersistentVolumeClaimRetentionPolicy { + when_deleted: Some("Delete".to_string()), + when_scaled: Some("Delete".to_string()), + }, + ), + replicas: page_server_spec.replicas, + selector: meta::v1::LabelSelector { + match_labels: Some( + vec![("app".to_string(), "page-server".to_string())] + .into_iter() + .collect(), + ), + ..Default::default() + }, + service_name: Some("page-server".to_string()), + volume_claim_templates: get_volume_claim_template( + page_server_spec.resources.clone(), + page_server_spec.storage_class_name.clone(), + )?, + update_strategy: Some(StatefulSetUpdateStrategy { + r#type: Some("RollingUpdate".to_string()), + rolling_update: Some(StatefulSetUpdateStrategyRollingUpdate { + // TODO(william.huang): Evaluate whether this 30-second wait between pod restarts is warranted. + min_ready_seconds: Some(30), + max_unavailable: Some(IntOrString::Int(1)), + in_place_update_strategy: None, + partition: None, + // Use the "InPlaceIfPossible" pod update policy to avoid recreating the Pod when updating container images. + // Recreating the Pod results in Pod IP changes, which can be disruptive to compute nodes who uses a Cloud + // DNS mechanism to locate the pageservers. The compute nodes can experience downtime of O(10 sec) on every + // pageserver IP address change due to DNS record TTLs. Recreating the Pod also re-subjects the Pod to cluster + // admission control and CNI (network address assignment) delays, which could also lead to potential (and + // unnecessary) downtime for routine container image updates. + // + // Note that the Pod won't be updated in-place (and will be recreated) if any fields unsupported by the + // "InPlaceIfPossible" policy are changed. Notably, the launch `command` and `arg` fields of the containers + // are NOT supported in in-place updates. + pod_update_policy: Some("InPlaceIfPossible".to_string()), + paused: None, + unordered_update: None, + }), + }), + template: PodTemplateSpec { + metadata: get_pod_metadata("page-server".to_string(), 9898), + spec: Some(PodSpec { + affinity: self.compute_node_affinity( + page_server_spec.node_selector.as_ref(), + page_server_spec.availability_zone_suffix.as_ref(), + ), + tolerations: self.node_selector_requirement_to_tolerations( + page_server_spec.node_selector.as_ref(), + ), + image_pull_secrets: page_server_spec.image_pull_secrets.clone(), + service_account_name: hadron_cluster_spec.service_account_name.clone(), + security_context: get_pod_security_context(), + // Set the pririty class to the very-high-priority "pg-compute", which should allow + // the safekeeper to preempt all other pods on the same nodes (including log daemon) + // if we run low on resources for whatever reason. + priority_class_name: Some("pg-compute".to_string()), + // The "InPlaceUpdateReady" readiness gate is required to use the "InPlaceIfPossible" pod update policy. + // See https://openkruise.io/docs/v1.6/user-manuals/advancedstatefulset/#in-place-update for details. + readiness_gates: Some(vec![PodReadinessGate { + condition_type: "InPlaceUpdateReady".to_string(), + }]), + volumes: Some(secret_volumes), + containers: vec![Container { + name: "page-server".to_string(), + image: page_server_spec.image.clone(), + image_pull_policy: page_server_spec.image_pull_policy.clone(), + ports: get_container_ports(vec![6400, 9898]), + resources: Some( + self.extract_cpu_memory_resources( + page_server_spec + .resources + .clone() + .ok_or(anyhow::anyhow!("Expected resources"))?, + )?, + ), + volume_mounts: Some(itertools::concat(vec![ + get_local_data_volume_mounts(), + secret_volume_mounts, + ])), + command: Some(vec!["/bin/sh".to_string(), "-c".to_string()]), + args: Some(vec![launch_command]), + env: { + let additional_env_vars = vec![ + EnvVar { + name: "STORAGE_CONTROLLER_ENDPOINT".to_string(), + value: Some(format!( + "http://{}:{}/upcall/v1/", + self.hcc_dns_name, self.hcc_listening_port + )), + ..Default::default() + }, + EnvVar { + name: "MIN_DISK_AVAIL_BYTES".to_string(), + value: Some(format!("{}", storage_size_bytes / 5)), + ..Default::default() + }, + EnvVar { + name: "NAMESPACE".to_string(), + value_from: Some(EnvVarSource { + field_ref: Some(ObjectFieldSelector { + field_path: "metadata.namespace".to_string(), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }, + EnvVar { + name: HADRON_NODE_IP_ADDRESS.to_string(), + value_from: Some(EnvVarSource { + field_ref: Some(ObjectFieldSelector { + field_path: "status.podIP".to_string(), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }, + ]; + + get_env_vars(object_storage_config, additional_env_vars) + }, + ..Default::default() + }], + ..Default::default() + }), + }, + ..Default::default() + }, + status: None, + }; + stateful_sets.push(stateful_set); + } + + let service = get_service( + "page-server".to_string(), + self.namespace.clone(), + 6400, + 9898, + ); + Ok(PageServerObjs { + image_puller_daemonsets, + stateful_sets, + service, + }) + } + + pub async fn get_http_urls_for_compute_services(&self, service_names: Vec) -> Vec { + let namespace = self + .pg_params + .read() + .expect("pg_params lock poisoned") + .compute_namespace + .clone(); + + service_names + .into_iter() + .map(|svc_name| { + Url::parse(&format!( + "http://{svc_name}.{namespace}.svc.cluster.local.:80" + )) + .unwrap() + }) + .collect() + } +} + +#[async_trait] +impl K8sManager for K8sManagerImpl { + fn get_client(&self) -> Arc { + self.client.clone() + } + + fn get_current_pg_params(&self) -> Result { + // Try to acquire the read lock + let read_lock = self + .pg_params + .read() + .map_err(|err| anyhow!("Failed to acquire read lock due to err: {:?}", err))?; + + Ok((*read_lock).clone()) + } + + fn set_pg_params(&self, params: PgParams) -> Result<(), anyhow::Error> { + // Try to acquire the write lock + let mut write_lock = self + .pg_params + .write() + .map_err(|err| anyhow!("Failed to acquire write lock due to err: {:?}", err))?; + + // If we were able to then override the params + *write_lock = params; + Ok(()) + } + + async fn deploy_compute(&self, pg_compute: PgCompute) -> kube::Result<()> { + self.deploy_compute(pg_compute).await + } + + async fn delete_compute( + &self, + pg_compute_name: &str, + model: ComputeModel, + ) -> kube::Result { + self.delete_compute(pg_compute_name, model).await + } + + async fn get_http_urls_for_compute_services(&self, service_names: Vec) -> Vec { + self.get_http_urls_for_compute_services(service_names).await + } + + async fn get_databricks_compute_settings( + &self, + workspace_url: Option, + ) -> DatabricksSettings { + self.get_databricks_compute_settings(workspace_url).await + } + + async fn get_instance_primary_ingress_service( + &self, + instance_id: Uuid, + ) -> kube::Result { + self.get_ingress_service(&Self::instance_primary_ingress_service_name(instance_id)) + .await + } + + async fn create_or_patch_instance_primary_ingress_service( + &self, + instance_id: Uuid, + compute_id: Uuid, + service_type: K8sServiceType, + ) -> kube::Result { + self.create_or_patch_ingress_service( + &Self::instance_primary_ingress_service_name(instance_id), + compute_id, + service_type, + ) + .await + } + + async fn create_or_patch_readable_secondary_ingress_service( + &self, + instance_id: Uuid, + ) -> kube::Result { + self.create_if_not_exists_readable_secondary_ingress( + Self::instance_read_only_ingress_service_name(instance_id), + instance_id, + ) + .await + } + + async fn delete_instance_primary_ingress_service( + &self, + instance_id: Uuid, + ) -> kube::Result { + self.delete_ingress_service(&Self::instance_primary_ingress_service_name(instance_id)) + .await + } +} + +#[cfg(test)] +mod tests { + use crate::hadron_k8s::{CloudProvider, BRICKSTORE_POOL_TYPES_LABEL_KEY}; + use camino_tempfile::Utf8TempDir; + use compute_api::spec::{DatabricksSettings, PgComputeTlsSettings}; + use http::{Request, Response}; + use k8s_openapi::api::core::v1::ConfigMapVolumeSource; + use k8s_openapi::api::core::v1::{ + NodeSelectorRequirement, SecretVolumeSource, Volume, VolumeMount, + }; + use kube::client::Body; + use kube::Client; + use reqwest::Url; + use std::fs::File; + use std::io::Write; + use std::sync::Arc; + + use crate::hadron_k8s::HadronObjectStorageConfig; + use crate::hadron_k8s::{ + endpoint_default_resources, select_node_group_by_tshirt_size, BrcDbletNodeGroup, K8sMount, + MountType, PgParams, + }; + use hcc_api::models::EndpointTShirtSize; + use tower_test::mock; + + use super::K8sManagerImpl; + + use super::ConfigData; + + // Test PgParams parsing behavior. + #[test] + fn test_pg_params_parsing() { + // Test that PgParams parsing fails when the input is empty or malformed. + assert!(serde_json::from_str::("").is_err()); + assert!(serde_json::from_str::("{").is_err()); + + // Sanity check that PgParams now contains required fields and there is no "default" behavior for + // required fields. + assert!(serde_json::from_str::("{}").is_err()); + // Test that PgParams parsing fails when a required field is mis-spelled ("comptue_image" instead of "compute_image"). + assert!(serde_json::from_str::( + r#"{ + "compute_namespace": "hadron-compute", + "comptue_image": "hadron/compute-image:v1", + "prometheus_exporter_image": "hadron/prometheus-exporter-image:v1" + }"# + ) + .is_err()); + + // Tests that once we can parse out the required fields, the rest of the feilds are optional and + // fall back to expected default values. + let parsed_config: PgParams = serde_json::from_str( + r#"{ + "compute_namespace": "hadron-compute", + "compute_image": "hadron/compute-image:v1", + "prometheus_exporter_image": "hadron/prometheus-exporter-image:v1" + }"#, + ) + .unwrap(); + assert_eq!( + parsed_config.compute_namespace, + "hadron-compute".to_string() + ); + assert_eq!( + parsed_config.compute_image, + "hadron/compute-image:v1".to_string() + ); + assert_eq!( + parsed_config.prometheus_exporter_image, + "hadron/prometheus-exporter-image:v1".to_string() + ); + assert_eq!( + parsed_config.compute_image_pull_secret, + Some("harbor-image-pull-secret".to_string()) + ); + assert_eq!(parsed_config.compute_pg_port, Some(55432)); + assert_eq!(parsed_config.compute_http_port, Some(55433)); + assert_eq!( + parsed_config.compute_mounts, + Some(vec![ + K8sMount { + name: "brickstore-internal-token-public-keys".to_string(), + mount_type: MountType::Secret, + mount_path: "/databricks/secrets/brickstore-internal-token-public-keys" + .to_string(), + files: vec!["key1.pem".to_string(), "key2.pem".to_string()], + }, + K8sMount { + name: "brickstore-domain-certs".to_string(), + mount_type: MountType::Secret, + mount_path: "/databricks/secrets/brickstore-domain-certs".to_string(), + files: vec!["server.key".to_string(), "server.crt".to_string()], + }, + K8sMount { + name: "trusted-ca-certificates".to_string(), + mount_type: MountType::Secret, + mount_path: "/databricks/secrets/trusted-ca".to_string(), + files: vec!["data-plane-misc-root-ca-cert.pem".to_string()], + }, + K8sMount { + name: "pg-compute-config".to_string(), + mount_type: MountType::ConfigMap, + mount_path: "/databricks/pg_config".to_string(), + files: vec![ + "databricks_pg_hba.conf".to_string(), + "databricks_pg_ident.conf".to_string() + ], + } + ]) + ); + assert_eq!( + parsed_config.pg_compute_tls_settings, + Some(PgComputeTlsSettings { + key_file: "/databricks/secrets/brickstore-domain-certs/server.key".to_string(), + cert_file: "/databricks/secrets/brickstore-domain-certs/server.crt".to_string(), + ca_file: "/databricks/secrets/trusted-ca/data-plane-misc-root-ca-cert.pem" + .to_string(), + }) + ); + assert_eq!( + parsed_config.databricks_pg_hba, + Some("/databricks/pg_config/databricks_pg_hba.conf".to_string()) + ); + assert_eq!( + parsed_config.databricks_pg_ident, + Some("/databricks/pg_config/databricks_pg_ident.conf".to_string()) + ); + } + + // Test storage billing params serialization to TOML. + #[test] + fn test_page_server_billing_metrics_config_to_toml() { + let mut billing_metrics_config = super::PageServerBillingMetricsConfig { + metric_collection_endpoint: Some("http://localhost:8080/metrics".to_string()), + metric_collection_interval: Some("5 min".to_string()), + synthetic_size_calculation_interval: Some("5 min".to_string()), + }; + + let mut toml_str = billing_metrics_config.to_toml(); + assert_eq!( + toml_str, + r#"metric_collection_endpoint = "http://localhost:8080/metrics" +metric_collection_interval = "5 min" +synthetic_size_calculation_interval = "5 min" +"# + ); + + billing_metrics_config.metric_collection_endpoint = None; + toml_str = billing_metrics_config.to_toml(); + + assert_eq!( + toml_str, + r#"metric_collection_interval = "5 min" +synthetic_size_calculation_interval = "5 min" +"# + ); + } + + // Test demonstrating the cluster-config ConfigMap parsing behavior. + #[test] + fn test_cluster_config_parsing() { + let parsed_config_data: ConfigData = serde_json::from_str( + r#"{ + "hadron_cluster": { + "hadron_cluster_spec": { + "storage_broker_spec": { + "image": "sb-image" + }, + "safe_keeper_specs": [ + { + "pool_id": 0, + "replicas": 1, + "image": "sk-image", + "storage_class_name": "sk-storage-class" + } + ], + "page_server_specs": [ + { + "pool_id": 0, + "replicas": 1, + "image": "ps-image", + "storage_class_name": "ps-storage-class" + + } + ] + } + }, + "pg_params": { + "compute_namespace": "test-ns", + "compute_image": "compute-image", + "prometheus_exporter_image": "prometheus-exporter-image", + "compute_image_pull_secret": "test-image-pull-secret", + "compute_mounts" : [ + { + "name": "secret", + "mount_path": "/databricks/secrets/secret", + "mount_type": "secret", + "files": ["file1", "file2"] + }, + { + "name": "another-secret", + "mount_path": "/databricks/secrets/another-secret", + "mount_type": "secret", + "files": ["another_file1", "another_file2"] + }, + { + "name": "config-map", + "mount_type": "config_map", + "mount_path": "/databricks/pg_config", + "files": ["config1", "config2"] + } + ], + "pg_compute_tls_settings": { + "key_file": "/databricks/secrets/some-directory/server.key", + "cert_file": "/databricks/secrets/some-directory/server.crt", + "ca_file": "/databricks/secrets/some-directory/ca.crt" + }, + "databricks_pg_hba": "/databricks/pg_config/hba", + "databricks_pg_ident": "/databricks/pg_config/ident" + }, + "page_server_billing_metrics_config": { + "metric_collection_endpoint": "http://localhost:8080/metrics", + "metric_collection_interval": "5 min", + "synthetic_size_calculation_interval": "5 min" + } + }"#, + ) + .unwrap(); + + let hadron_cluster_spec = parsed_config_data + .hadron_cluster + .unwrap() + .hadron_cluster_spec + .unwrap(); + + let storage_broker_spec = hadron_cluster_spec.storage_broker_spec.unwrap(); + assert_eq!(storage_broker_spec.image, Some("sb-image".to_string())); + + let safe_keeper_specs = hadron_cluster_spec.safe_keeper_specs.unwrap(); + let safe_keeper_spec = safe_keeper_specs.first().unwrap(); + assert_eq!(safe_keeper_spec.pool_id, Some(0)); + assert_eq!(safe_keeper_spec.replicas, Some(1)); + assert_eq!(safe_keeper_spec.image, Some("sk-image".to_string())); + assert_eq!( + safe_keeper_spec.storage_class_name, + Some("sk-storage-class".to_string()) + ); + + let page_server_specs = hadron_cluster_spec.page_server_specs.unwrap(); + let page_server_spec = page_server_specs.first().unwrap(); + assert_eq!(page_server_spec.pool_id, Some(0)); + assert_eq!(page_server_spec.replicas, Some(1)); + assert_eq!(page_server_spec.image, Some("ps-image".to_string())); + assert_eq!( + page_server_spec.storage_class_name, + Some("ps-storage-class".to_string()) + ); + + let pg_params = parsed_config_data.pg_params.unwrap(); + assert_eq!(pg_params.compute_namespace, "test-ns".to_string()); + assert_eq!(pg_params.compute_image, "compute-image".to_string()); + assert_eq!( + pg_params.prometheus_exporter_image, + "prometheus-exporter-image".to_string() + ); + assert_eq!( + pg_params.compute_image_pull_secret, + Some("test-image-pull-secret".to_string()) + ); + assert_eq!( + pg_params.compute_mounts, + Some(vec![ + K8sMount { + name: "secret".to_string(), + mount_type: MountType::Secret, + mount_path: "/databricks/secrets/secret".to_string(), + files: vec!["file1".to_string(), "file2".to_string()], + }, + K8sMount { + name: "another-secret".to_string(), + mount_type: MountType::Secret, + mount_path: "/databricks/secrets/another-secret".to_string(), + files: vec!["another_file1".to_string(), "another_file2".to_string()], + }, + K8sMount { + name: "config-map".to_string(), + mount_type: MountType::ConfigMap, + mount_path: "/databricks/pg_config".to_string(), + files: vec!["config1".to_string(), "config2".to_string()], + } + ]) + ); + assert_eq!( + pg_params.pg_compute_tls_settings, + Some(PgComputeTlsSettings { + key_file: "/databricks/secrets/some-directory/server.key".to_string(), + cert_file: "/databricks/secrets/some-directory/server.crt".to_string(), + ca_file: "/databricks/secrets/some-directory/ca.crt".to_string() + }) + ); + assert_eq!( + pg_params.databricks_pg_hba, + Some("/databricks/pg_config/hba".to_string()) + ); + assert_eq!( + pg_params.databricks_pg_ident, + Some("/databricks/pg_config/ident".to_string()) + ); + } + + #[test] + fn test_node_group_to_string() { + assert_eq!(BrcDbletNodeGroup::Dblet2C.to_string(), "dbletbrc2c"); + assert_eq!(BrcDbletNodeGroup::Dblet4C.to_string(), "dbletbrc4c"); + assert_eq!(BrcDbletNodeGroup::Dblet8C.to_string(), "dbletbrc8c"); + assert_eq!(BrcDbletNodeGroup::Dblet16C.to_string(), "dbletbrc16c"); + } + + #[test] + fn test_node_group_selection() { + assert_eq!( + select_node_group_by_tshirt_size(&EndpointTShirtSize::XSmall), + BrcDbletNodeGroup::Dblet2C + ); + assert_eq!( + select_node_group_by_tshirt_size(&EndpointTShirtSize::Small), + BrcDbletNodeGroup::Dblet4C + ); + assert_eq!( + select_node_group_by_tshirt_size(&EndpointTShirtSize::Medium), + BrcDbletNodeGroup::Dblet8C + ); + assert_eq!( + select_node_group_by_tshirt_size(&EndpointTShirtSize::Large), + BrcDbletNodeGroup::Dblet16C + ); + assert_eq!( + select_node_group_by_tshirt_size(&EndpointTShirtSize::Test), + BrcDbletNodeGroup::Dblet4C + ); + } + + #[test] + fn test_default_endpoint_config() { + let default_resources = endpoint_default_resources(); + let requests = default_resources + .requests + .clone() + .expect("resources.requests exist"); + assert_eq!( + requests.get("cpu").expect("resources.requests.cpu exist").0, + "500m" + ); + assert_eq!( + requests + .get("memory") + .expect("resources.requests.memory exist") + .0, + "4Gi" + ); + assert!(default_resources.limits.is_none()); + } + + #[tokio::test] + async fn test_node_affinity_generation() { + // We don't really use the k8s client in this test, so just use a non-functional mock client so that we can construct a K8sManager object. + let (mock_service, mut _handle) = mock::pair::, Response>(); + let mock_client = Arc::new(Client::new(mock_service, "default")); + + let test_k8s_manager = + K8sManagerImpl::new_for_test(mock_client, "test-region-1".to_string(), None); + + // Test the trivial case. + assert!(test_k8s_manager.compute_node_affinity(None, None).is_none()); + + let node_group_req = NodeSelectorRequirement { + key: BRICKSTORE_POOL_TYPES_LABEL_KEY.to_string(), + operator: "In".to_string(), + values: Some(vec!["brc16cn".to_string()]), + }; + + // Test that the topology.kubernetes.io/zone label is not added to "matchExpressions" when the availability zone suffix is not specified. + assert_eq!( + test_k8s_manager + .compute_node_affinity(Some(&node_group_req), None) + .unwrap(), + serde_json::from_str( + r#"{ + "nodeAffinity": { + "requiredDuringSchedulingIgnoredDuringExecution": { + "nodeSelectorTerms": [ + { + "matchExpressions": [ + { + "key": "brickstore-pool-types", + "operator": "In", + "values": ["brc16cn"] + } + ] + } + ] + } + } + }"# + ) + .unwrap() + ); + + // Test that specifying the availability zone suffix results in the correct "topology.kubernetes.io/zone" match expression to be added. + assert_eq!( + test_k8s_manager + .compute_node_affinity(Some(&node_group_req), Some(&"a".to_string())) + .unwrap(), + serde_json::from_str( + r#"{ + "nodeAffinity": { + "requiredDuringSchedulingIgnoredDuringExecution": { + "nodeSelectorTerms": [ + { + "matchExpressions": [ + { + "key": "brickstore-pool-types", + "operator": "In", + "values": ["brc16cn"] + }, + { + "key": "topology.kubernetes.io/zone", + "operator": "In", + "values": ["test-region-1a"] + } + ] + } + ] + } + } + }"# + ) + .unwrap() + ); + } + + /// Test that K8sSecretMount is converted to k8s Volume and VolumeMount correctly. + #[tokio::test] + async fn test_get_volumes_and_mounts() { + let (mock_service, mut _handle) = mock::pair::, Response>(); + let mock_client = Arc::new(Client::new(mock_service, "default")); + let test_k8s_manager = + K8sManagerImpl::new_for_test(mock_client, "test-region-1".to_string(), None); + + let mounts = vec![ + K8sMount { + name: "secret".to_string(), + mount_type: MountType::Secret, + mount_path: "/databricks/secrets/dir1".to_string(), + files: vec!["file1".to_string(), "file2".to_string()], + }, + K8sMount { + name: "config-map".to_string(), + mount_type: MountType::ConfigMap, + mount_path: "/databricks/config".to_string(), + files: vec!["config1".to_string(), "config2".to_string()], + }, + ]; + + let volumes_and_mounts = test_k8s_manager.get_hadron_volumes_and_mounts(mounts); + + let expected_volumes = vec![ + Volume { + name: "secret".to_string(), + secret: Some(SecretVolumeSource { + secret_name: Some("secret".to_string()), + ..Default::default() + }), + ..Default::default() + }, + Volume { + name: "config-map".to_string(), + config_map: Some(ConfigMapVolumeSource { + name: "config-map".to_string(), + ..Default::default() + }), + ..Default::default() + }, + ]; + + let expected_volume_mounts = vec![ + VolumeMount { + name: "secret".to_string(), + read_only: Some(true), + mount_path: "/databricks/secrets/dir1".to_string(), + ..Default::default() + }, + VolumeMount { + name: "config-map".to_string(), + read_only: Some(true), + mount_path: "/databricks/config".to_string(), + ..Default::default() + }, + ]; + + assert_eq!(volumes_and_mounts.volumes, expected_volumes); + assert_eq!(volumes_and_mounts.volume_mounts, expected_volume_mounts); + } + + /// Test that get_databricks_compute_settings returns the correct settings. + #[tokio::test] + async fn test_get_databricks_compute_settings() { + let (mock_service, mut _handle) = mock::pair::, Response>(); + let mock_client = Arc::new(Client::new(mock_service, "default")); + let test_k8s_manager = + K8sManagerImpl::new_for_test(mock_client, "test-region-1".to_string(), None); + + let databricks_compute_settings = test_k8s_manager + .get_databricks_compute_settings(Some( + Url::parse("https://test-workspace.databricks.com").unwrap(), + )) + .await; + + let expected_settings = DatabricksSettings { + pg_compute_tls_settings: PgComputeTlsSettings { + key_file: "/databricks/secrets/brickstore-domain-certs/server.key".to_string(), + cert_file: "/databricks/secrets/brickstore-domain-certs/server.crt".to_string(), + ca_file: "/databricks/secrets/trusted-ca/data-plane-misc-root-ca-cert.pem" + .to_string(), + }, + databricks_pg_hba: "/databricks/pg_config/databricks_pg_hba.conf".to_string(), + databricks_pg_ident: "/databricks/pg_config/databricks_pg_ident.conf".to_string(), + databricks_workspace_host: "test-workspace.databricks.com".to_string(), + }; + + assert_eq!(databricks_compute_settings, expected_settings) + } + + // Test the functionality of `K8sManager::get_http_urls_for_compute_services()`. Just to make sure we don't have stupid + // bugs/typos that would cause the function to panic when unwrapping Url::parse() results. + #[tokio::test] + async fn test_compute_service_url_generation() { + let (mock_service, mut _handle) = mock::pair::, Response>(); + let mock_client = Arc::new(Client::new(mock_service, "default")); + let test_k8s_manager = + K8sManagerImpl::new_for_test(mock_client, "test-region-1".to_string(), None); + + let service_names = vec!["pg-abc-admin".to_string(), "pg-xyz-admin".to_string()]; + + let expected_urls = vec![ + Url::parse("http://pg-abc-admin.test-namespace.svc.cluster.local.:80").unwrap(), + Url::parse("http://pg-xyz-admin.test-namespace.svc.cluster.local.:80").unwrap(), + ]; + + let actual_urls = test_k8s_manager + .get_http_urls_for_compute_services(service_names) + .await; + + assert_eq!(actual_urls, expected_urls); + } + + #[tokio::test] + async fn test_get_remote_storage_startup_args() { + let (mock_service, mut _handle) = mock::pair::, Response>(); + let mock_client = Arc::new(Client::new(mock_service, "default")); + let test_k8s_manager = + K8sManagerImpl::new_for_test(mock_client, "test-region-1".to_string(), None); + + let aws_storage_config = HadronObjectStorageConfig { + bucket_name: Some("test-bucket".to_string()), + bucket_region: Some("us-west-2".to_string()), + ..Default::default() + }; + + let mut result = test_k8s_manager.get_remote_storage_startup_args(&aws_storage_config); + + assert_eq!(result, "{bucket_name='$S3_BUCKET_URI', bucket_region='$S3_REGION', prefix_in_bucket='$PREFIX_IN_BUCKET'}"); + + let azure_storage_config = HadronObjectStorageConfig { + storage_account_resource_id: Some("/subscriptions/123/resourceGroups/xyz/providers/Microsoft.Storage/storageAccounts/def".to_string()), + azure_tenant_id: Some("456".to_string()), + storage_container_name: Some("container".to_string()), + storage_container_region: Some("westus".to_string()), + ..Default::default() + }; + + result = test_k8s_manager.get_remote_storage_startup_args(&azure_storage_config); + + assert_eq!(result, "{storage_account='$AZURE_STORAGE_ACCOUNT_NAME', container_name='$AZURE_STORAGE_CONTAINER_NAME', container_region='$AZURE_STORAGE_CONTAINER_REGION', prefix_in_container='$PREFIX_IN_BUCKET'}"); + } + + #[tokio::test] + async fn test_build_compute_service() { + let (mock_service, mut _handle) = mock::pair::, Response>(); + let mock_client = Arc::new(Client::new(mock_service, "default")); + let mut test_k8s_manager = K8sManagerImpl::new_for_test( + mock_client.clone(), + "test-region-1".to_string(), + Some(CloudProvider::AWS), + ); + + let test_compute_name = "test-pg".to_string(); + + let mut service = + test_k8s_manager.build_loadbalancer_service(1234, test_compute_name.clone()); + + // On AWS, the service should have 3 annotations and use an externalTrafficPolicy of "Local". + assert_eq!(service.metadata.annotations.clone().unwrap().len(), 3); + // All of the annotations should contain the string "aws". + for annotation in service.metadata.annotations.clone().unwrap().keys() { + assert!(annotation.contains("aws")); + } + assert_eq!( + service.spec.unwrap().external_traffic_policy.unwrap(), + "Local" + ); + + test_k8s_manager = K8sManagerImpl::new_for_test( + mock_client.clone(), + "test-region-1".to_string(), + Some(CloudProvider::Azure), + ); + + service = test_k8s_manager.build_loadbalancer_service(1234, test_compute_name.clone()); + + // On Azure, the service should have 2 annotations and use an externalTrafficPolicy of "Cluster". + assert_eq!(service.metadata.annotations.clone().unwrap().len(), 2); + // All of the annotations should contain the string "azure". + for annotation in service.metadata.annotations.clone().unwrap().keys() { + assert!(annotation.contains("azure")); + } + assert_eq!( + service.spec.unwrap().external_traffic_policy.unwrap(), + "Cluster" + ); + // Ensure the DNS label annotation is properly set. + assert_eq!( + service + .metadata + .annotations + .clone() + .unwrap() + .get("service.beta.kubernetes.io/azure-dns-label-name") + .unwrap() + .to_string(), + test_compute_name.clone() + ); + } + + #[tokio::test] + async fn test_node_selector_requirement_to_tolerations() { + let (mock_service, mut _handle) = mock::pair::, Response>(); + let mock_client = Arc::new(Client::new(mock_service, "default")); + + let clouds = vec![CloudProvider::AWS, CloudProvider::Azure]; + + for cloud in clouds { + let test_k8s_manager = K8sManagerImpl::new_for_test( + mock_client.clone(), + "test-region-1".to_string(), + Some(cloud.clone()), + ); + + let node_selector_requirement = NodeSelectorRequirement { + key: BRICKSTORE_POOL_TYPES_LABEL_KEY.to_string(), + operator: "In".to_string(), + values: Some(vec!["brc16cn".to_string()]), + }; + + let tolerations = test_k8s_manager + .node_selector_requirement_to_tolerations(Some(&node_selector_requirement)) + .unwrap(); + + assert_eq!(tolerations.len(), 1); + assert_eq!( + tolerations[0].key, + Some("databricks.com/node-type".to_string()) + ); + assert_eq!(tolerations[0].operator, Some("Equal".to_string())); + assert_eq!(tolerations[0].value, Some("brc16cn".to_string())); + + if cloud.clone() == CloudProvider::AWS { + assert_eq!(tolerations[0].effect, Some("NoSchedule".to_string())); + } else { + assert_eq!(tolerations[0].effect, Some("PreferNoSchedule".to_string())); + } + } + } + + #[tokio::test] + async fn test_read_and_validate_cluster_config() { + // Create a temporary directory to hold test config files. The directory is cleaned up when `tmp_dir` goes out of scope. + let tmp_dir = Utf8TempDir::new().unwrap(); + + // Case 1: Nonexistent file + let bad_path = tmp_dir.path().join("nonexistent.json"); + let err = K8sManagerImpl::read_and_validate_cluster_config(bad_path.as_str()).await; + assert!(err.is_err(), "Should fail on nonexistent file"); + + // Case 2: Malformed JSON + let malformed_path = tmp_dir.path().join("malformed.json"); + { + let mut file = File::create(malformed_path.as_std_path()).unwrap(); + writeln!(file, "{{ not valid json").unwrap(); + } + let err = K8sManagerImpl::read_and_validate_cluster_config(malformed_path.as_str()).await; + assert!(err.is_err(), "Should fail on malformed JSON"); + + // Case 3: Missing required fields + let missing_fields_path = tmp_dir.path().join("missing_fields.json"); + { + let mut file = File::create(missing_fields_path.as_std_path()).unwrap(); + writeln!( + file, + r#"{{ "hadron_cluster": {{ "hadron_cluster_spec": null }} }}"# + ) + .unwrap(); + } + let err = + K8sManagerImpl::read_and_validate_cluster_config(missing_fields_path.as_str()).await; + assert!( + err.is_err(), + "Should fail on config missing required fields" + ); + + // Case 4: Happy path + let valid_path = tmp_dir.path().join("valid.json"); + { + let mut file = File::create(valid_path.as_std_path()).unwrap(); + writeln!( + file, + r#" + {{ + "hadron_cluster": {{ + "hadron_cluster_spec": {{ + "object_storage_config": {{ + "bucket_name": "mybucket" + }} + }} + }}, + "pg_params": {{ + "compute_namespace": "default", + "compute_image": "postgres:latest", + "prometheus_exporter_image": "exporter:latest" + }} + }}"# + ) + .unwrap(); + } + let result = K8sManagerImpl::read_and_validate_cluster_config(valid_path.as_str()).await; + assert!(result.is_ok(), "Should succeed on valid config"); + let (cluster, pg_params, hash, _) = result.unwrap(); + assert!( + cluster.hadron_cluster_spec.is_some(), + "Cluster spec present" + ); + assert_eq!( + pg_params.compute_namespace, "default", + "Parsed namespace correctly" + ); + assert!(!hash.is_empty(), "Hash should be populated"); + } +} diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index c478604834..ce1adc7b67 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -2776,6 +2776,11 @@ def test_timeline_disk_usage_limit(neon_env_builder: NeonEnvBuilder): "SELECT value FROM neon_perf_counters WHERE metric = 'num_configured_safekeepers'" ) assert cur.fetchone() == (1,), "Expected 1 configured safekeeper" + # Check that max_active_safekeeper_commit_lag metric exists and is zero with single safekeeper + cur.execute( + "SELECT value FROM neon_perf_counters WHERE metric = 'max_active_safekeeper_commit_lag'" + ) + assert cur.fetchone() == (0,), "Expected zero commit lag with one safekeeper" # Get the safekeeper sk = env.safekeepers[0] @@ -2819,6 +2824,11 @@ def test_timeline_disk_usage_limit(neon_env_builder: NeonEnvBuilder): "SELECT value FROM neon_perf_counters WHERE metric = 'num_configured_safekeepers'" ) assert cur.fetchone() == (1,), "Expected 1 configured safekeeper" + # Check that max_active_safekeeper_commit_lag metric exists and is zero with no active safekeepers + cur.execute( + "SELECT value FROM neon_perf_counters WHERE metric = 'max_active_safekeeper_commit_lag'" + ) + assert cur.fetchone() == (0,), "Expected zero commit lag with no active safekeepers" # Sanity check that the hanging insert is indeed still hanging. Otherwise means the circuit breaker we # implemented didn't work as expected. @@ -2933,3 +2943,77 @@ def test_global_disk_usage_limit(neon_env_builder: NeonEnvBuilder): with conn.cursor() as cur: cur.execute("select count(*) from t2") assert cur.fetchone() == (3000,) + +@pytest.mark.skip(reason="Lakebase Mode") +def test_max_active_safekeeper_commit_lag(neon_env_builder: NeonEnvBuilder): + """ + This test validates the `max_active_safekeeper_commit_lag` metric. The + strategy is to intentionally create a scenario where one safekeeper falls + behind (by pausing it with a failpoint), observe that the metric correctly + reports this lag, and then confirm that the metric returns to zero after the + lagging safekeeper catches up (once the failpoint is removed). + """ + neon_env_builder.num_safekeepers = 2 + env = neon_env_builder.init_start() + # Create branch and start endpoint + env.create_branch("test_commit_lsn_lag_failpoint") + endpoint = env.endpoints.create_start("test_commit_lsn_lag_failpoint") + # Enable neon extension and table + endpoint.safe_psql("CREATE EXTENSION IF NOT EXISTS neon") + endpoint.safe_psql("CREATE TABLE t(key int primary key, value text)") + + # Identify the lagging safekeeper and configure failpoint to pause + lagging_sk = env.safekeepers[1] + with lagging_sk.http_client() as http_cli: + http_cli.configure_failpoints(("sk-acceptor-pausable", "pause")) + + # Note: Insert could hang because the failpoint above causes the safekeepers to lose quorum. + def run_hanging_insert(): + endpoint.safe_psql("INSERT INTO t SELECT generate_series(1,500), 'payload'") + + # Start the insert in a background thread + bg_thread = threading.Thread(target=run_hanging_insert) + bg_thread.start() + + # Wait for the lag metric to become positive + def lag_is_positive(): + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT value FROM neon_perf_counters WHERE metric = 'max_active_safekeeper_commit_lag'" + ) + row = cur.fetchone() + assert row is not None, "max_active_safekeeper_commit_lag metric not found" + lag = row[0] + log.info(f"Current commit lag: {lag}") + if lag == 0.0: + raise Exception("Commit lag is still zero, trying again...") + + # Confirm that we can observe a positive lag value + wait_until(lag_is_positive) + + # Unpause the failpoint so that the safekeepers sync back up. This should also unstuck the hanging insert. + with lagging_sk.http_client() as http_cli: + http_cli.configure_failpoints(("sk-acceptor-pausable", "off")) + + # Wait for the hanging insert to complete + bg_thread.join(timeout=30) + assert not bg_thread.is_alive(), "Hanging insert did not complete within timeout" + log.info("Hanging insert is unstuck successfully") + + def lag_is_zero(): + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT value FROM neon_perf_counters WHERE metric = 'max_active_safekeeper_commit_lag'" + ) + row = cur.fetchone() + assert ( + row is not None + ), "max_active_safekeeper_commit_lag metric not found in lag_is_zero" + lag = row[0] + log.info(f"Current commit lag: {lag}") + return lag == 0.0 + + # Confirm that the lag eventually returns to zero + wait_until(lag_is_zero)