Documentation and tweaks

This commit is contained in:
Erik Grinaker
2025-07-01 17:39:18 +02:00
parent 0bce818d5e
commit f6761760a2

View File

@@ -1,27 +1,72 @@
//! This module provides various Pageserver gRPC client resource pools.
//!
//! These pools are designed to reuse gRPC resources (connections, clients, and streams) across
//! multiple callers (i.e. Postgres backends). This avoids the resource cost and latency of creating
//! a dedicated TCP connection and server task for every Postgres backend.
//!
//! Each resource has its own, nested pool. The pools are custom-built for the properties of each
//! resource -- these are different enough that a generic pool isn't suitable.
//!
//! * ChannelPool: manages gRPC channels (TCP connections) to a single Pageserver. Multiple clients
//! can acquire and use the same channel concurrently (via HTTP/2 stream multiplexing), up to a
//! per-channel limit. Channels may be closed when they are no longer used by any clients.
//!
//! * ClientPool: manages gRPC clients for a single tenant shard. Each client acquires a (shared)
//! channel from the ChannelPool for client's lifetime. A client can only be acquired by a single
//! caller at a time, and is returned to the pool when dropped. Idle clients may be removed from
//! the pool after some time, to free up the channel.
//!
//! * StreamPool: manages bidirectional gRPC GetPage streams. Each stream acquires a client from
//! the ClientPool for the stream's lifetime. Internal streams are not exposed to callers;
//! instead, callers submit individual GetPage requests to the pool and await a response.
//! Internally, the pool will reuse or spin up a suitable stream for the request, possibly
//! pipelining multiple requests from multiple callers on the same stream (up to some queue
//! depth), and route the response back to the original caller. Idle streams may be removed from
//! the pool after some time, to free up the client.
use std::collections::{BTreeMap, HashMap};
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, Weak};
use futures::StreamExt;
use futures::StreamExt as _;
use scopeguard::defer;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot};
use tonic::transport::{Channel, Endpoint};
use tracing::warn;
use pageserver_page_api::{self as page_api, GetPageRequest, GetPageResponse};
use tracing::warn;
use utils::id::{TenantId, TimelineId};
use utils::shard::ShardIndex;
/// A gRPC channel pool. A channel is shared by many clients, using HTTP/2 stream multiplexing.
/// This pool allows an unlimited number of channels. Concurrency is limited by ClientPool. It is
/// not performance-critical, because clients (and thus channels) will be reused by ClientPool.
// TODO: tune these constants, and consider making them configurable.
/// Max number of concurrent clients per channel.
///
/// This doesn't use the `Pool` type, because it's designed for exclusive access, while a channel is
/// shared by many clients. Furthermore, we can't build a generic ArcPool for shared items, because
/// Protobuf clients require an owned Channel (not an Arc<Channel>), and we don't have access to the
/// Channel refcount.
struct ChannelPool {
/// TODO: consider separate limits for unary and streaming clients, so we don't fill up channels
/// with only streams.
const CLIENTS_PER_CHANNEL: usize = 16;
/// Maximum number of concurrent clients per `ClientPool`. This bounds the number of channels as
/// CLIENT_LIMIT / CLIENTS_PER_CHANNEL.
const CLIENT_LIMIT: usize = 64;
/// Max number of pipelined requests per gRPC GetPage stream.
const STREAM_QUEUE_DEPTH: usize = 2;
/// A gRPC channel pool, for a single Pageserver. A channel is shared by many clients (via HTTP/2
/// stream multiplexing), up to `CLIENTS_PER_CHANNEL`. The pool does not limit the number of
/// channels, and instead relies on `ClientPool` to limit the number of concurrent clients.
///
/// The pool is always wrapped in an outer `Arc`, to allow long-lived references from guards.
///
/// Tonic will automatically retry the underlying connection if it fails, so there is no need
/// to re-establish connections on errors.
///
/// TODO: reap idle channels.
/// TODO: consider adding a circuit breaker for errors and fail fast.
pub struct ChannelPool {
/// Pageserver endpoint to connect to.
endpoint: Endpoint,
/// Open channels.
@@ -38,77 +83,83 @@ struct ChannelEntry {
}
impl ChannelPool {
/// Max number of concurrent clients per channel.
///
/// TODO: tune this.
/// TODO: consider having separate limits for unary and streaming clients. This way, a channel
/// that's full of streaming requests also has room for a few unary requests.
const CLIENTS_PER_CHANNEL: usize = 16;
/// Creates a new channel pool for the given Pageserver URL.
pub fn new(url: String) -> anyhow::Result<Arc<Self>> {
/// Creates a new channel pool for the given Pageserver endpoint.
pub fn new<E>(endpoint: E) -> anyhow::Result<Arc<Self>>
where
E: TryInto<Endpoint> + Send + Sync + 'static,
<E as TryInto<Endpoint>>::Error: std::error::Error + Send + Sync,
{
Ok(Arc::new(Self {
endpoint: Endpoint::from_shared(url)?,
endpoint: endpoint.try_into()?,
channels: Default::default(),
}))
}
/// Acquires a new gRPC channel.
/// Acquires a gRPC channel for a client. Multiple clients may acquire the same channel.
///
/// NB: this is not particularly performance-sensitive. It is called rarely since clients are
/// cached and reused by ClientPool, and the number of channels will be small. O(n) performance
/// is therefore okay.
pub fn get(self: Arc<Self>) -> anyhow::Result<ChannelGuard> {
/// This never blocks (except for sync mutex acquisition). The channel is connected lazily on
/// first use, and the `ChannelPool` does not have a channel limit.
///
/// Callers should not clone the returned channel, and must hold onto the returned guard as long
/// as the channel is in use. It is unfortunately not possible to enforce this: the Protobuf
/// client requires an owned `Channel` and we don't have access to the channel's internal
/// refcount.
///
/// NB: this is not very performance-sensitive. It is only called when creating a new client,
/// and clients are cached and reused by ClientPool. The total number of channels will also be
/// small. O(n) performance is therefore okay.
pub fn get(self: &Arc<Self>) -> anyhow::Result<ChannelGuard> {
let mut channels = self.channels.lock().unwrap();
// Find an existing channel with available capacity. We check entries in BTreeMap order,
// such that we fill up the earliest channels first. The ClientPool also uses lower-ordered
// channels first. This allows us to reap later channels as they become idle.
// Try to find an existing channel with available capacity. We check entries in BTreeMap
// order, to fill up the lower-ordered channels first. The ClientPool also uses clients with
// lower-ordered channel IDs first. This will cluster clients in lower-ordered channels, and
// free up higher-ordered channels such that they can be reaped.
for (&id, entry) in channels.iter_mut() {
if entry.clients < Self::CLIENTS_PER_CHANNEL {
assert!(entry.clients <= CLIENTS_PER_CHANNEL, "channel overflow");
if entry.clients < CLIENTS_PER_CHANNEL {
entry.clients += 1;
return Ok(ChannelGuard {
pool: Arc::downgrade(&self),
pool: Arc::downgrade(self),
id,
channel: Some(entry.channel.clone()),
});
}
}
// Create a new channel. We connect lazily, such that we don't block and other clients can
// join onto the same channel.
let id = channels.keys().last().copied().unwrap_or_default();
// Create a new channel. We connect lazily on the first use, such that we don't block here
// and other clients can join onto the same channel while it's connecting.
let channel = self.endpoint.connect_lazy();
let guard = ChannelGuard {
pool: Arc::downgrade(&self),
id,
channel: Some(channel.clone()),
};
let id = channels.keys().last().copied().unwrap_or_default();
let entry = ChannelEntry {
channel,
clients: 1,
channel: channel.clone(),
clients: 1, // we're returning the guard below
};
channels.insert(id, entry);
Ok(guard)
Ok(ChannelGuard {
pool: Arc::downgrade(self),
id,
channel: Some(channel.clone()),
})
}
}
struct ChannelGuard {
/// Tracks a channel acquired from the pool. The owned inner channel can be obtained with `take()`.
/// However, the caller must hold onto the guard as long as it's using the channel, and should not
/// clone it.
pub struct ChannelGuard {
pool: Weak<ChannelPool>,
id: ChannelID,
channel: Option<Channel>,
}
impl ChannelGuard {
/// Returns the inner channel. Can only be called once. The caller must hold onto the guard as
/// long as the channel is in use, and should not clone it.
///
/// Unfortunately, we can't enforce that the guard outlives the channel reference, because a
/// Protobuf client requires an owned `Channel` and we don't have access to the channel's
/// internal refcount either. We could if the client took an `Arc<Channel>`.
/// Returns the inner channel. Panics if called more than once. The caller must hold onto the
/// guard as long as the channel is in use, and should not clone it.
pub fn take(&mut self) -> Channel {
self.channel.take().expect("channel")
self.channel.take().expect("channel already taken")
}
}
@@ -120,12 +171,20 @@ impl Drop for ChannelGuard {
};
let mut channels = pool.channels.lock().unwrap();
let entry = channels.get_mut(&self.id).expect("unknown channel");
assert!(entry.clients > 0, "channel clients underflow");
assert!(entry.clients > 0, "channel underflow");
entry.clients -= 1;
}
}
/// A pool of gRPC clients.
/// A pool of gRPC clients for a single tenant shard. Each client acquires a channel from the inner
/// `ChannelPool`. A client is only acquired by a single caller at a time. The pool limits the total
/// number of concurrent clients to `CLIENT_LIMIT` via semaphore.
///
/// The pool is always wrapped in an outer `Arc`, to allow long-lived references from guards.
///
/// TODO: reap idle clients.
/// TODO: error handling (but channel will be reconnected automatically).
/// TODO: rate limiting.
pub struct ClientPool {
/// Tenant ID.
tenant_id: TenantId,
@@ -135,63 +194,68 @@ pub struct ClientPool {
shard_id: ShardIndex,
/// Authentication token, if any.
auth_token: Option<String>,
/// Channel pool.
channels: Arc<ChannelPool>,
/// Limits the max number of concurrent clients.
/// Channel pool to acquire channels from.
channel_pool: Arc<ChannelPool>,
/// Limits the max number of concurrent clients for this pool.
limiter: Arc<Semaphore>,
/// Idle clients in the pool. This is sorted by channel ID and client ID, such that we use idle
/// clients from the lower-numbered channels first. This allows us to reap the higher-numbered
/// channels as they become idle.
idle: Mutex<BTreeMap<ClientKey, ClientEntry>>,
/// Idle pooled clients. Acquired clients are removed from here and returned on drop.
///
/// The first client in the map will be acquired next. The map is sorted by client ID, which in
/// turn is sorted by the channel ID, such that we prefer acquiring idle clients from
/// lower-ordered channels. This allows us to free up and reap higher-numbered channels as idle
/// clients are reaped.
idle: Mutex<BTreeMap<ClientID, ClientEntry>>,
/// Unique client ID generator.
next_client_id: AtomicUsize,
}
type ClientID = usize;
type ClientKey = (ChannelID, ClientID);
type ClientID = (ChannelID, usize);
struct ClientEntry {
client: page_api::Client,
channel_guard: ChannelGuard,
}
impl ClientPool {
const CLIENT_LIMIT: usize = 64; // TODO: make this configurable
/// Creates a new client pool for the given Pageserver and tenant shard.
/// Creates a new client pool for the given tenant shard. Channels are acquired from the given
/// `ChannelPool`, which must point to a Pageserver that hosts the tenant shard.
pub fn new(
url: String,
channel_pool: Arc<ChannelPool>,
tenant_id: TenantId,
timeline_id: TimelineId,
shard_id: ShardIndex,
auth_token: Option<String>,
) -> anyhow::Result<Self> {
Ok(Self {
) -> Arc<Self> {
Arc::new(Self {
tenant_id,
timeline_id,
shard_id,
auth_token,
channels: ChannelPool::new(url)?,
channel_pool,
idle: Mutex::default(),
limiter: Arc::new(Semaphore::new(Self::CLIENT_LIMIT)),
limiter: Arc::new(Semaphore::new(CLIENT_LIMIT)),
next_client_id: AtomicUsize::default(),
})
}
/// Gets a client from the pool, or creates a new one if necessary. The client is returned to
/// the pool when the guard is dropped.
pub async fn get(self: Arc<Self>) -> anyhow::Result<ClientGuard> {
/// Gets a client from the pool, or creates a new one if necessary. Blocks if the pool is at
/// `CLIENT_LIMIT`. The client is returned to the pool when the guard is dropped.
///
/// This is moderately performance-sensitive. It is called for every unary request, but recall
/// that these establish a new gRPC stream per request so it's already expensive. GetPage
/// requests use the `StreamPool` instead.
pub async fn get(self: &Arc<Self>) -> anyhow::Result<ClientGuard> {
let permit = self
.limiter
.clone()
.acquire_owned()
.await
.expect("never closed");
let mut idle = self.idle.lock().unwrap();
// Fast path: acquire an idle client from the pool.
if let Some(((_, id), entry)) = idle.pop_first() {
if let Some((id, entry)) = self.idle.lock().unwrap().pop_first() {
return Ok(ClientGuard {
pool: Arc::downgrade(&self),
pool: Arc::downgrade(self),
id,
client: Some(entry.client),
channel_guard: Some(entry.channel_guard),
@@ -200,9 +264,7 @@ impl ClientPool {
}
// Slow path: construct a new client.
let mut channel_guard = self.channels.clone().get()?; // never blocks (lazy connection)
let id = self.next_client_id.fetch_add(1, Ordering::Relaxed);
let mut channel_guard = self.channel_pool.get()?;
let client = page_api::Client::new(
channel_guard.take(),
self.tenant_id,
@@ -213,8 +275,11 @@ impl ClientPool {
)?;
Ok(ClientGuard {
pool: Arc::downgrade(&self),
id,
pool: Arc::downgrade(self),
id: (
channel_guard.id,
self.next_client_id.fetch_add(1, Ordering::Relaxed),
),
client: Some(client),
channel_guard: Some(channel_guard),
permit,
@@ -222,11 +287,13 @@ impl ClientPool {
}
}
/// A client acquired from the pool. The inner client can be accessed via derefs. The client is
/// returned to the pool when dropped.
pub struct ClientGuard {
pool: Weak<ClientPool>,
id: ClientID,
client: Option<page_api::Client>,
channel_guard: Option<ChannelGuard>,
client: Option<page_api::Client>, // Some until dropped
channel_guard: Option<ChannelGuard>, // Some until dropped
permit: OwnedSemaphorePermit,
}
@@ -250,39 +317,35 @@ impl Drop for ClientGuard {
let Some(pool) = self.pool.upgrade() else {
return; // pool was dropped
};
let mut idle = pool.idle.lock().unwrap();
let client = self.client.take().expect("dropped once");
let channel_guard = self.channel_guard.take().expect("dropped once");
let channel_id = channel_guard.id;
let entry = ClientEntry {
client,
channel_guard,
client: self.client.take().expect("dropped once"),
channel_guard: self.channel_guard.take().expect("dropped once"),
};
idle.insert((channel_id, self.id), entry);
pool.idle.lock().unwrap().insert(self.id, entry);
// The permit will be returned by its drop handler. Tag it here for visibility.
_ = self.permit;
}
}
/// A pool of bidirectional gRPC streams. Currently only used for GetPage streams.
/// TODO: consider making this generic over request and response types, but not currently needed.
/// A pool of bidirectional gRPC streams. Currently only used for GetPage streams. Each stream
/// acquires a client from the inner `ClientPool` for the stream's lifetime.
///
/// Individual streams are not exposed to callers -- instead, callers can send invididual requests
/// to the pool and await a response. Internally, requests are multiplexed over streams and
/// channels.
/// Individual streams are not exposed to callers -- instead, callers submit invididual requests to
/// the pool and await a response. Internally, requests are multiplexed across streams and channels.
///
/// TODO: reap idle streams.
/// TODO: error handling (but channel will be reconnected automatically).
/// TODO: rate limiting.
/// TODO: consider making this generic over request and response types; not currently needed.
pub struct StreamPool {
/// gRPC client pool.
clients: Arc<ClientPool>,
/// The client pool to acquire clients from.
client_pool: Arc<ClientPool>,
/// All pooled streams.
///
/// TODO: this must use something more sophisticated. This is on the GetPage hot path, so we
/// want cheap concurrent access in the common case. We also want to prioritize using streams
/// that belong to lower-numbered channels and clients first, such that we can reap
/// higher-numbered channels and clients as they become idle. And we can't hold a lock on this
/// while we're spinning up new streams, but we want to install an entry prior to spinning it up
/// such that other requests can join onto it (we won't know the client/channel ID until we've
/// acquired a client from the client pool which may block).
/// Incoming requests will be sent over an existing stream with available capacity, or a new
/// stream is spun up and added to the pool. Each stream has an associated Tokio task that
/// processes requests and responses.
streams: Arc<Mutex<HashMap<StreamID, StreamEntry>>>,
/// Limits the max number of concurrent requests (not streams).
limiter: Semaphore,
@@ -291,51 +354,72 @@ pub struct StreamPool {
}
type StreamID = usize;
type StreamSender = tokio::sync::mpsc::Sender<(GetPageRequest, ResponseSender)>;
type StreamReceiver = tokio::sync::mpsc::Receiver<(GetPageRequest, ResponseSender)>;
type ResponseSender = tokio::sync::oneshot::Sender<tonic::Result<GetPageResponse>>;
type RequestSender = Sender<(GetPageRequest, ResponseSender)>;
type RequestReceiver = Receiver<(GetPageRequest, ResponseSender)>;
type ResponseSender = oneshot::Sender<tonic::Result<GetPageResponse>>;
struct StreamEntry {
/// The request stream sender. The stream task exits when this is dropped.
sender: StreamSender,
/// Number of in-flight requests on this stream.
/// Sends caller requests to the stream task. The stream task exits when this is dropped.
sender: RequestSender,
/// Number of in-flight requests on this stream. This is an atomic to allow decrementing it on
/// completion without acquiring the `StreamPool::streams` lock.
queue_depth: Arc<AtomicUsize>,
}
impl StreamPool {
/// Max number of concurrent requests per stream.
const STREAM_QUEUE_DEPTH: usize = 2;
/// Max number of concurrent requests in flight.
const TOTAL_QUEUE_DEPTH: usize = ClientPool::CLIENT_LIMIT * Self::STREAM_QUEUE_DEPTH;
/// Creates a new stream pool, using the given client pool.
pub fn new(clients: Arc<ClientPool>) -> Self {
pub fn new(client_pool: Arc<ClientPool>) -> Self {
Self {
clients,
client_pool,
streams: Arc::default(),
limiter: Semaphore::new(Self::TOTAL_QUEUE_DEPTH),
limiter: Semaphore::new(CLIENT_LIMIT * STREAM_QUEUE_DEPTH),
next_stream_id: AtomicUsize::default(),
}
}
/// Sends a request via the stream pool, returning a response.
/// Sends a request via the stream pool and awaits the response. Blocks if the pool is at
/// capacity (i.e. `CLIENT_LIMIT * STREAM_QUEUE_DEPTH` requests in flight). The
/// `GetPageRequest::request_id` must be unique across in-flight request.
///
/// NB: errors are often returned as `GetPageResponse::status_code` instead of `tonic::Status`
/// to avoid tearing down the stream for per-request errors. Callers must check this.
///
/// This is very performance-sensitive, as it is on the GetPage hot path.
///
/// TODO: this must do something more sophisticated for performance. We want:
/// * Cheap, concurrent access in the common case where we can use a pooled stream.
/// * Quick acquisition of pooled streams with available capacity.
/// * Prefer streams that belong to lower-numbered channels, to reap idle channels.
/// * Prefer filling up existing streams' queue depth before spinning up new streams.
/// * Don't hold a lock while spinning up new streams.
/// * Allow concurrent clients to join onto streams while they're spun up.
/// * Allow spinning up multiple streams concurrently, but don't overshoot limits.
///
/// For now, we just do something simple and functional, but very inefficient (linear scan).
pub async fn send(&self, req: GetPageRequest) -> tonic::Result<GetPageResponse> {
// Acquire a permit. For simplicity, we drop it when this method returns, even if the
// request is still in flight because the caller went away. We do the same for queue depth.
// Acquire a permit. For simplicity, we drop it when this method returns. This may exceed
// the queue depth if a caller goes away while a request is in flight, but that's okay. We
// do the same for queue depth tracking.
let _permit = self.limiter.acquire().await.expect("never closed");
// Acquire a stream from the pool.
// Acquire a stream sender. We increment and decrement the queue depth here instead of in
// the stream task to ensure we don't exceed the queue depth limit.
#[allow(clippy::await_holding_lock)] // TODO: Clippy doesn't understand drop()
let (req_tx, queue_depth) = async {
let mut streams = self.streams.lock().unwrap();
// Try to find an existing stream with available capacity.
for entry in streams.values() {
assert!(
entry.queue_depth.load(Ordering::Relaxed) <= STREAM_QUEUE_DEPTH,
"stream overflow"
);
if entry
.queue_depth
// TODO: review ordering.
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |depth| {
(depth < Self::STREAM_QUEUE_DEPTH).then_some(depth + 1)
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |queue_depth| {
// Increment the queue depth via compare-and-swap.
// TODO: review ordering.
(queue_depth < STREAM_QUEUE_DEPTH).then_some(queue_depth + 1)
})
.is_ok()
{
@@ -344,25 +428,24 @@ impl StreamPool {
}
// No available stream, spin up a new one. We install the stream entry first and release
// the lock. This will allow other requests to join onto this stream while we're
// spinning up the task, and also create additional streams concurrently when full.
// the lock, to allow other callers to join onto this stream and also create additional
// streams concurrently when this fills up.
let id = self.next_stream_id.fetch_add(1, Ordering::Relaxed);
let queue_depth = Arc::new(AtomicUsize::new(1));
let (req_tx, req_rx) = tokio::sync::mpsc::channel(Self::STREAM_QUEUE_DEPTH);
streams.insert(
id,
StreamEntry {
sender: req_tx.clone(),
queue_depth: queue_depth.clone(),
},
);
drop(streams); // drop lock before spinning up task
let queue_depth = Arc::new(AtomicUsize::new(1)); // account for this request
let (req_tx, req_rx) = mpsc::channel(STREAM_QUEUE_DEPTH);
let entry = StreamEntry {
sender: req_tx.clone(),
queue_depth: queue_depth.clone(),
};
streams.insert(id, entry);
let clients = self.clients.clone();
drop(streams); // drop lock before spinning up stream
let client_pool = self.client_pool.clone();
let streams = self.streams.clone();
tokio::spawn(async move {
if let Err(err) = Self::run_stream(clients, req_rx).await {
if let Err(err) = Self::run_stream(client_pool, req_rx).await {
warn!("stream failed: {err}");
}
// Remove stream from pool on exit.
@@ -375,12 +458,15 @@ impl StreamPool {
.await
.map_err(|err| tonic::Status::internal(err.to_string()))?;
// Decrement the queue depth on return. We incremented it above, so we also decrement it
// here, even though that could prematurely decrement it before the response arrives.
defer!(queue_depth.fetch_sub(1, Ordering::SeqCst););
// Decrement the queue depth on return. This may prematurely decrement it if the caller goes
// away while the request is in flight, but that's okay.
defer!(
let prev_queue_depth = queue_depth.fetch_sub(1, Ordering::SeqCst);
assert!(prev_queue_depth > 0, "stream underflow");
);
// Send the request and wait for the response.
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let (resp_tx, resp_rx) = oneshot::channel();
req_tx
.send((req, resp_tx))
@@ -392,35 +478,43 @@ impl StreamPool {
.map_err(|_| tonic::Status::unavailable("stream closed"))?
}
/// Runs a stream task.
/// Runs a stream task. This acquires a client from the `ClientPool` and establishes a
/// bidirectional GetPage stream, then forwards requests and responses between callers and the
/// stream. It does not track or enforce queue depths, see `send()`.
///
/// The task exits when the request channel is closed, or on a stream error. The caller is
/// responsible for removing the stream from the pool on exit.
async fn run_stream(
client_pool: Arc<ClientPool>,
mut caller_rx: StreamReceiver,
mut caller_rx: RequestReceiver,
) -> anyhow::Result<()> {
// Acquire a client from the pool and create a stream.
let mut client_guard = client_pool.get().await?;
let client = client_guard.deref_mut();
let mut client = client_pool.get().await?;
let (req_tx, req_rx) = tokio::sync::mpsc::channel(Self::STREAM_QUEUE_DEPTH);
let (req_tx, req_rx) = mpsc::channel(STREAM_QUEUE_DEPTH);
let req_stream = tokio_stream::wrappers::ReceiverStream::new(req_rx);
let mut resp_stream = client.get_pages(req_stream).await?;
// Track caller response channels by request ID. If the task returns early, the response
// Track caller response channels by request ID. If the task returns early, these response
// channels will be dropped and the callers will receive an error.
let mut callers = HashMap::with_capacity(Self::STREAM_QUEUE_DEPTH);
let mut callers = HashMap::with_capacity(STREAM_QUEUE_DEPTH);
// Process requests and responses.
loop {
// NB: this can trip if the server doesn't respond to a request, so only debug_assert.
debug_assert!(callers.len() <= STREAM_QUEUE_DEPTH, "stream overflow");
tokio::select! {
// Receive requests from callers and send them to the stream.
req = caller_rx.recv() => {
// Shut down if input channel is closed.
// Shut down if request channel is closed.
let Some((req, resp_tx)) = req else {
return Ok(()); // stream closed
return Ok(());
};
// Store the response channel by request ID.
if callers.contains_key(&req.request_id) {
// Error on request ID duplicates. Ignore callers that went away.
_ = resp_tx.send(Err(tonic::Status::invalid_argument(
format!("duplicate request ID: {}", req.request_id),
)));
@@ -428,7 +522,7 @@ impl StreamPool {
}
callers.insert(req.request_id, resp_tx);
// Send the request on the stream. Bail out on send errors.
// Send the request on the stream. Bail out if the send fails.
req_tx.send(req).await.map_err(|_| {
tonic::Status::unavailable("stream closed")
})?;
@@ -441,12 +535,12 @@ impl StreamPool {
return Ok(())
};
// Send the response to the caller.
// Send the response to the caller. Ignore errors if the caller went away.
let Some(resp_tx) = callers.remove(&resp.request_id) else {
warn!("received response for unknown request ID: {}", resp.request_id);
continue;
};
_ = resp_tx.send(Ok(resp)); // ignore error if caller went away
_ = resp_tx.send(Ok(resp));
}
}
}