Object storage proxy (#11357)

Service targeted for storing and retrieving LFC prewarm data.
Can be used for proxying S3 access for Postgres extensions like
pg_mooncake as well.

Requests must include a Bearer JWT token.
Token is validated using a pemfile (should be passed in infra/).

Note: app is not tolerant to extra trailing slashes, see app.rs
`delete_prefix` test for comments.

Resolves: https://github.com/neondatabase/cloud/issues/26342
Unrelated changes: gate a `rename_noreplace` feature and disable it in
`remote_storage` so as `object_storage` can be built with musl
This commit is contained in:
Mikhail Kot
2025-04-08 15:54:53 +01:00
committed by GitHub
parent a7142f3bc6
commit 6138d61592
23 changed files with 1424 additions and 38 deletions

561
object_storage/src/app.rs Normal file
View File

@@ -0,0 +1,561 @@
use anyhow::anyhow;
use axum::body::{Body, Bytes};
use axum::response::{IntoResponse, Response};
use axum::{Router, http::StatusCode};
use object_storage::{PrefixS3Path, S3Path, Storage, bad_request, internal_error, not_found, ok};
use remote_storage::TimeoutOrCancel;
use remote_storage::{DownloadError, DownloadOpts, GenericRemoteStorage, RemotePath};
use std::{sync::Arc, time::SystemTime, time::UNIX_EPOCH};
use tokio_util::sync::CancellationToken;
use tracing::{error, info};
use utils::backoff::retry;
pub fn app(state: Arc<Storage>) -> Router<()> {
use axum::routing::{delete as _delete, get as _get};
let delete_prefix = _delete(delete_prefix);
Router::new()
.route(
"/{tenant_id}/{timeline_id}/{endpoint_id}/{*path}",
_get(get).put(set).delete(delete),
)
.route(
"/{tenant_id}/{timeline_id}/{endpoint_id}",
delete_prefix.clone(),
)
.route("/{tenant_id}/{timeline_id}", delete_prefix.clone())
.route("/{tenant_id}", delete_prefix)
.route("/metrics", _get(metrics))
.route("/status", _get(async || StatusCode::OK.into_response()))
.with_state(state)
}
type Result = anyhow::Result<Response, Response>;
type State = axum::extract::State<Arc<Storage>>;
const CONTENT_TYPE: &str = "content-type";
const APPLICATION_OCTET_STREAM: &str = "application/octet-stream";
const WARN_THRESHOLD: u32 = 3;
const MAX_RETRIES: u32 = 10;
async fn metrics() -> Result {
prometheus::TextEncoder::new()
.encode_to_string(&prometheus::gather())
.map(|s| s.into_response())
.map_err(|e| internal_error(e, "/metrics", "collecting metrics"))
}
async fn get(S3Path { path }: S3Path, state: State) -> Result {
info!(%path, "downloading");
let download_err = |e| {
if let DownloadError::NotFound = e {
info!(%path, %e, "downloading"); // 404 is not an issue of _this_ service
return not_found(&path);
}
internal_error(e, &path, "downloading")
};
let cancel = state.cancel.clone();
let opts = &DownloadOpts::default();
let stream = retry(
async || state.storage.download(&path, opts, &cancel).await,
DownloadError::is_permanent,
WARN_THRESHOLD,
MAX_RETRIES,
"downloading",
&cancel,
)
.await
.unwrap_or(Err(DownloadError::Cancelled))
.map_err(download_err)?
.download_stream;
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from_stream(stream))
.map_err(|e| internal_error(e, path, "reading response"))
}
// Best solution for files is multipart upload, but remote_storage doesn't support it,
// so we can either read Bytes in memory and push at once or forward BodyDataStream to
// remote_storage. The latter may seem more peformant, but BodyDataStream doesn't have a
// guaranteed size() which may produce issues while uploading to s3.
// So, currently we're going with an in-memory copy plus a boundary to prevent uploading
// very large files.
async fn set(S3Path { path }: S3Path, state: State, bytes: Bytes) -> Result {
info!(%path, "uploading");
let request_len = bytes.len();
let max_len = state.max_upload_file_limit;
if request_len > max_len {
return Err(bad_request(
anyhow!("File size {request_len} exceeds max {max_len}"),
"uploading",
));
}
let cancel = state.cancel.clone();
let fun = async || {
let stream = bytes_to_stream(bytes.clone());
state
.storage
.upload(stream, request_len, &path, None, &cancel)
.await
};
retry(
fun,
TimeoutOrCancel::caused_by_cancel,
WARN_THRESHOLD,
MAX_RETRIES,
"uploading",
&cancel,
)
.await
.unwrap_or(Err(anyhow!("uploading cancelled")))
.map_err(|e| internal_error(e, path, "reading response"))?;
Ok(ok())
}
async fn delete(S3Path { path }: S3Path, state: State) -> Result {
info!(%path, "deleting");
let cancel = state.cancel.clone();
retry(
async || state.storage.delete(&path, &cancel).await,
TimeoutOrCancel::caused_by_cancel,
WARN_THRESHOLD,
MAX_RETRIES,
"deleting",
&cancel,
)
.await
.unwrap_or(Err(anyhow!("deleting cancelled")))
.map_err(|e| internal_error(e, path, "deleting"))?;
Ok(ok())
}
async fn delete_prefix(PrefixS3Path { path }: PrefixS3Path, state: State) -> Result {
info!(%path, "deleting prefix");
let cancel = state.cancel.clone();
retry(
async || state.storage.delete_prefix(&path, &cancel).await,
TimeoutOrCancel::caused_by_cancel,
WARN_THRESHOLD,
MAX_RETRIES,
"deleting prefix",
&cancel,
)
.await
.unwrap_or(Err(anyhow!("deleting prefix cancelled")))
.map_err(|e| internal_error(e, path, "deleting prefix"))?;
Ok(ok())
}
pub async fn check_storage_permissions(
client: &GenericRemoteStorage,
cancel: CancellationToken,
) -> anyhow::Result<()> {
info!("storage permissions check");
// as_nanos() as multiple instances proxying same bucket may be started at once
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)?
.as_nanos()
.to_string();
let path = RemotePath::from_string(&format!("write_access_{now}"))?;
info!(%path, "uploading");
let body = now.to_string();
let stream = bytes_to_stream(Bytes::from(body.clone()));
client
.upload(stream, body.len(), &path, None, &cancel)
.await?;
use tokio::io::AsyncReadExt;
info!(%path, "downloading");
let download_opts = DownloadOpts {
kind: remote_storage::DownloadKind::Small,
..Default::default()
};
let mut body_read_buf = Vec::new();
let stream = client
.download(&path, &download_opts, &cancel)
.await?
.download_stream;
tokio_util::io::StreamReader::new(stream)
.read_to_end(&mut body_read_buf)
.await?;
let body_read = String::from_utf8(body_read_buf)?;
if body != body_read {
error!(%body, %body_read, "File contents do not match");
anyhow::bail!("Read back file doesn't match original")
}
info!(%path, "removing");
client.delete(&path, &cancel).await
}
fn bytes_to_stream(bytes: Bytes) -> impl futures::Stream<Item = std::io::Result<Bytes>> {
futures::stream::once(futures::future::ready(Ok(bytes)))
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, extract::Request, response::Response};
use http_body_util::BodyExt;
use itertools::iproduct;
use std::env::var;
use std::sync::Arc;
use std::time::Duration;
use test_log::test as testlog;
use tower::{Service, util::ServiceExt};
use utils::id::{TenantId, TimelineId};
// see libs/remote_storage/tests/test_real_s3.rs
const REAL_S3_ENV: &str = "ENABLE_REAL_S3_REMOTE_STORAGE";
const REAL_S3_BUCKET: &str = "REMOTE_STORAGE_S3_BUCKET";
const REAL_S3_REGION: &str = "REMOTE_STORAGE_S3_REGION";
async fn proxy() -> (Storage, Option<camino_tempfile::Utf8TempDir>) {
let cancel = CancellationToken::new();
let (dir, storage) = if var(REAL_S3_ENV).is_err() {
// tests execute in parallel and we need a new directory for each of them
let dir = camino_tempfile::tempdir().unwrap();
let fs =
remote_storage::LocalFs::new(dir.path().into(), Duration::from_secs(5)).unwrap();
(Some(dir), GenericRemoteStorage::LocalFs(fs))
} else {
// test_real_s3::create_s3_client is hard to reference, reimplementing here
let millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis();
use rand::Rng;
let random = rand::thread_rng().r#gen::<u32>();
let s3_config = remote_storage::S3Config {
bucket_name: var(REAL_S3_BUCKET).unwrap(),
bucket_region: var(REAL_S3_REGION).unwrap(),
prefix_in_bucket: Some(format!("test_{millis}_{random:08x}/")),
endpoint: None,
concurrency_limit: std::num::NonZeroUsize::new(100).unwrap(),
max_keys_per_list_response: None,
upload_storage_class: None,
};
let bucket = remote_storage::S3Bucket::new(&s3_config, Duration::from_secs(1))
.await
.unwrap();
(None, GenericRemoteStorage::AwsS3(Arc::new(bucket)))
};
let proxy = Storage {
auth: object_storage::JwtAuth::new(TEST_PUB_KEY_ED25519).unwrap(),
storage,
cancel: cancel.clone(),
max_upload_file_limit: usize::MAX,
};
check_storage_permissions(&proxy.storage, cancel)
.await
.unwrap();
(proxy, dir)
}
// see libs/utils/src/auth.rs
const TEST_PUB_KEY_ED25519: &[u8] = b"
-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEARYwaNBayR+eGI0iXB4s3QxE3Nl2g1iWbr6KtLWeVD/w=
-----END PUBLIC KEY-----
";
const TEST_PRIV_KEY_ED25519: &[u8] = br#"
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
-----END PRIVATE KEY-----
"#;
async fn request(req: Request<Body>) -> Response<Body> {
let (proxy, _) = proxy().await;
app(Arc::new(proxy))
.into_service()
.oneshot(req)
.await
.unwrap()
}
#[testlog(tokio::test)]
async fn status() {
let res = Request::builder()
.uri("/status")
.body(Body::empty())
.map(request)
.unwrap()
.await;
assert_eq!(res.status(), StatusCode::OK);
}
fn routes() -> impl Iterator<Item = (&'static str, &'static str)> {
iproduct!(
vec!["/1", "/1/2", "/1/2/3", "/1/2/3/4"],
vec!["GET", "PUT", "DELETE"]
)
}
#[testlog(tokio::test)]
async fn no_token() {
for (uri, method) in routes() {
info!(%uri, %method);
let res = Request::builder()
.uri(uri)
.method(method)
.body(Body::empty())
.map(request)
.unwrap()
.await;
assert!(matches!(
res.status(),
StatusCode::METHOD_NOT_ALLOWED | StatusCode::BAD_REQUEST
));
}
}
#[testlog(tokio::test)]
async fn invalid_token() {
for (uri, method) in routes() {
info!(%uri, %method);
let status = Request::builder()
.uri(uri)
.header("Authorization", "Bearer 123")
.method(method)
.body(Body::empty())
.map(request)
.unwrap()
.await;
assert!(matches!(
status.status(),
StatusCode::METHOD_NOT_ALLOWED | StatusCode::BAD_REQUEST
));
}
}
const TENANT_ID: TenantId =
TenantId::from_array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6]);
const TIMELINE_ID: TimelineId =
TimelineId::from_array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 7]);
const ENDPOINT_ID: &str = "ep-winter-frost-a662z3vg";
fn token() -> String {
let claims = object_storage::Claims {
tenant_id: TENANT_ID,
timeline_id: TIMELINE_ID,
endpoint_id: ENDPOINT_ID.into(),
exp: u64::MAX,
};
let key = jsonwebtoken::EncodingKey::from_ed_pem(TEST_PRIV_KEY_ED25519).unwrap();
let header = jsonwebtoken::Header::new(object_storage::VALIDATION_ALGO);
jsonwebtoken::encode(&header, &claims, &key).unwrap()
}
#[testlog(tokio::test)]
async fn unauthorized() {
let (proxy, _) = proxy().await;
let mut app = app(Arc::new(proxy)).into_service();
let token = token();
let args = itertools::iproduct!(
vec![TENANT_ID.to_string(), TenantId::generate().to_string()],
vec![TIMELINE_ID.to_string(), TimelineId::generate().to_string()],
vec![ENDPOINT_ID, "ep-ololo"]
)
.skip(1);
for ((uri, method), (tenant, timeline, endpoint)) in iproduct!(routes(), args) {
info!(%uri, %method, %tenant, %timeline, %endpoint);
let request = Request::builder()
.uri(format!("/{tenant}/{timeline}/{endpoint}/sub/path/key"))
.method(method)
.header("Authorization", format!("Bearer {}", token))
.body(Body::empty())
.unwrap();
let status = ServiceExt::ready(&mut app)
.await
.unwrap()
.call(request)
.await
.unwrap()
.status();
assert_eq!(status, StatusCode::UNAUTHORIZED);
}
}
#[testlog(tokio::test)]
async fn method_not_allowed() {
let token = token();
let iter = iproduct!(vec!["", "/.."], vec!["GET", "PUT"]);
for (key, method) in iter {
let status = Request::builder()
.uri(format!("/{TENANT_ID}/{TIMELINE_ID}/{ENDPOINT_ID}{key}"))
.method(method)
.header("Authorization", format!("Bearer {token}"))
.body(Body::empty())
.map(request)
.unwrap()
.await
.status();
assert!(matches!(
status,
StatusCode::BAD_REQUEST | StatusCode::METHOD_NOT_ALLOWED
));
}
}
async fn requests_chain(
chain: impl Iterator<Item = (String, &str, &'static str, StatusCode, bool)>,
token: impl Fn(&str) -> String,
) {
let (proxy, _) = proxy().await;
let mut app = app(Arc::new(proxy)).into_service();
for (uri, method, body, expected_status, compare_body) in chain {
info!(%uri, %method, %body, %expected_status);
let bearer = format!("Bearer {}", token(&uri));
let request = Request::builder()
.uri(uri)
.method(method)
.header("Authorization", &bearer)
.body(Body::from(body))
.unwrap();
let response = ServiceExt::ready(&mut app)
.await
.unwrap()
.call(request)
.await
.unwrap();
assert_eq!(response.status(), expected_status);
if !compare_body {
continue;
}
let read_body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(body, read_body);
}
}
#[testlog(tokio::test)]
async fn metrics() {
let uri = format!("/{TENANT_ID}/{TIMELINE_ID}/{ENDPOINT_ID}/key");
let req = vec![
(uri.clone(), "PUT", "body", StatusCode::OK, false),
(uri.clone(), "DELETE", "", StatusCode::OK, false),
];
requests_chain(req.into_iter(), |_| token()).await;
let res = Request::builder()
.uri("/metrics")
.body(Body::empty())
.map(request)
.unwrap()
.await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.into_body().collect().await.unwrap().to_bytes();
let body = String::from_utf8_lossy(&body);
tracing::debug!(%body);
// Storage metrics are not gathered for LocalFs
if var(REAL_S3_ENV).is_ok() {
assert!(body.contains("remote_storage_s3_deleted_objects_total"));
}
assert!(body.contains("process_threads"));
}
#[testlog(tokio::test)]
async fn insert_retrieve_remove() {
let uri = format!("/{TENANT_ID}/{TIMELINE_ID}/{ENDPOINT_ID}/key");
let chain = vec![
(uri.clone(), "GET", "", StatusCode::NOT_FOUND, false),
(uri.clone(), "PUT", "пыщьпыщь", StatusCode::OK, false),
(uri.clone(), "GET", "пыщьпыщь", StatusCode::OK, true),
(uri.clone(), "DELETE", "", StatusCode::OK, false),
(uri, "GET", "", StatusCode::NOT_FOUND, false),
];
requests_chain(chain.into_iter(), |_| token()).await;
}
fn delete_prefix_token(uri: &str) -> String {
use serde::Serialize;
let parts = uri.split("/").collect::<Vec<&str>>();
#[derive(Serialize)]
struct PrefixClaims {
tenant_id: TenantId,
timeline_id: Option<TimelineId>,
endpoint_id: Option<object_storage::EndpointId>,
exp: u64,
}
let claims = PrefixClaims {
tenant_id: parts.get(1).map(|c| c.parse().unwrap()).unwrap(),
timeline_id: parts.get(2).map(|c| c.parse().unwrap()),
endpoint_id: parts.get(3).map(ToString::to_string),
exp: u64::MAX,
};
let key = jsonwebtoken::EncodingKey::from_ed_pem(TEST_PRIV_KEY_ED25519).unwrap();
let header = jsonwebtoken::Header::new(object_storage::VALIDATION_ALGO);
jsonwebtoken::encode(&header, &claims, &key).unwrap()
}
// Can't use single digit numbers as they won't be validated as TimelineId and EndpointId
#[testlog(tokio::test)]
async fn delete_prefix() {
let tenant_id =
TenantId::from_array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).to_string();
let t2 = TimelineId::from_array([2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
let t3 = TimelineId::from_array([3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
let t4 = TimelineId::from_array([4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
let f = |timeline, path| format!("/{tenant_id}/{timeline}{path}");
// Why extra slash in string literals? Axum is weird with URIs:
// /1/2 and 1/2/ match different routes, thus first yields OK and second NOT_FOUND
// as it matches /tenant/timeline/endpoint, see https://stackoverflow.com/a/75355932
// The cost of removing trailing slash is suprisingly hard:
// * Add tower dependency with NormalizePath layer
// * wrap Router<()> in this layer https://github.com/tokio-rs/axum/discussions/2377
// * Rewrite make_service() -> into_make_service()
// * Rewrite oneshot() (not available for NormalizePath)
// I didn't manage to get it working correctly
let chain = vec![
// create 1/2/3/4, 1/2/3/5, delete prefix 1/2/3 -> empty
(f(t2, "/3/4"), "PUT", "", StatusCode::OK, false),
(f(t2, "/3/4"), "PUT", "", StatusCode::OK, false), // we can override file contents
(f(t2, "/3/5"), "PUT", "", StatusCode::OK, false),
(f(t2, "/3"), "DELETE", "", StatusCode::OK, false),
(f(t2, "/3/4"), "GET", "", StatusCode::NOT_FOUND, false),
(f(t2, "/3/5"), "GET", "", StatusCode::NOT_FOUND, false),
// create 1/2/3/4, 1/2/5/6, delete prefix 1/2/3 -> 1/2/5/6
(f(t2, "/3/4"), "PUT", "", StatusCode::OK, false),
(f(t2, "/5/6"), "PUT", "", StatusCode::OK, false),
(f(t2, "/3"), "DELETE", "", StatusCode::OK, false),
(f(t2, "/3/4"), "GET", "", StatusCode::NOT_FOUND, false),
(f(t2, "/5/6"), "GET", "", StatusCode::OK, false),
// create 1/2/3/4, 1/2/7/8, delete prefix 1/2 -> empty
(f(t2, "/3/4"), "PUT", "", StatusCode::OK, false),
(f(t2, "/7/8"), "PUT", "", StatusCode::OK, false),
(f(t2, ""), "DELETE", "", StatusCode::OK, false),
(f(t2, "/3/4"), "GET", "", StatusCode::NOT_FOUND, false),
(f(t2, "/7/8"), "GET", "", StatusCode::NOT_FOUND, false),
// create 1/2/3/4, 1/2/5/6, 1/3/8/9, delete prefix 1/2/3 -> 1/2/5/6, 1/3/8/9
(f(t2, "/3/4"), "PUT", "", StatusCode::OK, false),
(f(t2, "/5/6"), "PUT", "", StatusCode::OK, false),
(f(t3, "/8/9"), "PUT", "", StatusCode::OK, false),
(f(t2, "/3"), "DELETE", "", StatusCode::OK, false),
(f(t2, "/3/4"), "GET", "", StatusCode::NOT_FOUND, false),
(f(t2, "/5/6"), "GET", "", StatusCode::OK, false),
(f(t3, "/8/9"), "GET", "", StatusCode::OK, false),
// create 1/4/5/6, delete prefix 1/2 -> 1/3/8/9, 1/4/5/6
(f(t4, "/5/6"), "PUT", "", StatusCode::OK, false),
(f(t2, ""), "DELETE", "", StatusCode::OK, false),
(f(t2, "/3/4"), "GET", "", StatusCode::NOT_FOUND, false),
(f(t2, "/5/6"), "GET", "", StatusCode::NOT_FOUND, false),
(f(t3, "/8/9"), "GET", "", StatusCode::OK, false),
(f(t4, "/5/6"), "GET", "", StatusCode::OK, false),
// delete prefix 1 -> empty
(format!("/{tenant_id}"), "DELETE", "", StatusCode::OK, false),
(f(t2, "/3/4"), "GET", "", StatusCode::NOT_FOUND, false),
(f(t2, "/5/6"), "GET", "", StatusCode::NOT_FOUND, false),
(f(t3, "/8/9"), "GET", "", StatusCode::NOT_FOUND, false),
(f(t4, "/5/6"), "GET", "", StatusCode::NOT_FOUND, false),
];
requests_chain(chain.into_iter(), delete_prefix_token).await;
}
}

344
object_storage/src/lib.rs Normal file
View File

@@ -0,0 +1,344 @@
use anyhow::Result;
use axum::extract::{FromRequestParts, Path};
use axum::response::{IntoResponse, Response};
use axum::{RequestPartsExt, http::StatusCode, http::request::Parts};
use axum_extra::TypedHeader;
use axum_extra::headers::{Authorization, authorization::Bearer};
use camino::Utf8PathBuf;
use jsonwebtoken::{DecodingKey, Validation};
use remote_storage::{GenericRemoteStorage, RemotePath};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use std::result::Result as StdResult;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error};
use utils::id::{TenantId, TimelineId};
// simplified version of utils::auth::JwtAuth
pub struct JwtAuth {
decoding_key: DecodingKey,
validation: Validation,
}
pub const VALIDATION_ALGO: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::EdDSA;
impl JwtAuth {
pub fn new(key: &[u8]) -> Result<Self> {
Ok(Self {
decoding_key: DecodingKey::from_ed_pem(key)?,
validation: Validation::new(VALIDATION_ALGO),
})
}
pub fn decode<T: serde::de::DeserializeOwned>(&self, token: &str) -> Result<T> {
Ok(jsonwebtoken::decode(token, &self.decoding_key, &self.validation).map(|t| t.claims)?)
}
}
fn normalize_key(key: &str) -> StdResult<Utf8PathBuf, String> {
let key = clean_utf8(&Utf8PathBuf::from(key));
if key.starts_with("..") || key == "." || key == "/" {
return Err(format!("invalid key {key}"));
}
match key.strip_prefix("/").map(Utf8PathBuf::from) {
Ok(p) => Ok(p),
_ => Ok(key),
}
}
// Copied from path_clean crate with PathBuf->Utf8PathBuf
fn clean_utf8(path: &camino::Utf8Path) -> Utf8PathBuf {
use camino::Utf8Component as Comp;
let mut out = Vec::new();
for comp in path.components() {
match comp {
Comp::CurDir => (),
Comp::ParentDir => match out.last() {
Some(Comp::RootDir) => (),
Some(Comp::Normal(_)) => {
out.pop();
}
None | Some(Comp::CurDir) | Some(Comp::ParentDir) | Some(Comp::Prefix(_)) => {
out.push(comp)
}
},
comp => out.push(comp),
}
}
if !out.is_empty() {
out.iter().collect()
} else {
Utf8PathBuf::from(".")
}
}
pub struct Storage {
pub auth: JwtAuth,
pub storage: GenericRemoteStorage,
pub cancel: CancellationToken,
pub max_upload_file_limit: usize,
}
pub type EndpointId = String; // If needed, reuse small string from proxy/src/types.rc
#[derive(Deserialize, Serialize, PartialEq)]
pub struct Claims {
pub tenant_id: TenantId,
pub timeline_id: TimelineId,
pub endpoint_id: EndpointId,
pub exp: u64,
}
impl Display for Claims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Claims(tenant_id {} timeline_id {} endpoint_id {} exp {})",
self.tenant_id, self.timeline_id, self.endpoint_id, self.exp
)
}
}
#[derive(Deserialize, Serialize)]
struct KeyRequest {
tenant_id: TenantId,
timeline_id: TimelineId,
endpoint_id: EndpointId,
path: String,
}
#[derive(Debug, PartialEq)]
pub struct S3Path {
pub path: RemotePath,
}
impl TryFrom<&KeyRequest> for S3Path {
type Error = String;
fn try_from(req: &KeyRequest) -> StdResult<Self, Self::Error> {
let KeyRequest {
tenant_id,
timeline_id,
endpoint_id,
path,
} = &req;
let prefix = format!("{tenant_id}/{timeline_id}/{endpoint_id}",);
let path = Utf8PathBuf::from(prefix).join(normalize_key(path)?);
let path = RemotePath::new(&path).unwrap(); // unwrap() because the path is already relative
Ok(S3Path { path })
}
}
fn unauthorized(route: impl Display, claims: impl Display) -> Response {
debug!(%route, %claims, "route doesn't match claims");
StatusCode::UNAUTHORIZED.into_response()
}
pub fn bad_request(err: impl Display, desc: &'static str) -> Response {
debug!(%err, desc);
(StatusCode::BAD_REQUEST, err.to_string()).into_response()
}
pub fn ok() -> Response {
StatusCode::OK.into_response()
}
pub fn internal_error(err: impl Display, path: impl Display, desc: &'static str) -> Response {
error!(%err, %path, desc);
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
pub fn not_found(key: impl ToString) -> Response {
(StatusCode::NOT_FOUND, key.to_string()).into_response()
}
impl FromRequestParts<Arc<Storage>> for S3Path {
type Rejection = Response;
async fn from_request_parts(
parts: &mut Parts,
state: &Arc<Storage>,
) -> Result<Self, Self::Rejection> {
let Path(path): Path<KeyRequest> = parts
.extract()
.await
.map_err(|e| bad_request(e, "invalid route"))?;
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|e| bad_request(e, "invalid token"))?;
let claims: Claims = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "decoding token"))?;
let route = Claims {
tenant_id: path.tenant_id,
timeline_id: path.timeline_id,
endpoint_id: path.endpoint_id.clone(),
exp: claims.exp,
};
if route != claims {
return Err(unauthorized(route, claims));
}
(&path)
.try_into()
.map_err(|e| bad_request(e, "invalid route"))
}
}
#[derive(Deserialize, Serialize, PartialEq)]
pub struct PrefixKeyPath {
pub tenant_id: TenantId,
pub timeline_id: Option<TimelineId>,
pub endpoint_id: Option<EndpointId>,
}
impl Display for PrefixKeyPath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PrefixKeyPath(tenant_id {} timeline_id {} endpoint_id {})",
self.tenant_id,
self.timeline_id
.as_ref()
.map(ToString::to_string)
.unwrap_or("".to_string()),
self.endpoint_id
.as_ref()
.map(ToString::to_string)
.unwrap_or("".to_string())
)
}
}
#[derive(Debug, PartialEq)]
pub struct PrefixS3Path {
pub path: RemotePath,
}
impl From<&PrefixKeyPath> for PrefixS3Path {
fn from(path: &PrefixKeyPath) -> Self {
let timeline_id = path
.timeline_id
.as_ref()
.map(ToString::to_string)
.unwrap_or("".to_string());
let endpoint_id = path
.endpoint_id
.as_ref()
.map(ToString::to_string)
.unwrap_or("".to_string());
let path = Utf8PathBuf::from(path.tenant_id.to_string())
.join(timeline_id)
.join(endpoint_id);
let path = RemotePath::new(&path).unwrap(); // unwrap() because the path is already relative
PrefixS3Path { path }
}
}
impl FromRequestParts<Arc<Storage>> for PrefixS3Path {
type Rejection = Response;
async fn from_request_parts(
parts: &mut Parts,
state: &Arc<Storage>,
) -> Result<Self, Self::Rejection> {
let Path(path) = parts
.extract::<Path<PrefixKeyPath>>()
.await
.map_err(|e| bad_request(e, "invalid route"))?;
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|e| bad_request(e, "invalid token"))?;
let claims: PrefixKeyPath = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "invalid token"))?;
if path != claims {
return Err(unauthorized(path, claims));
}
Ok((&path).into())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalize_key() {
let f = super::normalize_key;
assert_eq!(f("hello/world/..").unwrap(), Utf8PathBuf::from("hello"));
assert_eq!(
f("ololo/1/../../not_ololo").unwrap(),
Utf8PathBuf::from("not_ololo")
);
assert!(f("ololo/1/../../../").is_err());
assert!(f(".").is_err());
assert!(f("../").is_err());
assert!(f("").is_err());
assert_eq!(f("/1/2/3").unwrap(), Utf8PathBuf::from("1/2/3"));
assert!(f("/1/2/3/../../../").is_err());
assert!(f("/1/2/3/../../../../").is_err());
}
const TENANT_ID: TenantId =
TenantId::from_array([1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6]);
const TIMELINE_ID: TimelineId =
TimelineId::from_array([1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 7]);
const ENDPOINT_ID: &str = "ep-winter-frost-a662z3vg";
#[test]
fn s3_path() {
let auth = Claims {
tenant_id: TENANT_ID,
timeline_id: TIMELINE_ID,
endpoint_id: ENDPOINT_ID.into(),
exp: u64::MAX,
};
let s3_path = |key| {
let path = &format!("{TENANT_ID}/{TIMELINE_ID}/{ENDPOINT_ID}/{key}");
let path = RemotePath::from_string(path).unwrap();
S3Path { path }
};
let path = "cache_key".to_string();
let mut key_path = KeyRequest {
path,
tenant_id: auth.tenant_id,
timeline_id: auth.timeline_id,
endpoint_id: auth.endpoint_id,
};
assert_eq!(S3Path::try_from(&key_path).unwrap(), s3_path(key_path.path));
key_path.path = "we/can/have/nested/paths".to_string();
assert_eq!(S3Path::try_from(&key_path).unwrap(), s3_path(key_path.path));
key_path.path = "../error/hello/../".to_string();
assert!(S3Path::try_from(&key_path).is_err());
}
#[test]
fn prefix_s3_path() {
let mut path = PrefixKeyPath {
tenant_id: TENANT_ID,
timeline_id: None,
endpoint_id: None,
};
let prefix_path = |s: String| RemotePath::from_string(&s).unwrap();
assert_eq!(
PrefixS3Path::from(&path).path,
prefix_path(format!("{TENANT_ID}"))
);
path.timeline_id = Some(TIMELINE_ID);
assert_eq!(
PrefixS3Path::from(&path).path,
prefix_path(format!("{TENANT_ID}/{TIMELINE_ID}"))
);
path.endpoint_id = Some(ENDPOINT_ID.into());
assert_eq!(
PrefixS3Path::from(&path).path,
prefix_path(format!("{TENANT_ID}/{TIMELINE_ID}/{ENDPOINT_ID}"))
);
}
}

