mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 08:52:56 +00:00
initial
This commit is contained in:
180
Cargo.lock
generated
180
Cargo.lock
generated
@@ -29,6 +29,41 @@ version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
|
||||
|
||||
[[package]]
|
||||
name = "aead"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
|
||||
dependencies = [
|
||||
"crypto-common",
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aes"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cipher",
|
||||
"cpufeatures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aes-gcm"
|
||||
version = "0.10.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1"
|
||||
dependencies = [
|
||||
"aead",
|
||||
"aes",
|
||||
"cipher",
|
||||
"ctr",
|
||||
"ghash",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.8.11"
|
||||
@@ -753,6 +788,7 @@ dependencies = [
|
||||
"axum",
|
||||
"axum-core",
|
||||
"bytes",
|
||||
"cookie",
|
||||
"futures-util",
|
||||
"headers",
|
||||
"http 1.1.0",
|
||||
@@ -1173,6 +1209,16 @@ dependencies = [
|
||||
"half",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cipher"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
|
||||
dependencies = [
|
||||
"crypto-common",
|
||||
"inout",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clang-sys"
|
||||
version = "1.6.1"
|
||||
@@ -1464,6 +1510,21 @@ dependencies = [
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cookie"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
|
||||
dependencies = [
|
||||
"aes-gcm",
|
||||
"base64 0.22.1",
|
||||
"percent-encoding",
|
||||
"rand 0.8.5",
|
||||
"subtle",
|
||||
"time",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.3"
|
||||
@@ -1657,9 +1718,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
"rand_core 0.6.4",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctr"
|
||||
version = "0.9.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835"
|
||||
dependencies = [
|
||||
"cipher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "curve25519-dalek"
|
||||
version = "4.1.3"
|
||||
@@ -2510,6 +2581,16 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ghash"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1"
|
||||
dependencies = [
|
||||
"opaque-debug",
|
||||
"polyval",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gimli"
|
||||
version = "0.31.1"
|
||||
@@ -3281,6 +3362,15 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inout"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "instant"
|
||||
version = "0.1.12"
|
||||
@@ -3794,6 +3884,15 @@ version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
|
||||
|
||||
[[package]]
|
||||
name = "nanoid"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8"
|
||||
dependencies = [
|
||||
"rand 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "neon-shmem"
|
||||
version = "0.1.0"
|
||||
@@ -4066,6 +4165,12 @@ version = "11.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
|
||||
|
||||
[[package]]
|
||||
name = "opaque-debug"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-probe"
|
||||
version = "0.1.5"
|
||||
@@ -4585,6 +4690,31 @@ version = "1.0.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
|
||||
|
||||
[[package]]
|
||||
name = "paster"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"axum-extra",
|
||||
"base64 0.13.1",
|
||||
"chrono",
|
||||
"nanoid",
|
||||
"rand 0.8.5",
|
||||
"reqwest",
|
||||
"rustls 0.23.27",
|
||||
"rustls-native-certs 0.8.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pbkdf2"
|
||||
version = "0.12.2"
|
||||
@@ -4757,6 +4887,18 @@ dependencies = [
|
||||
"never-say-never",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "polyval"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"opaque-debug",
|
||||
"universal-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "1.10.0"
|
||||
@@ -6559,6 +6701,32 @@ version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||
|
||||
[[package]]
|
||||
name = "shortener"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"axum-extra",
|
||||
"base64 0.13.1",
|
||||
"chrono",
|
||||
"cookie",
|
||||
"nanoid",
|
||||
"rand 0.8.5",
|
||||
"reqwest",
|
||||
"rustls 0.23.27",
|
||||
"rustls-native-certs 0.8.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook"
|
||||
version = "0.3.15"
|
||||
@@ -7927,6 +8095,16 @@ version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
|
||||
|
||||
[[package]]
|
||||
name = "universal-hash"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
|
||||
dependencies = [
|
||||
"crypto-common",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.9.0"
|
||||
@@ -8559,6 +8737,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"axum-core",
|
||||
"axum-extra",
|
||||
"base64 0.13.1",
|
||||
"base64 0.21.7",
|
||||
"base64ct",
|
||||
@@ -8570,6 +8749,7 @@ dependencies = [
|
||||
"clap_builder",
|
||||
"const-oid",
|
||||
"crypto-bigint 0.5.5",
|
||||
"crypto-common",
|
||||
"der 0.7.8",
|
||||
"deranged",
|
||||
"digest",
|
||||
|
||||
@@ -13,6 +13,8 @@ members = [
|
||||
"proxy",
|
||||
"safekeeper",
|
||||
"safekeeper/client",
|
||||
"shortener",
|
||||
"paster",
|
||||
"storage_broker",
|
||||
"storage_controller",
|
||||
"storage_controller/client",
|
||||
|
||||
25
paster/Cargo.toml
Normal file
25
paster/Cargo.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "paster"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
axum-extra = { workspace = true, features = ["cookie", "cookie-private"] }
|
||||
axum.workspace = true
|
||||
base64.workspace = true
|
||||
chrono.workspace = true
|
||||
nanoid = { version = "0.4.0", default-features = false }
|
||||
rand.workspace = true
|
||||
reqwest.workspace = true
|
||||
rustls-native-certs.workspace = true
|
||||
rustls.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
time = { version = "0.3.36", default-features = false }
|
||||
tokio-postgres-rustls.workspace = true
|
||||
tokio-postgres.workspace = true
|
||||
tokio.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing.workspace = true
|
||||
workspace_hack.workspace = true
|
||||
18
paster/migrations/1_initial.sql
Normal file
18
paster/migrations/1_initial.sql
Normal file
@@ -0,0 +1,18 @@
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
sub VARCHAR(100) NOT NULL UNIQUE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INT NOT NULL UNIQUE REFERENCES users(id),
|
||||
session_id VARCHAR NOT NULL,
|
||||
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pastes (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INT NOT NULL REFERENCES users(id),
|
||||
paste text NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
353
paster/src/main.rs
Normal file
353
paster/src/main.rs
Normal file
@@ -0,0 +1,353 @@
|
||||
//! Paster is a service to share logs or code snippets outside of
|
||||
//! Slack, not relying on public services
|
||||
use anyhow::Result;
|
||||
use shortener::google_oauth_gate::{AuthRequest, State, UserId};
|
||||
use axum::Form;
|
||||
use axum::extract::{FromRef, FromRequestParts, Path, Query, State as AxumStateT};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{Html, IntoResponse};
|
||||
use axum::response::{Redirect, Response};
|
||||
use axum::routing::get;
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use axum_extra::extract::cookie::{Cookie, Key};
|
||||
use chrono::{Duration, Local, TimeZone, Utc};
|
||||
use core::num::NonZeroI32;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info};
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
|
||||
const SOCKET: &str = "127.0.0.1:12344";
|
||||
const HOST: &str = "http://127.0.0.1:12344";
|
||||
const ALLOWED_OAUTH_DOMAIN: &str = "neon.tech";
|
||||
|
||||
fn oauth_redirect_url() -> String {
|
||||
format!("{HOST}{AUTHORIZED_ROUTE}")
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
let oauth_client_id = env::var("OAUTH_CLIENT_ID").expect("Missing OAUTH_CLIENT_ID");
|
||||
let oauth_client_secret = env::var("OAUTH_CLIENT_SECRET").expect("Missing OAUTH_CLIENT_SECRET");
|
||||
|
||||
let db_connstr = env::var("DB_CONNSTR").expect("Missing DB_CONNSTR");
|
||||
let mut roots = rustls::RootCertStore::empty();
|
||||
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") {
|
||||
roots.add(cert).unwrap();
|
||||
}
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.with_root_certificates(roots)
|
||||
.with_no_client_auth();
|
||||
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config);
|
||||
info!("initialized TLS");
|
||||
|
||||
let (db_client, db_conn) = tokio_postgres::connect(&db_connstr, tls).await?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = db_conn.await {
|
||||
error!(%err, "connecting to database");
|
||||
std::process::exit(1);
|
||||
}
|
||||
});
|
||||
info!("connected to database");
|
||||
|
||||
let state = InnerState {
|
||||
db_client,
|
||||
cookie_jar_key: Key::generate(),
|
||||
oauth_client_id,
|
||||
oauth_client_secret,
|
||||
};
|
||||
let router = axum::Router::new()
|
||||
.route("/", get(index).post(paste))
|
||||
.route("/authorize", get(authorize))
|
||||
.route(AUTHORIZED_ROUTE, get(authorized))
|
||||
.route("/{id}", get(view_paste))
|
||||
.with_state(State { 0: Arc::new(state) });
|
||||
let listener = tokio::net::TcpListener::bind(SOCKET)
|
||||
.await
|
||||
.expect("failed to bind TcpListener");
|
||||
info!("listening on {SOCKET}");
|
||||
axum::serve(listener, router).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UserId {
|
||||
id: NonZeroI32,
|
||||
}
|
||||
|
||||
impl axum::extract::OptionalFromRequestParts<State> for UserId {
|
||||
type Rejection = Response;
|
||||
async fn from_request_parts(
|
||||
parts: &mut axum::http::request::Parts,
|
||||
state: &State,
|
||||
) -> Result<Option<Self>, Self::Rejection> {
|
||||
let jar: PrivateCookieJar = PrivateCookieJar::from_request_parts(parts, state)
|
||||
.await
|
||||
.unwrap(); // infallible
|
||||
let Some(session_id) = jar.get(COOKIE_SID).map(|cookie| cookie.value().to_owned()) else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let client = &state.db_client;
|
||||
let query = client
|
||||
.query_opt(
|
||||
"SELECT user_id FROM sessions WHERE session_id = $1",
|
||||
&[&session_id],
|
||||
)
|
||||
.await;
|
||||
let id = match query {
|
||||
Ok(Some(row)) => row.get::<usize, i32>(0),
|
||||
Ok(None) => return Ok(None),
|
||||
Err(err) => {
|
||||
error!(%err, "querying user session");
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
let id = NonZeroI32::new(id).unwrap(); // postgres id guaranteed not to be zero
|
||||
Ok(Some(Self { id }))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Paste {
|
||||
paste: String,
|
||||
}
|
||||
|
||||
fn paste_form() -> Html<String> {
|
||||
Html(
|
||||
r#"
|
||||
<form method="post">
|
||||
<textarea name="paste" style="width:100%;height:80%"></textarea>
|
||||
<input type="submit" value="Paste" style="margin-top:10px">
|
||||
</form>"#
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
fn authorize_link(paste_id: i32) -> String {
|
||||
format!("<a href=\"/authorize?paste_id={paste_id}\">Authorize</a>")
|
||||
}
|
||||
|
||||
async fn index(user: Option<UserId>) -> Html<String> {
|
||||
if user.is_some() {
|
||||
return paste_form();
|
||||
}
|
||||
Html(authorize_link(0))
|
||||
}
|
||||
|
||||
async fn paste(
|
||||
state: AxumState,
|
||||
user: Option<UserId>,
|
||||
Form(Paste { paste }): Form<Paste>,
|
||||
) -> Response {
|
||||
let user_id = match user {
|
||||
None => return StatusCode::FORBIDDEN.into_response(),
|
||||
Some(user) => user.id,
|
||||
};
|
||||
if paste.is_empty() {
|
||||
return paste_form().into_response();
|
||||
}
|
||||
|
||||
let query = state
|
||||
.db_client
|
||||
.query_one(
|
||||
"INSERT INTO pastes (user_id, paste) VALUES ($1, $2) RETURNING id",
|
||||
&[&user_id.get(), &paste],
|
||||
)
|
||||
.await;
|
||||
let id = match query {
|
||||
Ok(row) => row.get::<usize, i32>(0),
|
||||
Err(err) => {
|
||||
error!(%err, "inserting paste");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
Redirect::to(&format!("/{id}")).into_response()
|
||||
}
|
||||
|
||||
async fn view_paste(state: AxumState, user: Option<UserId>, Path(paste_id): Path<i32>) -> Response {
|
||||
let user_id = match user {
|
||||
None => return Html(authorize_link(paste_id)).into_response(),
|
||||
Some(user) => user.id,
|
||||
};
|
||||
|
||||
let query = state
|
||||
.db_client
|
||||
.query_opt("SELECT paste FROM pastes WHERE id = $1", &[&paste_id])
|
||||
.await;
|
||||
let row = match query {
|
||||
Ok(None) => return StatusCode::NOT_FOUND.into_response(),
|
||||
Ok(Some(row)) => row,
|
||||
Err(err) => {
|
||||
error!(%err, %paste_id, %user_id, "querying paste");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
row.get::<usize, String>(0).into_response()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthRequest {
|
||||
code: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthResponse {
|
||||
access_token: String,
|
||||
id_token: String,
|
||||
expires_in: u64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UserInfo {
|
||||
hd: String,
|
||||
sub: String,
|
||||
}
|
||||
|
||||
fn decode_id_token(token: String) -> Option<UserInfo> {
|
||||
let payload = token.split(".").skip(1).take(1).collect::<Vec<&str>>();
|
||||
let decoded = base64::decode_config(payload.get(0)?, base64::STANDARD_NO_PAD).ok()?;
|
||||
serde_json::from_slice::<UserInfo>(&decoded).ok()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthorizeQuery {
|
||||
paste_id: i32,
|
||||
}
|
||||
|
||||
fn generate_csrf_token(num_bytes: u32) -> String {
|
||||
use rand::{Rng, thread_rng};
|
||||
let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().r#gen::<u8>()).collect();
|
||||
base64::encode_config(&random_bytes, base64::URL_SAFE_NO_PAD)
|
||||
}
|
||||
|
||||
async fn authorize(
|
||||
state: AxumState,
|
||||
jar: PrivateCookieJar,
|
||||
Query(AuthorizeQuery { paste_id }): Query<AuthorizeQuery>,
|
||||
) -> (PrivateCookieJar, Redirect) {
|
||||
let csrf_token = generate_csrf_token(16);
|
||||
let client_id = &state.oauth_client_id;
|
||||
let redirect_uri = oauth_redirect_url();
|
||||
let auth_url = format!(
|
||||
"{OAUTH_BASE_URL}?response_type=code\
|
||||
&client_id={client_id}\
|
||||
&state={csrf_token}\
|
||||
&redirect_uri={redirect_uri}\
|
||||
&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email"
|
||||
);
|
||||
|
||||
let redirect_cookie = Cookie::build((COOKIE_REDIRECT, paste_id.to_string()))
|
||||
.path("/")
|
||||
//.TODO secure(true) not true for localhost
|
||||
//.domain(COOKIE_DOMAIN)
|
||||
.secure(false)
|
||||
.same_site(axum_extra::extract::cookie::SameSite::Lax)
|
||||
.http_only(true)
|
||||
.build();
|
||||
let csrf_cookie = Cookie::build((COOKIE_CSRF, csrf_token))
|
||||
.path("/")
|
||||
//.TODO secure(true) not true for localhost
|
||||
//.domain(COOKIE_DOMAIN)
|
||||
.secure(false)
|
||||
.same_site(axum_extra::extract::cookie::SameSite::Lax)
|
||||
.http_only(true)
|
||||
.build();
|
||||
let jar = jar.add(redirect_cookie).add(csrf_cookie);
|
||||
let url = Into::<String>::into(auth_url);
|
||||
(jar, Redirect::to(&url))
|
||||
}
|
||||
|
||||
async fn authorized(
|
||||
state: AxumState,
|
||||
jar: PrivateCookieJar,
|
||||
Query(auth_request): Query<AuthRequest>,
|
||||
) -> Result<(PrivateCookieJar, Redirect), Response> {
|
||||
let params = [
|
||||
("grant_type", "authorization_code"),
|
||||
("redirect_uri", &oauth_redirect_url()),
|
||||
("code", &auth_request.code),
|
||||
("client_id", &state.oauth_client_id),
|
||||
("client_secret", &state.oauth_client_secret),
|
||||
];
|
||||
let auth_response = reqwest::Client::new()
|
||||
.post(OAUTH_TOKEN_URL)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
error!(%err, "exchanging oauth code for token");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
})?
|
||||
.json::<AuthResponse>()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
error!(%err, "deserializing access token response");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
})?;
|
||||
let Some(UserInfo { hd, sub }) = decode_id_token(auth_response.id_token) else {
|
||||
error!("Failed to decode response id token");
|
||||
return Err(StatusCode::UNAUTHORIZED.into_response());
|
||||
};
|
||||
if hd != ALLOWED_OAUTH_DOMAIN {
|
||||
error!(hd, "Domain doesn't match {ALLOWED_OAUTH_DOMAIN}");
|
||||
return Err(StatusCode::UNAUTHORIZED.into_response());
|
||||
}
|
||||
|
||||
let token_duration = Duration::try_seconds(auth_response.expires_in as i64).unwrap();
|
||||
let expires_at = Utc.from_utc_datetime(&(Local::now().naive_local() + token_duration));
|
||||
let cookie_max_age = time::Duration::new(token_duration.num_seconds(), 0);
|
||||
|
||||
let session_cookie = Cookie::build((COOKIE_SID, auth_response.access_token.clone()))
|
||||
.path("/")
|
||||
//.TODO secure(true) not true for localhost
|
||||
//.domain(COOKIE_DOMAIN)
|
||||
.secure(false)
|
||||
.same_site(axum_extra::extract::cookie::SameSite::Lax)
|
||||
.http_only(true)
|
||||
.max_age(cookie_max_age)
|
||||
.build();
|
||||
|
||||
state
|
||||
.db_client
|
||||
.query(
|
||||
"WITH user_insert AS (\
|
||||
INSERT INTO users (sub) VALUES ($1) \
|
||||
ON CONFLICT (sub) DO UPDATE SET sub = excluded.sub RETURNING id)\
|
||||
INSERT INTO sessions (user_id, session_id, expires_at) \
|
||||
SELECT id, $2, $3 FROM user_insert \
|
||||
ON CONFLICT (user_id) DO UPDATE SET \
|
||||
session_id = excluded.session_id, \
|
||||
expires_at = excluded.expires_at",
|
||||
&[&sub, &auth_response.access_token, &expires_at],
|
||||
)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
error!(%err, %sub, "updating session");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
})?;
|
||||
|
||||
let csrf_cookie = jar.get(COOKIE_CSRF).unwrap(); // set in authorize()
|
||||
let jar = jar.remove(csrf_cookie).add(session_cookie);
|
||||
match jar.get(COOKIE_REDIRECT) {
|
||||
Some(redirect_cookie) => {
|
||||
let mut value = redirect_cookie.value_trimmed();
|
||||
if value == "0" {
|
||||
value = "";
|
||||
}
|
||||
let redirect_url = format!("/{value}");
|
||||
Ok((jar.remove(redirect_cookie), Redirect::to(&redirect_url)))
|
||||
}
|
||||
None => Ok((jar, Redirect::to("/"))),
|
||||
}
|
||||
}
|
||||
26
shortener/Cargo.toml
Normal file
26
shortener/Cargo.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[package]
|
||||
name = "shortener"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
axum-extra = { workspace = true, features = ["cookie", "cookie-private"] }
|
||||
axum.workspace = true
|
||||
base64.workspace = true
|
||||
chrono.workspace = true
|
||||
cookie = "0.18.1"
|
||||
nanoid = { version = "0.4.0", default-features = false }
|
||||
rand.workspace = true
|
||||
reqwest.workspace = true
|
||||
rustls-native-certs.workspace = true
|
||||
rustls.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
time = { version = "0.3.36", default-features = false }
|
||||
tokio-postgres-rustls.workspace = true
|
||||
tokio-postgres.workspace = true
|
||||
tokio.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing.workspace = true
|
||||
workspace_hack.workspace = true
|
||||
19
shortener/migrations/1_initial.sql
Normal file
19
shortener/migrations/1_initial.sql
Normal file
@@ -0,0 +1,19 @@
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
sub VARCHAR(100) NOT NULL UNIQUE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INT NOT NULL UNIQUE REFERENCES users(id),
|
||||
session_id VARCHAR NOT NULL,
|
||||
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS urls (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INT NOT NULL REFERENCES users(id),
|
||||
short_url VARCHAR(6) NOT NULL UNIQUE,
|
||||
long_url VARCHAR NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
222
shortener/src/google_oauth_gate.rs
Normal file
222
shortener/src/google_oauth_gate.rs
Normal file
@@ -0,0 +1,222 @@
|
||||
//! Library to gate infrastructure behind Google Oauth for domain.
|
||||
//!
|
||||
//! Why not oauth-rs? Oauth .exchange_code() doesn't work with "request failed". Also, we can't get
|
||||
//! id token from it, and I don't want to pull in whole openid library just for that.
|
||||
//! Id token saves us a request to openid endpoint and one Oauth scope we don't use
|
||||
use anyhow::{Context, Result, bail};
|
||||
use axum::extract::{FromRef, FromRequestParts, Query, State as AxumState};
|
||||
use axum::response::Redirect;
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use axum_extra::extract::cookie::{Cookie, Key};
|
||||
use chrono::{Duration, Local, TimeZone, Utc};
|
||||
use cookie::CookieBuilder;
|
||||
use core::num::NonZeroI32;
|
||||
use reqwest::StatusCode;
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use tokio_postgres::Socket;
|
||||
|
||||
const OAUTH_BASE_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth";
|
||||
const OAUTH_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
|
||||
const COOKIE_SID: &str = "sid";
|
||||
const COOKIE_CSRF: &str = "csrf";
|
||||
|
||||
pub struct Config {
|
||||
pub oauth_client_id: String,
|
||||
pub oauth_client_secret: String,
|
||||
pub oauth_redirect_url: String,
|
||||
pub oauth_allowed_domain: String,
|
||||
pub cookie_settings: fn(CookieBuilder) -> CookieBuilder,
|
||||
}
|
||||
|
||||
pub struct InnerState {
|
||||
config: Config,
|
||||
cookie_jar_key: Key,
|
||||
pub db_client: tokio_postgres::Client,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct State(Arc<InnerState>);
|
||||
type DbConn = tokio_postgres::Connection<Socket, tokio_postgres_rustls::RustlsStream<Socket>>;
|
||||
|
||||
impl State {
|
||||
pub async fn new(config: Config, db_connstr: &str) -> Result<(Self, DbConn)> {
|
||||
let mut roots = rustls::RootCertStore::empty();
|
||||
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
|
||||
{
|
||||
roots.add(cert).unwrap();
|
||||
}
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.with_root_certificates(roots)
|
||||
.with_no_client_auth();
|
||||
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
|
||||
|
||||
let (db_client, db_conn) = tokio_postgres::connect(&db_connstr, tls).await?;
|
||||
let inner = InnerState {
|
||||
config,
|
||||
cookie_jar_key: Key::generate(),
|
||||
db_client,
|
||||
};
|
||||
Ok((Self { 0: Arc::new(inner) }, db_conn))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for State {
|
||||
type Target = InnerState;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&*self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for Key {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.cookie_jar_key.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UserId {
|
||||
pub id: NonZeroI32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AuthRequest {
|
||||
code: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthResponse {
|
||||
access_token: String,
|
||||
id_token: String,
|
||||
expires_in: u64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UserInfo {
|
||||
hd: String,
|
||||
sub: String,
|
||||
}
|
||||
|
||||
impl axum::extract::OptionalFromRequestParts<State> for UserId {
|
||||
type Rejection = StatusCode;
|
||||
async fn from_request_parts(
|
||||
parts: &mut axum::http::request::Parts,
|
||||
state: &State,
|
||||
) -> Result<Option<Self>, Self::Rejection> {
|
||||
let jar: PrivateCookieJar = PrivateCookieJar::from_request_parts(parts, state)
|
||||
.await
|
||||
.unwrap(); // infallible
|
||||
let Some(session_id) = jar.get(COOKIE_SID).map(|cookie| cookie.value().to_owned()) else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let client = &state.db_client;
|
||||
let query = client
|
||||
.query_opt(
|
||||
"SELECT user_id FROM sessions WHERE session_id = $1",
|
||||
&[&session_id],
|
||||
)
|
||||
.await;
|
||||
let id = match query {
|
||||
Ok(Some(row)) => row.get::<usize, i32>(0),
|
||||
Ok(None) => return Ok(None),
|
||||
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
};
|
||||
let id = NonZeroI32::new(id).unwrap(); // postgres id guaranteed not to be zero
|
||||
Ok(Some(Self { id }))
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_id_token(token: String) -> Option<UserInfo> {
|
||||
let payload = token.split(".").skip(1).take(1).collect::<Vec<&str>>();
|
||||
let decoded = base64::decode_config(payload.get(0)?, base64::STANDARD_NO_PAD).ok()?;
|
||||
serde_json::from_slice::<UserInfo>(&decoded).ok()
|
||||
}
|
||||
|
||||
fn generate_csrf_token(num_bytes: u32) -> String {
|
||||
use rand::{Rng, thread_rng};
|
||||
let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().r#gen::<u8>()).collect();
|
||||
base64::encode_config(&random_bytes, base64::URL_SAFE_NO_PAD)
|
||||
}
|
||||
|
||||
pub async fn authorize(
|
||||
state: AxumState<State>,
|
||||
jar: PrivateCookieJar,
|
||||
) -> (PrivateCookieJar, Redirect) {
|
||||
let csrf_token = generate_csrf_token(16);
|
||||
let client_id = &state.config.oauth_client_id;
|
||||
let redirect_uri = &state.config.oauth_redirect_url;
|
||||
let auth_url = format!(
|
||||
"{OAUTH_BASE_URL}?response_type=code\
|
||||
&client_id={client_id}\
|
||||
&state={csrf_token}\
|
||||
&redirect_uri={redirect_uri}\
|
||||
&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email"
|
||||
);
|
||||
|
||||
let csrf_cookie =
|
||||
(state.config.cookie_settings)(Cookie::build((COOKIE_CSRF, csrf_token))).build();
|
||||
let url = Into::<String>::into(auth_url);
|
||||
(jar.add(csrf_cookie), Redirect::to(&url))
|
||||
}
|
||||
|
||||
pub async fn authorized(
|
||||
state: AxumState<State>,
|
||||
jar: PrivateCookieJar,
|
||||
Query(auth_request): Query<AuthRequest>,
|
||||
) -> Result<PrivateCookieJar> {
|
||||
let params = [
|
||||
("grant_type", "authorization_code"),
|
||||
("redirect_uri", &state.config.oauth_redirect_url),
|
||||
("code", &auth_request.code),
|
||||
("client_id", &state.config.oauth_client_id),
|
||||
("client_secret", &state.config.oauth_client_secret),
|
||||
];
|
||||
let auth_response = reqwest::Client::new()
|
||||
.post(OAUTH_TOKEN_URL)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.context("exchanging oauth code for token")?
|
||||
.json::<AuthResponse>()
|
||||
.await
|
||||
.context("deserializing access_token response")?;
|
||||
let Some(UserInfo { hd, sub }) = decode_id_token(auth_response.id_token) else {
|
||||
bail!("failed to decode id token")
|
||||
};
|
||||
|
||||
let allowed_domain = &state.config.oauth_allowed_domain;
|
||||
if hd != *allowed_domain {
|
||||
bail!("{hd} doesn't match {allowed_domain}")
|
||||
}
|
||||
|
||||
let token_duration = Duration::try_seconds(auth_response.expires_in as i64).unwrap();
|
||||
let expires_at = Utc.from_utc_datetime(&(Local::now().naive_local() + token_duration));
|
||||
let cookie_max_age = time::Duration::new(token_duration.num_seconds(), 0);
|
||||
|
||||
let session_cookie = (state.config.cookie_settings)(Cookie::build((
|
||||
COOKIE_SID,
|
||||
auth_response.access_token.clone(),
|
||||
)))
|
||||
.max_age(cookie_max_age)
|
||||
.build();
|
||||
|
||||
state
|
||||
.db_client
|
||||
.query(
|
||||
"WITH user_insert AS (\
|
||||
INSERT INTO users (sub) VALUES ($1) \
|
||||
ON CONFLICT (sub) DO UPDATE SET sub = excluded.sub RETURNING id)\
|
||||
INSERT INTO sessions (user_id, session_id, expires_at) \
|
||||
SELECT id, $2, $3 FROM user_insert \
|
||||
ON CONFLICT (user_id) DO UPDATE SET \
|
||||
session_id = excluded.session_id, \
|
||||
expires_at = excluded.expires_at",
|
||||
&[&sub, &auth_response.access_token, &expires_at],
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("updating session for {sub}"))?;
|
||||
|
||||
let csrf_cookie = jar.get(COOKIE_CSRF).unwrap(); // set in authorize()
|
||||
Ok(jar.remove(csrf_cookie).add(session_cookie))
|
||||
}
|
||||
240
shortener/src/main.rs
Normal file
240
shortener/src/main.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
//! Shortener is a service to gate access to internal infrastructure
|
||||
//! URLs behind team authorisation to expose less private information.
|
||||
pub mod google_oauth_gate;
|
||||
use crate::google_oauth_gate::{AuthRequest, State, UserId};
|
||||
use anyhow::Result;
|
||||
use axum::Form;
|
||||
use axum::extract::State as AxumState;
|
||||
use axum::extract::{Path, Query};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{Html, IntoResponse};
|
||||
use axum::response::{Redirect, Response};
|
||||
use axum::routing::get;
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use axum_extra::extract::cookie::Cookie;
|
||||
use cookie::CookieBuilder;
|
||||
use google_oauth_gate::Config;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use tracing::{error, info};
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
|
||||
const SOCKET: &str = "127.0.0.1:12344";
|
||||
const HOST: &str = "http://127.0.0.1:12344";
|
||||
const COOKIE_REDIRECT: &str = "redirect";
|
||||
const ALLOWED_OAUTH_DOMAIN: &str = "neon.tech";
|
||||
const AUTHORIZED_ROUTE: &str = "/authorized";
|
||||
const SHORT_URL_LEN: usize = 6;
|
||||
|
||||
fn cookie_settings(b: CookieBuilder) -> CookieBuilder {
|
||||
if HOST.contains("127.0.0.1") {
|
||||
b.path("/")
|
||||
.secure(false)
|
||||
.same_site(axum_extra::extract::cookie::SameSite::Lax)
|
||||
.http_only(true)
|
||||
} else {
|
||||
b.path("/")
|
||||
.domain(ALLOWED_OAUTH_DOMAIN)
|
||||
.secure(true)
|
||||
.http_only(false)
|
||||
}
|
||||
}
|
||||
|
||||
fn oauth_redirect_url() -> String {
|
||||
format!("{HOST}{AUTHORIZED_ROUTE}")
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
let oauth_client_id = env::var("OAUTH_CLIENT_ID").expect("Missing OAUTH_CLIENT_ID");
|
||||
let oauth_client_secret = env::var("OAUTH_CLIENT_SECRET").expect("Missing OAUTH_CLIENT_SECRET");
|
||||
let db_connstr = env::var("DB_CONNSTR").expect("Missing DB_CONNSTR");
|
||||
|
||||
let config = Config {
|
||||
oauth_client_id,
|
||||
oauth_client_secret,
|
||||
oauth_redirect_url: oauth_redirect_url(),
|
||||
oauth_allowed_domain: ALLOWED_OAUTH_DOMAIN.to_string(),
|
||||
cookie_settings,
|
||||
};
|
||||
let (state, db_conn) = State::new(config, &db_connstr).await?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = db_conn.await {
|
||||
error!(%err, "connecting to database");
|
||||
std::process::exit(1);
|
||||
}
|
||||
});
|
||||
|
||||
let router = axum::Router::new()
|
||||
.route("/", get(index).post(shorten))
|
||||
.route("/authorize", get(authorize))
|
||||
.route(AUTHORIZED_ROUTE, get(authorized))
|
||||
.route("/{short_url}", get(redirect))
|
||||
.with_state(state);
|
||||
let listener = tokio::net::TcpListener::bind(SOCKET)
|
||||
.await
|
||||
.expect("failed to bind TcpListener");
|
||||
info!("listening on {SOCKET}");
|
||||
axum::serve(listener, router).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LongUrl {
|
||||
url: String,
|
||||
}
|
||||
|
||||
fn shorten_form(short_url: &str) -> Html<String> {
|
||||
let mut form = r#"
|
||||
<div style="margin:auto;width:50%;padding:10px">
|
||||
<form method="post">
|
||||
<input type="text" name="url" style="width:100%">
|
||||
<input type="submit" value="Shorten" style="margin-top:10px">
|
||||
</form>"#
|
||||
.to_string();
|
||||
if !short_url.is_empty() {
|
||||
form += &format!(
|
||||
r#"
|
||||
<p>
|
||||
<a id="short" href="{0}">{0}</a>
|
||||
<button onclick="copy()">Copy</button>
|
||||
</p>
|
||||
<script>
|
||||
function copy() {{
|
||||
navigator.clipboard.writeText(document.querySelector("\#short").textContent);
|
||||
}}
|
||||
</script>"#,
|
||||
short_url
|
||||
);
|
||||
}
|
||||
form += "</div>";
|
||||
Html(form)
|
||||
}
|
||||
|
||||
fn authorize_link(short_url: &str) -> Html<String> {
|
||||
Html(format!(
|
||||
"<a href=\"/authorize?short_url={short_url}\">Authorize</a>"
|
||||
))
|
||||
}
|
||||
|
||||
async fn index(user: Option<UserId>) -> Html<String> {
|
||||
if user.is_some() {
|
||||
return shorten_form("");
|
||||
}
|
||||
authorize_link("")
|
||||
}
|
||||
|
||||
async fn shorten(
|
||||
state: AxumState<State>,
|
||||
user: Option<UserId>,
|
||||
Form(LongUrl { url }): Form<LongUrl>,
|
||||
) -> Response {
|
||||
let user_id = match user {
|
||||
None => return StatusCode::FORBIDDEN.into_response(),
|
||||
Some(user) => user.id.get(),
|
||||
};
|
||||
if url.is_empty() {
|
||||
return shorten_form("").into_response();
|
||||
}
|
||||
|
||||
let mut short_url = "".to_string();
|
||||
for i in 0..20 {
|
||||
short_url = nanoid::nanoid!(SHORT_URL_LEN);
|
||||
let query = state
|
||||
.db_client
|
||||
.query_opt(
|
||||
"INSERT INTO urls (user_id, short_url, long_url) VALUES ($1, $2, $3) \
|
||||
ON CONFLICT (short_url) DO NOTHING \
|
||||
RETURNING short_url",
|
||||
&[&user_id, &short_url, &url],
|
||||
)
|
||||
.await;
|
||||
match query {
|
||||
Ok(Some(_)) => break,
|
||||
Ok(None) => {
|
||||
info!(short_url, "url clash, retry {i}");
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
error!(%err, "inserting shortened url");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
}
|
||||
shorten_form(&format!("{HOST}/{short_url}")).into_response()
|
||||
}
|
||||
|
||||
async fn redirect(
|
||||
state: AxumState<State>,
|
||||
user: Option<UserId>,
|
||||
Path(short_url): Path<String>,
|
||||
) -> Response {
|
||||
let user_id = match user {
|
||||
None => return authorize_link(&short_url).into_response(),
|
||||
Some(user) => user.id,
|
||||
};
|
||||
|
||||
let query = state
|
||||
.db_client
|
||||
.query_opt(
|
||||
"SELECT long_url FROM urls WHERE short_url = $1",
|
||||
&[&short_url],
|
||||
)
|
||||
.await;
|
||||
match query {
|
||||
Ok(Some(row)) => Redirect::permanent(row.get(0)).into_response(),
|
||||
Ok(None) => StatusCode::NOT_FOUND.into_response(),
|
||||
Err(err) => {
|
||||
error!(%err, %short_url, %user_id, "querying long url");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthorizeQuery {
|
||||
short_url: String,
|
||||
}
|
||||
|
||||
async fn authorize(
|
||||
state: AxumState<State>,
|
||||
jar: PrivateCookieJar,
|
||||
Query(AuthorizeQuery { short_url }): Query<AuthorizeQuery>,
|
||||
) -> (PrivateCookieJar, Redirect) {
|
||||
let (jar, auth_redirect) = google_oauth_gate::authorize(state, jar).await;
|
||||
let redirect_cookie = Cookie::build((COOKIE_REDIRECT, short_url))
|
||||
.path("/")
|
||||
//.TODO secure(true) not true for localhost
|
||||
//.domain(COOKIE_DOMAIN)
|
||||
.secure(false)
|
||||
.same_site(axum_extra::extract::cookie::SameSite::Lax)
|
||||
.http_only(true)
|
||||
.build();
|
||||
(jar.add(redirect_cookie), auth_redirect)
|
||||
}
|
||||
|
||||
async fn authorized(
|
||||
state: AxumState<State>,
|
||||
jar: PrivateCookieJar,
|
||||
query: Query<AuthRequest>,
|
||||
) -> Result<(PrivateCookieJar, Redirect), Response> {
|
||||
use google_oauth_gate::authorized;
|
||||
let jar = authorized(state, jar, query).await.map_err(|err| {
|
||||
error!(%err, "authorizing");
|
||||
return StatusCode::UNAUTHORIZED.into_response();
|
||||
})?;
|
||||
let Some(redirect_cookie) = jar.get(COOKIE_REDIRECT) else {
|
||||
return Ok((jar, Redirect::to("/")));
|
||||
};
|
||||
let redirect_url = Redirect::to(&format!("/{}", redirect_cookie.value_trimmed()));
|
||||
Ok((jar.remove(redirect_cookie), redirect_url))
|
||||
}
|
||||
2
vendor/postgres-v14
vendored
2
vendor/postgres-v14
vendored
Submodule vendor/postgres-v14 updated: 55c0d45abe...b6eece3f52
2
vendor/postgres-v15
vendored
2
vendor/postgres-v15
vendored
Submodule vendor/postgres-v15 updated: de7640f55d...20f8491225
2
vendor/postgres-v16
vendored
2
vendor/postgres-v16
vendored
Submodule vendor/postgres-v16 updated: 0bf96bd6d7...77c63bfebf
2
vendor/postgres-v17
vendored
2
vendor/postgres-v17
vendored
Submodule vendor/postgres-v17 updated: 8be779fd3a...32d704d965
@@ -20,6 +20,7 @@ anstream = { version = "0.6" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
axum = { version = "0.8", features = ["ws"] }
|
||||
axum-core = { version = "0.5", default-features = false, features = ["tracing"] }
|
||||
axum-extra = { version = "0.10", features = ["cookie-private", "typed-header"] }
|
||||
base64-594e8ee84c453af0 = { package = "base64", version = "0.13", features = ["alloc"] }
|
||||
base64-647d43efb71741da = { package = "base64", version = "0.21" }
|
||||
base64ct = { version = "1", default-features = false, features = ["std"] }
|
||||
@@ -30,6 +31,7 @@ clap = { version = "4", features = ["derive", "env", "string"] }
|
||||
clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] }
|
||||
const-oid = { version = "0.9", default-features = false, features = ["db", "std"] }
|
||||
crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] }
|
||||
crypto-common = { version = "0.1", default-features = false, features = ["getrandom", "std"] }
|
||||
der = { version = "0.7", default-features = false, features = ["derive", "flagset", "oid", "pem", "std"] }
|
||||
deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] }
|
||||
digest = { version = "0.10", features = ["mac", "oid", "std"] }
|
||||
|
||||
Reference in New Issue
Block a user