Move and library crates into a dedicated directory and rename them

This commit is contained in:
Kirill Bulatov
2022-04-20 16:38:33 +03:00
committed by Kirill Bulatov
parent 629688fd6c
commit 81cad6277a
127 changed files with 355 additions and 360 deletions

33
libs/utils/src/accum.rs Normal file
View File

@@ -0,0 +1,33 @@
/// A helper to "accumulate" a value similar to `Iterator::reduce`, but lets you
/// feed the accumulated values by calling the 'accum' function, instead of having an
/// iterator.
///
/// For example, to calculate the smallest value among some integers:
///
/// ```
/// use utils::accum::Accum;
///
/// let values = [1, 2, 3];
///
/// let mut min_value: Accum<u32> = Accum(None);
/// for new_value in &values {
/// min_value.accum(std::cmp::min, *new_value);
/// }
///
/// assert_eq!(min_value.0.unwrap(), 1);
/// ```
pub struct Accum<T>(pub Option<T>);
impl<T: Copy> Accum<T> {
pub fn accum<F>(&mut self, func: F, new_value: T)
where
F: FnOnce(T, T) -> T,
{
// If there is no previous value, just store the new value.
// Otherwise call the function to decide which one to keep.
self.0 = Some(if let Some(accum) = self.0 {
func(accum, new_value)
} else {
new_value
});
}
}

98
libs/utils/src/auth.rs Normal file
View File

@@ -0,0 +1,98 @@
// For details about authentication see docs/authentication.md
//
// TODO: use ed25519 keys
// Relevant issue: https://github.com/Keats/jsonwebtoken/issues/162
use serde;
use std::fs;
use std::path::Path;
use anyhow::{bail, Result};
use jsonwebtoken::{
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation,
};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
use crate::zid::ZTenantId;
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum Scope {
Tenant,
PageServerApi,
}
#[serde_as]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
#[serde(default)]
#[serde_as(as = "Option<DisplayFromStr>")]
pub tenant_id: Option<ZTenantId>,
pub scope: Scope,
}
impl Claims {
pub fn new(tenant_id: Option<ZTenantId>, scope: Scope) -> Self {
Self { tenant_id, scope }
}
}
pub fn check_permission(claims: &Claims, tenantid: Option<ZTenantId>) -> Result<()> {
match (&claims.scope, tenantid) {
(Scope::Tenant, None) => {
bail!("Attempt to access management api with tenant scope. Permission denied")
}
(Scope::Tenant, Some(tenantid)) => {
if claims.tenant_id.unwrap() != tenantid {
bail!("Tenant id mismatch. Permission denied")
}
Ok(())
}
(Scope::PageServerApi, None) => Ok(()), // access to management api for PageServerApi scope
(Scope::PageServerApi, Some(_)) => Ok(()), // access to tenant api using PageServerApi scope
}
}
pub struct JwtAuth {
decoding_key: DecodingKey,
validation: Validation,
}
impl JwtAuth {
pub fn new(decoding_key: DecodingKey) -> Self {
let mut validation = Validation::new(JWT_ALGORITHM);
// The default 'required_spec_claims' is 'exp'. But we don't want to require
// expiration.
validation.required_spec_claims = [].into();
Self {
decoding_key,
validation,
}
}
pub fn from_key_path(key_path: &Path) -> Result<Self> {
let public_key = fs::read(key_path)?;
Ok(Self::new(DecodingKey::from_rsa_pem(&public_key)?))
}
pub fn decode(&self, token: &str) -> Result<TokenData<Claims>> {
Ok(decode(token, &self.decoding_key, &self.validation)?)
}
}
impl std::fmt::Debug for JwtAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtAuth")
.field("validation", &self.validation)
.finish()
}
}
// this function is used only for testing purposes in CLI e g generate tokens during init
pub fn encode_from_key_file(claims: &Claims, key_data: &[u8]) -> Result<String> {
let key = EncodingKey::from_rsa_pem(key_data)?;
Ok(encode(&Header::new(JWT_ALGORITHM), claims, &key)?)
}

411
libs/utils/src/bin_ser.rs Normal file
View File

@@ -0,0 +1,411 @@
//! Utilities for binary serialization/deserialization.
//!
//! The [`BeSer`] trait allows us to define data structures
//! that can match data structures that are sent over the wire
//! in big-endian form with no packing.
//!
//! The [`LeSer`] trait does the same thing, in little-endian form.
//!
//! Note: you will get a compile error if you try to `use` both traits
//! in the same module or scope. This is intended to be a safety
//! mechanism: mixing big-endian and little-endian encoding in the same file
//! is error-prone.
#![warn(missing_docs)]
use bincode::Options;
use serde::{de::DeserializeOwned, Serialize};
use std::io::{self, Read, Write};
use thiserror::Error;
/// An error that occurred during a deserialize operation
///
/// This could happen because the input data was too short,
/// or because an invalid value was encountered.
#[derive(Debug, Error)]
pub enum DeserializeError {
/// The deserializer isn't able to deserialize the supplied data.
#[error("deserialize error")]
BadInput,
/// While deserializing from a `Read` source, an `io::Error` occurred.
#[error("deserialize error: {0}")]
Io(io::Error),
}
impl From<bincode::Error> for DeserializeError {
fn from(e: bincode::Error) -> Self {
match *e {
bincode::ErrorKind::Io(io_err) => DeserializeError::Io(io_err),
_ => DeserializeError::BadInput,
}
}
}
/// An error that occurred during a serialize operation
///
/// This probably means our [`Write`] failed, e.g. we tried
/// to write beyond the end of a buffer.
#[derive(Debug, Error)]
pub enum SerializeError {
/// The serializer isn't able to serialize the supplied data.
#[error("serialize error")]
BadInput,
/// While serializing into a `Write` sink, an `io::Error` occurred.
#[error("serialize error: {0}")]
Io(io::Error),
}
impl From<bincode::Error> for SerializeError {
fn from(e: bincode::Error) -> Self {
match *e {
bincode::ErrorKind::Io(io_err) => SerializeError::Io(io_err),
_ => SerializeError::BadInput,
}
}
}
/// A shortcut that configures big-endian binary serialization
///
/// Properties:
/// - Big endian
/// - Fixed integer encoding (i.e. 1u32 is 00000001 not 01)
///
/// Does not allow trailing bytes in deserialization. If this is desired, you
/// may set [`Options::allow_trailing_bytes`] to explicitly accomodate this.
pub fn be_coder() -> impl Options {
bincode::DefaultOptions::new()
.with_big_endian()
.with_fixint_encoding()
}
/// A shortcut that configures little-ending binary serialization
///
/// Properties:
/// - Little endian
/// - Fixed integer encoding (i.e. 1u32 is 00000001 not 01)
///
/// Does not allow trailing bytes in deserialization. If this is desired, you
/// may set [`Options::allow_trailing_bytes`] to explicitly accomodate this.
pub fn le_coder() -> impl Options {
bincode::DefaultOptions::new()
.with_little_endian()
.with_fixint_encoding()
}
/// Binary serialize/deserialize helper functions (Big Endian)
///
pub trait BeSer {
/// Serialize into a byte slice
fn ser_into_slice(&self, mut b: &mut [u8]) -> Result<(), SerializeError>
where
Self: Serialize,
{
// &mut [u8] implements Write, but `ser_into` needs a mutable
// reference to that. So we need the slightly awkward "mutable
// reference to a mutable reference.
self.ser_into(&mut b)
}
/// Serialize into a borrowed writer
///
/// This is useful for most `Write` types except `&mut [u8]`, which
/// can more easily use [`ser_into_slice`](Self::ser_into_slice).
fn ser_into<W: Write>(&self, w: &mut W) -> Result<(), SerializeError>
where
Self: Serialize,
{
be_coder().serialize_into(w, &self).map_err(|e| e.into())
}
/// Serialize into a new heap-allocated buffer
fn ser(&self) -> Result<Vec<u8>, SerializeError>
where
Self: Serialize,
{
be_coder().serialize(&self).map_err(|e| e.into())
}
/// Deserialize from the full contents of a byte slice
///
/// See also: [`BeSer::des_prefix`]
fn des(buf: &[u8]) -> Result<Self, DeserializeError>
where
Self: DeserializeOwned,
{
be_coder()
.deserialize(buf)
.or(Err(DeserializeError::BadInput))
}
/// Deserialize from a prefix of the byte slice
///
/// Uses as much of the byte slice as is necessary to deserialize the
/// type, but does not guarantee that the entire slice is used.
///
/// See also: [`BeSer::des`]
fn des_prefix(buf: &[u8]) -> Result<Self, DeserializeError>
where
Self: DeserializeOwned,
{
be_coder()
.allow_trailing_bytes()
.deserialize(buf)
.or(Err(DeserializeError::BadInput))
}
/// Deserialize from a reader
fn des_from<R: Read>(r: &mut R) -> Result<Self, DeserializeError>
where
Self: DeserializeOwned,
{
be_coder().deserialize_from(r).map_err(|e| e.into())
}
/// Compute the serialized size of a data structure
///
/// Note: it may be faster to serialize to a buffer and then measure the
/// buffer length, than to call `serialized_size` and then `ser_into`.
fn serialized_size(&self) -> Result<u64, SerializeError>
where
Self: Serialize,
{
be_coder().serialized_size(self).map_err(|e| e.into())
}
}
/// Binary serialize/deserialize helper functions (Little Endian)
///
pub trait LeSer {
/// Serialize into a byte slice
fn ser_into_slice(&self, mut b: &mut [u8]) -> Result<(), SerializeError>
where
Self: Serialize,
{
// &mut [u8] implements Write, but `ser_into` needs a mutable
// reference to that. So we need the slightly awkward "mutable
// reference to a mutable reference.
self.ser_into(&mut b)
}
/// Serialize into a borrowed writer
///
/// This is useful for most `Write` types except `&mut [u8]`, which
/// can more easily use [`ser_into_slice`](Self::ser_into_slice).
fn ser_into<W: Write>(&self, w: &mut W) -> Result<(), SerializeError>
where
Self: Serialize,
{
le_coder().serialize_into(w, &self).map_err(|e| e.into())
}
/// Serialize into a new heap-allocated buffer
fn ser(&self) -> Result<Vec<u8>, SerializeError>
where
Self: Serialize,
{
le_coder().serialize(&self).map_err(|e| e.into())
}
/// Deserialize from the full contents of a byte slice
///
/// See also: [`LeSer::des_prefix`]
fn des(buf: &[u8]) -> Result<Self, DeserializeError>
where
Self: DeserializeOwned,
{
le_coder()
.deserialize(buf)
.or(Err(DeserializeError::BadInput))
}
/// Deserialize from a prefix of the byte slice
///
/// Uses as much of the byte slice as is necessary to deserialize the
/// type, but does not guarantee that the entire slice is used.
///
/// See also: [`LeSer::des`]
fn des_prefix(buf: &[u8]) -> Result<Self, DeserializeError>
where
Self: DeserializeOwned,
{
le_coder()
.allow_trailing_bytes()
.deserialize(buf)
.or(Err(DeserializeError::BadInput))
}
/// Deserialize from a reader
fn des_from<R: Read>(r: &mut R) -> Result<Self, DeserializeError>
where
Self: DeserializeOwned,
{
le_coder().deserialize_from(r).map_err(|e| e.into())
}
/// Compute the serialized size of a data structure
///
/// Note: it may be faster to serialize to a buffer and then measure the
/// buffer length, than to call `serialized_size` and then `ser_into`.
fn serialized_size(&self) -> Result<u64, SerializeError>
where
Self: Serialize,
{
le_coder().serialized_size(self).map_err(|e| e.into())
}
}
// Because usage of `BeSer` or `LeSer` can be done with *either* a Serialize or
// DeserializeOwned implementation, the blanket implementation has to be for every type.
impl<T> BeSer for T {}
impl<T> LeSer for T {}
#[cfg(test)]
mod tests {
use super::DeserializeError;
use serde::{Deserialize, Serialize};
use std::io::Cursor;
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct ShortStruct {
a: u8,
b: u32,
}
const SHORT1: ShortStruct = ShortStruct { a: 7, b: 65536 };
const SHORT1_ENC_BE: &[u8] = &[7, 0, 1, 0, 0];
const SHORT1_ENC_BE_TRAILING: &[u8] = &[7, 0, 1, 0, 0, 255, 255, 255];
const SHORT1_ENC_LE: &[u8] = &[7, 0, 0, 1, 0];
const SHORT1_ENC_LE_TRAILING: &[u8] = &[7, 0, 0, 1, 0, 255, 255, 255];
const SHORT2: ShortStruct = ShortStruct {
a: 8,
b: 0x07030000,
};
const SHORT2_ENC_BE: &[u8] = &[8, 7, 3, 0, 0];
const SHORT2_ENC_BE_TRAILING: &[u8] = &[8, 7, 3, 0, 0, 0xff, 0xff, 0xff];
const SHORT2_ENC_LE: &[u8] = &[8, 0, 0, 3, 7];
const SHORT2_ENC_LE_TRAILING: &[u8] = &[8, 0, 0, 3, 7, 0xff, 0xff, 0xff];
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct LongMsg {
pub tag: u8,
pub blockpos: u32,
pub last_flush_position: u64,
pub apply: u64,
pub timestamp: i64,
pub reply_requested: u8,
}
const LONG1: LongMsg = LongMsg {
tag: 42,
blockpos: 0x1000_2000,
last_flush_position: 0x1234_2345_3456_4567,
apply: 0x9876_5432_10FE_DCBA,
timestamp: 0x7788_99AA_BBCC_DDFF,
reply_requested: 1,
};
#[test]
fn be_short() {
use super::BeSer;
assert_eq!(SHORT1.serialized_size().unwrap(), 5);
let encoded = SHORT1.ser().unwrap();
assert_eq!(encoded, SHORT1_ENC_BE);
let decoded = ShortStruct::des(SHORT2_ENC_BE).unwrap();
assert_eq!(decoded, SHORT2);
// with trailing data
let decoded = ShortStruct::des_prefix(SHORT2_ENC_BE_TRAILING).unwrap();
assert_eq!(decoded, SHORT2);
let err = ShortStruct::des(SHORT2_ENC_BE_TRAILING).unwrap_err();
assert!(matches!(err, DeserializeError::BadInput));
// serialize into a `Write` sink.
let mut buf = Cursor::new(vec![0xFF; 8]);
SHORT1.ser_into(&mut buf).unwrap();
assert_eq!(buf.into_inner(), SHORT1_ENC_BE_TRAILING);
// deserialize from a `Write` sink.
let mut buf = Cursor::new(SHORT2_ENC_BE);
let decoded = ShortStruct::des_from(&mut buf).unwrap();
assert_eq!(decoded, SHORT2);
// deserialize from a `Write` sink that terminates early.
let mut buf = Cursor::new([0u8; 4]);
let err = ShortStruct::des_from(&mut buf).unwrap_err();
assert!(matches!(err, DeserializeError::Io(_)));
}
#[test]
fn le_short() {
use super::LeSer;
assert_eq!(SHORT1.serialized_size().unwrap(), 5);
let encoded = SHORT1.ser().unwrap();
assert_eq!(encoded, SHORT1_ENC_LE);
let decoded = ShortStruct::des(SHORT2_ENC_LE).unwrap();
assert_eq!(decoded, SHORT2);
// with trailing data
let decoded = ShortStruct::des_prefix(SHORT2_ENC_LE_TRAILING).unwrap();
assert_eq!(decoded, SHORT2);
let err = ShortStruct::des(SHORT2_ENC_LE_TRAILING).unwrap_err();
assert!(matches!(err, DeserializeError::BadInput));
// serialize into a `Write` sink.
let mut buf = Cursor::new(vec![0xFF; 8]);
SHORT1.ser_into(&mut buf).unwrap();
assert_eq!(buf.into_inner(), SHORT1_ENC_LE_TRAILING);
// deserialize from a `Write` sink.
let mut buf = Cursor::new(SHORT2_ENC_LE);
let decoded = ShortStruct::des_from(&mut buf).unwrap();
assert_eq!(decoded, SHORT2);
// deserialize from a `Write` sink that terminates early.
let mut buf = Cursor::new([0u8; 4]);
let err = ShortStruct::des_from(&mut buf).unwrap_err();
assert!(matches!(err, DeserializeError::Io(_)));
}
#[test]
fn be_long() {
use super::BeSer;
assert_eq!(LONG1.serialized_size().unwrap(), 30);
let msg = LONG1;
let encoded = msg.ser().unwrap();
let expected = hex_literal::hex!(
"2A 1000 2000 1234 2345 3456 4567 9876 5432 10FE DCBA 7788 99AA BBCC DDFF 01"
);
assert_eq!(encoded, expected);
let msg2 = LongMsg::des(&encoded).unwrap();
assert_eq!(msg, msg2);
}
#[test]
fn le_long() {
use super::LeSer;
assert_eq!(LONG1.serialized_size().unwrap(), 30);
let msg = LONG1;
let encoded = msg.ser().unwrap();
let expected = hex_literal::hex!(
"2A 0020 0010 6745 5634 4523 3412 BADC FE10 3254 7698 FFDD CCBB AA99 8877 01"
);
assert_eq!(encoded, expected);
let msg2 = LongMsg::des(&encoded).unwrap();
assert_eq!(msg, msg2);
}
}

