From 8de4636c2ed66041cf77b35c44da5ac4cec0f8aa Mon Sep 17 00:00:00 2001 From: Mikhail Kot Date: Mon, 19 May 2025 21:22:00 +0100 Subject: [PATCH] initial --- Cargo.lock | 180 +++++++++++++++ Cargo.toml | 2 + paster/Cargo.toml | 25 ++ paster/migrations/1_initial.sql | 18 ++ paster/src/main.rs | 353 +++++++++++++++++++++++++++++ shortener/Cargo.toml | 26 +++ shortener/migrations/1_initial.sql | 19 ++ shortener/src/google_oauth_gate.rs | 222 ++++++++++++++++++ shortener/src/main.rs | 240 ++++++++++++++++++++ vendor/postgres-v14 | 2 +- vendor/postgres-v15 | 2 +- vendor/postgres-v16 | 2 +- vendor/postgres-v17 | 2 +- workspace_hack/Cargo.toml | 2 + 14 files changed, 1091 insertions(+), 4 deletions(-) create mode 100644 paster/Cargo.toml create mode 100644 paster/migrations/1_initial.sql create mode 100644 paster/src/main.rs create mode 100644 shortener/Cargo.toml create mode 100644 shortener/migrations/1_initial.sql create mode 100644 shortener/src/google_oauth_gate.rs create mode 100644 shortener/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 89351432c1..5e7ad7646f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,6 +29,41 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "ahash" version = "0.8.11" @@ -753,6 +788,7 @@ dependencies = [ "axum", "axum-core", "bytes", + "cookie", "futures-util", "headers", "http 1.1.0", @@ -1173,6 +1209,16 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clang-sys" version = "1.6.1" @@ -1464,6 +1510,21 @@ dependencies = [ "workspace_hack", ] +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "aes-gcm", + "base64 0.22.1", + "percent-encoding", + "rand 0.8.5", + "subtle", + "time", + "version_check", +] + [[package]] name = "core-foundation" version = "0.9.3" @@ -1657,9 +1718,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", + "rand_core 0.6.4", "typenum", ] +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "curve25519-dalek" version = "4.1.3" @@ -2510,6 +2581,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "gimli" version = "0.31.1" @@ -3281,6 +3362,15 @@ dependencies = [ "libc", ] +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "generic-array", +] + [[package]] name = "instant" version = "0.1.12" @@ -3794,6 +3884,15 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" +[[package]] +name = "nanoid" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8" +dependencies = [ + "rand 0.8.5", +] + [[package]] name = "neon-shmem" version = "0.1.0" @@ -4066,6 +4165,12 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "openssl-probe" version = "0.1.5" @@ -4585,6 +4690,31 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +[[package]] +name = "paster" +version = "0.1.0" +dependencies = [ + "anyhow", + "axum", + "axum-extra", + "base64 0.13.1", + "chrono", + "nanoid", + "rand 0.8.5", + "reqwest", + "rustls 0.23.27", + "rustls-native-certs 0.8.0", + "serde", + "serde_json", + "time", + "tokio", + "tokio-postgres", + "tokio-postgres-rustls", + "tracing", + "tracing-subscriber", + "workspace_hack", +] + [[package]] name = "pbkdf2" version = "0.12.2" @@ -4757,6 +4887,18 @@ dependencies = [ "never-say-never", ] +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portable-atomic" version = "1.10.0" @@ -6559,6 +6701,32 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "shortener" +version = "0.1.0" +dependencies = [ + "anyhow", + "axum", + "axum-extra", + "base64 0.13.1", + "chrono", + "cookie", + "nanoid", + "rand 0.8.5", + "reqwest", + "rustls 0.23.27", + "rustls-native-certs 0.8.0", + "serde", + "serde_json", + "time", + "tokio", + "tokio-postgres", + "tokio-postgres-rustls", + "tracing", + "tracing-subscriber", + "workspace_hack", +] + [[package]] name = "signal-hook" version = "0.3.15" @@ -7927,6 +8095,16 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "untrusted" version = "0.9.0" @@ -8559,6 +8737,7 @@ dependencies = [ "anyhow", "axum", "axum-core", + "axum-extra", "base64 0.13.1", "base64 0.21.7", "base64ct", @@ -8570,6 +8749,7 @@ dependencies = [ "clap_builder", "const-oid", "crypto-bigint 0.5.5", + "crypto-common", "der 0.7.8", "deranged", "digest", diff --git a/Cargo.toml b/Cargo.toml index a040010fb7..5b3b92b71b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,8 @@ members = [ "proxy", "safekeeper", "safekeeper/client", + "shortener", + "paster", "storage_broker", "storage_controller", "storage_controller/client", diff --git a/paster/Cargo.toml b/paster/Cargo.toml new file mode 100644 index 0000000000..ed8db04e6e --- /dev/null +++ b/paster/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "paster" +version = "0.1.0" +edition.workspace = true +license.workspace = true +[dependencies] +anyhow.workspace = true +axum-extra = { workspace = true, features = ["cookie", "cookie-private"] } +axum.workspace = true +base64.workspace = true +chrono.workspace = true +nanoid = { version = "0.4.0", default-features = false } +rand.workspace = true +reqwest.workspace = true +rustls-native-certs.workspace = true +rustls.workspace = true +serde.workspace = true +serde_json.workspace = true +time = { version = "0.3.36", default-features = false } +tokio-postgres-rustls.workspace = true +tokio-postgres.workspace = true +tokio.workspace = true +tracing-subscriber.workspace = true +tracing.workspace = true +workspace_hack.workspace = true diff --git a/paster/migrations/1_initial.sql b/paster/migrations/1_initial.sql new file mode 100644 index 0000000000..d65d2b0f88 --- /dev/null +++ b/paster/migrations/1_initial.sql @@ -0,0 +1,18 @@ +CREATE TABLE IF NOT EXISTS users ( + id SERIAL PRIMARY KEY, + sub VARCHAR(100) NOT NULL UNIQUE +); + +CREATE TABLE IF NOT EXISTS sessions ( + id SERIAL PRIMARY KEY, + user_id INT NOT NULL UNIQUE REFERENCES users(id), + session_id VARCHAR NOT NULL, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL +); + +CREATE TABLE IF NOT EXISTS pastes ( + id SERIAL PRIMARY KEY, + user_id INT NOT NULL REFERENCES users(id), + paste text NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +) diff --git a/paster/src/main.rs b/paster/src/main.rs new file mode 100644 index 0000000000..0de0923ddf --- /dev/null +++ b/paster/src/main.rs @@ -0,0 +1,353 @@ +//! Paster is a service to share logs or code snippets outside of +//! Slack, not relying on public services +use anyhow::Result; +use shortener::google_oauth_gate::{AuthRequest, State, UserId}; +use axum::Form; +use axum::extract::{FromRef, FromRequestParts, Path, Query, State as AxumStateT}; +use axum::http::StatusCode; +use axum::response::{Html, IntoResponse}; +use axum::response::{Redirect, Response}; +use axum::routing::get; +use axum_extra::extract::PrivateCookieJar; +use axum_extra::extract::cookie::{Cookie, Key}; +use chrono::{Duration, Local, TimeZone, Utc}; +use core::num::NonZeroI32; +use serde::Deserialize; +use std::env; +use std::sync::Arc; +use tracing::{error, info}; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +const SOCKET: &str = "127.0.0.1:12344"; +const HOST: &str = "http://127.0.0.1:12344"; +const ALLOWED_OAUTH_DOMAIN: &str = "neon.tech"; + +fn oauth_redirect_url() -> String { + format!("{HOST}{AUTHORIZED_ROUTE}") +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let oauth_client_id = env::var("OAUTH_CLIENT_ID").expect("Missing OAUTH_CLIENT_ID"); + let oauth_client_secret = env::var("OAUTH_CLIENT_SECRET").expect("Missing OAUTH_CLIENT_SECRET"); + + let db_connstr = env::var("DB_CONNSTR").expect("Missing DB_CONNSTR"); + let mut roots = rustls::RootCertStore::empty(); + for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") { + roots.add(cert).unwrap(); + } + let config = rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_no_client_auth(); + let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config); + info!("initialized TLS"); + + let (db_client, db_conn) = tokio_postgres::connect(&db_connstr, tls).await?; + tokio::spawn(async move { + if let Err(err) = db_conn.await { + error!(%err, "connecting to database"); + std::process::exit(1); + } + }); + info!("connected to database"); + + let state = InnerState { + db_client, + cookie_jar_key: Key::generate(), + oauth_client_id, + oauth_client_secret, + }; + let router = axum::Router::new() + .route("/", get(index).post(paste)) + .route("/authorize", get(authorize)) + .route(AUTHORIZED_ROUTE, get(authorized)) + .route("/{id}", get(view_paste)) + .with_state(State { 0: Arc::new(state) }); + let listener = tokio::net::TcpListener::bind(SOCKET) + .await + .expect("failed to bind TcpListener"); + info!("listening on {SOCKET}"); + axum::serve(listener, router).await.unwrap(); + Ok(()) +} + +#[derive(Deserialize)] +pub struct UserId { + id: NonZeroI32, +} + +impl axum::extract::OptionalFromRequestParts for UserId { + type Rejection = Response; + async fn from_request_parts( + parts: &mut axum::http::request::Parts, + state: &State, + ) -> Result, Self::Rejection> { + let jar: PrivateCookieJar = PrivateCookieJar::from_request_parts(parts, state) + .await + .unwrap(); // infallible + let Some(session_id) = jar.get(COOKIE_SID).map(|cookie| cookie.value().to_owned()) else { + return Ok(None); + }; + + let client = &state.db_client; + let query = client + .query_opt( + "SELECT user_id FROM sessions WHERE session_id = $1", + &[&session_id], + ) + .await; + let id = match query { + Ok(Some(row)) => row.get::(0), + Ok(None) => return Ok(None), + Err(err) => { + error!(%err, "querying user session"); + return Ok(None); + } + }; + let id = NonZeroI32::new(id).unwrap(); // postgres id guaranteed not to be zero + Ok(Some(Self { id })) + } +} + +#[derive(Deserialize)] +struct Paste { + paste: String, +} + +fn paste_form() -> Html { + Html( + r#" +
+ + +
"# + .to_string(), + ) +} + +fn authorize_link(paste_id: i32) -> String { + format!("Authorize") +} + +async fn index(user: Option) -> Html { + if user.is_some() { + return paste_form(); + } + Html(authorize_link(0)) +} + +async fn paste( + state: AxumState, + user: Option, + Form(Paste { paste }): Form, +) -> Response { + let user_id = match user { + None => return StatusCode::FORBIDDEN.into_response(), + Some(user) => user.id, + }; + if paste.is_empty() { + return paste_form().into_response(); + } + + let query = state + .db_client + .query_one( + "INSERT INTO pastes (user_id, paste) VALUES ($1, $2) RETURNING id", + &[&user_id.get(), &paste], + ) + .await; + let id = match query { + Ok(row) => row.get::(0), + Err(err) => { + error!(%err, "inserting paste"); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; + Redirect::to(&format!("/{id}")).into_response() +} + +async fn view_paste(state: AxumState, user: Option, Path(paste_id): Path) -> Response { + let user_id = match user { + None => return Html(authorize_link(paste_id)).into_response(), + Some(user) => user.id, + }; + + let query = state + .db_client + .query_opt("SELECT paste FROM pastes WHERE id = $1", &[&paste_id]) + .await; + let row = match query { + Ok(None) => return StatusCode::NOT_FOUND.into_response(), + Ok(Some(row)) => row, + Err(err) => { + error!(%err, %paste_id, %user_id, "querying paste"); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; + row.get::(0).into_response() +} + +#[derive(Deserialize)] +struct AuthRequest { + code: String, +} + +#[derive(Deserialize)] +struct AuthResponse { + access_token: String, + id_token: String, + expires_in: u64, +} + +#[derive(Deserialize)] +struct UserInfo { + hd: String, + sub: String, +} + +fn decode_id_token(token: String) -> Option { + let payload = token.split(".").skip(1).take(1).collect::>(); + let decoded = base64::decode_config(payload.get(0)?, base64::STANDARD_NO_PAD).ok()?; + serde_json::from_slice::(&decoded).ok() +} + +#[derive(Deserialize)] +struct AuthorizeQuery { + paste_id: i32, +} + +fn generate_csrf_token(num_bytes: u32) -> String { + use rand::{Rng, thread_rng}; + let random_bytes: Vec = (0..num_bytes).map(|_| thread_rng().r#gen::()).collect(); + base64::encode_config(&random_bytes, base64::URL_SAFE_NO_PAD) +} + +async fn authorize( + state: AxumState, + jar: PrivateCookieJar, + Query(AuthorizeQuery { paste_id }): Query, +) -> (PrivateCookieJar, Redirect) { + let csrf_token = generate_csrf_token(16); + let client_id = &state.oauth_client_id; + let redirect_uri = oauth_redirect_url(); + let auth_url = format!( + "{OAUTH_BASE_URL}?response_type=code\ + &client_id={client_id}\ + &state={csrf_token}\ + &redirect_uri={redirect_uri}\ + &scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email" + ); + + let redirect_cookie = Cookie::build((COOKIE_REDIRECT, paste_id.to_string())) + .path("/") + //.TODO secure(true) not true for localhost + //.domain(COOKIE_DOMAIN) + .secure(false) + .same_site(axum_extra::extract::cookie::SameSite::Lax) + .http_only(true) + .build(); + let csrf_cookie = Cookie::build((COOKIE_CSRF, csrf_token)) + .path("/") + //.TODO secure(true) not true for localhost + //.domain(COOKIE_DOMAIN) + .secure(false) + .same_site(axum_extra::extract::cookie::SameSite::Lax) + .http_only(true) + .build(); + let jar = jar.add(redirect_cookie).add(csrf_cookie); + let url = Into::::into(auth_url); + (jar, Redirect::to(&url)) +} + +async fn authorized( + state: AxumState, + jar: PrivateCookieJar, + Query(auth_request): Query, +) -> Result<(PrivateCookieJar, Redirect), Response> { + let params = [ + ("grant_type", "authorization_code"), + ("redirect_uri", &oauth_redirect_url()), + ("code", &auth_request.code), + ("client_id", &state.oauth_client_id), + ("client_secret", &state.oauth_client_secret), + ]; + let auth_response = reqwest::Client::new() + .post(OAUTH_TOKEN_URL) + .form(¶ms) + .send() + .await + .map_err(|err| { + error!(%err, "exchanging oauth code for token"); + StatusCode::INTERNAL_SERVER_ERROR.into_response() + })? + .json::() + .await + .map_err(|err| { + error!(%err, "deserializing access token response"); + StatusCode::INTERNAL_SERVER_ERROR.into_response() + })?; + let Some(UserInfo { hd, sub }) = decode_id_token(auth_response.id_token) else { + error!("Failed to decode response id token"); + return Err(StatusCode::UNAUTHORIZED.into_response()); + }; + if hd != ALLOWED_OAUTH_DOMAIN { + error!(hd, "Domain doesn't match {ALLOWED_OAUTH_DOMAIN}"); + return Err(StatusCode::UNAUTHORIZED.into_response()); + } + + let token_duration = Duration::try_seconds(auth_response.expires_in as i64).unwrap(); + let expires_at = Utc.from_utc_datetime(&(Local::now().naive_local() + token_duration)); + let cookie_max_age = time::Duration::new(token_duration.num_seconds(), 0); + + let session_cookie = Cookie::build((COOKIE_SID, auth_response.access_token.clone())) + .path("/") + //.TODO secure(true) not true for localhost + //.domain(COOKIE_DOMAIN) + .secure(false) + .same_site(axum_extra::extract::cookie::SameSite::Lax) + .http_only(true) + .max_age(cookie_max_age) + .build(); + + state + .db_client + .query( + "WITH user_insert AS (\ + INSERT INTO users (sub) VALUES ($1) \ + ON CONFLICT (sub) DO UPDATE SET sub = excluded.sub RETURNING id)\ + INSERT INTO sessions (user_id, session_id, expires_at) \ + SELECT id, $2, $3 FROM user_insert \ + ON CONFLICT (user_id) DO UPDATE SET \ + session_id = excluded.session_id, \ + expires_at = excluded.expires_at", + &[&sub, &auth_response.access_token, &expires_at], + ) + .await + .map_err(|err| { + error!(%err, %sub, "updating session"); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + })?; + + let csrf_cookie = jar.get(COOKIE_CSRF).unwrap(); // set in authorize() + let jar = jar.remove(csrf_cookie).add(session_cookie); + match jar.get(COOKIE_REDIRECT) { + Some(redirect_cookie) => { + let mut value = redirect_cookie.value_trimmed(); + if value == "0" { + value = ""; + } + let redirect_url = format!("/{value}"); + Ok((jar.remove(redirect_cookie), Redirect::to(&redirect_url))) + } + None => Ok((jar, Redirect::to("/"))), + } +} diff --git a/shortener/Cargo.toml b/shortener/Cargo.toml new file mode 100644 index 0000000000..2dce9acd06 --- /dev/null +++ b/shortener/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "shortener" +version = "0.1.0" +edition.workspace = true +license.workspace = true +[dependencies] +anyhow.workspace = true +axum-extra = { workspace = true, features = ["cookie", "cookie-private"] } +axum.workspace = true +base64.workspace = true +chrono.workspace = true +cookie = "0.18.1" +nanoid = { version = "0.4.0", default-features = false } +rand.workspace = true +reqwest.workspace = true +rustls-native-certs.workspace = true +rustls.workspace = true +serde.workspace = true +serde_json.workspace = true +time = { version = "0.3.36", default-features = false } +tokio-postgres-rustls.workspace = true +tokio-postgres.workspace = true +tokio.workspace = true +tracing-subscriber.workspace = true +tracing.workspace = true +workspace_hack.workspace = true diff --git a/shortener/migrations/1_initial.sql b/shortener/migrations/1_initial.sql new file mode 100644 index 0000000000..65ebabc5d4 --- /dev/null +++ b/shortener/migrations/1_initial.sql @@ -0,0 +1,19 @@ +CREATE TABLE IF NOT EXISTS users ( + id SERIAL PRIMARY KEY, + sub VARCHAR(100) NOT NULL UNIQUE +); + +CREATE TABLE IF NOT EXISTS sessions ( + id SERIAL PRIMARY KEY, + user_id INT NOT NULL UNIQUE REFERENCES users(id), + session_id VARCHAR NOT NULL, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL +); + +CREATE TABLE IF NOT EXISTS urls ( + id SERIAL PRIMARY KEY, + user_id INT NOT NULL REFERENCES users(id), + short_url VARCHAR(6) NOT NULL UNIQUE, + long_url VARCHAR NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +) diff --git a/shortener/src/google_oauth_gate.rs b/shortener/src/google_oauth_gate.rs new file mode 100644 index 0000000000..7ee2534a28 --- /dev/null +++ b/shortener/src/google_oauth_gate.rs @@ -0,0 +1,222 @@ +//! Library to gate infrastructure behind Google Oauth for domain. +//! +//! Why not oauth-rs? Oauth .exchange_code() doesn't work with "request failed". Also, we can't get +//! id token from it, and I don't want to pull in whole openid library just for that. +//! Id token saves us a request to openid endpoint and one Oauth scope we don't use +use anyhow::{Context, Result, bail}; +use axum::extract::{FromRef, FromRequestParts, Query, State as AxumState}; +use axum::response::Redirect; +use axum_extra::extract::PrivateCookieJar; +use axum_extra::extract::cookie::{Cookie, Key}; +use chrono::{Duration, Local, TimeZone, Utc}; +use cookie::CookieBuilder; +use core::num::NonZeroI32; +use reqwest::StatusCode; +use serde::Deserialize; +use std::sync::Arc; +use tokio_postgres::Socket; + +const OAUTH_BASE_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth"; +const OAUTH_TOKEN_URL: &str = "https://oauth2.googleapis.com/token"; +const COOKIE_SID: &str = "sid"; +const COOKIE_CSRF: &str = "csrf"; + +pub struct Config { + pub oauth_client_id: String, + pub oauth_client_secret: String, + pub oauth_redirect_url: String, + pub oauth_allowed_domain: String, + pub cookie_settings: fn(CookieBuilder) -> CookieBuilder, +} + +pub struct InnerState { + config: Config, + cookie_jar_key: Key, + pub db_client: tokio_postgres::Client, +} + +#[derive(Clone)] +pub struct State(Arc); +type DbConn = tokio_postgres::Connection>; + +impl State { + pub async fn new(config: Config, db_connstr: &str) -> Result<(Self, DbConn)> { + let mut roots = rustls::RootCertStore::empty(); + for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") + { + roots.add(cert).unwrap(); + } + let tls_config = rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_no_client_auth(); + let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config); + + let (db_client, db_conn) = tokio_postgres::connect(&db_connstr, tls).await?; + let inner = InnerState { + config, + cookie_jar_key: Key::generate(), + db_client, + }; + Ok((Self { 0: Arc::new(inner) }, db_conn)) + } +} + +impl std::ops::Deref for State { + type Target = InnerState; + fn deref(&self) -> &Self::Target { + &*self.0 + } +} + +impl FromRef for Key { + fn from_ref(state: &State) -> Self { + state.cookie_jar_key.clone() + } +} + +#[derive(Deserialize)] +pub struct UserId { + pub id: NonZeroI32, +} + +#[derive(Deserialize)] +pub struct AuthRequest { + code: String, +} + +#[derive(Deserialize)] +struct AuthResponse { + access_token: String, + id_token: String, + expires_in: u64, +} + +#[derive(Deserialize)] +struct UserInfo { + hd: String, + sub: String, +} + +impl axum::extract::OptionalFromRequestParts for UserId { + type Rejection = StatusCode; + async fn from_request_parts( + parts: &mut axum::http::request::Parts, + state: &State, + ) -> Result, Self::Rejection> { + let jar: PrivateCookieJar = PrivateCookieJar::from_request_parts(parts, state) + .await + .unwrap(); // infallible + let Some(session_id) = jar.get(COOKIE_SID).map(|cookie| cookie.value().to_owned()) else { + return Ok(None); + }; + + let client = &state.db_client; + let query = client + .query_opt( + "SELECT user_id FROM sessions WHERE session_id = $1", + &[&session_id], + ) + .await; + let id = match query { + Ok(Some(row)) => row.get::(0), + Ok(None) => return Ok(None), + Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR), + }; + let id = NonZeroI32::new(id).unwrap(); // postgres id guaranteed not to be zero + Ok(Some(Self { id })) + } +} + +fn decode_id_token(token: String) -> Option { + let payload = token.split(".").skip(1).take(1).collect::>(); + let decoded = base64::decode_config(payload.get(0)?, base64::STANDARD_NO_PAD).ok()?; + serde_json::from_slice::(&decoded).ok() +} + +fn generate_csrf_token(num_bytes: u32) -> String { + use rand::{Rng, thread_rng}; + let random_bytes: Vec = (0..num_bytes).map(|_| thread_rng().r#gen::()).collect(); + base64::encode_config(&random_bytes, base64::URL_SAFE_NO_PAD) +} + +pub async fn authorize( + state: AxumState, + jar: PrivateCookieJar, +) -> (PrivateCookieJar, Redirect) { + let csrf_token = generate_csrf_token(16); + let client_id = &state.config.oauth_client_id; + let redirect_uri = &state.config.oauth_redirect_url; + let auth_url = format!( + "{OAUTH_BASE_URL}?response_type=code\ + &client_id={client_id}\ + &state={csrf_token}\ + &redirect_uri={redirect_uri}\ + &scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email" + ); + + let csrf_cookie = + (state.config.cookie_settings)(Cookie::build((COOKIE_CSRF, csrf_token))).build(); + let url = Into::::into(auth_url); + (jar.add(csrf_cookie), Redirect::to(&url)) +} + +pub async fn authorized( + state: AxumState, + jar: PrivateCookieJar, + Query(auth_request): Query, +) -> Result { + let params = [ + ("grant_type", "authorization_code"), + ("redirect_uri", &state.config.oauth_redirect_url), + ("code", &auth_request.code), + ("client_id", &state.config.oauth_client_id), + ("client_secret", &state.config.oauth_client_secret), + ]; + let auth_response = reqwest::Client::new() + .post(OAUTH_TOKEN_URL) + .form(¶ms) + .send() + .await + .context("exchanging oauth code for token")? + .json::() + .await + .context("deserializing access_token response")?; + let Some(UserInfo { hd, sub }) = decode_id_token(auth_response.id_token) else { + bail!("failed to decode id token") + }; + + let allowed_domain = &state.config.oauth_allowed_domain; + if hd != *allowed_domain { + bail!("{hd} doesn't match {allowed_domain}") + } + + let token_duration = Duration::try_seconds(auth_response.expires_in as i64).unwrap(); + let expires_at = Utc.from_utc_datetime(&(Local::now().naive_local() + token_duration)); + let cookie_max_age = time::Duration::new(token_duration.num_seconds(), 0); + + let session_cookie = (state.config.cookie_settings)(Cookie::build(( + COOKIE_SID, + auth_response.access_token.clone(), + ))) + .max_age(cookie_max_age) + .build(); + + state + .db_client + .query( + "WITH user_insert AS (\ + INSERT INTO users (sub) VALUES ($1) \ + ON CONFLICT (sub) DO UPDATE SET sub = excluded.sub RETURNING id)\ + INSERT INTO sessions (user_id, session_id, expires_at) \ + SELECT id, $2, $3 FROM user_insert \ + ON CONFLICT (user_id) DO UPDATE SET \ + session_id = excluded.session_id, \ + expires_at = excluded.expires_at", + &[&sub, &auth_response.access_token, &expires_at], + ) + .await + .with_context(|| format!("updating session for {sub}"))?; + + let csrf_cookie = jar.get(COOKIE_CSRF).unwrap(); // set in authorize() + Ok(jar.remove(csrf_cookie).add(session_cookie)) +} diff --git a/shortener/src/main.rs b/shortener/src/main.rs new file mode 100644 index 0000000000..24e229ac3d --- /dev/null +++ b/shortener/src/main.rs @@ -0,0 +1,240 @@ +//! Shortener is a service to gate access to internal infrastructure +//! URLs behind team authorisation to expose less private information. +pub mod google_oauth_gate; +use crate::google_oauth_gate::{AuthRequest, State, UserId}; +use anyhow::Result; +use axum::Form; +use axum::extract::State as AxumState; +use axum::extract::{Path, Query}; +use axum::http::StatusCode; +use axum::response::{Html, IntoResponse}; +use axum::response::{Redirect, Response}; +use axum::routing::get; +use axum_extra::extract::PrivateCookieJar; +use axum_extra::extract::cookie::Cookie; +use cookie::CookieBuilder; +use google_oauth_gate::Config; +use serde::Deserialize; +use std::env; +use tracing::{error, info}; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +const SOCKET: &str = "127.0.0.1:12344"; +const HOST: &str = "http://127.0.0.1:12344"; +const COOKIE_REDIRECT: &str = "redirect"; +const ALLOWED_OAUTH_DOMAIN: &str = "neon.tech"; +const AUTHORIZED_ROUTE: &str = "/authorized"; +const SHORT_URL_LEN: usize = 6; + +fn cookie_settings(b: CookieBuilder) -> CookieBuilder { + if HOST.contains("127.0.0.1") { + b.path("/") + .secure(false) + .same_site(axum_extra::extract::cookie::SameSite::Lax) + .http_only(true) + } else { + b.path("/") + .domain(ALLOWED_OAUTH_DOMAIN) + .secure(true) + .http_only(false) + } +} + +fn oauth_redirect_url() -> String { + format!("{HOST}{AUTHORIZED_ROUTE}") +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let oauth_client_id = env::var("OAUTH_CLIENT_ID").expect("Missing OAUTH_CLIENT_ID"); + let oauth_client_secret = env::var("OAUTH_CLIENT_SECRET").expect("Missing OAUTH_CLIENT_SECRET"); + let db_connstr = env::var("DB_CONNSTR").expect("Missing DB_CONNSTR"); + + let config = Config { + oauth_client_id, + oauth_client_secret, + oauth_redirect_url: oauth_redirect_url(), + oauth_allowed_domain: ALLOWED_OAUTH_DOMAIN.to_string(), + cookie_settings, + }; + let (state, db_conn) = State::new(config, &db_connstr).await?; + tokio::spawn(async move { + if let Err(err) = db_conn.await { + error!(%err, "connecting to database"); + std::process::exit(1); + } + }); + + let router = axum::Router::new() + .route("/", get(index).post(shorten)) + .route("/authorize", get(authorize)) + .route(AUTHORIZED_ROUTE, get(authorized)) + .route("/{short_url}", get(redirect)) + .with_state(state); + let listener = tokio::net::TcpListener::bind(SOCKET) + .await + .expect("failed to bind TcpListener"); + info!("listening on {SOCKET}"); + axum::serve(listener, router).await.unwrap(); + Ok(()) +} + +#[derive(Deserialize)] +struct LongUrl { + url: String, +} + +fn shorten_form(short_url: &str) -> Html { + let mut form = r#" +
+
+ + +
"# + .to_string(); + if !short_url.is_empty() { + form += &format!( + r#" +

