Compare commits

..

1 Commits

Author SHA1 Message Date
John Crowley
d4326bfaf4 Initial implementation of GCS provider. 2025-04-22 13:46:57 -05:00
15 changed files with 1483 additions and 287 deletions

View File

@@ -133,7 +133,6 @@ runs:
fi
PERF_REPORT_DIR="$(realpath test_runner/perf-report-local)"
echo "PERF_REPORT_DIR=${PERF_REPORT_DIR}" >> ${GITHUB_ENV}
rm -rf $PERF_REPORT_DIR
TEST_SELECTION="test_runner/${{ inputs.test_selection }}"
@@ -210,12 +209,11 @@ runs:
--verbose \
-rA $TEST_SELECTION $EXTRA_PARAMS
- name: Upload performance report
if: ${{ !cancelled() && inputs.save_perf_report == 'true' }}
shell: bash -euxo pipefail {0}
run: |
export REPORT_FROM="${PERF_REPORT_DIR}"
scripts/generate_and_push_perf_report.sh
if [[ "${{ inputs.save_perf_report }}" == "true" ]]; then
export REPORT_FROM="$PERF_REPORT_DIR"
export REPORT_TO="$PLATFORM"
scripts/generate_and_push_perf_report.sh
fi
- name: Upload compatibility snapshot
# Note, that we use `github.base_ref` which is a target branch for a PR

86
Cargo.lock generated
View File