View File

@@ -0,0 +1,52 @@
use postgres::Config;
pub fn connection_host_port(config: &Config) -> (String, u16) {
assert_eq!(
config.get_hosts().len(),
1,
"only one pair of host and port is supported in connection string"
);
assert_eq!(
config.get_ports().len(),
1,
"only one pair of host and port is supported in connection string"
);
let host = match &config.get_hosts()[0] {
postgres::config::Host::Tcp(host) => host.as_ref(),
postgres::config::Host::Unix(host) => host.to_str().unwrap(),
};
(host.to_owned(), config.get_ports()[0])
}
pub fn connection_address(config: &Config) -> String {
let (host, port) = connection_host_port(config);
format!("{}:{}", host, port)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_host_port() {
let config: Config = "postgresql://no_user@localhost:64000/no_db"
.parse()
.unwrap();
assert_eq!(
connection_host_port(&config),
("localhost".to_owned(), 64000)
);
}
#[test]
#[should_panic(expected = "only one pair of host and port is supported in connection string")]
fn test_connection_host_port_multiple_ports() {
let config: Config = "postgresql://no_user@localhost:64000,localhost:64001/no_db"
.parse()
.unwrap();
assert_eq!(
connection_host_port(&config),
("localhost".to_owned(), 64000)
);
}
}

View File

@@ -0,0 +1,125 @@
use std::{
fs::{self, File},
io,
path::Path,
};
/// Similar to [`std::fs::create_dir`], except we fsync the
/// created directory and its parent.
pub fn create_dir(path: impl AsRef<Path>) -> io::Result<()> {
let path = path.as_ref();
fs::create_dir(path)?;
File::open(path)?.sync_all()?;
if let Some(parent) = path.parent() {
File::open(parent)?.sync_all()
} else {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"can't find parent",
))
}
}
/// Similar to [`std::fs::create_dir_all`], except we fsync all
/// newly created directories and the pre-existing parent.
pub fn create_dir_all(path: impl AsRef<Path>) -> io::Result<()> {
let mut path = path.as_ref();
let mut dirs_to_create = Vec::new();
// Figure out which directories we need to create.
loop {
match path.metadata() {
Ok(metadata) if metadata.is_dir() => break,
Ok(_) => {
return Err(io::Error::new(
io::ErrorKind::AlreadyExists,
format!("non-directory found in path: {}", path.display()),
));
}
Err(ref e) if e.kind() == io::ErrorKind::NotFound => {}
Err(e) => return Err(e),
}
dirs_to_create.push(path);
match path.parent() {
Some(parent) => path = parent,
None => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("can't find parent of path '{}'", path.display()).as_str(),
));
}
}
}
// Create directories from parent to child.
for &path in dirs_to_create.iter().rev() {
fs::create_dir(path)?;
}
// Fsync the created directories from child to parent.
for &path in dirs_to_create.iter() {
File::open(path)?.sync_all()?;
}
// If we created any new directories, fsync the parent.
if !dirs_to_create.is_empty() {
File::open(path)?.sync_all()?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use tempfile::tempdir;
use super::*;
#[test]
fn test_create_dir_fsyncd() {
let dir = tempdir().unwrap();
let existing_dir_path = dir.path();
let err = create_dir(existing_dir_path).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::AlreadyExists);
let child_dir = existing_dir_path.join("child");
create_dir(child_dir).unwrap();
let nested_child_dir = existing_dir_path.join("child1").join("child2");
let err = create_dir(nested_child_dir).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotFound);
}
#[test]
fn test_create_dir_all_fsyncd() {
let dir = tempdir().unwrap();
let existing_dir_path = dir.path();
create_dir_all(existing_dir_path).unwrap();
let child_dir = existing_dir_path.join("child");
assert!(!child_dir.exists());
create_dir_all(&child_dir).unwrap();
assert!(child_dir.exists());
let nested_child_dir = existing_dir_path.join("child1").join("child2");
assert!(!nested_child_dir.exists());
create_dir_all(&nested_child_dir).unwrap();
assert!(nested_child_dir.exists());
let file_path = existing_dir_path.join("file");
std::fs::write(&file_path, b"").unwrap();
let err = create_dir_all(&file_path).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::AlreadyExists);
let invalid_dir_path = file_path.join("folder");
create_dir_all(&invalid_dir_path).unwrap_err();
}
}

View File

