diff --git a/pageserver/client_grpc/src/request_tracker.rs b/pageserver/client_grpc/src/request_tracker.rs index 9cce6a06e1..5f5e767c49 100644 --- a/pageserver/client_grpc/src/request_tracker.rs +++ b/pageserver/client_grpc/src/request_tracker.rs @@ -33,15 +33,20 @@ use std::time::Duration; use client_cache::PooledItemFactory; +/// StreamReturner represents a gRPC stream to a pageserver. +/// +/// To send a request: +/// 1. insert the request's ID, along with a channel to receive the response +/// 2. send the request to 'sender' #[derive(Clone)] pub struct StreamReturner { sender: tokio::sync::mpsc::Sender, sender_hashmap: Arc< - tokio::sync::Mutex< + tokio::sync::Mutex>, - >, + >>, >, >, } @@ -97,7 +102,7 @@ impl PooledItemFactory for StreamFactory { let stream_returner = StreamReturner { sender: sender.clone(), sender_hashmap: Arc::new(tokio::sync::Mutex::new( - std::collections::HashMap::new(), + Some(std::collections::HashMap::new()), )), }; let map = Arc::clone(&stream_returner.sender_hashmap); @@ -106,32 +111,42 @@ impl PooledItemFactory for StreamFactory { let map_clone = Arc::clone(&map); let mut inner = resp.into_inner(); loop { - let resp = inner.message().await; - if !resp.is_ok() { - break; // Exit the loop if no more messages - } - let response = resp.unwrap().unwrap(); - - // look up stream in hash map - let mut hashmap = map_clone.lock().await; - if let Some(sender) = hashmap.get(&response.request_id) { - // Send the response to the original request sender - if let Err(e) = sender.send(Ok(response.clone())).await { - eprintln!("Failed to send response: {}", e); + match inner.message().await { + Err(e) => { + tracing::info!("error received on getpage stream: {e}"); + break; // Exit the loop if no more messages + } + Ok(None) => { + break; // Sender closed the stream + } + Ok(Some(response)) => { + // look up stream in hash map + let mut hashmap = map_clone.lock().await; + let hashmap = hashmap.as_mut().expect("no other task clears the hashmap"); + if let Some(sender) = hashmap.get(&response.request_id) { + // Send the response to the original request sender + if let Err(e) = sender.send(Ok(response.clone())).await { + eprintln!("Failed to send response: {}", e); + } + hashmap.remove(&response.request_id); + } else { + eprintln!("No sender found for request ID: {}", response.request_id); + } } - hashmap.remove(&response.request_id); - } else { - eprintln!("No sender found for request ID: {}", response.request_id); } } + // Don't accept any more requests + // Close every sender stream in the hashmap - let hashmap = map_clone.lock().await; + let mut hashmap_opt = map_clone.lock().await; + let hashmap = hashmap_opt.as_mut().expect("no other task clears the hashmap"); for sender in hashmap.values() { let error = Status::new(Code::Unknown, "Stream closed"); if let Err(e) = sender.send(Err(error)).await { eprintln!("Failed to send close response: {}", e); } } + *hashmap_opt = None; }); Ok(Ok(stream_returner)) @@ -288,8 +303,18 @@ impl RequestTracker { let map = returner.sender_hashmap.clone(); // Insert the response sender into the hashmap { - let mut map_inner = map.lock().await; - map_inner.insert(request_id, response_sender); + if let Some(map_inner) = map.lock().await.as_mut() { + let old = map_inner.insert(request_id, response_sender); + + // request IDs must be unique + if old.is_some() { + panic!("request with ID {request_id} is already in-flight"); + } + } else { + // The stream was closed. Try a different one. + tracing::info!("stream was concurrently closed"); + continue; + } } let sent = returner .sender @@ -299,9 +324,10 @@ impl RequestTracker { if let Err(_e) = sent { // Remove the request from the map if sending failed { - let mut map_inner = map.lock().await; - // remove from hashmap - map_inner.remove(&request_id); + if let Some(map_inner) = map.lock().await.as_mut() { + // remove from hashmap + map_inner.remove(&request_id); + } } stream_returner .finish(Err(Status::new(Code::Unknown, "Failed to send request")))