View File

@@ -0,0 +1,65 @@
//! `object_storage` is a service which provides API for uploading and downloading
//! files. It is used by compute and control plane for accessing LFC prewarm data.
//! This service is deployed either as a separate component or as part of compute image
//! for large computes.
mod app;
use anyhow::Context;
use tracing::info;
use utils::logging;
//see set()
const fn max_upload_file_limit() -> usize {
100 * 1024 * 1024
}
#[derive(serde::Deserialize)]
#[serde(tag = "type")]
struct Config {
listen: std::net::SocketAddr,
pemfile: camino::Utf8PathBuf,
#[serde(flatten)]
storage_config: remote_storage::RemoteStorageConfig,
#[serde(default = "max_upload_file_limit")]
max_upload_file_limit: usize,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
logging::init(
logging::LogFormat::Plain,
logging::TracingErrorLayerEnablement::EnableWithRustLogFilter,
logging::Output::Stdout,
)?;
let config: String = std::env::args().skip(1).take(1).collect();
if config.is_empty() {
anyhow::bail!("Usage: object_storage config.json")
}
info!("Reading config from {config}");
let config = std::fs::read_to_string(config.clone())?;
let config: Config = serde_json::from_str(&config).context("parsing config")?;
info!("Reading pemfile from {}", config.pemfile.clone());
let pemfile = std::fs::read(config.pemfile.clone())?;
info!("Loading public key from {}", config.pemfile.clone());
let auth = object_storage::JwtAuth::new(&pemfile)?;
let listener = tokio::net::TcpListener::bind(config.listen).await.unwrap();
info!("listening on {}", listener.local_addr().unwrap());
let storage = remote_storage::GenericRemoteStorage::from_config(&config.storage_config).await?;
let cancel = tokio_util::sync::CancellationToken::new();
app::check_storage_permissions(&storage, cancel.clone()).await?;
let proxy = std::sync::Arc::new(object_storage::Storage {
auth,
storage,
cancel: cancel.clone(),
max_upload_file_limit: config.max_upload_file_limit,
});
tokio::spawn(utils::signals::signal_handler(cancel.clone()));
axum::serve(listener, app::app(proxy))
.with_graceful_shutdown(async move { cancel.cancelled().await })
.await?;
Ok(())
}