diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index be734029df..06a34ec41e 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -413,6 +413,52 @@ struct StartVmMonitorResult { vm_monitor: Option>>, } +/// Databricks-specific environment variables to be passed to the `postgres` sub-process. +pub struct DatabricksEnvVars { + /// The Databricks "endpoint ID" of the compute instance. Used by `postgres` to check + /// the token scopes of internal auth tokens. + pub endpoint_id: String, + /// Hostname of the Databricks workspace URL this compute instance belongs to. + /// Used by postgres to verify Databricks PAT tokens. + pub workspace_host: String, +} + +impl DatabricksEnvVars { + pub fn new(compute_spec: &ComputeSpec, compute_id: Option<&String>) -> Self { + // compute_id is a string format of "{endpoint_id}/{compute_idx}" + // endpoint_id is a uuid. We only need to pass down endpoint_id to postgres. + // Panics if compute_id is not set or not in the expected format. + let endpoint_id = compute_id.unwrap().split('/').next().unwrap().to_string(); + let workspace_host = compute_spec + .databricks_settings + .as_ref() + .map(|s| s.databricks_workspace_host.clone()) + .unwrap_or("".to_string()); + Self { + endpoint_id, + workspace_host, + } + } + + /// Constants for the names of Databricks-specific postgres environment variables. + const DATABRICKS_ENDPOINT_ID_ENVVAR: &'static str = "DATABRICKS_ENDPOINT_ID"; + const DATABRICKS_WORKSPACE_HOST_ENVVAR: &'static str = "DATABRICKS_WORKSPACE_HOST"; + + /// Convert DatabricksEnvVars to a list of string pairs that can be passed as env vars. Consumes `self`. + pub fn to_env_var_list(self) -> Vec<(String, String)> { + vec![ + ( + Self::DATABRICKS_ENDPOINT_ID_ENVVAR.to_string(), + self.endpoint_id.clone(), + ), + ( + Self::DATABRICKS_WORKSPACE_HOST_ENVVAR.to_string(), + self.workspace_host.clone(), + ), + ] + } +} + impl ComputeNode { pub fn new(params: ComputeNodeParams, config: ComputeConfig) -> Result { let connstr = params.connstr.as_str(); @@ -1581,14 +1627,31 @@ impl ComputeNode { pub fn start_postgres(&self, storage_auth_token: Option) -> Result { let pgdata_path = Path::new(&self.params.pgdata); + let env_vars: Vec<(String, String)> = if self.params.lakebase_mode { + let databricks_env_vars = { + let state = self.state.lock().unwrap(); + let spec = &state.pspec.as_ref().unwrap().spec; + DatabricksEnvVars::new(spec, Some(&self.params.compute_id)) + }; + + info!( + "Starting Postgres for databricks endpoint id: {}", + &databricks_env_vars.endpoint_id + ); + + let mut env_vars = databricks_env_vars.to_env_var_list(); + env_vars.extend(storage_auth_token.map(|t| ("NEON_AUTH_TOKEN".to_string(), t))); + env_vars + } else if let Some(storage_auth_token) = &storage_auth_token { + vec![("NEON_AUTH_TOKEN".to_owned(), storage_auth_token.to_owned())] + } else { + vec![] + }; + // Run postgres as a child process. let mut pg = maybe_cgexec(&self.params.pgbin) .args(["-D", &self.params.pgdata]) - .envs(if let Some(storage_auth_token) = &storage_auth_token { - vec![("NEON_AUTH_TOKEN", storage_auth_token)] - } else { - vec![] - }) + .envs(env_vars) .stderr(Stdio::piped()) .spawn() .expect("cannot start postgres process");