@@ -0,0 +1,181 @@
use crate::auth::{self, Claims, JwtAuth};
use crate::http::error;
use crate::zid::ZTenantId;
use anyhow::anyhow;
use hyper::header::AUTHORIZATION;
use hyper::{header::CONTENT_TYPE, Body, Request, Response, Server};
use lazy_static::lazy_static;
use metrics::{new_common_metric_name, register_int_counter, Encoder, IntCounter, TextEncoder};
use routerify::ext::RequestExt;
use routerify::RequestInfo;
use routerify::{Middleware, Router, RouterBuilder, RouterService};
use tracing::info;
use std::future::Future;
use std::net::TcpListener;
use super::error::ApiError;
lazy_static! {
static ref SERVE_METRICS_COUNT: IntCounter = register_int_counter!(
new_common_metric_name("serve_metrics_count"),
"Number of metric requests made"
)
.expect("failed to define a metric");
}
async fn logger(res: Response<Body>, info: RequestInfo) -> Result<Response<Body>, ApiError> {
info!("{} {} {}", info.method(), info.uri().path(), res.status(),);
Ok(res)
}
async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
SERVE_METRICS_COUNT.inc();
let mut buffer = vec![];
let encoder = TextEncoder::new();
let metrics = metrics::gather();
encoder.encode(&metrics, &mut buffer).unwrap();
let response = Response::builder()
.status(200)
.header(CONTENT_TYPE, encoder.format_type())
.body(Body::from(buffer))
.unwrap();
Ok(response)
}
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
Router::builder()
.middleware(Middleware::post_with_info(logger))
.get("/metrics", prometheus_metrics_handler)
.err_handler(error::handler)
}
pub fn attach_openapi_ui(
router_builder: RouterBuilder<hyper::Body, ApiError>,
spec: &'static [u8],
spec_mount_path: &'static str,
ui_mount_path: &'static str,
) -> RouterBuilder<hyper::Body, ApiError> {
router_builder.get(spec_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(spec)).unwrap())
}).get(ui_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(format!(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<title>rweb</title>
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
<script>
window.onload = function() {{
const ui = SwaggerUIBundle({{
"dom_id": "\#swagger-ui",
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true,
url: "{}",
}})
window.ui = ui;
}};
</script>
</body>
</html>
"#, spec_mount_path))).unwrap())
})
}
fn parse_token(header_value: &str) -> Result<&str, ApiError> {
// header must be in form Bearer <token>
let (prefix, token) = header_value
.split_once(' ')
.ok_or_else(|| ApiError::Unauthorized("malformed authorization header".to_string()))?;
if prefix != "Bearer" {
return Err(ApiError::Unauthorized(
"malformed authorization header".to_string(),
));
}
Ok(token)
}
pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
provide_auth: fn(&Request<Body>) -> Option<&JwtAuth>,
) -> Middleware<B, ApiError> {
Middleware::pre(move |req| async move {
if let Some(auth) = provide_auth(&req) {
match req.headers().get(AUTHORIZATION) {
Some(value) => {
let header_value = value.to_str().map_err(|_| {
ApiError::Unauthorized("malformed authorization header".to_string())
})?;
let token = parse_token(header_value)?;
let data = auth
.decode(token)
.map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?;
req.set_context(data.claims);
}
None => {
return Err(ApiError::Unauthorized(
"missing authorization header".to_string(),
))
}
}
}
Ok(req)
})
}
pub fn check_permission(req: &Request<Body>, tenantid: Option<ZTenantId>) -> Result<(), ApiError> {
match req.context::<Claims>() {
Some(claims) => Ok(auth::check_permission(&claims, tenantid)
.map_err(|err| ApiError::Forbidden(err.to_string()))?),
None => Ok(()), // claims is None because auth is disabled
}
}
///
/// Start listening for HTTP requests on given socket.
///
/// 'shutdown_future' can be used to stop. If the Future becomes
/// ready, we stop listening for new requests, and the function returns.
///
pub fn serve_thread_main<S>(
router_builder: RouterBuilder<hyper::Body, ApiError>,
listener: TcpListener,
shutdown_future: S,
) -> anyhow::Result<()>
where
S: Future<Output = ()> + Send + Sync,
{
info!("Starting an HTTP endpoint at {}", listener.local_addr()?);
// Create a Service from the router above to handle incoming requests.
let service = RouterService::new(router_builder.build().map_err(|err| anyhow!(err))?).unwrap();
// Enter a single-threaded tokio runtime bound to the current thread
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let _guard = runtime.enter();
let server = Server::from_tcp(listener)?
.serve(service)
.with_graceful_shutdown(shutdown_future);
runtime.block_on(server)?;
Ok(())
}

View File

@@ -0,0 +1,88 @@
use anyhow::anyhow;
use hyper::{header, Body, Response, StatusCode};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ApiError {
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Forbidden: {0}")]
Forbidden(String),
#[error("Unauthorized: {0}")]
Unauthorized(String),
#[error("NotFound: {0}")]
NotFound(String),
#[error("Conflict: {0}")]
Conflict(String),
#[error(transparent)]
InternalServerError(#[from] anyhow::Error),
}
impl ApiError {
pub fn from_err<E: Into<anyhow::Error>>(err: E) -> Self {
Self::InternalServerError(anyhow!(err))
}
pub fn into_response(self) -> Response<Body> {
match self {
ApiError::BadRequest(_) => HttpErrorBody::response_from_msg_and_status(
self.to_string(),
StatusCode::BAD_REQUEST,
),
ApiError::Forbidden(_) => {
HttpErrorBody::response_from_msg_and_status(self.to_string(), StatusCode::FORBIDDEN)
}
ApiError::Unauthorized(_) => HttpErrorBody::response_from_msg_and_status(
self.to_string(),
StatusCode::UNAUTHORIZED,
),
ApiError::NotFound(_) => {
HttpErrorBody::response_from_msg_and_status(self.to_string(), StatusCode::NOT_FOUND)
}
ApiError::Conflict(_) => {
HttpErrorBody::response_from_msg_and_status(self.to_string(), StatusCode::CONFLICT)
}
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::INTERNAL_SERVER_ERROR,
),
}
}
}
#[derive(Serialize, Deserialize)]
pub struct HttpErrorBody {
pub msg: String,
}
impl HttpErrorBody {
pub fn from_msg(msg: String) -> Self {
HttpErrorBody { msg }
}
pub fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Body> {
HttpErrorBody { msg }.to_response(status)
}
pub fn to_response(&self, status: StatusCode) -> Response<Body> {
Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/json")
// we do not have nested maps with non string keys so serialization shouldn't fail
.body(Body::from(serde_json::to_string(self).unwrap()))
.unwrap()
}
}
pub async fn handler(err: routerify::RouteError) -> Response<Body> {
tracing::error!("Error processing HTTP request: {:?}", err);
err.downcast::<ApiError>()
.expect("handler should always return api error")
.into_response()
}

View File

@@ -0,0 +1,28 @@
use bytes::Buf;
use hyper::{header, Body, Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use super::error::ApiError;
pub async fn json_request<T: for<'de> Deserialize<'de>>(
request: &mut Request<Body>,
) -> Result<T, ApiError> {
let whole_body = hyper::body::aggregate(request.body_mut())
.await
.map_err(ApiError::from_err)?;
serde_json::from_reader(whole_body.reader())
.map_err(|err| ApiError::BadRequest(format!("Failed to parse json request {}", err)))
}
pub fn json_response<T: Serialize>(
status: StatusCode,
data: T,
) -> Result<Response<Body>, ApiError> {
let json = serde_json::to_string(&data).map_err(ApiError::from_err)?;
let response = Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(json))
.map_err(ApiError::from_err)?;
Ok(response)
}

View File

@@ -0,0 +1,8 @@
pub mod endpoint;
pub mod error;
pub mod json;
pub mod request;
/// Current fast way to apply simple http routing in various Zenith binaries.
/// Re-exported for sake of uniform approach, that could be later replaced with better alternatives, if needed.
pub use routerify::{ext::RequestExt, RouterBuilder, RouterService};

View File

@@ -0,0 +1,33 @@
use std::str::FromStr;
use super::error::ApiError;
use hyper::{Body, Request};
use routerify::ext::RequestExt;
pub fn get_request_param<'a>(
request: &'a Request<Body>,
param_name: &str,
) -> Result<&'a str, ApiError> {
match request.param(param_name) {
Some(arg) => Ok(arg),
None => {
return Err(ApiError::BadRequest(format!(
"no {} specified in path param",
param_name
)))
}
}
}
pub fn parse_request_param<T: FromStr>(
request: &Request<Body>,
param_name: &str,
) -> Result<T, ApiError> {
match get_request_param(request, param_name)?.parse() {
Ok(v) => Ok(v),
Err(_) => Err(ApiError::BadRequest(format!(
"failed to parse {}",
param_name
))),
}
}

84
libs/utils/src/lib.rs Normal file
View File

@@ -0,0 +1,84 @@
//! `utils` is intended to be a place to put code that is shared
//! between other crates in this repository.
#![allow(clippy::manual_range_contains)]
/// `Lsn` type implements common tasks on Log Sequence Numbers
pub mod lsn;
/// SeqWait allows waiting for a future sequence number to arrive
pub mod seqwait;
/// append only ordered map implemented with a Vec
pub mod vec_map;
// Async version of SeqWait. Currently unused.
// pub mod seqwait_async;
pub mod bin_ser;
pub mod postgres_backend;
pub mod pq_proto;
// dealing with connstring parsing and handy access to it's parts
pub mod connstring;
// helper functions for creating and fsyncing directories/trees
pub mod crashsafe_dir;
// common authentication routines
pub mod auth;
// utility functions and helper traits for unified unique id generation/serialization etc.
pub mod zid;
// http endpoint utils
pub mod http;
// socket splitting utils
pub mod sock_split;
// common log initialisation routine
pub mod logging;
// Misc
pub mod accum;
pub mod shutdown;
// Tools for calling certain async methods in sync contexts
pub mod sync;
// Utility for binding TcpListeners with proper socket options.
pub mod tcp_listener;
// Utility for putting a raw file descriptor into non-blocking mode
pub mod nonblock;
// Default signal handling
pub mod signals;
// This is a shortcut to embed git sha into binaries and avoid copying the same build script to all packages
//
// we have several cases:
// * building locally from git repo
// * building in CI from git repo
// * building in docker (either in CI or locally)
//
// One thing to note is that .git is not available in docker (and it is bad to include it there).
// So everything becides docker build is covered by git_version crate.
// For docker use environment variable to pass git version, which is then retrieved by buildscript (build.rs).
// It takes variable from build process env and puts it to the rustc env. And then we can retrieve it here by using env! macro.
// Git version received from environment variable used as a fallback in git_version invokation.
// And to avoid running buildscript every recompilation, we use rerun-if-env-changed option.
// So the build script will be run only when GIT_VERSION envvar has changed.
//
// Why not to use buildscript to get git commit sha directly without procmacro from different crate?
// Caching and workspaces complicates that. In case `utils` is not
// recompiled due to caching then version may become outdated.
// git_version crate handles that case by introducing a dependency on .git internals via include_bytes! macro,
// so if we changed the index state git_version will pick that up and rerun the macro.
//
// Note that with git_version prefix is `git:` and in case of git version from env its `git-env:`.
use git_version::git_version;
pub const GIT_VERSION: &str = git_version!(
prefix = "git:",
fallback = concat!("git-env:", env!("GIT_VERSION")),
args = ["--abbrev=40", "--always", "--dirty=-modified"] // always use full sha
);

42
libs/utils/src/logging.rs Normal file
View File

@@ -0,0 +1,42 @@
use std::{
fs::{File, OpenOptions},
path::Path,
};
use anyhow::{Context, Result};
pub fn init(log_filename: impl AsRef<Path>, daemonize: bool) -> Result<File> {
// Don't open the same file for output multiple times;
// the different fds could overwrite each other's output.
let log_file = OpenOptions::new()
.create(true)
.append(true)
.open(&log_filename)
.with_context(|| format!("failed to open {:?}", log_filename.as_ref()))?;
let default_filter_str = "info";
// We fall back to printing all spans at info-level or above if
// the RUST_LOG environment variable is not set.
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(default_filter_str));
let base_logger = tracing_subscriber::fmt()
.with_env_filter(env_filter)
.with_target(false) // don't include event targets
.with_ansi(false); // don't use colors in log file;
// we are cloning and returning log file in order to allow redirecting daemonized stdout and stderr to it
// if we do not use daemonization (e.g. in docker) it is better to log to stdout directly
// for example to be in line with docker log command which expects logs comimg from stdout
if daemonize {
let x = log_file.try_clone().unwrap();
base_logger
.with_writer(move || x.try_clone().unwrap())
.init();
} else {
base_logger.init();
}
Ok(log_file)
}

308
libs/utils/src/lsn.rs Normal file
View File

