Make pull_timeline work with auth enabled.

- Make safekeeper read SAFEKEEPER_AUTH_TOKEN env variable with JWT
  token to connect to other safekeepers.
- Set it in neon_local when auth is enabled.
- Create simple rust http client supporting it, and use it in pull_timeline
  implementation.
- Enable auth in all pull_timeline tests.
- Make sk http_client() by default generate safekeeper wide token, it makes
  easier enabling auth in all tests by default.
This commit is contained in:
Arseny Sher
2024-06-13 22:47:58 +03:00
committed by Arseny Sher
parent 29a41fc7b9
commit 4feb6ba29c
10 changed files with 245 additions and 41 deletions

View File

@@ -14,6 +14,7 @@ use camino::Utf8PathBuf;
use postgres_connection::PgConnectionConfig;
use reqwest::{IntoUrl, Method};
use thiserror::Error;
use utils::auth::{Claims, Scope};
use utils::{http::error::HttpErrorBody, id::NodeId};
use crate::{
@@ -197,7 +198,7 @@ impl SafekeeperNode {
&datadir,
&self.env.safekeeper_bin(),
&args,
[],
self.safekeeper_env_variables()?,
background_process::InitialPidFile::Expect(self.pid_file()),
|| async {
match self.check_status().await {
@@ -210,6 +211,18 @@ impl SafekeeperNode {
.await
}
fn safekeeper_env_variables(&self) -> anyhow::Result<Vec<(String, String)>> {
// Generate a token to connect from safekeeper to peers
if self.conf.auth_enabled {
let token = self
.env
.generate_auth_token(&Claims::new(None, Scope::SafekeeperData))?;
Ok(vec![("SAFEKEEPER_AUTH_TOKEN".to_owned(), token)])
} else {
Ok(Vec::new())
}
}
///
/// Stop the server.
///

View File

@@ -13,7 +13,9 @@ use tokio::runtime::Handle;
use tokio::signal::unix::{signal, SignalKind};
use tokio::task::JoinError;
use toml_edit::Document;
use utils::logging::SecretString;
use std::env::{var, VarError};
use std::fs::{self, File};
use std::io::{ErrorKind, Write};
use std::str::FromStr;
@@ -287,6 +289,22 @@ async fn main() -> anyhow::Result<()> {
}
};
// Load JWT auth token to connect to other safekeepers for pull_timeline.
let sk_auth_token = match var("SAFEKEEPER_AUTH_TOKEN") {
Ok(v) => {
info!("loaded JWT token for authentication with safekeepers");
Some(SecretString::from(v))
}
Err(VarError::NotPresent) => {
info!("no JWT token for authentication with safekeepers detected");
None
}
Err(_) => {
warn!("JWT token for authentication with safekeepers is not unicode");
None
}
};
let conf = SafeKeeperConf {
workdir,
my_id: id,
@@ -307,6 +325,7 @@ async fn main() -> anyhow::Result<()> {
pg_auth,
pg_tenant_only_auth,
http_auth,
sk_auth_token,
current_thread_runtime: args.current_thread_runtime,
walsenders_keep_horizon: args.walsenders_keep_horizon,
partial_backup_enabled: args.partial_backup_enabled,

View File

@@ -0,0 +1,139 @@
//! Safekeeper http client.
//!
//! Partially copied from pageserver client; some parts might be better to be
//! united.
//!
//! It would be also good to move it out to separate crate, but this needs
//! duplication of internal-but-reported structs like WalSenderState, ServerInfo
//! etc.
use reqwest::{IntoUrl, Method, StatusCode};
use utils::{
http::error::HttpErrorBody,
id::{TenantId, TimelineId},
logging::SecretString,
};
use super::routes::TimelineStatus;
#[derive(Debug, Clone)]
pub struct Client {
mgmt_api_endpoint: String,
authorization_header: Option<SecretString>,
client: reqwest::Client,
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
/// Failed to receive body (reqwest error).
#[error("receive body: {0}")]
ReceiveBody(reqwest::Error),
/// Status is not ok, but failed to parse body as `HttpErrorBody`.
#[error("receive error body: {0}")]
ReceiveErrorBody(String),
/// Status is not ok; parsed error in body as `HttpErrorBody`.
#[error("safekeeper API: {1}")]
ApiError(StatusCode, String),
}
pub type Result<T> = std::result::Result<T, Error>;
pub trait ResponseErrorMessageExt: Sized {
fn error_from_body(self) -> impl std::future::Future<Output = Result<Self>> + Send;
}
/// If status is not ok, try to extract error message from the body.
impl ResponseErrorMessageExt for reqwest::Response {
async fn error_from_body(self) -> Result<Self> {
let status = self.status();
if !(status.is_client_error() || status.is_server_error()) {
return Ok(self);
}
let url = self.url().to_owned();
Err(match self.json::<HttpErrorBody>().await {
Ok(HttpErrorBody { msg }) => Error::ApiError(status, msg),
Err(_) => {
Error::ReceiveErrorBody(format!("http error ({}) at {}.", status.as_u16(), url))
}
})
}
}
impl Client {
pub fn new(mgmt_api_endpoint: String, jwt: Option<SecretString>) -> Self {
Self::from_client(reqwest::Client::new(), mgmt_api_endpoint, jwt)
}
pub fn from_client(
client: reqwest::Client,
mgmt_api_endpoint: String,
jwt: Option<SecretString>,
) -> Self {
Self {
mgmt_api_endpoint,
authorization_header: jwt
.map(|jwt| SecretString::from(format!("Bearer {}", jwt.get_contents()))),
client,
}
}
pub async fn timeline_status(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> Result<TimelineStatus> {
let uri = format!(
"{}/v1/tenant/{}/timeline/{}",
self.mgmt_api_endpoint, tenant_id, timeline_id
);
let resp = self.get(&uri).await?;
resp.json().await.map_err(Error::ReceiveBody)
}
pub async fn snapshot(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> Result<reqwest::Response> {
let uri = format!(
"{}/v1/tenant/{}/timeline/{}/snapshot",
self.mgmt_api_endpoint, tenant_id, timeline_id
);
self.get(&uri).await
}
async fn get<U: IntoUrl>(&self, uri: U) -> Result<reqwest::Response> {
self.request(Method::GET, uri, ()).await
}
/// Send the request and check that the status code is good.
async fn request<B: serde::Serialize, U: reqwest::IntoUrl>(
&self,
method: Method,
uri: U,
body: B,
) -> Result<reqwest::Response> {
let res = self.request_noerror(method, uri, body).await?;
let response = res.error_from_body().await?;
Ok(response)
}
/// Just send the request.
async fn request_noerror<B: serde::Serialize, U: reqwest::IntoUrl>(
&self,
method: Method,
uri: U,
body: B,
) -> Result<reqwest::Response> {
let req = self.client.request(method, uri);
let req = if let Some(value) = &self.authorization_header {
req.header(reqwest::header::AUTHORIZATION, value.get_contents())
} else {
req
};
req.json(&body).send().await.map_err(Error::ReceiveBody)
}
}

View File

@@ -1,3 +1,4 @@
pub mod client;
pub mod routes;
pub use routes::make_router;

View File

@@ -195,8 +195,9 @@ async fn timeline_pull_handler(mut request: Request<Body>) -> Result<Response<Bo
check_permission(&request, None)?;
let data: pull_timeline::Request = json_request(&mut request).await?;
let conf = get_conf(&request);
let resp = pull_timeline::handle_request(data)
let resp = pull_timeline::handle_request(data, conf.sk_auth_token.clone())
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, resp)

View File

@@ -7,7 +7,7 @@ use tokio::runtime::Runtime;
use std::time::Duration;
use storage_broker::Uri;
use utils::{auth::SwappableJwtAuth, id::NodeId};
use utils::{auth::SwappableJwtAuth, id::NodeId, logging::SecretString};
mod auth;
pub mod broker;
@@ -78,6 +78,8 @@ pub struct SafeKeeperConf {
pub pg_auth: Option<Arc<JwtAuth>>,
pub pg_tenant_only_auth: Option<Arc<JwtAuth>>,
pub http_auth: Option<Arc<SwappableJwtAuth>>,
/// JWT token to connect to other safekeepers with.
pub sk_auth_token: Option<SecretString>,
pub current_thread_runtime: bool,
pub walsenders_keep_horizon: bool,
pub partial_backup_enabled: bool,
@@ -114,6 +116,7 @@ impl SafeKeeperConf {
pg_auth: None,
pg_tenant_only_auth: None,
http_auth: None,
sk_auth_token: None,
heartbeat_timeout: Duration::new(5, 0),
max_offloader_lag_bytes: defaults::DEFAULT_MAX_OFFLOADER_LAG_BYTES,
current_thread_runtime: false,

View File

@@ -27,7 +27,10 @@ use tracing::{error, info, instrument};
use crate::{
control_file::{self, CONTROL_FILE_NAME},
debug_dump,
http::routes::TimelineStatus,
http::{
client::{self, Client},
routes::TimelineStatus,
},
safekeeper::Term,
timeline::{get_tenant_dir, get_timeline_dir, FullAccessTimeline, Timeline, TimelineError},
wal_storage::{self, open_wal_file, Storage},
@@ -36,6 +39,7 @@ use crate::{
use utils::{
crashsafe::{durable_rename, fsync_async_opt},
id::{TenantId, TenantTimelineId, TimelineId},
logging::SecretString,
lsn::Lsn,
pausable_failpoint,
};
@@ -163,6 +167,11 @@ impl FullAccessTimeline {
// We need to stream since the oldest segment someone (s3 or pageserver)
// still needs. This duplicates calc_horizon_lsn logic.
//
// We know that WAL wasn't removed up to this point because it cannot be
// removed further than `backup_lsn`. Since we're holding shared_state
// lock and setting `wal_removal_on_hold` later, it guarantees that WAL
// won't be removed until we're done.
let from_lsn = min(
shared_state.sk.state.remote_consistent_lsn,
shared_state.sk.state.backup_lsn,
@@ -255,7 +264,10 @@ pub struct DebugDumpResponse {
}
/// Find the most advanced safekeeper and pull timeline from it.
pub async fn handle_request(request: Request) -> Result<Response> {
pub async fn handle_request(
request: Request,
sk_auth_token: Option<SecretString>,
) -> Result<Response> {
let existing_tli = GlobalTimelines::get(TenantTimelineId::new(
request.tenant_id,
request.timeline_id,
@@ -264,26 +276,22 @@ pub async fn handle_request(request: Request) -> Result<Response> {
bail!("Timeline {} already exists", request.timeline_id);
}
let client = reqwest::Client::new();
let http_hosts = request.http_hosts.clone();
// Send request to /v1/tenant/:tenant_id/timeline/:timeline_id
let responses = futures::future::join_all(http_hosts.iter().map(|url| {
let url = format!(
"{}/v1/tenant/{}/timeline/{}",
url, request.tenant_id, request.timeline_id
);
client.get(url).send()
}))
.await;
// Figure out statuses of potential donors.
let responses: Vec<Result<TimelineStatus, client::Error>> =
futures::future::join_all(http_hosts.iter().map(|url| async {
let cclient = Client::new(url.clone(), sk_auth_token.clone());
let info = cclient
.timeline_status(request.tenant_id, request.timeline_id)
.await?;
Ok(info)
}))
.await;
let mut statuses = Vec::new();
for (i, response) in responses.into_iter().enumerate() {
let response = response.context(format!("fetching status from {}", http_hosts[i]))?;
response
.error_for_status_ref()
.context(format!("checking status from {}", http_hosts[i]))?;
let status: crate::http::routes::TimelineStatus = response.json().await?;
let status = response.context(format!("fetching status from {}", http_hosts[i]))?;
statuses.push((status, i));
}
@@ -303,10 +311,14 @@ pub async fn handle_request(request: Request) -> Result<Response> {
assert!(status.tenant_id == request.tenant_id);
assert!(status.timeline_id == request.timeline_id);
pull_timeline(status, safekeeper_host).await
pull_timeline(status, safekeeper_host, sk_auth_token).await
}
async fn pull_timeline(status: TimelineStatus, host: String) -> Result<Response> {
async fn pull_timeline(
status: TimelineStatus,
host: String,
sk_auth_token: Option<SecretString>,
) -> Result<Response> {
let ttid = TenantTimelineId::new(status.tenant_id, status.timeline_id);
info!(
"pulling timeline {} from safekeeper {}, commit_lsn={}, flush_lsn={}, term={}, epoch={}",
@@ -322,17 +334,11 @@ async fn pull_timeline(status: TimelineStatus, host: String) -> Result<Response>
let (_tmp_dir, tli_dir_path) = create_temp_timeline_dir(conf, ttid).await?;
let client = reqwest::Client::new();
let client = Client::new(host.clone(), sk_auth_token.clone());
// Request stream with basebackup archive.
let bb_resp = client
.get(format!(
"{}/v1/tenant/{}/timeline/{}/snapshot",
host, status.tenant_id, status.timeline_id
))
.send()
.snapshot(status.tenant_id, status.timeline_id)
.await?;
bb_resp.error_for_status_ref()?;
// Make Stream of Bytes from it...
let bb_stream = bb_resp.bytes_stream().map_err(std::io::Error::other);

View File

@@ -174,6 +174,7 @@ pub fn run_server(os: NodeOs, disk: Arc<SafekeeperDisk>) -> Result<()> {
pg_auth: None,
pg_tenant_only_auth: None,
http_auth: None,
sk_auth_token: None,
current_thread_runtime: false,
walsenders_keep_horizon: false,
partial_backup_enabled: false,

View File

@@ -3847,7 +3847,15 @@ class Safekeeper(LogUtils):
assert isinstance(res, dict)
return res
def http_client(self, auth_token: Optional[str] = None) -> SafekeeperHttpClient:
def http_client(
self, auth_token: Optional[str] = None, gen_sk_wide_token: bool = True
) -> SafekeeperHttpClient:
"""
When auth_token is None but gen_sk_wide is True creates safekeeper wide
token, which is a reasonable default.
"""
if auth_token is None and gen_sk_wide_token:
auth_token = self.env.auth_keys.generate_safekeeper_token()
is_testing_enabled = '"testing"' in self.env.get_binary_version("safekeeper")
return SafekeeperHttpClient(
port=self.port.http, auth_token=auth_token, is_testing_enabled=is_testing_enabled
@@ -4450,6 +4458,7 @@ def wait_for_last_flush_lsn(
tenant: TenantId,
timeline: TimelineId,
pageserver_id: Optional[int] = None,
auth_token: Optional[str] = None,
) -> Lsn:
"""Wait for pageserver to catch up the latest flush LSN, returns the last observed lsn."""
@@ -4463,7 +4472,7 @@ def wait_for_last_flush_lsn(
f"wait_for_last_flush_lsn: waiting for {last_flush_lsn} on shard {tenant_shard_id} on pageserver {pageserver.id})"
)
waited = wait_for_last_record_lsn(
pageserver.http_client(), tenant_shard_id, timeline, last_flush_lsn
pageserver.http_client(auth_token=auth_token), tenant_shard_id, timeline, last_flush_lsn
)
assert waited >= last_flush_lsn
@@ -4559,6 +4568,7 @@ def last_flush_lsn_upload(
tenant_id: TenantId,
timeline_id: TimelineId,
pageserver_id: Optional[int] = None,
auth_token: Optional[str] = None,
) -> Lsn:
"""
Wait for pageserver to catch to the latest flush LSN of given endpoint,
@@ -4566,11 +4576,11 @@ def last_flush_lsn_upload(
reaching flush LSN).
"""
last_flush_lsn = wait_for_last_flush_lsn(
env, endpoint, tenant_id, timeline_id, pageserver_id=pageserver_id
env, endpoint, tenant_id, timeline_id, pageserver_id=pageserver_id, auth_token=auth_token
)
shards = tenant_get_shards(env, tenant_id, pageserver_id)
for tenant_shard_id, pageserver in shards:
ps_http = pageserver.http_client()
ps_http = pageserver.http_client(auth_token=auth_token)
wait_for_last_record_lsn(ps_http, tenant_shard_id, timeline_id, last_flush_lsn)
# force a checkpoint to trigger upload
ps_http.timeline_checkpoint(tenant_shard_id, timeline_id)

View File

@@ -374,7 +374,7 @@ def test_wal_removal(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
http_cli_other = env.safekeepers[0].http_client(
auth_token=env.auth_keys.generate_tenant_token(TenantId.generate())
)
http_cli_noauth = env.safekeepers[0].http_client()
http_cli_noauth = env.safekeepers[0].http_client(gen_sk_wide_token=False)
# Pretend WAL is offloaded to s3.
if auth_enabled:
@@ -830,7 +830,7 @@ def test_timeline_status(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
auth_token=env.auth_keys.generate_tenant_token(TenantId.generate())
)
wa_http_cli_bad.check_status()
wa_http_cli_noauth = wa.http_client()
wa_http_cli_noauth = wa.http_client(gen_sk_wide_token=False)
wa_http_cli_noauth.check_status()
# debug endpoint requires safekeeper scope
@@ -964,7 +964,7 @@ def test_sk_auth(neon_env_builder: NeonEnvBuilder):
# By default, neon_local enables auth on all services if auth is configured,
# so http must require the token.
sk_http_cli_noauth = sk.http_client()
sk_http_cli_noauth = sk.http_client(gen_sk_wide_token=False)
sk_http_cli_auth = sk.http_client(auth_token=env.auth_keys.generate_tenant_token(tenant_id))
with pytest.raises(sk_http_cli_noauth.HTTPError, match="Forbidden|Unauthorized"):
sk_http_cli_noauth.timeline_status(tenant_id, timeline_id)
@@ -1640,7 +1640,7 @@ def test_delete_force(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
sk_http_other = sk.http_client(
auth_token=env.auth_keys.generate_tenant_token(tenant_id_other)
)
sk_http_noauth = sk.http_client()
sk_http_noauth = sk.http_client(gen_sk_wide_token=False)
assert (sk_data_dir / str(tenant_id) / str(timeline_id_1)).is_dir()
assert (sk_data_dir / str(tenant_id) / str(timeline_id_2)).is_dir()
assert (sk_data_dir / str(tenant_id) / str(timeline_id_3)).is_dir()
@@ -1723,7 +1723,10 @@ def test_delete_force(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
cur.execute("INSERT INTO t (key) VALUES (123)")
# Basic pull_timeline test.
def test_pull_timeline(neon_env_builder: NeonEnvBuilder):
neon_env_builder.auth_enabled = True
def execute_payload(endpoint: Endpoint):
with closing(endpoint.connect()) as conn:
with conn.cursor() as cur:
@@ -1739,7 +1742,7 @@ def test_pull_timeline(neon_env_builder: NeonEnvBuilder):
def show_statuses(safekeepers: List[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId):
for sk in safekeepers:
http_cli = sk.http_client()
http_cli = sk.http_client(auth_token=env.auth_keys.generate_tenant_token(tenant_id))
try:
status = http_cli.timeline_status(tenant_id, timeline_id)
log.info(f"Safekeeper {sk.id} status: {status}")
@@ -1769,7 +1772,7 @@ def test_pull_timeline(neon_env_builder: NeonEnvBuilder):
res = (
env.safekeepers[3]
.http_client()
.http_client(auth_token=env.auth_keys.generate_safekeeper_token())
.pull_timeline(
{
"tenant_id": str(tenant_id),
@@ -1817,6 +1820,7 @@ def test_pull_timeline(neon_env_builder: NeonEnvBuilder):
# Expected to fail while holding off WAL gc plus fetching commit_lsn WAL
# segment is not implemented.
def test_pull_timeline_gc(neon_env_builder: NeonEnvBuilder):
neon_env_builder.auth_enabled = True
neon_env_builder.num_safekeepers = 3
neon_env_builder.enable_safekeeper_remote_storage(default_remote_storage())
env = neon_env_builder.init_start()
@@ -1846,7 +1850,13 @@ def test_pull_timeline_gc(neon_env_builder: NeonEnvBuilder):
# ensure segment exists
endpoint.safe_psql("insert into t select generate_series(1, 180000), 'papaya'")
lsn = last_flush_lsn_upload(env, endpoint, tenant_id, timeline_id)
lsn = last_flush_lsn_upload(
env,
endpoint,
tenant_id,
timeline_id,
auth_token=env.auth_keys.generate_tenant_token(tenant_id),
)
assert lsn > Lsn("0/2000000")
# Checkpoint timeline beyond lsn.
src_sk.checkpoint_up_to(tenant_id, timeline_id, lsn, wait_wal_removal=False)
@@ -1886,6 +1896,7 @@ def test_pull_timeline_gc(neon_env_builder: NeonEnvBuilder):
#
# Expected to fail while term check is not implemented.
def test_pull_timeline_term_change(neon_env_builder: NeonEnvBuilder):
neon_env_builder.auth_enabled = True
neon_env_builder.num_safekeepers = 3
neon_env_builder.enable_safekeeper_remote_storage(default_remote_storage())
env = neon_env_builder.init_start()