Fix handling the case that server closes the stream

- avoid panic by checking for Ok(None) response from
  tonic::Streaming::message() instead of just using unwrap()
- There was a race condition, if the caller sent the message, but the
  receiver task concurrently received Ok(None) indicating the stream
  was closed. (I didn't see that in action, but I think it could happen
  by reading the code)
This commit is contained in:
Heikki Linnakangas
2025-06-29 22:43:25 +03:00
parent 7020476bf5
commit 924c6a6fdf

View File

@@ -33,15 +33,20 @@ use std::time::Duration;
use client_cache::PooledItemFactory;
/// StreamReturner represents a gRPC stream to a pageserver.
///
/// To send a request:
/// 1. insert the request's ID, along with a channel to receive the response
/// 2. send the request to 'sender'
#[derive(Clone)]
pub struct StreamReturner {
sender: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
sender_hashmap: Arc<
tokio::sync::Mutex<
tokio::sync::Mutex<Option<
std::collections::HashMap<
u64,
tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>,
>,
>>,
>,
>,
}
@@ -97,7 +102,7 @@ impl PooledItemFactory<StreamReturner> for StreamFactory {
let stream_returner = StreamReturner {
sender: sender.clone(),
sender_hashmap: Arc::new(tokio::sync::Mutex::new(
std::collections::HashMap::new(),
Some(std::collections::HashMap::new()),
)),
};
let map = Arc::clone(&stream_returner.sender_hashmap);
@@ -106,32 +111,42 @@ impl PooledItemFactory<StreamReturner> for StreamFactory {
let map_clone = Arc::clone(&map);
let mut inner = resp.into_inner();
loop {
let resp = inner.message().await;
if !resp.is_ok() {
break; // Exit the loop if no more messages
}
let response = resp.unwrap().unwrap();
// look up stream in hash map
let mut hashmap = map_clone.lock().await;
if let Some(sender) = hashmap.get(&response.request_id) {
// Send the response to the original request sender
if let Err(e) = sender.send(Ok(response.clone())).await {
eprintln!("Failed to send response: {}", e);
match inner.message().await {
Err(e) => {
tracing::info!("error received on getpage stream: {e}");
break; // Exit the loop if no more messages
}
Ok(None) => {
break; // Sender closed the stream
}
Ok(Some(response)) => {
// look up stream in hash map
let mut hashmap = map_clone.lock().await;
let hashmap = hashmap.as_mut().expect("no other task clears the hashmap");
if let Some(sender) = hashmap.get(&response.request_id) {
// Send the response to the original request sender
if let Err(e) = sender.send(Ok(response.clone())).await {
eprintln!("Failed to send response: {}", e);
}
hashmap.remove(&response.request_id);
} else {
eprintln!("No sender found for request ID: {}", response.request_id);
}
}
hashmap.remove(&response.request_id);
} else {
eprintln!("No sender found for request ID: {}", response.request_id);
}
}
// Don't accept any more requests
// Close every sender stream in the hashmap
let hashmap = map_clone.lock().await;
let mut hashmap_opt = map_clone.lock().await;
let hashmap = hashmap_opt.as_mut().expect("no other task clears the hashmap");
for sender in hashmap.values() {
let error = Status::new(Code::Unknown, "Stream closed");
if let Err(e) = sender.send(Err(error)).await {
eprintln!("Failed to send close response: {}", e);
}
}
*hashmap_opt = None;
});
Ok(Ok(stream_returner))
@@ -288,8 +303,18 @@ impl RequestTracker {
let map = returner.sender_hashmap.clone();
// Insert the response sender into the hashmap
{
let mut map_inner = map.lock().await;
map_inner.insert(request_id, response_sender);
if let Some(map_inner) = map.lock().await.as_mut() {
let old = map_inner.insert(request_id, response_sender);
// request IDs must be unique
if old.is_some() {
panic!("request with ID {request_id} is already in-flight");
}
} else {
// The stream was closed. Try a different one.
tracing::info!("stream was concurrently closed");
continue;
}
}
let sent = returner
.sender
@@ -299,9 +324,10 @@ impl RequestTracker {
if let Err(_e) = sent {
// Remove the request from the map if sending failed
{
let mut map_inner = map.lock().await;
// remove from hashmap
map_inner.remove(&request_id);
if let Some(map_inner) = map.lock().await.as_mut() {
// remove from hashmap
map_inner.remove(&request_id);
}
}
stream_returner
.finish(Err(Status::new(Code::Unknown, "Failed to send request")))