@@ -0,0 +1,308 @@
#![warn(missing_docs)]
use serde::{Deserialize, Serialize};
use std::fmt;
use std::ops::{Add, AddAssign};
use std::path::Path;
use std::str::FromStr;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::seqwait::MonotonicCounter;
/// Transaction log block size in bytes
pub const XLOG_BLCKSZ: u32 = 8192;
/// A Postgres LSN (Log Sequence Number), also known as an XLogRecPtr
#[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Lsn(pub u64);
/// We tried to parse an LSN from a string, but failed
#[derive(Debug, PartialEq, thiserror::Error)]
#[error("LsnParseError")]
pub struct LsnParseError;
impl Lsn {
/// Maximum possible value for an LSN
pub const MAX: Lsn = Lsn(u64::MAX);
/// Subtract a number, returning None on overflow.
pub fn checked_sub<T: Into<u64>>(self, other: T) -> Option<Lsn> {
let other: u64 = other.into();
self.0.checked_sub(other).map(Lsn)
}
/// Subtract a number, returning the difference as i128 to avoid overflow.
pub fn widening_sub<T: Into<u64>>(self, other: T) -> i128 {
let other: u64 = other.into();
i128::from(self.0) - i128::from(other)
}
/// Parse an LSN from a filename in the form `0000000000000000`
pub fn from_filename<F>(filename: F) -> Result<Self, LsnParseError>
where
F: AsRef<Path>,
{
let filename: &Path = filename.as_ref();
let filename = filename.to_str().ok_or(LsnParseError)?;
Lsn::from_hex(filename)
}
/// Parse an LSN from a string in the form `0000000000000000`
pub fn from_hex<S>(s: S) -> Result<Self, LsnParseError>
where
S: AsRef<str>,
{
let s: &str = s.as_ref();
let n = u64::from_str_radix(s, 16).or(Err(LsnParseError))?;
Ok(Lsn(n))
}
/// Compute the offset into a segment
pub fn segment_offset(self, seg_sz: usize) -> usize {
(self.0 % seg_sz as u64) as usize
}
/// Compute the segment number
pub fn segment_number(self, seg_sz: usize) -> u64 {
self.0 / seg_sz as u64
}
/// Compute the offset into a block
pub fn block_offset(self) -> u64 {
const BLCKSZ: u64 = XLOG_BLCKSZ as u64;
self.0 % BLCKSZ
}
/// Compute the bytes remaining in this block
///
/// If the LSN is already at the block boundary, it will return `XLOG_BLCKSZ`.
pub fn remaining_in_block(self) -> u64 {
const BLCKSZ: u64 = XLOG_BLCKSZ as u64;
BLCKSZ - (self.0 % BLCKSZ)
}
/// Compute the bytes remaining to fill a chunk of some size
///
/// If the LSN is already at the chunk boundary, it will return 0.
pub fn calc_padding<T: Into<u64>>(self, sz: T) -> u64 {
let sz: u64 = sz.into();
// By using wrapping_sub, we can subtract first and then mod second.
// If it's done the other way around, then we would return a full
// chunk size if we're already at the chunk boundary.
// (Regular subtraction will panic on overflow in debug builds.)
(sz.wrapping_sub(self.0)) % sz
}
/// Align LSN on 8-byte boundary (alignment of WAL records).
pub fn align(&self) -> Lsn {
Lsn((self.0 + 7) & !7)
}
/// Align LSN on 8-byte boundary (alignment of WAL records).
pub fn is_aligned(&self) -> bool {
*self == self.align()
}
}
impl From<u64> for Lsn {
fn from(n: u64) -> Self {
Lsn(n)
}
}
impl From<Lsn> for u64 {
fn from(lsn: Lsn) -> u64 {
lsn.0
}
}
impl FromStr for Lsn {
type Err = LsnParseError;
/// Parse an LSN from a string in the form `00000000/00000000`
///
/// If the input string is missing the '/' character, then use `Lsn::from_hex`
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut splitter = s.split('/');
if let (Some(left), Some(right), None) = (splitter.next(), splitter.next(), splitter.next())
{
let left_num = u32::from_str_radix(left, 16).map_err(|_| LsnParseError)?;
let right_num = u32::from_str_radix(right, 16).map_err(|_| LsnParseError)?;
Ok(Lsn((left_num as u64) << 32 | right_num as u64))
} else {
Err(LsnParseError)
}
}
}
impl fmt::Display for Lsn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:X}/{:X}", self.0 >> 32, self.0 & 0xffffffff)
}
}
impl fmt::Debug for Lsn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:X}/{:X}", self.0 >> 32, self.0 & 0xffffffff)
}
}
impl Add<u64> for Lsn {
type Output = Lsn;
fn add(self, other: u64) -> Self::Output {
// panic if the addition overflows.
Lsn(self.0.checked_add(other).unwrap())
}
}
impl AddAssign<u64> for Lsn {
fn add_assign(&mut self, other: u64) {
// panic if the addition overflows.
self.0 = self.0.checked_add(other).unwrap();
}
}
/// An [`Lsn`] that can be accessed atomically.
pub struct AtomicLsn {
inner: AtomicU64,
}
impl AtomicLsn {
/// Creates a new atomic `Lsn`.
pub fn new(val: u64) -> Self {
AtomicLsn {
inner: AtomicU64::new(val),
}
}
/// Atomically retrieve the `Lsn` value from memory.
pub fn load(&self) -> Lsn {
Lsn(self.inner.load(Ordering::Acquire))
}
/// Atomically store a new `Lsn` value to memory.
pub fn store(&self, lsn: Lsn) {
self.inner.store(lsn.0, Ordering::Release);
}
/// Adds to the current value, returning the previous value.
///
/// This operation will panic on overflow.
pub fn fetch_add(&self, val: u64) -> Lsn {
let prev = self.inner.fetch_add(val, Ordering::AcqRel);
assert!(prev.checked_add(val).is_some(), "AtomicLsn overflow");
Lsn(prev)
}
/// Atomically sets the Lsn to the max of old and new value, returning the old value.
pub fn fetch_max(&self, lsn: Lsn) -> Lsn {
let prev = self.inner.fetch_max(lsn.0, Ordering::AcqRel);
Lsn(prev)
}
}
impl From<Lsn> for AtomicLsn {
fn from(lsn: Lsn) -> Self {
Self::new(lsn.0)
}
}
/// Pair of LSN's pointing to the end of the last valid record and previous one
#[derive(Debug, Clone, Copy)]
pub struct RecordLsn {
/// LSN at the end of the current record
pub last: Lsn,
/// LSN at the end of the previous record
pub prev: Lsn,
}
/// Expose `self.last` as counter to be able to use RecordLsn in SeqWait
impl MonotonicCounter<Lsn> for RecordLsn {
fn cnt_advance(&mut self, lsn: Lsn) {
assert!(self.last <= lsn);
let new_prev = self.last;
self.last = lsn;
self.prev = new_prev;
}
fn cnt_value(&self) -> Lsn {
self.last
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lsn_strings() {
assert_eq!("12345678/AAAA5555".parse(), Ok(Lsn(0x12345678AAAA5555)));
assert_eq!("aaaa/bbbb".parse(), Ok(Lsn(0x0000AAAA0000BBBB)));
assert_eq!("1/A".parse(), Ok(Lsn(0x000000010000000A)));
assert_eq!("0/0".parse(), Ok(Lsn(0)));
"ABCDEFG/12345678".parse::<Lsn>().unwrap_err();
"123456789/AAAA5555".parse::<Lsn>().unwrap_err();
"12345678/AAAA55550".parse::<Lsn>().unwrap_err();
"-1/0".parse::<Lsn>().unwrap_err();
"1/-1".parse::<Lsn>().unwrap_err();
assert_eq!(format!("{}", Lsn(0x12345678AAAA5555)), "12345678/AAAA5555");
assert_eq!(format!("{}", Lsn(0x000000010000000A)), "1/A");
assert_eq!(
Lsn::from_hex("12345678AAAA5555"),
Ok(Lsn(0x12345678AAAA5555))
);
assert_eq!(Lsn::from_hex("0"), Ok(Lsn(0)));
assert_eq!(Lsn::from_hex("F12345678AAAA5555"), Err(LsnParseError));
}
#[test]
fn test_lsn_math() {
assert_eq!(Lsn(1234) + 11u64, Lsn(1245));
assert_eq!(
{
let mut lsn = Lsn(1234);
lsn += 11u64;
lsn
},
Lsn(1245)
);
assert_eq!(Lsn(1234).checked_sub(1233u64), Some(Lsn(1)));
assert_eq!(Lsn(1234).checked_sub(1235u64), None);
assert_eq!(Lsn(1235).widening_sub(1234u64), 1);
assert_eq!(Lsn(1234).widening_sub(1235u64), -1);
assert_eq!(Lsn(u64::MAX).widening_sub(0u64), i128::from(u64::MAX));
assert_eq!(Lsn(0).widening_sub(u64::MAX), -i128::from(u64::MAX));
let seg_sz: usize = 16 * 1024 * 1024;
assert_eq!(Lsn(0x1000007).segment_offset(seg_sz), 7);
assert_eq!(Lsn(0x1000007).segment_number(seg_sz), 1u64);
assert_eq!(Lsn(0x4007).block_offset(), 7u64);
assert_eq!(Lsn(0x4000).block_offset(), 0u64);
assert_eq!(Lsn(0x4007).remaining_in_block(), 8185u64);
assert_eq!(Lsn(0x4000).remaining_in_block(), 8192u64);
assert_eq!(Lsn(0xffff01).calc_padding(seg_sz as u64), 255u64);
assert_eq!(Lsn(0x2000000).calc_padding(seg_sz as u64), 0u64);
assert_eq!(Lsn(0xffff01).calc_padding(8u32), 7u64);
assert_eq!(Lsn(0xffff00).calc_padding(8u32), 0u64);
}
#[test]
fn test_atomic_lsn() {
let lsn = AtomicLsn::new(0);
assert_eq!(lsn.fetch_add(1234), Lsn(0));
assert_eq!(lsn.load(), Lsn(1234));
lsn.store(Lsn(5678));
assert_eq!(lsn.load(), Lsn(5678));
assert_eq!(lsn.fetch_max(Lsn(6000)), Lsn(5678));
assert_eq!(lsn.fetch_max(Lsn(5000)), Lsn(6000));
}
}

View File

@@ -0,0 +1,17 @@
use nix::fcntl::{fcntl, OFlag, F_GETFL, F_SETFL};
use std::os::unix::io::RawFd;
/// Put a file descriptor into non-blocking mode
pub fn set_nonblock(fd: RawFd) -> Result<(), std::io::Error> {
let bits = fcntl(fd, F_GETFL)?;
// Safety: If F_GETFL returns some unknown bits, they should be valid
// for passing back to F_SETFL, too. If we left them out, the F_SETFL
// would effectively clear them, which is not what we want.
let mut flags = unsafe { OFlag::from_bits_unchecked(bits) };
flags |= OFlag::O_NONBLOCK;
fcntl(fd, F_SETFL(flags))?;
Ok(())
}

View File

@@ -0,0 +1,500 @@
//! Server-side synchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::pq_proto::{BeMessage, BeParameterStatusMessage, FeMessage, FeStartupPacket};
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
use anyhow::{bail, ensure, Context, Result};
use bytes::{Bytes, BytesMut};
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::io::{self, Write};
use std::net::{Shutdown, SocketAddr, TcpStream};
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tracing::*;
static PGBACKEND_SHUTDOWN_REQUESTED: AtomicBool = AtomicBool::new(false);
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: &str) -> Result<()>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupPacket) -> Result<()> {
Ok(())
}
/// Check auth md5
fn check_auth_md5(&mut self, _pgb: &mut PostgresBackend, _md5_response: &[u8]) -> Result<()> {
bail!("MD5 auth failed")
}
/// Check auth jwt
fn check_auth_jwt(&mut self, _pgb: &mut PostgresBackend, _jwt_response: &[u8]) -> Result<()> {
bail!("JWT auth failed")
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, PartialOrd)]
pub enum ProtoState {
Initialization,
Encrypted,
Authentication,
Established,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthType {
Trust,
MD5,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
ZenithJWT,
}
impl FromStr for AuthType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Trust" => Ok(Self::Trust),
"MD5" => Ok(Self::MD5),
"ZenithJWT" => Ok(Self::ZenithJWT),
_ => bail!("invalid value \"{}\" for auth type", s),
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::MD5 => "MD5",
AuthType::ZenithJWT => "ZenithJWT",
})
}
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Always-writeable sock_split stream.
/// May not be readable. See [`PostgresBackend::take_stream_in`]
pub enum Stream {
Bidirectional(BidiStream),
WriteOnly(WriteStream),
}
impl Stream {
fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.shutdown(how),
Self::WriteOnly(write_stream) => write_stream.shutdown(how),
}
}
}
impl io::Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.write(buf),
Self::WriteOnly(write_stream) => write_stream.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.flush(),
Self::WriteOnly(write_stream) => write_stream.flush(),
}
}
}
pub struct PostgresBackend {
stream: Option<Stream>,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
buf_out: BytesMut,
pub state: ProtoState,
md5_salt: [u8; 4],
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
// Helper function for socket read loops
pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool {
for cause in error.chain() {
if let Some(io_error) = cause.downcast_ref::<io::Error>() {
if io_error.kind() == std::io::ErrorKind::WouldBlock {
return true;
}
}
}
false
}
// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations)
// PG protocol strings are always C strings.
fn cstr_to_str(b: &Bytes) -> Result<&str> {
let without_null = if b.last() == Some(&0) {
&b[..b.len() - 1]
} else {
&b[..]
};
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
pub fn new(
socket: TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
set_read_timeout: bool,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
if set_read_timeout {
socket
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
}
Ok(Self {
stream: Some(Stream::Bidirectional(BidiStream::from_tcp(socket))),
buf_out: BytesMut::with_capacity(10 * 1024),
state: ProtoState::Initialization,
md5_salt: [0u8; 4],
auth_type,
tls_config,
peer_addr,
})
}
pub fn into_stream(self) -> Stream {
self.stream.unwrap()
}
/// Get direct reference (into the Option) to the read stream.
fn get_stream_in(&mut self) -> Result<&mut BidiStream> {
match &mut self.stream {
Some(Stream::Bidirectional(stream)) => Ok(stream),
_ => bail!("reader taken"),
}
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
pub fn take_stream_in(&mut self) -> Option<ReadStream> {
let stream = self.stream.take();
match stream {
Some(Stream::Bidirectional(bidi_stream)) => {
let (read, write) = bidi_stream.split();
self.stream = Some(Stream::WriteOnly(write));
Some(read)
}
stream => {
self.stream = stream;
None
}
}
}
/// Read full message or return None if connection is closed.
pub fn read_message(&mut self) -> Result<Option<FeMessage>> {
let (state, stream) = (self.state, self.get_stream_in()?);
use ProtoState::*;
match state {
Initialization | Encrypted => FeStartupPacket::read(stream),
Authentication | Established => FeMessage::read(stream),
}
}
/// Write message into internal output buffer.
pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
Ok(self)
}
/// Flush output buffer into the socket.
pub fn flush(&mut self) -> io::Result<&mut Self> {
let stream = self.stream.as_mut().unwrap();
stream.write_all(&self.buf_out)?;
self.buf_out.clear();
Ok(self)
}
/// Write message into internal buffer and flush it.
pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush()
}
// Wrapper for run_message_loop() that shuts down socket when we are done
pub fn run(mut self, handler: &mut impl Handler) -> Result<()> {
let ret = self.run_message_loop(handler);
if let Some(stream) = self.stream.as_mut() {
let _ = stream.shutdown(Shutdown::Both);
}
ret
}
fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<()> {
trace!("postgres backend to {:?} started", self.peer_addr);
let mut unnamed_query_string = Bytes::new();
while !PGBACKEND_SHUTDOWN_REQUESTED.load(Ordering::Relaxed) {
match self.read_message() {
Ok(message) => {
if let Some(msg) = message {
trace!("got message {:?}", msg);
match self.process_message(handler, msg, &mut unnamed_query_string)? {
ProcessMsgResult::Continue => continue,
ProcessMsgResult::Break => break,
}
} else {
break;
}
}
Err(e) => {
// If it is a timeout error, continue the loop
if !is_socket_read_timed_out(&e) {
return Err(e);
}
}
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
pub fn start_tls(&mut self) -> anyhow::Result<()> {
match self.stream.take() {
Some(Stream::Bidirectional(bidi_stream)) => {
let conn = rustls::ServerConnection::new(self.tls_config.clone().unwrap())?;
self.stream = Some(Stream::Bidirectional(bidi_stream.start_tls(conn)?));
Ok(())
}
stream => {
self.stream = stream;
bail!("can't start TLs without bidi stream");
}
}
}
fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
if self.state < ProtoState::Established {
ensure!(
matches!(
msg,
FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_)
),
"protocol violation"
);
}
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupPacket(m) => {
trace!("got startup message {:?}", m);
match m {
FeStartupPacket::SslRequest => {
info!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls()?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
info!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse("must connect with TLS"))?;
bail!("client did not connect with TLS");
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
// The async python driver requires a valid server_version
.write_message_noflush(&BeMessage::ParameterStatus(
BeParameterStatusMessage::ServerVersion("14.1"),
))?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::MD5 => {
rand::thread_rng().fill(&mut self.md5_salt);
self.write_message(&BeMessage::AuthenticationMD5Password(
self.md5_salt,
))?;
self.state = ProtoState::Authentication;
}
AuthType::ZenithJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
return Ok(ProcessMsgResult::Break);
}
}
}
FeMessage::PasswordMessage(m) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
match self.auth_type {
AuthType::Trust => unreachable!(),
AuthType::MD5 => {
let (_, md5_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_md5(self, md5_response) {
self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?;
bail!("auth failed: {}", e);
}
}
AuthType::ZenithJWT => {
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?;
bail!("auth failed: {}", e);
}
}
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
FeMessage::Query(m) => {
// remove null terminator
let query_string = cstr_to_str(&m.body)?;
trace!("got query {:?}", query_string);
// xxx distinguish fatal and recoverable errors?
if let Err(e) = handler.process_query(self, query_string) {
// ":?" uses the alternate formatting style, which makes anyhow display the
// full cause of the error, not just the top-level context + its trace.
// We don't want to send that in the ErrorResponse though,
// because it's not relevant to the compute node logs.
error!("query handler for '{}' failed: {:?}", query_string, e);
self.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string()))?;
// TODO: untangle convoluted control flow
if e.to_string().contains("failed to run") {
return Ok(ProcessMsgResult::Break);
}
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {:?}", query_string);
// xxx distinguish fatal and recoverable errors?
if let Err(e) = handler.process_query(self, query_string) {
error!("query handler for '{}' failed: {:?}", query_string, e);
self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesnt require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
bail!("unexpected message type: {:?}", msg);
}
}
Ok(ProcessMsgResult::Continue)
}
}
// Set the flag to inform connections to cancel
pub fn set_pgbackend_shutdown_requested() {
PGBACKEND_SHUTDOWN_REQUESTED.swap(true, Ordering::Relaxed);
}

