diff --git a/src/common/query/src/error.rs b/src/common/query/src/error.rs index e70b9f4833..dd2b29adf3 100644 --- a/src/common/query/src/error.rs +++ b/src/common/query/src/error.rs @@ -199,6 +199,18 @@ pub enum Error { #[snafu(display("Invalid character in prefix config: {}", prefix))] InvalidColumnPrefix { prefix: String }, + + #[snafu(display( + "DynFilterPayload::Datafusion is {} bytes, which exceeds the configured limit of {} bytes", + payload_size_bytes, + max_payload_bytes + ))] + DynFilterPayloadTooLarge { + payload_size_bytes: usize, + max_payload_bytes: usize, + #[snafu(implicit)] + location: Location, + }, } pub type Result = std::result::Result; @@ -229,6 +241,8 @@ impl ErrorExt for Error { | Error::InvalidFuncArgs { .. } | Error::InvalidColumnPrefix { .. } => StatusCode::InvalidArguments, + Error::DynFilterPayloadTooLarge { .. } => StatusCode::PlanQuery, + Error::ConvertDfRecordBatchStream { source, .. } => source.status_code(), Error::DecodePlan { source, .. } diff --git a/src/common/query/src/request.rs b/src/common/query/src/request.rs index 121292709f..1a2a58b8ca 100644 --- a/src/common/query/src/request.rs +++ b/src/common/query/src/request.rs @@ -19,10 +19,10 @@ use std::sync::Arc; use api::v1::region::RegionRequestHeader; use datafusion::execution::TaskContext; -use datafusion::physical_expr::expressions::Column; +use datafusion::physical_expr::expressions::{Column, InListExpr, lit}; use datafusion::physical_plan::PhysicalExpr; use datafusion::physical_plan::joins::HashTableLookupExpr; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_expr::LogicalPlan; use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; @@ -31,6 +31,7 @@ use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; use datafusion_proto::protobuf::PhysicalExprNode; use prost::Message; use serde::{Deserialize, Serialize}; +use snafu::ensure; use store_api::storage::RegionId; /// Current wire-format version for remote dynamic filter payload updates. @@ -39,6 +40,7 @@ pub use self::initial_remote_dyn_filter_reg::{ INITIAL_REMOTE_DYN_FILTER_REGS_MAX_TOTAL_PROTO_BYTES, InitialDynFilterReg, InitialDynFilterRegs, InitialDynFilterSnapshot, }; +use crate::error::{DynFilterPayloadTooLargeSnafu, Error as CommonQueryError}; pub const DYN_FILTER_PROTOCOL_VERSION: u32 = 1; @@ -63,24 +65,23 @@ pub enum DynFilterPayload { impl DynFilterPayload { /// Encodes a DataFusion physical expression into a bounded dynamic filter payload. /// - /// This rejects expressions that cannot be safely shipped as dynamic filter - /// predicates and fails if the serialized payload exceeds `max_payload_bytes`. + /// Runtime-only hash lookup predicates are degraded to `true` before encoding so + /// serializable min/max bounds around them can still be shipped to remote scans. + /// If the full serializable predicate is still larger than `max_payload_bytes`, large + /// membership predicates (`IN (...)`) are also degraded to `true` as a bounds-only fallback. pub fn from_datafusion_expr( expr: &Arc, max_payload_bytes: usize, ) -> DataFusionResult { - validate_supported_payload_expr(expr)?; - - let codec = DefaultPhysicalExtensionCodec {}; - let proto = serialize_physical_expr(expr, &codec)?; - let mut bytes = Vec::new(); - proto.encode(&mut bytes).map_err(|e| { - DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}")) - })?; - - validate_payload_size(bytes.len(), max_payload_bytes)?; - - Ok(Self::Datafusion(bytes)) + match encode_remote_dyn_filter_expr(expr, max_payload_bytes, false) { + Ok(bytes) => Ok(Self::Datafusion(bytes)), + Err(CommonQueryError::DynFilterPayloadTooLarge { .. }) => { + encode_remote_dyn_filter_expr(expr, max_payload_bytes, true) + .map(Self::Datafusion) + .map_err(DataFusionError::from) + } + Err(error) => Err(DataFusionError::from(error)), + } } /// Decodes a DataFusion dynamic filter payload against the provided input schema. @@ -95,7 +96,7 @@ impl DynFilterPayload { max_payload_bytes: usize, ) -> DataFusionResult> { let Self::Datafusion(bytes) = self; - validate_payload_size(bytes.len(), max_payload_bytes)?; + validate_payload_size(bytes.len(), max_payload_bytes).map_err(DataFusionError::from)?; let codec = DefaultPhysicalExtensionCodec {}; let proto = PhysicalExprNode::decode(bytes.as_slice()).map_err(|e| { DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}")) @@ -118,13 +119,41 @@ fn encode_physical_expr_to_bytes(expr: &Arc) -> DataFusionResu Ok(bytes) } +fn encode_remote_dyn_filter_expr( + expr: &Arc, + max_payload_bytes: usize, + bounds_only: bool, +) -> Result, CommonQueryError> { + let expr = portable_remote_dyn_filter_expr(Arc::clone(expr), bounds_only) + .map_err(CommonQueryError::from)?; + let bytes = encode_physical_expr_to_bytes(&expr).map_err(CommonQueryError::from)?; + validate_payload_size(bytes.len(), max_payload_bytes)?; + Ok(bytes) +} + +fn portable_remote_dyn_filter_expr( + expr: Arc, + bounds_only: bool, +) -> DataFusionResult> { + expr.transform_up(|node| { + if node.as_any().is::() + || (bounds_only && node.as_any().is::()) + { + Ok(Transformed::yes(lit(true))) + } else { + Ok(Transformed::no(node)) + } + }) + .map(|transformed| transformed.data) +} + pub(crate) fn decode_physical_expr_from_bytes( bytes: &[u8], task_ctx: &TaskContext, input_schema: &datafusion::arrow::datatypes::Schema, max_payload_bytes: usize, ) -> DataFusionResult> { - validate_payload_size(bytes.len(), max_payload_bytes)?; + validate_payload_size(bytes.len(), max_payload_bytes).map_err(DataFusionError::from)?; let codec = DefaultPhysicalExtensionCodec {}; let proto = PhysicalExprNode::decode(bytes).map_err(|e| { DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}")) @@ -139,13 +168,14 @@ pub(crate) fn decode_physical_expr_from_bytes( fn validate_payload_size( payload_size_bytes: usize, max_payload_bytes: usize, -) -> DataFusionResult<()> { - if payload_size_bytes > max_payload_bytes { - return Err(DataFusionError::Plan(format!( - "DynFilterPayload::Datafusion is {} bytes, which exceeds the configured limit of {} bytes", - payload_size_bytes, max_payload_bytes - ))); - } +) -> Result<(), CommonQueryError> { + ensure!( + payload_size_bytes <= max_payload_bytes, + DynFilterPayloadTooLargeSnafu { + payload_size_bytes, + max_payload_bytes, + } + ); Ok(()) } @@ -263,7 +293,11 @@ mod tests { use base64::Engine; use base64::prelude::BASE64_STANDARD; use datafusion::arrow::datatypes::{DataType, Field, Schema}; - use datafusion::physical_expr::expressions::Column; + use datafusion::physical_expr::expressions::{BinaryExpr, Column, InListExpr, lit}; + use datafusion::physical_plan::expressions::col; + use datafusion::physical_plan::joins::join_hash_map::JoinHashMapU32; + use datafusion::physical_plan::joins::{HashTableLookupExpr, Map, SeededRandomState}; + use datafusion_expr::Operator; use super::*; @@ -407,12 +441,114 @@ mod tests { assert!(matches!(err, DataFusionError::Plan(_))); } + #[test] + fn dyn_filter_payload_hash_lookup_fallback_preserves_bounds() { + let schema = Arc::new(Schema::new(vec![Field::new( + "device_id", + DataType::Int32, + false, + )])); + let device_id = col("device_id", &schema).unwrap(); + let lower_bound = Arc::new(BinaryExpr::new( + Arc::clone(&device_id), + Operator::GtEq, + lit(10i32), + )) as Arc; + let lookup = Arc::new(HashTableLookupExpr::new( + vec![Arc::clone(&device_id)], + SeededRandomState::with_seeds(0, 0, 0, 0), + Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(0)))), + "hash_lookup".to_string(), + )) as Arc; + let expr = + Arc::new(BinaryExpr::new(lower_bound, Operator::And, lookup)) as Arc; + + let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap(); + let decoded = payload + .decode_datafusion_expr(&TaskContext::default(), &schema, 1024) + .unwrap(); + + assert!(!contains_expr::(&decoded)); + let decoded_display = decoded.to_string(); + assert!(decoded_display.contains("device_id")); + assert!(decoded_display.contains(">=")); + assert!(!decoded_display.contains("hash_lookup")); + } + + #[test] + fn dyn_filter_payload_oversized_inlist_falls_back_to_bounds() { + let schema = Arc::new(Schema::new(vec![Field::new( + "device_id", + DataType::Int32, + false, + )])); + let device_id = col("device_id", &schema).unwrap(); + let lower_bound = Arc::new(BinaryExpr::new( + Arc::clone(&device_id), + Operator::GtEq, + lit(8192i32), + )) as Arc; + let upper_bound = Arc::new(BinaryExpr::new( + Arc::clone(&device_id), + Operator::LtEq, + lit(8255i32), + )) as Arc; + let bounds = Arc::new(BinaryExpr::new(lower_bound, Operator::And, upper_bound)) + as Arc; + let in_list = Arc::new( + InListExpr::try_new( + Arc::clone(&device_id), + (8192..8256).map(lit).collect(), + false, + &schema, + ) + .unwrap(), + ) as Arc; + let expr = Arc::new(BinaryExpr::new(Arc::clone(&bounds), Operator::And, in_list)) + as Arc; + let bounds_only = portable_remote_dyn_filter_expr(Arc::clone(&expr), true).unwrap(); + let bounds_only_size = encode_physical_expr_to_bytes(&bounds_only).unwrap().len(); + let full_size = encode_physical_expr_to_bytes(&expr).unwrap().len(); + assert!(full_size > bounds_only_size); + + let payload = DynFilterPayload::from_datafusion_expr(&expr, bounds_only_size).unwrap(); + let decoded = payload + .decode_datafusion_expr(&TaskContext::default(), &schema, bounds_only_size) + .unwrap(); + + assert!(!contains_expr::(&decoded)); + let decoded_display = decoded.to_string(); + assert!(decoded_display.contains("device_id")); + assert!(decoded_display.contains(">=")); + assert!(decoded_display.contains("<=")); + } + #[test] fn dyn_filter_payload_rejects_oversized_payload() { let expr: Arc = Arc::new(Column::new("host", 0)); let err = DynFilterPayload::from_datafusion_expr(&expr, 1).unwrap_err(); - assert!(matches!(err, DataFusionError::Plan(_))); + let DataFusionError::External(error) = err else { + panic!("expected external common query error, got: {err:?}"); + }; + assert!(matches!( + error.downcast_ref::(), + Some(CommonQueryError::DynFilterPayloadTooLarge { .. }) + )); + } + + fn contains_expr(expr: &Arc) -> bool { + let mut found = false; + expr.apply(|node| { + if node.as_any().is::() { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + .unwrap(); + found } } diff --git a/src/query/src/dist_plan/merge_scan.rs b/src/query/src/dist_plan/merge_scan.rs index 2201d0c52f..951bc7319a 100644 --- a/src/query/src/dist_plan/merge_scan.rs +++ b/src/query/src/dist_plan/merge_scan.rs @@ -85,6 +85,23 @@ fn acquire_remote_dyn_filter_registry_lease( ) } +fn query_context_for_remote_dyn_filter_region( + query_ctx: &QueryContextRef, + region_id: RegionId, + remote_dyn_filter_registry_lease: Option<&RemoteDynFilterRegistryLease>, + captured_dyn_filters: &[CapturedDynFilter], +) -> session::context::QueryContext { + if let Some(remote_dyn_filter_registry_lease) = remote_dyn_filter_registry_lease { + register_dyn_filters_for_region( + remote_dyn_filter_registry_lease.registry(), + region_id, + captured_dyn_filters, + ); + } + + query_context_with_initial_dyn_filter_regs(query_ctx, region_id, captured_dyn_filters) +} + #[derive(Debug, Hash, PartialOrd, PartialEq, Eq, Clone)] pub struct MergeScanLogicalPlan { /// In logical plan phase it only contains one input @@ -346,25 +363,16 @@ impl MergeScanExec { .step_by(target_partition) .copied() { - if let Some(remote_dyn_filter_registry_lease) = - remote_dyn_filter_registry_lease.as_ref() - { - register_dyn_filters_for_region( - remote_dyn_filter_registry_lease.registry(), - region_id, - &captured_remote_dyn_filters, - ); - } - let region_span = tracing_context.attach(tracing::info_span!( parent: &Span::current(), "merge_scan_region", region_id = %region_id, partition = partition )); - let region_query_ctx = query_context_with_initial_dyn_filter_regs( + let region_query_ctx = query_context_for_remote_dyn_filter_region( &query_ctx, region_id, + remote_dyn_filter_registry_lease.as_ref(), &captured_remote_dyn_filters, ); let request = QueryRequest { @@ -397,6 +405,13 @@ impl MergeScanExec { })?; let do_get_cost = do_get_start.elapsed(); + if let Some(remote_dyn_filter_registry_lease) = + remote_dyn_filter_registry_lease.as_ref() + { + remote_dyn_filter_registry_lease + .ensure_fanout_task(region_query_handler.clone()); + } + ready_timer.stop(); let mut poll_duration = Duration::ZERO; @@ -869,6 +884,7 @@ mod tests { use std::collections::BTreeSet; use async_trait::async_trait; + use common_query::request::INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY; use datafusion::config::ConfigOptions; use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::filter_pushdown::ChildFilterPushdownResult; @@ -885,13 +901,54 @@ mod tests { use uuid::Uuid; use super::*; - use crate::dist_plan::DynFilterRegistryManager; + use crate::dist_plan::{DynFilterRegistryManager, Subscriber}; use crate::region_query::RegionQueryHandler; fn test_query_id(value: u128) -> QueryId { QueryId::from(Uuid::from_u128(value)) } + #[test] + fn remote_dyn_filter_region_query_context_registers_before_do_get() { + let registry_manager = Arc::new(DynFilterRegistryManager::default()); + let query_ctx = QueryContext::arc(); + let query_id = query_ctx + .remote_query_id_value() + .expect("query context must have remote query id"); + let lease = registry_manager.acquire_lease(query_id); + let region_id = RegionId::new(1024, 7); + let dyn_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("host", 0)) as Arc<_>], + physical_lit(true) as _, + )) as Arc; + let captured = capture_remote_dyn_filters_for_pushdown( + RemoteDynFilterProducerId::new(42), + vec![dyn_filter], + ); + assert_eq!(captured.captured_dyn_filters.len(), 1); + + let region_query_ctx = query_context_for_remote_dyn_filter_region( + &query_ctx, + region_id, + Some(&lease), + &captured.captured_dyn_filters, + ); + + let entries = lease.registry().entries(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].subscribers(), vec![Subscriber::new(region_id)]); + assert!( + !entries[0].fanout_started_for_test(), + "fanout must start only after do_get succeeds" + ); + assert!( + region_query_ctx + .extension(INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY) + .is_some(), + "initial RDF registrations must be present in the do_get query context" + ); + } + #[test] fn remote_dyn_filter_registry_cleanup_waits_for_last_query_scoped_stream_drop() { let registry_manager = Arc::new(DynFilterRegistryManager::default()); diff --git a/src/query/src/dist_plan/remote_dyn_filter_registry.rs b/src/query/src/dist_plan/remote_dyn_filter_registry.rs index 4c33ddd960..3d0bda9132 100644 --- a/src/query/src/dist_plan/remote_dyn_filter_registry.rs +++ b/src/query/src/dist_plan/remote_dyn_filter_registry.rs @@ -13,15 +13,29 @@ // limitations under the License. use std::collections::{HashMap, HashSet}; -use std::sync::{Arc, RwLock, Weak}; +use std::future::Future; +use std::sync::{Arc, Mutex, RwLock, Weak}; +use std::time::Duration; +use api::v1::region::{RemoteDynFilterUnregister, RemoteDynFilterUpdate}; +use common_query::request::DynFilterPayload; +use common_runtime::spawn_global; +use common_telemetry::{debug, warn}; +use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr; use session::query_id::QueryId; use store_api::storage::RegionId; +use tokio::sync::{Notify, watch}; use crate::dist_plan::FilterId; +use crate::region_query::RegionQueryHandlerRef; -/// Routing metadata for a remote dynamic filter subscriber. +const REMOTE_DYN_FILTER_UPDATE_PAYLOAD_MAX_BYTES: usize = 64 * 1024; +const REMOTE_DYN_FILTER_RECONCILE_INTERVAL: Duration = Duration::from_secs(1); +/// Bound best-effort RDF control RPCs so one bad subscriber cannot stall fanout. +const REMOTE_DYN_FILTER_CONTROL_RPC_TIMEOUT: Duration = Duration::from_secs(10); + +/// Region subscribed to a remote dynamic filter. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Subscriber { region_id: RegionId, @@ -53,15 +67,21 @@ pub enum SubscriberRegistration { MissingFilter, } -/// A registered query-local remote dynamic filter entry. -/// -/// The frontend query owns the strong DataFusion filter handle until the query finishes; the -/// registry only keeps a weak reference for later updates. +/// A registered query-local producer filter and its region subscribers. #[derive(Debug)] pub struct DynFilterEntry { filter_id: FilterId, - alive_dyn_filter: Weak, + producer_filter: Weak, subscribers: RwLock>, + state: Mutex, + subscriber_changed: Notify, +} + +#[derive(Debug, Default)] +struct DynFilterEntryState { + last_sent_generation: u64, + unregistered: bool, + fanout_started: bool, } #[derive(Debug)] @@ -70,11 +90,13 @@ struct QueryDynFilterRegistryInner { } impl DynFilterEntry { - pub fn new(filter_id: FilterId, alive_dyn_filter: Arc) -> Self { + pub fn new(filter_id: FilterId, producer_filter: Arc) -> Self { Self { filter_id, - alive_dyn_filter: Arc::downgrade(&alive_dyn_filter), + producer_filter: Arc::downgrade(&producer_filter), subscribers: RwLock::new(HashSet::new()), + state: Mutex::new(DynFilterEntryState::default()), + subscriber_changed: Notify::new(), } } @@ -82,8 +104,8 @@ impl DynFilterEntry { &self.filter_id } - pub fn upgrade_alive_dyn_filter(&self) -> Option> { - self.alive_dyn_filter.upgrade() + pub fn upgrade_producer_filter(&self) -> Option> { + self.producer_filter.upgrade() } pub fn subscribers(&self) -> Vec { @@ -94,19 +116,68 @@ impl DynFilterEntry { let mut subscribers = self.subscribers.write().unwrap(); subscribers.insert(subscriber) } + + fn mark_generation_sent(&self, generation: u64) -> bool { + let mut state = self.state.lock().unwrap(); + if generation <= state.last_sent_generation { + return false; + } + + state.last_sent_generation = generation; + true + } + + fn try_mark_unregistered(&self) -> bool { + let mut state = self.state.lock().unwrap(); + if state.unregistered { + return false; + } + + state.unregistered = true; + true + } + + fn reactivate_for_new_subscriber(&self) { + { + let mut state = self.state.lock().unwrap(); + // Reset generation/unregister state so late subscribers get the current snapshot. + state.last_sent_generation = 0; + state.unregistered = false; + } + self.subscriber_changed.notify_one(); + } + + fn mark_fanout_started(&self) -> bool { + let mut state = self.state.lock().unwrap(); + if state.fanout_started { + return false; + } + + state.fanout_started = true; + true + } + + #[cfg(test)] + pub(crate) fn fanout_started_for_test(&self) -> bool { + self.state.lock().unwrap().fanout_started + } } /// Query-scoped registry that owns all remote dynamic filters for one query. #[derive(Debug)] pub struct QueryDynFilterRegistry { query_id: QueryId, + lifecycle_tx: watch::Sender<()>, inner: RwLock, } impl QueryDynFilterRegistry { pub fn new(query_id: QueryId) -> Self { + // Close-only lifecycle signal; dropping the registry closes it for watchers. + let (lifecycle_tx, _) = watch::channel(()); Self { query_id, + lifecycle_tx, inner: RwLock::new(QueryDynFilterRegistryInner { entries: HashMap::new(), }), @@ -138,14 +209,14 @@ impl QueryDynFilterRegistry { pub fn register_remote_dyn_filter( &self, filter_id: FilterId, - alive_dyn_filter: Arc, + producer_filter: Arc, ) -> EntryRegistration { let mut inner = self.inner.write().unwrap(); if let Some(existing) = inner.entries.get(&filter_id) { return EntryRegistration::Existing(existing.clone()); } - let entry = Arc::new(DynFilterEntry::new(filter_id.clone(), alive_dyn_filter)); + let entry = Arc::new(DynFilterEntry::new(filter_id.clone(), producer_filter)); inner.entries.insert(filter_id, entry.clone()); EntryRegistration::Inserted(entry) } @@ -160,16 +231,380 @@ impl QueryDynFilterRegistry { }; if entry.register_subscriber(subscriber) { + // New subscribers need the current snapshot; existing subscribers may see a duplicate. + entry.reactivate_for_new_subscriber(); SubscriberRegistration::Added } else { SubscriberRegistration::Duplicate } } + + /// Starts missing producer fanout watchers for the registry's entries. + /// + /// Watchers do not hold the registry alive; dropping the registry closes their lifecycle channel. + pub fn ensure_fanout_task(self: &Arc, region_query_handler: RegionQueryHandlerRef) { + for entry in self.entries() { + ensure_entry_fanout_task( + self.query_id, + entry, + region_query_handler.clone(), + self.lifecycle_tx.subscribe(), + ); + } + } + + #[cfg(test)] + async fn fanout_snapshot( + &self, + region_query_handler: &RegionQueryHandlerRef, + entry: &DynFilterEntry, + filter: &DynamicFilterPhysicalExpr, + is_complete: bool, + ) { + let mut lifecycle_rx = self.lifecycle_tx.subscribe(); + fanout_snapshot_for_query( + self.query_id, + region_query_handler, + entry, + filter, + is_complete, + &mut lifecycle_rx, + REMOTE_DYN_FILTER_CONTROL_RPC_TIMEOUT, + ) + .await; + } + + #[cfg(test)] + async fn unregister_all_once(&self, region_query_handler: &RegionQueryHandlerRef) { + for entry in self.entries() { + unregister_entry_once_for_query(region_query_handler, self.query_id, &entry).await; + } + } +} + +fn ensure_entry_fanout_task( + query_id: QueryId, + entry: Arc, + region_query_handler: RegionQueryHandlerRef, + lifecycle_rx: watch::Receiver<()>, +) { + if !entry.mark_fanout_started() { + return; + } + + let _handle = spawn_global(async move { + run_entry_fanout(query_id, entry, region_query_handler, lifecycle_rx).await; + }); +} + +async fn run_entry_fanout( + query_id: QueryId, + entry: Arc, + region_query_handler: RegionQueryHandlerRef, + mut lifecycle_rx: watch::Receiver<()>, +) { + let mut is_complete = false; + // Start reconcile after one interval and skip missed ticks; it is only a coalescing fallback. + let mut reconcile_interval = tokio::time::interval_at( + tokio::time::Instant::now() + REMOTE_DYN_FILTER_RECONCILE_INTERVAL, + REMOTE_DYN_FILTER_RECONCILE_INTERVAL, + ); + reconcile_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + loop { + let Some(filter) = entry.upgrade_producer_filter() else { + unregister_entry_once_for_query(®ion_query_handler, query_id, &entry).await; + return; + }; + + if !fanout_snapshot_for_query( + query_id, + ®ion_query_handler, + &entry, + &filter, + is_complete, + &mut lifecycle_rx, + REMOTE_DYN_FILTER_CONTROL_RPC_TIMEOUT, + ) + .await + { + break; + } + + if is_complete { + tokio::select! { + _ = entry.subscriber_changed.notified() => {} + result = lifecycle_rx.changed() => { + if result.is_err() { + break; + } + } + } + continue; + } + + tokio::select! { + _ = filter.wait_update() => {} + _ = filter.wait_complete() => { + is_complete = true; + } + // `wait_update()` can miss an update sent while an RPC is in-flight. + // Re-read periodically to coalesce to the latest generation. + _ = reconcile_interval.tick() => {} + _ = entry.subscriber_changed.notified() => {} + result = lifecycle_rx.changed() => { + if result.is_err() { + break; + } + } + } + } + + unregister_entry_once_for_query(®ion_query_handler, query_id, &entry).await; +} + +async fn fanout_snapshot_for_query( + query_id: QueryId, + region_query_handler: &RegionQueryHandlerRef, + entry: &DynFilterEntry, + filter: &DynamicFilterPhysicalExpr, + is_complete: bool, + lifecycle_rx: &mut watch::Receiver<()>, + control_rpc_timeout: Duration, +) -> bool { + let Some((generation, current)) = current_stable_snapshot(filter, lifecycle_rx).await else { + return true; + }; + + // The entry-global watermark advances before best-effort fanout. A timed-out + // subscriber may miss this generation; later/complete snapshots supersede it, + // and RDF only prunes. + if !is_complete && !entry.mark_generation_sent(generation) { + return true; + } + + if is_complete { + let _ = entry.mark_generation_sent(generation); + } + + let payload = match DynFilterPayload::from_datafusion_expr( + ¤t, + REMOTE_DYN_FILTER_UPDATE_PAYLOAD_MAX_BYTES, + ) { + Ok(DynFilterPayload::Datafusion(payload)) => payload, + Ok(_) => { + warn!("Ignored unsupported remote dynamic filter producer payload"); + return true; + } + Err(error) => { + warn!(error; "Failed to encode remote dynamic filter producer snapshot"); + return true; + } + }; + + fanout_update_for_query( + query_id, + region_query_handler, + entry, + generation, + is_complete, + payload, + lifecycle_rx, + control_rpc_timeout, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +async fn fanout_update_for_query( + query_id: QueryId, + region_query_handler: &RegionQueryHandlerRef, + entry: &DynFilterEntry, + generation: u64, + is_complete: bool, + payload: Vec, + lifecycle_rx: &mut watch::Receiver<()>, + control_rpc_timeout: Duration, +) -> bool { + let query_id = query_id.to_string(); + let filter_id = entry.filter_id().to_string(); + + for subscriber in entry.subscribers() { + let update = RemoteDynFilterUpdate { + filter_id: filter_id.clone(), + payload: payload.clone(), + generation, + is_complete, + }; + + match await_control_rpc_or_lifecycle_close( + lifecycle_rx, + format!( + "update query_id={} filter_id={} region_id={}", + query_id, + filter_id, + subscriber.region_id() + ), + region_query_handler.handle_remote_dyn_filter_update( + subscriber.region_id(), + query_id.clone(), + update, + ), + control_rpc_timeout, + ) + .await + { + ControlRpcResult::Ok(result) => { + if let Err(error) = result { + warn!( + error; + "Failed to fan out remote dynamic filter update, query_id={}, filter_id={}, region_id={}", + query_id, + filter_id, + subscriber.region_id() + ); + } + } + ControlRpcResult::TimedOut => {} + ControlRpcResult::LifecycleClosed => return false, + } + } + + true +} + +async fn unregister_entry_once_for_query( + region_query_handler: &RegionQueryHandlerRef, + query_id: QueryId, + entry: &DynFilterEntry, +) { + if !entry.try_mark_unregistered() { + return; + } + + let query_id = query_id.to_string(); + let filter_id = entry.filter_id().to_string(); + + for subscriber in entry.subscribers() { + let unregister = RemoteDynFilterUnregister { + filter_id: filter_id.clone(), + }; + + let Some(result) = await_control_rpc_timeout( + format!( + "unregister query_id={} filter_id={} region_id={}", + query_id, + filter_id, + subscriber.region_id() + ), + region_query_handler.handle_remote_dyn_filter_unregister( + subscriber.region_id(), + query_id.clone(), + unregister, + ), + ) + .await + else { + continue; + }; + + if let Err(error) = result { + warn!( + error; + "Failed to fan out remote dynamic filter unregister, query_id={}, filter_id={}, region_id={}", + query_id, + filter_id, + subscriber.region_id() + ); + } + } + + debug!("Remote dynamic filter producer unregistered subscribers"); +} + +enum ControlRpcResult { + Ok(T), + TimedOut, + LifecycleClosed, +} + +async fn await_control_rpc_or_lifecycle_close( + lifecycle_rx: &mut watch::Receiver<()>, + operation: String, + rpc: impl Future, + control_rpc_timeout: Duration, +) -> ControlRpcResult { + if lifecycle_rx.has_changed().is_err() { + return ControlRpcResult::LifecycleClosed; + } + + tokio::select! { + biased; + result = lifecycle_rx.changed() => { + if result.is_err() { + debug!("Cancelled remote dynamic filter control RPC after lifecycle close"); + } + ControlRpcResult::LifecycleClosed + } + result = rpc => ControlRpcResult::Ok(result), + _ = tokio::time::sleep(control_rpc_timeout) => { + warn!("Timed out remote dynamic filter control RPC: {}", operation); + ControlRpcResult::TimedOut + } + } +} + +async fn await_control_rpc_timeout( + operation: String, + rpc: impl Future, +) -> Option { + tokio::select! { + result = rpc => Some(result), + _ = tokio::time::sleep(REMOTE_DYN_FILTER_CONTROL_RPC_TIMEOUT) => { + warn!("Timed out remote dynamic filter control RPC: {}", operation); + None + } + } +} + +async fn current_stable_snapshot( + filter: &DynamicFilterPhysicalExpr, + lifecycle_rx: &mut watch::Receiver<()>, +) -> Option<(u64, Arc)> { + loop { + if lifecycle_rx.has_changed().is_err() { + return None; + } + + let before = filter.snapshot_generation(); + let current = match filter.current() { + Ok(current) => current, + Err(error) => { + warn!(error; "Failed to read remote dynamic filter producer snapshot"); + return None; + } + }; + let after = filter.snapshot_generation(); + + if before == after { + return Some((after, current)); + } + + tokio::select! { + biased; + result = lifecycle_rx.changed() => { + if result.is_err() { + return None; + } + } + _ = tokio::task::yield_now() => {} + } + } } /// Stream-scoped lease that keeps a query registry alive. /// -/// Production code owns registries through this lease; the manager only keeps a weak index. +/// Stream leases own registry lifecycle; the manager only keeps a weak index. #[derive(Debug)] pub struct RemoteDynFilterRegistryLease { registry_manager: Arc, @@ -196,6 +631,13 @@ impl RemoteDynFilterRegistryLease { .expect("remote dyn filter registry lease must hold a registry") } + pub fn ensure_fanout_task(&self, region_query_handler: RegionQueryHandlerRef) { + self.registry + .as_ref() + .expect("remote dyn filter registry lease must hold a registry") + .ensure_fanout_task(region_query_handler); + } + #[cfg(test)] pub(crate) fn ptr_eq(&self, other: &Self) -> bool { Arc::ptr_eq( @@ -224,7 +666,7 @@ impl Drop for RemoteDynFilterRegistryLease { /// Query-engine manager for query-scoped remote dynamic filter registries. /// -/// Weak index only; active streams own registries through [`RemoteDynFilterRegistryLease`]. +/// Weak index only; stream leases own registries through [`RemoteDynFilterRegistryLease`]. #[derive(Debug, Default)] pub struct DynFilterRegistryManager { registries: RwLock>>, @@ -326,14 +768,147 @@ impl DynFilterRegistryManager { #[cfg(test)] mod tests { - use std::sync::Barrier; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::{Barrier, Mutex}; use std::thread; + use std::time::Duration; + use api::v1::region::{RemoteDynFilterUnregister, RemoteDynFilterUpdate}; + use async_trait::async_trait; + use common_query::request::QueryRequest; use datafusion_physical_expr::expressions::{Column, lit}; + use session::ReadPreference; use uuid::Uuid; use super::*; use crate::dist_plan::{FilterFingerprint, RemoteDynFilterProducerId}; + use crate::error::Result as QueryResult; + use crate::region_query::RegionQueryHandler; + + #[derive(Debug, Clone, PartialEq, Eq)] + struct RecordedUpdate { + region_id: RegionId, + query_id: String, + filter_id: String, + generation: u64, + is_complete: bool, + payload: Vec, + } + + #[derive(Debug, Clone, PartialEq, Eq)] + struct RecordedUnregister { + region_id: RegionId, + query_id: String, + filter_id: String, + } + + #[derive(Default)] + struct RecordingRegionQueryHandler { + updates: Mutex>, + unregisters: Mutex>, + block_next_update: AtomicBool, + update_blocked: Notify, + release_update: Notify, + } + + impl RecordingRegionQueryHandler { + fn updates(&self) -> Vec { + self.updates.lock().unwrap().clone() + } + + fn unregisters(&self) -> Vec { + self.unregisters.lock().unwrap().clone() + } + + fn block_next_update(&self) { + self.block_next_update.store(true, Ordering::SeqCst); + } + + async fn wait_for_blocked_update(&self) { + self.update_blocked.notified().await; + } + + fn release_blocked_update(&self) { + self.release_update.notify_one(); + } + + async fn wait_for_update_count(&self, expected: usize) { + for _ in 0..300 { + if self.updates().len() >= expected { + return; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + panic!("timed out waiting for {expected} remote dyn filter updates"); + } + + async fn wait_for_unregister_count(&self, expected: usize) { + for _ in 0..300 { + if self.unregisters().len() >= expected { + return; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + panic!("timed out waiting for {expected} remote dyn filter unregisters"); + } + } + + async fn wait_for_registry_drop(registry: Weak) { + for _ in 0..300 { + if registry.upgrade().is_none() { + return; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + panic!("timed out waiting for remote dyn filter registry drop"); + } + + #[async_trait] + impl RegionQueryHandler for RecordingRegionQueryHandler { + async fn do_get( + &self, + _read_preference: ReadPreference, + _request: QueryRequest, + ) -> QueryResult { + unreachable!("remote dyn filter registry tests should not execute remote queries") + } + + async fn handle_remote_dyn_filter_update( + &self, + region_id: RegionId, + query_id: String, + update: RemoteDynFilterUpdate, + ) -> QueryResult<()> { + let should_block = self.block_next_update.swap(false, Ordering::SeqCst); + self.updates.lock().unwrap().push(RecordedUpdate { + region_id, + query_id, + filter_id: update.filter_id, + generation: update.generation, + is_complete: update.is_complete, + payload: update.payload, + }); + if should_block { + self.update_blocked.notify_one(); + self.release_update.notified().await; + } + Ok(()) + } + + async fn handle_remote_dyn_filter_unregister( + &self, + region_id: RegionId, + query_id: String, + unregister: RemoteDynFilterUnregister, + ) -> QueryResult<()> { + self.unregisters.lock().unwrap().push(RecordedUnregister { + region_id, + query_id, + filter_id: unregister.filter_id, + }); + Ok(()) + } + } fn test_query_id(value: u128) -> QueryId { QueryId::from(Uuid::from_u128(value)) @@ -575,4 +1150,436 @@ mod tests { ); assert_eq!(entry.subscribers().len(), 1); } + + #[tokio::test] + async fn fanout_sends_changed_generations_to_subscribers() { + let query_id = test_query_id(1); + let registry = Arc::new(QueryDynFilterRegistry::new(query_id)); + let filter = test_dyn_filter(&["host"]); + let filter_id = test_filter_id(1); + let entry = match registry.register_remote_dyn_filter(filter_id.clone(), filter.clone()) { + EntryRegistration::Inserted(entry) => entry, + other => panic!("unexpected registration result: {other:?}"), + }; + let subscriber = Subscriber::new(RegionId::new(1024, 7)); + assert_eq!( + registry.register_subscriber(&filter_id, subscriber.clone()), + SubscriberRegistration::Added + ); + + let handler = Arc::new(RecordingRegionQueryHandler::default()); + let handler_ref = handler.clone() as RegionQueryHandlerRef; + + registry + .fanout_snapshot(&handler_ref, &entry, filter.as_ref(), false) + .await; + let updates = handler.updates(); + assert_eq!(updates.len(), 1); + assert_eq!(updates[0].region_id, subscriber.region_id()); + assert_eq!(updates[0].query_id, query_id.to_string()); + assert_eq!(updates[0].filter_id, filter_id.to_string()); + assert_eq!(updates[0].generation, filter.snapshot_generation()); + assert!(!updates[0].is_complete); + assert!(!updates[0].payload.is_empty()); + + registry + .fanout_snapshot(&handler_ref, &entry, filter.as_ref(), false) + .await; + assert_eq!(handler.updates().len(), 1); + + filter.update(lit(false) as _).unwrap(); + registry + .fanout_snapshot(&handler_ref, &entry, filter.as_ref(), false) + .await; + let updates = handler.updates(); + assert_eq!(updates.len(), 2); + assert_eq!(updates[1].generation, filter.snapshot_generation()); + + let second_subscriber = Subscriber::new(RegionId::new(1024, 8)); + assert_eq!( + registry.register_subscriber(&filter_id, second_subscriber.clone()), + SubscriberRegistration::Added + ); + registry + .fanout_snapshot(&handler_ref, &entry, filter.as_ref(), false) + .await; + let updates = handler.updates(); + assert_eq!(updates.len(), 4); + assert!( + updates[2..] + .iter() + .any(|update| update.region_id == subscriber.region_id()) + ); + assert!( + updates[2..] + .iter() + .any(|update| update.region_id == second_subscriber.region_id()) + ); + assert_eq!(entry.subscribers().len(), 2); + } + + #[tokio::test] + async fn fanout_task_waits_for_dynamic_filter_notifications() { + let query_id = test_query_id(3); + let manager = Arc::new(DynFilterRegistryManager::default()); + let lease = manager.acquire_lease(query_id); + let registry_weak = Arc::downgrade(lease.registry.as_ref().unwrap()); + let filter = test_dyn_filter(&["host"]); + let filter_id = test_filter_id(1); + let _ = lease + .registry() + .register_remote_dyn_filter(filter_id.clone(), filter.clone()); + let subscriber = Subscriber::new(RegionId::new(1024, 7)); + assert_eq!( + lease + .registry() + .register_subscriber(&filter_id, subscriber.clone()), + SubscriberRegistration::Added + ); + + let handler = Arc::new(RecordingRegionQueryHandler::default()); + lease.ensure_fanout_task(handler.clone() as RegionQueryHandlerRef); + + handler.wait_for_update_count(1).await; + let initial_generation = handler.updates()[0].generation; + + filter.update(lit(false) as _).unwrap(); + handler.wait_for_update_count(2).await; + let updates = handler.updates(); + assert!(updates[1].generation > initial_generation); + assert_eq!(updates[1].region_id, subscriber.region_id()); + assert_eq!(updates[1].filter_id, filter_id.to_string()); + + filter.mark_complete(); + handler.wait_for_update_count(3).await; + let updates = handler.updates(); + assert!(updates[2].is_complete); + + drop(lease); + handler.wait_for_unregister_count(1).await; + let unregisters = handler.unregisters(); + assert_eq!(unregisters[0].region_id, subscriber.region_id()); + assert_eq!(unregisters[0].filter_id, filter_id.to_string()); + + wait_for_registry_drop(registry_weak).await; + } + + #[tokio::test] + async fn repeated_ensure_fanout_task_keeps_single_watcher() { + let query_id = test_query_id(6); + let manager = Arc::new(DynFilterRegistryManager::default()); + let lease = manager.acquire_lease(query_id); + let registry_weak = Arc::downgrade(lease.registry.as_ref().unwrap()); + let filter = test_dyn_filter(&["host"]); + let filter_id = test_filter_id(1); + let entry = match lease + .registry() + .register_remote_dyn_filter(filter_id.clone(), filter.clone()) + { + EntryRegistration::Inserted(entry) => entry, + other => panic!("unexpected registration result: {other:?}"), + }; + let subscriber = Subscriber::new(RegionId::new(1024, 7)); + assert_eq!( + lease + .registry() + .register_subscriber(&filter_id, subscriber.clone()), + SubscriberRegistration::Added + ); + + let handler = Arc::new(RecordingRegionQueryHandler::default()); + lease.ensure_fanout_task(handler.clone() as RegionQueryHandlerRef); + lease.ensure_fanout_task(handler.clone() as RegionQueryHandlerRef); + + assert!(entry.fanout_started_for_test()); + handler.wait_for_update_count(1).await; + tokio::time::sleep(Duration::from_millis(50)).await; + assert_eq!(handler.updates().len(), 1); + + filter.update(lit(false) as _).unwrap(); + handler.wait_for_update_count(2).await; + tokio::time::sleep(Duration::from_millis(50)).await; + assert_eq!(handler.updates().len(), 2); + + drop(lease); + handler.wait_for_unregister_count(1).await; + wait_for_registry_drop(registry_weak).await; + } + + #[tokio::test] + async fn fanout_task_resends_complete_snapshot_to_late_subscriber() { + let query_id = test_query_id(7); + let manager = Arc::new(DynFilterRegistryManager::default()); + let lease = manager.acquire_lease(query_id); + let registry_weak = Arc::downgrade(lease.registry.as_ref().unwrap()); + let filter = test_dyn_filter(&["host"]); + let filter_id = test_filter_id(1); + let _ = lease + .registry() + .register_remote_dyn_filter(filter_id.clone(), filter.clone()); + let first_subscriber = Subscriber::new(RegionId::new(1024, 7)); + assert_eq!( + lease + .registry() + .register_subscriber(&filter_id, first_subscriber.clone()), + SubscriberRegistration::Added + ); + + let handler = Arc::new(RecordingRegionQueryHandler::default()); + lease.ensure_fanout_task(handler.clone() as RegionQueryHandlerRef); + handler.wait_for_update_count(1).await; + + filter.mark_complete(); + handler.wait_for_update_count(2).await; + assert!(handler.updates()[1].is_complete); + + let late_subscriber = Subscriber::new(RegionId::new(1024, 8)); + assert_eq!( + lease + .registry() + .register_subscriber(&filter_id, late_subscriber.clone()), + SubscriberRegistration::Added + ); + + handler.wait_for_update_count(4).await; + let updates = handler.updates(); + assert!( + updates[2..].iter().any( + |update| update.region_id == first_subscriber.region_id() && update.is_complete + ) + ); + assert!( + updates[2..] + .iter() + .any(|update| update.region_id == late_subscriber.region_id() + && update.is_complete) + ); + + drop(lease); + handler.wait_for_unregister_count(1).await; + wait_for_registry_drop(registry_weak).await; + } + + #[tokio::test] + async fn fanout_task_unregisters_when_producer_filter_is_dropped() { + let query_id = test_query_id(8); + let manager = Arc::new(DynFilterRegistryManager::default()); + let lease = manager.acquire_lease(query_id); + let registry_weak = Arc::downgrade(lease.registry.as_ref().unwrap()); + let filter = test_dyn_filter(&["host"]); + let filter_id = test_filter_id(1); + let _ = lease + .registry() + .register_remote_dyn_filter(filter_id.clone(), filter.clone()); + let subscriber = Subscriber::new(RegionId::new(1024, 7)); + assert_eq!( + lease + .registry() + .register_subscriber(&filter_id, subscriber.clone()), + SubscriberRegistration::Added + ); + + let handler = Arc::new(RecordingRegionQueryHandler::default()); + lease.ensure_fanout_task(handler.clone() as RegionQueryHandlerRef); + handler.wait_for_update_count(1).await; + + drop(filter); + handler.wait_for_unregister_count(1).await; + let unregisters = handler.unregisters(); + assert_eq!(unregisters[0].region_id, subscriber.region_id()); + assert_eq!(unregisters[0].filter_id, filter_id.to_string()); + + drop(lease); + wait_for_registry_drop(registry_weak).await; + } + + #[tokio::test] + async fn reconcile_tick_catches_update_while_fanout_is_in_flight() { + let query_id = test_query_id(4); + let manager = Arc::new(DynFilterRegistryManager::default()); + let lease = manager.acquire_lease(query_id); + let registry_weak = Arc::downgrade(lease.registry.as_ref().unwrap()); + let filter = test_dyn_filter(&["host"]); + let filter_id = test_filter_id(1); + let _ = lease + .registry() + .register_remote_dyn_filter(filter_id.clone(), filter.clone()); + let subscriber = Subscriber::new(RegionId::new(1024, 7)); + assert_eq!( + lease + .registry() + .register_subscriber(&filter_id, subscriber.clone()), + SubscriberRegistration::Added + ); + + let handler = Arc::new(RecordingRegionQueryHandler::default()); + handler.block_next_update(); + lease.ensure_fanout_task(handler.clone() as RegionQueryHandlerRef); + + handler.wait_for_blocked_update().await; + let initial_generation = handler.updates()[0].generation; + + // Update before the watcher can subscribe again; reconcile must catch it. + filter.update(lit(false) as _).unwrap(); + handler.release_blocked_update(); + + handler.wait_for_update_count(2).await; + let updates = handler.updates(); + assert!(updates[1].generation > initial_generation); + assert_eq!(updates[1].region_id, subscriber.region_id()); + assert_eq!(updates[1].filter_id, filter_id.to_string()); + + drop(lease); + handler.wait_for_unregister_count(1).await; + wait_for_registry_drop(registry_weak).await; + } + + #[tokio::test] + async fn fanout_task_unregisters_after_lifecycle_close_during_blocked_update() { + let query_id = test_query_id(5); + let manager = Arc::new(DynFilterRegistryManager::default()); + let lease = manager.acquire_lease(query_id); + let registry_weak = Arc::downgrade(lease.registry.as_ref().unwrap()); + let filter = test_dyn_filter(&["host"]); + let filter_id = test_filter_id(1); + let _ = lease + .registry() + .register_remote_dyn_filter(filter_id.clone(), filter.clone()); + let subscriber = Subscriber::new(RegionId::new(1024, 7)); + assert_eq!( + lease + .registry() + .register_subscriber(&filter_id, subscriber.clone()), + SubscriberRegistration::Added + ); + + let handler = Arc::new(RecordingRegionQueryHandler::default()); + handler.block_next_update(); + lease.ensure_fanout_task(handler.clone() as RegionQueryHandlerRef); + + handler.wait_for_blocked_update().await; + drop(lease); + + handler.wait_for_unregister_count(1).await; + let unregisters = handler.unregisters(); + assert_eq!(unregisters[0].region_id, subscriber.region_id()); + assert_eq!(unregisters[0].filter_id, filter_id.to_string()); + wait_for_registry_drop(registry_weak).await; + } + + #[tokio::test] + async fn update_timeout_does_not_stop_fanout_for_other_subscribers() { + let query_id = test_query_id(9); + let registry = QueryDynFilterRegistry::new(query_id); + let filter = test_dyn_filter(&["host"]); + let filter_id = test_filter_id(1); + let entry = match registry.register_remote_dyn_filter(filter_id.clone(), filter.clone()) { + EntryRegistration::Inserted(entry) => entry, + other => panic!("unexpected registration result: {other:?}"), + }; + let first_subscriber = Subscriber::new(RegionId::new(1024, 7)); + let second_subscriber = Subscriber::new(RegionId::new(1024, 8)); + assert_eq!( + registry.register_subscriber(&filter_id, first_subscriber.clone()), + SubscriberRegistration::Added + ); + assert_eq!( + registry.register_subscriber(&filter_id, second_subscriber.clone()), + SubscriberRegistration::Added + ); + + let handler = Arc::new(RecordingRegionQueryHandler::default()); + let handler_ref = handler.clone() as RegionQueryHandlerRef; + let mut lifecycle_rx = registry.lifecycle_tx.subscribe(); + handler.block_next_update(); + assert!( + fanout_snapshot_for_query( + query_id, + &handler_ref, + &entry, + filter.as_ref(), + false, + &mut lifecycle_rx, + Duration::from_millis(100), + ) + .await + ); + + handler.wait_for_blocked_update().await; + // Fanout is serial and the blocked RPC stays blocked; the second update proves + // timeout continued to the next subscriber. + handler.wait_for_update_count(2).await; + + let initial_updates = handler.updates(); + assert_eq!( + initial_updates.len(), + 2, + "the healthy subscriber must still receive the update after another subscriber times out" + ); + assert!( + initial_updates + .iter() + .any(|update| update.region_id == first_subscriber.region_id()) + ); + assert!( + initial_updates + .iter() + .any(|update| update.region_id == second_subscriber.region_id()) + ); + + filter.update(lit(false) as _).unwrap(); + assert!( + fanout_snapshot_for_query( + query_id, + &handler_ref, + &entry, + filter.as_ref(), + false, + &mut lifecycle_rx, + Duration::from_millis(100), + ) + .await + ); + handler.wait_for_update_count(4).await; + let updates = handler.updates(); + assert!( + updates[2..] + .iter() + .any(|update| update.region_id == first_subscriber.region_id()) + ); + assert!( + updates[2..] + .iter() + .any(|update| update.region_id == second_subscriber.region_id()) + ); + + registry.unregister_all_once(&handler_ref).await; + handler.wait_for_unregister_count(1).await; + } + + #[tokio::test] + async fn unregister_fanout_is_idempotent() { + let query_id = test_query_id(2); + let registry = QueryDynFilterRegistry::new(query_id); + let filter = test_dyn_filter(&["host"]); + let filter_id = test_filter_id(1); + let _ = registry.register_remote_dyn_filter(filter_id.clone(), filter); + let subscriber = Subscriber::new(RegionId::new(1024, 7)); + assert_eq!( + registry.register_subscriber(&filter_id, subscriber.clone()), + SubscriberRegistration::Added + ); + + let handler = Arc::new(RecordingRegionQueryHandler::default()); + let handler_ref = handler.clone() as RegionQueryHandlerRef; + + registry.unregister_all_once(&handler_ref).await; + registry.unregister_all_once(&handler_ref).await; + + let unregisters = handler.unregisters(); + assert_eq!(unregisters.len(), 1); + assert_eq!(unregisters[0].region_id, subscriber.region_id()); + assert_eq!(unregisters[0].query_id, query_id.to_string()); + assert_eq!(unregisters[0].filter_id, filter_id.to_string()); + } }