@@ -2424,6 +2424,33 @@ dependencies = [
"slab",
]
[[package]]
name = "gcp_auth"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbf67f30198e045a039264c01fb44659ce82402d7771c50938beb41a5ac87733"
dependencies = [
"async-trait",
"base64 0.22.1",
"bytes",
"chrono",
"home",
"http 1.1.0",
"http-body-util",
"hyper 1.4.1",
"hyper-rustls 0.27.5",
"hyper-util",
"ring",
"rustls-pemfile 2.1.1",
"serde",
"serde_json",
"thiserror 1.0.69",
"tokio",
"tracing",
"tracing-futures",
"url",
]
[[package]]
name = "gen_ops"
version = "0.4.0"
@@ -2722,6 +2749,15 @@ dependencies = [
"digest",
]
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "hostname"
version = "0.4.0"
@@ -2951,6 +2987,24 @@ dependencies = [
"tower-service",
]
[[package]]
name = "hyper-rustls"
version = "0.27.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
dependencies = [
"futures-util",
"http 1.1.0",
"hyper 1.4.1",
"hyper-util",
"rustls 0.23.18",
"rustls-native-certs 0.8.0",
"rustls-pki-types",
"tokio",
"tokio-rustls 0.26.0",
"tower-service",
]
[[package]]
name = "hyper-timeout"
version = "0.5.1"
@@ -3706,6 +3760,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@@ -5522,8 +5586,11 @@ dependencies = [
"bytes",
"camino",
"camino-tempfile",
"chrono",
"futures",
"futures-util",
"gcp_auth",
"http 1.1.0",
"http-body-util",
"http-types",
"humantime-serde",
@@ -5544,7 +5611,9 @@ dependencies = [
"tokio-util",
"toml_edit",
"tracing",
"url",
"utils",
"uuid",
]
[[package]]
@@ -5574,6 +5643,7 @@ dependencies = [
"js-sys",
"log",
"mime",
"mime_guess",
"once_cell",
"percent-encoding",
"pin-project-lite",
@@ -7597,6 +7667,16 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "tracing-futures"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2"
dependencies = [
"pin-project",
"tracing",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
@@ -7750,6 +7830,12 @@ dependencies = [
"libc",
]
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-bidi"
version = "0.3.17"

View File

@@ -1677,7 +1677,7 @@ RUN set -e \
&& apt clean && rm -rf /var/lib/apt/lists/*
# Use `dist_man_MANS=` to skip manpage generation (which requires python3/pandoc)
ENV PGBOUNCER_TAG=pgbouncer_1_24_1
ENV PGBOUNCER_TAG=pgbouncer_1_22_1
RUN set -e \
&& git clone --recurse-submodules --depth 1 --branch ${PGBOUNCER_TAG} https://github.com/pgbouncer/pgbouncer.git pgbouncer \
&& cd pgbouncer \

View File

@@ -18,7 +18,8 @@ camino = { workspace = true, features = ["serde1"] }
humantime-serde.workspace = true
hyper = { workspace = true, features = ["client"] }
futures.workspace = true
reqwest.workspace = true
reqwest = { workspace = true, features = ["multipart", "stream"] }
chrono = { version = "0.4", default-features = false, features = ["clock"] }
serde.workspace = true
serde_json.workspace = true
tokio = { workspace = true, features = ["sync", "fs", "io-util"] }
@@ -40,6 +41,10 @@ http-types.workspace = true
http-body-util.workspace = true
itertools.workspace = true
sync_wrapper = { workspace = true, features = ["futures"] }
gcp_auth = "0.12.3"
url.workspace = true
http.workspace = true
uuid.workspace = true
[dev-dependencies]
camino-tempfile.workspace = true

View File

@@ -41,6 +41,7 @@ impl RemoteStorageKind {
RemoteStorageKind::LocalFs { .. } => None,
RemoteStorageKind::AwsS3(config) => Some(&config.bucket_name),
RemoteStorageKind::AzureContainer(config) => Some(&config.container_name),
RemoteStorageKind::GCS(config) => Some(&config.bucket_name),
}
}
}
@@ -51,6 +52,7 @@ impl RemoteStorageConfig {
match &self.storage {
RemoteStorageKind::LocalFs { .. } => DEFAULT_REMOTE_STORAGE_LOCALFS_CONCURRENCY_LIMIT,
RemoteStorageKind::AwsS3(c) => c.concurrency_limit.into(),
RemoteStorageKind::GCS(c) => c.concurrency_limit.into(),
RemoteStorageKind::AzureContainer(c) => c.concurrency_limit.into(),
}
}
@@ -85,6 +87,9 @@ pub enum RemoteStorageKind {
/// Azure Blob based storage, storing all files in the container
/// specified by the config
AzureContainer(AzureConfig),
/// Google Cloud based storage, storing all files in the GCS bucket
/// specified by the config
GCS(GCSConfig),
}
/// AWS S3 bucket coordinates and access credentials to manage the bucket contents (read and write).
@@ -154,6 +159,32 @@ impl Debug for S3Config {
}
}
#[derive(Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct GCSConfig {
/// Name of the bucket to connect to.
pub bucket_name: String,
/// A "subfolder" in the bucket, to use the same bucket separately by multiple remote storage users at once.
pub prefix_in_bucket: Option<String>,
#[serde(default = "default_remote_storage_s3_concurrency_limit")]
pub concurrency_limit: NonZeroUsize,
#[serde(default = "default_max_keys_per_list_response")]
pub max_keys_per_list_response: Option<i32>,
}
impl Debug for GCSConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GCSConfig")
.field("bucket_name", &self.bucket_name)
.field("prefix_in_bucket", &self.prefix_in_bucket)
.field("concurrency_limit", &self.concurrency_limit)
.field(
"max_keys_per_list_response",
&self.max_keys_per_list_response,
)
.finish()
}
}
/// Azure bucket coordinates and access credentials to manage the bucket contents (read and write).
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AzureConfig {
@@ -268,6 +299,30 @@ timeout = '5s'";
);
}
#[test]
fn test_gcs_parsing() {
let toml = "\
bucket_name = 'foo-bar'
prefix_in_bucket = '/pageserver'
";
let config = parse(toml).unwrap();
assert_eq!(
config,
RemoteStorageConfig {
storage: RemoteStorageKind::GCS(GCSConfig {
bucket_name: "foo-bar".into(),
prefix_in_bucket: Some("pageserver/".into()),
max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
concurrency_limit: std::num::NonZero::new(100).unwrap(),
}),
timeout: Duration::from_secs(120),
small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT
}
);
}
#[test]
fn test_s3_parsing() {
let toml = "\

View File

@@ -0,0 +1,978 @@
#![allow(dead_code)]
#![allow(unused)]
use crate::config::GCSConfig;
use crate::error::Cancelled;
pub(super) use crate::metrics::RequestKind;
use crate::metrics::{AttemptOutcome, start_counting_cancelled_wait, start_measuring_requests};
use crate::{
ConcurrencyLimiter, Download, DownloadError, DownloadOpts, GCS_SCOPES, Listing, ListingMode,
ListingObject, MAX_KEYS_PER_DELETE_GCS, REMOTE_STORAGE_PREFIX_SEPARATOR, RemotePath,
RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
};
use anyhow::Context;
use azure_core::Etag;
use bytes::Bytes;
use bytes::BytesMut;
use chrono::DateTime;
use futures::stream::Stream;
use futures::stream::TryStreamExt;
use futures_util::StreamExt;
use gcp_auth::{Token, TokenProvider};
use http::Method;
use http::StatusCode;
use reqwest::{Client, header};
use scopeguard::ScopeGuard;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Debug;
use std::num::NonZeroU32;
use std::pin::{Pin, pin};
use std::sync::Arc;
use std::time::Duration;
use std::time::SystemTime;
use tokio_util::codec::{BytesCodec, FramedRead};
use tokio_util::sync::CancellationToken;
use tracing;
use url::Url;
use uuid::Uuid;
// ---------
pub struct GCSBucket {
token_provider: Arc<dyn TokenProvider>,
bucket_name: String,
prefix_in_bucket: Option<String>,
max_keys_per_list_response: Option<i32>,
concurrency_limiter: ConcurrencyLimiter,
pub timeout: Duration,
}
struct GetObjectRequest {
bucket: String,
key: String,
etag: Option<String>,
range: Option<String>,
}
// ---------
impl GCSBucket {
pub async fn new(remote_storage_config: &GCSConfig, timeout: Duration) -> anyhow::Result<Self> {
tracing::debug!(
"creating remote storage for gcs bucket {}",
remote_storage_config.bucket_name
);
// clean up 'prefix_in_bucket' if user provides '/pageserver' or 'pageserver/'
let prefix_in_bucket = remote_storage_config
.prefix_in_bucket
.as_deref()
.map(|prefix| {
let mut prefix = prefix;
while prefix.starts_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
prefix = &prefix[1..];
}
let mut prefix = prefix.to_string();
if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
prefix.pop();
}
prefix
});
// get GOOGLE_APPLICATION_CREDENTIALS
let provider = gcp_auth::provider().await?;
Ok(GCSBucket {
token_provider: Arc::clone(&provider),
bucket_name: remote_storage_config.bucket_name.clone(),
prefix_in_bucket,
timeout,
max_keys_per_list_response: remote_storage_config.max_keys_per_list_response,
concurrency_limiter: ConcurrencyLimiter::new(
remote_storage_config.concurrency_limit.get(),
),
})
}
// convert `RemotePath` -> `String`
pub fn relative_path_to_gcs_object(&self, path: &RemotePath) -> String {
let path_string = path.get_path().as_str();
match &self.prefix_in_bucket {
Some(prefix) => prefix.clone() + "/" + path_string,
None => path_string.to_string(),
}
}
// convert `String` -> `RemotePath`
pub fn gcs_object_to_relative_path(&self, key: &str) -> RemotePath {
let relative_path =
match key.strip_prefix(self.prefix_in_bucket.as_deref().unwrap_or_default()) {
Some(stripped) => stripped,
// we rely on GCS to return properly prefixed paths
// for requests with a certain prefix
None => panic!(
"Key {} does not start with bucket prefix {:?}",
key, self.prefix_in_bucket
),
};
RemotePath(
relative_path
.split(REMOTE_STORAGE_PREFIX_SEPARATOR)
.collect(),
)
}
pub fn bucket_name(&self) -> &str {
&self.bucket_name
}
fn max_keys_per_delete(&self) -> usize {
MAX_KEYS_PER_DELETE_GCS
}
async fn permit(
&self,
kind: RequestKind,
cancel: &CancellationToken,
) -> Result<tokio::sync::SemaphorePermit<'_>, Cancelled> {
let started_at = start_counting_cancelled_wait(kind);
let acquire = self.concurrency_limiter.acquire(kind);
let permit = tokio::select! {
permit = acquire => permit.expect("semaphore is never closed"),
_ = cancel.cancelled() => return Err(Cancelled),
};
let started_at = ScopeGuard::into_inner(started_at);
crate::metrics::BUCKET_METRICS
.wait_seconds
.observe_elapsed(kind, started_at);
Ok(permit)
}
async fn owned_permit(
&self,
kind: RequestKind,
cancel: &CancellationToken,
) -> Result<tokio::sync::OwnedSemaphorePermit, Cancelled> {
let started_at = start_counting_cancelled_wait(kind);
let acquire = self.concurrency_limiter.acquire_owned(kind);
let permit = tokio::select! {
permit = acquire => permit.expect("semaphore is never closed"),
_ = cancel.cancelled() => return Err(Cancelled),
};
let started_at = ScopeGuard::into_inner(started_at);
crate::metrics::BUCKET_METRICS
.wait_seconds
.observe_elapsed(kind, started_at);
Ok(permit)
}
async fn put_object(
&self,
byte_stream: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
fs_size: usize,
to: &RemotePath,
cancel: &CancellationToken,
) -> anyhow::Result<()> {
// https://cloud.google.com/storage/docs/xml-api/reference-headers#chunked
let mut headers = header::HeaderMap::new();
headers.insert(
header::TRANSFER_ENCODING,
header::HeaderValue::from_static("chunked"),
);
// TODO Check if we need type 'multipart/related' file to attach metadata like Neon's S3
// `.upload()` does.
// https://cloud.google.com/storage/docs/uploading-objects#uploading-an-object
let upload_uri = format!(
"https://storage.googleapis.com/upload/storage/v1/b/{}/o/?uploadType=media&name={}",
self.bucket_name.clone(),
self.relative_path_to_gcs_object(to).trim_start_matches("/")
);
let upload = Client::new()
.post(upload_uri)
.body(reqwest::Body::wrap_stream(byte_stream))
.headers(headers)
.bearer_auth(self.token_provider.token(GCS_SCOPES).await?.as_str())
.send();
// We await it in a race against the Tokio timeout
let upload = tokio::time::timeout(self.timeout, upload);
let res = tokio::select! {
res = upload => res,
_ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
};
match res {
Ok(Ok(res)) => {
if !res.status().is_success() {
match res.status() {
StatusCode::NOT_FOUND => {
return Err(anyhow::anyhow!("GCS error: not found \n\t {:?}", res));
}
_ => {
return Err(anyhow::anyhow!(
"GCS PUT response contained no response body \n\t {:?}",
res
));
}
}
} else {
Ok(())
}
}
Ok(Err(reqw)) => Err(reqw.into()),
Err(_timeout) => Err(TimeoutOrCancel::Timeout.into()),
}
}
async fn copy(
&self,
from: String,
to: String,
cancel: &CancellationToken,
) -> anyhow::Result<()> {
let kind = RequestKind::Copy;
let _permit = self.permit(kind, cancel).await?;
let timeout = tokio::time::sleep(self.timeout);
let started_at = start_measuring_requests(kind);
let copy_uri = format!(
"https://storage.googleapis.com/storage/v1/b/{}/o/{}/copyTo/b/{}/o/{}",
self.bucket_name.clone(),
&from,
self.bucket_name.clone(),
&to
);
let op = Client::new()
.post(copy_uri)
.bearer_auth(self.token_provider.token(GCS_SCOPES).await?.as_str())
.send();
let res = tokio::select! {
res = op => res,
_ = timeout => return Err(TimeoutOrCancel::Timeout.into()),
_ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
};
let started_at = ScopeGuard::into_inner(started_at);
crate::metrics::BUCKET_METRICS
.req_seconds
.observe_elapsed(kind, &res, started_at);
res?;
Ok(())
}
async fn delete_oids(
&self,
delete_objects: &[String],
cancel: &CancellationToken,
_permit: &tokio::sync::SemaphorePermit<'_>,
) -> anyhow::Result<()> {
let kind = RequestKind::Delete;
let mut cancel = std::pin::pin!(cancel.cancelled());
for chunk in delete_objects.chunks(MAX_KEYS_PER_DELETE_GCS) {
let started_at = start_measuring_requests(kind);
// Use this to report keys that didn't delete based on 'content_id'
let mut delete_objects_status = HashMap::new();
let mut form = reqwest::multipart::Form::new();
let bulk_uri = "https://storage.googleapis.com/batch/storage/v1";
for (index, path) in delete_objects.iter().enumerate() {
delete_objects_status.insert(index + 1, path.clone());
let path_to_delete: String =
url::form_urlencoded::byte_serialize(path.trim_start_matches("/").as_bytes())
.collect();
let delete_req = format!(
"
DELETE /storage/v1/b/{}/o/{} HTTP/1.1\r\n\
Content-Type: application/json\r\n\
accept: application/json\r\n\
content-length: 0\r\n
",
self.bucket_name.clone(),
path_to_delete
)
.trim()
.to_string();
let content_id = format!("<{}+{}>", Uuid::new_v4(), index + 1);
let mut part_headers = header::HeaderMap::new();
part_headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/http"),
);
part_headers.insert(
header::TRANSFER_ENCODING,
header::HeaderValue::from_static("binary"),
);
part_headers.insert(
header::HeaderName::from_static("content-id"),
header::HeaderValue::from_str(&content_id)?,
);
let part = reqwest::multipart::Part::text(delete_req).headers(part_headers);
form = form.part(format!("request-{}", index), part);
}
let mut headers = header::HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_str(&format!(
"multipart/mixed; boundary={}",
form.boundary()
))?,
);
let req = Client::new()
.post(bulk_uri)
.bearer_auth(self.token_provider.token(GCS_SCOPES).await?.as_str())
.multipart(form)
.headers(headers)
.send();
let resp = tokio::select! {
resp = req => resp,
_ = tokio::time::sleep(self.timeout) => return Err(TimeoutOrCancel::Timeout.into()),
_ = &mut cancel => return Err(TimeoutOrCancel::Cancel.into()),
};
let started_at = ScopeGuard::into_inner(started_at);
crate::metrics::BUCKET_METRICS
.req_seconds
.observe_elapsed(kind, &resp, started_at);
let resp = resp.context("request deletion")?;
crate::metrics::BUCKET_METRICS
.deleted_objects_total
.inc_by(chunk.len() as u64);
let res_headers = resp.headers().to_owned();
let boundary = res_headers
.get(header::CONTENT_TYPE)
.unwrap()
.to_str()?
.split("=")
.last()
.unwrap();
let res_body = resp.text().await?;
let parsed: HashMap<String, String> = res_body
.split(&format!("--{}", boundary))
.filter_map(|c| {
let mut lines = c.lines();
let id = lines.find_map(|line| {
line.strip_prefix("Content-ID:")
.and_then(|suf| suf.split('+').last())
.and_then(|suf| suf.split('>').next())
.map(|x| x.trim().to_string())
});
let status_code = lines.find_map(|line| {
// Not sure if this protocol version shouldn't be so specific
line.strip_prefix("HTTP/1.1")
.and_then(|x| x.split_whitespace().next())
.map(|x| x.trim().to_string())
});
id.zip(status_code)
})
.collect();
// Gather failures
let errors: HashMap<usize, &String> = parsed
.iter()
.filter_map(|(x, y)| {
if y.chars().next() != Some('2') {
x.parse::<usize>().ok().map(|v| (v, y))
} else {
None
}
})
.collect();
if !errors.is_empty() {
// Report 10 of them like S3
const LOG_UP_TO_N_ERRORS: usize = 10;
for (id, code) in errors.iter().take(LOG_UP_TO_N_ERRORS) {
tracing::warn!(
"DeleteObjects key {} failed with code: {}",
delete_objects_status.get(id).unwrap(),
code
);
}
return Err(anyhow::anyhow!(
"Failed to delete {}/{} objects",
errors.len(),
chunk.len(),
));
}
}
Ok(())
}
async fn list_objects_v2(&self, list_uri: String) -> anyhow::Result<reqwest::RequestBuilder> {
let res = Client::new()
.get(list_uri)
.bearer_auth(self.token_provider.token(GCS_SCOPES).await?.as_str());
Ok(res)
}
// need a 'bucket', a 'key', and a bytes 'range'.
async fn get_object(
&self,
request: GetObjectRequest,
cancel: &CancellationToken,
) -> anyhow::Result<Download, DownloadError> {
let kind = RequestKind::Get;
let permit = self.owned_permit(kind, cancel).await?;
let started_at = start_measuring_requests(kind);
let encoded_path: String =
url::form_urlencoded::byte_serialize(request.key.as_bytes()).collect();
/// We do this in two parts:
/// 1. Serialize the metadata of the first request to get Etag, last modified, etc
/// 2. We do not .await the second request pass on the pinned stream to the 'get_object'
/// caller
// 1. Serialize Metadata in initial request
let metadata_uri_mod = "alt=json";
let download_uri = format!(
"https://storage.googleapis.com/storage/v1/b/{}/o/{}?{}",
self.bucket_name.clone(),
encoded_path,
metadata_uri_mod
);
let res = Client::new()
.get(download_uri)
.bearer_auth(
self.token_provider
.token(GCS_SCOPES)
.await
.map_err(|e: gcp_auth::Error| DownloadError::Other(e.into()))?
.as_str(),
)
.send()
.await
.map_err(|e: reqwest::Error| DownloadError::Other(e.into()))?;
if !res.status().is_success() {
match res.status() {
StatusCode::NOT_FOUND => return Err(DownloadError::NotFound),
_ => {
return Err(DownloadError::Other(anyhow::anyhow!(
"GCS GET resposne contained no response body"
)));
}
}
};
let body = res
.text()
.await
.map_err(|e: reqwest::Error| DownloadError::Other(e.into()))?;
let resp: GCSObject = serde_json::from_str(&body)
.map_err(|e: serde_json::Error| DownloadError::Other(e.into()))?;
// 2. Byte Stream request
let mut headers = header::HeaderMap::new();
headers.insert(header::RANGE, header::HeaderValue::from_static("bytes=0-"));
let encoded_path: String =
url::form_urlencoded::byte_serialize(request.key.as_bytes()).collect();
let stream_uri_mod = "alt=media";
let stream_uri = format!(
"https://storage.googleapis.com/storage/v1/b/{}/o/{}?{}",
self.bucket_name.clone(),
encoded_path,
stream_uri_mod
);
let mut req = Client::new()
.get(stream_uri)
.headers(headers)
.bearer_auth(
self.token_provider
.token(GCS_SCOPES)
.await
.map_err(|e: gcp_auth::Error| DownloadError::Other(e.into()))?
.as_str(),
)
.send();
let get_object = tokio::select! {
res = req => res,
_ = tokio::time::sleep(self.timeout) => return Err(DownloadError::Timeout),
_ = cancel.cancelled() => return Err(DownloadError::Cancelled),
};
let started_at = ScopeGuard::into_inner(started_at);
let object_output = match get_object {
Ok(object_output) => {
if !object_output.status().is_success() {
match object_output.status() {
StatusCode::NOT_FOUND => return Err(DownloadError::NotFound),
_ => {
return Err(DownloadError::Other(anyhow::anyhow!(
"GCS GET resposne contained no response body"
)));
}
}
} else {
object_output
}
}
Err(e) => {
crate::metrics::BUCKET_METRICS.req_seconds.observe_elapsed(
kind,
AttemptOutcome::Err,
started_at,
);
return Err(DownloadError::Other(
anyhow::Error::new(e).context("download s3 object"),
));
}
};
let remaining = self.timeout.saturating_sub(started_at.elapsed());
let metadata = resp.metadata.map(StorageMetadata);
let etag = resp
.etag
.ok_or(DownloadError::Other(anyhow::anyhow!("Missing ETag header")))?
.into();
let last_modified: SystemTime = resp
.updated
.and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
.map(|s| s.into())
.unwrap_or(SystemTime::now());
// But let data stream pass through
Ok(Download {
download_stream: Box::pin(object_output.bytes_stream().map(|item| {
item.map_err(|e: reqwest::Error| std::io::Error::new(std::io::ErrorKind::Other, e))
})),
etag,
last_modified,
metadata,
})
}
}
impl RemoteStorage for GCSBucket {
// ---------------------------------------
// Neon wrappers for GCS client functions
// ---------------------------------------
fn list_streaming(
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> impl Stream<Item = Result<Listing, DownloadError>> {
let kind = RequestKind::List;
let mut max_keys = max_keys.map(|mk| mk.get() as i32);
let list_prefix = prefix
.map(|p| self.relative_path_to_gcs_object(p))
.or_else(|| {
self.prefix_in_bucket.clone().map(|mut s| {
s.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
s
})
})
.unwrap();
let request_max_keys = self
.max_keys_per_list_response
.into_iter()
.chain(max_keys.into_iter())
.min()
// https://cloud.google.com/storage/docs/json_api/v1/objects/list?hl=en#parameters
// TODO set this to default
.unwrap_or(1000);
// We pass URI in to `list_objects_v2` as we'll modify it with `NextPageToken`, hence
// `mut`
let mut list_uri = format!(
"https://storage.googleapis.com/storage/v1/b/{}/o?prefix={}&maxResults={}",
self.bucket_name.clone(),
list_prefix,
request_max_keys,
);
// on ListingMode:
// https://github.com/neondatabase/neon/blob/edc11253b65e12a10843711bd88ad277511396d7/libs/remote_storage/src/lib.rs#L158C1-L164C2
if let ListingMode::WithDelimiter = mode {
list_uri.push_str(&format!(
"&delimiter={}",
REMOTE_STORAGE_PREFIX_SEPARATOR.to_string()
));
}
async_stream::stream! {
let mut continuation_token = None;
'outer: loop {
let started_at = start_measuring_requests(kind);
let request = self.list_objects_v2(list_uri.clone())
.await
.map_err(DownloadError::Other)?
.send();
// this is like `await`
let response = tokio::select! {
res = request => Ok(res),
_ = tokio::time::sleep(self.timeout) => Err(DownloadError::Timeout),
_ = cancel.cancelled() => Err(DownloadError::Cancelled),
}?;
// just mapping our `Result' error variant's type.
let response = response
.context("Failed to list GCS prefixes")
.map_err(DownloadError::Other);
let started_at = ScopeGuard::into_inner(started_at);
crate::metrics::BUCKET_METRICS
.req_seconds
.observe_elapsed(kind, &response, started_at);
let response = match response {
Ok(response) => response,
Err(e) => {
// The error is potentially retryable, so we must rewind the loop after yielding.
yield Err(e);
continue 'outer;
},
};
let body = response.text()
.await
.map_err(|e: reqwest::Error| DownloadError::Other(e.into()))?;
let resp: GCSListResponse = serde_json::from_str(&body).map_err(|e: serde_json::Error| DownloadError::Other(e.into()))?;
let prefixes = resp.common_prefixes();
let keys = resp.contents();
tracing::debug!("list: {} prefixes, {} keys", prefixes.len(), keys.len());
let mut result = Listing::default();
for res in keys.iter() {
let last_modified: SystemTime = res.updated.clone()
.and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
.map(|s| s.into())
.unwrap_or(SystemTime::now());
let size = res.size.clone().unwrap_or("0".to_string()).parse::<u64>().unwrap();
let key = res.name.clone();
result.keys.push(
ListingObject{
key: self.gcs_object_to_relative_path(&key),
last_modified,
size,
}
);
if let Some(mut mk) = max_keys {
assert!(mk > 0);
mk -= 1;
if mk == 0 {
tracing::debug!("reached limit set by max_keys");
yield Ok(result);
break 'outer;
}
max_keys = Some(mk);
};
}
result.prefixes.extend(prefixes.iter().filter_map(|p| {
Some(
self.gcs_object_to_relative_path(
p.trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR)
),
)
}));
yield Ok(result);
continuation_token = match resp.next_page_token {
Some(token) => {
list_uri = list_uri + "&pageToken=" + &token;
Some(token)
},
None => break
}
}
}
}
async fn head_object(
&self,
key: &RemotePath,
cancel: &CancellationToken,
) -> Result<ListingObject, DownloadError> {
let kind = RequestKind::Head;
todo!();
}
async fn upload(
&self,
from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
from_size_bytes: usize,
to: &RemotePath,
metadata: Option<StorageMetadata>,
cancel: &CancellationToken,
) -> anyhow::Result<()> {
let kind = RequestKind::Put;
let _permit = self.permit(kind, cancel).await?;
let started_at = start_measuring_requests(kind);
let upload = self.put_object(from, from_size_bytes, to, cancel);
let upload = tokio::time::timeout(self.timeout, upload);
let res = tokio::select! {
res = upload => res,
_ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
};
if let Ok(inner) = &res {
// do not incl. timeouts as errors in metrics but cancellations
let started_at = ScopeGuard::into_inner(started_at);
crate::metrics::BUCKET_METRICS
.req_seconds
.observe_elapsed(kind, inner, started_at);
}
match res {
Ok(Ok(_put)) => Ok(()),
Ok(Err(sdk)) => Err(sdk.into()),
Err(_timeout) => Err(TimeoutOrCancel::Timeout.into()),
}
}
async fn copy(
&self,
from: &RemotePath,
to: &RemotePath,
cancel: &CancellationToken,
) -> anyhow::Result<()> {
let kind = RequestKind::Copy;
let _permit = self.permit(kind, cancel).await?;
let timeout = tokio::time::sleep(self.timeout);
let started_at = start_measuring_requests(kind);
// we need to specify bucket_name as a prefix
let copy_source = format!(
"{}/{}",
self.bucket_name,
self.relative_path_to_gcs_object(from)
);
todo!();
}
async fn download(
&self,
from: &RemotePath,
opts: &DownloadOpts,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
// if prefix is not none then download file `prefix/from`
// if prefix is none then download file `from`
self.get_object(
GetObjectRequest {
bucket: self.bucket_name.clone(),
key: self
.relative_path_to_gcs_object(from)
.trim_start_matches("/")
.to_string(),
etag: opts.etag.as_ref().map(|e| e.to_string()),
range: opts.byte_range_header(),
},
cancel,
)
.await
}
async fn delete_objects(
&self,
paths: &[RemotePath],
cancel: &CancellationToken,
) -> anyhow::Result<()> {
let kind = RequestKind::Delete;
let permit = self.permit(kind, cancel).await?;
let mut delete_objects: Vec<String> = Vec::with_capacity(paths.len());
let delete_objects: Vec<String> = paths
.iter()
.map(|i| self.relative_path_to_gcs_object(i))
.collect();
self.delete_oids(&delete_objects, cancel, &permit).await
}
fn max_keys_per_delete(&self) -> usize {
MAX_KEYS_PER_DELETE_GCS
}
async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> {
let paths = std::array::from_ref(path);
self.delete_objects(paths, cancel).await
}
async fn time_travel_recover(
&self,
prefix: Option<&RemotePath>,
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
) -> Result<(), TimeTravelError> {
Ok(())
}
}
// ---------
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "snake_case")]
pub struct GCSListResponse {
#[serde(rename = "nextPageToken")]
pub next_page_token: Option<String>,
pub items: Option<Vec<GCSObject>>,
pub prefixes: Option<Vec<String>>,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "snake_case")]
pub struct GCSObject {
pub name: String,
pub bucket: String,
pub generation: String,
pub metageneration: String,
#[serde(rename = "contentType")]
pub content_type: Option<String>,
#[serde(rename = "storageClass")]
pub storage_class: String,
pub size: Option<String>,
#[serde(rename = "md5Hash")]
pub md5_hash: Option<String>,
pub crc32c: String,
pub etag: Option<String>,
#[serde(rename = "timeCreated")]
pub time_created: String,
pub updated: Option<String>,
#[serde(rename = "timeStorageClassUpdated")]
pub time_storage_class_updated: String,
#[serde(rename = "timeFinalized")]
pub time_finalized: String,
pub metadata: Option<HashMap<String, String>>,
}
impl GCSListResponse {
pub fn contents(&self) -> &[GCSObject] {
self.items.as_deref().unwrap_or_default()
}
pub fn common_prefixes(&self) -> &[String] {
self.prefixes.as_deref().unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
use gcp_auth;
use std::num::NonZero;
use std::pin::pin;
use std::sync::Arc;
const BUFFER_SIZE: usize = 32 * 1024;
// TODO what does Neon want here for integration tests?
const BUCKET: &str = "https://storage.googleapis.com/storage/v1/b/my-test-bucket";
#[tokio::test]
async fn list_returns_keys_from_bucket() {
let provider = gcp_auth::provider().await.unwrap();
let gcs = GCSBucket {
token_provider: Arc::clone(&provider),
bucket_name: BUCKET.to_string(),
prefix_in_bucket: None,
max_keys_per_list_response: Some(100),
concurrency_limiter: ConcurrencyLimiter::new(100),
timeout: std::time::Duration::from_secs(120),
};
let cancel = CancellationToken::new();
let remote_prefix = "box/tiff/2023/TN".to_string();
let max_keys: u32 = 100;
let mut stream = pin!(gcs.list_streaming(Some(remote_prefix), NonZero::new(max_keys)));
let mut combined = stream
.next()
.await
.expect("At least one item required")
.unwrap();
while let Some(list) = stream.next().await {
let list = list.unwrap();
combined.keys.extend(list.keys.into_iter());
combined.prefixes.extend_from_slice(&list.prefixes);
}
for key in combined.keys.iter() {
println!("Item: {} -- {:?}", key.key, key.last_modified);
}
assert_ne!(0, combined.keys.len());
}
}

View File

@@ -12,6 +12,7 @@
mod azure_blob;
mod config;
mod error;
mod gcs_bucket;
mod local_fs;
mod metrics;
mod s3_bucket;
@@ -42,6 +43,7 @@ use tokio_util::sync::CancellationToken;
use tracing::info;
pub use self::azure_blob::AzureBlobStorage;
pub use self::gcs_bucket::GCSBucket;
pub use self::local_fs::LocalFs;
pub use self::s3_bucket::S3Bucket;
pub use self::simulate_failures::UnreliableWrapper;
@@ -80,8 +82,12 @@ pub const MAX_KEYS_PER_DELETE_S3: usize = 1000;
/// <https://learn.microsoft.com/en-us/rest/api/storageservices/blob-batch>
pub const MAX_KEYS_PER_DELETE_AZURE: usize = 256;
pub const MAX_KEYS_PER_DELETE_GCS: usize = 1000;
const REMOTE_STORAGE_PREFIX_SEPARATOR: char = '/';
const GCS_SCOPES: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
/// Path on the remote storage, relative to some inner prefix.
/// The prefix is an implementation detail, that allows representing local paths
/// as the remote ones, stripping the local storage prefix away.
@@ -439,6 +445,7 @@ pub enum GenericRemoteStorage<Other: Clone = Arc<UnreliableWrapper>> {
AwsS3(Arc<S3Bucket>),
AzureBlob(Arc<AzureBlobStorage>),
Unreliable(Other),
GCS(Arc<GCSBucket>),
}
impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
@@ -455,6 +462,7 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
Self::AwsS3(s) => s.list(prefix, mode, max_keys, cancel).await,
Self::AzureBlob(s) => s.list(prefix, mode, max_keys, cancel).await,
Self::Unreliable(s) => s.list(prefix, mode, max_keys, cancel).await,
Self::GCS(s) => s.list(prefix, mode, max_keys, cancel).await,
}
}
@@ -472,6 +480,7 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
Self::AwsS3(s) => Box::pin(s.list_streaming(prefix, mode, max_keys, cancel)),
Self::AzureBlob(s) => Box::pin(s.list_streaming(prefix, mode, max_keys, cancel)),
Self::Unreliable(s) => Box::pin(s.list_streaming(prefix, mode, max_keys, cancel)),
Self::GCS(s) => Box::pin(s.list_streaming(prefix, mode, max_keys, cancel)),
}
}
@@ -486,6 +495,7 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
Self::AwsS3(s) => s.head_object(key, cancel).await,
Self::AzureBlob(s) => s.head_object(key, cancel).await,
Self::Unreliable(s) => s.head_object(key, cancel).await,
Self::GCS(_) => todo!(),
}
}
@@ -503,6 +513,7 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
Self::AwsS3(s) => s.upload(from, data_size_bytes, to, metadata, cancel).await,
Self::AzureBlob(s) => s.upload(from, data_size_bytes, to, metadata, cancel).await,
Self::Unreliable(s) => s.upload(from, data_size_bytes, to, metadata, cancel).await,
Self::GCS(s) => s.upload(from, data_size_bytes, to, metadata, cancel).await,
}
}
@@ -518,6 +529,7 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
Self::AwsS3(s) => s.download(from, opts, cancel).await,
Self::AzureBlob(s) => s.download(from, opts, cancel).await,
Self::Unreliable(s) => s.download(from, opts, cancel).await,
Self::GCS(s) => s.download(from, opts, cancel).await,
}
}
@@ -532,6 +544,7 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
Self::AwsS3(s) => s.delete(path, cancel).await,
Self::AzureBlob(s) => s.delete(path, cancel).await,
Self::Unreliable(s) => s.delete(path, cancel).await,
Self::GCS(s) => s.delete(path, cancel).await,
}
}
@@ -546,6 +559,7 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
Self::AwsS3(s) => s.delete_objects(paths, cancel).await,
Self::AzureBlob(s) => s.delete_objects(paths, cancel).await,
Self::Unreliable(s) => s.delete_objects(paths, cancel).await,
Self::GCS(s) => s.delete_objects(paths, cancel).await,
}
}
@@ -556,6 +570,7 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
Self::AwsS3(s) => s.max_keys_per_delete(),
Self::AzureBlob(s) => s.max_keys_per_delete(),
Self::Unreliable(s) => s.max_keys_per_delete(),
Self::GCS(s) => s.max_keys_per_delete(),
}
}
@@ -570,6 +585,7 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
Self::AwsS3(s) => s.delete_prefix(prefix, cancel).await,
Self::AzureBlob(s) => s.delete_prefix(prefix, cancel).await,
Self::Unreliable(s) => s.delete_prefix(prefix, cancel).await,
Self::GCS(_) => todo!(),
}
}
@@ -585,6 +601,7 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
Self::AwsS3(s) => s.copy(from, to, cancel).await,
Self::AzureBlob(s) => s.copy(from, to, cancel).await,
Self::Unreliable(s) => s.copy(from, to, cancel).await,
Self::GCS(_) => todo!(),
}
}
@@ -613,17 +630,25 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
.await
}
Self::GCS(_) => todo!(),
}
}
}
impl GenericRemoteStorage {
pub async fn from_config(storage_config: &RemoteStorageConfig) -> anyhow::Result<Self> {
info!("RemoteStorageConfig: {:?}", storage_config);
let timeout = storage_config.timeout;
// If somkeone overrides timeout to be small without adjusting small_timeout, then adjust it automatically
// If someone overrides timeout to be small without adjusting small_timeout, then adjust it automatically
let small_timeout = std::cmp::min(storage_config.small_timeout, timeout);
info!(
"RemoteStorageConfig's storage attribute: {:?}",
storage_config.storage
);
Ok(match &storage_config.storage {
RemoteStorageKind::LocalFs { local_path: path } => {
info!("Using fs root '{path}' as a remote storage");
@@ -661,6 +686,16 @@ impl GenericRemoteStorage {
small_timeout,
)?))
}
RemoteStorageKind::GCS(gcs_config) => {
let google_application_credentials =
std::env::var("GOOGLE_APPLICATION_CREDENTIALS")
.unwrap_or_else(|_| "<none>".into());
info!(
"Using gcs bucket '{}' as a remote storage, prefix in bucket: '{:?}', GOOGLE_APPLICATION_CREDENTIALS: {google_application_credentials }",
gcs_config.bucket_name, gcs_config.prefix_in_bucket
);
Self::GCS(Arc::new(GCSBucket::new(gcs_config, timeout).await?))
}
})
}
@@ -690,6 +725,7 @@ impl GenericRemoteStorage {
Self::AwsS3(s) => Some(s.bucket_name()),
Self::AzureBlob(s) => Some(s.container_name()),
Self::Unreliable(_s) => None,
Self::GCS(s) => Some(s.bucket_name()),
}
}
}

View File

@@ -50,6 +50,7 @@ impl UnreliableWrapper {
GenericRemoteStorage::Unreliable(_s) => {
panic!("Can't wrap unreliable wrapper unreliably")
}
GenericRemoteStorage::GCS(_) => todo!(),
};
UnreliableWrapper {
inner,

View File

@@ -1285,10 +1285,6 @@ impl Timeline {
reconstruct_state: &mut ValuesReconstructState,
ctx: &RequestContext,
) -> Result<BTreeMap<Key, Result<Bytes, PageReconstructError>>, GetVectoredError> {
if query.is_empty() {
return Ok(BTreeMap::default());
}
let read_path = if self.conf.enable_read_path_debugging || ctx.read_path_debug() {
Some(ReadPath::new(
query.total_keyspace(),

View File

@@ -16,9 +16,9 @@ use tracing::field::display;
use tracing::{debug, info};
use super::AsyncRW;
use super::conn_pool::poll_client_generic;
use super::conn_pool::poll_client;
use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
use super::http_conn_pool::{self, HttpConnPool};
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
@@ -42,9 +42,10 @@ use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<GlobalConnPool<HttpConnPool>>,
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
pub(crate) local_pool: Arc<LocalConnPool<postgres_client::Client>>,
pub(crate) pool: Arc<GlobalConnPool<EndpointConnPool<postgres_client::Client>>>,
pub(crate) pool:
Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
@@ -211,7 +212,7 @@ impl PoolingBackend {
None
} else {
debug!("pool: looking for an existing connection");
self.pool.get(ctx, &conn_info)
self.pool.get(ctx, &conn_info)?
};
if let Some(client) = maybe_client {
@@ -245,9 +246,9 @@ impl PoolingBackend {
&self,
ctx: &RequestContext,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client, HttpConnError> {
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
debug!("pool: looking for an existing connection");
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
}
@@ -531,7 +532,7 @@ impl ShouldRetryWakeCompute for LocalProxyConnError {
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<EndpointConnPool<postgres_client::Client>>>,
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
@@ -577,7 +578,7 @@ impl ConnectMechanism for TokioMechanism {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_client_generic(
Ok(poll_client(
self.pool.clone(),
ctx,
self.conn_info.clone(),
@@ -592,7 +593,7 @@ impl ConnectMechanism for TokioMechanism {
}
struct HyperMechanism {
pool: Arc<GlobalConnPool<HttpConnPool>>,
pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
@@ -602,7 +603,7 @@ struct HyperMechanism {
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client;
type Connection = http_conn_pool::Client<Send>;
type ConnectError = HttpConnError;
type Error = HttpConnError;
@@ -638,26 +639,15 @@ impl ConnectMechanism for HyperMechanism {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
let client = poll_client_generic(
Ok(poll_http2_client(
self.pool.clone(),
ctx,
self.conn_info.clone(),
&self.conn_info,
client,
connection,
self.conn_id,
node_info.aux.clone(),
);
// auth-broker -> local-proxy clients don't return to the pool, since
// they are multiplexing and cloneable. So instead we insert it once here.
if let Some(endpoint) = self.conn_info.endpoint_cache_key() {
self.pool
.get_or_create_endpoint_pool(&endpoint)
.write()
.register(&client);
}
Ok(client)
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}

View File

@@ -11,7 +11,7 @@ use smallvec::SmallVec;
use tokio::net::TcpStream;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn};
use tracing::{Instrument, error, info, info_span, warn};
#[cfg(test)]
use {
super::conn_pool_lib::GlobalConnPoolOptions,
@@ -20,7 +20,8 @@ use {
};
use super::conn_pool_lib::{
ClientDataEnum, ClientInnerCommon, ConnInfo, EndpointConnPoolExt, GlobalConnPool,
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool,
GlobalConnPool,
};
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
@@ -28,7 +29,6 @@ use crate::metrics::Metrics;
use crate::tls::postgres_rustls::MakeRustlsConnect;
type TlsStream = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::Stream;
pub(super) type Conn = postgres_client::Connection<TcpStream, TlsStream>;
#[derive(Debug, Clone)]
pub(crate) struct ConnInfoWithAuth {
@@ -56,20 +56,20 @@ impl fmt::Display for ConnInfo {
}
}
pub(crate) fn poll_client_generic<P: EndpointConnPoolExt>(
global_pool: Arc<GlobalConnPool<P>>,
pub(crate) fn poll_client<C: ClientInnerExt>(
global_pool: Arc<GlobalConnPool<C, EndpointConnPool<C>>>,
ctx: &RequestContext,
conn_info: ConnInfo,
client: P::ClientInner,
connection: P::Connection,
client: C,
mut connection: postgres_client::Connection<TcpStream, TlsStream>,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> P::Client {
) -> Client<C> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
let mut session_id = ctx.session_id();
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
let span = info_span!(parent: None, "connection", %conn_id, %session_id);
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
@@ -85,30 +85,27 @@ pub(crate) fn poll_client_generic<P: EndpointConnPoolExt>(
let cancel = CancellationToken::new();
let cancelled = cancel.clone().cancelled_owned();
tokio::spawn(async move {
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let mut idle_timeout = pin!(tokio::time::sleep(idle));
let mut cancelled = pin!(cancelled);
let mut connection = pin!(P::spawn_conn(connection));
poll_fn(move |cx| {
let _enter = span.enter();
if cancelled.as_mut().poll(cx).is_ready() {
info!("connection dropped");
return Poll::Ready(());
return Poll::Ready(())
}
match rx.has_changed() {
Ok(true) => {
let session_id = *rx.borrow_and_update();
span.record("session_id", tracing::field::display(session_id));
info!("changed session");
session_id = *rx.borrow_and_update();
info!(%session_id, "changed session");
idle_timeout.as_mut().reset(Instant::now() + idle);
}
Err(_) => {
info!("connection dropped");
return Poll::Ready(());
return Poll::Ready(())
}
_ => {}
}
@@ -120,25 +117,48 @@ pub(crate) fn poll_client_generic<P: EndpointConnPoolExt>(
if let Some(pool) = pool.clone().upgrade() {
// remove client from pool - should close the connection if it's idle.
// does nothing if the client is currently checked-out and in-use
if pool.write().remove_conn(db_user.clone(), conn_id) {
if pool.write().remove_client(db_user.clone(), conn_id) {
info!("idle connection removed");
}
}
}
ready!(connection.as_mut().poll(cx));
loop {
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Ok(_)) => {
warn!(%session_id, "unknown message");
}
Some(Err(e)) => {
error!(%session_id, "connection error: {}", e);
break
}
None => {
info!("connection closed");
break
}
}
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(db_user.clone(), conn_id) {
if pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
})
.await;
});
}).await;
}
.instrument(span));
let inner = ClientInnerCommon {
inner: client,
aux,
@@ -149,42 +169,7 @@ pub(crate) fn poll_client_generic<P: EndpointConnPoolExt>(
}),
};
P::wrap_client(inner, conn_info, pool_clone)
}
pub async fn poll_tokio_postgres_conn_really(mut connection: Conn) {
poll_fn(move |cx| {
loop {
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!("notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(
pid = notif.process_id(),
channel = notif.channel(),
"notification received"
);
}
Some(Ok(_)) => {
warn!("unknown message");
}
Some(Err(e)) => {
error!("connection error: {}", e);
break;
}
None => {
info!("connection closed");
break;
}
}
}
Poll::Ready(())
})
.await;
Client::new(inner, conn_info, pool_clone)
}
#[derive(Clone)]
@@ -194,11 +179,11 @@ pub(crate) struct ClientDataRemote {
}
impl ClientDataRemote {
pub fn session(&self) -> &tokio::sync::watch::Sender<uuid::Uuid> {
&self.session
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
&mut self.session
}
pub fn cancel(&self) {
pub fn cancel(&mut self) {
self.cancel.cancel();
}
}
@@ -210,7 +195,6 @@ mod tests {
use super::*;
use crate::proxy::NeonOptions;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::conn_pool_lib::{Client, ClientInnerExt};
use crate::types::{BranchId, EndpointId, ProjectId};
struct MockClient(Arc<AtomicBool>);

View File

@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Weak};
@@ -11,10 +12,11 @@ use rand::Rng;
use smol_str::ToSmolStr;
use tracing::{Span, debug, info};
use super::conn_pool::{ClientDataRemote, poll_tokio_postgres_conn_really};
use super::backend::HttpConnError;
use super::conn_pool::ClientDataRemote;
use super::http_conn_pool::ClientDataHttp;
use super::local_conn_pool::ClientDataLocal;
use crate::auth::backend::ComputeUserInfo;
use crate::config::HttpConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
@@ -49,6 +51,7 @@ impl ConnInfo {
pub(crate) enum ClientDataEnum {
Remote(ClientDataRemote),
Local(ClientDataLocal),
Http(ClientDataHttp),
}
#[derive(Clone)]
@@ -61,9 +64,14 @@ pub(crate) struct ClientInnerCommon<C: ClientInnerExt> {
impl<C: ClientInnerExt> Drop for ClientInnerCommon<C> {
fn drop(&mut self) {
match &self.data {
ClientDataEnum::Remote(remote_data) => remote_data.cancel(),
ClientDataEnum::Local(local_data) => local_data.cancel(),
match &mut self.data {
ClientDataEnum::Remote(remote_data) => {
remote_data.cancel();
}
ClientDataEnum::Local(local_data) => {
local_data.cancel();
}
ClientDataEnum::Http(_http_data) => (),
}
}
}
@@ -73,8 +81,8 @@ impl<C: ClientInnerExt> ClientInnerCommon<C> {
self.conn_id
}
pub(crate) fn get_data(&self) -> &ClientDataEnum {
&self.data
pub(crate) fn get_data(&mut self) -> &mut ClientDataEnum {
&mut self.data
}
}
@@ -318,70 +326,12 @@ impl<C: ClientInnerExt> DbUserConn<C> for DbUserConnPool<C> {
}
}
pub(crate) trait EndpointConnPoolExt: Send + Sync + 'static {
type Client;
type ClientInner: ClientInnerExt;
type Connection: Send + 'static;
fn create(config: &HttpConfig, global_connections_count: Arc<AtomicUsize>) -> Self;
fn wrap_client(
inner: ClientInnerCommon<Self::ClientInner>,
conn_info: ConnInfo,
pool: Weak<RwLock<Self>>,
) -> Self::Client;
fn get_conn_entry(
&mut self,
db_user: (DbName, RoleName),
) -> Option<ClientInnerCommon<Self::ClientInner>>;
fn remove_conn(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool;
fn spawn_conn(conn: Self::Connection) -> impl Future<Output = ()> + Send + 'static;
pub(crate) trait EndpointConnPoolExt<C: ClientInnerExt> {
fn clear_closed(&mut self) -> usize;
fn total_conns(&self) -> usize;
}
impl<C: ClientInnerExt> EndpointConnPoolExt for EndpointConnPool<C> {
type Client = Client<C>;
type ClientInner = C;
type Connection = super::conn_pool::Conn;
fn create(config: &HttpConfig, global_connections_count: Arc<AtomicUsize>) -> Self {
EndpointConnPool {
pools: HashMap::new(),
total_conns: 0,
max_conns: config.pool_options.max_conns_per_endpoint,
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count,
global_pool_size_max_conns: config.pool_options.max_total_conns,
pool_name: String::from("remote"),
}
}
fn wrap_client(
client: ClientInnerCommon<Self::ClientInner>,
conn_info: ConnInfo,
pool: Weak<RwLock<Self>>,
) -> Self::Client {
Client::new(client, conn_info.clone(), pool)
}
fn get_conn_entry(
&mut self,
db_user: (DbName, RoleName),
) -> Option<ClientInnerCommon<Self::ClientInner>> {
Some(self.get_conn_entry(db_user)?.conn)
}
fn remove_conn(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
self.remove_client(db_user, conn_id)
}
async fn spawn_conn(conn: Self::Connection) {
poll_tokio_postgres_conn_really(conn).await;
}
impl<C: ClientInnerExt> EndpointConnPoolExt<C> for EndpointConnPool<C> {
fn clear_closed(&mut self) -> usize {
let mut clients_removed: usize = 0;
for db_pool in self.pools.values_mut() {
@@ -395,9 +345,10 @@ impl<C: ClientInnerExt> EndpointConnPoolExt for EndpointConnPool<C> {
}
}
pub(crate) struct GlobalConnPool<P>
pub(crate) struct GlobalConnPool<C, P>
where
P: EndpointConnPoolExt,
C: ClientInnerExt,
P: EndpointConnPoolExt<C>,
{
// endpoint -> per-endpoint connection pool
//
@@ -416,6 +367,8 @@ where
pub(crate) global_connections_count: Arc<AtomicUsize>,
pub(crate) config: &'static crate::config::HttpConfig,
_marker: PhantomData<C>,
}
#[derive(Debug, Clone, Copy)]
@@ -438,9 +391,10 @@ pub struct GlobalConnPoolOptions {
pub max_total_conns: usize,
}
impl<P> GlobalConnPool<P>
impl<C, P> GlobalConnPool<C, P>
where
P: EndpointConnPoolExt,
C: ClientInnerExt,
P: EndpointConnPoolExt<C>,
{
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
@@ -449,6 +403,7 @@ where
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
_marker: PhantomData,
})
}
@@ -537,72 +492,80 @@ where
}
}
impl<P: EndpointConnPoolExt> GlobalConnPool<P> {
impl<C: ClientInnerExt> GlobalConnPool<C, EndpointConnPool<C>> {
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestContext,
conn_info: &ConnInfo,
) -> Option<P::Client> {
let endpoint = conn_info.endpoint_cache_key()?;
) -> Result<Option<Client<C>>, HttpConnError> {
let mut client: Option<ClientInnerCommon<C>> = None;
let Some(endpoint) = conn_info.endpoint_cache_key() else {
return Ok(None);
};
let endpoint_pool = self.get_endpoint_pool(&endpoint)?;
let client = endpoint_pool
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
if let Some(entry) = endpoint_pool
.write()
.get_conn_entry(conn_info.db_and_user())?;
.get_conn_entry(conn_info.db_and_user())
{
client = Some(entry.conn);
}
let endpoint_pool = Arc::downgrade(&endpoint_pool);
if client.inner.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return None;
}
tracing::Span::current().record("conn_id", tracing::field::display(client.get_conn_id()));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
debug!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
match client.get_data() {
ClientDataEnum::Local(data) => {
data.session().send(ctx.session_id()).ok()?;
// ok return cached connection if found and establish a new one otherwise
if let Some(mut client) = client {
if client.inner.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
}
ClientDataEnum::Remote(data) => {
data.session().send(ctx.session_id()).ok()?;
tracing::Span::current()
.record("conn_id", tracing::field::display(client.get_conn_id()));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
debug!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
match client.get_data() {
ClientDataEnum::Local(data) => {
data.session().send(ctx.session_id())?;
}
ClientDataEnum::Remote(data) => {
data.session().send(ctx.session_id())?;
}
ClientDataEnum::Http(_) => (),
}
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
}
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
Some(P::wrap_client(client, conn_info.clone(), endpoint_pool))
}
}
impl<P: EndpointConnPoolExt> GlobalConnPool<P> {
pub(crate) fn get_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Option<Arc<RwLock<P>>> {
Some(self.global_pool.get(endpoint)?.clone())
Ok(None)
}
pub(crate) fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<P>> {
) -> Arc<RwLock<EndpointConnPool<C>>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
}
// slow path
let new_pool = Arc::new(RwLock::new(P::create(
self.config,
self.global_connections_count.clone(),
)));
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
pools: HashMap::new(),
total_conns: 0,
max_conns: self.config.pool_options.max_conns_per_endpoint,
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
global_pool_size_max_conns: self.config.pool_options.max_total_conns,
pool_name: String::from("remote"),
}));
// find or create a pool for this endpoint
let mut created = false;
@@ -629,7 +592,6 @@ impl<P: EndpointConnPoolExt> GlobalConnPool<P> {
pool
}
}
pub(crate) struct Client<C: ClientInnerExt> {
span: Span,
inner: Option<ClientInnerCommon<C>>,

View File

@@ -4,25 +4,32 @@ use std::sync::{Arc, Weak};
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use smol_str::ToSmolStr;
use tracing::{error, info};
use tracing::{Instrument, debug, error, info, info_span};
use super::AsyncRW;
use super::backend::HttpConnError;
use super::conn_pool_lib::{
ClientInnerCommon, ClientInnerExt, ConnInfo, ConnPoolEntry, EndpointConnPoolExt,
ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, ConnPoolEntry,
EndpointConnPoolExt, GlobalConnPool,
};
use crate::config::HttpConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::protocol2::ConnectionInfoExtra;
use crate::types::EndpointCacheKey;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
pub(crate) type Connect = http2::Connection<TokioIo<AsyncRW>, hyper::body::Incoming, TokioExecutor>;
#[derive(Clone)]
pub(crate) struct ClientDataHttp();
// Per-endpoint connection pool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct HttpConnPool {
pub(crate) struct HttpConnPool<C: ClientInnerExt + Clone> {
// TODO(conrad):
// either we should open more connections depending on stream count
// (not exposed by hyper, need our own counter)
@@ -32,13 +39,13 @@ pub(crate) struct HttpConnPool {
// seems somewhat redundant though.
//
// Probably we should run a semaphore and just the single conn. TBD.
conns: VecDeque<ConnPoolEntry<Send>>,
conns: VecDeque<ConnPoolEntry<C>>,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
}
impl HttpConnPool {
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry<Send>> {
impl<C: ClientInnerExt + Clone> HttpConnPool<C> {
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry<C>> {
let Self { conns, .. } = self;
loop {
@@ -76,59 +83,9 @@ impl HttpConnPool {
}
removed > 0
}
pub fn register(&mut self, client: &Client) {
self.conns.push_back(ConnPoolEntry {
conn: client.inner.clone(),
_last_access: std::time::Instant::now(),
});
}
}
impl EndpointConnPoolExt for HttpConnPool {
type Client = Client;
type ClientInner = Send;
type Connection = Connect;
fn create(_config: &HttpConfig, global_connections_count: Arc<AtomicUsize>) -> Self {
HttpConnPool {
conns: VecDeque::new(),
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count,
}
}
fn wrap_client(
inner: ClientInnerCommon<Self::ClientInner>,
_conn_info: ConnInfo,
_pool: Weak<parking_lot::RwLock<Self>>,
) -> Self::Client {
Client::new(inner)
}
fn get_conn_entry(
&mut self,
_db_user: (crate::types::DbName, crate::types::RoleName),
) -> Option<ClientInnerCommon<Self::ClientInner>> {
Some(self.get_conn_entry()?.conn)
}
fn remove_conn(
&mut self,
_db_user: (crate::types::DbName, crate::types::RoleName),
conn_id: uuid::Uuid,
) -> bool {
self.remove_conn(conn_id)
}
async fn spawn_conn(conn: Self::Connection) {
let res = conn.await;
match res {
Ok(()) => info!("connection closed"),
Err(e) => error!("connection error: {e:?}"),
}
}
impl<C: ClientInnerExt + Clone> EndpointConnPoolExt<C> for HttpConnPool<C> {
fn clear_closed(&mut self) -> usize {
let Self { conns, .. } = self;
let old_len = conns.len();
@@ -143,7 +100,7 @@ impl EndpointConnPoolExt for HttpConnPool {
}
}
impl Drop for HttpConnPool {
impl<C: ClientInnerExt + Clone> Drop for HttpConnPool<C> {
fn drop(&mut self) {
if !self.conns.is_empty() {
self.global_connections_count
@@ -157,12 +114,154 @@ impl Drop for HttpConnPool {
}
}
pub(crate) struct Client {
pub(crate) inner: ClientInnerCommon<Send>,
impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
#[expect(unused_results)]
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestContext,
conn_info: &ConnInfo,
) -> Result<Option<Client<C>>, HttpConnError> {
let result: Result<Option<Client<C>>, HttpConnError>;
let Some(endpoint) = conn_info.endpoint_cache_key() else {
result = Ok(None);
return result;
};
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
let Some(client) = endpoint_pool.write().get_conn_entry() else {
result = Ok(None);
return result;
};
tracing::Span::current().record("conn_id", tracing::field::display(client.conn.conn_id));
debug!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
Ok(Some(Client::new(client.conn.clone())))
}
fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<HttpConnPool<C>>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
}
// slow path
let new_pool = Arc::new(RwLock::new(HttpConnPool {
conns: VecDeque::new(),
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
}));
// find or create a pool for this endpoint
let mut created = false;
let pool = self
.global_pool
.entry(endpoint.clone())
.or_insert_with(|| {
created = true;
new_pool
})
.clone();
// log new global pool size
if created {
let global_pool_size = self
.global_pool_size
.fetch_add(1, atomic::Ordering::Relaxed)
+ 1;
info!(
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
);
}
pool
}
}
impl Client {
pub(self) fn new(inner: ClientInnerCommon<Send>) -> Self {
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
ctx: &RequestContext,
conn_info: &ConnInfo,
client: Send,
connection: Connect,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<Send> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
});
let pool = match conn_info.endpoint_cache_key() {
Some(endpoint) => {
let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
let client = ClientInnerCommon {
inner: client.clone(),
aux: aux.clone(),
conn_id,
data: ClientDataEnum::Http(ClientDataHttp()),
};
pool.write().conns.push_back(ConnPoolEntry {
conn: client,
_last_access: std::time::Instant::now(),
});
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.inc();
Arc::downgrade(&pool)
}
None => Weak::new(),
};
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let res = connection.await;
match res {
Ok(()) => info!("connection closed"),
Err(e) => error!(%session_id, "connection error: {e:?}"),
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(conn_id) {
info!("closed connection removed");
}
}
}
.instrument(span),
);
let client = ClientInnerCommon {
inner: client,
aux,
conn_id,
data: ClientDataEnum::Http(ClientDataHttp()),
};
Client::new(client)
}
pub(crate) struct Client<C: ClientInnerExt + Clone> {
pub(crate) inner: ClientInnerCommon<C>,
}
impl<C: ClientInnerExt + Clone> Client<C> {
pub(self) fn new(inner: ClientInnerCommon<C>) -> Self {
Self { inner }
}

View File

@@ -53,11 +53,11 @@ pub(crate) struct ClientDataLocal {
}
impl ClientDataLocal {
pub fn session(&self) -> &tokio::sync::watch::Sender<uuid::Uuid> {
&self.session
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
&mut self.session
}
pub fn cancel(&self) {
pub fn cancel(&mut self) {
self.cancel.cancel();
}
}
@@ -99,7 +99,7 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
.map(|entry| entry.conn);
// ok return cached connection if found and establish a new one otherwise
if let Some(client) = client {
if let Some(mut client) = client {
if client.inner.is_closed() {
info!("local_pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
@@ -120,9 +120,11 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
ClientDataEnum::Local(data) => {
data.session().send(ctx.session_id())?;
}
ClientDataEnum::Remote(data) => {
data.session().send(ctx.session_id())?;
}
ClientDataEnum::Http(_) => (),
}
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);

View File

@@ -271,6 +271,7 @@ impl BucketConfig {
"container {}, storage account {:?}, region {}",
config.container_name, config.storage_account, config.container_region
),
RemoteStorageKind::GCS(config) => format!("bucket {}", config.bucket_name),
}
}
pub fn bucket_name(&self) -> Option<&str> {
@@ -418,6 +419,9 @@ async fn init_remote(
config.prefix_in_container.get_or_insert(default_prefix);
}
RemoteStorageKind::LocalFs { .. } => (),
RemoteStorageKind::GCS(config) => {
config.prefix_in_bucket.get_or_insert(default_prefix);
}
}
// We already pass the prefix to the remote client above