1068
libs/utils/src/pq_proto.rs Normal file

File diff suppressed because it is too large Load Diff

293
libs/utils/src/seqwait.rs Normal file
View File

@@ -0,0 +1,293 @@
#![warn(missing_docs)]
use std::cmp::{Eq, Ordering, PartialOrd};
use std::collections::BinaryHeap;
use std::fmt::Debug;
use std::mem;
use std::sync::mpsc::{channel, Receiver, Sender};
use std::sync::Mutex;
use std::time::Duration;
/// An error happened while waiting for a number
#[derive(Debug, PartialEq, thiserror::Error)]
#[error("SeqWaitError")]
pub enum SeqWaitError {
/// The wait timeout was reached
Timeout,
/// [`SeqWait::shutdown`] was called
Shutdown,
}
/// Monotonically increasing value
///
/// It is handy to store some other fields under the same mutex in SeqWait<S>
/// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with
/// any type that can expose counter. <V> is the type of exposed counter.
pub trait MonotonicCounter<V> {
/// Bump counter value and check that it goes forward
/// N.B.: new_val is an actual new value, not a difference.
fn cnt_advance(&mut self, new_val: V);
/// Get counter value
fn cnt_value(&self) -> V;
}
/// Internal components of a `SeqWait`
struct SeqWaitInt<S, V>
where
S: MonotonicCounter<V>,
V: Ord,
{
waiters: BinaryHeap<Waiter<V>>,
current: S,
shutdown: bool,
}
struct Waiter<T>
where
T: Ord,
{
wake_num: T, // wake me when this number arrives ...
wake_channel: Sender<()>, // ... by sending a message to this channel
}
// BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
// to get that.
impl<T: Ord> PartialOrd for Waiter<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
other.wake_num.partial_cmp(&self.wake_num)
}
}
impl<T: Ord> Ord for Waiter<T> {
fn cmp(&self, other: &Self) -> Ordering {
other.wake_num.cmp(&self.wake_num)
}
}
impl<T: Ord> PartialEq for Waiter<T> {
fn eq(&self, other: &Self) -> bool {
other.wake_num == self.wake_num
}
}
impl<T: Ord> Eq for Waiter<T> {}
/// A tool for waiting on a sequence number
///
/// This provides a way to wait the arrival of a number.
/// As soon as the number arrives by another caller calling
/// [`advance`], then the waiter will be woken up.
///
/// This implementation takes a blocking Mutex on both [`wait_for`]
/// and [`advance`], meaning there may be unexpected executor blocking
/// due to thread scheduling unfairness. There are probably better
/// implementations, but we can probably live with this for now.
///
/// [`wait_for`]: SeqWait::wait_for
/// [`advance`]: SeqWait::advance
///
/// <S> means Storage, <V> is type of counter that this storage exposes.
///
pub struct SeqWait<S, V>
where
S: MonotonicCounter<V>,
V: Ord,
{
internal: Mutex<SeqWaitInt<S, V>>,
}
impl<S, V> SeqWait<S, V>
where
S: MonotonicCounter<V> + Copy,
V: Ord + Copy,
{
/// Create a new `SeqWait`, initialized to a particular number
pub fn new(starting_num: S) -> Self {
let internal = SeqWaitInt {
waiters: BinaryHeap::new(),
current: starting_num,
shutdown: false,
};
SeqWait {
internal: Mutex::new(internal),
}
}
/// Shut down a `SeqWait`, causing all waiters (present and
/// future) to return an error.
pub fn shutdown(&self) {
let waiters = {
// Prevent new waiters; wake all those that exist.
// Wake everyone with an error.
let mut internal = self.internal.lock().unwrap();
// This will steal the entire waiters map.
// When we drop it all waiters will be woken.
mem::take(&mut internal.waiters)
// Drop the lock as we exit this scope.
};
// When we drop the waiters list, each Receiver will
// be woken with an error.
// This drop doesn't need to be explicit; it's done
// here to make it easier to read the code and understand
// the order of events.
drop(waiters);
}
/// Wait for a number to arrive
///
/// This call won't complete until someone has called `advance`
/// with a number greater than or equal to the one we're waiting for.
pub fn wait_for(&self, num: V) -> Result<(), SeqWaitError> {
match self.queue_for_wait(num) {
Ok(None) => Ok(()),
Ok(Some(rx)) => rx.recv().map_err(|_| SeqWaitError::Shutdown),
Err(e) => Err(e),
}
}
/// Wait for a number to arrive
///
/// This call won't complete until someone has called `advance`
/// with a number greater than or equal to the one we're waiting for.
///
/// If that hasn't happened after the specified timeout duration,
/// [`SeqWaitError::Timeout`] will be returned.
pub fn wait_for_timeout(&self, num: V, timeout_duration: Duration) -> Result<(), SeqWaitError> {
match self.queue_for_wait(num) {
Ok(None) => Ok(()),
Ok(Some(rx)) => rx.recv_timeout(timeout_duration).map_err(|e| match e {
std::sync::mpsc::RecvTimeoutError::Timeout => SeqWaitError::Timeout,
std::sync::mpsc::RecvTimeoutError::Disconnected => SeqWaitError::Shutdown,
}),
Err(e) => Err(e),
}
}
/// Register and return a channel that will be notified when a number arrives,
/// or None, if it has already arrived.
fn queue_for_wait(&self, num: V) -> Result<Option<Receiver<()>>, SeqWaitError> {
let mut internal = self.internal.lock().unwrap();
if internal.current.cnt_value() >= num {
return Ok(None);
}
if internal.shutdown {
return Err(SeqWaitError::Shutdown);
}
// Create a new channel.
let (tx, rx) = channel();
internal.waiters.push(Waiter {
wake_num: num,
wake_channel: tx,
});
// Drop the lock as we exit this scope.
Ok(Some(rx))
}
/// Announce a new number has arrived
///
/// All waiters at this value or below will be woken.
///
/// Returns the old number.
pub fn advance(&self, num: V) -> V {
let old_value;
let wake_these = {
let mut internal = self.internal.lock().unwrap();
old_value = internal.current.cnt_value();
if old_value >= num {
return old_value;
}
internal.current.cnt_advance(num);
// Pop all waiters <= num from the heap. Collect them in a vector, and
// wake them up after releasing the lock.
let mut wake_these = Vec::new();
while let Some(n) = internal.waiters.peek() {
if n.wake_num > num {
break;
}
wake_these.push(internal.waiters.pop().unwrap().wake_channel);
}
wake_these
};
for tx in wake_these {
// This can fail if there are no receivers.
// We don't care; discard the error.
let _ = tx.send(());
}
old_value
}
/// Read the current value, without waiting.
pub fn load(&self) -> S {
self.internal.lock().unwrap().current
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread::sleep;
use std::thread::spawn;
use std::time::Duration;
impl MonotonicCounter<i32> for i32 {
fn cnt_advance(&mut self, val: i32) {
assert!(*self <= val);
*self = val;
}
fn cnt_value(&self) -> i32 {
*self
}
}
#[test]
fn seqwait() {
let seq = Arc::new(SeqWait::new(0));
let seq2 = Arc::clone(&seq);
let seq3 = Arc::clone(&seq);
spawn(move || {
seq2.wait_for(42).expect("wait_for 42");
let old = seq2.advance(100);
assert_eq!(old, 99);
seq2.wait_for(999).expect_err("no 999");
});
spawn(move || {
seq3.wait_for(42).expect("wait_for 42");
seq3.wait_for(0).expect("wait_for 0");
});
sleep(Duration::from_secs(1));
let old = seq.advance(99);
assert_eq!(old, 0);
seq.wait_for(100).expect("wait_for 100");
// Calling advance with a smaller value is a no-op
assert_eq!(seq.advance(98), 100);
assert_eq!(seq.load(), 100);
seq.shutdown();
}
#[test]
fn seqwait_timeout() {
let seq = Arc::new(SeqWait::new(0));
let seq2 = Arc::clone(&seq);
spawn(move || {
let timeout = Duration::from_millis(1);
let res = seq2.wait_for_timeout(42, timeout);
assert_eq!(res, Err(SeqWaitError::Timeout));
});
sleep(Duration::from_secs(1));
// This will attempt to wake, but nothing will happen
// because the waiter already dropped its Receiver.
let old = seq.advance(99);
assert_eq!(old, 0)
}
}

View File

@@ -0,0 +1,224 @@
///
/// Async version of 'seqwait.rs'
///
/// NOTE: This is currently unused. If you need this, you'll need to uncomment this in lib.rs.
///
#![warn(missing_docs)]
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::mem;
use std::sync::Mutex;
use std::time::Duration;
use tokio::sync::watch::{channel, Receiver, Sender};
use tokio::time::timeout;
/// An error happened while waiting for a number
#[derive(Debug, PartialEq, thiserror::Error)]
#[error("SeqWaitError")]
pub enum SeqWaitError {
/// The wait timeout was reached
Timeout,
/// [`SeqWait::shutdown`] was called
Shutdown,
}
/// Internal components of a `SeqWait`
struct SeqWaitInt<T>
where
T: Ord,
{
waiters: BTreeMap<T, (Sender<()>, Receiver<()>)>,
current: T,
shutdown: bool,
}
/// A tool for waiting on a sequence number
///
/// This provides a way to await the arrival of a number.
/// As soon as the number arrives by another caller calling
/// [`advance`], then the waiter will be woken up.
///
/// This implementation takes a blocking Mutex on both [`wait_for`]
/// and [`advance`], meaning there may be unexpected executor blocking
/// due to thread scheduling unfairness. There are probably better
/// implementations, but we can probably live with this for now.
///
/// [`wait_for`]: SeqWait::wait_for
/// [`advance`]: SeqWait::advance
///
pub struct SeqWait<T>
where
T: Ord,
{
internal: Mutex<SeqWaitInt<T>>,
}
impl<T> SeqWait<T>
where
T: Ord + Debug + Copy,
{
/// Create a new `SeqWait`, initialized to a particular number
pub fn new(starting_num: T) -> Self {
let internal = SeqWaitInt {
waiters: BTreeMap::new(),
current: starting_num,
shutdown: false,
};
SeqWait {
internal: Mutex::new(internal),
}
}
/// Shut down a `SeqWait`, causing all waiters (present and
/// future) to return an error.
pub fn shutdown(&self) {
let waiters = {
// Prevent new waiters; wake all those that exist.
// Wake everyone with an error.
let mut internal = self.internal.lock().unwrap();
// This will steal the entire waiters map.
// When we drop it all waiters will be woken.
mem::take(&mut internal.waiters)
// Drop the lock as we exit this scope.
};
// When we drop the waiters list, each Receiver will
// be woken with an error.
// This drop doesn't need to be explicit; it's done
// here to make it easier to read the code and understand
// the order of events.
drop(waiters);
}
/// Wait for a number to arrive
///
/// This call won't complete until someone has called `advance`
/// with a number greater than or equal to the one we're waiting for.
pub async fn wait_for(&self, num: T) -> Result<(), SeqWaitError> {
let mut rx = {
let mut internal = self.internal.lock().unwrap();
if internal.current >= num {
return Ok(());
}
if internal.shutdown {
return Err(SeqWaitError::Shutdown);
}
// If we already have a channel for waiting on this number, reuse it.
if let Some((_, rx)) = internal.waiters.get_mut(&num) {
// an Err from changed() means the sender was dropped.
rx.clone()
} else {
// Create a new channel.
let (tx, rx) = channel(());
internal.waiters.insert(num, (tx, rx.clone()));
rx
}
// Drop the lock as we exit this scope.
};
rx.changed().await.map_err(|_| SeqWaitError::Shutdown)
}
/// Wait for a number to arrive
///
/// This call won't complete until someone has called `advance`
/// with a number greater than or equal to the one we're waiting for.
///
/// If that hasn't happened after the specified timeout duration,
/// [`SeqWaitError::Timeout`] will be returned.
pub async fn wait_for_timeout(
&self,
num: T,
timeout_duration: Duration,
) -> Result<(), SeqWaitError> {
timeout(timeout_duration, self.wait_for(num))
.await
.unwrap_or(Err(SeqWaitError::Timeout))
}
/// Announce a new number has arrived
///
/// All waiters at this value or below will be woken.
///
/// `advance` will panic if you send it a lower number than
/// a previous call.
pub fn advance(&self, num: T) {
let wake_these = {
let mut internal = self.internal.lock().unwrap();
if internal.current > num {
panic!(
"tried to advance backwards, from {:?} to {:?}",
internal.current, num
);
}
internal.current = num;
// split_off will give me all the high-numbered waiters,
// so split and then swap. Everything at or above `num`
// stays.
let mut split = internal.waiters.split_off(&num);
std::mem::swap(&mut split, &mut internal.waiters);
// `split_at` didn't get the value at `num`; if it's
// there take that too.
if let Some(sleeper) = internal.waiters.remove(&num) {
split.insert(num, sleeper);
}
split
};
for (_wake_num, (tx, _rx)) in wake_these {
// This can fail if there are no receivers.
// We don't care; discard the error.
let _ = tx.send(());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::time::{sleep, Duration};
#[tokio::test]
async fn seqwait() {
let seq = Arc::new(SeqWait::new(0));
let seq2 = Arc::clone(&seq);
let seq3 = Arc::clone(&seq);
tokio::spawn(async move {
seq2.wait_for(42).await.expect("wait_for 42");
seq2.advance(100);
seq2.wait_for(999).await.expect_err("no 999");
});
tokio::spawn(async move {
seq3.wait_for(42).await.expect("wait_for 42");
seq3.wait_for(0).await.expect("wait_for 0");
});
sleep(Duration::from_secs(1)).await;
seq.advance(99);
seq.wait_for(100).await.expect("wait_for 100");
seq.shutdown();
}
#[tokio::test]
async fn seqwait_timeout() {
let seq = Arc::new(SeqWait::new(0));
let seq2 = Arc::clone(&seq);
tokio::spawn(async move {
let timeout = Duration::from_millis(1);
let res = seq2.wait_for_timeout(42, timeout).await;
assert_eq!(res, Err(SeqWaitError::Timeout));
});
sleep(Duration::from_secs(1)).await;
// This will attempt to wake, but nothing will happen
// because the waiter already dropped its Receiver.
seq.advance(99);
}
}

View File

@@ -0,0 +1,6 @@
/// Immediately terminate the calling process without calling
/// atexit callbacks, C runtime destructors etc. We mainly use
/// this to protect coverage data from concurrent writes.
pub fn exit_now(code: u8) {
unsafe { nix::libc::_exit(code as _) };
}

59
libs/utils/src/signals.rs Normal file
View File

@@ -0,0 +1,59 @@
use signal_hook::flag;
use signal_hook::iterator::Signals;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
pub use signal_hook::consts::{signal::*, TERM_SIGNALS};
pub fn install_shutdown_handlers() -> anyhow::Result<ShutdownSignals> {
let term_now = Arc::new(AtomicBool::new(false));
for sig in TERM_SIGNALS {
// When terminated by a second term signal, exit with exit code 1.
// This will do nothing the first time (because term_now is false).
flag::register_conditional_shutdown(*sig, 1, Arc::clone(&term_now))?;
// But this will "arm" the above for the second time, by setting it to true.
// The order of registering these is important, if you put this one first, it will
// first arm and then terminate all in the first round.
flag::register(*sig, Arc::clone(&term_now))?;
}
Ok(ShutdownSignals)
}
pub enum Signal {
Quit,
Interrupt,
Terminate,
}
impl Signal {
pub fn name(&self) -> &'static str {
match self {
Signal::Quit => "SIGQUIT",
Signal::Interrupt => "SIGINT",
Signal::Terminate => "SIGTERM",
}
}
}
pub struct ShutdownSignals;
impl ShutdownSignals {
pub fn handle(
self,
mut handler: impl FnMut(Signal) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
for raw_signal in Signals::new(TERM_SIGNALS)?.into_iter() {
let signal = match raw_signal {
SIGINT => Signal::Interrupt,
SIGTERM => Signal::Terminate,
SIGQUIT => Signal::Quit,
other => panic!("unknown signal: {}", other),
};
handler(signal)?;
}
Ok(())
}
}

View File

@@ -0,0 +1,206 @@
use std::{
io::{self, BufReader, Write},
net::{Shutdown, TcpStream},
sync::Arc,
};
use rustls::Connection;
/// Wrapper supporting reads of a shared TcpStream.
pub struct ArcTcpRead(Arc<TcpStream>);
impl io::Read for ArcTcpRead {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
(&*self.0).read(buf)
}
}
impl std::ops::Deref for ArcTcpRead {
type Target = TcpStream;
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
/// Wrapper around a TCP Stream supporting buffered reads.
pub struct BufStream(BufReader<ArcTcpRead>);
impl io::Read for BufStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl io::Write for BufStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.get_ref().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.get_ref().flush()
}
}
impl BufStream {
/// Unwrap into the internal BufReader.
fn into_reader(self) -> BufReader<ArcTcpRead> {
self.0
}
/// Returns a reference to the underlying TcpStream.
fn get_ref(&self) -> &TcpStream {
&*self.0.get_ref().0
}
}
pub enum ReadStream {
Tcp(BufReader<ArcTcpRead>),
Tls(rustls_split::ReadHalf),
}
impl io::Read for ReadStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(reader) => reader.read(buf),
Self::Tls(read_half) => read_half.read(buf),
}
}
}
impl ReadStream {
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(write_half) => write_half.shutdown(how),
}
}
}
pub enum WriteStream {
Tcp(Arc<TcpStream>),
Tls(rustls_split::WriteHalf),
}
impl WriteStream {
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.shutdown(how),
Self::Tls(write_half) => write_half.shutdown(how),
}
}
}
impl io::Write for WriteStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.as_ref().write(buf),
Self::Tls(write_half) => write_half.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.as_ref().flush(),
Self::Tls(write_half) => write_half.flush(),
}
}
}
type TlsStream<T> = rustls::StreamOwned<rustls::ServerConnection, T>;
pub enum BidiStream {
Tcp(BufStream),
/// This variant is boxed, because [`rustls::ServerConnection`] is quite larger than [`BufStream`].
Tls(Box<TlsStream<BufStream>>),
}
impl BidiStream {
pub fn from_tcp(stream: TcpStream) -> Self {
Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream)))))
}
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(tls_boxed) => {
if how == Shutdown::Read {
tls_boxed.sock.get_ref().shutdown(how)
} else {
tls_boxed.conn.send_close_notify();
let res = tls_boxed.flush();
tls_boxed.sock.get_ref().shutdown(how)?;
res
}
}
}
}
/// Split the bi-directional stream into two owned read and write halves.
pub fn split(self) -> (ReadStream, WriteStream) {
match self {
Self::Tcp(stream) => {
let reader = stream.into_reader();
let stream: Arc<TcpStream> = reader.get_ref().0.clone();
(ReadStream::Tcp(reader), WriteStream::Tcp(stream))
}
Self::Tls(tls_boxed) => {
let reader = tls_boxed.sock.into_reader();
let buffer_data = reader.buffer().to_owned();
let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192);
let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192);
// TODO would be nice to avoid the Arc here
let socket = Arc::try_unwrap(reader.into_inner().0).unwrap();
let (read_half, write_half) = rustls_split::split(
socket,
Connection::Server(tls_boxed.conn),
read_buf_cfg,
write_buf_cfg,
);
(ReadStream::Tls(read_half), WriteStream::Tls(write_half))
}
}
}
pub fn start_tls(self, mut conn: rustls::ServerConnection) -> io::Result<Self> {
match self {
Self::Tcp(mut stream) => {
conn.complete_io(&mut stream)?;
assert!(!conn.is_handshaking());
Ok(Self::Tls(Box::new(TlsStream::new(conn, stream))))
}
Self::Tls { .. } => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"TLS is already started on this stream",
)),
}
}
}
impl io::Read for BidiStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.read(buf),
Self::Tls(tls_boxed) => tls_boxed.read(buf),
}
}
}
impl io::Write for BidiStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.write(buf),
Self::Tls(tls_boxed) => tls_boxed.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.flush(),
Self::Tls(tls_boxed) => tls_boxed.flush(),
}
}
}

