mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-24 21:10:37 +00:00
Compare commits
5 Commits
ci-run/pr-
...
conrad/pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9cffb16463 | ||
|
|
b2e0ab5dc6 | ||
|
|
c55742b437 | ||
|
|
431a12acba | ||
|
|
fd07ecf58f |
86
Cargo.lock
generated
86
Cargo.lock
generated
@@ -2424,33 +2424,6 @@ dependencies = [
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gcp_auth"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dbf67f30198e045a039264c01fb44659ce82402d7771c50938beb41a5ac87733"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"home",
|
||||
"http 1.1.0",
|
||||
"http-body-util",
|
||||
"hyper 1.4.1",
|
||||
"hyper-rustls 0.27.5",
|
||||
"hyper-util",
|
||||
"ring",
|
||||
"rustls-pemfile 2.1.1",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-futures",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gen_ops"
|
||||
version = "0.4.0"
|
||||
@@ -2749,15 +2722,6 @@ dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "home"
|
||||
version = "0.5.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
|
||||
dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hostname"
|
||||
version = "0.4.0"
|
||||
@@ -2987,24 +2951,6 @@ dependencies = [
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-rustls"
|
||||
version = "0.27.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"http 1.1.0",
|
||||
"hyper 1.4.1",
|
||||
"hyper-util",
|
||||
"rustls 0.23.18",
|
||||
"rustls-native-certs 0.8.0",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
"tokio-rustls 0.26.0",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-timeout"
|
||||
version = "0.5.1"
|
||||
@@ -3760,16 +3706,6 @@ version = "0.3.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
|
||||
|
||||
[[package]]
|
||||
name = "mime_guess"
|
||||
version = "2.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
|
||||
dependencies = [
|
||||
"mime",
|
||||
"unicase",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "minimal-lexical"
|
||||
version = "0.2.1"
|
||||
@@ -5586,11 +5522,8 @@ dependencies = [
|
||||
"bytes",
|
||||
"camino",
|
||||
"camino-tempfile",
|
||||
"chrono",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"gcp_auth",
|
||||
"http 1.1.0",
|
||||
"http-body-util",
|
||||
"http-types",
|
||||
"humantime-serde",
|
||||
@@ -5611,9 +5544,7 @@ dependencies = [
|
||||
"tokio-util",
|
||||
"toml_edit",
|
||||
"tracing",
|
||||
"url",
|
||||
"utils",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5643,7 +5574,6 @@ dependencies = [
|
||||
"js-sys",
|
||||
"log",
|
||||
"mime",
|
||||
"mime_guess",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
@@ -7667,16 +7597,6 @@ dependencies = [
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-futures"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2"
|
||||
dependencies = [
|
||||
"pin-project",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-log"
|
||||
version = "0.2.0"
|
||||
@@ -7830,12 +7750,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicase"
|
||||
version = "2.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-bidi"
|
||||
version = "0.3.17"
|
||||
|
||||
@@ -18,8 +18,7 @@ camino = { workspace = true, features = ["serde1"] }
|
||||
humantime-serde.workspace = true
|
||||
hyper = { workspace = true, features = ["client"] }
|
||||
futures.workspace = true
|
||||
reqwest = { workspace = true, features = ["multipart", "stream"] }
|
||||
chrono = { version = "0.4", default-features = false, features = ["clock"] }
|
||||
reqwest.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
tokio = { workspace = true, features = ["sync", "fs", "io-util"] }
|
||||
@@ -41,10 +40,6 @@ http-types.workspace = true
|
||||
http-body-util.workspace = true
|
||||
itertools.workspace = true
|
||||
sync_wrapper = { workspace = true, features = ["futures"] }
|
||||
gcp_auth = "0.12.3"
|
||||
url.workspace = true
|
||||
http.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
camino-tempfile.workspace = true
|
||||
|
||||
@@ -41,7 +41,6 @@ impl RemoteStorageKind {
|
||||
RemoteStorageKind::LocalFs { .. } => None,
|
||||
RemoteStorageKind::AwsS3(config) => Some(&config.bucket_name),
|
||||
RemoteStorageKind::AzureContainer(config) => Some(&config.container_name),
|
||||
RemoteStorageKind::GCS(config) => Some(&config.bucket_name),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -52,7 +51,6 @@ impl RemoteStorageConfig {
|
||||
match &self.storage {
|
||||
RemoteStorageKind::LocalFs { .. } => DEFAULT_REMOTE_STORAGE_LOCALFS_CONCURRENCY_LIMIT,
|
||||
RemoteStorageKind::AwsS3(c) => c.concurrency_limit.into(),
|
||||
RemoteStorageKind::GCS(c) => c.concurrency_limit.into(),
|
||||
RemoteStorageKind::AzureContainer(c) => c.concurrency_limit.into(),
|
||||
}
|
||||
}
|
||||
@@ -87,9 +85,6 @@ pub enum RemoteStorageKind {
|
||||
/// Azure Blob based storage, storing all files in the container
|
||||
/// specified by the config
|
||||
AzureContainer(AzureConfig),
|
||||
/// Google Cloud based storage, storing all files in the GCS bucket
|
||||
/// specified by the config
|
||||
GCS(GCSConfig),
|
||||
}
|
||||
|
||||
/// AWS S3 bucket coordinates and access credentials to manage the bucket contents (read and write).
|
||||
@@ -159,32 +154,6 @@ impl Debug for S3Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Deserialize, Serialize)]
|
||||
pub struct GCSConfig {
|
||||
/// Name of the bucket to connect to.
|
||||
pub bucket_name: String,
|
||||
/// A "subfolder" in the bucket, to use the same bucket separately by multiple remote storage users at once.
|
||||
pub prefix_in_bucket: Option<String>,
|
||||
#[serde(default = "default_remote_storage_s3_concurrency_limit")]
|
||||
pub concurrency_limit: NonZeroUsize,
|
||||
#[serde(default = "default_max_keys_per_list_response")]
|
||||
pub max_keys_per_list_response: Option<i32>,
|
||||
}
|
||||
|
||||
impl Debug for GCSConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GCSConfig")
|
||||
.field("bucket_name", &self.bucket_name)
|
||||
.field("prefix_in_bucket", &self.prefix_in_bucket)
|
||||
.field("concurrency_limit", &self.concurrency_limit)
|
||||
.field(
|
||||
"max_keys_per_list_response",
|
||||
&self.max_keys_per_list_response,
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Azure bucket coordinates and access credentials to manage the bucket contents (read and write).
|
||||
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct AzureConfig {
|
||||
@@ -299,30 +268,6 @@ timeout = '5s'";
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gcs_parsing() {
|
||||
let toml = "\
|
||||
bucket_name = 'foo-bar'
|
||||
prefix_in_bucket = '/pageserver'
|
||||
";
|
||||
|
||||
let config = parse(toml).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
config,
|
||||
RemoteStorageConfig {
|
||||
storage: RemoteStorageKind::GCS(GCSConfig {
|
||||
bucket_name: "foo-bar".into(),
|
||||
prefix_in_bucket: Some("pageserver/".into()),
|
||||
max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
|
||||
concurrency_limit: std::num::NonZero::new(100).unwrap(),
|
||||
}),
|
||||
timeout: Duration::from_secs(120),
|
||||
small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_s3_parsing() {
|
||||
let toml = "\
|
||||
|
||||
@@ -1,978 +0,0 @@
|
||||
#![allow(dead_code)]
|
||||
#![allow(unused)]
|
||||
|
||||
use crate::config::GCSConfig;
|
||||
use crate::error::Cancelled;
|
||||
pub(super) use crate::metrics::RequestKind;
|
||||
use crate::metrics::{AttemptOutcome, start_counting_cancelled_wait, start_measuring_requests};
|
||||
use crate::{
|
||||
ConcurrencyLimiter, Download, DownloadError, DownloadOpts, GCS_SCOPES, Listing, ListingMode,
|
||||
ListingObject, MAX_KEYS_PER_DELETE_GCS, REMOTE_STORAGE_PREFIX_SEPARATOR, RemotePath,
|
||||
RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
|
||||
};
|
||||
use anyhow::Context;
|
||||
use azure_core::Etag;
|
||||
use bytes::Bytes;
|
||||
use bytes::BytesMut;
|
||||
use chrono::DateTime;
|
||||
use futures::stream::Stream;
|
||||
use futures::stream::TryStreamExt;
|
||||
use futures_util::StreamExt;
|
||||
use gcp_auth::{Token, TokenProvider};
|
||||
use http::Method;
|
||||
use http::StatusCode;
|
||||
use reqwest::{Client, header};
|
||||
use scopeguard::ScopeGuard;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::num::NonZeroU32;
|
||||
use std::pin::{Pin, pin};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::SystemTime;
|
||||
use tokio_util::codec::{BytesCodec, FramedRead};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
// ---------
|
||||
pub struct GCSBucket {
|
||||
token_provider: Arc<dyn TokenProvider>,
|
||||
bucket_name: String,
|
||||
prefix_in_bucket: Option<String>,
|
||||
max_keys_per_list_response: Option<i32>,
|
||||
concurrency_limiter: ConcurrencyLimiter,
|
||||
pub timeout: Duration,
|
||||
}
|
||||
|
||||
struct GetObjectRequest {
|
||||
bucket: String,
|
||||
key: String,
|
||||
etag: Option<String>,
|
||||
range: Option<String>,
|
||||
}
|
||||
|
||||
// ---------
|
||||
|
||||
impl GCSBucket {
|
||||
pub async fn new(remote_storage_config: &GCSConfig, timeout: Duration) -> anyhow::Result<Self> {
|
||||
tracing::debug!(
|
||||
"creating remote storage for gcs bucket {}",
|
||||
remote_storage_config.bucket_name
|
||||
);
|
||||
|
||||
// clean up 'prefix_in_bucket' if user provides '/pageserver' or 'pageserver/'
|
||||
let prefix_in_bucket = remote_storage_config
|
||||
.prefix_in_bucket
|
||||
.as_deref()
|
||||
.map(|prefix| {
|
||||
let mut prefix = prefix;
|
||||
while prefix.starts_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
|
||||
prefix = &prefix[1..];
|
||||
}
|
||||
|
||||
let mut prefix = prefix.to_string();
|
||||
if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
|
||||
prefix.pop();
|
||||
}
|
||||
|
||||
prefix
|
||||
});
|
||||
|
||||
// get GOOGLE_APPLICATION_CREDENTIALS
|
||||
let provider = gcp_auth::provider().await?;
|
||||
|
||||
Ok(GCSBucket {
|
||||
token_provider: Arc::clone(&provider),
|
||||
bucket_name: remote_storage_config.bucket_name.clone(),
|
||||
prefix_in_bucket,
|
||||
timeout,
|
||||
max_keys_per_list_response: remote_storage_config.max_keys_per_list_response,
|
||||
concurrency_limiter: ConcurrencyLimiter::new(
|
||||
remote_storage_config.concurrency_limit.get(),
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
// convert `RemotePath` -> `String`
|
||||
pub fn relative_path_to_gcs_object(&self, path: &RemotePath) -> String {
|
||||
let path_string = path.get_path().as_str();
|
||||
match &self.prefix_in_bucket {
|
||||
Some(prefix) => prefix.clone() + "/" + path_string,
|
||||
None => path_string.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
// convert `String` -> `RemotePath`
|
||||
pub fn gcs_object_to_relative_path(&self, key: &str) -> RemotePath {
|
||||
let relative_path =
|
||||
match key.strip_prefix(self.prefix_in_bucket.as_deref().unwrap_or_default()) {
|
||||
Some(stripped) => stripped,
|
||||
// we rely on GCS to return properly prefixed paths
|
||||
// for requests with a certain prefix
|
||||
None => panic!(
|
||||
"Key {} does not start with bucket prefix {:?}",
|
||||
key, self.prefix_in_bucket
|
||||
),
|
||||
};
|
||||
RemotePath(
|
||||
relative_path
|
||||
.split(REMOTE_STORAGE_PREFIX_SEPARATOR)
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn bucket_name(&self) -> &str {
|
||||
&self.bucket_name
|
||||
}
|
||||
|
||||
fn max_keys_per_delete(&self) -> usize {
|
||||
MAX_KEYS_PER_DELETE_GCS
|
||||
}
|
||||
|
||||
async fn permit(
|
||||
&self,
|
||||
kind: RequestKind,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<tokio::sync::SemaphorePermit<'_>, Cancelled> {
|
||||
let started_at = start_counting_cancelled_wait(kind);
|
||||
let acquire = self.concurrency_limiter.acquire(kind);
|
||||
|
||||
let permit = tokio::select! {
|
||||
permit = acquire => permit.expect("semaphore is never closed"),
|
||||
_ = cancel.cancelled() => return Err(Cancelled),
|
||||
};
|
||||
|
||||
let started_at = ScopeGuard::into_inner(started_at);
|
||||
crate::metrics::BUCKET_METRICS
|
||||
.wait_seconds
|
||||
.observe_elapsed(kind, started_at);
|
||||
|
||||
Ok(permit)
|
||||
}
|
||||
|
||||
async fn owned_permit(
|
||||
&self,
|
||||
kind: RequestKind,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<tokio::sync::OwnedSemaphorePermit, Cancelled> {
|
||||
let started_at = start_counting_cancelled_wait(kind);
|
||||
let acquire = self.concurrency_limiter.acquire_owned(kind);
|
||||
|
||||
let permit = tokio::select! {
|
||||
permit = acquire => permit.expect("semaphore is never closed"),
|
||||
_ = cancel.cancelled() => return Err(Cancelled),
|
||||
};
|
||||
|
||||
let started_at = ScopeGuard::into_inner(started_at);
|
||||
crate::metrics::BUCKET_METRICS
|
||||
.wait_seconds
|
||||
.observe_elapsed(kind, started_at);
|
||||
Ok(permit)
|
||||
}
|
||||
|
||||
async fn put_object(
|
||||
&self,
|
||||
byte_stream: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
|
||||
fs_size: usize,
|
||||
to: &RemotePath,
|
||||
cancel: &CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
// https://cloud.google.com/storage/docs/xml-api/reference-headers#chunked
|
||||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(
|
||||
header::TRANSFER_ENCODING,
|
||||
header::HeaderValue::from_static("chunked"),
|
||||
);
|
||||
|
||||
// TODO Check if we need type 'multipart/related' file to attach metadata like Neon's S3
|
||||
// `.upload()` does.
|
||||
// https://cloud.google.com/storage/docs/uploading-objects#uploading-an-object
|
||||
let upload_uri = format!(
|
||||
"https://storage.googleapis.com/upload/storage/v1/b/{}/o/?uploadType=media&name={}",
|
||||
self.bucket_name.clone(),
|
||||
self.relative_path_to_gcs_object(to).trim_start_matches("/")
|
||||
);
|
||||
|
||||
let upload = Client::new()
|
||||
.post(upload_uri)
|
||||
.body(reqwest::Body::wrap_stream(byte_stream))
|
||||
.headers(headers)
|
||||
.bearer_auth(self.token_provider.token(GCS_SCOPES).await?.as_str())
|
||||
.send();
|
||||
|
||||
// We await it in a race against the Tokio timeout
|
||||
let upload = tokio::time::timeout(self.timeout, upload);
|
||||
let res = tokio::select! {
|
||||
res = upload => res,
|
||||
_ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(Ok(res)) => {
|
||||
if !res.status().is_success() {
|
||||
match res.status() {
|
||||
StatusCode::NOT_FOUND => {
|
||||
return Err(anyhow::anyhow!("GCS error: not found \n\t {:?}", res));
|
||||
}
|
||||
_ => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"GCS PUT response contained no response body \n\t {:?}",
|
||||
res
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Ok(Err(reqw)) => Err(reqw.into()),
|
||||
Err(_timeout) => Err(TimeoutOrCancel::Timeout.into()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn copy(
|
||||
&self,
|
||||
from: String,
|
||||
to: String,
|
||||
cancel: &CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
let kind = RequestKind::Copy;
|
||||
|
||||
let _permit = self.permit(kind, cancel).await?;
|
||||
|
||||
let timeout = tokio::time::sleep(self.timeout);
|
||||
|
||||
let started_at = start_measuring_requests(kind);
|
||||
|
||||
let copy_uri = format!(
|
||||
"https://storage.googleapis.com/storage/v1/b/{}/o/{}/copyTo/b/{}/o/{}",
|
||||
self.bucket_name.clone(),
|
||||
&from,
|
||||
self.bucket_name.clone(),
|
||||
&to
|
||||
);
|
||||
|
||||
let op = Client::new()
|
||||
.post(copy_uri)
|
||||
.bearer_auth(self.token_provider.token(GCS_SCOPES).await?.as_str())
|
||||
.send();
|
||||
|
||||
let res = tokio::select! {
|
||||
res = op => res,
|
||||
_ = timeout => return Err(TimeoutOrCancel::Timeout.into()),
|
||||
_ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
|
||||
};
|
||||
|
||||
let started_at = ScopeGuard::into_inner(started_at);
|
||||
crate::metrics::BUCKET_METRICS
|
||||
.req_seconds
|
||||
.observe_elapsed(kind, &res, started_at);
|
||||
|
||||
res?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_oids(
|
||||
&self,
|
||||
delete_objects: &[String],
|
||||
cancel: &CancellationToken,
|
||||
_permit: &tokio::sync::SemaphorePermit<'_>,
|
||||
) -> anyhow::Result<()> {
|
||||
let kind = RequestKind::Delete;
|
||||
let mut cancel = std::pin::pin!(cancel.cancelled());
|
||||
|
||||
for chunk in delete_objects.chunks(MAX_KEYS_PER_DELETE_GCS) {
|
||||
let started_at = start_measuring_requests(kind);
|
||||
|
||||
// Use this to report keys that didn't delete based on 'content_id'
|
||||
let mut delete_objects_status = HashMap::new();
|
||||
|
||||
let mut form = reqwest::multipart::Form::new();
|
||||
let bulk_uri = "https://storage.googleapis.com/batch/storage/v1";
|
||||
|
||||
for (index, path) in delete_objects.iter().enumerate() {
|
||||
delete_objects_status.insert(index + 1, path.clone());
|
||||
|
||||
let path_to_delete: String =
|
||||
url::form_urlencoded::byte_serialize(path.trim_start_matches("/").as_bytes())
|
||||
.collect();
|
||||
|
||||
let delete_req = format!(
|
||||
"
|
||||
DELETE /storage/v1/b/{}/o/{} HTTP/1.1\r\n\
|
||||
Content-Type: application/json\r\n\
|
||||
accept: application/json\r\n\
|
||||
content-length: 0\r\n
|
||||
",
|
||||
self.bucket_name.clone(),
|
||||
path_to_delete
|
||||
)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
let content_id = format!("<{}+{}>", Uuid::new_v4(), index + 1);
|
||||
|
||||
let mut part_headers = header::HeaderMap::new();
|
||||
part_headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/http"),
|
||||
);
|
||||
part_headers.insert(
|
||||
header::TRANSFER_ENCODING,
|
||||
header::HeaderValue::from_static("binary"),
|
||||
);
|
||||
part_headers.insert(
|
||||
header::HeaderName::from_static("content-id"),
|
||||
header::HeaderValue::from_str(&content_id)?,
|
||||
);
|
||||
let part = reqwest::multipart::Part::text(delete_req).headers(part_headers);
|
||||
|
||||
form = form.part(format!("request-{}", index), part);
|
||||
}
|
||||
|
||||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_str(&format!(
|
||||
"multipart/mixed; boundary={}",
|
||||
form.boundary()
|
||||
))?,
|
||||
);
|
||||
|
||||
let req = Client::new()
|
||||
.post(bulk_uri)
|
||||
.bearer_auth(self.token_provider.token(GCS_SCOPES).await?.as_str())
|
||||
.multipart(form)
|
||||
.headers(headers)
|
||||
.send();
|
||||
|
||||
let resp = tokio::select! {
|
||||
resp = req => resp,
|
||||
_ = tokio::time::sleep(self.timeout) => return Err(TimeoutOrCancel::Timeout.into()),
|
||||
_ = &mut cancel => return Err(TimeoutOrCancel::Cancel.into()),
|
||||
};
|
||||
|
||||
let started_at = ScopeGuard::into_inner(started_at);
|
||||
crate::metrics::BUCKET_METRICS
|
||||
.req_seconds
|
||||
.observe_elapsed(kind, &resp, started_at);
|
||||
|
||||
let resp = resp.context("request deletion")?;
|
||||
|
||||
crate::metrics::BUCKET_METRICS
|
||||
.deleted_objects_total
|
||||
.inc_by(chunk.len() as u64);
|
||||
|
||||
let res_headers = resp.headers().to_owned();
|
||||
|
||||
let boundary = res_headers
|
||||
.get(header::CONTENT_TYPE)
|
||||
.unwrap()
|
||||
.to_str()?
|
||||
.split("=")
|
||||
.last()
|
||||
.unwrap();
|
||||
|
||||
let res_body = resp.text().await?;
|
||||
|
||||
let parsed: HashMap<String, String> = res_body
|
||||
.split(&format!("--{}", boundary))
|
||||
.filter_map(|c| {
|
||||
let mut lines = c.lines();
|
||||
|
||||
let id = lines.find_map(|line| {
|
||||
line.strip_prefix("Content-ID:")
|
||||
.and_then(|suf| suf.split('+').last())
|
||||
.and_then(|suf| suf.split('>').next())
|
||||
.map(|x| x.trim().to_string())
|
||||
});
|
||||
|
||||
let status_code = lines.find_map(|line| {
|
||||
// Not sure if this protocol version shouldn't be so specific
|
||||
line.strip_prefix("HTTP/1.1")
|
||||
.and_then(|x| x.split_whitespace().next())
|
||||
.map(|x| x.trim().to_string())
|
||||
});
|
||||
|
||||
id.zip(status_code)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Gather failures
|
||||
let errors: HashMap<usize, &String> = parsed
|
||||
.iter()
|
||||
.filter_map(|(x, y)| {
|
||||
if y.chars().next() != Some('2') {
|
||||
x.parse::<usize>().ok().map(|v| (v, y))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !errors.is_empty() {
|
||||
// Report 10 of them like S3
|
||||
const LOG_UP_TO_N_ERRORS: usize = 10;
|
||||
for (id, code) in errors.iter().take(LOG_UP_TO_N_ERRORS) {
|
||||
tracing::warn!(
|
||||
"DeleteObjects key {} failed with code: {}",
|
||||
delete_objects_status.get(id).unwrap(),
|
||||
code
|
||||
);
|
||||
}
|
||||
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to delete {}/{} objects",
|
||||
errors.len(),
|
||||
chunk.len(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_objects_v2(&self, list_uri: String) -> anyhow::Result<reqwest::RequestBuilder> {
|
||||
let res = Client::new()
|
||||
.get(list_uri)
|
||||
.bearer_auth(self.token_provider.token(GCS_SCOPES).await?.as_str());
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
// need a 'bucket', a 'key', and a bytes 'range'.
|
||||
async fn get_object(
|
||||
&self,
|
||||
request: GetObjectRequest,
|
||||
cancel: &CancellationToken,
|
||||
) -> anyhow::Result<Download, DownloadError> {
|
||||
let kind = RequestKind::Get;
|
||||
|
||||
let permit = self.owned_permit(kind, cancel).await?;
|
||||
|
||||
let started_at = start_measuring_requests(kind);
|
||||
|
||||
let encoded_path: String =
|
||||
url::form_urlencoded::byte_serialize(request.key.as_bytes()).collect();
|
||||
|
||||
/// We do this in two parts:
|
||||
/// 1. Serialize the metadata of the first request to get Etag, last modified, etc
|
||||
/// 2. We do not .await the second request pass on the pinned stream to the 'get_object'
|
||||
/// caller
|
||||
// 1. Serialize Metadata in initial request
|
||||
let metadata_uri_mod = "alt=json";
|
||||
let download_uri = format!(
|
||||
"https://storage.googleapis.com/storage/v1/b/{}/o/{}?{}",
|
||||
self.bucket_name.clone(),
|
||||
encoded_path,
|
||||
metadata_uri_mod
|
||||
);
|
||||
|
||||
let res = Client::new()
|
||||
.get(download_uri)
|
||||
.bearer_auth(
|
||||
self.token_provider
|
||||
.token(GCS_SCOPES)
|
||||
.await
|
||||
.map_err(|e: gcp_auth::Error| DownloadError::Other(e.into()))?
|
||||
.as_str(),
|
||||
)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e: reqwest::Error| DownloadError::Other(e.into()))?;
|
||||
|
||||
if !res.status().is_success() {
|
||||
match res.status() {
|
||||
StatusCode::NOT_FOUND => return Err(DownloadError::NotFound),
|
||||
_ => {
|
||||
return Err(DownloadError::Other(anyhow::anyhow!(
|
||||
"GCS GET resposne contained no response body"
|
||||
)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let body = res
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e: reqwest::Error| DownloadError::Other(e.into()))?;
|
||||
|
||||
let resp: GCSObject = serde_json::from_str(&body)
|
||||
.map_err(|e: serde_json::Error| DownloadError::Other(e.into()))?;
|
||||
|
||||
// 2. Byte Stream request
|
||||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(header::RANGE, header::HeaderValue::from_static("bytes=0-"));
|
||||
|
||||
let encoded_path: String =
|
||||
url::form_urlencoded::byte_serialize(request.key.as_bytes()).collect();
|
||||
|
||||
let stream_uri_mod = "alt=media";
|
||||
let stream_uri = format!(
|
||||
"https://storage.googleapis.com/storage/v1/b/{}/o/{}?{}",
|
||||
self.bucket_name.clone(),
|
||||
encoded_path,
|
||||
stream_uri_mod
|
||||
);
|
||||
|
||||
let mut req = Client::new()
|
||||
.get(stream_uri)
|
||||
.headers(headers)
|
||||
.bearer_auth(
|
||||
self.token_provider
|
||||
.token(GCS_SCOPES)
|
||||
.await
|
||||
.map_err(|e: gcp_auth::Error| DownloadError::Other(e.into()))?
|
||||
.as_str(),
|
||||
)
|
||||
.send();
|
||||
|
||||
let get_object = tokio::select! {
|
||||
res = req => res,
|
||||
_ = tokio::time::sleep(self.timeout) => return Err(DownloadError::Timeout),
|
||||
_ = cancel.cancelled() => return Err(DownloadError::Cancelled),
|
||||
};
|
||||
|
||||
let started_at = ScopeGuard::into_inner(started_at);
|
||||
|
||||
let object_output = match get_object {
|
||||
Ok(object_output) => {
|
||||
if !object_output.status().is_success() {
|
||||
match object_output.status() {
|
||||
StatusCode::NOT_FOUND => return Err(DownloadError::NotFound),
|
||||
_ => {
|
||||
return Err(DownloadError::Other(anyhow::anyhow!(
|
||||
"GCS GET resposne contained no response body"
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
object_output
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
crate::metrics::BUCKET_METRICS.req_seconds.observe_elapsed(
|
||||
kind,
|
||||
AttemptOutcome::Err,
|
||||
started_at,
|
||||
);
|
||||
|
||||
return Err(DownloadError::Other(
|
||||
anyhow::Error::new(e).context("download s3 object"),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let remaining = self.timeout.saturating_sub(started_at.elapsed());
|
||||
|
||||
let metadata = resp.metadata.map(StorageMetadata);
|
||||
|
||||
let etag = resp
|
||||
.etag
|
||||
.ok_or(DownloadError::Other(anyhow::anyhow!("Missing ETag header")))?
|
||||
.into();
|
||||
|
||||
let last_modified: SystemTime = resp
|
||||
.updated
|
||||
.and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
|
||||
.map(|s| s.into())
|
||||
.unwrap_or(SystemTime::now());
|
||||
|
||||
// But let data stream pass through
|
||||
Ok(Download {
|
||||
download_stream: Box::pin(object_output.bytes_stream().map(|item| {
|
||||
item.map_err(|e: reqwest::Error| std::io::Error::new(std::io::ErrorKind::Other, e))
|
||||
})),
|
||||
etag,
|
||||
last_modified,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl RemoteStorage for GCSBucket {
|
||||
// ---------------------------------------
|
||||
// Neon wrappers for GCS client functions
|
||||
// ---------------------------------------
|
||||
|
||||
fn list_streaming(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
max_keys: Option<NonZeroU32>,
|
||||
cancel: &CancellationToken,
|
||||
) -> impl Stream<Item = Result<Listing, DownloadError>> {
|
||||
let kind = RequestKind::List;
|
||||
|
||||
let mut max_keys = max_keys.map(|mk| mk.get() as i32);
|
||||
|
||||
let list_prefix = prefix
|
||||
.map(|p| self.relative_path_to_gcs_object(p))
|
||||
.or_else(|| {
|
||||
self.prefix_in_bucket.clone().map(|mut s| {
|
||||
s.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
|
||||
s
|
||||
})
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let request_max_keys = self
|
||||
.max_keys_per_list_response
|
||||
.into_iter()
|
||||
.chain(max_keys.into_iter())
|
||||
.min()
|
||||
// https://cloud.google.com/storage/docs/json_api/v1/objects/list?hl=en#parameters
|
||||
// TODO set this to default
|
||||
.unwrap_or(1000);
|
||||
|
||||
// We pass URI in to `list_objects_v2` as we'll modify it with `NextPageToken`, hence
|
||||
// `mut`
|
||||
let mut list_uri = format!(
|
||||
"https://storage.googleapis.com/storage/v1/b/{}/o?prefix={}&maxResults={}",
|
||||
self.bucket_name.clone(),
|
||||
list_prefix,
|
||||
request_max_keys,
|
||||
);
|
||||
|
||||
// on ListingMode:
|
||||
// https://github.com/neondatabase/neon/blob/edc11253b65e12a10843711bd88ad277511396d7/libs/remote_storage/src/lib.rs#L158C1-L164C2
|
||||
if let ListingMode::WithDelimiter = mode {
|
||||
list_uri.push_str(&format!(
|
||||
"&delimiter={}",
|
||||
REMOTE_STORAGE_PREFIX_SEPARATOR.to_string()
|
||||
));
|
||||
}
|
||||
|
||||
async_stream::stream! {
|
||||
|
||||
let mut continuation_token = None;
|
||||
|
||||
'outer: loop {
|
||||
let started_at = start_measuring_requests(kind);
|
||||
|
||||
let request = self.list_objects_v2(list_uri.clone())
|
||||
.await
|
||||
.map_err(DownloadError::Other)?
|
||||
.send();
|
||||
|
||||
// this is like `await`
|
||||
let response = tokio::select! {
|
||||
res = request => Ok(res),
|
||||
_ = tokio::time::sleep(self.timeout) => Err(DownloadError::Timeout),
|
||||
_ = cancel.cancelled() => Err(DownloadError::Cancelled),
|
||||
}?;
|
||||
|
||||
// just mapping our `Result' error variant's type.
|
||||
let response = response
|
||||
.context("Failed to list GCS prefixes")
|
||||
.map_err(DownloadError::Other);
|
||||
|
||||
let started_at = ScopeGuard::into_inner(started_at);
|
||||
|
||||
crate::metrics::BUCKET_METRICS
|
||||
.req_seconds
|
||||
.observe_elapsed(kind, &response, started_at);
|
||||
|
||||
let response = match response {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
// The error is potentially retryable, so we must rewind the loop after yielding.
|
||||
yield Err(e);
|
||||
continue 'outer;
|
||||
},
|
||||
};
|
||||
|
||||
let body = response.text()
|
||||
.await
|
||||
.map_err(|e: reqwest::Error| DownloadError::Other(e.into()))?;
|
||||
|
||||
let resp: GCSListResponse = serde_json::from_str(&body).map_err(|e: serde_json::Error| DownloadError::Other(e.into()))?;
|
||||
|
||||
let prefixes = resp.common_prefixes();
|
||||
let keys = resp.contents();
|
||||
|
||||
tracing::debug!("list: {} prefixes, {} keys", prefixes.len(), keys.len());
|
||||
|
||||
let mut result = Listing::default();
|
||||
|
||||
for res in keys.iter() {
|
||||
|
||||
let last_modified: SystemTime = res.updated.clone()
|
||||
.and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
|
||||
.map(|s| s.into())
|
||||
.unwrap_or(SystemTime::now());
|
||||
|
||||
let size = res.size.clone().unwrap_or("0".to_string()).parse::<u64>().unwrap();
|
||||
|
||||
let key = res.name.clone();
|
||||
|
||||
result.keys.push(
|
||||
ListingObject{
|
||||
key: self.gcs_object_to_relative_path(&key),
|
||||
last_modified,
|
||||
size,
|
||||
}
|
||||
);
|
||||
|
||||
if let Some(mut mk) = max_keys {
|
||||
assert!(mk > 0);
|
||||
mk -= 1;
|
||||
if mk == 0 {
|
||||
tracing::debug!("reached limit set by max_keys");
|
||||
yield Ok(result);
|
||||
break 'outer;
|
||||
}
|
||||
max_keys = Some(mk);
|
||||
};
|
||||
}
|
||||
|
||||
result.prefixes.extend(prefixes.iter().filter_map(|p| {
|
||||
Some(
|
||||
self.gcs_object_to_relative_path(
|
||||
p.trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR)
|
||||
),
|
||||
)
|
||||
}));
|
||||
|
||||
yield Ok(result);
|
||||
|
||||
continuation_token = match resp.next_page_token {
|
||||
Some(token) => {
|
||||
list_uri = list_uri + "&pageToken=" + &token;
|
||||
Some(token)
|
||||
},
|
||||
None => break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn head_object(
|
||||
&self,
|
||||
key: &RemotePath,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<ListingObject, DownloadError> {
|
||||
let kind = RequestKind::Head;
|
||||
|
||||
todo!();
|
||||
}
|
||||
|
||||
async fn upload(
|
||||
&self,
|
||||
from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
|
||||
from_size_bytes: usize,
|
||||
to: &RemotePath,
|
||||
metadata: Option<StorageMetadata>,
|
||||
cancel: &CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
let kind = RequestKind::Put;
|
||||
let _permit = self.permit(kind, cancel).await?;
|
||||
|
||||
let started_at = start_measuring_requests(kind);
|
||||
|
||||
let upload = self.put_object(from, from_size_bytes, to, cancel);
|
||||
|
||||
let upload = tokio::time::timeout(self.timeout, upload);
|
||||
|
||||
let res = tokio::select! {
|
||||
res = upload => res,
|
||||
_ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
|
||||
};
|
||||
|
||||
if let Ok(inner) = &res {
|
||||
// do not incl. timeouts as errors in metrics but cancellations
|
||||
let started_at = ScopeGuard::into_inner(started_at);
|
||||
crate::metrics::BUCKET_METRICS
|
||||
.req_seconds
|
||||
.observe_elapsed(kind, inner, started_at);
|
||||
}
|
||||
|
||||
match res {
|
||||
Ok(Ok(_put)) => Ok(()),
|
||||
Ok(Err(sdk)) => Err(sdk.into()),
|
||||
Err(_timeout) => Err(TimeoutOrCancel::Timeout.into()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn copy(
|
||||
&self,
|
||||
from: &RemotePath,
|
||||
to: &RemotePath,
|
||||
cancel: &CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
let kind = RequestKind::Copy;
|
||||
let _permit = self.permit(kind, cancel).await?;
|
||||
|
||||
let timeout = tokio::time::sleep(self.timeout);
|
||||
|
||||
let started_at = start_measuring_requests(kind);
|
||||
|
||||
// we need to specify bucket_name as a prefix
|
||||
let copy_source = format!(
|
||||
"{}/{}",
|
||||
self.bucket_name,
|
||||
self.relative_path_to_gcs_object(from)
|
||||
);
|
||||
|
||||
todo!();
|
||||
}
|
||||
|
||||
async fn download(
|
||||
&self,
|
||||
from: &RemotePath,
|
||||
opts: &DownloadOpts,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<Download, DownloadError> {
|
||||
// if prefix is not none then download file `prefix/from`
|
||||
// if prefix is none then download file `from`
|
||||
|
||||
self.get_object(
|
||||
GetObjectRequest {
|
||||
bucket: self.bucket_name.clone(),
|
||||
key: self
|
||||
.relative_path_to_gcs_object(from)
|
||||
.trim_start_matches("/")
|
||||
.to_string(),
|
||||
etag: opts.etag.as_ref().map(|e| e.to_string()),
|
||||
range: opts.byte_range_header(),
|
||||
},
|
||||
cancel,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_objects(
|
||||
&self,
|
||||
paths: &[RemotePath],
|
||||
cancel: &CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
let kind = RequestKind::Delete;
|
||||
let permit = self.permit(kind, cancel).await?;
|
||||
|
||||
let mut delete_objects: Vec<String> = Vec::with_capacity(paths.len());
|
||||
|
||||
let delete_objects: Vec<String> = paths
|
||||
.iter()
|
||||
.map(|i| self.relative_path_to_gcs_object(i))
|
||||
.collect();
|
||||
|
||||
self.delete_oids(&delete_objects, cancel, &permit).await
|
||||
}
|
||||
|
||||
fn max_keys_per_delete(&self) -> usize {
|
||||
MAX_KEYS_PER_DELETE_GCS
|
||||
}
|
||||
|
||||
async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> {
|
||||
let paths = std::array::from_ref(path);
|
||||
self.delete_objects(paths, cancel).await
|
||||
}
|
||||
|
||||
async fn time_travel_recover(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
timestamp: SystemTime,
|
||||
done_if_after: SystemTime,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<(), TimeTravelError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct GCSListResponse {
|
||||
#[serde(rename = "nextPageToken")]
|
||||
pub next_page_token: Option<String>,
|
||||
pub items: Option<Vec<GCSObject>>,
|
||||
pub prefixes: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct GCSObject {
|
||||
pub name: String,
|
||||
pub bucket: String,
|
||||
pub generation: String,
|
||||
pub metageneration: String,
|
||||
#[serde(rename = "contentType")]
|
||||
pub content_type: Option<String>,
|
||||
#[serde(rename = "storageClass")]
|
||||
pub storage_class: String,
|
||||
pub size: Option<String>,
|
||||
#[serde(rename = "md5Hash")]
|
||||
pub md5_hash: Option<String>,
|
||||
pub crc32c: String,
|
||||
pub etag: Option<String>,
|
||||
#[serde(rename = "timeCreated")]
|
||||
pub time_created: String,
|
||||
pub updated: Option<String>,
|
||||
#[serde(rename = "timeStorageClassUpdated")]
|
||||
pub time_storage_class_updated: String,
|
||||
#[serde(rename = "timeFinalized")]
|
||||
pub time_finalized: String,
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl GCSListResponse {
|
||||
pub fn contents(&self) -> &[GCSObject] {
|
||||
self.items.as_deref().unwrap_or_default()
|
||||
}
|
||||
pub fn common_prefixes(&self) -> &[String] {
|
||||
self.prefixes.as_deref().unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use gcp_auth;
|
||||
use std::num::NonZero;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
const BUFFER_SIZE: usize = 32 * 1024;
|
||||
|
||||
// TODO what does Neon want here for integration tests?
|
||||
const BUCKET: &str = "https://storage.googleapis.com/storage/v1/b/my-test-bucket";
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_returns_keys_from_bucket() {
|
||||
let provider = gcp_auth::provider().await.unwrap();
|
||||
let gcs = GCSBucket {
|
||||
token_provider: Arc::clone(&provider),
|
||||
bucket_name: BUCKET.to_string(),
|
||||
prefix_in_bucket: None,
|
||||
max_keys_per_list_response: Some(100),
|
||||
concurrency_limiter: ConcurrencyLimiter::new(100),
|
||||
timeout: std::time::Duration::from_secs(120),
|
||||
};
|
||||
|
||||
let cancel = CancellationToken::new();
|
||||
let remote_prefix = "box/tiff/2023/TN".to_string();
|
||||
let max_keys: u32 = 100;
|
||||
let mut stream = pin!(gcs.list_streaming(Some(remote_prefix), NonZero::new(max_keys)));
|
||||
let mut combined = stream
|
||||
.next()
|
||||
.await
|
||||
.expect("At least one item required")
|
||||
.unwrap();
|
||||
while let Some(list) = stream.next().await {
|
||||
let list = list.unwrap();
|
||||
combined.keys.extend(list.keys.into_iter());
|
||||
combined.prefixes.extend_from_slice(&list.prefixes);
|
||||
}
|
||||
|
||||
for key in combined.keys.iter() {
|
||||
println!("Item: {} -- {:?}", key.key, key.last_modified);
|
||||
}
|
||||
|
||||
assert_ne!(0, combined.keys.len());
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,6 @@
|
||||
mod azure_blob;
|
||||
mod config;
|
||||
mod error;
|
||||
mod gcs_bucket;
|
||||
mod local_fs;
|
||||
mod metrics;
|
||||
mod s3_bucket;
|
||||
@@ -43,7 +42,6 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
|
||||
pub use self::azure_blob::AzureBlobStorage;
|
||||
pub use self::gcs_bucket::GCSBucket;
|
||||
pub use self::local_fs::LocalFs;
|
||||
pub use self::s3_bucket::S3Bucket;
|
||||
pub use self::simulate_failures::UnreliableWrapper;
|
||||
@@ -82,12 +80,8 @@ pub const MAX_KEYS_PER_DELETE_S3: usize = 1000;
|
||||
/// <https://learn.microsoft.com/en-us/rest/api/storageservices/blob-batch>
|
||||
pub const MAX_KEYS_PER_DELETE_AZURE: usize = 256;
|
||||
|
||||
pub const MAX_KEYS_PER_DELETE_GCS: usize = 1000;
|
||||
|
||||
const REMOTE_STORAGE_PREFIX_SEPARATOR: char = '/';
|
||||
|
||||
const GCS_SCOPES: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
||||
|
||||
/// Path on the remote storage, relative to some inner prefix.
|
||||
/// The prefix is an implementation detail, that allows representing local paths
|
||||
/// as the remote ones, stripping the local storage prefix away.
|
||||
@@ -445,7 +439,6 @@ pub enum GenericRemoteStorage<Other: Clone = Arc<UnreliableWrapper>> {
|
||||
AwsS3(Arc<S3Bucket>),
|
||||
AzureBlob(Arc<AzureBlobStorage>),
|
||||
Unreliable(Other),
|
||||
GCS(Arc<GCSBucket>),
|
||||
}
|
||||
|
||||
impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
@@ -462,7 +455,6 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
Self::AwsS3(s) => s.list(prefix, mode, max_keys, cancel).await,
|
||||
Self::AzureBlob(s) => s.list(prefix, mode, max_keys, cancel).await,
|
||||
Self::Unreliable(s) => s.list(prefix, mode, max_keys, cancel).await,
|
||||
Self::GCS(s) => s.list(prefix, mode, max_keys, cancel).await,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -480,7 +472,6 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
Self::AwsS3(s) => Box::pin(s.list_streaming(prefix, mode, max_keys, cancel)),
|
||||
Self::AzureBlob(s) => Box::pin(s.list_streaming(prefix, mode, max_keys, cancel)),
|
||||
Self::Unreliable(s) => Box::pin(s.list_streaming(prefix, mode, max_keys, cancel)),
|
||||
Self::GCS(s) => Box::pin(s.list_streaming(prefix, mode, max_keys, cancel)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -495,7 +486,6 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
Self::AwsS3(s) => s.head_object(key, cancel).await,
|
||||
Self::AzureBlob(s) => s.head_object(key, cancel).await,
|
||||
Self::Unreliable(s) => s.head_object(key, cancel).await,
|
||||
Self::GCS(_) => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -513,7 +503,6 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
Self::AwsS3(s) => s.upload(from, data_size_bytes, to, metadata, cancel).await,
|
||||
Self::AzureBlob(s) => s.upload(from, data_size_bytes, to, metadata, cancel).await,
|
||||
Self::Unreliable(s) => s.upload(from, data_size_bytes, to, metadata, cancel).await,
|
||||
Self::GCS(s) => s.upload(from, data_size_bytes, to, metadata, cancel).await,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -529,7 +518,6 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
Self::AwsS3(s) => s.download(from, opts, cancel).await,
|
||||
Self::AzureBlob(s) => s.download(from, opts, cancel).await,
|
||||
Self::Unreliable(s) => s.download(from, opts, cancel).await,
|
||||
Self::GCS(s) => s.download(from, opts, cancel).await,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -544,7 +532,6 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
Self::AwsS3(s) => s.delete(path, cancel).await,
|
||||
Self::AzureBlob(s) => s.delete(path, cancel).await,
|
||||
Self::Unreliable(s) => s.delete(path, cancel).await,
|
||||
Self::GCS(s) => s.delete(path, cancel).await,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -559,7 +546,6 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
Self::AwsS3(s) => s.delete_objects(paths, cancel).await,
|
||||
Self::AzureBlob(s) => s.delete_objects(paths, cancel).await,
|
||||
Self::Unreliable(s) => s.delete_objects(paths, cancel).await,
|
||||
Self::GCS(s) => s.delete_objects(paths, cancel).await,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -570,7 +556,6 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
Self::AwsS3(s) => s.max_keys_per_delete(),
|
||||
Self::AzureBlob(s) => s.max_keys_per_delete(),
|
||||
Self::Unreliable(s) => s.max_keys_per_delete(),
|
||||
Self::GCS(s) => s.max_keys_per_delete(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -585,7 +570,6 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
Self::AwsS3(s) => s.delete_prefix(prefix, cancel).await,
|
||||
Self::AzureBlob(s) => s.delete_prefix(prefix, cancel).await,
|
||||
Self::Unreliable(s) => s.delete_prefix(prefix, cancel).await,
|
||||
Self::GCS(_) => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -601,7 +585,6 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
Self::AwsS3(s) => s.copy(from, to, cancel).await,
|
||||
Self::AzureBlob(s) => s.copy(from, to, cancel).await,
|
||||
Self::Unreliable(s) => s.copy(from, to, cancel).await,
|
||||
Self::GCS(_) => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -630,25 +613,17 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
|
||||
.await
|
||||
}
|
||||
Self::GCS(_) => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GenericRemoteStorage {
|
||||
pub async fn from_config(storage_config: &RemoteStorageConfig) -> anyhow::Result<Self> {
|
||||
info!("RemoteStorageConfig: {:?}", storage_config);
|
||||
|
||||
let timeout = storage_config.timeout;
|
||||
|
||||
// If someone overrides timeout to be small without adjusting small_timeout, then adjust it automatically
|
||||
// If somkeone overrides timeout to be small without adjusting small_timeout, then adjust it automatically
|
||||
let small_timeout = std::cmp::min(storage_config.small_timeout, timeout);
|
||||
|
||||
info!(
|
||||
"RemoteStorageConfig's storage attribute: {:?}",
|
||||
storage_config.storage
|
||||
);
|
||||
|
||||
Ok(match &storage_config.storage {
|
||||
RemoteStorageKind::LocalFs { local_path: path } => {
|
||||
info!("Using fs root '{path}' as a remote storage");
|
||||
@@ -686,16 +661,6 @@ impl GenericRemoteStorage {
|
||||
small_timeout,
|
||||
)?))
|
||||
}
|
||||
RemoteStorageKind::GCS(gcs_config) => {
|
||||
let google_application_credentials =
|
||||
std::env::var("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
.unwrap_or_else(|_| "<none>".into());
|
||||
info!(
|
||||
"Using gcs bucket '{}' as a remote storage, prefix in bucket: '{:?}', GOOGLE_APPLICATION_CREDENTIALS: {google_application_credentials }",
|
||||
gcs_config.bucket_name, gcs_config.prefix_in_bucket
|
||||
);
|
||||
Self::GCS(Arc::new(GCSBucket::new(gcs_config, timeout).await?))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -725,7 +690,6 @@ impl GenericRemoteStorage {
|
||||
Self::AwsS3(s) => Some(s.bucket_name()),
|
||||
Self::AzureBlob(s) => Some(s.container_name()),
|
||||
Self::Unreliable(_s) => None,
|
||||
Self::GCS(s) => Some(s.bucket_name()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,6 @@ impl UnreliableWrapper {
|
||||
GenericRemoteStorage::Unreliable(_s) => {
|
||||
panic!("Can't wrap unreliable wrapper unreliably")
|
||||
}
|
||||
GenericRemoteStorage::GCS(_) => todo!(),
|
||||
};
|
||||
UnreliableWrapper {
|
||||
inner,
|
||||
|
||||
@@ -297,7 +297,7 @@ async fn handle_client(
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
|
||||
match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await {
|
||||
match copy_bidirectional_client_compute(&mut tls_stream, &mut client, |_, _| {}).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
|
||||
@@ -24,6 +24,7 @@ use crate::config::{
|
||||
use crate::context::parquet::ParquetUploadArgs;
|
||||
use crate::http::health_server::AppMetrics;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::conntrack::ConnectionTracking;
|
||||
use crate::rate_limiter::{
|
||||
EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter,
|
||||
};
|
||||
@@ -418,6 +419,8 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
64,
|
||||
));
|
||||
|
||||
let conntracking = Arc::new(ConnectionTracking::default());
|
||||
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
let mut client_tasks = JoinSet::new();
|
||||
@@ -431,6 +434,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
conntracking.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -453,6 +457,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
conntracking.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::proxy::conntrack::ConnectionTracking;
|
||||
use crate::proxy::handshake::{HandshakeData, handshake};
|
||||
use crate::proxy::passthrough::ProxyPassthrough;
|
||||
use crate::proxy::{
|
||||
@@ -25,6 +26,7 @@ pub async fn task_main(
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
conntracking: Arc<ConnectionTracking>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
@@ -50,6 +52,7 @@ pub async fn task_main(
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
let conntracking = Arc::clone(&conntracking);
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
|
||||
@@ -111,6 +114,7 @@ pub async fn task_main(
|
||||
socket,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
conntracking,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
@@ -167,6 +171,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
conntracking: Arc<ConnectionTracking>,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
@@ -264,6 +269,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
conntracking,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
}))
|
||||
|
||||
@@ -200,8 +200,10 @@ pub enum HttpDirection {
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
#[label(singleton = "direction")]
|
||||
pub enum Direction {
|
||||
Tx,
|
||||
Rx,
|
||||
#[label(rename = "tx")]
|
||||
ComputeToClient,
|
||||
#[label(rename = "rx")]
|
||||
ClientToCompute,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
|
||||
564
proxy/src/proxy/conntrack.rs
Normal file
564
proxy/src/proxy/conntrack.rs
Normal file
@@ -0,0 +1,564 @@
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct ConnId(usize);
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ConnectionTracking {
|
||||
conns: clashmap::ClashMap<ConnId, (ConnectionState, SystemTime)>,
|
||||
}
|
||||
|
||||
impl ConnectionTracking {
|
||||
pub fn new_tracker(self: &Arc<Self>) -> ConnectionTracker<Conn> {
|
||||
let conn_id = self.new_conn_id();
|
||||
ConnectionTracker::new(Conn {
|
||||
conn_id,
|
||||
tracking: Arc::clone(self),
|
||||
})
|
||||
}
|
||||
|
||||
fn new_conn_id(&self) -> ConnId {
|
||||
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
let id = ConnId(NEXT_ID.fetch_add(1, Ordering::Relaxed));
|
||||
self.conns
|
||||
.insert(id, (ConnectionState::Idle, SystemTime::now()));
|
||||
id
|
||||
}
|
||||
|
||||
fn update(&self, conn_id: ConnId, new_state: ConnectionState) {
|
||||
let new_timestamp = SystemTime::now();
|
||||
let old_state = self.conns.insert(conn_id, (new_state, new_timestamp));
|
||||
|
||||
if let Some((old_state, _old_timestamp)) = old_state {
|
||||
tracing::debug!(?conn_id, %old_state, %new_state, "conntrack: update");
|
||||
} else {
|
||||
tracing::debug!(?conn_id, %new_state, "conntrack: update");
|
||||
}
|
||||
}
|
||||
|
||||
fn remove(&self, conn_id: ConnId) {
|
||||
if let Some((_, (old_state, _old_timestamp))) = self.conns.remove(&conn_id) {
|
||||
tracing::debug!(?conn_id, %old_state, "conntrack: remove");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Conn {
|
||||
conn_id: ConnId,
|
||||
tracking: Arc<ConnectionTracking>,
|
||||
}
|
||||
|
||||
impl StateChangeObserver for Conn {
|
||||
fn change(&mut self, _old_state: ConnectionState, new_state: ConnectionState) {
|
||||
match new_state {
|
||||
ConnectionState::Init
|
||||
| ConnectionState::Idle
|
||||
| ConnectionState::Transaction
|
||||
| ConnectionState::Busy
|
||||
| ConnectionState::Unknown => self.tracking.update(self.conn_id, new_state),
|
||||
ConnectionState::Closed => self.tracking.remove(self.conn_id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Called by `ConnectionTracker` whenever the `ConnectionState` changed.
|
||||
pub trait StateChangeObserver {
|
||||
/// Called iff the connection's state changed.
|
||||
fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState);
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
|
||||
#[repr(u8)]
|
||||
pub enum ConnectionState {
|
||||
#[default]
|
||||
Init = 0,
|
||||
Idle = 1,
|
||||
Transaction = 2,
|
||||
Busy = 3,
|
||||
Closed = 4,
|
||||
Unknown = 5,
|
||||
}
|
||||
|
||||
impl fmt::Display for ConnectionState {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match *self {
|
||||
ConnectionState::Init => f.write_str("init"),
|
||||
ConnectionState::Idle => f.write_str("idle"),
|
||||
ConnectionState::Transaction => f.write_str("transaction"),
|
||||
ConnectionState::Busy => f.write_str("busy"),
|
||||
ConnectionState::Closed => f.write_str("closed"),
|
||||
ConnectionState::Unknown => f.write_str("unknown"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tracks the `ConnectionState` of a connection by inspecting the frontend and
|
||||
/// backend stream and reacting to specific messages. Used in combination with
|
||||
/// two `TrackedStream`s.
|
||||
pub struct ConnectionTracker<SCO: StateChangeObserver> {
|
||||
state: ConnectionState,
|
||||
observer: SCO,
|
||||
}
|
||||
|
||||
impl<SCO: StateChangeObserver> Drop for ConnectionTracker<SCO> {
|
||||
fn drop(&mut self) {
|
||||
self.observer.change(self.state, ConnectionState::Closed);
|
||||
}
|
||||
}
|
||||
|
||||
impl<SCO: StateChangeObserver> ConnectionTracker<SCO> {
|
||||
pub fn new(observer: SCO) -> Self {
|
||||
ConnectionTracker {
|
||||
state: ConnectionState::default(),
|
||||
observer,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn frontend_message_tag(&mut self, tag: Tag) {
|
||||
self.update_state(|old_state| Self::state_from_frontend_tag(old_state, tag));
|
||||
}
|
||||
|
||||
pub fn backend_message_tag(&mut self, tag: Tag) {
|
||||
self.update_state(|old_state| Self::state_from_backend_tag(old_state, tag));
|
||||
}
|
||||
|
||||
fn update_state(&mut self, new_state_fn: impl FnOnce(ConnectionState) -> ConnectionState) {
|
||||
let old_state = self.state;
|
||||
let new_state = new_state_fn(old_state);
|
||||
if old_state != new_state {
|
||||
self.observer.change(old_state, new_state);
|
||||
self.state = new_state;
|
||||
}
|
||||
}
|
||||
|
||||
fn state_from_frontend_tag(_old_state: ConnectionState, fe_tag: Tag) -> ConnectionState {
|
||||
// Most activity from the client puts connection into busy state.
|
||||
// Only the server can put a connection back into idle state.
|
||||
match fe_tag {
|
||||
Tag::Start | Tag::ReadyForQuery(_) | Tag::Message(_) => ConnectionState::Busy,
|
||||
Tag::End => ConnectionState::Closed,
|
||||
Tag::Lost => ConnectionState::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
fn state_from_backend_tag(old_state: ConnectionState, be_tag: Tag) -> ConnectionState {
|
||||
match be_tag {
|
||||
// Check for RFQ and put connection into idle or idle in transaction state.
|
||||
Tag::ReadyForQuery(b'I') => ConnectionState::Idle,
|
||||
Tag::ReadyForQuery(b'T') => ConnectionState::Transaction,
|
||||
Tag::ReadyForQuery(b'E') => ConnectionState::Transaction,
|
||||
// We can't put a connection into idle state for unknown RFQ status.
|
||||
Tag::ReadyForQuery(_) => ConnectionState::Unknown,
|
||||
// Ignore out-fo message from the server.
|
||||
Tag::NOTICE | Tag::NOTIFICATION_RESPONSE | Tag::PARAMETER_STATUS => old_state,
|
||||
// All other activity from server puts connection into busy state.
|
||||
Tag::Start | Tag::Message(_) => ConnectionState::Busy,
|
||||
|
||||
Tag::End => ConnectionState::Closed,
|
||||
Tag::Lost => ConnectionState::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum Tag {
|
||||
Message(u8),
|
||||
ReadyForQuery(u8),
|
||||
Start,
|
||||
End,
|
||||
Lost,
|
||||
}
|
||||
|
||||
impl Tag {
|
||||
const READY_FOR_QUERY: Tag = Tag::Message(b'Z');
|
||||
const NOTICE: Tag = Tag::Message(b'N');
|
||||
const NOTIFICATION_RESPONSE: Tag = Tag::Message(b'A');
|
||||
const PARAMETER_STATUS: Tag = Tag::Message(b'S');
|
||||
}
|
||||
|
||||
impl fmt::Display for Tag {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match *self {
|
||||
Tag::Start => f.write_str("start"),
|
||||
Tag::End => f.write_str("end"),
|
||||
Tag::Lost => f.write_str("lost"),
|
||||
Tag::Message(tag) => write!(f, "'{}'", tag as char),
|
||||
Tag::ReadyForQuery(status) => write!(f, "ReadyForQuery:'{}'", status as char),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait TagObserver {
|
||||
fn observe(&mut self, tag: Tag);
|
||||
}
|
||||
|
||||
impl<F: FnMut(Tag)> TagObserver for F {
|
||||
fn observe(&mut self, tag: Tag) {
|
||||
(self)(tag);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub(super) enum StreamScannerState {
|
||||
#[allow(dead_code)]
|
||||
/// Initial state when no message has been read and we are looling for a
|
||||
/// message without a tag.
|
||||
Start,
|
||||
/// Read a message tag.
|
||||
Tag,
|
||||
/// Read the length bytes and calculate the total length.
|
||||
Length {
|
||||
tag: Tag,
|
||||
/// Number of bytes missing to know the full length of the message: 0..=4
|
||||
length_bytes_missing: usize,
|
||||
/// Total length of the message (without tag) that is calculated as we
|
||||
/// read the bytes for the length.
|
||||
calculated_length: usize,
|
||||
},
|
||||
/// Read (= skip) the payload.
|
||||
Payload {
|
||||
tag: Tag,
|
||||
/// If this is the first time payload bytes are read. Important when
|
||||
/// inspecting specific messages, like ReadyForQuery.
|
||||
first: bool,
|
||||
/// Number of payload bytes left to read before looking for a new tag.
|
||||
bytes_to_skip: usize,
|
||||
},
|
||||
/// Stream was terminated.
|
||||
End,
|
||||
/// Stream ended up in a lost state. We only stop tracking the stream, not
|
||||
/// interrupt it.
|
||||
Lost,
|
||||
}
|
||||
|
||||
impl StreamScannerState {
|
||||
pub(super) fn scan_bytes<TO: TagObserver>(&mut self, mut buf: &[u8], observer: &mut TO) {
|
||||
use StreamScannerState as S;
|
||||
|
||||
if matches!(*self, S::End | S::Lost) {
|
||||
return;
|
||||
}
|
||||
if buf.is_empty() {
|
||||
match *self {
|
||||
S::Start | S::Tag => {
|
||||
observer.observe(Tag::End);
|
||||
*self = S::End;
|
||||
return;
|
||||
}
|
||||
S::Length { .. } | S::Payload { .. } => {
|
||||
observer.observe(Tag::Lost);
|
||||
*self = S::Lost;
|
||||
return;
|
||||
}
|
||||
S::End | S::Lost => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
while !buf.is_empty() {
|
||||
match *self {
|
||||
S::Start => {
|
||||
*self = S::Length {
|
||||
tag: Tag::Start,
|
||||
length_bytes_missing: 4,
|
||||
calculated_length: 0,
|
||||
};
|
||||
}
|
||||
|
||||
S::Tag => {
|
||||
let tag = buf.first().copied().expect("buf not empty");
|
||||
buf = &buf[1..];
|
||||
|
||||
*self = S::Length {
|
||||
tag: Tag::Message(tag),
|
||||
length_bytes_missing: 4,
|
||||
calculated_length: 0,
|
||||
};
|
||||
}
|
||||
|
||||
S::Length {
|
||||
tag,
|
||||
mut length_bytes_missing,
|
||||
mut calculated_length,
|
||||
} => {
|
||||
let consume = length_bytes_missing.min(buf.len());
|
||||
|
||||
let (length_bytes, remainder) = buf.split_at(consume);
|
||||
for b in length_bytes {
|
||||
calculated_length <<= 8;
|
||||
calculated_length |= *b as usize;
|
||||
}
|
||||
buf = remainder;
|
||||
|
||||
length_bytes_missing -= consume;
|
||||
if length_bytes_missing == 0 {
|
||||
let Some(bytes_to_skip) = calculated_length.checked_sub(4) else {
|
||||
observer.observe(Tag::Lost);
|
||||
*self = S::Lost;
|
||||
return;
|
||||
};
|
||||
|
||||
if bytes_to_skip == 0 {
|
||||
observer.observe(tag);
|
||||
*self = S::Tag;
|
||||
} else {
|
||||
*self = S::Payload {
|
||||
tag,
|
||||
first: true,
|
||||
bytes_to_skip,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
*self = S::Length {
|
||||
tag,
|
||||
length_bytes_missing,
|
||||
calculated_length,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
S::Payload {
|
||||
tag,
|
||||
first,
|
||||
mut bytes_to_skip,
|
||||
} => {
|
||||
let consume = bytes_to_skip.min(buf.len());
|
||||
bytes_to_skip -= consume;
|
||||
if bytes_to_skip == 0 {
|
||||
if tag == Tag::READY_FOR_QUERY && first && consume == 1 {
|
||||
let status = buf.first().copied().expect("buf not empty");
|
||||
observer.observe(Tag::ReadyForQuery(status));
|
||||
} else {
|
||||
observer.observe(tag);
|
||||
}
|
||||
*self = S::Tag;
|
||||
} else {
|
||||
*self = S::Payload {
|
||||
tag,
|
||||
first: false,
|
||||
bytes_to_skip,
|
||||
};
|
||||
}
|
||||
buf = &buf[consume..];
|
||||
}
|
||||
|
||||
S::End | S::Lost => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::cell::RefCell;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use tokio::io::{AsyncRead, ReadBuf};
|
||||
use tokio::io::{AsyncReadExt, BufReader};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_stream_scanner() {
|
||||
let tags = Rc::new(RefCell::new(Vec::new()));
|
||||
let observer_tags = tags.clone();
|
||||
let mut observer = move |tag| {
|
||||
observer_tags.borrow_mut().push(tag);
|
||||
};
|
||||
let mut state = StreamScannerState::Start;
|
||||
|
||||
state.scan_bytes(&[0, 0], &mut observer);
|
||||
assert_eq!(tags.borrow().as_slice(), &[]);
|
||||
assert_eq!(
|
||||
state,
|
||||
StreamScannerState::Length {
|
||||
tag: Tag::Start,
|
||||
length_bytes_missing: 2,
|
||||
calculated_length: 0,
|
||||
}
|
||||
);
|
||||
|
||||
state.scan_bytes(&[0x01, 0x01, 0x00], &mut observer);
|
||||
assert_eq!(tags.borrow().as_slice(), &[]);
|
||||
assert_eq!(
|
||||
state,
|
||||
StreamScannerState::Payload {
|
||||
tag: Tag::Start,
|
||||
first: false,
|
||||
bytes_to_skip: 0x00000101 - 4 - 1,
|
||||
}
|
||||
);
|
||||
|
||||
state.scan_bytes(vec![0; 0x00000101 - 4 - 1 - 1].as_slice(), &mut observer);
|
||||
assert_eq!(tags.borrow().as_slice(), &[]);
|
||||
assert_eq!(
|
||||
state,
|
||||
StreamScannerState::Payload {
|
||||
tag: Tag::Start,
|
||||
first: false,
|
||||
bytes_to_skip: 1,
|
||||
}
|
||||
);
|
||||
|
||||
state.scan_bytes(&[0x00, b'A', 0x00, 0x00, 0x00, 0x08], &mut observer);
|
||||
assert_eq!(tags.borrow().as_slice(), &[Tag::Start]);
|
||||
assert_eq!(
|
||||
state,
|
||||
StreamScannerState::Payload {
|
||||
tag: Tag::Message(b'A'),
|
||||
first: true,
|
||||
bytes_to_skip: 4,
|
||||
}
|
||||
);
|
||||
|
||||
state.scan_bytes(&[0, 0, 0, 0], &mut observer);
|
||||
assert_eq!(tags.borrow().as_slice(), &[Tag::Start, Tag::Message(b'A')]);
|
||||
assert_eq!(state, StreamScannerState::Tag);
|
||||
|
||||
state.scan_bytes(&[b'Z', 0x00, 0x00, 0x00, 0x05, b'T'], &mut observer);
|
||||
assert_eq!(
|
||||
tags.borrow().as_slice(),
|
||||
&[Tag::Start, Tag::Message(b'A'), Tag::ReadyForQuery(b'T')]
|
||||
);
|
||||
assert_eq!(state, StreamScannerState::Tag);
|
||||
|
||||
state.scan_bytes(&[], &mut observer);
|
||||
assert_eq!(
|
||||
tags.borrow().as_slice(),
|
||||
&[
|
||||
Tag::Start,
|
||||
Tag::Message(b'A'),
|
||||
Tag::ReadyForQuery(b'T'),
|
||||
Tag::End
|
||||
]
|
||||
);
|
||||
assert_eq!(state, StreamScannerState::End);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connection_tracker() {
|
||||
let transitions: Arc<Mutex<Vec<(ConnectionState, ConnectionState)>>> = Arc::default();
|
||||
struct Observer(Arc<Mutex<Vec<(ConnectionState, ConnectionState)>>>);
|
||||
impl StateChangeObserver for Observer {
|
||||
fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState) {
|
||||
self.0.lock().unwrap().push((old_state, new_state));
|
||||
}
|
||||
}
|
||||
let mut tracker = ConnectionTracker::new(Observer(transitions.clone()));
|
||||
|
||||
let stream = BufReader::new(
|
||||
&[
|
||||
0, 0, 0, 4, // Init
|
||||
b'Z', 0, 0, 0, 5, b'I', // Init -> Idle
|
||||
b'x', 0, 0, 0, 4, // Idle -> Busy
|
||||
b'Z', 0, 0, 0, 5, b'I', // Busy -> Idle
|
||||
][..],
|
||||
);
|
||||
// AsyncRead
|
||||
let mut stream = TrackedStream::new(stream, |tag| tracker.backend_message_tag(tag));
|
||||
|
||||
let mut readbuf = [0; 2];
|
||||
let n = stream.read_exact(&mut readbuf).await.unwrap();
|
||||
assert_eq!(n, 2);
|
||||
assert_eq!(&readbuf, &[0, 0,]);
|
||||
assert!(transitions.lock().unwrap().is_empty());
|
||||
|
||||
let mut readbuf = [0; 2];
|
||||
let n = stream.read_exact(&mut readbuf).await.unwrap();
|
||||
assert_eq!(n, 2);
|
||||
assert_eq!(&readbuf, &[0, 4]);
|
||||
assert_eq!(
|
||||
transitions.lock().unwrap().as_slice(),
|
||||
&[(ConnectionState::Init, ConnectionState::Busy)]
|
||||
);
|
||||
|
||||
let mut readbuf = [0; 6];
|
||||
let n = stream.read_exact(&mut readbuf).await.unwrap();
|
||||
assert_eq!(n, 6);
|
||||
assert_eq!(&readbuf, &[b'Z', 0, 0, 0, 5, b'I']);
|
||||
assert_eq!(
|
||||
transitions.lock().unwrap().as_slice(),
|
||||
&[
|
||||
(ConnectionState::Init, ConnectionState::Busy),
|
||||
(ConnectionState::Busy, ConnectionState::Idle),
|
||||
]
|
||||
);
|
||||
|
||||
let mut readbuf = [0; 5];
|
||||
let n = stream.read_exact(&mut readbuf).await.unwrap();
|
||||
assert_eq!(n, 5);
|
||||
assert_eq!(&readbuf, &[b'x', 0, 0, 0, 4]);
|
||||
assert_eq!(
|
||||
transitions.lock().unwrap().as_slice(),
|
||||
&[
|
||||
(ConnectionState::Init, ConnectionState::Busy),
|
||||
(ConnectionState::Busy, ConnectionState::Idle),
|
||||
(ConnectionState::Idle, ConnectionState::Busy),
|
||||
]
|
||||
);
|
||||
|
||||
let mut readbuf = [0; 6];
|
||||
let n = stream.read_exact(&mut readbuf).await.unwrap();
|
||||
assert_eq!(n, 6);
|
||||
assert_eq!(&readbuf, &[b'Z', 0, 0, 0, 5, b'I']);
|
||||
assert_eq!(
|
||||
transitions.lock().unwrap().as_slice(),
|
||||
&[
|
||||
(ConnectionState::Init, ConnectionState::Busy),
|
||||
(ConnectionState::Busy, ConnectionState::Idle),
|
||||
(ConnectionState::Idle, ConnectionState::Busy),
|
||||
(ConnectionState::Busy, ConnectionState::Idle),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
pub struct TrackedStream<S, TO> {
|
||||
stream: S,
|
||||
observer: TO,
|
||||
state: StreamScannerState,
|
||||
}
|
||||
|
||||
impl<S: Unpin, TO> Unpin for TrackedStream<S, TO> {}
|
||||
|
||||
impl<S: AsyncRead + Unpin, TO: TagObserver> TrackedStream<S, TO> {
|
||||
pub const fn new(stream: S, observer: TO) -> Self {
|
||||
TrackedStream {
|
||||
stream,
|
||||
observer,
|
||||
state: StreamScannerState::Start,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin, TO: TagObserver> AsyncRead for TrackedStream<S, TO> {
|
||||
#[inline]
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let Self {
|
||||
stream,
|
||||
observer,
|
||||
state,
|
||||
} = Pin::into_inner(self);
|
||||
|
||||
let old_len = buf.filled().len();
|
||||
match Pin::new(stream).poll_read(cx, buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let new_len = buf.filled().len();
|
||||
state.scan_bytes(&buf.filled()[old_len..new_len], observer);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,8 @@ use std::task::{Context, Poll, ready};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tracing::info;
|
||||
|
||||
use crate::metrics::Direction;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TransferState {
|
||||
Running(CopyBuffer),
|
||||
@@ -43,6 +45,8 @@ pub enum ErrorSource {
|
||||
fn transfer_one_direction<A, B>(
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut TransferState,
|
||||
direction: Direction,
|
||||
conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
r: &mut A,
|
||||
w: &mut B,
|
||||
) -> Poll<Result<u64, ErrorDirection>>
|
||||
@@ -55,7 +59,8 @@ where
|
||||
loop {
|
||||
match state {
|
||||
TransferState::Running(buf) => {
|
||||
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
|
||||
let count =
|
||||
ready!(buf.poll_copy(cx, direction, conn_tracker, r.as_mut(), w.as_mut()))?;
|
||||
*state = TransferState::ShuttingDown(count);
|
||||
}
|
||||
TransferState::ShuttingDown(count) => {
|
||||
@@ -71,6 +76,7 @@ where
|
||||
pub async fn copy_bidirectional_client_compute<Client, Compute>(
|
||||
client: &mut Client,
|
||||
compute: &mut Compute,
|
||||
mut conn_tracker: impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
) -> Result<(u64, u64), ErrorSource>
|
||||
where
|
||||
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
@@ -80,12 +86,24 @@ where
|
||||
let mut compute_to_client = TransferState::Running(CopyBuffer::new());
|
||||
|
||||
poll_fn(|cx| {
|
||||
let mut client_to_compute_result =
|
||||
transfer_one_direction(cx, &mut client_to_compute, client, compute)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
let mut compute_to_client_result =
|
||||
transfer_one_direction(cx, &mut compute_to_client, compute, client)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
let mut client_to_compute_result = transfer_one_direction(
|
||||
cx,
|
||||
&mut client_to_compute,
|
||||
Direction::ClientToCompute,
|
||||
&mut conn_tracker,
|
||||
client,
|
||||
compute,
|
||||
)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
let mut compute_to_client_result = transfer_one_direction(
|
||||
cx,
|
||||
&mut compute_to_client,
|
||||
Direction::ComputeToClient,
|
||||
&mut conn_tracker,
|
||||
compute,
|
||||
client,
|
||||
)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
|
||||
// TODO: 1 info log, with a enum label for close direction.
|
||||
|
||||
@@ -95,9 +113,15 @@ where
|
||||
info!("Compute is done, terminate client");
|
||||
// Initiate shutdown
|
||||
client_to_compute = TransferState::ShuttingDown(buf.amt);
|
||||
client_to_compute_result =
|
||||
transfer_one_direction(cx, &mut client_to_compute, client, compute)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
client_to_compute_result = transfer_one_direction(
|
||||
cx,
|
||||
&mut client_to_compute,
|
||||
Direction::ClientToCompute,
|
||||
&mut conn_tracker,
|
||||
client,
|
||||
compute,
|
||||
)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,9 +131,15 @@ where
|
||||
info!("Client is done, terminate compute");
|
||||
// Initiate shutdown
|
||||
compute_to_client = TransferState::ShuttingDown(buf.amt);
|
||||
compute_to_client_result =
|
||||
transfer_one_direction(cx, &mut compute_to_client, compute, client)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
compute_to_client_result = transfer_one_direction(
|
||||
cx,
|
||||
&mut compute_to_client,
|
||||
Direction::ComputeToClient,
|
||||
&mut conn_tracker,
|
||||
compute,
|
||||
client,
|
||||
)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,6 +178,8 @@ impl CopyBuffer {
|
||||
fn poll_fill_buf<R>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
direction: Direction,
|
||||
conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
reader: Pin<&mut R>,
|
||||
) -> Poll<io::Result<()>>
|
||||
where
|
||||
@@ -158,6 +190,8 @@ impl CopyBuffer {
|
||||
buf.set_filled(me.cap);
|
||||
|
||||
let res = reader.poll_read(cx, &mut buf);
|
||||
conn_tracker(direction, &buf.filled()[me.cap..]);
|
||||
|
||||
if let Poll::Ready(Ok(())) = res {
|
||||
let filled_len = buf.filled().len();
|
||||
me.read_done = me.cap == filled_len;
|
||||
@@ -169,6 +203,8 @@ impl CopyBuffer {
|
||||
fn poll_write_buf<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
direction: Direction,
|
||||
conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<usize, ErrorDirection>>
|
||||
@@ -182,7 +218,8 @@ impl CopyBuffer {
|
||||
// Top up the buffer towards full if we can read a bit more
|
||||
// data - this should improve the chances of a large write
|
||||
if !me.read_done && me.cap < me.buf.len() {
|
||||
ready!(me.poll_fill_buf(cx, reader.as_mut())).map_err(ErrorDirection::Read)?;
|
||||
ready!(me.poll_fill_buf(cx, direction, conn_tracker, reader.as_mut()))
|
||||
.map_err(ErrorDirection::Read)?;
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
@@ -193,6 +230,8 @@ impl CopyBuffer {
|
||||
pub(super) fn poll_copy<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
direction: Direction,
|
||||
conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<u64, ErrorDirection>>
|
||||
@@ -204,7 +243,7 @@ impl CopyBuffer {
|
||||
// If there is some space left in our buffer, then we try to read some
|
||||
// data to continue, thus maximizing the chances of a large write.
|
||||
if self.cap < self.buf.len() && !self.read_done {
|
||||
match self.poll_fill_buf(cx, reader.as_mut()) {
|
||||
match self.poll_fill_buf(cx, direction, conn_tracker, reader.as_mut()) {
|
||||
Poll::Ready(Ok(())) => (),
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))),
|
||||
Poll::Pending => {
|
||||
@@ -227,7 +266,13 @@ impl CopyBuffer {
|
||||
|
||||
// If our buffer has some data, let's write it out!
|
||||
while self.pos < self.cap {
|
||||
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
|
||||
let i = ready!(self.poll_write_buf(
|
||||
cx,
|
||||
direction,
|
||||
conn_tracker,
|
||||
reader.as_mut(),
|
||||
writer.as_mut()
|
||||
))?;
|
||||
if i == 0 {
|
||||
return Poll::Ready(Err(ErrorDirection::Write(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
@@ -278,9 +323,10 @@ mod tests {
|
||||
compute_client.write_all(b"Neon").await.unwrap();
|
||||
compute_client.shutdown().await.unwrap();
|
||||
|
||||
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
|
||||
.await
|
||||
.unwrap();
|
||||
let result =
|
||||
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (client_to_compute_count, compute_to_client_count) = result;
|
||||
@@ -301,9 +347,10 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
|
||||
.await
|
||||
.unwrap();
|
||||
let result =
|
||||
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (client_to_compute_count, compute_to_client_count) = result;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
mod tests;
|
||||
|
||||
pub(crate) mod connect_compute;
|
||||
pub mod conntrack;
|
||||
mod copy_bidirectional;
|
||||
pub(crate) mod handshake;
|
||||
pub(crate) mod passthrough;
|
||||
@@ -30,6 +31,7 @@ use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::conntrack::ConnectionTracking;
|
||||
use crate::proxy::handshake::{HandshakeData, handshake};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
@@ -60,6 +62,7 @@ pub async fn task_main(
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conntracking: Arc<ConnectionTracking>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
@@ -85,6 +88,7 @@ pub async fn task_main(
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
let conntracking = Arc::clone(&conntracking);
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
@@ -149,6 +153,7 @@ pub async fn task_main(
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
conntracking,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
@@ -268,6 +273,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
conntracking: Arc<ConnectionTracking>,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
@@ -409,6 +415,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
conntracking,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
}))
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::debug;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
use super::copy_bidirectional::ErrorSource;
|
||||
use crate::cancellation;
|
||||
@@ -9,16 +10,19 @@ use crate::compute::PostgresConnection;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::proxy::conntrack::{ConnectionTracking, StreamScannerState};
|
||||
use crate::proxy::copy_bidirectional::copy_bidirectional_client_compute;
|
||||
use crate::stream::Stream;
|
||||
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
|
||||
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn proxy_pass(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
mut client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
mut compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: MetricsAuxInfo,
|
||||
private_link_id: Option<SmolStr>,
|
||||
conntracking: &Arc<ConnectionTracking>,
|
||||
) -> Result<(), ErrorSource> {
|
||||
// we will report ingress at a later date
|
||||
let usage_tx = USAGE_METRICS.register(Ids {
|
||||
@@ -27,35 +31,35 @@ pub(crate) async fn proxy_pass(
|
||||
private_link_id,
|
||||
});
|
||||
|
||||
let mut conn_tracker = conntracking.new_tracker();
|
||||
|
||||
let metrics = &Metrics::get().proxy.io_bytes;
|
||||
let m_sent = metrics.with_labels(Direction::Tx);
|
||||
let mut client = MeasuredStream::new(
|
||||
client,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes we sent to the client (outbound).
|
||||
metrics.get_metric(m_sent).inc_by(cnt as u64);
|
||||
usage_tx.record_egress(cnt as u64);
|
||||
},
|
||||
);
|
||||
let m_sent = metrics.with_labels(Direction::ComputeToClient);
|
||||
let m_recv = metrics.with_labels(Direction::ClientToCompute);
|
||||
|
||||
let m_recv = metrics.with_labels(Direction::Rx);
|
||||
let mut compute = MeasuredStream::new(
|
||||
compute,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes the client sent to the compute node (inbound).
|
||||
metrics.get_metric(m_recv).inc_by(cnt as u64);
|
||||
usage_tx.record_ingress(cnt as u64);
|
||||
},
|
||||
);
|
||||
let mut client_to_compute = StreamScannerState::Tag;
|
||||
let mut compute_to_client = StreamScannerState::Tag;
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
debug!("performing the proxy pass...");
|
||||
let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute(
|
||||
&mut client,
|
||||
&mut compute,
|
||||
)
|
||||
|
||||
let _ = copy_bidirectional_client_compute(&mut client, &mut compute, |direction, bytes| {
|
||||
match direction {
|
||||
Direction::ClientToCompute => {
|
||||
client_to_compute
|
||||
.scan_bytes(bytes, &mut |tag| conn_tracker.frontend_message_tag(tag));
|
||||
|
||||
metrics.get_metric(m_recv).inc_by(bytes.len() as u64);
|
||||
usage_tx.record_ingress(bytes.len() as u64);
|
||||
}
|
||||
Direction::ComputeToClient => {
|
||||
compute_to_client
|
||||
.scan_bytes(bytes, &mut |tag| conn_tracker.backend_message_tag(tag));
|
||||
|
||||
metrics.get_metric(m_sent).inc_by(bytes.len() as u64);
|
||||
usage_tx.record_egress(bytes.len() as u64);
|
||||
}
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
@@ -68,6 +72,7 @@ pub(crate) struct ProxyPassthrough<S> {
|
||||
pub(crate) session_id: uuid::Uuid,
|
||||
pub(crate) private_link_id: Option<SmolStr>,
|
||||
pub(crate) cancel: cancellation::Session,
|
||||
pub(crate) conntracking: Arc<ConnectionTracking>,
|
||||
|
||||
pub(crate) _req: NumConnectionRequestsGuard<'static>,
|
||||
pub(crate) _conn: NumClientConnectionsGuard<'static>,
|
||||
@@ -83,6 +88,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
self.compute.stream,
|
||||
self.aux,
|
||||
self.private_link_id,
|
||||
&self.conntracking,
|
||||
)
|
||||
.await;
|
||||
if let Err(err) = self
|
||||
|
||||
@@ -50,6 +50,7 @@ use crate::context::RequestContext;
|
||||
use crate::ext::TaskExt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::protocol2::{ChainRW, ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::conntrack::ConnectionTracking;
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
@@ -124,6 +125,9 @@ pub async fn task_main(
|
||||
connections.close(); // allows `connections.wait to complete`
|
||||
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
let conntracking = Arc::new(ConnectionTracking::default());
|
||||
|
||||
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
|
||||
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
|
||||
if let Err(e) = conn.set_nodelay(true) {
|
||||
@@ -153,6 +157,8 @@ pub async fn task_main(
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellations = cancellations.clone();
|
||||
let conntracking = Arc::clone(&conntracking);
|
||||
|
||||
connections.spawn(
|
||||
async move {
|
||||
let conn_token2 = conn_token.clone();
|
||||
@@ -185,6 +191,7 @@ pub async fn task_main(
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
conn_token,
|
||||
conntracking,
|
||||
conn,
|
||||
conn_info,
|
||||
session_id,
|
||||
@@ -309,6 +316,7 @@ async fn connection_handler(
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_token: CancellationToken,
|
||||
conntracking: Arc<ConnectionTracking>,
|
||||
conn: AsyncRW,
|
||||
conn_info: ConnectionInfo,
|
||||
session_id: uuid::Uuid,
|
||||
@@ -347,6 +355,7 @@ async fn connection_handler(
|
||||
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
|
||||
// By spawning the future, we ensure it never gets cancelled until it decides to.
|
||||
let cancellations = cancellations.clone();
|
||||
let conntracking = Arc::clone(&conntracking);
|
||||
let handler = connections.spawn(
|
||||
request_handler(
|
||||
req,
|
||||
@@ -359,6 +368,7 @@ async fn connection_handler(
|
||||
http_request_token,
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellations,
|
||||
conntracking,
|
||||
)
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
@@ -407,6 +417,7 @@ async fn request_handler(
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellations: TaskTracker,
|
||||
conntracking: Arc<ConnectionTracking>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
@@ -452,6 +463,7 @@ async fn request_handler(
|
||||
endpoint_rate_limiter,
|
||||
host,
|
||||
cancellations,
|
||||
conntracking,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
@@ -17,6 +17,7 @@ use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::conntrack::ConnectionTracking;
|
||||
use crate::proxy::{ClientMode, ErrorSource, handle_client};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
|
||||
@@ -133,6 +134,7 @@ pub(crate) async fn serve_websocket(
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
hostname: Option<String>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
conntracking: Arc<ConnectionTracking>,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
|
||||
@@ -152,6 +154,7 @@ pub(crate) async fn serve_websocket(
|
||||
endpoint_rate_limiter,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
conntracking,
|
||||
))
|
||||
.await;
|
||||
|
||||
|
||||
@@ -271,7 +271,6 @@ impl BucketConfig {
|
||||
"container {}, storage account {:?}, region {}",
|
||||
config.container_name, config.storage_account, config.container_region
|
||||
),
|
||||
RemoteStorageKind::GCS(config) => format!("bucket {}", config.bucket_name),
|
||||
}
|
||||
}
|
||||
pub fn bucket_name(&self) -> Option<&str> {
|
||||
@@ -419,9 +418,6 @@ async fn init_remote(
|
||||
config.prefix_in_container.get_or_insert(default_prefix);
|
||||
}
|
||||
RemoteStorageKind::LocalFs { .. } => (),
|
||||
RemoteStorageKind::GCS(config) => {
|
||||
config.prefix_in_bucket.get_or_insert(default_prefix);
|
||||
}
|
||||
}
|
||||
|
||||
// We already pass the prefix to the remote client above
|
||||
|
||||
Reference in New Issue
Block a user