diff --git a/src/meta-srv/src/handler.rs b/src/meta-srv/src/handler.rs index 92916838d8..9cfd4e6079 100644 --- a/src/meta-srv/src/handler.rs +++ b/src/meta-srv/src/handler.rs @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::cmp::Ordering; use std::collections::{BTreeMap, HashSet}; use std::fmt::{Debug, Display}; -use std::ops::Range; +use std::ops::Bound; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; @@ -136,6 +137,26 @@ pub struct PusherId { pub id: u64, } +impl PartialEq for PusherId { + fn eq(&self, other: &Self) -> bool { + self.role as i32 == other.role as i32 && self.id == other.id + } +} + +impl Eq for PusherId {} + +impl PartialOrd for PusherId { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PusherId { + fn cmp(&self, other: &Self) -> Ordering { + (self.role as i32, self.id).cmp(&(other.role as i32, other.id)) + } +} + impl Debug for PusherId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}-{}", self.role, self.id) @@ -153,8 +174,11 @@ impl PusherId { Self { role, id } } - pub fn string_key(&self) -> String { - format!("{}-{}", self.role as i32, self.id) + fn role_range(role: Role) -> (Bound, Bound) { + ( + Bound::Included(Self::new(role, u64::MIN)), + Bound::Included(Self::new(role, u64::MAX)), + ) } } @@ -214,7 +238,7 @@ impl Pusher { /// The group of heartbeat pushers. #[derive(Clone, Default)] -pub struct Pushers(Arc>>); +pub struct Pushers(Arc>>); impl Pushers { async fn push( @@ -222,11 +246,12 @@ impl Pushers { pusher_id: PusherId, mailbox_message: MailboxMessage, ) -> Result { - let pusher_id = pusher_id.string_key(); let pushers = self.0.read().await; let pusher = pushers .get(&pusher_id) - .context(error::PusherNotFoundSnafu { pusher_id })?; + .with_context(|| error::PusherNotFoundSnafu { + pusher_id: pusher_id.to_string(), + })?; pusher .push(HeartbeatResponse { @@ -239,14 +264,10 @@ impl Pushers { Ok(pusher.deregister_signal_receiver.clone()) } - async fn broadcast( - &self, - range: Range, - mailbox_message: &MailboxMessage, - ) -> Result<()> { + async fn broadcast(&self, role: Role, mailbox_message: &MailboxMessage) -> Result<()> { let pushers = self.0.read().await; let pushers = pushers - .range(range) + .range(PusherId::role_range(role)) .map(|(_, value)| value) .collect::>(); let mut results = Vec::with_capacity(pushers.len()); @@ -271,12 +292,12 @@ impl Pushers { Ok(()) } - pub(crate) async fn insert(&self, pusher_id: String, pusher: Pusher) -> Option { + pub(crate) async fn insert(&self, pusher_id: PusherId, pusher: Pusher) -> Option { self.0.write().await.insert(pusher_id, pusher) } - async fn remove(&self, pusher_id: &str) -> Option { - self.0.write().await.remove(pusher_id) + async fn remove(&self, pusher_id: PusherId) -> Option { + self.0.write().await.remove(&pusher_id) } } @@ -308,12 +329,12 @@ impl HeartbeatHandlerGroup { pub async fn register_pusher(&self, pusher_id: PusherId, pusher: Pusher) { METRIC_META_HEARTBEAT_CONNECTION_NUM.inc(); info!("Pusher register: {}", pusher_id); - let _ = self.pushers.insert(pusher_id.string_key(), pusher).await; + let _ = self.pushers.insert(pusher_id, pusher).await; } /// Deregisters the heartbeat response [`Pusher`] with the given key from the group. pub async fn deregister_push(&self, pusher_id: PusherId) { - if self.pushers.remove(&pusher_id.string_key()).await.is_some() { + if self.pushers.remove(pusher_id).await.is_some() { info!("Pusher unregister: {}", pusher_id); METRIC_META_HEARTBEAT_CONNECTION_NUM.dec(); } @@ -323,7 +344,7 @@ impl HeartbeatHandlerGroup { /// Returns whether the group contains the heartbeat response [`Pusher`] with the given key. pub async fn contains_pusher(&self, pusher_id: &PusherId) -> bool { let pushers = self.pushers.0.read().await; - pushers.contains_key(&pusher_id.string_key()) + pushers.contains_key(pusher_id) } /// Returns the [`Pushers`] of the group. @@ -531,7 +552,7 @@ impl Mailbox for HeartbeatMailbox { } async fn broadcast(&self, ch: &BroadcastChannel, msg: &MailboxMessage) -> Result<()> { - self.pushers.broadcast(ch.pusher_range(), msg).await + self.pushers.broadcast(ch.role(), msg).await } async fn on_recv(&self, id: MessageId, maybe_msg: Result) -> Result<()> { @@ -861,6 +882,7 @@ impl HeartbeatHandlerGroupBuilderCustomizer for DefaultHeartbeatHandlerGroupBuil mod tests { use std::assert_matches; + use std::collections::BTreeMap; use std::sync::Arc; use std::time::Duration; @@ -936,6 +958,62 @@ mod tests { (mailbox, receiver) } + #[test] + fn test_pusher_id_role_range() { + let mut pushers = BTreeMap::new(); + pushers.insert(PusherId::new(Role::Datanode, u64::MAX), "datanode"); + pushers.insert(PusherId::new(Role::Frontend, u64::MIN), "frontend-min"); + pushers.insert(PusherId::new(Role::Frontend, u64::MAX), "frontend-max"); + pushers.insert(PusherId::new(Role::Flownode, u64::MIN), "flownode"); + + let frontend_pushers = pushers + .range(PusherId::role_range(Role::Frontend)) + .map(|(_, value)| *value) + .collect::>(); + + assert_eq!(frontend_pushers, vec!["frontend-min", "frontend-max"]); + } + + #[tokio::test] + async fn test_pushers_broadcast_by_role() { + let pushers = Pushers::default(); + let (datanode_tx, mut datanode_rx) = mpsc::channel(1); + let (frontend_tx, mut frontend_rx) = mpsc::channel(1); + let (flownode_tx, mut flownode_rx) = mpsc::channel(1); + + pushers + .insert( + PusherId::new(Role::Datanode, u64::MAX), + Pusher::new(datanode_tx), + ) + .await; + pushers + .insert(PusherId::new(Role::Frontend, 1), Pusher::new(frontend_tx)) + .await; + pushers + .insert( + PusherId::new(Role::Flownode, u64::MIN), + Pusher::new(flownode_tx), + ) + .await; + + let msg = MailboxMessage { + id: 42, + subject: "broadcast-test".to_string(), + timestamp_millis: 123, + ..Default::default() + }; + + pushers.broadcast(Role::Frontend, &msg).await.unwrap(); + + let received = frontend_rx.recv().await.unwrap().unwrap(); + let mailbox_message = received.mailbox_message.unwrap(); + assert_eq!(mailbox_message.id, 0); + assert_eq!(mailbox_message.subject, "broadcast-test"); + assert!(datanode_rx.try_recv().is_err()); + assert!(flownode_rx.try_recv().is_err()); + } + #[test] fn test_handler_group_builder() { let group = HeartbeatHandlerGroupBuilder::new(Pushers::default()) diff --git a/src/meta-srv/src/procedure/test_util.rs b/src/meta-srv/src/procedure/test_util.rs index 5bf60fe32e..318a276676 100644 --- a/src/meta-srv/src/procedure/test_util.rs +++ b/src/meta-srv/src/procedure/test_util.rs @@ -66,7 +66,7 @@ impl MailboxContext { ) { let pusher_id = channel.pusher_id(); let pusher = Pusher::new(tx); - let _ = self.pushers.insert(pusher_id.string_key(), pusher).await; + let _ = self.pushers.insert(pusher_id, pusher).await; } pub fn mailbox(&self) -> &MailboxRef { diff --git a/src/meta-srv/src/service/mailbox.rs b/src/meta-srv/src/service/mailbox.rs index 86b631998b..f3fbdcbffc 100644 --- a/src/meta-srv/src/service/mailbox.rs +++ b/src/meta-srv/src/service/mailbox.rs @@ -13,7 +13,6 @@ // limitations under the License. use std::fmt::{Display, Formatter}; -use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -69,20 +68,11 @@ pub enum BroadcastChannel { } impl BroadcastChannel { - pub(crate) fn pusher_range(&self) -> Range { + pub(crate) fn role(&self) -> Role { match self { - BroadcastChannel::Datanode => Range { - start: format!("{}-", Role::Datanode as i32), - end: format!("{}-", Role::Frontend as i32), - }, - BroadcastChannel::Frontend => Range { - start: format!("{}-", Role::Frontend as i32), - end: format!("{}-", Role::Flownode as i32), - }, - BroadcastChannel::Flownode => Range { - start: format!("{}-", Role::Flownode as i32), - end: format!("{}-", Role::Flownode as i32 + 1), - }, + BroadcastChannel::Datanode => Role::Datanode, + BroadcastChannel::Frontend => Role::Frontend, + BroadcastChannel::Flownode => Role::Flownode, } } } @@ -219,19 +209,10 @@ mod tests { use super::*; #[test] - fn test_channel_pusher_range() { - assert_eq!( - BroadcastChannel::Datanode.pusher_range(), - ("0-".to_string().."1-".to_string()) - ); - assert_eq!( - BroadcastChannel::Frontend.pusher_range(), - ("1-".to_string().."2-".to_string()) - ); - assert_eq!( - BroadcastChannel::Flownode.pusher_range(), - ("2-".to_string().."3-".to_string()) - ); + fn test_broadcast_channel_role() { + assert_eq!(BroadcastChannel::Datanode.role(), Role::Datanode); + assert_eq!(BroadcastChannel::Frontend.role(), Role::Frontend); + assert_eq!(BroadcastChannel::Flownode.role(), Role::Flownode); } #[tokio::test]