179
libs/utils/src/sync.rs Normal file
View File

@@ -0,0 +1,179 @@
use pin_project_lite::pin_project;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::{io, task};
pin_project! {
/// We use this future to mark certain methods
/// as callable in both sync and async modes.
#[repr(transparent)]
pub struct SyncFuture<S, T: Future> {
#[pin]
inner: T,
_marker: PhantomData<S>,
}
}
/// This wrapper lets us synchronously wait for inner future's completion
/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**.
/// For instance, `S` may be substituted with types implementing
/// [`tokio::io::AsyncRead`], but it's not the only viable option.
impl<S, T: Future> SyncFuture<S, T> {
/// NOTE: caller should carefully pick a type for `S`,
/// because we don't want to enable [`SyncFuture::wait`] when
/// it's in fact impossible to run the future synchronously.
/// Violation of this contract will not cause UB, but
/// panics and async event loop freezes won't please you.
///
/// Example:
///
/// ```
/// # use utils::sync::SyncFuture;
/// # use std::future::Future;
/// # use tokio::io::AsyncReadExt;
/// #
/// // Parse a pair of numbers from a stream
/// pub fn parse_pair<Reader>(
/// stream: &mut Reader,
/// ) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<(u32, u64)>> + '_>
/// where
/// Reader: tokio::io::AsyncRead + Unpin,
/// {
/// // If `Reader` is a `SyncProof`, this will give caller
/// // an opportunity to use `SyncFuture::wait`, because
/// // `.await` will always result in `Poll::Ready`.
/// SyncFuture::new(async move {
/// let x = stream.read_u32().await?;
/// let y = stream.read_u64().await?;
/// Ok((x, y))
/// })
/// }
/// ```
pub fn new(inner: T) -> Self {
Self {
inner,
_marker: PhantomData,
}
}
}
impl<S, T: Future> Future for SyncFuture<S, T> {
type Output = T::Output;
/// In async code, [`SyncFuture`] behaves like a regular wrapper.
#[inline(always)]
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
/// Postulates that we can call [`SyncFuture::wait`].
/// If implementer is also a [`Future`], it should always
/// return [`task::Poll::Ready`] from [`Future::poll`].
///
/// Each implementation should document which futures
/// specifically are being declared sync-proof.
pub trait SyncPostulate {}
impl<T: SyncPostulate> SyncPostulate for &T {}
impl<T: SyncPostulate> SyncPostulate for &mut T {}
impl<P: SyncPostulate, T: Future> SyncFuture<P, T> {
/// Synchronously wait for future completion.
pub fn wait(mut self) -> T::Output {
const RAW_WAKER: task::RawWaker = task::RawWaker::new(
std::ptr::null(),
&task::RawWakerVTable::new(
|_| RAW_WAKER,
|_| panic!("SyncFuture: failed to wake"),
|_| panic!("SyncFuture: failed to wake by ref"),
|_| { /* drop is no-op */ },
),
);
// SAFETY: We never move `self` during this call;
// furthermore, it will be dropped in the end regardless of panics
let this = unsafe { Pin::new_unchecked(&mut self) };
// SAFETY: This waker doesn't do anything apart from panicking
let waker = unsafe { task::Waker::from_raw(RAW_WAKER) };
let context = &mut task::Context::from_waker(&waker);
match this.poll(context) {
task::Poll::Ready(res) => res,
_ => panic!("SyncFuture: unexpected pending!"),
}
}
}
/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`],
/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`].
/// NOTE: you **should not** use this in async code.
#[repr(transparent)]
pub struct AsyncishRead<T: io::Read + Unpin>(pub T);
/// This lets us call [`SyncFuture<AsyncishRead<_>, _>::wait`],
/// and allows the future to await on any of the [`AsyncRead`]
/// and [`AsyncReadExt`] methods on `AsyncishRead`.
impl<T: io::Read + Unpin> SyncPostulate for AsyncishRead<T> {}
impl<T: io::Read + Unpin> tokio::io::AsyncRead for AsyncishRead<T> {
#[inline(always)]
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
task::Poll::Ready(
// `Read::read` will block, meaning we don't need a real event loop!
self.0
.read(buf.initialize_unfilled())
.map(|sz| buf.advance(sz)),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// async helper(stream: &mut impl AsyncRead) -> io::Result<u32>
fn bytes_add<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = io::Result<u32>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
SyncFuture::new(async move {
let a = stream.read_u32().await?;
let b = stream.read_u32().await?;
Ok(a + b)
})
}
#[test]
fn test_sync() {
let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat();
let res = bytes_add(&mut AsyncishRead(&mut &bytes[..]))
.wait()
.unwrap();
assert_eq!(res, 300);
}
// We need a single-threaded executor for this test
#[tokio::test(flavor = "current_thread")]
async fn test_async() {
let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap();
let write = async move {
tx.write_u32(100).await?;
tx.write_u32(200).await?;
Ok(())
};
let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap();
assert_eq!(res, 300);
}
}