+ {0} + +

+ "#, + short_url + ); + } + form += "
"; + Html(form) +} + +fn authorize_link(short_url: &str) -> Html { + Html(format!( + "Authorize" + )) +} + +async fn index(user: Option) -> Html { + if user.is_some() { + return shorten_form(""); + } + authorize_link("") +} + +async fn shorten( + state: AxumState, + user: Option, + Form(LongUrl { url }): Form, +) -> Response { + let user_id = match user { + None => return StatusCode::FORBIDDEN.into_response(), + Some(user) => user.id.get(), + }; + if url.is_empty() { + return shorten_form("").into_response(); + } + + let mut short_url = "".to_string(); + for i in 0..20 { + short_url = nanoid::nanoid!(SHORT_URL_LEN); + let query = state + .db_client + .query_opt( + "INSERT INTO urls (user_id, short_url, long_url) VALUES ($1, $2, $3) \ + ON CONFLICT (short_url) DO NOTHING \ + RETURNING short_url", + &[&user_id, &short_url, &url], + ) + .await; + match query { + Ok(Some(_)) => break, + Ok(None) => { + info!(short_url, "url clash, retry {i}"); + continue; + } + Err(err) => { + error!(%err, "inserting shortened url"); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; + } + shorten_form(&format!("{HOST}/{short_url}")).into_response() +} + +async fn redirect( + state: AxumState, + user: Option, + Path(short_url): Path, +) -> Response { + let user_id = match user { + None => return authorize_link(&short_url).into_response(), + Some(user) => user.id, + }; + + let query = state + .db_client + .query_opt( + "SELECT long_url FROM urls WHERE short_url = $1", + &[&short_url], + ) + .await; + match query { + Ok(Some(row)) => Redirect::permanent(row.get(0)).into_response(), + Ok(None) => StatusCode::NOT_FOUND.into_response(), + Err(err) => { + error!(%err, %short_url, %user_id, "querying long url"); + StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + } +} + +#[derive(Deserialize)] +struct AuthorizeQuery { + short_url: String, +} + +async fn authorize( + state: AxumState, + jar: PrivateCookieJar, + Query(AuthorizeQuery { short_url }): Query, +) -> (PrivateCookieJar, Redirect) { + let (jar, auth_redirect) = google_oauth_gate::authorize(state, jar).await; + let redirect_cookie = Cookie::build((COOKIE_REDIRECT, short_url)) + .path("/") + //.TODO secure(true) not true for localhost + //.domain(COOKIE_DOMAIN) + .secure(false) + .same_site(axum_extra::extract::cookie::SameSite::Lax) + .http_only(true) + .build(); + (jar.add(redirect_cookie), auth_redirect) +} + +async fn authorized( + state: AxumState, + jar: PrivateCookieJar, + query: Query, +) -> Result<(PrivateCookieJar, Redirect), Response> { + use google_oauth_gate::authorized; + let jar = authorized(state, jar, query).await.map_err(|err| { + error!(%err, "authorizing"); + return StatusCode::UNAUTHORIZED.into_response(); + })?; + let Some(redirect_cookie) = jar.get(COOKIE_REDIRECT) else { + return Ok((jar, Redirect::to("/"))); + }; + let redirect_url = Redirect::to(&format!("/{}", redirect_cookie.value_trimmed())); + Ok((jar.remove(redirect_cookie), redirect_url)) +} diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index 55c0d45abe..b6eece3f52 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit 55c0d45abe6467c02084c2192bca117eda6ce1e7 +Subproject commit b6eece3f528fdc380e6e2c13381434470606787f diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index de7640f55d..20f8491225 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit de7640f55da07512834d5cc40c4b3fb376b5f04f +Subproject commit 20f8491225f86bdedbc986e9a69ebafb1c94aa99 diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index 0bf96bd6d7..77c63bfebf 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit 0bf96bd6d70301a0b43b0b3457bb3cf8fb43c198 +Subproject commit 77c63bfebff5c833682cc2654e2191fec4d5b24e diff --git a/vendor/postgres-v17 b/vendor/postgres-v17 index 8be779fd3a..32d704d965 160000 --- a/vendor/postgres-v17 +++ b/vendor/postgres-v17 @@ -1 +1 @@ -Subproject commit 8be779fd3ab9e87206da96a7e4842ef1abf04f44 +Subproject commit 32d704d965d8ad632c0ddef64b45a5ba95536442 diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 2b07889871..56cc900fe1 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -20,6 +20,7 @@ anstream = { version = "0.6" } anyhow = { version = "1", features = ["backtrace"] } axum = { version = "0.8", features = ["ws"] } axum-core = { version = "0.5", default-features = false, features = ["tracing"] } +axum-extra = { version = "0.10", features = ["cookie-private", "typed-header"] } base64-594e8ee84c453af0 = { package = "base64", version = "0.13", features = ["alloc"] } base64-647d43efb71741da = { package = "base64", version = "0.21" } base64ct = { version = "1", default-features = false, features = ["std"] } @@ -30,6 +31,7 @@ clap = { version = "4", features = ["derive", "env", "string"] } clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] } const-oid = { version = "0.9", default-features = false, features = ["db", "std"] } crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] } +crypto-common = { version = "0.1", default-features = false, features = ["getrandom", "std"] } der = { version = "0.7", default-features = false, features = ["derive", "flagset", "oid", "pem", "std"] } deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] } digest = { version = "0.10", features = ["mac", "oid", "std"] }