mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-07 12:40:38 +00:00
Compare commits
4 Commits
conrad/pro
...
erik/grpc-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d2d29bf38 | ||
|
|
e8ebb8e433 | ||
|
|
232591e457 | ||
|
|
8daf272561 |
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -4236,6 +4236,7 @@ name = "pagebench"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"camino",
|
||||
"clap",
|
||||
"futures",
|
||||
@@ -4244,12 +4245,15 @@ dependencies = [
|
||||
"humantime-serde",
|
||||
"pageserver_api",
|
||||
"pageserver_client",
|
||||
"pageserver_page_api",
|
||||
"rand 0.8.5",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tonic 0.13.1",
|
||||
"tracing",
|
||||
"utils",
|
||||
"workspace_hack",
|
||||
|
||||
@@ -354,9 +354,6 @@ pub struct ShardImportProgressV1 {
|
||||
pub completed: usize,
|
||||
/// Hash of the plan
|
||||
pub import_plan_hash: u64,
|
||||
/// Soft limit for the job size
|
||||
/// This needs to remain constant throughout the import
|
||||
pub job_soft_size_limit: usize,
|
||||
}
|
||||
|
||||
impl ShardImportStatus {
|
||||
@@ -1934,7 +1931,7 @@ pub enum PagestreamFeMessage {
|
||||
}
|
||||
|
||||
// Wrapped in libpq CopyData
|
||||
#[derive(strum_macros::EnumProperty)]
|
||||
#[derive(Debug, strum_macros::EnumProperty)]
|
||||
pub enum PagestreamBeMessage {
|
||||
Exists(PagestreamExistsResponse),
|
||||
Nblocks(PagestreamNblocksResponse),
|
||||
@@ -2045,7 +2042,7 @@ pub enum PagestreamProtocolVersion {
|
||||
|
||||
pub type RequestId = u64;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct PagestreamRequest {
|
||||
pub reqid: RequestId,
|
||||
pub request_lsn: Lsn,
|
||||
@@ -2064,7 +2061,7 @@ pub struct PagestreamNblocksRequest {
|
||||
pub rel: RelTag,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct PagestreamGetPageRequest {
|
||||
pub hdr: PagestreamRequest,
|
||||
pub rel: RelTag,
|
||||
|
||||
@@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize};
|
||||
// FIXME: should move 'forknum' as last field to keep this consistent with Postgres.
|
||||
// Then we could replace the custom Ord and PartialOrd implementations below with
|
||||
// deriving them. This will require changes in walredoproc.c.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct RelTag {
|
||||
pub forknum: u8,
|
||||
pub spcnode: Oid,
|
||||
|
||||
@@ -584,6 +584,7 @@ impl TryFrom<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn try_from(segment: GetSlruSegmentResponse) -> Result<Self, Self::Error> {
|
||||
// TODO: can a segment legitimately be empty?
|
||||
if segment.is_empty() {
|
||||
return Err(ProtocolError::Missing("segment"));
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
camino.workspace = true
|
||||
clap.workspace = true
|
||||
futures.workspace = true
|
||||
@@ -15,14 +16,17 @@ hdrhistogram.workspace = true
|
||||
humantime.workspace = true
|
||||
humantime-serde.workspace = true
|
||||
rand.workspace = true
|
||||
reqwest.workspace=true
|
||||
reqwest.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
tracing.workspace = true
|
||||
tokio.workspace = true
|
||||
tokio-stream.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tonic.workspace = true
|
||||
|
||||
pageserver_client.workspace = true
|
||||
pageserver_api.workspace = true
|
||||
pageserver_page_api.workspace = true
|
||||
utils = { path = "../../libs/utils/" }
|
||||
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
|
||||
|
||||
@@ -7,11 +7,15 @@ use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use camino::Utf8PathBuf;
|
||||
use pageserver_api::key::Key;
|
||||
use pageserver_api::keyspace::KeySpaceAccum;
|
||||
use pageserver_api::models::{PagestreamGetPageRequest, PagestreamRequest};
|
||||
use pageserver_api::models::{
|
||||
PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamRequest,
|
||||
};
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use pageserver_page_api::proto;
|
||||
use rand::prelude::*;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -22,6 +26,12 @@ use utils::lsn::Lsn;
|
||||
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
|
||||
use crate::util::{request_stats, tokio_thread_local_stats};
|
||||
|
||||
#[derive(clap::ValueEnum, Clone, Debug)]
|
||||
enum Protocol {
|
||||
Libpq,
|
||||
Grpc,
|
||||
}
|
||||
|
||||
/// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace.
|
||||
#[derive(clap::Parser)]
|
||||
pub(crate) struct Args {
|
||||
@@ -35,6 +45,8 @@ pub(crate) struct Args {
|
||||
num_clients: NonZeroUsize,
|
||||
#[clap(long)]
|
||||
runtime: Option<humantime::Duration>,
|
||||
#[clap(long, value_enum, default_value = "libpq")]
|
||||
protocol: Protocol,
|
||||
/// Each client sends requests at the given rate.
|
||||
///
|
||||
/// If a request takes too long and we should be issuing a new request already,
|
||||
@@ -303,7 +315,20 @@ async fn main_impl(
|
||||
.unwrap();
|
||||
|
||||
Box::pin(async move {
|
||||
client_libpq(args, worker_id, ss, cancel, rps_period, ranges, weights).await
|
||||
let client: Box<dyn Client> = match args.protocol {
|
||||
Protocol::Libpq => Box::new(
|
||||
LibpqClient::new(args.page_service_connstring.clone(), worker_id.timeline)
|
||||
.await
|
||||
.unwrap(),
|
||||
),
|
||||
|
||||
Protocol::Grpc => Box::new(
|
||||
GrpcClient::new(args.page_service_connstring.clone(), worker_id.timeline)
|
||||
.await
|
||||
.unwrap(),
|
||||
),
|
||||
};
|
||||
run_worker(args, client, ss, cancel, rps_period, ranges, weights).await
|
||||
})
|
||||
};
|
||||
|
||||
@@ -355,23 +380,15 @@ async fn main_impl(
|
||||
anyhow::Ok(())
|
||||
}
|
||||
|
||||
async fn client_libpq(
|
||||
async fn run_worker(
|
||||
args: &Args,
|
||||
worker_id: WorkerId,
|
||||
mut client: Box<dyn Client>,
|
||||
shared_state: Arc<SharedState>,
|
||||
cancel: CancellationToken,
|
||||
rps_period: Option<Duration>,
|
||||
ranges: Vec<KeyRange>,
|
||||
weights: rand::distributions::weighted::WeightedIndex<i128>,
|
||||
) {
|
||||
let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
let mut client = client
|
||||
.pagestream(worker_id.timeline.tenant_id, worker_id.timeline.timeline_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
shared_state.start_work_barrier.wait().await;
|
||||
let client_start = Instant::now();
|
||||
let mut ticks_processed = 0;
|
||||
@@ -415,12 +432,12 @@ async fn client_libpq(
|
||||
blkno: block_no,
|
||||
}
|
||||
};
|
||||
client.getpage_send(req).await.unwrap();
|
||||
client.send_get_page(req).await.unwrap();
|
||||
inflight.push_back(start);
|
||||
}
|
||||
|
||||
let start = inflight.pop_front().unwrap();
|
||||
client.getpage_recv().await.unwrap();
|
||||
client.recv_get_page().await.unwrap();
|
||||
let end = Instant::now();
|
||||
shared_state.live_stats.request_done();
|
||||
ticks_processed += 1;
|
||||
@@ -442,3 +459,101 @@ async fn client_libpq(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A benchmark client, to allow switching out the transport protocol.
|
||||
///
|
||||
/// For simplicity, this just uses separate asynchronous send/recv methods. The send method could
|
||||
/// return a future that resolves when the response is received, but we don't really need it.
|
||||
#[async_trait]
|
||||
trait Client: Send {
|
||||
/// Sends an asynchronous GetPage request to the pageserver.
|
||||
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()>;
|
||||
|
||||
/// Receives the next GetPage response from the pageserver.
|
||||
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse>;
|
||||
}
|
||||
|
||||
/// A libpq-based Pageserver client.
|
||||
struct LibpqClient {
|
||||
inner: pageserver_client::page_service::PagestreamClient,
|
||||
}
|
||||
|
||||
impl LibpqClient {
|
||||
async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result<Self> {
|
||||
let inner = pageserver_client::page_service::Client::new(connstring)
|
||||
.await?
|
||||
.pagestream(ttid.tenant_id, ttid.timeline_id)
|
||||
.await?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Client for LibpqClient {
|
||||
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> {
|
||||
self.inner.getpage_send(req).await
|
||||
}
|
||||
|
||||
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse> {
|
||||
self.inner.getpage_recv().await
|
||||
}
|
||||
}
|
||||
|
||||
/// A gRPC client using the raw, no-frills gRPC client.
|
||||
struct GrpcClient {
|
||||
req_tx: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
|
||||
resp_rx: tonic::Streaming<proto::GetPageResponse>,
|
||||
}
|
||||
|
||||
impl GrpcClient {
|
||||
async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result<Self> {
|
||||
let mut client = pageserver_page_api::proto::PageServiceClient::connect(connstring).await?;
|
||||
|
||||
let (req_tx, req_rx) = tokio::sync::mpsc::channel(1);
|
||||
let req_stream = tokio_stream::wrappers::ReceiverStream::new(req_rx);
|
||||
let mut req = tonic::Request::new(req_stream);
|
||||
let metadata = req.metadata_mut();
|
||||
metadata.insert("neon-tenant-id", ttid.tenant_id.to_string().try_into()?);
|
||||
metadata.insert("neon-timeline-id", ttid.timeline_id.to_string().try_into()?);
|
||||
metadata.insert("neon-shard-id", "0000".try_into()?);
|
||||
|
||||
let resp = client.get_pages(req).await?;
|
||||
let resp_stream = resp.into_inner();
|
||||
|
||||
Ok(Self {
|
||||
req_tx,
|
||||
resp_rx: resp_stream,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Client for GrpcClient {
|
||||
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> {
|
||||
let req = proto::GetPageRequest {
|
||||
request_id: 0,
|
||||
request_class: proto::GetPageClass::Normal as i32,
|
||||
read_lsn: Some(proto::ReadLsn {
|
||||
request_lsn: req.hdr.request_lsn.0,
|
||||
not_modified_since_lsn: req.hdr.not_modified_since.0,
|
||||
}),
|
||||
rel: Some(req.rel.into()),
|
||||
block_number: vec![req.blkno],
|
||||
};
|
||||
self.req_tx.send(req).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse> {
|
||||
let resp = self.resp_rx.message().await?.unwrap();
|
||||
anyhow::ensure!(
|
||||
resp.status_code == proto::GetPageStatusCode::Ok as i32,
|
||||
"unexpected status code: {}",
|
||||
resp.status_code
|
||||
);
|
||||
Ok(PagestreamGetPageResponse {
|
||||
page: resp.page_image[0].clone(),
|
||||
req: PagestreamGetPageRequest::default(), // dummy
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -804,7 +804,7 @@ fn start_pageserver(
|
||||
} else {
|
||||
None
|
||||
},
|
||||
basebackup_cache.clone(),
|
||||
basebackup_cache,
|
||||
);
|
||||
|
||||
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for
|
||||
@@ -816,12 +816,10 @@ fn start_pageserver(
|
||||
let mut page_service_grpc = None;
|
||||
if let Some(grpc_listener) = grpc_listener {
|
||||
page_service_grpc = Some(page_service::spawn_grpc(
|
||||
conf,
|
||||
tenant_manager.clone(),
|
||||
grpc_auth,
|
||||
otel_guard.as_ref().map(|g| g.dispatch.clone()),
|
||||
grpc_listener,
|
||||
basebackup_cache,
|
||||
)?);
|
||||
}
|
||||
|
||||
|
||||
@@ -837,30 +837,7 @@ async fn collect_eviction_candidates(
|
||||
continue;
|
||||
}
|
||||
let info = tl.get_local_layers_for_disk_usage_eviction().await;
|
||||
debug!(
|
||||
tenant_id=%tl.tenant_shard_id.tenant_id,
|
||||
shard_id=%tl.tenant_shard_id.shard_slug(),
|
||||
timeline_id=%tl.timeline_id,
|
||||
"timeline resident layers count: {}", info.resident_layers.len()
|
||||
);
|
||||
|
||||
tenant_candidates.extend(info.resident_layers.into_iter());
|
||||
max_layer_size = max_layer_size.max(info.max_layer_size.unwrap_or(0));
|
||||
|
||||
if cancel.is_cancelled() {
|
||||
return Ok(EvictionCandidates::Cancelled);
|
||||
}
|
||||
}
|
||||
|
||||
// Also consider layers of timelines being imported for eviction
|
||||
for tl in tenant.list_importing_timelines() {
|
||||
let info = tl.timeline.get_local_layers_for_disk_usage_eviction().await;
|
||||
debug!(
|
||||
tenant_id=%tl.timeline.tenant_shard_id.tenant_id,
|
||||
shard_id=%tl.timeline.tenant_shard_id.shard_slug(),
|
||||
timeline_id=%tl.timeline.timeline_id,
|
||||
"timeline resident layers count: {}", info.resident_layers.len()
|
||||
);
|
||||
debug!(tenant_id=%tl.tenant_shard_id.tenant_id, shard_id=%tl.tenant_shard_id.shard_slug(), timeline_id=%tl.timeline_id, "timeline resident layers count: {}", info.resident_layers.len());
|
||||
|
||||
tenant_candidates.extend(info.resident_layers.into_iter());
|
||||
max_layer_size = max_layer_size.max(info.max_layer_size.unwrap_or(0));
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! The Page Service listens for client connections and serves their GetPage@LSN
|
||||
//! requests.
|
||||
|
||||
use std::any::Any;
|
||||
use std::borrow::Cow;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::os::fd::AsRawFd;
|
||||
@@ -31,6 +32,7 @@ use pageserver_api::models::{
|
||||
};
|
||||
use pageserver_api::reltag::SlruKind;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use pageserver_page_api as page_api;
|
||||
use pageserver_page_api::proto;
|
||||
use postgres_backend::{
|
||||
AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error,
|
||||
@@ -39,6 +41,7 @@ use postgres_ffi::BLCKSZ;
|
||||
use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID;
|
||||
use pq_proto::framed::ConnectionError;
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, RowDescriptor};
|
||||
use smallvec::{SmallVec, smallvec};
|
||||
use strum_macros::IntoStaticStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter};
|
||||
use tokio::task::JoinHandle;
|
||||
@@ -76,6 +79,7 @@ use crate::tenant::mgr::{
|
||||
GetActiveTenantError, GetTenantError, ShardResolveResult, ShardSelector, TenantManager,
|
||||
};
|
||||
use crate::tenant::storage_layer::IoConcurrency;
|
||||
use crate::tenant::timeline::handle::{HandleUpgradeError, WeakHandle};
|
||||
use crate::tenant::timeline::{self, WaitLsnError};
|
||||
use crate::tenant::{GetTimelineError, PageReconstructError, Timeline};
|
||||
use crate::{CancellableTask, PERF_TRACE_TARGET, timed_after_cancellation};
|
||||
@@ -165,15 +169,14 @@ pub fn spawn(
|
||||
|
||||
/// Spawns a gRPC server for the page service.
|
||||
///
|
||||
/// TODO: move this onto GrpcPageServiceHandler::spawn().
|
||||
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
|
||||
/// need to reimplement the TCP+TLS accept loop ourselves.
|
||||
pub fn spawn_grpc(
|
||||
conf: &'static PageServerConf,
|
||||
tenant_manager: Arc<TenantManager>,
|
||||
auth: Option<Arc<SwappableJwtAuth>>,
|
||||
perf_trace_dispatch: Option<Dispatch>,
|
||||
listener: std::net::TcpListener,
|
||||
basebackup_cache: Arc<BasebackupCache>,
|
||||
) -> anyhow::Result<CancellableTask> {
|
||||
let cancel = CancellationToken::new();
|
||||
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
|
||||
@@ -202,21 +205,16 @@ pub fn spawn_grpc(
|
||||
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
|
||||
|
||||
// Main page service.
|
||||
let page_service_handler = PageServerHandler::new(
|
||||
let page_service_handler = GrpcPageServiceHandler {
|
||||
tenant_manager,
|
||||
auth.clone(),
|
||||
PageServicePipeliningConfig::Serial, // TODO: unused with gRPC
|
||||
conf.get_vectored_concurrent_io,
|
||||
ConnectionPerfSpanFields::default(),
|
||||
basebackup_cache,
|
||||
ctx,
|
||||
cancel.clone(),
|
||||
gate.enter().expect("just created"),
|
||||
);
|
||||
};
|
||||
|
||||
let mut received_at_interceptor = ReceivedAtInterceptor;
|
||||
let mut tenant_interceptor = TenantMetadataInterceptor;
|
||||
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
|
||||
let interceptors = move |mut req: tonic::Request<()>| {
|
||||
req = received_at_interceptor.call(req)?;
|
||||
req = tenant_interceptor.call(req)?;
|
||||
req = auth_interceptor.call(req)?;
|
||||
Ok(req)
|
||||
@@ -709,6 +707,89 @@ enum PageStreamError {
|
||||
BadRequest(Cow<'static, str>),
|
||||
}
|
||||
|
||||
impl PageStreamError {
|
||||
/// Converts a PageStreamError into a proto::GetPageResponse with the appropriate status
|
||||
/// code, or a gRPC status if it should terminate the stream (e.g. shutdown). This is a
|
||||
/// convenience method for use from a get_pages gRPC stream.
|
||||
#[allow(clippy::result_large_err)]
|
||||
fn into_get_page_response(
|
||||
self,
|
||||
request_id: page_api::RequestID,
|
||||
) -> Result<proto::GetPageResponse, tonic::Status> {
|
||||
use page_api::GetPageStatusCode;
|
||||
use tonic::Code;
|
||||
|
||||
// We dispatch to Into<tonic::Status> first, and then map it to a GetPageResponse.
|
||||
let status: tonic::Status = self.into();
|
||||
let status_code = match status.code() {
|
||||
// We shouldn't see an OK status here, because we're emitting an error.
|
||||
Code::Ok => {
|
||||
debug_assert_ne!(status.code(), Code::Ok);
|
||||
return Err(tonic::Status::internal(format!(
|
||||
"unexpected OK status: {status:?}",
|
||||
)));
|
||||
}
|
||||
|
||||
// These are per-request errors, returned as GetPageResponses.
|
||||
Code::AlreadyExists => GetPageStatusCode::InvalidRequest,
|
||||
Code::DataLoss => GetPageStatusCode::InternalError,
|
||||
Code::FailedPrecondition => GetPageStatusCode::InvalidRequest,
|
||||
Code::InvalidArgument => GetPageStatusCode::InvalidRequest,
|
||||
Code::Internal => GetPageStatusCode::InternalError,
|
||||
Code::NotFound => GetPageStatusCode::NotFound,
|
||||
Code::OutOfRange => GetPageStatusCode::InvalidRequest,
|
||||
Code::ResourceExhausted => GetPageStatusCode::SlowDown,
|
||||
|
||||
// These should terminate the stream.
|
||||
Code::Aborted => return Err(status),
|
||||
Code::Cancelled => return Err(status),
|
||||
Code::DeadlineExceeded => return Err(status),
|
||||
Code::PermissionDenied => return Err(status),
|
||||
Code::Unauthenticated => return Err(status),
|
||||
Code::Unavailable => return Err(status),
|
||||
Code::Unimplemented => return Err(status),
|
||||
Code::Unknown => return Err(status),
|
||||
};
|
||||
|
||||
Ok(page_api::GetPageResponse {
|
||||
request_id,
|
||||
status_code,
|
||||
reason: Some(status.message().to_string()),
|
||||
page_images: SmallVec::new(),
|
||||
}
|
||||
.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PageStreamError> for tonic::Status {
|
||||
fn from(err: PageStreamError) -> Self {
|
||||
use tonic::Code;
|
||||
let code = match &err {
|
||||
PageStreamError::Reconnect(_) => Code::Unavailable,
|
||||
PageStreamError::Shutdown => Code::Unavailable,
|
||||
PageStreamError::Read(err) => match err {
|
||||
PageReconstructError::Cancelled => Code::Unavailable,
|
||||
PageReconstructError::MissingKey(_) => Code::NotFound,
|
||||
PageReconstructError::AncestorLsnTimeout(err) => match err {
|
||||
WaitLsnError::Timeout(_) => Code::Internal,
|
||||
WaitLsnError::BadState(_) => Code::Internal,
|
||||
WaitLsnError::Shutdown => Code::Unavailable,
|
||||
},
|
||||
PageReconstructError::Other(_) => Code::Internal,
|
||||
PageReconstructError::WalRedo(_) => Code::Internal,
|
||||
},
|
||||
PageStreamError::LsnTimeout(err) => match err {
|
||||
WaitLsnError::Timeout(_) => Code::Internal,
|
||||
WaitLsnError::BadState(_) => Code::Internal,
|
||||
WaitLsnError::Shutdown => Code::Unavailable,
|
||||
},
|
||||
PageStreamError::NotFound(_) => Code::NotFound,
|
||||
PageStreamError::BadRequest(_) => Code::InvalidArgument,
|
||||
};
|
||||
tonic::Status::new(code, format!("{err}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PageReconstructError> for PageStreamError {
|
||||
fn from(value: PageReconstructError) -> Self {
|
||||
match value {
|
||||
@@ -789,37 +870,37 @@ enum BatchedFeMessage {
|
||||
Exists {
|
||||
span: Span,
|
||||
timer: SmgrOpTimer,
|
||||
shard: timeline::handle::WeakHandle<TenantManagerTypes>,
|
||||
shard: WeakHandle<TenantManagerTypes>,
|
||||
req: models::PagestreamExistsRequest,
|
||||
},
|
||||
Nblocks {
|
||||
span: Span,
|
||||
timer: SmgrOpTimer,
|
||||
shard: timeline::handle::WeakHandle<TenantManagerTypes>,
|
||||
shard: WeakHandle<TenantManagerTypes>,
|
||||
req: models::PagestreamNblocksRequest,
|
||||
},
|
||||
GetPage {
|
||||
span: Span,
|
||||
shard: timeline::handle::WeakHandle<TenantManagerTypes>,
|
||||
pages: smallvec::SmallVec<[BatchedGetPageRequest; 1]>,
|
||||
shard: WeakHandle<TenantManagerTypes>,
|
||||
pages: SmallVec<[BatchedGetPageRequest; 1]>,
|
||||
batch_break_reason: GetPageBatchBreakReason,
|
||||
},
|
||||
DbSize {
|
||||
span: Span,
|
||||
timer: SmgrOpTimer,
|
||||
shard: timeline::handle::WeakHandle<TenantManagerTypes>,
|
||||
shard: WeakHandle<TenantManagerTypes>,
|
||||
req: models::PagestreamDbSizeRequest,
|
||||
},
|
||||
GetSlruSegment {
|
||||
span: Span,
|
||||
timer: SmgrOpTimer,
|
||||
shard: timeline::handle::WeakHandle<TenantManagerTypes>,
|
||||
shard: WeakHandle<TenantManagerTypes>,
|
||||
req: models::PagestreamGetSlruSegmentRequest,
|
||||
},
|
||||
#[cfg(feature = "testing")]
|
||||
Test {
|
||||
span: Span,
|
||||
shard: timeline::handle::WeakHandle<TenantManagerTypes>,
|
||||
shard: WeakHandle<TenantManagerTypes>,
|
||||
requests: Vec<BatchedTestRequest>,
|
||||
},
|
||||
RespondError {
|
||||
@@ -1068,26 +1149,6 @@ impl PageServerHandler {
|
||||
let neon_fe_msg =
|
||||
PagestreamFeMessage::parse(&mut copy_data_bytes.reader(), protocol_version)?;
|
||||
|
||||
// TODO: turn in to async closure once available to avoid repeating received_at
|
||||
async fn record_op_start_and_throttle(
|
||||
shard: &timeline::handle::Handle<TenantManagerTypes>,
|
||||
op: metrics::SmgrQueryType,
|
||||
received_at: Instant,
|
||||
) -> Result<SmgrOpTimer, QueryError> {
|
||||
// It's important to start the smgr op metric recorder as early as possible
|
||||
// so that the _started counters are incremented before we do
|
||||
// any serious waiting, e.g., for throttle, batching, or actual request handling.
|
||||
let mut timer = shard.query_metrics.start_smgr_op(op, received_at);
|
||||
let now = Instant::now();
|
||||
timer.observe_throttle_start(now);
|
||||
let throttled = tokio::select! {
|
||||
res = shard.pagestream_throttle.throttle(1, now) => res,
|
||||
_ = shard.cancel.cancelled() => return Err(QueryError::Shutdown),
|
||||
};
|
||||
timer.observe_throttle_done(throttled);
|
||||
Ok(timer)
|
||||
}
|
||||
|
||||
let batched_msg = match neon_fe_msg {
|
||||
PagestreamFeMessage::Exists(req) => {
|
||||
let shard = timeline_handles
|
||||
@@ -1095,7 +1156,7 @@ impl PageServerHandler {
|
||||
.await?;
|
||||
debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id();
|
||||
let span = tracing::info_span!(parent: &parent_span, "handle_get_rel_exists_request", rel = %req.rel, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug());
|
||||
let timer = record_op_start_and_throttle(
|
||||
let timer = Self::record_op_start_and_throttle(
|
||||
&shard,
|
||||
metrics::SmgrQueryType::GetRelExists,
|
||||
received_at,
|
||||
@@ -1113,7 +1174,7 @@ impl PageServerHandler {
|
||||
.get(tenant_id, timeline_id, ShardSelector::Zero)
|
||||
.await?;
|
||||
let span = tracing::info_span!(parent: &parent_span, "handle_get_nblocks_request", rel = %req.rel, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug());
|
||||
let timer = record_op_start_and_throttle(
|
||||
let timer = Self::record_op_start_and_throttle(
|
||||
&shard,
|
||||
metrics::SmgrQueryType::GetRelSize,
|
||||
received_at,
|
||||
@@ -1131,7 +1192,7 @@ impl PageServerHandler {
|
||||
.get(tenant_id, timeline_id, ShardSelector::Zero)
|
||||
.await?;
|
||||
let span = tracing::info_span!(parent: &parent_span, "handle_db_size_request", dbnode = %req.dbnode, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug());
|
||||
let timer = record_op_start_and_throttle(
|
||||
let timer = Self::record_op_start_and_throttle(
|
||||
&shard,
|
||||
metrics::SmgrQueryType::GetDbSize,
|
||||
received_at,
|
||||
@@ -1149,7 +1210,7 @@ impl PageServerHandler {
|
||||
.get(tenant_id, timeline_id, ShardSelector::Zero)
|
||||
.await?;
|
||||
let span = tracing::info_span!(parent: &parent_span, "handle_get_slru_segment_request", kind = %req.kind, segno = %req.segno, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug());
|
||||
let timer = record_op_start_and_throttle(
|
||||
let timer = Self::record_op_start_and_throttle(
|
||||
&shard,
|
||||
metrics::SmgrQueryType::GetSlruSegment,
|
||||
received_at,
|
||||
@@ -1274,7 +1335,7 @@ impl PageServerHandler {
|
||||
// request handler log messages contain the request-specific fields.
|
||||
let span = mkspan!(shard.tenant_shard_id.shard_slug());
|
||||
|
||||
let timer = record_op_start_and_throttle(
|
||||
let timer = Self::record_op_start_and_throttle(
|
||||
&shard,
|
||||
metrics::SmgrQueryType::GetPageAtLsn,
|
||||
received_at,
|
||||
@@ -1321,7 +1382,7 @@ impl PageServerHandler {
|
||||
BatchedFeMessage::GetPage {
|
||||
span,
|
||||
shard: shard.downgrade(),
|
||||
pages: smallvec::smallvec![BatchedGetPageRequest {
|
||||
pages: smallvec![BatchedGetPageRequest {
|
||||
req,
|
||||
timer,
|
||||
lsn_range: LsnRange {
|
||||
@@ -1343,9 +1404,12 @@ impl PageServerHandler {
|
||||
.get(tenant_id, timeline_id, ShardSelector::Zero)
|
||||
.await?;
|
||||
let span = tracing::info_span!(parent: &parent_span, "handle_test_request", shard_id = %shard.tenant_shard_id.shard_slug());
|
||||
let timer =
|
||||
record_op_start_and_throttle(&shard, metrics::SmgrQueryType::Test, received_at)
|
||||
.await?;
|
||||
let timer = Self::record_op_start_and_throttle(
|
||||
&shard,
|
||||
metrics::SmgrQueryType::Test,
|
||||
received_at,
|
||||
)
|
||||
.await?;
|
||||
BatchedFeMessage::Test {
|
||||
span,
|
||||
shard: shard.downgrade(),
|
||||
@@ -1356,6 +1420,26 @@ impl PageServerHandler {
|
||||
Ok(Some(batched_msg))
|
||||
}
|
||||
|
||||
/// Starts a SmgrOpTimer at received_at and throttles the request.
|
||||
async fn record_op_start_and_throttle(
|
||||
shard: &timeline::handle::Handle<TenantManagerTypes>,
|
||||
op: metrics::SmgrQueryType,
|
||||
received_at: Instant,
|
||||
) -> Result<SmgrOpTimer, QueryError> {
|
||||
// It's important to start the smgr op metric recorder as early as possible
|
||||
// so that the _started counters are incremented before we do
|
||||
// any serious waiting, e.g., for throttle, batching, or actual request handling.
|
||||
let mut timer = shard.query_metrics.start_smgr_op(op, received_at);
|
||||
let now = Instant::now();
|
||||
timer.observe_throttle_start(now);
|
||||
let throttled = tokio::select! {
|
||||
res = shard.pagestream_throttle.throttle(1, now) => res,
|
||||
_ = shard.cancel.cancelled() => return Err(QueryError::Shutdown),
|
||||
};
|
||||
timer.observe_throttle_done(throttled);
|
||||
Ok(timer)
|
||||
}
|
||||
|
||||
/// Post-condition: `batch` is Some()
|
||||
#[instrument(skip_all, level = tracing::Level::TRACE)]
|
||||
#[allow(clippy::boxed_local)]
|
||||
@@ -1453,8 +1537,11 @@ impl PageServerHandler {
|
||||
let (mut handler_results, span) = {
|
||||
// TODO: we unfortunately have to pin the future on the heap, since GetPage futures are huge and
|
||||
// won't fit on the stack.
|
||||
let mut boxpinned =
|
||||
Box::pin(self.pagestream_dispatch_batched_message(batch, io_concurrency, ctx));
|
||||
let mut boxpinned = Box::pin(Self::pagestream_dispatch_batched_message(
|
||||
batch,
|
||||
io_concurrency,
|
||||
ctx,
|
||||
));
|
||||
log_slow(
|
||||
log_slow_name,
|
||||
LOG_SLOW_GETPAGE_THRESHOLD,
|
||||
@@ -1610,7 +1697,6 @@ impl PageServerHandler {
|
||||
/// Helper which dispatches a batched message to the appropriate handler.
|
||||
/// Returns a vec of results, along with the extracted trace span.
|
||||
async fn pagestream_dispatch_batched_message(
|
||||
&mut self,
|
||||
batch: BatchedFeMessage,
|
||||
io_concurrency: IoConcurrency,
|
||||
ctx: &RequestContext,
|
||||
@@ -1640,10 +1726,10 @@ impl PageServerHandler {
|
||||
let (shard, ctx) = upgrade_handle_and_set_context!(shard);
|
||||
(
|
||||
vec![
|
||||
self.handle_get_rel_exists_request(&shard, &req, &ctx)
|
||||
Self::handle_get_rel_exists_request(&shard, &req, &ctx)
|
||||
.instrument(span.clone())
|
||||
.await
|
||||
.map(|msg| (msg, timer, ctx))
|
||||
.map(|msg| (PagestreamBeMessage::Exists(msg), timer, ctx))
|
||||
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
|
||||
],
|
||||
span,
|
||||
@@ -1659,10 +1745,10 @@ impl PageServerHandler {
|
||||
let (shard, ctx) = upgrade_handle_and_set_context!(shard);
|
||||
(
|
||||
vec![
|
||||
self.handle_get_nblocks_request(&shard, &req, &ctx)
|
||||
Self::handle_get_nblocks_request(&shard, &req, &ctx)
|
||||
.instrument(span.clone())
|
||||
.await
|
||||
.map(|msg| (msg, timer, ctx))
|
||||
.map(|msg| (PagestreamBeMessage::Nblocks(msg), timer, ctx))
|
||||
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
|
||||
],
|
||||
span,
|
||||
@@ -1680,16 +1766,15 @@ impl PageServerHandler {
|
||||
{
|
||||
let npages = pages.len();
|
||||
trace!(npages, "handling getpage request");
|
||||
let res = self
|
||||
.handle_get_page_at_lsn_request_batched(
|
||||
&shard,
|
||||
pages,
|
||||
io_concurrency,
|
||||
batch_break_reason,
|
||||
&ctx,
|
||||
)
|
||||
.instrument(span.clone())
|
||||
.await;
|
||||
let res = Self::handle_get_page_at_lsn_request_batched(
|
||||
&shard,
|
||||
pages,
|
||||
io_concurrency,
|
||||
batch_break_reason,
|
||||
&ctx,
|
||||
)
|
||||
.instrument(span.clone())
|
||||
.await;
|
||||
assert_eq!(res.len(), npages);
|
||||
res
|
||||
},
|
||||
@@ -1706,10 +1791,10 @@ impl PageServerHandler {
|
||||
let (shard, ctx) = upgrade_handle_and_set_context!(shard);
|
||||
(
|
||||
vec![
|
||||
self.handle_db_size_request(&shard, &req, &ctx)
|
||||
Self::handle_db_size_request(&shard, &req, &ctx)
|
||||
.instrument(span.clone())
|
||||
.await
|
||||
.map(|msg| (msg, timer, ctx))
|
||||
.map(|msg| (PagestreamBeMessage::DbSize(msg), timer, ctx))
|
||||
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
|
||||
],
|
||||
span,
|
||||
@@ -1725,10 +1810,10 @@ impl PageServerHandler {
|
||||
let (shard, ctx) = upgrade_handle_and_set_context!(shard);
|
||||
(
|
||||
vec![
|
||||
self.handle_get_slru_segment_request(&shard, &req, &ctx)
|
||||
Self::handle_get_slru_segment_request(&shard, &req, &ctx)
|
||||
.instrument(span.clone())
|
||||
.await
|
||||
.map(|msg| (msg, timer, ctx))
|
||||
.map(|msg| (PagestreamBeMessage::GetSlruSegment(msg), timer, ctx))
|
||||
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
|
||||
],
|
||||
span,
|
||||
@@ -1746,8 +1831,7 @@ impl PageServerHandler {
|
||||
{
|
||||
let npages = requests.len();
|
||||
trace!(npages, "handling getpage request");
|
||||
let res = self
|
||||
.handle_test_request_batch(&shard, requests, &ctx)
|
||||
let res = Self::handle_test_request_batch(&shard, requests, &ctx)
|
||||
.instrument(span.clone())
|
||||
.await;
|
||||
assert_eq!(res.len(), npages);
|
||||
@@ -2301,11 +2385,10 @@ impl PageServerHandler {
|
||||
|
||||
#[instrument(skip_all, fields(shard_id))]
|
||||
async fn handle_get_rel_exists_request(
|
||||
&mut self,
|
||||
timeline: &Timeline,
|
||||
req: &PagestreamExistsRequest,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<PagestreamBeMessage, PageStreamError> {
|
||||
) -> Result<PagestreamExistsResponse, PageStreamError> {
|
||||
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
|
||||
let lsn = Self::wait_or_get_last_lsn(
|
||||
timeline,
|
||||
@@ -2327,19 +2410,15 @@ impl PageServerHandler {
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(PagestreamBeMessage::Exists(PagestreamExistsResponse {
|
||||
req: *req,
|
||||
exists,
|
||||
}))
|
||||
Ok(PagestreamExistsResponse { req: *req, exists })
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(shard_id))]
|
||||
async fn handle_get_nblocks_request(
|
||||
&mut self,
|
||||
timeline: &Timeline,
|
||||
req: &PagestreamNblocksRequest,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<PagestreamBeMessage, PageStreamError> {
|
||||
) -> Result<PagestreamNblocksResponse, PageStreamError> {
|
||||
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
|
||||
let lsn = Self::wait_or_get_last_lsn(
|
||||
timeline,
|
||||
@@ -2361,19 +2440,18 @@ impl PageServerHandler {
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(PagestreamBeMessage::Nblocks(PagestreamNblocksResponse {
|
||||
Ok(PagestreamNblocksResponse {
|
||||
req: *req,
|
||||
n_blocks,
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(shard_id))]
|
||||
async fn handle_db_size_request(
|
||||
&mut self,
|
||||
timeline: &Timeline,
|
||||
req: &PagestreamDbSizeRequest,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<PagestreamBeMessage, PageStreamError> {
|
||||
) -> Result<PagestreamDbSizeResponse, PageStreamError> {
|
||||
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
|
||||
let lsn = Self::wait_or_get_last_lsn(
|
||||
timeline,
|
||||
@@ -2397,23 +2475,19 @@ impl PageServerHandler {
|
||||
.await?;
|
||||
let db_size = total_blocks as i64 * BLCKSZ as i64;
|
||||
|
||||
Ok(PagestreamBeMessage::DbSize(PagestreamDbSizeResponse {
|
||||
req: *req,
|
||||
db_size,
|
||||
}))
|
||||
Ok(PagestreamDbSizeResponse { req: *req, db_size })
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn handle_get_page_at_lsn_request_batched(
|
||||
&mut self,
|
||||
timeline: &Timeline,
|
||||
requests: smallvec::SmallVec<[BatchedGetPageRequest; 1]>,
|
||||
requests: SmallVec<[BatchedGetPageRequest; 1]>,
|
||||
io_concurrency: IoConcurrency,
|
||||
batch_break_reason: GetPageBatchBreakReason,
|
||||
ctx: &RequestContext,
|
||||
) -> Vec<Result<(PagestreamBeMessage, SmgrOpTimer, RequestContext), BatchedPageStreamError>>
|
||||
{
|
||||
debug_assert_current_span_has_tenant_and_timeline_id();
|
||||
//debug_assert_current_span_has_tenant_and_timeline_id();
|
||||
|
||||
timeline
|
||||
.query_metrics
|
||||
@@ -2532,11 +2606,10 @@ impl PageServerHandler {
|
||||
|
||||
#[instrument(skip_all, fields(shard_id))]
|
||||
async fn handle_get_slru_segment_request(
|
||||
&mut self,
|
||||
timeline: &Timeline,
|
||||
req: &PagestreamGetSlruSegmentRequest,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<PagestreamBeMessage, PageStreamError> {
|
||||
) -> Result<PagestreamGetSlruSegmentResponse, PageStreamError> {
|
||||
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
|
||||
let lsn = Self::wait_or_get_last_lsn(
|
||||
timeline,
|
||||
@@ -2551,16 +2624,13 @@ impl PageServerHandler {
|
||||
.ok_or(PageStreamError::BadRequest("invalid SLRU kind".into()))?;
|
||||
let segment = timeline.get_slru_segment(kind, req.segno, lsn, ctx).await?;
|
||||
|
||||
Ok(PagestreamBeMessage::GetSlruSegment(
|
||||
PagestreamGetSlruSegmentResponse { req: *req, segment },
|
||||
))
|
||||
Ok(PagestreamGetSlruSegmentResponse { req: *req, segment })
|
||||
}
|
||||
|
||||
// NB: this impl mimics what we do for batched getpage requests.
|
||||
#[cfg(feature = "testing")]
|
||||
#[instrument(skip_all, fields(shard_id))]
|
||||
async fn handle_test_request_batch(
|
||||
&mut self,
|
||||
timeline: &Timeline,
|
||||
requests: Vec<BatchedTestRequest>,
|
||||
_ctx: &RequestContext,
|
||||
@@ -3300,57 +3370,342 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements the page service over gRPC.
|
||||
/// Serves the page service over gRPC. Dispatches to PageServerHandler for request processing.
|
||||
///
|
||||
/// TODO: not yet implemented, all methods return unimplemented.
|
||||
/// TODO: add trace spans, interceptors, and sampling.
|
||||
/// TODO: rename to PageServiceHandler when libpq impl is removed.
|
||||
pub struct GrpcPageServiceHandler {
|
||||
tenant_manager: Arc<TenantManager>,
|
||||
ctx: RequestContext,
|
||||
}
|
||||
|
||||
impl GrpcPageServiceHandler {
|
||||
/// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of
|
||||
/// relations and their sizes, as well as SLRU segments and other data.
|
||||
#[allow(clippy::result_large_err)]
|
||||
fn ensure_shard_zero(req: &tonic::Request<impl Any>) -> Result<(), tonic::Status> {
|
||||
match Self::extract::<ShardIndex>(req).shard_number.0 {
|
||||
0 => Ok(()),
|
||||
shard => Err(tonic::Status::invalid_argument(format!(
|
||||
"request must execute on shard zero (is shard {shard})",
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts the given type from the request extensions. It must have been set by an
|
||||
/// interceptor.
|
||||
fn extract<T: Send + Sync + 'static>(req: &tonic::Request<impl Any>) -> &T {
|
||||
req.extensions()
|
||||
.get::<T>()
|
||||
.expect("extension should be set by interceptor")
|
||||
}
|
||||
|
||||
/// Generates a PagestreamRequest header from a ReadLsn and request ID.
|
||||
fn make_hdr(read_lsn: page_api::ReadLsn, req_id: u64) -> PagestreamRequest {
|
||||
PagestreamRequest {
|
||||
reqid: req_id,
|
||||
request_lsn: read_lsn.request_lsn,
|
||||
not_modified_since: read_lsn.not_modified_since_lsn.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Acquires a timeline handle for the given request. The request must have been decorated by
|
||||
/// TenantMetadataInterceptor first.
|
||||
async fn get_request_timeline(
|
||||
&self,
|
||||
req: &tonic::Request<impl Any>,
|
||||
) -> Result<timeline::handle::Handle<TenantManagerTypes>, GetActiveTimelineError> {
|
||||
let ttid = *Self::extract::<TenantTimelineId>(req);
|
||||
let shard_index = *Self::extract::<ShardIndex>(req);
|
||||
let shard_selector = ShardSelector::Known(shard_index);
|
||||
|
||||
// TODO: untangle this from TenantManagerWrapper::resolve() and Cache::get(), to avoid the
|
||||
// unnecessary overhead.
|
||||
TimelineHandles::new(self.tenant_manager.clone())
|
||||
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Starts a SmgrOpTimer at received_at, throttles the request, and records execution start.
|
||||
/// Only errors if the timeline is shutting down.
|
||||
///
|
||||
/// TODO: revamp request timers -- in particular,
|
||||
/// TODO: consider moving throttling out and returning SlowDown errors.
|
||||
async fn record_op_start_and_throttle(
|
||||
timeline: &timeline::handle::Handle<TenantManagerTypes>,
|
||||
op: metrics::SmgrQueryType,
|
||||
received_at: Instant,
|
||||
) -> Result<SmgrOpTimer, tonic::Status> {
|
||||
let mut timer = PageServerHandler::record_op_start_and_throttle(timeline, op, received_at)
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
// record_op_start_and_throttle() only returns Shutdown.
|
||||
QueryError::Shutdown => tonic::Status::unavailable(format!("{err}")),
|
||||
err => tonic::Status::internal(format!("unexpected error: {err}")),
|
||||
})?;
|
||||
timer.observe_execution_start(Instant::now());
|
||||
Ok(timer)
|
||||
}
|
||||
|
||||
/// Processes a GetPage batch request, via the GetPages bidirectional streaming RPC.
|
||||
///
|
||||
/// NB: errors will terminate the stream. Per-request errors should return a GetPageResponse
|
||||
/// with an appropriate status code instead.
|
||||
async fn handle_get_page_request(
|
||||
ctx: &RequestContext,
|
||||
timeline: &WeakHandle<TenantManagerTypes>,
|
||||
req: proto::GetPageRequest,
|
||||
io_concurrency: IoConcurrency,
|
||||
) -> Result<proto::GetPageResponse, tonic::Status> {
|
||||
let received_at = Instant::now();
|
||||
let timeline = timeline.upgrade().map_err(|err| match err {
|
||||
HandleUpgradeError::ShutDown => tonic::Status::unavailable("timeline is shutting down"),
|
||||
})?;
|
||||
let ctx = ctx.with_scope_page_service_pagestream(&timeline);
|
||||
|
||||
// Validate the request and convert it to a Pagestream request.
|
||||
let req: page_api::GetPageRequest = req.try_into()?;
|
||||
|
||||
let effective_lsn = match PageServerHandler::effective_request_lsn(
|
||||
&timeline,
|
||||
timeline.get_last_record_lsn(),
|
||||
req.read_lsn.request_lsn,
|
||||
req.read_lsn.not_modified_since_lsn.unwrap_or_default(),
|
||||
&timeline.get_applied_gc_cutoff_lsn(),
|
||||
) {
|
||||
Ok(lsn) => lsn,
|
||||
Err(err) => return err.into_get_page_response(req.request_id),
|
||||
};
|
||||
|
||||
let mut batch = SmallVec::with_capacity(req.block_numbers.len());
|
||||
for blkno in req.block_numbers {
|
||||
// TODO: this creates one timer per page and throttles it. We should have a timer for
|
||||
// the entire batch, and throttle only the batch, but this is equivalent to what
|
||||
// PageServerHandler does already so we keep it for now.
|
||||
let timer = Self::record_op_start_and_throttle(
|
||||
&timeline,
|
||||
metrics::SmgrQueryType::GetPageAtLsn,
|
||||
received_at,
|
||||
)
|
||||
.await?;
|
||||
|
||||
batch.push(BatchedGetPageRequest {
|
||||
req: PagestreamGetPageRequest {
|
||||
hdr: Self::make_hdr(req.read_lsn, req.request_id),
|
||||
rel: req.rel,
|
||||
blkno,
|
||||
},
|
||||
lsn_range: LsnRange {
|
||||
effective_lsn,
|
||||
request_lsn: req.read_lsn.request_lsn,
|
||||
},
|
||||
timer,
|
||||
ctx: ctx.attached_child(),
|
||||
batch_wait_ctx: None, // TODO: add tracing
|
||||
});
|
||||
}
|
||||
|
||||
let results = PageServerHandler::handle_get_page_at_lsn_request_batched(
|
||||
&timeline,
|
||||
batch,
|
||||
io_concurrency,
|
||||
GetPageBatchBreakReason::BatchFull, // TODO: not relevant for gRPC batches
|
||||
&ctx,
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut resp = page_api::GetPageResponse {
|
||||
request_id: req.request_id,
|
||||
status_code: page_api::GetPageStatusCode::Ok,
|
||||
reason: None,
|
||||
page_images: SmallVec::with_capacity(results.len()),
|
||||
};
|
||||
|
||||
for result in results {
|
||||
match result {
|
||||
Ok((PagestreamBeMessage::GetPage(r), _, _)) => resp.page_images.push(r.page),
|
||||
Ok((resp, _, _)) => {
|
||||
return Err(tonic::Status::internal(format!(
|
||||
"unexpected response: {resp:?}"
|
||||
)));
|
||||
}
|
||||
Err(err) => return err.err.into_get_page_response(req.request_id),
|
||||
};
|
||||
}
|
||||
|
||||
Ok(resp.into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements the gRPC page service.
|
||||
///
|
||||
/// TODO: when the libpq impl is removed, simplify this:
|
||||
/// * Add Tower middleware for timeline handle, rate limiting, and timing.
|
||||
/// * Remove the intermediate Pagestream types.
|
||||
/// * Inline the handler code.
|
||||
#[tonic::async_trait]
|
||||
impl proto::PageService for PageServerHandler {
|
||||
impl proto::PageService for GrpcPageServiceHandler {
|
||||
type GetBaseBackupStream = Pin<
|
||||
Box<dyn Stream<Item = Result<proto::GetBaseBackupResponseChunk, tonic::Status>> + Send>,
|
||||
>;
|
||||
|
||||
type GetPagesStream =
|
||||
Pin<Box<dyn Stream<Item = Result<proto::GetPageResponse, tonic::Status>> + Send>>;
|
||||
|
||||
async fn check_rel_exists(
|
||||
&self,
|
||||
_: tonic::Request<proto::CheckRelExistsRequest>,
|
||||
req: tonic::Request<proto::CheckRelExistsRequest>,
|
||||
) -> Result<tonic::Response<proto::CheckRelExistsResponse>, tonic::Status> {
|
||||
Err(tonic::Status::unimplemented("not implemented"))
|
||||
let received_at = Self::extract::<ReceivedAt>(&req).0;
|
||||
let timeline = self.get_request_timeline(&req).await?;
|
||||
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
|
||||
|
||||
// Validate the request and convert it to a Pagestream request.
|
||||
Self::ensure_shard_zero(&req)?;
|
||||
let req: page_api::CheckRelExistsRequest = req.into_inner().try_into()?;
|
||||
let req = PagestreamExistsRequest {
|
||||
hdr: Self::make_hdr(req.read_lsn, 0),
|
||||
rel: req.rel,
|
||||
};
|
||||
|
||||
// Execute the request and convert the response.
|
||||
let _timer = Self::record_op_start_and_throttle(
|
||||
&timeline,
|
||||
metrics::SmgrQueryType::GetRelExists,
|
||||
received_at,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let resp = PageServerHandler::handle_get_rel_exists_request(&timeline, &req, &ctx).await?;
|
||||
let resp: page_api::CheckRelExistsResponse = resp.exists;
|
||||
Ok(tonic::Response::new(resp.into()))
|
||||
}
|
||||
|
||||
async fn get_base_backup(
|
||||
&self,
|
||||
_: tonic::Request<proto::GetBaseBackupRequest>,
|
||||
) -> Result<tonic::Response<Self::GetBaseBackupStream>, tonic::Status> {
|
||||
Err(tonic::Status::unimplemented("not implemented"))
|
||||
Err(tonic::Status::unimplemented("not implemented")) // TODO
|
||||
}
|
||||
|
||||
async fn get_db_size(
|
||||
&self,
|
||||
_: tonic::Request<proto::GetDbSizeRequest>,
|
||||
req: tonic::Request<proto::GetDbSizeRequest>,
|
||||
) -> Result<tonic::Response<proto::GetDbSizeResponse>, tonic::Status> {
|
||||
Err(tonic::Status::unimplemented("not implemented"))
|
||||
let received_at = Self::extract::<ReceivedAt>(&req).0;
|
||||
let timeline = self.get_request_timeline(&req).await?;
|
||||
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
|
||||
|
||||
// Validate the request and convert it to a Pagestream request.
|
||||
Self::ensure_shard_zero(&req)?;
|
||||
let req: page_api::GetDbSizeRequest = req.into_inner().try_into()?;
|
||||
let req = PagestreamDbSizeRequest {
|
||||
hdr: Self::make_hdr(req.read_lsn, 0),
|
||||
dbnode: req.db_oid,
|
||||
};
|
||||
|
||||
// Execute the request and convert the response.
|
||||
let _timer = Self::record_op_start_and_throttle(
|
||||
&timeline,
|
||||
metrics::SmgrQueryType::GetDbSize,
|
||||
received_at,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let resp = PageServerHandler::handle_db_size_request(&timeline, &req, &ctx).await?;
|
||||
let resp = resp.db_size as page_api::GetDbSizeResponse;
|
||||
Ok(tonic::Response::new(resp.into()))
|
||||
}
|
||||
|
||||
async fn get_pages(
|
||||
&self,
|
||||
_: tonic::Request<tonic::Streaming<proto::GetPageRequest>>,
|
||||
req: tonic::Request<tonic::Streaming<proto::GetPageRequest>>,
|
||||
) -> Result<tonic::Response<Self::GetPagesStream>, tonic::Status> {
|
||||
Err(tonic::Status::unimplemented("not implemented"))
|
||||
// Extract the timeline from the request and check that it exists.
|
||||
let ttid = *Self::extract::<TenantTimelineId>(&req);
|
||||
let shard_index = *Self::extract::<ShardIndex>(&req);
|
||||
let shard_selector = ShardSelector::Known(shard_index);
|
||||
|
||||
let mut handles = TimelineHandles::new(self.tenant_manager.clone());
|
||||
handles
|
||||
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
|
||||
.await?;
|
||||
|
||||
let ctx = self.ctx.attached_child();
|
||||
let mut reqs = req.into_inner();
|
||||
|
||||
let resps = async_stream::try_stream! {
|
||||
let timeline = handles
|
||||
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
|
||||
.await?
|
||||
.downgrade();
|
||||
while let Some(req) = reqs.message().await? {
|
||||
// TODO: implement IoConcurrency sidecar.
|
||||
yield Self::handle_get_page_request(&ctx, &timeline, req, IoConcurrency::Sequential).await?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(tonic::Response::new(Box::pin(resps)))
|
||||
}
|
||||
|
||||
async fn get_rel_size(
|
||||
&self,
|
||||
_: tonic::Request<proto::GetRelSizeRequest>,
|
||||
req: tonic::Request<proto::GetRelSizeRequest>,
|
||||
) -> Result<tonic::Response<proto::GetRelSizeResponse>, tonic::Status> {
|
||||
Err(tonic::Status::unimplemented("not implemented"))
|
||||
let received_at = Self::extract::<ReceivedAt>(&req).0;
|
||||
let timeline = self.get_request_timeline(&req).await?;
|
||||
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
|
||||
|
||||
// Validate the request and convert it to a Pagestream request.
|
||||
Self::ensure_shard_zero(&req)?;
|
||||
let req: page_api::GetRelSizeRequest = req.into_inner().try_into()?;
|
||||
let req = PagestreamNblocksRequest {
|
||||
hdr: Self::make_hdr(req.read_lsn, 0),
|
||||
rel: req.rel,
|
||||
};
|
||||
|
||||
// Execute the request and convert the response.
|
||||
let _timer = Self::record_op_start_and_throttle(
|
||||
&timeline,
|
||||
metrics::SmgrQueryType::GetRelSize,
|
||||
received_at,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let resp = PageServerHandler::handle_get_nblocks_request(&timeline, &req, &ctx).await?;
|
||||
let resp: page_api::GetRelSizeResponse = resp.n_blocks;
|
||||
Ok(tonic::Response::new(resp.into()))
|
||||
}
|
||||
|
||||
async fn get_slru_segment(
|
||||
&self,
|
||||
_: tonic::Request<proto::GetSlruSegmentRequest>,
|
||||
req: tonic::Request<proto::GetSlruSegmentRequest>,
|
||||
) -> Result<tonic::Response<proto::GetSlruSegmentResponse>, tonic::Status> {
|
||||
Err(tonic::Status::unimplemented("not implemented"))
|
||||
let received_at = Self::extract::<ReceivedAt>(&req).0;
|
||||
let timeline = self.get_request_timeline(&req).await?;
|
||||
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
|
||||
|
||||
// Validate the request and convert it to a Pagestream request.
|
||||
Self::ensure_shard_zero(&req)?;
|
||||
let req: page_api::GetSlruSegmentRequest = req.into_inner().try_into()?;
|
||||
let req = PagestreamGetSlruSegmentRequest {
|
||||
hdr: Self::make_hdr(req.read_lsn, 0),
|
||||
kind: req.kind as u8,
|
||||
segno: req.segno,
|
||||
};
|
||||
|
||||
// Execute the request and convert the response.
|
||||
let _timer = Self::record_op_start_and_throttle(
|
||||
&timeline,
|
||||
metrics::SmgrQueryType::GetSlruSegment,
|
||||
received_at,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let resp =
|
||||
PageServerHandler::handle_get_slru_segment_request(&timeline, &req, &ctx).await?;
|
||||
let resp: page_api::GetSlruSegmentResponse = resp.segment;
|
||||
Ok(tonic::Response::new(resp.try_into()?))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3370,10 +3725,24 @@ impl From<GetActiveTenantError> for QueryError {
|
||||
}
|
||||
}
|
||||
|
||||
/// gRPC interceptor that records the start time of request processing as a ReceivedAt extension.
|
||||
///
|
||||
/// TODO: generalize this for other observability information.
|
||||
#[derive(Clone)]
|
||||
struct ReceivedAtInterceptor;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ReceivedAt(Instant);
|
||||
|
||||
impl tonic::service::Interceptor for ReceivedAtInterceptor {
|
||||
fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
|
||||
req.extensions_mut().insert(ReceivedAt(Instant::now()));
|
||||
Ok(req)
|
||||
}
|
||||
}
|
||||
|
||||
/// gRPC interceptor that decodes tenant metadata and stores it as request extensions of type
|
||||
/// TenantTimelineId and ShardIndex.
|
||||
///
|
||||
/// TODO: consider looking up the timeline handle here and storing it.
|
||||
#[derive(Clone)]
|
||||
struct TenantMetadataInterceptor;
|
||||
|
||||
@@ -3486,14 +3855,36 @@ impl From<GetActiveTimelineError> for QueryError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::tenant::timeline::handle::HandleUpgradeError> for QueryError {
|
||||
fn from(e: crate::tenant::timeline::handle::HandleUpgradeError) -> Self {
|
||||
impl From<HandleUpgradeError> for QueryError {
|
||||
fn from(e: HandleUpgradeError) -> Self {
|
||||
match e {
|
||||
crate::tenant::timeline::handle::HandleUpgradeError::ShutDown => QueryError::Shutdown,
|
||||
HandleUpgradeError::ShutDown => QueryError::Shutdown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GetActiveTimelineError> for tonic::Status {
|
||||
fn from(err: GetActiveTimelineError) -> Self {
|
||||
use tonic::Code;
|
||||
let code = match &err {
|
||||
GetActiveTimelineError::Tenant(err) => match err {
|
||||
GetActiveTenantError::Broken(_) => Code::Internal,
|
||||
GetActiveTenantError::Cancelled => Code::Unavailable,
|
||||
GetActiveTenantError::NotFound(_) => Code::NotFound,
|
||||
GetActiveTenantError::SwitchedTenant => Code::Unavailable,
|
||||
GetActiveTenantError::WaitForActiveTimeout { .. } => Code::Unavailable,
|
||||
GetActiveTenantError::WillNotBecomeActive(_) => Code::Unavailable,
|
||||
},
|
||||
GetActiveTimelineError::Timeline(err) => match err {
|
||||
GetTimelineError::NotFound { .. } => Code::NotFound,
|
||||
GetTimelineError::NotActive { .. } => Code::Unavailable,
|
||||
GetTimelineError::ShuttingDown => Code::Unavailable,
|
||||
},
|
||||
};
|
||||
tonic::Status::new(code, format!("{err}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn set_tracing_field_shard_id(timeline: &Timeline) {
|
||||
debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id();
|
||||
tracing::Span::current().record(
|
||||
|
||||
@@ -274,7 +274,7 @@ impl Timeline {
|
||||
io_concurrency: IoConcurrency,
|
||||
ctx: &RequestContext,
|
||||
) -> Vec<Result<Bytes, PageReconstructError>> {
|
||||
debug_assert_current_span_has_tenant_and_timeline_id();
|
||||
//debug_assert_current_span_has_tenant_and_timeline_id();
|
||||
|
||||
let mut slots_filled = 0;
|
||||
let page_count = pages.len();
|
||||
|
||||
@@ -300,7 +300,7 @@ pub struct TenantShard {
|
||||
/// as in progress.
|
||||
/// * Imported timelines are removed when the storage controller calls the post timeline
|
||||
/// import activation endpoint.
|
||||
timelines_importing: std::sync::Mutex<HashMap<TimelineId, Arc<ImportingTimeline>>>,
|
||||
timelines_importing: std::sync::Mutex<HashMap<TimelineId, ImportingTimeline>>,
|
||||
|
||||
/// The last tenant manifest known to be in remote storage. None if the manifest has not yet
|
||||
/// been either downloaded or uploaded. Always Some after tenant attach.
|
||||
@@ -672,7 +672,6 @@ pub enum MaybeOffloaded {
|
||||
pub enum TimelineOrOffloaded {
|
||||
Timeline(Arc<Timeline>),
|
||||
Offloaded(Arc<OffloadedTimeline>),
|
||||
Importing(Arc<ImportingTimeline>),
|
||||
}
|
||||
|
||||
impl TimelineOrOffloaded {
|
||||
@@ -684,9 +683,6 @@ impl TimelineOrOffloaded {
|
||||
TimelineOrOffloaded::Offloaded(offloaded) => {
|
||||
TimelineOrOffloadedArcRef::Offloaded(offloaded)
|
||||
}
|
||||
TimelineOrOffloaded::Importing(importing) => {
|
||||
TimelineOrOffloadedArcRef::Importing(importing)
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn tenant_shard_id(&self) -> TenantShardId {
|
||||
@@ -699,16 +695,12 @@ impl TimelineOrOffloaded {
|
||||
match self {
|
||||
TimelineOrOffloaded::Timeline(timeline) => &timeline.delete_progress,
|
||||
TimelineOrOffloaded::Offloaded(offloaded) => &offloaded.delete_progress,
|
||||
TimelineOrOffloaded::Importing(importing) => &importing.delete_progress,
|
||||
}
|
||||
}
|
||||
fn maybe_remote_client(&self) -> Option<Arc<RemoteTimelineClient>> {
|
||||
match self {
|
||||
TimelineOrOffloaded::Timeline(timeline) => Some(timeline.remote_client.clone()),
|
||||
TimelineOrOffloaded::Offloaded(_offloaded) => None,
|
||||
TimelineOrOffloaded::Importing(importing) => {
|
||||
Some(importing.timeline.remote_client.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -716,7 +708,6 @@ impl TimelineOrOffloaded {
|
||||
pub enum TimelineOrOffloadedArcRef<'a> {
|
||||
Timeline(&'a Arc<Timeline>),
|
||||
Offloaded(&'a Arc<OffloadedTimeline>),
|
||||
Importing(&'a Arc<ImportingTimeline>),
|
||||
}
|
||||
|
||||
impl TimelineOrOffloadedArcRef<'_> {
|
||||
@@ -724,14 +715,12 @@ impl TimelineOrOffloadedArcRef<'_> {
|
||||
match self {
|
||||
TimelineOrOffloadedArcRef::Timeline(timeline) => timeline.tenant_shard_id,
|
||||
TimelineOrOffloadedArcRef::Offloaded(offloaded) => offloaded.tenant_shard_id,
|
||||
TimelineOrOffloadedArcRef::Importing(importing) => importing.timeline.tenant_shard_id,
|
||||
}
|
||||
}
|
||||
pub fn timeline_id(&self) -> TimelineId {
|
||||
match self {
|
||||
TimelineOrOffloadedArcRef::Timeline(timeline) => timeline.timeline_id,
|
||||
TimelineOrOffloadedArcRef::Offloaded(offloaded) => offloaded.timeline_id,
|
||||
TimelineOrOffloadedArcRef::Importing(importing) => importing.timeline.timeline_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -748,12 +737,6 @@ impl<'a> From<&'a Arc<OffloadedTimeline>> for TimelineOrOffloadedArcRef<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a Arc<ImportingTimeline>> for TimelineOrOffloadedArcRef<'a> {
|
||||
fn from(timeline: &'a Arc<ImportingTimeline>) -> Self {
|
||||
Self::Importing(timeline)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
|
||||
pub enum GetTimelineError {
|
||||
#[error("Timeline is shutting down")]
|
||||
@@ -1806,25 +1789,20 @@ impl TenantShard {
|
||||
},
|
||||
) => {
|
||||
let timeline_id = timeline.timeline_id;
|
||||
let import_task_gate = Gate::default();
|
||||
let import_task_guard = import_task_gate.enter().unwrap();
|
||||
let import_task_handle =
|
||||
tokio::task::spawn(self.clone().create_timeline_import_pgdata_task(
|
||||
timeline.clone(),
|
||||
import_pgdata,
|
||||
guard,
|
||||
import_task_guard,
|
||||
ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Warn),
|
||||
));
|
||||
|
||||
let prev = self.timelines_importing.lock().unwrap().insert(
|
||||
timeline_id,
|
||||
Arc::new(ImportingTimeline {
|
||||
ImportingTimeline {
|
||||
timeline: timeline.clone(),
|
||||
import_task_handle,
|
||||
import_task_gate,
|
||||
delete_progress: TimelineDeleteProgress::default(),
|
||||
}),
|
||||
},
|
||||
);
|
||||
|
||||
assert!(prev.is_none());
|
||||
@@ -2442,17 +2420,6 @@ impl TenantShard {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Lists timelines the tenant contains.
|
||||
/// It's up to callers to omit certain timelines that are not considered ready for use.
|
||||
pub fn list_importing_timelines(&self) -> Vec<Arc<ImportingTimeline>> {
|
||||
self.timelines_importing
|
||||
.lock()
|
||||
.unwrap()
|
||||
.values()
|
||||
.map(Arc::clone)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Lists timelines the tenant manages, including offloaded ones.
|
||||
///
|
||||
/// It's up to callers to omit certain timelines that are not considered ready for use.
|
||||
@@ -2886,25 +2853,19 @@ impl TenantShard {
|
||||
|
||||
let (timeline, timeline_create_guard) = uninit_timeline.finish_creation_myself();
|
||||
|
||||
let import_task_gate = Gate::default();
|
||||
let import_task_guard = import_task_gate.enter().unwrap();
|
||||
|
||||
let import_task_handle = tokio::spawn(self.clone().create_timeline_import_pgdata_task(
|
||||
timeline.clone(),
|
||||
index_part,
|
||||
timeline_create_guard,
|
||||
import_task_guard,
|
||||
timeline_ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Warn),
|
||||
));
|
||||
|
||||
let prev = self.timelines_importing.lock().unwrap().insert(
|
||||
timeline.timeline_id,
|
||||
Arc::new(ImportingTimeline {
|
||||
ImportingTimeline {
|
||||
timeline: timeline.clone(),
|
||||
import_task_handle,
|
||||
import_task_gate,
|
||||
delete_progress: TimelineDeleteProgress::default(),
|
||||
}),
|
||||
},
|
||||
);
|
||||
|
||||
// Idempotency is enforced higher up the stack
|
||||
@@ -2963,7 +2924,6 @@ impl TenantShard {
|
||||
timeline: Arc<Timeline>,
|
||||
index_part: import_pgdata::index_part_format::Root,
|
||||
timeline_create_guard: TimelineCreateGuard,
|
||||
_import_task_guard: GateGuard,
|
||||
ctx: RequestContext,
|
||||
) {
|
||||
debug_assert_current_span_has_tenant_and_timeline_id();
|
||||
@@ -3875,9 +3835,6 @@ impl TenantShard {
|
||||
.build_timeline_client(offloaded.timeline_id, self.remote_storage.clone());
|
||||
Arc::new(remote_client)
|
||||
}
|
||||
TimelineOrOffloadedArcRef::Importing(_) => {
|
||||
unreachable!("Importing timelines are not included in the iterator")
|
||||
}
|
||||
};
|
||||
|
||||
// Shut down the timeline's remote client: this means that the indices we write
|
||||
@@ -5087,14 +5044,6 @@ impl TenantShard {
|
||||
info!("timeline already exists but is offloaded");
|
||||
Err(CreateTimelineError::Conflict)
|
||||
}
|
||||
Err(TimelineExclusionError::AlreadyExists {
|
||||
existing: TimelineOrOffloaded::Importing(_existing),
|
||||
..
|
||||
}) => {
|
||||
// If there's a timeline already importing, then we would hit
|
||||
// the [`TimelineExclusionError::AlreadyCreating`] branch above.
|
||||
unreachable!("Importing timelines hold the creation guard")
|
||||
}
|
||||
Err(TimelineExclusionError::AlreadyExists {
|
||||
existing: TimelineOrOffloaded::Timeline(existing),
|
||||
arg,
|
||||
|
||||
@@ -1348,21 +1348,6 @@ impl RemoteTimelineClient {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn schedule_unlinking_of_layers_from_index_part<I>(
|
||||
self: &Arc<Self>,
|
||||
names: I,
|
||||
) -> Result<(), NotInitialized>
|
||||
where
|
||||
I: IntoIterator<Item = LayerName>,
|
||||
{
|
||||
let mut guard = self.upload_queue.lock().unwrap();
|
||||
let upload_queue = guard.initialized_mut()?;
|
||||
|
||||
self.schedule_unlinking_of_layers_from_index_part0(upload_queue, names);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the remote index file, removing the to-be-deleted files from the index,
|
||||
/// allowing scheduling of actual deletions later.
|
||||
fn schedule_unlinking_of_layers_from_index_part0<I>(
|
||||
|
||||
@@ -206,8 +206,8 @@ pub struct GcCompactionQueue {
|
||||
}
|
||||
|
||||
static CONCURRENT_GC_COMPACTION_TASKS: Lazy<Arc<Semaphore>> = Lazy::new(|| {
|
||||
// Only allow one timeline on one pageserver to run gc compaction at a time.
|
||||
Arc::new(Semaphore::new(1))
|
||||
// Only allow two timelines on one pageserver to run gc compaction at a time.
|
||||
Arc::new(Semaphore::new(2))
|
||||
});
|
||||
|
||||
impl GcCompactionQueue {
|
||||
|
||||
@@ -121,7 +121,6 @@ async fn remove_maybe_offloaded_timeline_from_tenant(
|
||||
// This observes the locking order between timelines and timelines_offloaded
|
||||
let mut timelines = tenant.timelines.lock().unwrap();
|
||||
let mut timelines_offloaded = tenant.timelines_offloaded.lock().unwrap();
|
||||
let mut timelines_importing = tenant.timelines_importing.lock().unwrap();
|
||||
let offloaded_children_exist = timelines_offloaded
|
||||
.iter()
|
||||
.any(|(_, entry)| entry.ancestor_timeline_id == Some(timeline.timeline_id()));
|
||||
@@ -151,12 +150,8 @@ async fn remove_maybe_offloaded_timeline_from_tenant(
|
||||
.expect("timeline that we were deleting was concurrently removed from 'timelines_offloaded' map");
|
||||
offloaded_timeline.delete_from_ancestor_with_timelines(&timelines);
|
||||
}
|
||||
TimelineOrOffloaded::Importing(importing) => {
|
||||
timelines_importing.remove(&importing.timeline.timeline_id);
|
||||
}
|
||||
}
|
||||
|
||||
drop(timelines_importing);
|
||||
drop(timelines_offloaded);
|
||||
drop(timelines);
|
||||
|
||||
@@ -208,17 +203,8 @@ impl DeleteTimelineFlow {
|
||||
guard.mark_in_progress()?;
|
||||
|
||||
// Now that the Timeline is in Stopping state, request all the related tasks to shut down.
|
||||
// TODO(vlad): shut down imported timeline here
|
||||
match &timeline {
|
||||
TimelineOrOffloaded::Timeline(timeline) => {
|
||||
timeline.shutdown(super::ShutdownMode::Hard).await;
|
||||
}
|
||||
TimelineOrOffloaded::Importing(importing) => {
|
||||
importing.shutdown().await;
|
||||
}
|
||||
TimelineOrOffloaded::Offloaded(_offloaded) => {
|
||||
// Nothing to shut down in this case
|
||||
}
|
||||
if let TimelineOrOffloaded::Timeline(timeline) = &timeline {
|
||||
timeline.shutdown(super::ShutdownMode::Hard).await;
|
||||
}
|
||||
|
||||
tenant.gc_block.before_delete(&timeline.timeline_id());
|
||||
@@ -403,18 +389,10 @@ impl DeleteTimelineFlow {
|
||||
Err(anyhow::anyhow!("failpoint: timeline-delete-before-rm"))?
|
||||
});
|
||||
|
||||
match timeline {
|
||||
TimelineOrOffloaded::Timeline(timeline) => {
|
||||
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await;
|
||||
}
|
||||
TimelineOrOffloaded::Importing(importing) => {
|
||||
delete_local_timeline_directory(conf, tenant.tenant_shard_id, &importing.timeline)
|
||||
.await;
|
||||
}
|
||||
TimelineOrOffloaded::Offloaded(_offloaded) => {
|
||||
// Offloaded timelines have no local state
|
||||
// TODO: once we persist offloaded information, delete the timeline from there, too
|
||||
}
|
||||
// Offloaded timelines have no local state
|
||||
// TODO: once we persist offloaded information, delete the timeline from there, too
|
||||
if let TimelineOrOffloaded::Timeline(timeline) = timeline {
|
||||
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await;
|
||||
}
|
||||
|
||||
fail::fail_point!("timeline-delete-after-rm", |_| {
|
||||
@@ -473,16 +451,12 @@ pub(super) fn make_timeline_delete_guard(
|
||||
// For more context see this discussion: `https://github.com/neondatabase/neon/pull/4552#discussion_r1253437346`
|
||||
let timelines = tenant.timelines.lock().unwrap();
|
||||
let timelines_offloaded = tenant.timelines_offloaded.lock().unwrap();
|
||||
let timelines_importing = tenant.timelines_importing.lock().unwrap();
|
||||
|
||||
let timeline = match timelines.get(&timeline_id) {
|
||||
Some(t) => TimelineOrOffloaded::Timeline(Arc::clone(t)),
|
||||
None => match timelines_offloaded.get(&timeline_id) {
|
||||
Some(t) => TimelineOrOffloaded::Offloaded(Arc::clone(t)),
|
||||
None => match timelines_importing.get(&timeline_id) {
|
||||
Some(t) => TimelineOrOffloaded::Importing(Arc::clone(t)),
|
||||
None => return Err(DeleteTimelineError::NotFound),
|
||||
},
|
||||
None => return Err(DeleteTimelineError::NotFound),
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -8,10 +8,8 @@ use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
use utils::lsn::Lsn;
|
||||
use utils::pausable_failpoint;
|
||||
use utils::sync::gate::Gate;
|
||||
|
||||
use super::{Timeline, TimelineDeleteProgress};
|
||||
use super::Timeline;
|
||||
use crate::context::RequestContext;
|
||||
use crate::controller_upcall_client::{StorageControllerUpcallApi, StorageControllerUpcallClient};
|
||||
use crate::tenant::metadata::TimelineMetadata;
|
||||
@@ -21,23 +19,15 @@ mod importbucket_client;
|
||||
mod importbucket_format;
|
||||
pub(crate) mod index_part_format;
|
||||
|
||||
pub struct ImportingTimeline {
|
||||
pub(crate) struct ImportingTimeline {
|
||||
pub import_task_handle: JoinHandle<()>,
|
||||
pub import_task_gate: Gate,
|
||||
pub timeline: Arc<Timeline>,
|
||||
pub delete_progress: TimelineDeleteProgress,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ImportingTimeline {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "ImportingTimeline<{}>", self.timeline.timeline_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl ImportingTimeline {
|
||||
pub async fn shutdown(&self) {
|
||||
pub(crate) async fn shutdown(self) {
|
||||
self.import_task_handle.abort();
|
||||
self.import_task_gate.close().await;
|
||||
let _ = self.import_task_handle.await;
|
||||
|
||||
self.timeline.remote_client.shutdown().await;
|
||||
}
|
||||
@@ -111,8 +101,6 @@ pub async fn doit(
|
||||
.schedule_index_upload_for_file_changes()?;
|
||||
timeline.remote_client.wait_completion().await?;
|
||||
|
||||
pausable_failpoint!("import-timeline-pre-success-notify-pausable");
|
||||
|
||||
// Communicate that shard is done.
|
||||
// Ensure at-least-once delivery of the upcall to storage controller
|
||||
// before we mark the task as done and never come here again.
|
||||
|
||||
@@ -30,7 +30,6 @@
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::num::NonZeroUsize;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -101,24 +100,8 @@ async fn run_v1(
|
||||
tasks: Vec::default(),
|
||||
};
|
||||
|
||||
// Use the job size limit encoded in the progress if we are resuming an import.
|
||||
// This ensures that imports have stable plans even if the pageserver config changes.
|
||||
let import_config = {
|
||||
match &import_progress {
|
||||
Some(progress) => {
|
||||
let base = &timeline.conf.timeline_import_config;
|
||||
TimelineImportConfig {
|
||||
import_job_soft_size_limit: NonZeroUsize::new(progress.job_soft_size_limit)
|
||||
.unwrap(),
|
||||
import_job_concurrency: base.import_job_concurrency,
|
||||
import_job_checkpoint_threshold: base.import_job_checkpoint_threshold,
|
||||
}
|
||||
}
|
||||
None => timeline.conf.timeline_import_config.clone(),
|
||||
}
|
||||
};
|
||||
|
||||
let plan = planner.plan(&import_config).await?;
|
||||
let import_config = &timeline.conf.timeline_import_config;
|
||||
let plan = planner.plan(import_config).await?;
|
||||
|
||||
// Hash the plan and compare with the hash of the plan we got back from the storage controller.
|
||||
// If the two match, it means that the planning stage had the same output.
|
||||
@@ -143,7 +126,7 @@ async fn run_v1(
|
||||
pausable_failpoint!("import-timeline-pre-execute-pausable");
|
||||
|
||||
let start_from_job_idx = import_progress.map(|progress| progress.completed);
|
||||
plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx)
|
||||
plan.execute(timeline, start_from_job_idx, plan_hash, import_config, ctx)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -470,7 +453,6 @@ impl Plan {
|
||||
jobs: jobs_in_plan,
|
||||
completed: last_completed_job_idx,
|
||||
import_plan_hash,
|
||||
job_soft_size_limit: import_config.import_job_soft_size_limit.into(),
|
||||
};
|
||||
|
||||
timeline.remote_client.schedule_index_upload_for_file_changes()?;
|
||||
@@ -982,15 +964,6 @@ impl ChunkProcessingJob {
|
||||
.cloned();
|
||||
match existing_layer {
|
||||
Some(existing) => {
|
||||
// Unlink the remote layer from the index without scheduling its deletion.
|
||||
// When `existing_layer` drops [`LayerInner::drop`] will schedule its deletion from
|
||||
// remote storage, but that assumes that the layer was unlinked from the index first.
|
||||
timeline
|
||||
.remote_client
|
||||
.schedule_unlinking_of_layers_from_index_part(std::iter::once(
|
||||
existing.layer_desc().layer_name(),
|
||||
))?;
|
||||
|
||||
guard.open_mut()?.rewrite_layers(
|
||||
&[(existing.clone(), resident_layer.clone())],
|
||||
&[],
|
||||
|
||||
@@ -17,7 +17,9 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
|
||||
use crate::auth::{
|
||||
self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange,
|
||||
};
|
||||
use crate::cache::Cached;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
@@ -135,6 +137,16 @@ impl<'a, T> Backend<'a, T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'a, T, E> Backend<'a, Result<T, E>> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
|
||||
match self {
|
||||
Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
|
||||
Self::Local(l) => Ok(Backend::Local(l)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ComputeCredentials {
|
||||
pub(crate) info: ComputeUserInfo,
|
||||
@@ -272,7 +284,7 @@ async fn auth_quirks(
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
@@ -289,12 +301,15 @@ async fn auth_quirks(
|
||||
debug!("fetching authentication info and allowlists");
|
||||
|
||||
// check allowed list
|
||||
if config.ip_allowlist_check_enabled {
|
||||
let allowed_ips = if config.ip_allowlist_check_enabled {
|
||||
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
}
|
||||
}
|
||||
allowed_ips
|
||||
} else {
|
||||
Cached::new_uncached(Arc::new(vec![]))
|
||||
};
|
||||
|
||||
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
|
||||
let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
|
||||
@@ -353,7 +368,7 @@ async fn auth_quirks(
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(keys) => Ok(keys),
|
||||
Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
|
||||
Err(e) => {
|
||||
if e.is_password_failed() {
|
||||
// The password could have been changed, so we invalidate the cache.
|
||||
@@ -405,39 +420,53 @@ async fn authenticate_with_secret(
|
||||
classic::authenticate(ctx, info, client, config, secret).await
|
||||
}
|
||||
|
||||
impl ControlPlaneClient {
|
||||
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
/// Get username from the credentials.
|
||||
pub(crate) fn get_user(&self) -> &str {
|
||||
match self {
|
||||
Self::ControlPlane(_, user_info) => &user_info.user,
|
||||
Self::Local(_) => "local",
|
||||
}
|
||||
}
|
||||
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
|
||||
pub(crate) async fn authenticate(
|
||||
&self,
|
||||
self,
|
||||
ctx: &RequestContext,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
debug!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.endpoint(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
|
||||
let res = match self {
|
||||
Self::ControlPlane(api, user_info) => {
|
||||
debug!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.endpoint(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let credentials = auth_quirks(
|
||||
ctx,
|
||||
self,
|
||||
user_info,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
let (credentials, ip_allowlist) = auth_quirks(
|
||||
ctx,
|
||||
&*api,
|
||||
user_info,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
|
||||
}
|
||||
Self::Local(_) => {
|
||||
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: replace with some metric
|
||||
info!("user successfully authenticated");
|
||||
|
||||
Ok(credentials)
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
@@ -507,25 +536,6 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ControlPlaneWakeCompute<'a> {
|
||||
pub cplane: &'a ControlPlaneClient,
|
||||
pub creds: ComputeCredentials,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for ControlPlaneWakeCompute<'_> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
self.cplane.wake_compute(ctx, &self.creds.info).await
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
&self.creds.keys
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unimplemented, clippy::unwrap_used)]
|
||||
@@ -542,7 +552,6 @@ mod tests {
|
||||
use postgres_protocol::message::backend::Message as PgMessage;
|
||||
use postgres_protocol::message::frontend;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use super::jwt::JwkCache;
|
||||
use super::{AuthRateLimiter, auth_quirks};
|
||||
@@ -693,7 +702,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_scram() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -775,7 +784,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_cleartext() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -829,7 +838,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_password_hack() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -878,7 +887,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(creds.info.endpoint, "my-endpoint");
|
||||
assert_eq!(creds.0.info.endpoint, "my-endpoint");
|
||||
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Client authentication mechanisms.
|
||||
|
||||
pub mod backend;
|
||||
pub use backend::{Backend, ControlPlaneWakeCompute};
|
||||
pub use backend::Backend;
|
||||
|
||||
mod credentials;
|
||||
pub(crate) use credentials::{
|
||||
|
||||
@@ -18,7 +18,6 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, error, info};
|
||||
use utils::project_git_version;
|
||||
use utils::sentry_init::init_sentry;
|
||||
@@ -227,8 +226,7 @@ pub(super) async fn task_main(
|
||||
let dest_suffix = Arc::clone(&dest_suffix);
|
||||
let compute_tls_config = compute_tls_config.clone();
|
||||
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(
|
||||
connections.spawn(
|
||||
async move {
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
@@ -251,7 +249,6 @@ pub(super) async fn task_main(
|
||||
compute_tls_config,
|
||||
tls_server_end_point,
|
||||
socket,
|
||||
tracker,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -277,11 +274,10 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestContext,
|
||||
raw_stream: S,
|
||||
tracker: TaskTrackerToken,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
) -> anyhow::Result<(Stream<S>, TaskTrackerToken)> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream), tracker);
|
||||
) -> anyhow::Result<Stream<S>> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
|
||||
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
use pq_proto::FeStartupPacket::SslRequest;
|
||||
@@ -295,7 +291,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let (raw, read_buf, tracker) = stream.into_inner();
|
||||
let (raw, read_buf) = stream.into_inner();
|
||||
// TODO: Normally, client doesn't send any data before
|
||||
// server says TLS handshake is ok and read_buf is empty.
|
||||
// However, you could imagine pipelining of postgres
|
||||
@@ -306,16 +302,13 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
|
||||
Ok((
|
||||
Stream::Tls {
|
||||
tls: Box::new(
|
||||
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?,
|
||||
),
|
||||
tls_server_end_point,
|
||||
},
|
||||
tracker,
|
||||
))
|
||||
Ok(Stream::Tls {
|
||||
tls: Box::new(
|
||||
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?,
|
||||
),
|
||||
tls_server_end_point,
|
||||
})
|
||||
}
|
||||
unexpected => {
|
||||
info!(
|
||||
@@ -336,10 +329,8 @@ async fn handle_client(
|
||||
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
tracker: TaskTrackerToken,
|
||||
) -> anyhow::Result<()> {
|
||||
let (mut tls_stream, _tracker) =
|
||||
ssl_handshake(&ctx, stream, tracker, tls_config, tls_server_end_point).await?;
|
||||
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
|
||||
|
||||
// Cut off first part of the SNI domain
|
||||
// We receive required destination details in the format of
|
||||
|
||||
@@ -323,7 +323,7 @@ impl CancellationHandler {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_key(self: Arc<Self>) -> Session {
|
||||
pub(crate) fn get_key(self: &Arc<Self>) -> Session {
|
||||
// we intentionally generate a random "backend pid" and "secret key" here.
|
||||
// we use the corresponding u64 as an identifier for the
|
||||
// actual endpoint+pid+secret for postgres/pgbouncer.
|
||||
@@ -340,7 +340,7 @@ impl CancellationHandler {
|
||||
Session {
|
||||
key,
|
||||
redis_key,
|
||||
cancellation_handler: self,
|
||||
cancellation_handler: Arc::clone(self),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::TryFutureExt;
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, debug, error, info};
|
||||
|
||||
use crate::auth::backend::ConsoleRedirectBackend;
|
||||
@@ -15,8 +14,10 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::proxy::handshake::{HandshakeData, handshake};
|
||||
use crate::proxy::passthrough::passthrough;
|
||||
use crate::proxy::{ClientRequestError, prepare_client_connection, run_until_cancelled};
|
||||
use crate::proxy::passthrough::ProxyPassthrough;
|
||||
use crate::proxy::{
|
||||
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
|
||||
};
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
@@ -34,6 +35,7 @@ pub async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -47,11 +49,11 @@ pub async fn task_main(
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(async move {
|
||||
connections.spawn(async move {
|
||||
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
@@ -101,80 +103,99 @@ pub async fn task_main(
|
||||
&config.region,
|
||||
);
|
||||
|
||||
let span = ctx.span();
|
||||
let mut slot = Some(ctx);
|
||||
let res = handle_client(
|
||||
config,
|
||||
backend,
|
||||
&mut slot,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
cancellations,
|
||||
)
|
||||
.instrument(span)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
.await;
|
||||
|
||||
match (slot, res) {
|
||||
(None, _) => {}
|
||||
(Some(ctx), Ok(())) => {
|
||||
ctx.success();
|
||||
}
|
||||
(Some(ctx), Err(e)) => {
|
||||
match res {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
backend: &'static ConsoleRedirectBackend,
|
||||
ctx_slot: &mut Option<RequestContext>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
tracker: TaskTrackerToken,
|
||||
) -> Result<(), ClientRequestError> {
|
||||
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
|
||||
debug!(%protocol, "handling interactive connection from client");
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let request_gauge = metrics.connection_requests.guard(protocol);
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
|
||||
let data = {
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error);
|
||||
tokio::time::timeout(config.handshake_timeout, do_handshake).await??
|
||||
};
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
|
||||
|
||||
let (mut stream, params) = match data {
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data, tracker) => {
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
tokio::spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
let _tracker = tracker;
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
@@ -184,17 +205,15 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
|
||||
backend.get_api(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|e| debug!(error = ?e, "cancel_session failed"))
|
||||
.ok();
|
||||
}
|
||||
.instrument(cancel_span)
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
});
|
||||
|
||||
return Ok(());
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let (node_info, user_info, _ip_allowlist) = match backend
|
||||
@@ -209,13 +228,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
TcpMechanism {
|
||||
&TcpMechanism {
|
||||
user_info,
|
||||
params_compat: true,
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
node_info,
|
||||
&node_info,
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
@@ -233,22 +252,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf, tracker) = stream.into_inner();
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
ctx.set_success();
|
||||
|
||||
tokio::spawn(passthrough(
|
||||
ctx,
|
||||
&config.connect_to_compute,
|
||||
stream,
|
||||
node,
|
||||
session,
|
||||
request_gauge,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
));
|
||||
|
||||
Ok(())
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
private_link_id: None,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ pub struct RequestContext(
|
||||
/// I would typically use a RefCell but that would break the `Send` requirements
|
||||
/// so we need something with thread-safety. `TryLock` is a cheap alternative
|
||||
/// that offers similar semantics to a `RefCell` but with synchronisation.
|
||||
TryLock<Box<RequestContextInner>>,
|
||||
TryLock<RequestContextInner>,
|
||||
);
|
||||
|
||||
struct RequestContextInner {
|
||||
@@ -89,7 +89,7 @@ pub(crate) enum AuthMethod {
|
||||
impl Clone for RequestContext {
|
||||
fn clone(&self) -> Self {
|
||||
let inner = self.0.try_lock().expect("should not deadlock");
|
||||
let new = Box::new(RequestContextInner {
|
||||
let new = RequestContextInner {
|
||||
conn_info: inner.conn_info.clone(),
|
||||
session_id: inner.session_id,
|
||||
protocol: inner.protocol,
|
||||
@@ -117,7 +117,7 @@ impl Clone for RequestContext {
|
||||
disconnect_sender: None,
|
||||
latency_timer: LatencyTimer::noop(inner.protocol),
|
||||
disconnect_timestamp: inner.disconnect_timestamp,
|
||||
});
|
||||
};
|
||||
|
||||
Self(TryLock::new(new))
|
||||
}
|
||||
@@ -140,7 +140,7 @@ impl RequestContext {
|
||||
role = tracing::field::Empty,
|
||||
);
|
||||
|
||||
let inner = Box::new(RequestContextInner {
|
||||
let inner = RequestContextInner {
|
||||
conn_info,
|
||||
session_id,
|
||||
protocol,
|
||||
@@ -168,7 +168,7 @@ impl RequestContext {
|
||||
disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
|
||||
latency_timer: LatencyTimer::new(protocol),
|
||||
disconnect_timestamp: None,
|
||||
});
|
||||
};
|
||||
|
||||
Self(TryLock::new(inner))
|
||||
}
|
||||
@@ -522,7 +522,7 @@ impl Drop for RequestContextInner {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DisconnectLogger(Box<RequestContextInner>);
|
||||
pub struct DisconnectLogger(RequestContextInner);
|
||||
|
||||
impl Drop for DisconnectLogger {
|
||||
fn drop(&mut self) {
|
||||
|
||||
@@ -53,25 +53,6 @@ pub(crate) trait ConnectMechanism {
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: ConnectMechanism + Sync> ConnectMechanism for &T {
|
||||
type Connection = T::Connection;
|
||||
type ConnectError = T::ConnectError;
|
||||
type Error = T::Error;
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
T::connect_once(self, ctx, node_info, config).await
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg) {
|
||||
T::update_connect_config(self, conf);
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait ComputeConnectBackend {
|
||||
async fn wake_compute(
|
||||
@@ -124,8 +105,8 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
ctx: &RequestContext,
|
||||
mechanism: M,
|
||||
backend: B,
|
||||
mechanism: &M,
|
||||
user_info: &B,
|
||||
wake_compute_retry_config: RetryConfig,
|
||||
compute: &ComputeConfig,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
@@ -135,9 +116,9 @@ where
|
||||
{
|
||||
let mut num_retries = 0;
|
||||
let mut node_info =
|
||||
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
|
||||
node_info.set_keys(backend.get_keys());
|
||||
node_info.set_keys(user_info.get_keys());
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
// try once
|
||||
@@ -178,7 +159,7 @@ where
|
||||
let old_node_info = invalidate_cache(node_info);
|
||||
// TODO: increment num_retries?
|
||||
let mut node_info =
|
||||
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
node_info.reuse_settings(old_node_info);
|
||||
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
@@ -67,6 +67,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn copy_bidirectional_client_compute<Client, Compute>(
|
||||
client: &mut Client,
|
||||
compute: &mut Compute,
|
||||
|
||||
@@ -5,7 +5,6 @@ use pq_proto::{
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::endpoint_sni;
|
||||
@@ -52,7 +51,7 @@ impl ReportableError for HandshakeError {
|
||||
|
||||
pub(crate) enum HandshakeData<S> {
|
||||
Startup(PqStream<Stream<S>>, StartupMessageParams),
|
||||
Cancel(CancelKeyData, TaskTrackerToken),
|
||||
Cancel(CancelKeyData),
|
||||
}
|
||||
|
||||
/// Establish a (most probably, secure) connection with the client.
|
||||
@@ -63,7 +62,6 @@ pub(crate) enum HandshakeData<S> {
|
||||
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestContext,
|
||||
stream: S,
|
||||
tracker: TaskTrackerToken,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
record_handshake_error: bool,
|
||||
) -> Result<HandshakeData<S>, HandshakeError> {
|
||||
@@ -73,7 +71,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream), tracker);
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
match msg {
|
||||
@@ -159,13 +157,15 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let (_, tls_server_end_point) =
|
||||
tls.cert_resolver.resolve(conn_info.server_name());
|
||||
|
||||
stream.framed = Framed {
|
||||
stream: Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
stream = PqStream {
|
||||
framed: Framed {
|
||||
stream: Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
},
|
||||
read_buf,
|
||||
write_buf,
|
||||
},
|
||||
read_buf,
|
||||
write_buf,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -248,7 +248,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
FeStartupPacket::CancelRequest(cancel_key_data) => {
|
||||
info!(session_type = "cancellation", "successful handshake");
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data, stream.tracker));
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,27 +10,26 @@ pub(crate) mod wake_compute;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use futures::TryFutureExt;
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use passthrough::passthrough;
|
||||
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::{SmolStr, format_smolstr};
|
||||
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
use self::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use self::passthrough::ProxyPassthrough;
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::handshake::{HandshakeData, handshake};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
@@ -71,6 +70,7 @@ pub async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -84,12 +84,12 @@ pub async fn task_main(
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(async move {
|
||||
connections.spawn(async move {
|
||||
let (socket, conn_info) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
@@ -138,41 +138,60 @@ pub async fn task_main(
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
let span = ctx.span();
|
||||
let mut ctx = Some(ctx);
|
||||
|
||||
let res = handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&mut ctx,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
cancellations,
|
||||
)
|
||||
.instrument(span)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
.await;
|
||||
|
||||
match (ctx, res) {
|
||||
(None, _) => {}
|
||||
(Some(ctx), Ok(())) => {
|
||||
ctx.success();
|
||||
}
|
||||
(Some(ctx), Err(e)) => {
|
||||
match res {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -239,79 +258,46 @@ impl ReportableError for ClientRequestError {
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx_slot: &mut Option<RequestContext>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
tracker: TaskTrackerToken,
|
||||
) -> Result<(), ClientRequestError> {
|
||||
let cplane = match auth_backend {
|
||||
auth::Backend::ControlPlane(cplane, ()) => &**cplane,
|
||||
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
|
||||
};
|
||||
|
||||
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
|
||||
debug!(%protocol, "handling interactive connection from client");
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let request_gauge = metrics.connection_requests.guard(protocol);
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let handshake_result: Result<_, ClientRequestError> = async {
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let data = {
|
||||
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
tokio::time::timeout(
|
||||
config.handshake_timeout,
|
||||
handshake(
|
||||
ctx,
|
||||
stream,
|
||||
tracker,
|
||||
mode.handshake_tls(tls),
|
||||
record_handshake_error,
|
||||
),
|
||||
)
|
||||
.await??
|
||||
};
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
|
||||
|
||||
match data {
|
||||
HandshakeData::Startup(mut stream, params) => {
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let host = mode.hostname(stream.get_ref());
|
||||
let cn = tls.map(|tls| &tls.common_names);
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, host, cn);
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e, Some(ctx)).await?,
|
||||
};
|
||||
|
||||
let session = cancellation_handler.get_key();
|
||||
Ok(Some((stream, params, session, user_info)))
|
||||
}
|
||||
HandshakeData::Cancel(cancel_key_data, tracker) => {
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
ctx.set_success();
|
||||
|
||||
let cancel_span = tracing::info_span!(parent: None, "cancel_session", session_id = ?ctx.session_id());
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
tokio::spawn(async move {
|
||||
// ensure the proxy doesn't shutdown until we complete this task.
|
||||
let _tracker = tracker;
|
||||
|
||||
cancellation_handler
|
||||
async move {
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
ctx,
|
||||
@@ -319,108 +305,111 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
|
||||
config.authentication_config.is_vpc_acccess_proxy,
|
||||
auth_backend.get_api(),
|
||||
)
|
||||
.instrument(cancel_span)
|
||||
.await
|
||||
.unwrap_or_else(|e| debug!(error = ?e, "cancel_session failed"));
|
||||
});
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
});
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
||||
let Some((mut stream, params, session, user_info)) = handshake_result? else {
|
||||
return Ok(());
|
||||
};
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
drop(pause);
|
||||
|
||||
let auth_result: Result<_, ClientRequestError> = async {
|
||||
let user = user_info.user.clone();
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
match cplane
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut stream,
|
||||
user_info,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => Ok(auth_result),
|
||||
Err(e) => {
|
||||
let db = params.get("database");
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
stream
|
||||
.throw_error(e, Some(ctx))
|
||||
.instrument(params_span)
|
||||
.await?
|
||||
}
|
||||
}
|
||||
}
|
||||
.await;
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
|
||||
let compute_creds = auth_result?;
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
let connect_result: Result<_, ClientRequestError> = async {
|
||||
let compute_user_info = compute_creds.info.clone();
|
||||
let params_compat = compute_user_info
|
||||
.options
|
||||
.get(NeonOptions::PARAMS_COMPAT)
|
||||
.is_some();
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let result = auth_backend
|
||||
.as_ref()
|
||||
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
|
||||
.transpose();
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e, Some(ctx)).await?,
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let (user_info, _ip_allowlist) = match user_info
|
||||
.authenticate(
|
||||
ctx,
|
||||
TcpMechanism {
|
||||
user_info: compute_user_info,
|
||||
params_compat,
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
auth::ControlPlaneWakeCompute {
|
||||
cplane,
|
||||
creds: compute_creds,
|
||||
},
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
&mut stream,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e, Some(ctx)))
|
||||
.await?;
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
let db = params.get("database");
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
prepare_client_connection(&node, *session.key(), &mut stream).await?;
|
||||
return stream
|
||||
.throw_error(e, Some(ctx))
|
||||
.instrument(params_span)
|
||||
.await?;
|
||||
}
|
||||
};
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf, tracker) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
let compute_user_info = match &user_info {
|
||||
auth::Backend::ControlPlane(_, info) => &info.info,
|
||||
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
|
||||
};
|
||||
let params_compat = compute_user_info
|
||||
.options
|
||||
.get(NeonOptions::PARAMS_COMPAT)
|
||||
.is_some();
|
||||
|
||||
Ok((node, stream, tracker))
|
||||
}
|
||||
.await;
|
||||
|
||||
let (node, stream, tracker) = connect_result?;
|
||||
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
ctx.set_success();
|
||||
|
||||
tokio::spawn(passthrough(
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info: compute_user_info.clone(),
|
||||
params_compat,
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&user_info,
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
stream,
|
||||
node,
|
||||
session,
|
||||
request_gauge,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
));
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e, Some(ctx)))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
private_link_id,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use smol_str::{SmolStr, ToSmolStr};
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::debug;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
@@ -8,14 +7,13 @@ use super::copy_bidirectional::ErrorSource;
|
||||
use crate::cancellation;
|
||||
use crate::compute::PostgresConnection;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::stream::Stream;
|
||||
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
|
||||
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn proxy_pass(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
@@ -63,53 +61,41 @@ pub(crate) async fn proxy_pass(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn passthrough<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
ctx: RequestContext,
|
||||
compute_config: &'static ComputeConfig,
|
||||
pub(crate) struct ProxyPassthrough<S> {
|
||||
pub(crate) client: Stream<S>,
|
||||
pub(crate) compute: PostgresConnection,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
pub(crate) session_id: uuid::Uuid,
|
||||
pub(crate) private_link_id: Option<SmolStr>,
|
||||
pub(crate) cancel: cancellation::Session,
|
||||
|
||||
client: Stream<S>,
|
||||
compute: PostgresConnection,
|
||||
cancel: cancellation::Session,
|
||||
|
||||
_req: NumConnectionRequestsGuard<'static>,
|
||||
_conn: NumClientConnectionsGuard<'static>,
|
||||
_tracker: TaskTrackerToken,
|
||||
) {
|
||||
let session_id = ctx.session_id();
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let _disconnect = ctx.log_connect();
|
||||
let res = proxy_pass(client, compute.stream, compute.aux, private_link_id).await;
|
||||
|
||||
match res {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
tracing::warn!(
|
||||
session_id = ?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
tracing::error!(
|
||||
session_id = ?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(err) = compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?session_id, ?err, "could not cancel the query in the database");
|
||||
}
|
||||
|
||||
// we don't need a result. If the queue is full, we just log the error
|
||||
drop(cancel.remove_cancel_key());
|
||||
pub(crate) _req: NumConnectionRequestsGuard<'static>,
|
||||
pub(crate) _conn: NumClientConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
pub(crate) async fn proxy_pass(
|
||||
self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(
|
||||
self.client,
|
||||
self.compute.stream,
|
||||
self.aux,
|
||||
self.private_link_id,
|
||||
)
|
||||
.await;
|
||||
if let Err(err) = self
|
||||
.compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
|
||||
}
|
||||
|
||||
drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,7 +38,6 @@ async fn proxy_mitm(
|
||||
let (end_client, startup) = match handshake(
|
||||
&RequestContext::test(),
|
||||
client1,
|
||||
TaskTracker::new().token(),
|
||||
Some(&server_config1),
|
||||
false,
|
||||
)
|
||||
@@ -46,7 +45,7 @@ async fn proxy_mitm(
|
||||
.unwrap()
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(_, _) => panic!("cancellation not supported"),
|
||||
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
|
||||
};
|
||||
|
||||
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
|
||||
|
||||
@@ -15,7 +15,6 @@ use rstest::rstest;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::pki_types;
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio_util::task::TaskTracker;
|
||||
use tracing_test::traced_test;
|
||||
|
||||
use super::connect_compute::ConnectMechanism;
|
||||
@@ -179,12 +178,10 @@ async fn dummy_proxy(
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let (client, _) = read_proxy_protocol(client).await?;
|
||||
let t = TaskTracker::new().token();
|
||||
let mut stream =
|
||||
match handshake(&RequestContext::test(), client, t, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_, _) => bail!("cancellation not supported"),
|
||||
};
|
||||
let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
|
||||
};
|
||||
|
||||
auth.authenticate(&mut stream).await?;
|
||||
|
||||
@@ -625,7 +622,7 @@ async fn connect_to_compute_success() {
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -639,7 +636,7 @@ async fn connect_to_compute_retry() {
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -654,7 +651,7 @@ async fn connect_to_compute_non_retry_1() {
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -669,7 +666,7 @@ async fn connect_to_compute_non_retry_2() {
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -694,7 +691,7 @@ async fn connect_to_compute_non_retry_3() {
|
||||
connect_to_compute(
|
||||
&ctx,
|
||||
&mechanism,
|
||||
user_info,
|
||||
&user_info,
|
||||
wake_compute_retry_config,
|
||||
&config,
|
||||
)
|
||||
@@ -712,7 +709,7 @@ async fn wake_retry() {
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -727,7 +724,7 @@ async fn wake_non_retry() {
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -746,7 +743,7 @@ async fn fail_but_wake_invalidates_cache() {
|
||||
let user = helper_create_connect_info(&mech);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
|
||||
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -767,7 +764,7 @@ async fn fail_no_wake_skips_cache_invalidation() {
|
||||
let user = helper_create_connect_info(&mech);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
|
||||
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -788,7 +785,7 @@ async fn retry_but_wake_invalidates_cache() {
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -811,7 +808,7 @@ async fn retry_no_wake_skips_invalidation() {
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
@@ -224,13 +224,13 @@ impl PoolingBackend {
|
||||
let backend = self.auth_backend.as_ref().map(|()| keys);
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
TokioMechanism {
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
backend,
|
||||
&backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
&self.config.connect_to_compute,
|
||||
)
|
||||
@@ -268,13 +268,13 @@ impl PoolingBackend {
|
||||
});
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
HyperMechanism {
|
||||
&HyperMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.http_conn_pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
backend,
|
||||
&backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
&self.config.connect_to_compute,
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::time::timeout;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tokio_util::task::TaskTracker;
|
||||
use tracing::{Instrument, info, warn};
|
||||
|
||||
use crate::cancellation::CancellationHandler;
|
||||
@@ -124,6 +124,7 @@ pub async fn task_main(
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
connections.close(); // allows `connections.wait to complete`
|
||||
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
|
||||
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
|
||||
if let Err(e) = conn.set_nodelay(true) {
|
||||
@@ -149,11 +150,11 @@ pub async fn task_main(
|
||||
let conn_token = cancellation_token.child_token();
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let backend = backend.clone();
|
||||
let connections2 = connections.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(
|
||||
let cancellations = cancellations.clone();
|
||||
connections.spawn(
|
||||
async move {
|
||||
let conn_token2 = conn_token.clone();
|
||||
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
|
||||
@@ -180,7 +181,8 @@ pub async fn task_main(
|
||||
Box::pin(connection_handler(
|
||||
config,
|
||||
backend,
|
||||
tracker,
|
||||
connections2,
|
||||
cancellations,
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
conn_token,
|
||||
@@ -303,7 +305,8 @@ async fn connection_startup(
|
||||
async fn connection_handler(
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
tracker: TaskTrackerToken,
|
||||
connections: TaskTracker,
|
||||
cancellations: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_token: CancellationToken,
|
||||
@@ -344,17 +347,19 @@ async fn connection_handler(
|
||||
|
||||
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
|
||||
// By spawning the future, we ensure it never gets cancelled until it decides to.
|
||||
let handler = tokio::spawn(
|
||||
let cancellations = cancellations.clone();
|
||||
let handler = connections.spawn(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
backend.clone(),
|
||||
tracker.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
session_id,
|
||||
conn_info2.clone(),
|
||||
http_request_token,
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellations,
|
||||
)
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
@@ -395,13 +400,14 @@ async fn request_handler(
|
||||
mut request: hyper::Request<Incoming>,
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
tracker: TaskTrackerToken,
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
session_id: uuid::Uuid,
|
||||
conn_info: ConnectionInfo,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellations: TaskTracker,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
@@ -435,17 +441,10 @@ async fn request_handler(
|
||||
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
tokio::spawn(
|
||||
let cancellations = cancellations.clone();
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
let websocket = match websocket.await {
|
||||
Err(e) => {
|
||||
warn!("could not upgrade websocket connection: {e:#}");
|
||||
return;
|
||||
}
|
||||
Ok(websocket) => websocket,
|
||||
};
|
||||
|
||||
websocket::serve_websocket(
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
backend.auth_backend,
|
||||
ctx,
|
||||
@@ -453,9 +452,12 @@ async fn request_handler(
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
host,
|
||||
tracker,
|
||||
cancellations,
|
||||
)
|
||||
.await;
|
||||
.await
|
||||
{
|
||||
warn!("error in websocket connection: {e:#}");
|
||||
}
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
|
||||
@@ -2,14 +2,14 @@ use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll, ready};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use framed_websockets::{Frame, OpCode, WebSocketServer};
|
||||
use futures::{Sink, Stream};
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper::upgrade::OnUpgrade;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::cancellation::CancellationHandler;
|
||||
@@ -17,7 +17,7 @@ use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::{ClientMode, handle_client};
|
||||
use crate::proxy::{ClientMode, ErrorSource, handle_client};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
|
||||
pin_project! {
|
||||
@@ -128,12 +128,13 @@ pub(crate) async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static crate::auth::Backend<'static, ()>,
|
||||
ctx: RequestContext,
|
||||
websocket: Upgraded,
|
||||
websocket: OnUpgrade,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
hostname: Option<String>,
|
||||
tracker: TaskTrackerToken,
|
||||
) {
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
@@ -141,28 +142,36 @@ pub(crate) async fn serve_websocket(
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Ws);
|
||||
|
||||
let mut ctx_slot = Some(ctx);
|
||||
let res = handle_client(
|
||||
let res = Box::pin(handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&mut ctx_slot,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
endpoint_rate_limiter,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
)
|
||||
cancellations,
|
||||
))
|
||||
.await;
|
||||
|
||||
match (ctx_slot, res) {
|
||||
(None, _) => {}
|
||||
(Some(ctx), Err(e)) => {
|
||||
match res {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
Err(e.into())
|
||||
}
|
||||
(Some(ctx), Ok(())) => {
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
Ok(())
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::control_plane::messages::ColdStartInfo;
|
||||
@@ -25,22 +24,19 @@ use crate::tls::TlsServerEndPoint;
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
pub(crate) framed: Framed<S>,
|
||||
pub(crate) tracker: TaskTrackerToken,
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper.
|
||||
pub fn new(stream: S, tracker: TaskTrackerToken) -> Self {
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
framed: Framed::new(stream),
|
||||
tracker,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the underlying stream and read buffer.
|
||||
pub fn into_inner(self) -> (S, BytesMut, TaskTrackerToken) {
|
||||
let (stream, read) = self.framed.into_inner();
|
||||
(stream, read, self.tracker)
|
||||
pub fn into_inner(self) -> (S, BytesMut) {
|
||||
self.framed.into_inner()
|
||||
}
|
||||
|
||||
/// Get a shared reference to the underlying stream.
|
||||
|
||||
@@ -44,7 +44,6 @@ struct GlobalTimelinesState {
|
||||
// on-demand timeline creation from recreating deleted timelines. This is only soft-enforced, as
|
||||
// this map is dropped on restart.
|
||||
tombstones: HashMap<TenantTimelineId, Instant>,
|
||||
tenant_tombstones: HashMap<TenantId, Instant>,
|
||||
|
||||
conf: Arc<SafeKeeperConf>,
|
||||
broker_active_set: Arc<TimelinesSet>,
|
||||
@@ -82,25 +81,10 @@ impl GlobalTimelinesState {
|
||||
}
|
||||
}
|
||||
|
||||
fn has_tombstone(&self, ttid: &TenantTimelineId) -> bool {
|
||||
self.tombstones.contains_key(ttid) || self.tenant_tombstones.contains_key(&ttid.tenant_id)
|
||||
}
|
||||
|
||||
/// Removes all blocking tombstones for the given timeline ID.
|
||||
/// Returns `true` if there have been actual changes.
|
||||
fn remove_tombstone(&mut self, ttid: &TenantTimelineId) -> bool {
|
||||
self.tombstones.remove(ttid).is_some()
|
||||
|| self.tenant_tombstones.remove(&ttid.tenant_id).is_some()
|
||||
}
|
||||
|
||||
fn delete(&mut self, ttid: TenantTimelineId) {
|
||||
self.timelines.remove(&ttid);
|
||||
self.tombstones.insert(ttid, Instant::now());
|
||||
}
|
||||
|
||||
fn add_tenant_tombstone(&mut self, tenant_id: TenantId) {
|
||||
self.tenant_tombstones.insert(tenant_id, Instant::now());
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct used to manage access to the global timelines map.
|
||||
@@ -115,7 +99,6 @@ impl GlobalTimelines {
|
||||
state: Mutex::new(GlobalTimelinesState {
|
||||
timelines: HashMap::new(),
|
||||
tombstones: HashMap::new(),
|
||||
tenant_tombstones: HashMap::new(),
|
||||
conf,
|
||||
broker_active_set: Arc::new(TimelinesSet::default()),
|
||||
global_rate_limiter: RateLimiter::new(1, 1),
|
||||
@@ -262,7 +245,7 @@ impl GlobalTimelines {
|
||||
return Ok(timeline);
|
||||
}
|
||||
|
||||
if state.has_tombstone(&ttid) {
|
||||
if state.tombstones.contains_key(&ttid) {
|
||||
anyhow::bail!("Timeline {ttid} is deleted, refusing to recreate");
|
||||
}
|
||||
|
||||
@@ -312,14 +295,13 @@ impl GlobalTimelines {
|
||||
_ => {}
|
||||
}
|
||||
if check_tombstone {
|
||||
if state.has_tombstone(&ttid) {
|
||||
if state.tombstones.contains_key(&ttid) {
|
||||
anyhow::bail!("timeline {ttid} is deleted, refusing to recreate");
|
||||
}
|
||||
} else {
|
||||
// We may be have been asked to load a timeline that was previously deleted (e.g. from `pull_timeline.rs`). We trust
|
||||
// that the human doing this manual intervention knows what they are doing, and remove its tombstone.
|
||||
// It's also possible that we enter this when the tenant has been deleted, even if the timeline itself has never existed.
|
||||
if state.remove_tombstone(&ttid) {
|
||||
if state.tombstones.remove(&ttid).is_some() {
|
||||
warn!("un-deleted timeline {ttid}");
|
||||
}
|
||||
}
|
||||
@@ -500,7 +482,6 @@ impl GlobalTimelines {
|
||||
let tli_res = {
|
||||
let state = self.state.lock().unwrap();
|
||||
|
||||
// Do NOT check tenant tombstones here: those were set earlier
|
||||
if state.tombstones.contains_key(ttid) {
|
||||
// Presence of a tombstone guarantees that a previous deletion has completed and there is no work to do.
|
||||
info!("Timeline {ttid} was already deleted");
|
||||
@@ -576,10 +557,6 @@ impl GlobalTimelines {
|
||||
action: DeleteOrExclude,
|
||||
) -> Result<HashMap<TenantTimelineId, TimelineDeleteResult>> {
|
||||
info!("deleting all timelines for tenant {}", tenant_id);
|
||||
|
||||
// Adding a tombstone before getting the timelines to prevent new timeline additions
|
||||
self.state.lock().unwrap().add_tenant_tombstone(*tenant_id);
|
||||
|
||||
let to_delete = self.get_all_for_tenant(*tenant_id);
|
||||
|
||||
let mut err = None;
|
||||
@@ -623,9 +600,6 @@ impl GlobalTimelines {
|
||||
state
|
||||
.tombstones
|
||||
.retain(|_, v| now.duration_since(*v) < *tombstone_ttl);
|
||||
state
|
||||
.tenant_tombstones
|
||||
.retain(|_, v| now.duration_since(*v) < *tombstone_ttl);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -482,10 +482,6 @@ async fn handle_tenant_timeline_delete(
|
||||
ForwardOutcome::NotForwarded(_req) => {}
|
||||
};
|
||||
|
||||
service
|
||||
.maybe_delete_timeline_import(tenant_id, timeline_id)
|
||||
.await?;
|
||||
|
||||
// For timeline deletions, which both implement an "initially return 202, then 404 once
|
||||
// we're done" semantic, we wrap with a retry loop to expose a simpler API upstream.
|
||||
async fn deletion_wrapper<R, F>(service: Arc<Service>, f: F) -> Result<Response<Body>, ApiError>
|
||||
|
||||
@@ -99,8 +99,8 @@ use crate::tenant_shard::{
|
||||
ScheduleOptimization, ScheduleOptimizationAction, TenantShard,
|
||||
};
|
||||
use crate::timeline_import::{
|
||||
FinalizingImport, ImportResult, ShardImportStatuses, TimelineImport,
|
||||
TimelineImportFinalizeError, TimelineImportState, UpcallClient,
|
||||
ImportResult, ShardImportStatuses, TimelineImport, TimelineImportFinalizeError,
|
||||
TimelineImportState, UpcallClient,
|
||||
};
|
||||
|
||||
const WAITER_FILL_DRAIN_POLL_TIMEOUT: Duration = Duration::from_millis(500);
|
||||
@@ -232,9 +232,6 @@ struct ServiceState {
|
||||
|
||||
/// Queue of tenants who are waiting for concurrency limits to permit them to reconcile
|
||||
delayed_reconcile_rx: tokio::sync::mpsc::Receiver<TenantShardId>,
|
||||
|
||||
/// Tracks ongoing timeline import finalization tasks
|
||||
imports_finalizing: BTreeMap<(TenantId, TimelineId), FinalizingImport>,
|
||||
}
|
||||
|
||||
/// Transform an error from a pageserver into an error to return to callers of a storage
|
||||
@@ -311,7 +308,6 @@ impl ServiceState {
|
||||
scheduler,
|
||||
ongoing_operation: None,
|
||||
delayed_reconcile_rx,
|
||||
imports_finalizing: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4101,58 +4097,13 @@ impl Service {
|
||||
///
|
||||
/// If this method gets pre-empted by shut down, it will be called again at start-up (on-going
|
||||
/// imports are stored in the database).
|
||||
///
|
||||
/// # Cancel-Safety
|
||||
/// Not cancel safe.
|
||||
/// If the caller stops polling, the import will not be removed from
|
||||
/// [`ServiceState::imports_finalizing`].
|
||||
#[instrument(skip_all, fields(
|
||||
tenant_id=%import.tenant_id,
|
||||
timeline_id=%import.timeline_id,
|
||||
))]
|
||||
|
||||
async fn finalize_timeline_import(
|
||||
self: &Arc<Self>,
|
||||
import: TimelineImport,
|
||||
) -> Result<(), TimelineImportFinalizeError> {
|
||||
let tenant_timeline = (import.tenant_id, import.timeline_id);
|
||||
|
||||
let (_finalize_import_guard, cancel) = {
|
||||
let mut locked = self.inner.write().unwrap();
|
||||
let gate = Gate::default();
|
||||
let cancel = CancellationToken::default();
|
||||
|
||||
let guard = gate.enter().unwrap();
|
||||
|
||||
locked.imports_finalizing.insert(
|
||||
tenant_timeline,
|
||||
FinalizingImport {
|
||||
gate,
|
||||
cancel: cancel.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
(guard, cancel)
|
||||
};
|
||||
|
||||
let res = tokio::select! {
|
||||
res = self.finalize_timeline_import_impl(import) => {
|
||||
res
|
||||
},
|
||||
_ = cancel.cancelled() => {
|
||||
Err(TimelineImportFinalizeError::Cancelled)
|
||||
}
|
||||
};
|
||||
|
||||
let mut locked = self.inner.write().unwrap();
|
||||
locked.imports_finalizing.remove(&tenant_timeline);
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
async fn finalize_timeline_import_impl(
|
||||
self: &Arc<Self>,
|
||||
import: TimelineImport,
|
||||
) -> Result<(), TimelineImportFinalizeError> {
|
||||
tracing::info!("Finalizing timeline import");
|
||||
|
||||
@@ -4352,46 +4303,6 @@ impl Service {
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Delete a timeline import if it exists
|
||||
///
|
||||
/// Firstly, delete the entry from the database. Any updates
|
||||
/// from pageservers after the update will fail with a 404, so the
|
||||
/// import cannot progress into finalizing state if it's not there already.
|
||||
/// Secondly, cancel the finalization if one is in progress.
|
||||
pub(crate) async fn maybe_delete_timeline_import(
|
||||
self: &Arc<Self>,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
) -> Result<(), DatabaseError> {
|
||||
let tenant_has_ongoing_import = {
|
||||
let locked = self.inner.read().unwrap();
|
||||
locked
|
||||
.tenants
|
||||
.range(TenantShardId::tenant_range(tenant_id))
|
||||
.any(|(_tid, shard)| shard.importing == TimelineImportState::Importing)
|
||||
};
|
||||
|
||||
if !tenant_has_ongoing_import {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.persistence
|
||||
.delete_timeline_import(tenant_id, timeline_id)
|
||||
.await?;
|
||||
|
||||
let maybe_finalizing = {
|
||||
let mut locked = self.inner.write().unwrap();
|
||||
locked.imports_finalizing.remove(&(tenant_id, timeline_id))
|
||||
};
|
||||
|
||||
if let Some(finalizing) = maybe_finalizing {
|
||||
finalizing.cancel.cancel();
|
||||
finalizing.gate.close().await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn tenant_timeline_archival_config(
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
@@ -8627,9 +8538,8 @@ impl Service {
|
||||
Some(ShardCount(new_shard_count))
|
||||
}
|
||||
|
||||
/// Fetches the top tenant shards from every available node, in descending order of
|
||||
/// max logical size. Offline nodes are skipped, and any errors from available nodes
|
||||
/// will be logged and ignored.
|
||||
/// Fetches the top tenant shards from every node, in descending order of
|
||||
/// max logical size. Any node errors will be logged and ignored.
|
||||
async fn get_top_tenant_shards(
|
||||
&self,
|
||||
request: &TopTenantShardsRequest,
|
||||
@@ -8640,7 +8550,6 @@ impl Service {
|
||||
.unwrap()
|
||||
.nodes
|
||||
.values()
|
||||
.filter(|node| node.is_available())
|
||||
.cloned()
|
||||
.collect_vec();
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use pageserver_api::models::{ShardImportProgress, ShardImportStatus};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use utils::sync::gate::Gate;
|
||||
use utils::{
|
||||
id::{TenantId, TimelineId},
|
||||
shard::ShardIndex,
|
||||
@@ -56,8 +55,6 @@ pub(crate) enum TimelineImportUpdateFollowUp {
|
||||
pub(crate) enum TimelineImportFinalizeError {
|
||||
#[error("Shut down interrupted import finalize")]
|
||||
ShuttingDown,
|
||||
#[error("Import finalization was cancelled")]
|
||||
Cancelled,
|
||||
#[error("Mismatched shard detected during import finalize: {0}")]
|
||||
MismatchedShards(ShardIndex),
|
||||
}
|
||||
@@ -167,11 +164,6 @@ impl TimelineImport {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct FinalizingImport {
|
||||
pub(crate) gate: Gate,
|
||||
pub(crate) cancel: CancellationToken,
|
||||
}
|
||||
|
||||
pub(crate) type ImportResult = Result<(), String>;
|
||||
|
||||
pub(crate) struct UpcallClient {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
@@ -12,7 +11,6 @@ from _pytest.config import Config
|
||||
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_cli import AbstractNeonCli
|
||||
from fixtures.neon_fixtures import Endpoint, VanillaPostgres
|
||||
from fixtures.pg_version import PgVersion
|
||||
from fixtures.remote_storage import MockS3Server
|
||||
|
||||
@@ -163,57 +161,3 @@ def fast_import(
|
||||
f.write(fi.cmd.stderr)
|
||||
|
||||
log.info("Written logs to %s", test_output_dir)
|
||||
|
||||
|
||||
def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path):
|
||||
"""
|
||||
Mock the import S3 bucket into a local directory for a provided vanilla PG instance.
|
||||
"""
|
||||
assert not vanilla_pg.is_running()
|
||||
|
||||
path.mkdir()
|
||||
# what cplane writes before scheduling fast_import
|
||||
specpath = path / "spec.json"
|
||||
specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"}))
|
||||
# what fast_import writes
|
||||
vanilla_pg.pgdatadir.rename(path / "pgdata")
|
||||
statusdir = path / "status"
|
||||
statusdir.mkdir()
|
||||
(statusdir / "pgdata").write_text(json.dumps({"done": True}))
|
||||
(statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True}))
|
||||
|
||||
|
||||
def populate_vanilla_pg(vanilla_pg: VanillaPostgres, target_relblock_size: int) -> int:
|
||||
assert vanilla_pg.is_running()
|
||||
|
||||
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
|
||||
# fillfactor so we don't need to produce that much data
|
||||
# 900 byte per row is > 10% => 1 row per page
|
||||
vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""")
|
||||
|
||||
nrows = 0
|
||||
while True:
|
||||
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
|
||||
log.info(
|
||||
f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages"
|
||||
)
|
||||
if relblock_size >= target_relblock_size:
|
||||
break
|
||||
addrows = int((target_relblock_size - relblock_size) // 8192)
|
||||
assert addrows >= 1, "forward progress"
|
||||
vanilla_pg.safe_psql(
|
||||
f"insert into t select generate_series({nrows + 1}, {nrows + addrows})"
|
||||
)
|
||||
nrows += addrows
|
||||
|
||||
return nrows
|
||||
|
||||
|
||||
def validate_import_from_vanilla_pg(endpoint: Endpoint, nrows: int):
|
||||
assert endpoint.safe_psql_many(
|
||||
[
|
||||
"set effective_io_concurrency=32;",
|
||||
"SET statement_timeout='300s';",
|
||||
"select count(*), sum(data::bigint)::bigint from t",
|
||||
]
|
||||
) == [[], [], [(nrows, nrows * (nrows + 1) // 2)]]
|
||||
|
||||
@@ -2337,22 +2337,6 @@ class NeonStorageController(MetricsGetter, LogUtils):
|
||||
headers=self.headers(TokenScope.ADMIN),
|
||||
)
|
||||
|
||||
def import_status(
|
||||
self, tenant_shard_id: TenantShardId, timeline_id: TimelineId, generation: int
|
||||
):
|
||||
payload = {
|
||||
"tenant_shard_id": str(tenant_shard_id),
|
||||
"timeline_id": str(timeline_id),
|
||||
"generation": generation,
|
||||
}
|
||||
|
||||
self.request(
|
||||
"GET",
|
||||
f"{self.api}/upcall/v1/timeline_import_status",
|
||||
headers=self.headers(TokenScope.GENERATIONS_API),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
def reconcile_all(self):
|
||||
r = self.request(
|
||||
"POST",
|
||||
@@ -2829,11 +2813,6 @@ class NeonPageserver(PgProtocol, LogUtils):
|
||||
if self.running:
|
||||
self.http_client().configure_failpoints([(name, action)])
|
||||
|
||||
def clear_persistent_failpoint(self, name: str):
|
||||
del self._persistent_failpoints[name]
|
||||
if self.running:
|
||||
self.http_client().configure_failpoints([(name, "off")])
|
||||
|
||||
def timeline_dir(
|
||||
self,
|
||||
tenant_shard_id: TenantId | TenantShardId,
|
||||
|
||||
@@ -675,7 +675,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
|
||||
|
||||
def timeline_delete(
|
||||
self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, **kwargs
|
||||
) -> int:
|
||||
):
|
||||
"""
|
||||
Note that deletion is not instant, it is scheduled and performed mostly in the background.
|
||||
So if you need to wait for it to complete use `timeline_delete_wait_completed`.
|
||||
@@ -688,8 +688,6 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
|
||||
res_json = res.json()
|
||||
assert res_json is None
|
||||
|
||||
return res.status_code
|
||||
|
||||
def timeline_gc(
|
||||
self,
|
||||
tenant_id: TenantId | TenantShardId,
|
||||
|
||||
@@ -1,41 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import json
|
||||
import time
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from threading import Event
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from fixtures.common_types import Lsn, TenantId, TimelineId
|
||||
from fixtures.fast_import import mock_import_bucket, populate_vanilla_pg
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_fixtures import (
|
||||
NeonEnv,
|
||||
NeonEnvBuilder,
|
||||
NeonPageserver,
|
||||
PgBin,
|
||||
VanillaPostgres,
|
||||
wait_for_last_flush_lsn,
|
||||
)
|
||||
from fixtures.pageserver.http import (
|
||||
ImportPgdataIdemptencyKey,
|
||||
)
|
||||
from fixtures.pageserver.utils import wait_for_upload_queue_empty
|
||||
from fixtures.remote_storage import RemoteStorageKind
|
||||
from fixtures.utils import human_bytes, run_only_on_default_postgres, wait_until
|
||||
from werkzeug.wrappers.response import Response
|
||||
from fixtures.utils import human_bytes, wait_until
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from fixtures.pageserver.http import PageserverHttpClient
|
||||
from pytest_httpserver import HTTPServer
|
||||
from werkzeug.wrappers.request import Request
|
||||
|
||||
|
||||
GLOBAL_LRU_LOG_LINE = "tenant_min_resident_size-respecting LRU would not relieve pressure, evicting more following global LRU policy"
|
||||
@@ -174,7 +164,6 @@ class EvictionEnv:
|
||||
min_avail_bytes,
|
||||
mock_behavior,
|
||||
eviction_order: EvictionOrder,
|
||||
wait_logical_size: bool = True,
|
||||
):
|
||||
"""
|
||||
Starts pageserver up with mocked statvfs setup. The startup is
|
||||
@@ -212,12 +201,11 @@ class EvictionEnv:
|
||||
pageserver.start()
|
||||
|
||||
# we now do initial logical size calculation on startup, which on debug builds can fight with disk usage based eviction
|
||||
if wait_logical_size:
|
||||
for tenant_id, timeline_id in self.timelines:
|
||||
tenant_ps = self.neon_env.get_tenant_pageserver(tenant_id)
|
||||
# Pageserver may be none if we are currently not attached anywhere, e.g. during secondary eviction test
|
||||
if tenant_ps is not None:
|
||||
tenant_ps.http_client().timeline_wait_logical_size(tenant_id, timeline_id)
|
||||
for tenant_id, timeline_id in self.timelines:
|
||||
tenant_ps = self.neon_env.get_tenant_pageserver(tenant_id)
|
||||
# Pageserver may be none if we are currently not attached anywhere, e.g. during secondary eviction test
|
||||
if tenant_ps is not None:
|
||||
tenant_ps.http_client().timeline_wait_logical_size(tenant_id, timeline_id)
|
||||
|
||||
def statvfs_called():
|
||||
pageserver.assert_log_contains(".*running mocked statvfs.*")
|
||||
@@ -894,121 +882,3 @@ def test_secondary_mode_eviction(eviction_env_ha: EvictionEnv):
|
||||
assert total_size - post_eviction_total_size >= evict_bytes, (
|
||||
"we requested at least evict_bytes worth of free space"
|
||||
)
|
||||
|
||||
|
||||
@run_only_on_default_postgres(reason="PG version is irrelevant here")
|
||||
def test_import_timeline_disk_pressure_eviction(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
vanilla_pg: VanillaPostgres,
|
||||
make_httpserver: HTTPServer,
|
||||
pg_bin: PgBin,
|
||||
):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
# Set up mock control plane HTTP server to listen for import completions
|
||||
import_completion_signaled = Event()
|
||||
|
||||
def handler(request: Request) -> Response:
|
||||
log.info(f"control plane /import_complete request: {request.json}")
|
||||
import_completion_signaled.set()
|
||||
return Response(json.dumps({}), status=200)
|
||||
|
||||
cplane_mgmt_api_server = make_httpserver
|
||||
cplane_mgmt_api_server.expect_request(
|
||||
"/storage/api/v1/import_complete", method="PUT"
|
||||
).respond_with_handler(handler)
|
||||
|
||||
# Plug the cplane mock in
|
||||
neon_env_builder.control_plane_hooks_api = (
|
||||
f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/"
|
||||
)
|
||||
|
||||
# The import will specifiy a local filesystem path mocking remote storage
|
||||
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
|
||||
|
||||
vanilla_pg.start()
|
||||
target_relblock_size = 1024 * 1024 * 128
|
||||
populate_vanilla_pg(vanilla_pg, target_relblock_size)
|
||||
vanilla_pg.stop()
|
||||
|
||||
env = neon_env_builder.init_configs()
|
||||
env.start()
|
||||
|
||||
importbucket_path = neon_env_builder.repo_dir / "test_import_completion_bucket"
|
||||
mock_import_bucket(vanilla_pg, importbucket_path)
|
||||
|
||||
tenant_id = TenantId.generate()
|
||||
timeline_id = TimelineId.generate()
|
||||
idempotency = ImportPgdataIdemptencyKey.random()
|
||||
|
||||
eviction_env = EvictionEnv(
|
||||
timelines=[(tenant_id, timeline_id)],
|
||||
neon_env=env,
|
||||
pageserver_http=env.pageserver.http_client(),
|
||||
layer_size=5 * 1024 * 1024, # Doesn't apply here
|
||||
pg_bin=pg_bin, # Not used here
|
||||
pgbench_init_lsns={}, # Not used here
|
||||
)
|
||||
|
||||
# Pause before delivering the final notification to storcon.
|
||||
# This keeps the import in progress.
|
||||
failpoint_name = "import-timeline-pre-success-notify-pausable"
|
||||
env.pageserver.add_persistent_failpoint(failpoint_name, "pause")
|
||||
|
||||
env.storage_controller.tenant_create(tenant_id)
|
||||
env.storage_controller.timeline_create(
|
||||
tenant_id,
|
||||
{
|
||||
"new_timeline_id": str(timeline_id),
|
||||
"import_pgdata": {
|
||||
"idempotency_key": str(idempotency),
|
||||
"location": {"LocalFs": {"path": str(importbucket_path.absolute())}},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def hit_failpoint():
|
||||
log.info("Checking log for pattern...")
|
||||
try:
|
||||
assert env.pageserver.log_contains(f".*at failpoint {failpoint_name}.*")
|
||||
except Exception:
|
||||
log.exception("Failed to find pattern in log")
|
||||
raise
|
||||
|
||||
wait_until(hit_failpoint)
|
||||
assert not import_completion_signaled.is_set()
|
||||
|
||||
env.pageserver.stop()
|
||||
|
||||
total_size, _, _ = eviction_env.timelines_du(env.pageserver)
|
||||
blocksize = 512
|
||||
total_blocks = (total_size + (blocksize - 1)) // blocksize
|
||||
|
||||
eviction_env.pageserver_start_with_disk_usage_eviction(
|
||||
env.pageserver,
|
||||
period="1s",
|
||||
max_usage_pct=33,
|
||||
min_avail_bytes=0,
|
||||
mock_behavior={
|
||||
"type": "Success",
|
||||
"blocksize": blocksize,
|
||||
"total_blocks": total_blocks,
|
||||
# Only count layer files towards used bytes in the mock_statvfs.
|
||||
# This avoids accounting for metadata files & tenant conf in the tests.
|
||||
"name_filter": ".*__.*",
|
||||
},
|
||||
eviction_order=EvictionOrder.RELATIVE_ORDER_SPARE,
|
||||
wait_logical_size=False,
|
||||
)
|
||||
|
||||
wait_until(lambda: env.pageserver.assert_log_contains(".*disk usage pressure relieved"))
|
||||
|
||||
env.pageserver.clear_persistent_failpoint(failpoint_name)
|
||||
|
||||
def cplane_notified():
|
||||
assert import_completion_signaled.is_set()
|
||||
|
||||
wait_until(cplane_notified)
|
||||
|
||||
env.pageserver.allowed_errors.append(r".* running disk usage based eviction due to pressure.*")
|
||||
|
||||
@@ -12,19 +12,13 @@ import psycopg2
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId
|
||||
from fixtures.fast_import import (
|
||||
FastImport,
|
||||
mock_import_bucket,
|
||||
populate_vanilla_pg,
|
||||
validate_import_from_vanilla_pg,
|
||||
)
|
||||
from fixtures.fast_import import FastImport
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_fixtures import (
|
||||
NeonEnvBuilder,
|
||||
PageserverImportConfig,
|
||||
PgBin,
|
||||
PgProtocol,
|
||||
StorageControllerApiException,
|
||||
StorageControllerMigrationConfig,
|
||||
VanillaPostgres,
|
||||
)
|
||||
@@ -65,6 +59,24 @@ smoke_params = [
|
||||
]
|
||||
|
||||
|
||||
def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path):
|
||||
"""
|
||||
Mock the import S3 bucket into a local directory for a provided vanilla PG instance.
|
||||
"""
|
||||
assert not vanilla_pg.is_running()
|
||||
|
||||
path.mkdir()
|
||||
# what cplane writes before scheduling fast_import
|
||||
specpath = path / "spec.json"
|
||||
specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"}))
|
||||
# what fast_import writes
|
||||
vanilla_pg.pgdatadir.rename(path / "pgdata")
|
||||
statusdir = path / "status"
|
||||
statusdir.mkdir()
|
||||
(statusdir / "pgdata").write_text(json.dumps({"done": True}))
|
||||
(statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True}))
|
||||
|
||||
|
||||
@skip_in_debug_build("MULTIPLE_RELATION_SEGMENTS has non trivial amount of data")
|
||||
@pytest.mark.parametrize("shard_count,stripe_size,rel_block_size", smoke_params)
|
||||
def test_pgdata_import_smoke(
|
||||
@@ -119,6 +131,10 @@ def test_pgdata_import_smoke(
|
||||
# Put data in vanilla pg
|
||||
#
|
||||
|
||||
vanilla_pg.start()
|
||||
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
|
||||
|
||||
log.info("create relblock data")
|
||||
if rel_block_size == RelBlockSize.ONE_STRIPE_SIZE:
|
||||
target_relblock_size = stripe_size * 8192
|
||||
elif rel_block_size == RelBlockSize.TWO_STRPES_PER_SHARD:
|
||||
@@ -129,8 +145,45 @@ def test_pgdata_import_smoke(
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
vanilla_pg.start()
|
||||
rows_inserted = populate_vanilla_pg(vanilla_pg, target_relblock_size)
|
||||
# fillfactor so we don't need to produce that much data
|
||||
# 900 byte per row is > 10% => 1 row per page
|
||||
vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""")
|
||||
|
||||
nrows = 0
|
||||
while True:
|
||||
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
|
||||
log.info(
|
||||
f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages"
|
||||
)
|
||||
if relblock_size >= target_relblock_size:
|
||||
break
|
||||
addrows = int((target_relblock_size - relblock_size) // 8192)
|
||||
assert addrows >= 1, "forward progress"
|
||||
vanilla_pg.safe_psql(
|
||||
f"insert into t select generate_series({nrows + 1}, {nrows + addrows})"
|
||||
)
|
||||
nrows += addrows
|
||||
expect_nrows = nrows
|
||||
expect_sum = (
|
||||
(nrows) * (nrows + 1) // 2
|
||||
) # https://stackoverflow.com/questions/43901484/sum-of-the-integers-from-1-to-n
|
||||
|
||||
def validate_vanilla_equivalence(ep):
|
||||
# TODO: would be nicer to just compare pgdump
|
||||
|
||||
# Enable IO concurrency for batching on large sequential scan, to avoid making
|
||||
# this test unnecessarily onerous on CPU. Especially on debug mode, it's still
|
||||
# pretty onerous though, so increase statement_timeout to avoid timeouts.
|
||||
assert ep.safe_psql_many(
|
||||
[
|
||||
"set effective_io_concurrency=32;",
|
||||
"SET statement_timeout='300s';",
|
||||
"select count(*), sum(data::bigint)::bigint from t",
|
||||
]
|
||||
) == [[], [], [(expect_nrows, expect_sum)]]
|
||||
|
||||
validate_vanilla_equivalence(vanilla_pg)
|
||||
|
||||
vanilla_pg.stop()
|
||||
|
||||
#
|
||||
@@ -221,14 +274,14 @@ def test_pgdata_import_smoke(
|
||||
config_lines=ep_config,
|
||||
)
|
||||
|
||||
validate_import_from_vanilla_pg(ro_endpoint, rows_inserted)
|
||||
validate_vanilla_equivalence(ro_endpoint)
|
||||
|
||||
# ensure the import survives restarts
|
||||
ro_endpoint.stop()
|
||||
env.pageserver.stop(immediate=True)
|
||||
env.pageserver.start()
|
||||
ro_endpoint.start()
|
||||
validate_import_from_vanilla_pg(ro_endpoint, rows_inserted)
|
||||
validate_vanilla_equivalence(ro_endpoint)
|
||||
|
||||
#
|
||||
# validate the layer files in each shard only have the shard-specific data
|
||||
@@ -268,7 +321,7 @@ def test_pgdata_import_smoke(
|
||||
child_workload = workload.branch(timeline_id=child_timeline_id, branch_name="br-tip")
|
||||
child_workload.validate()
|
||||
|
||||
validate_import_from_vanilla_pg(child_workload.endpoint(), rows_inserted)
|
||||
validate_vanilla_equivalence(child_workload.endpoint())
|
||||
|
||||
# ... at the initdb lsn
|
||||
_ = env.create_branch(
|
||||
@@ -283,7 +336,7 @@ def test_pgdata_import_smoke(
|
||||
tenant_id=tenant_id,
|
||||
config_lines=ep_config,
|
||||
)
|
||||
validate_import_from_vanilla_pg(br_initdb_endpoint, rows_inserted)
|
||||
validate_vanilla_equivalence(br_initdb_endpoint)
|
||||
with pytest.raises(psycopg2.errors.UndefinedTable):
|
||||
br_initdb_endpoint.safe_psql(f"select * from {workload.table}")
|
||||
|
||||
@@ -370,12 +423,8 @@ def test_import_completion_on_restart(
|
||||
|
||||
|
||||
@run_only_on_default_postgres(reason="PG version is irrelevant here")
|
||||
@pytest.mark.parametrize("action", ["restart", "delete"])
|
||||
def test_import_respects_timeline_lifecycle(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
vanilla_pg: VanillaPostgres,
|
||||
make_httpserver: HTTPServer,
|
||||
action: str,
|
||||
def test_import_respects_tenant_shutdown(
|
||||
neon_env_builder: NeonEnvBuilder, vanilla_pg: VanillaPostgres, make_httpserver: HTTPServer
|
||||
):
|
||||
"""
|
||||
Validate that importing timelines respect the usual timeline life cycle:
|
||||
@@ -443,33 +492,16 @@ def test_import_respects_timeline_lifecycle(
|
||||
wait_until(hit_failpoint)
|
||||
assert not import_completion_signaled.is_set()
|
||||
|
||||
if action == "restart":
|
||||
# Restart the pageserver while an import job is in progress.
|
||||
# This clears the failpoint and we expect that the import starts up afresh
|
||||
# after the restart and eventually completes.
|
||||
env.pageserver.stop()
|
||||
env.pageserver.start()
|
||||
# Restart the pageserver while an import job is in progress.
|
||||
# This clears the failpoint and we expect that the import starts up afresh
|
||||
# after the restart and eventually completes.
|
||||
env.pageserver.stop()
|
||||
env.pageserver.start()
|
||||
|
||||
def cplane_notified():
|
||||
assert import_completion_signaled.is_set()
|
||||
def cplane_notified():
|
||||
assert import_completion_signaled.is_set()
|
||||
|
||||
wait_until(cplane_notified)
|
||||
elif action == "delete":
|
||||
status = env.storage_controller.pageserver_api().timeline_delete(tenant_id, timeline_id)
|
||||
assert status == 200
|
||||
|
||||
timeline_path = env.pageserver.timeline_dir(tenant_id, timeline_id)
|
||||
assert not timeline_path.exists(), "Timeline dir exists after deletion"
|
||||
|
||||
shard_zero = TenantShardId(tenant_id, 0, 0)
|
||||
location = env.storage_controller.inspect(shard_zero)
|
||||
assert location is not None
|
||||
generation = location[0]
|
||||
|
||||
with pytest.raises(StorageControllerApiException, match="not found"):
|
||||
env.storage_controller.import_status(shard_zero, timeline_id, generation)
|
||||
else:
|
||||
raise RuntimeError(f"{action} param not recognized")
|
||||
wait_until(cplane_notified)
|
||||
|
||||
|
||||
@skip_in_debug_build("Validation query takes too long in debug builds")
|
||||
@@ -524,8 +556,23 @@ def test_import_chaos(
|
||||
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
|
||||
|
||||
vanilla_pg.start()
|
||||
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
|
||||
vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""")
|
||||
|
||||
inserted_rows = populate_vanilla_pg(vanilla_pg, TARGET_RELBOCK_SIZE)
|
||||
nrows = 0
|
||||
while True:
|
||||
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
|
||||
log.info(
|
||||
f"relblock size: {relblock_size / 8192} pages (target: {TARGET_RELBOCK_SIZE // 8192}) pages"
|
||||
)
|
||||
if relblock_size >= TARGET_RELBOCK_SIZE:
|
||||
break
|
||||
addrows = int((TARGET_RELBOCK_SIZE - relblock_size) // 8192)
|
||||
assert addrows >= 1, "forward progress"
|
||||
vanilla_pg.safe_psql(
|
||||
f"insert into t select generate_series({nrows + 1}, {nrows + addrows})"
|
||||
)
|
||||
nrows += addrows
|
||||
|
||||
vanilla_pg.stop()
|
||||
|
||||
@@ -693,7 +740,13 @@ def test_import_chaos(
|
||||
endpoint = env.endpoints.create_start(branch_name=import_branch_name, tenant_id=tenant_id)
|
||||
|
||||
# Validate the imported data is legit
|
||||
validate_import_from_vanilla_pg(endpoint, inserted_rows)
|
||||
assert endpoint.safe_psql_many(
|
||||
[
|
||||
"set effective_io_concurrency=32;",
|
||||
"SET statement_timeout='300s';",
|
||||
"select count(*), sum(data::bigint)::bigint from t",
|
||||
]
|
||||
) == [[], [], [(nrows, nrows * (nrows + 1) // 2)]]
|
||||
|
||||
endpoint.stop()
|
||||
|
||||
|
||||
@@ -4192,10 +4192,10 @@ def test_storcon_create_delete_sk_down(
|
||||
# ensure the safekeeper deleted the timeline
|
||||
def timeline_deleted_on_active_sks():
|
||||
env.safekeepers[0].assert_log_contains(
|
||||
f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)"
|
||||
f"deleting timeline {tenant_id}/{child_timeline_id} from disk"
|
||||
)
|
||||
env.safekeepers[2].assert_log_contains(
|
||||
f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)"
|
||||
f"deleting timeline {tenant_id}/{child_timeline_id} from disk"
|
||||
)
|
||||
|
||||
wait_until(timeline_deleted_on_active_sks)
|
||||
@@ -4210,7 +4210,7 @@ def test_storcon_create_delete_sk_down(
|
||||
# ensure that there is log msgs for the third safekeeper too
|
||||
def timeline_deleted_on_sk():
|
||||
env.safekeepers[1].assert_log_contains(
|
||||
f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)"
|
||||
f"deleting timeline {tenant_id}/{child_timeline_id} from disk"
|
||||
)
|
||||
|
||||
wait_until(timeline_deleted_on_sk)
|
||||
|
||||
Reference in New Issue
Block a user