View File

@@ -0,0 +1,16 @@
use std::{
io,
net::{TcpListener, ToSocketAddrs},
os::unix::prelude::AsRawFd,
};
use nix::sys::socket::{setsockopt, sockopt::ReuseAddr};
/// Bind a [`TcpListener`] to addr with `SO_REUSEADDR` set to true.
pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<TcpListener> {
let listener = TcpListener::bind(addr)?;
setsockopt(listener.as_raw_fd(), ReuseAddr, &true)?;
Ok(listener)
}

316
libs/utils/src/vec_map.rs Normal file
View File

@@ -0,0 +1,316 @@
use std::{alloc::Layout, cmp::Ordering, ops::RangeBounds};
use serde::{Deserialize, Serialize};
/// Ordered map datastructure implemented in a Vec.
/// Append only - can only add keys that are larger than the
/// current max key.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VecMap<K, V>(Vec<(K, V)>);
impl<K, V> Default for VecMap<K, V> {
fn default() -> Self {
VecMap(Default::default())
}
}
#[derive(Debug)]
pub struct InvalidKey;
impl<K: Ord, V> VecMap<K, V> {
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn as_slice(&self) -> &[(K, V)] {
self.0.as_slice()
}
/// This function may panic if given a range where the lower bound is
/// greater than the upper bound.
pub fn slice_range<R: RangeBounds<K>>(&self, range: R) -> &[(K, V)] {
use std::ops::Bound::*;
let binary_search = |k: &K| self.0.binary_search_by_key(&k, extract_key);
let start_idx = match range.start_bound() {
Unbounded => 0,
Included(k) => binary_search(k).unwrap_or_else(std::convert::identity),
Excluded(k) => match binary_search(k) {
Ok(idx) => idx + 1,
Err(idx) => idx,
},
};
let end_idx = match range.end_bound() {
Unbounded => self.0.len(),
Included(k) => match binary_search(k) {
Ok(idx) => idx + 1,
Err(idx) => idx,
},
Excluded(k) => binary_search(k).unwrap_or_else(std::convert::identity),
};
&self.0[start_idx..end_idx]
}
/// Add a key value pair to the map.
/// If `key` is less than or equal to the current maximum key
/// the pair will not be added and InvalidKey error will be returned.
pub fn append(&mut self, key: K, value: V) -> Result<usize, InvalidKey> {
if let Some((last_key, _last_value)) = self.0.last() {
if &key <= last_key {
return Err(InvalidKey);
}
}
let delta_size = self.instrument_vec_op(|vec| vec.push((key, value)));
Ok(delta_size)
}
/// Update the maximum key value pair or add a new key value pair to the map.
/// If `key` is less than the current maximum key no updates or additions
/// will occur and InvalidKey error will be returned.
pub fn append_or_update_last(
&mut self,
key: K,
mut value: V,
) -> Result<(Option<V>, usize), InvalidKey> {
if let Some((last_key, last_value)) = self.0.last_mut() {
match key.cmp(last_key) {
Ordering::Less => return Err(InvalidKey),
Ordering::Equal => {
std::mem::swap(last_value, &mut value);
const DELTA_SIZE: usize = 0;
return Ok((Some(value), DELTA_SIZE));
}
Ordering::Greater => {}
}
}
let delta_size = self.instrument_vec_op(|vec| vec.push((key, value)));
Ok((None, delta_size))
}
/// Split the map into two.
///
/// The left map contains everything before `cutoff` (exclusive).
/// Right map contains `cutoff` and everything after (inclusive).
pub fn split_at(&self, cutoff: &K) -> (Self, Self)
where
K: Clone,
V: Clone,
{
let split_idx = self
.0
.binary_search_by_key(&cutoff, extract_key)
.unwrap_or_else(std::convert::identity);
(
VecMap(self.0[..split_idx].to_vec()),
VecMap(self.0[split_idx..].to_vec()),
)
}
/// Move items from `other` to the end of `self`, leaving `other` empty.
/// If any keys in `other` is less than or equal to any key in `self`,
/// `InvalidKey` error will be returned and no mutation will occur.
pub fn extend(&mut self, other: &mut Self) -> Result<usize, InvalidKey> {
let self_last_opt = self.0.last().map(extract_key);
let other_first_opt = other.0.last().map(extract_key);
if let (Some(self_last), Some(other_first)) = (self_last_opt, other_first_opt) {
if self_last >= other_first {
return Err(InvalidKey);
}
}
let delta_size = self.instrument_vec_op(|vec| vec.append(&mut other.0));
Ok(delta_size)
}
/// Instrument an operation on the underlying [`Vec`].
/// Will panic if the operation decreases capacity.
/// Returns the increase in memory usage caused by the op.
fn instrument_vec_op(&mut self, op: impl FnOnce(&mut Vec<(K, V)>)) -> usize {
let old_cap = self.0.capacity();
op(&mut self.0);
let new_cap = self.0.capacity();
match old_cap.cmp(&new_cap) {
Ordering::Less => {
let old_size = Layout::array::<(K, V)>(old_cap).unwrap().size();
let new_size = Layout::array::<(K, V)>(new_cap).unwrap().size();
new_size - old_size
}
Ordering::Equal => 0,
Ordering::Greater => panic!("VecMap capacity shouldn't ever decrease"),
}
}
}
fn extract_key<K, V>(entry: &(K, V)) -> &K {
&entry.0
}
#[cfg(test)]
mod tests {
use std::{collections::BTreeMap, ops::Bound};
use super::VecMap;
#[test]
fn unbounded_range() {
let mut vec = VecMap::default();
vec.append(0, ()).unwrap();
assert_eq!(vec.slice_range(0..0), &[]);
}
#[test]
#[should_panic]
fn invalid_ordering_range() {
let mut vec = VecMap::default();
vec.append(0, ()).unwrap();
#[allow(clippy::reversed_empty_ranges)]
vec.slice_range(1..0);
}
#[test]
fn range_tests() {
let mut vec = VecMap::default();
vec.append(0, ()).unwrap();
vec.append(2, ()).unwrap();
vec.append(4, ()).unwrap();
assert_eq!(vec.slice_range(0..0), &[]);
assert_eq!(vec.slice_range(0..1), &[(0, ())]);
assert_eq!(vec.slice_range(0..2), &[(0, ())]);
assert_eq!(vec.slice_range(0..3), &[(0, ()), (2, ())]);
assert_eq!(vec.slice_range(..0), &[]);
assert_eq!(vec.slice_range(..1), &[(0, ())]);
assert_eq!(vec.slice_range(..3), &[(0, ()), (2, ())]);
assert_eq!(vec.slice_range(..3), &[(0, ()), (2, ())]);
assert_eq!(vec.slice_range(0..=0), &[(0, ())]);
assert_eq!(vec.slice_range(0..=1), &[(0, ())]);
assert_eq!(vec.slice_range(0..=2), &[(0, ()), (2, ())]);
assert_eq!(vec.slice_range(0..=3), &[(0, ()), (2, ())]);
assert_eq!(vec.slice_range(..=0), &[(0, ())]);
assert_eq!(vec.slice_range(..=1), &[(0, ())]);
assert_eq!(vec.slice_range(..=2), &[(0, ()), (2, ())]);
assert_eq!(vec.slice_range(..=3), &[(0, ()), (2, ())]);
}
struct BoundIter {
min: i32,
max: i32,
next: Option<Bound<i32>>,
}
impl BoundIter {
fn new(min: i32, max: i32) -> Self {
Self {
min,
max,
next: Some(Bound::Unbounded),
}
}
}
impl Iterator for BoundIter {
type Item = Bound<i32>;
fn next(&mut self) -> Option<Self::Item> {
let cur = self.next?;
self.next = match &cur {
Bound::Unbounded => Some(Bound::Included(self.min)),
Bound::Included(x) => {
if *x >= self.max {
Some(Bound::Excluded(self.min))
} else {
Some(Bound::Included(x + 1))
}
}
Bound::Excluded(x) => {
if *x >= self.max {
None
} else {
Some(Bound::Excluded(x + 1))
}
}
};
Some(cur)
}
}
#[test]
fn range_exhaustive() {
let map: BTreeMap<i32, ()> = (1..=7).step_by(2).map(|x| (x, ())).collect();
let mut vec = VecMap::default();
for &key in map.keys() {
vec.append(key, ()).unwrap();
}
const RANGE_MIN: i32 = 0;
const RANGE_MAX: i32 = 8;
for lower_bound in BoundIter::new(RANGE_MIN, RANGE_MAX) {
let ub_min = match lower_bound {
Bound::Unbounded => RANGE_MIN,
Bound::Included(x) => x,
Bound::Excluded(x) => x + 1,
};
for upper_bound in BoundIter::new(ub_min, RANGE_MAX) {
let map_range: Vec<(i32, ())> = map
.range((lower_bound, upper_bound))
.map(|(&x, _)| (x, ()))
.collect();
let vec_slice = vec.slice_range((lower_bound, upper_bound));
assert_eq!(map_range, vec_slice);
}
}
}
#[test]
fn extend() {
let mut left = VecMap::default();
left.append(0, ()).unwrap();
assert_eq!(left.as_slice(), &[(0, ())]);
let mut empty = VecMap::default();
left.extend(&mut empty).unwrap();
assert_eq!(left.as_slice(), &[(0, ())]);
assert_eq!(empty.as_slice(), &[]);
let mut right = VecMap::default();
right.append(1, ()).unwrap();
left.extend(&mut right).unwrap();
assert_eq!(left.as_slice(), &[(0, ()), (1, ())]);
assert_eq!(right.as_slice(), &[]);
let mut zero_map = VecMap::default();
zero_map.append(0, ()).unwrap();
left.extend(&mut zero_map).unwrap_err();
assert_eq!(left.as_slice(), &[(0, ()), (1, ())]);
assert_eq!(zero_map.as_slice(), &[(0, ())]);
let mut one_map = VecMap::default();
one_map.append(1, ()).unwrap();
left.extend(&mut one_map).unwrap_err();
assert_eq!(left.as_slice(), &[(0, ()), (1, ())]);
assert_eq!(one_map.as_slice(), &[(1, ())]);
}
}

