This commit is contained in:
Mikhail Kot
2025-05-19 21:22:00 +01:00
parent 5b62749c42
commit 8de4636c2e
14 changed files with 1091 additions and 4 deletions

180
Cargo.lock generated
View File

@@ -29,6 +29,41 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
[[package]]
name = "aead"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
dependencies = [
"crypto-common",
"generic-array",
]
[[package]]
name = "aes"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "aes-gcm"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1"
dependencies = [
"aead",
"aes",
"cipher",
"ctr",
"ghash",
"subtle",
]
[[package]]
name = "ahash"
version = "0.8.11"
@@ -753,6 +788,7 @@ dependencies = [
"axum",
"axum-core",
"bytes",
"cookie",
"futures-util",
"headers",
"http 1.1.0",
@@ -1173,6 +1209,16 @@ dependencies = [
"half",
]
[[package]]
name = "cipher"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
dependencies = [
"crypto-common",
"inout",
]
[[package]]
name = "clang-sys"
version = "1.6.1"
@@ -1464,6 +1510,21 @@ dependencies = [
"workspace_hack",
]
[[package]]
name = "cookie"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
dependencies = [
"aes-gcm",
"base64 0.22.1",
"percent-encoding",
"rand 0.8.5",
"subtle",
"time",
"version_check",
]
[[package]]
name = "core-foundation"
version = "0.9.3"
@@ -1657,9 +1718,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"rand_core 0.6.4",
"typenum",
]
[[package]]
name = "ctr"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835"
dependencies = [
"cipher",
]
[[package]]
name = "curve25519-dalek"
version = "4.1.3"
@@ -2510,6 +2581,16 @@ dependencies = [
"winapi",
]
[[package]]
name = "ghash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1"
dependencies = [
"opaque-debug",
"polyval",
]
[[package]]
name = "gimli"
version = "0.31.1"
@@ -3281,6 +3362,15 @@ dependencies = [
"libc",
]
[[package]]
name = "inout"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01"
dependencies = [
"generic-array",
]
[[package]]
name = "instant"
version = "0.1.12"
@@ -3794,6 +3884,15 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
[[package]]
name = "nanoid"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8"
dependencies = [
"rand 0.8.5",
]
[[package]]
name = "neon-shmem"
version = "0.1.0"
@@ -4066,6 +4165,12 @@ version = "11.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
[[package]]
name = "opaque-debug"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "openssl-probe"
version = "0.1.5"
@@ -4585,6 +4690,31 @@ version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
[[package]]
name = "paster"
version = "0.1.0"
dependencies = [
"anyhow",
"axum",
"axum-extra",
"base64 0.13.1",
"chrono",
"nanoid",
"rand 0.8.5",
"reqwest",
"rustls 0.23.27",
"rustls-native-certs 0.8.0",
"serde",
"serde_json",
"time",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tracing",
"tracing-subscriber",
"workspace_hack",
]
[[package]]
name = "pbkdf2"
version = "0.12.2"
@@ -4757,6 +4887,18 @@ dependencies = [
"never-say-never",
]
[[package]]
name = "polyval"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25"
dependencies = [
"cfg-if",
"cpufeatures",
"opaque-debug",
"universal-hash",
]
[[package]]
name = "portable-atomic"
version = "1.10.0"
@@ -6559,6 +6701,32 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "shortener"
version = "0.1.0"
dependencies = [
"anyhow",
"axum",
"axum-extra",
"base64 0.13.1",
"chrono",
"cookie",
"nanoid",
"rand 0.8.5",
"reqwest",
"rustls 0.23.27",
"rustls-native-certs 0.8.0",
"serde",
"serde_json",
"time",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tracing",
"tracing-subscriber",
"workspace_hack",
]
[[package]]
name = "signal-hook"
version = "0.3.15"
@@ -7927,6 +8095,16 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
[[package]]
name = "universal-hash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
dependencies = [
"crypto-common",
"subtle",
]
[[package]]
name = "untrusted"
version = "0.9.0"
@@ -8559,6 +8737,7 @@ dependencies = [
"anyhow",
"axum",
"axum-core",
"axum-extra",
"base64 0.13.1",
"base64 0.21.7",
"base64ct",
@@ -8570,6 +8749,7 @@ dependencies = [
"clap_builder",
"const-oid",
"crypto-bigint 0.5.5",
"crypto-common",
"der 0.7.8",
"deranged",
"digest",

View File

@@ -13,6 +13,8 @@ members = [
"proxy",
"safekeeper",
"safekeeper/client",
"shortener",
"paster",
"storage_broker",
"storage_controller",
"storage_controller/client",

25
paster/Cargo.toml Normal file
View File

@@ -0,0 +1,25 @@
[package]
name = "paster"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
axum-extra = { workspace = true, features = ["cookie", "cookie-private"] }
axum.workspace = true
base64.workspace = true
chrono.workspace = true
nanoid = { version = "0.4.0", default-features = false }
rand.workspace = true
reqwest.workspace = true
rustls-native-certs.workspace = true
rustls.workspace = true
serde.workspace = true
serde_json.workspace = true
time = { version = "0.3.36", default-features = false }
tokio-postgres-rustls.workspace = true
tokio-postgres.workspace = true
tokio.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
workspace_hack.workspace = true

View File

@@ -0,0 +1,18 @@
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
sub VARCHAR(100) NOT NULL UNIQUE
);
CREATE TABLE IF NOT EXISTS sessions (
id SERIAL PRIMARY KEY,
user_id INT NOT NULL UNIQUE REFERENCES users(id),
session_id VARCHAR NOT NULL,
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
);
CREATE TABLE IF NOT EXISTS pastes (
id SERIAL PRIMARY KEY,
user_id INT NOT NULL REFERENCES users(id),
paste text NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
)

353
paster/src/main.rs Normal file
View File

@@ -0,0 +1,353 @@
//! Paster is a service to share logs or code snippets outside of
//! Slack, not relying on public services
use anyhow::Result;
use shortener::google_oauth_gate::{AuthRequest, State, UserId};
use axum::Form;
use axum::extract::{FromRef, FromRequestParts, Path, Query, State as AxumStateT};
use axum::http::StatusCode;
use axum::response::{Html, IntoResponse};
use axum::response::{Redirect, Response};
use axum::routing::get;
use axum_extra::extract::PrivateCookieJar;
use axum_extra::extract::cookie::{Cookie, Key};
use chrono::{Duration, Local, TimeZone, Utc};
use core::num::NonZeroI32;
use serde::Deserialize;
use std::env;
use std::sync::Arc;
use tracing::{error, info};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
const SOCKET: &str = "127.0.0.1:12344";
const HOST: &str = "http://127.0.0.1:12344";
const ALLOWED_OAUTH_DOMAIN: &str = "neon.tech";
fn oauth_redirect_url() -> String {
format!("{HOST}{AUTHORIZED_ROUTE}")
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let oauth_client_id = env::var("OAUTH_CLIENT_ID").expect("Missing OAUTH_CLIENT_ID");
let oauth_client_secret = env::var("OAUTH_CLIENT_SECRET").expect("Missing OAUTH_CLIENT_SECRET");
let db_connstr = env::var("DB_CONNSTR").expect("Missing DB_CONNSTR");
let mut roots = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") {
roots.add(cert).unwrap();
}
let config = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config);
info!("initialized TLS");
let (db_client, db_conn) = tokio_postgres::connect(&db_connstr, tls).await?;
tokio::spawn(async move {
if let Err(err) = db_conn.await {
error!(%err, "connecting to database");
std::process::exit(1);
}
});
info!("connected to database");
let state = InnerState {
db_client,
cookie_jar_key: Key::generate(),
oauth_client_id,
oauth_client_secret,
};
let router = axum::Router::new()
.route("/", get(index).post(paste))
.route("/authorize", get(authorize))
.route(AUTHORIZED_ROUTE, get(authorized))
.route("/{id}", get(view_paste))
.with_state(State { 0: Arc::new(state) });
let listener = tokio::net::TcpListener::bind(SOCKET)
.await
.expect("failed to bind TcpListener");
info!("listening on {SOCKET}");
axum::serve(listener, router).await.unwrap();
Ok(())
}
#[derive(Deserialize)]
pub struct UserId {
id: NonZeroI32,
}
impl axum::extract::OptionalFromRequestParts<State> for UserId {
type Rejection = Response;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &State,
) -> Result<Option<Self>, Self::Rejection> {
let jar: PrivateCookieJar = PrivateCookieJar::from_request_parts(parts, state)
.await
.unwrap(); // infallible
let Some(session_id) = jar.get(COOKIE_SID).map(|cookie| cookie.value().to_owned()) else {
return Ok(None);
};
let client = &state.db_client;
let query = client
.query_opt(
"SELECT user_id FROM sessions WHERE session_id = $1",
&[&session_id],
)
.await;
let id = match query {
Ok(Some(row)) => row.get::<usize, i32>(0),
Ok(None) => return Ok(None),
Err(err) => {
error!(%err, "querying user session");
return Ok(None);
}
};
let id = NonZeroI32::new(id).unwrap(); // postgres id guaranteed not to be zero
Ok(Some(Self { id }))
}
}
#[derive(Deserialize)]
struct Paste {
paste: String,
}
fn paste_form() -> Html<String> {
Html(
r#"
<form method="post">
<textarea name="paste" style="width:100%;height:80%"></textarea>
<input type="submit" value="Paste" style="margin-top:10px">
</form>"#
.to_string(),
)
}
fn authorize_link(paste_id: i32) -> String {
format!("<a href=\"/authorize?paste_id={paste_id}\">Authorize</a>")
}
async fn index(user: Option<UserId>) -> Html<String> {
if user.is_some() {
return paste_form();
}
Html(authorize_link(0))
}
async fn paste(
state: AxumState,
user: Option<UserId>,
Form(Paste { paste }): Form<Paste>,
) -> Response {
let user_id = match user {
None => return StatusCode::FORBIDDEN.into_response(),
Some(user) => user.id,
};
if paste.is_empty() {
return paste_form().into_response();
}
let query = state
.db_client
.query_one(
"INSERT INTO pastes (user_id, paste) VALUES ($1, $2) RETURNING id",
&[&user_id.get(), &paste],
)
.await;
let id = match query {
Ok(row) => row.get::<usize, i32>(0),
Err(err) => {
error!(%err, "inserting paste");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
Redirect::to(&format!("/{id}")).into_response()
}
async fn view_paste(state: AxumState, user: Option<UserId>, Path(paste_id): Path<i32>) -> Response {
let user_id = match user {
None => return Html(authorize_link(paste_id)).into_response(),
Some(user) => user.id,
};
let query = state
.db_client
.query_opt("SELECT paste FROM pastes WHERE id = $1", &[&paste_id])
.await;
let row = match query {
Ok(None) => return StatusCode::NOT_FOUND.into_response(),
Ok(Some(row)) => row,
Err(err) => {
error!(%err, %paste_id, %user_id, "querying paste");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
row.get::<usize, String>(0).into_response()
}
#[derive(Deserialize)]
struct AuthRequest {
code: String,
}
#[derive(Deserialize)]
struct AuthResponse {
access_token: String,
id_token: String,
expires_in: u64,
}
#[derive(Deserialize)]
struct UserInfo {
hd: String,
sub: String,
}
fn decode_id_token(token: String) -> Option<UserInfo> {
let payload = token.split(".").skip(1).take(1).collect::<Vec<&str>>();
let decoded = base64::decode_config(payload.get(0)?, base64::STANDARD_NO_PAD).ok()?;
serde_json::from_slice::<UserInfo>(&decoded).ok()
}
#[derive(Deserialize)]
struct AuthorizeQuery {
paste_id: i32,
}
fn generate_csrf_token(num_bytes: u32) -> String {
use rand::{Rng, thread_rng};
let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().r#gen::<u8>()).collect();
base64::encode_config(&random_bytes, base64::URL_SAFE_NO_PAD)
}
async fn authorize(
state: AxumState,
jar: PrivateCookieJar,
Query(AuthorizeQuery { paste_id }): Query<AuthorizeQuery>,
) -> (PrivateCookieJar, Redirect) {
let csrf_token = generate_csrf_token(16);
let client_id = &state.oauth_client_id;
let redirect_uri = oauth_redirect_url();
let auth_url = format!(
"{OAUTH_BASE_URL}?response_type=code\
&client_id={client_id}\
&state={csrf_token}\
&redirect_uri={redirect_uri}\
&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email"
);
let redirect_cookie = Cookie::build((COOKIE_REDIRECT, paste_id.to_string()))
.path("/")
//.TODO secure(true) not true for localhost
//.domain(COOKIE_DOMAIN)
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
.build();
let csrf_cookie = Cookie::build((COOKIE_CSRF, csrf_token))
.path("/")
//.TODO secure(true) not true for localhost
//.domain(COOKIE_DOMAIN)
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
.build();
let jar = jar.add(redirect_cookie).add(csrf_cookie);
let url = Into::<String>::into(auth_url);
(jar, Redirect::to(&url))
}
async fn authorized(
state: AxumState,
jar: PrivateCookieJar,
Query(auth_request): Query<AuthRequest>,
) -> Result<(PrivateCookieJar, Redirect), Response> {
let params = [
("grant_type", "authorization_code"),
("redirect_uri", &oauth_redirect_url()),
("code", &auth_request.code),
("client_id", &state.oauth_client_id),
("client_secret", &state.oauth_client_secret),
];
let auth_response = reqwest::Client::new()
.post(OAUTH_TOKEN_URL)
.form(&params)
.send()
.await
.map_err(|err| {
error!(%err, "exchanging oauth code for token");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
})?
.json::<AuthResponse>()
.await
.map_err(|err| {
error!(%err, "deserializing access token response");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
})?;
let Some(UserInfo { hd, sub }) = decode_id_token(auth_response.id_token) else {
error!("Failed to decode response id token");
return Err(StatusCode::UNAUTHORIZED.into_response());
};
if hd != ALLOWED_OAUTH_DOMAIN {
error!(hd, "Domain doesn't match {ALLOWED_OAUTH_DOMAIN}");
return Err(StatusCode::UNAUTHORIZED.into_response());
}
let token_duration = Duration::try_seconds(auth_response.expires_in as i64).unwrap();
let expires_at = Utc.from_utc_datetime(&(Local::now().naive_local() + token_duration));
let cookie_max_age = time::Duration::new(token_duration.num_seconds(), 0);
let session_cookie = Cookie::build((COOKIE_SID, auth_response.access_token.clone()))
.path("/")
//.TODO secure(true) not true for localhost
//.domain(COOKIE_DOMAIN)
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
.max_age(cookie_max_age)
.build();
state
.db_client
.query(
"WITH user_insert AS (\
INSERT INTO users (sub) VALUES ($1) \
ON CONFLICT (sub) DO UPDATE SET sub = excluded.sub RETURNING id)\
INSERT INTO sessions (user_id, session_id, expires_at) \
SELECT id, $2, $3 FROM user_insert \
ON CONFLICT (user_id) DO UPDATE SET \
session_id = excluded.session_id, \
expires_at = excluded.expires_at",
&[&sub, &auth_response.access_token, &expires_at],
)
.await
.map_err(|err| {
error!(%err, %sub, "updating session");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
})?;
let csrf_cookie = jar.get(COOKIE_CSRF).unwrap(); // set in authorize()
let jar = jar.remove(csrf_cookie).add(session_cookie);
match jar.get(COOKIE_REDIRECT) {
Some(redirect_cookie) => {
let mut value = redirect_cookie.value_trimmed();
if value == "0" {
value = "";
}
let redirect_url = format!("/{value}");
Ok((jar.remove(redirect_cookie), Redirect::to(&redirect_url)))
}
None => Ok((jar, Redirect::to("/"))),
}
}

26
shortener/Cargo.toml Normal file
View File

@@ -0,0 +1,26 @@
[package]
name = "shortener"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
axum-extra = { workspace = true, features = ["cookie", "cookie-private"] }
axum.workspace = true
base64.workspace = true
chrono.workspace = true
cookie = "0.18.1"
nanoid = { version = "0.4.0", default-features = false }
rand.workspace = true
reqwest.workspace = true
rustls-native-certs.workspace = true
rustls.workspace = true
serde.workspace = true
serde_json.workspace = true
time = { version = "0.3.36", default-features = false }
tokio-postgres-rustls.workspace = true
tokio-postgres.workspace = true
tokio.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
workspace_hack.workspace = true

View File

@@ -0,0 +1,19 @@
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
sub VARCHAR(100) NOT NULL UNIQUE
);
CREATE TABLE IF NOT EXISTS sessions (
id SERIAL PRIMARY KEY,
user_id INT NOT NULL UNIQUE REFERENCES users(id),
session_id VARCHAR NOT NULL,
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
);
CREATE TABLE IF NOT EXISTS urls (
id SERIAL PRIMARY KEY,
user_id INT NOT NULL REFERENCES users(id),
short_url VARCHAR(6) NOT NULL UNIQUE,
long_url VARCHAR NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
)

View File

@@ -0,0 +1,222 @@
//! Library to gate infrastructure behind Google Oauth for domain.
//!
//! Why not oauth-rs? Oauth .exchange_code() doesn't work with "request failed". Also, we can't get
//! id token from it, and I don't want to pull in whole openid library just for that.
//! Id token saves us a request to openid endpoint and one Oauth scope we don't use
use anyhow::{Context, Result, bail};
use axum::extract::{FromRef, FromRequestParts, Query, State as AxumState};
use axum::response::Redirect;
use axum_extra::extract::PrivateCookieJar;
use axum_extra::extract::cookie::{Cookie, Key};
use chrono::{Duration, Local, TimeZone, Utc};
use cookie::CookieBuilder;
use core::num::NonZeroI32;
use reqwest::StatusCode;
use serde::Deserialize;
use std::sync::Arc;
use tokio_postgres::Socket;
const OAUTH_BASE_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth";
const OAUTH_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
const COOKIE_SID: &str = "sid";
const COOKIE_CSRF: &str = "csrf";
pub struct Config {
pub oauth_client_id: String,
pub oauth_client_secret: String,
pub oauth_redirect_url: String,
pub oauth_allowed_domain: String,
pub cookie_settings: fn(CookieBuilder) -> CookieBuilder,
}
pub struct InnerState {
config: Config,
cookie_jar_key: Key,
pub db_client: tokio_postgres::Client,
}
#[derive(Clone)]
pub struct State(Arc<InnerState>);
type DbConn = tokio_postgres::Connection<Socket, tokio_postgres_rustls::RustlsStream<Socket>>;
impl State {
pub async fn new(config: Config, db_connstr: &str) -> Result<(Self, DbConn)> {
let mut roots = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
roots.add(cert).unwrap();
}
let tls_config = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
let (db_client, db_conn) = tokio_postgres::connect(&db_connstr, tls).await?;
let inner = InnerState {
config,
cookie_jar_key: Key::generate(),
db_client,
};
Ok((Self { 0: Arc::new(inner) }, db_conn))
}
}
impl std::ops::Deref for State {
type Target = InnerState;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl FromRef<State> for Key {
fn from_ref(state: &State) -> Self {
state.cookie_jar_key.clone()
}
}
#[derive(Deserialize)]
pub struct UserId {
pub id: NonZeroI32,
}
#[derive(Deserialize)]
pub struct AuthRequest {
code: String,
}
#[derive(Deserialize)]
struct AuthResponse {
access_token: String,
id_token: String,
expires_in: u64,
}
#[derive(Deserialize)]
struct UserInfo {
hd: String,
sub: String,
}
impl axum::extract::OptionalFromRequestParts<State> for UserId {
type Rejection = StatusCode;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &State,
) -> Result<Option<Self>, Self::Rejection> {
let jar: PrivateCookieJar = PrivateCookieJar::from_request_parts(parts, state)
.await
.unwrap(); // infallible
let Some(session_id) = jar.get(COOKIE_SID).map(|cookie| cookie.value().to_owned()) else {
return Ok(None);
};
let client = &state.db_client;
let query = client
.query_opt(
"SELECT user_id FROM sessions WHERE session_id = $1",
&[&session_id],
)
.await;
let id = match query {
Ok(Some(row)) => row.get::<usize, i32>(0),
Ok(None) => return Ok(None),
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
};
let id = NonZeroI32::new(id).unwrap(); // postgres id guaranteed not to be zero
Ok(Some(Self { id }))
}
}
fn decode_id_token(token: String) -> Option<UserInfo> {
let payload = token.split(".").skip(1).take(1).collect::<Vec<&str>>();
let decoded = base64::decode_config(payload.get(0)?, base64::STANDARD_NO_PAD).ok()?;
serde_json::from_slice::<UserInfo>(&decoded).ok()
}
fn generate_csrf_token(num_bytes: u32) -> String {
use rand::{Rng, thread_rng};
let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().r#gen::<u8>()).collect();
base64::encode_config(&random_bytes, base64::URL_SAFE_NO_PAD)
}
pub async fn authorize(
state: AxumState<State>,
jar: PrivateCookieJar,
) -> (PrivateCookieJar, Redirect) {
let csrf_token = generate_csrf_token(16);
let client_id = &state.config.oauth_client_id;
let redirect_uri = &state.config.oauth_redirect_url;
let auth_url = format!(
"{OAUTH_BASE_URL}?response_type=code\
&client_id={client_id}\
&state={csrf_token}\
&redirect_uri={redirect_uri}\
&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email"
);
let csrf_cookie =
(state.config.cookie_settings)(Cookie::build((COOKIE_CSRF, csrf_token))).build();
let url = Into::<String>::into(auth_url);
(jar.add(csrf_cookie), Redirect::to(&url))
}
pub async fn authorized(
state: AxumState<State>,
jar: PrivateCookieJar,
Query(auth_request): Query<AuthRequest>,
) -> Result<PrivateCookieJar> {
let params = [
("grant_type", "authorization_code"),
("redirect_uri", &state.config.oauth_redirect_url),
("code", &auth_request.code),
("client_id", &state.config.oauth_client_id),
("client_secret", &state.config.oauth_client_secret),
];
let auth_response = reqwest::Client::new()
.post(OAUTH_TOKEN_URL)
.form(&params)
.send()
.await
.context("exchanging oauth code for token")?
.json::<AuthResponse>()
.await
.context("deserializing access_token response")?;
let Some(UserInfo { hd, sub }) = decode_id_token(auth_response.id_token) else {
bail!("failed to decode id token")
};
let allowed_domain = &state.config.oauth_allowed_domain;
if hd != *allowed_domain {
bail!("{hd} doesn't match {allowed_domain}")
}
let token_duration = Duration::try_seconds(auth_response.expires_in as i64).unwrap();
let expires_at = Utc.from_utc_datetime(&(Local::now().naive_local() + token_duration));
let cookie_max_age = time::Duration::new(token_duration.num_seconds(), 0);
let session_cookie = (state.config.cookie_settings)(Cookie::build((
COOKIE_SID,
auth_response.access_token.clone(),
)))
.max_age(cookie_max_age)
.build();
state
.db_client
.query(
"WITH user_insert AS (\
INSERT INTO users (sub) VALUES ($1) \
ON CONFLICT (sub) DO UPDATE SET sub = excluded.sub RETURNING id)\
INSERT INTO sessions (user_id, session_id, expires_at) \
SELECT id, $2, $3 FROM user_insert \
ON CONFLICT (user_id) DO UPDATE SET \
session_id = excluded.session_id, \
expires_at = excluded.expires_at",
&[&sub, &auth_response.access_token, &expires_at],
)
.await
.with_context(|| format!("updating session for {sub}"))?;
let csrf_cookie = jar.get(COOKIE_CSRF).unwrap(); // set in authorize()
Ok(jar.remove(csrf_cookie).add(session_cookie))
}

240
shortener/src/main.rs Normal file
View File

@@ -0,0 +1,240 @@
//! Shortener is a service to gate access to internal infrastructure
//! URLs behind team authorisation to expose less private information.
pub mod google_oauth_gate;
use crate::google_oauth_gate::{AuthRequest, State, UserId};
use anyhow::Result;
use axum::Form;
use axum::extract::State as AxumState;
use axum::extract::{Path, Query};
use axum::http::StatusCode;
use axum::response::{Html, IntoResponse};
use axum::response::{Redirect, Response};
use axum::routing::get;
use axum_extra::extract::PrivateCookieJar;
use axum_extra::extract::cookie::Cookie;
use cookie::CookieBuilder;
use google_oauth_gate::Config;
use serde::Deserialize;
use std::env;
use tracing::{error, info};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
const SOCKET: &str = "127.0.0.1:12344";
const HOST: &str = "http://127.0.0.1:12344";
const COOKIE_REDIRECT: &str = "redirect";
const ALLOWED_OAUTH_DOMAIN: &str = "neon.tech";
const AUTHORIZED_ROUTE: &str = "/authorized";
const SHORT_URL_LEN: usize = 6;
fn cookie_settings(b: CookieBuilder) -> CookieBuilder {
if HOST.contains("127.0.0.1") {
b.path("/")
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
} else {
b.path("/")
.domain(ALLOWED_OAUTH_DOMAIN)
.secure(true)
.http_only(false)
}
}
fn oauth_redirect_url() -> String {
format!("{HOST}{AUTHORIZED_ROUTE}")
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let oauth_client_id = env::var("OAUTH_CLIENT_ID").expect("Missing OAUTH_CLIENT_ID");
let oauth_client_secret = env::var("OAUTH_CLIENT_SECRET").expect("Missing OAUTH_CLIENT_SECRET");
let db_connstr = env::var("DB_CONNSTR").expect("Missing DB_CONNSTR");
let config = Config {
oauth_client_id,
oauth_client_secret,
oauth_redirect_url: oauth_redirect_url(),
oauth_allowed_domain: ALLOWED_OAUTH_DOMAIN.to_string(),
cookie_settings,
};
let (state, db_conn) = State::new(config, &db_connstr).await?;
tokio::spawn(async move {
if let Err(err) = db_conn.await {
error!(%err, "connecting to database");
std::process::exit(1);
}
});
let router = axum::Router::new()
.route("/", get(index).post(shorten))
.route("/authorize", get(authorize))
.route(AUTHORIZED_ROUTE, get(authorized))
.route("/{short_url}", get(redirect))
.with_state(state);
let listener = tokio::net::TcpListener::bind(SOCKET)
.await
.expect("failed to bind TcpListener");
info!("listening on {SOCKET}");
axum::serve(listener, router).await.unwrap();
Ok(())
}
#[derive(Deserialize)]
struct LongUrl {
url: String,
}
fn shorten_form(short_url: &str) -> Html<String> {
let mut form = r#"
<div style="margin:auto;width:50%;padding:10px">
<form method="post">
<input type="text" name="url" style="width:100%">
<input type="submit" value="Shorten" style="margin-top:10px">
</form>"#
.to_string();
if !short_url.is_empty() {
form += &format!(
r#"
<p>
<a id="short" href="{0}">{0}</a>
<button onclick="copy()">Copy</button>
</p>
<script>
function copy() {{
navigator.clipboard.writeText(document.querySelector("\#short").textContent);
}}
</script>"#,
short_url
);
}
form += "</div>";
Html(form)
}
fn authorize_link(short_url: &str) -> Html<String> {
Html(format!(
"<a href=\"/authorize?short_url={short_url}\">Authorize</a>"
))
}
async fn index(user: Option<UserId>) -> Html<String> {
if user.is_some() {
return shorten_form("");
}
authorize_link("")
}
async fn shorten(
state: AxumState<State>,
user: Option<UserId>,
Form(LongUrl { url }): Form<LongUrl>,
) -> Response {
let user_id = match user {
None => return StatusCode::FORBIDDEN.into_response(),
Some(user) => user.id.get(),
};
if url.is_empty() {
return shorten_form("").into_response();
}
let mut short_url = "".to_string();
for i in 0..20 {
short_url = nanoid::nanoid!(SHORT_URL_LEN);
let query = state
.db_client
.query_opt(
"INSERT INTO urls (user_id, short_url, long_url) VALUES ($1, $2, $3) \
ON CONFLICT (short_url) DO NOTHING \
RETURNING short_url",
&[&user_id, &short_url, &url],
)
.await;
match query {
Ok(Some(_)) => break,
Ok(None) => {
info!(short_url, "url clash, retry {i}");
continue;
}
Err(err) => {
error!(%err, "inserting shortened url");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
}
shorten_form(&format!("{HOST}/{short_url}")).into_response()
}
async fn redirect(
state: AxumState<State>,
user: Option<UserId>,
Path(short_url): Path<String>,
) -> Response {
let user_id = match user {
None => return authorize_link(&short_url).into_response(),
Some(user) => user.id,
};
let query = state
.db_client
.query_opt(
"SELECT long_url FROM urls WHERE short_url = $1",
&[&short_url],
)
.await;
match query {
Ok(Some(row)) => Redirect::permanent(row.get(0)).into_response(),
Ok(None) => StatusCode::NOT_FOUND.into_response(),
Err(err) => {
error!(%err, %short_url, %user_id, "querying long url");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
#[derive(Deserialize)]
struct AuthorizeQuery {
short_url: String,
}
async fn authorize(
state: AxumState<State>,
jar: PrivateCookieJar,
Query(AuthorizeQuery { short_url }): Query<AuthorizeQuery>,
) -> (PrivateCookieJar, Redirect) {
let (jar, auth_redirect) = google_oauth_gate::authorize(state, jar).await;
let redirect_cookie = Cookie::build((COOKIE_REDIRECT, short_url))
.path("/")
//.TODO secure(true) not true for localhost
//.domain(COOKIE_DOMAIN)
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
.build();
(jar.add(redirect_cookie), auth_redirect)
}
async fn authorized(
state: AxumState<State>,
jar: PrivateCookieJar,
query: Query<AuthRequest>,
) -> Result<(PrivateCookieJar, Redirect), Response> {
use google_oauth_gate::authorized;
let jar = authorized(state, jar, query).await.map_err(|err| {
error!(%err, "authorizing");
return StatusCode::UNAUTHORIZED.into_response();
})?;
let Some(redirect_cookie) = jar.get(COOKIE_REDIRECT) else {
return Ok((jar, Redirect::to("/")));
};
let redirect_url = Redirect::to(&format!("/{}", redirect_cookie.value_trimmed()));
Ok((jar.remove(redirect_cookie), redirect_url))
}

View File

@@ -20,6 +20,7 @@ anstream = { version = "0.6" }
anyhow = { version = "1", features = ["backtrace"] }
axum = { version = "0.8", features = ["ws"] }
axum-core = { version = "0.5", default-features = false, features = ["tracing"] }
axum-extra = { version = "0.10", features = ["cookie-private", "typed-header"] }
base64-594e8ee84c453af0 = { package = "base64", version = "0.13", features = ["alloc"] }
base64-647d43efb71741da = { package = "base64", version = "0.21" }
base64ct = { version = "1", default-features = false, features = ["std"] }
@@ -30,6 +31,7 @@ clap = { version = "4", features = ["derive", "env", "string"] }
clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] }
const-oid = { version = "0.9", default-features = false, features = ["db", "std"] }
crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] }
crypto-common = { version = "0.1", default-features = false, features = ["getrandom", "std"] }
der = { version = "0.7", default-features = false, features = ["derive", "flagset", "oid", "pem", "std"] }
deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] }
digest = { version = "0.10", features = ["mac", "oid", "std"] }