235
libs/utils/src/zid.rs Normal file
View File

@@ -0,0 +1,235 @@
use std::{fmt, str::FromStr};
use hex::FromHex;
use rand::Rng;
use serde::{Deserialize, Serialize};
/// Zenith ID is a 128-bit random ID.
/// Used to represent various identifiers. Provides handy utility methods and impls.
///
/// NOTE: It (de)serializes as an array of hex bytes, so the string representation would look
/// like `[173,80,132,115,129,226,72,254,170,201,135,108,199,26,228,24]`.
///
/// Use `#[serde_as(as = "DisplayFromStr")]` to (de)serialize it as hex string instead: `ad50847381e248feaac9876cc71ae418`.
/// Check the `serde_with::serde_as` documentation for options for more complex types.
#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
struct ZId([u8; 16]);
impl ZId {
pub fn get_from_buf(buf: &mut dyn bytes::Buf) -> ZId {
let mut arr = [0u8; 16];
buf.copy_to_slice(&mut arr);
ZId::from(arr)
}
pub fn as_arr(&self) -> [u8; 16] {
self.0
}
pub fn generate() -> Self {
let mut tli_buf = [0u8; 16];
rand::thread_rng().fill(&mut tli_buf);
ZId::from(tli_buf)
}
fn hex_encode(&self) -> String {
static HEX: &[u8] = b"0123456789abcdef";
let mut buf = vec![0u8; self.0.len() * 2];
for (&b, chunk) in self.0.as_ref().iter().zip(buf.chunks_exact_mut(2)) {
chunk[0] = HEX[((b >> 4) & 0xf) as usize];
chunk[1] = HEX[(b & 0xf) as usize];
}
unsafe { String::from_utf8_unchecked(buf) }
}
}
impl FromStr for ZId {
type Err = hex::FromHexError;
fn from_str(s: &str) -> Result<ZId, Self::Err> {
Self::from_hex(s)
}
}
// this is needed for pretty serialization and deserialization of ZId's using serde integration with hex crate
impl FromHex for ZId {
type Error = hex::FromHexError;
fn from_hex<T: AsRef<[u8]>>(hex: T) -> Result<Self, Self::Error> {
let mut buf: [u8; 16] = [0u8; 16];
hex::decode_to_slice(hex, &mut buf)?;
Ok(ZId(buf))
}
}
impl AsRef<[u8]> for ZId {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl From<[u8; 16]> for ZId {
fn from(b: [u8; 16]) -> Self {
ZId(b)
}
}
impl fmt::Display for ZId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.hex_encode())
}
}
impl fmt::Debug for ZId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.hex_encode())
}
}
macro_rules! zid_newtype {
($t:ident) => {
impl $t {
pub fn get_from_buf(buf: &mut dyn bytes::Buf) -> $t {
$t(ZId::get_from_buf(buf))
}
pub fn as_arr(&self) -> [u8; 16] {
self.0.as_arr()
}
pub fn generate() -> Self {
$t(ZId::generate())
}
pub const fn from_array(b: [u8; 16]) -> Self {
$t(ZId(b))
}
}
impl FromStr for $t {
type Err = hex::FromHexError;
fn from_str(s: &str) -> Result<$t, Self::Err> {
let value = ZId::from_str(s)?;
Ok($t(value))
}
}
impl From<[u8; 16]> for $t {
fn from(b: [u8; 16]) -> Self {
$t(ZId::from(b))
}
}
impl FromHex for $t {
type Error = hex::FromHexError;
fn from_hex<T: AsRef<[u8]>>(hex: T) -> Result<Self, Self::Error> {
Ok($t(ZId::from_hex(hex)?))
}
}
impl AsRef<[u8]> for $t {
fn as_ref(&self) -> &[u8] {
&self.0 .0
}
}
impl fmt::Display for $t {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl fmt::Debug for $t {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
};
}
/// Zenith timeline IDs are different from PostgreSQL timeline
/// IDs. They serve a similar purpose though: they differentiate
/// between different "histories" of the same cluster. However,
/// PostgreSQL timeline IDs are a bit cumbersome, because they are only
/// 32-bits wide, and they must be in ascending order in any given
/// timeline history. Those limitations mean that we cannot generate a
/// new PostgreSQL timeline ID by just generating a random number. And
/// that in turn is problematic for the "pull/push" workflow, where you
/// have a local copy of a zenith repository, and you periodically sync
/// the local changes with a remote server. When you work "detached"
/// from the remote server, you cannot create a PostgreSQL timeline ID
/// that's guaranteed to be different from all existing timelines in
/// the remote server. For example, if two people are having a clone of
/// the repository on their laptops, and they both create a new branch
/// with different name. What timeline ID would they assign to their
/// branches? If they pick the same one, and later try to push the
/// branches to the same remote server, they will get mixed up.
///
/// To avoid those issues, Zenith has its own concept of timelines that
/// is separate from PostgreSQL timelines, and doesn't have those
/// limitations. A zenith timeline is identified by a 128-bit ID, which
/// is usually printed out as a hex string.
///
/// NOTE: It (de)serializes as an array of hex bytes, so the string representation would look
/// like `[173,80,132,115,129,226,72,254,170,201,135,108,199,26,228,24]`.
/// See [`ZId`] for alternative ways to serialize it.
#[derive(Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd, Serialize, Deserialize)]
pub struct ZTimelineId(ZId);
zid_newtype!(ZTimelineId);
/// Zenith Tenant Id represents identifiar of a particular tenant.
/// Is used for distinguishing requests and data belonging to different users.
///
/// NOTE: It (de)serializes as an array of hex bytes, so the string representation would look
/// like `[173,80,132,115,129,226,72,254,170,201,135,108,199,26,228,24]`.
/// See [`ZId`] for alternative ways to serialize it.
#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct ZTenantId(ZId);
zid_newtype!(ZTenantId);
// A pair uniquely identifying Zenith instance.
#[derive(Debug, Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Hash)]
pub struct ZTenantTimelineId {
pub tenant_id: ZTenantId,
pub timeline_id: ZTimelineId,
}
impl ZTenantTimelineId {
pub fn new(tenant_id: ZTenantId, timeline_id: ZTimelineId) -> Self {
ZTenantTimelineId {
tenant_id,
timeline_id,
}
}
pub fn generate() -> Self {
Self::new(ZTenantId::generate(), ZTimelineId::generate())
}
pub fn empty() -> Self {
Self::new(ZTenantId::from([0u8; 16]), ZTimelineId::from([0u8; 16]))
}
}
impl fmt::Display for ZTenantTimelineId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}-{}", self.tenant_id, self.timeline_id)
}
}
// Unique ID of a storage node (safekeeper or pageserver). Supposed to be issued
// by the console.
#[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd, Debug, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ZNodeId(pub u64);
impl fmt::Display for ZNodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}