From 3a0f37c06bf0ede72ffb5756c7e30f14a294c206 Mon Sep 17 00:00:00 2001 From: discord9 Date: Mon, 20 Apr 2026 13:11:42 +0800 Subject: [PATCH] feat: init reg dyn filter --- src/common/query/Cargo.toml | 1 + src/common/query/src/request.rs | 44 ++- .../request/initial_remote_dyn_filter_reg.rs | 154 ++++++++++ src/datanode/src/region_server.rs | 160 +++++++++- .../src/region_server/registrations.rs | 124 ++++++++ src/query/src/dist_plan.rs | 1 + src/query/src/dist_plan/dyn_filter_bridge.rs | 282 ++++++++++++++++++ src/query/src/dist_plan/merge_scan.rs | 165 +--------- 8 files changed, 773 insertions(+), 158 deletions(-) create mode 100644 src/common/query/src/request/initial_remote_dyn_filter_reg.rs create mode 100644 src/datanode/src/region_server/registrations.rs create mode 100644 src/query/src/dist_plan/dyn_filter_bridge.rs diff --git a/src/common/query/Cargo.toml b/src/common/query/Cargo.toml index b28e56fe8e..9706098826 100644 --- a/src/common/query/Cargo.toml +++ b/src/common/query/Cargo.toml @@ -28,6 +28,7 @@ datatypes.workspace = true once_cell.workspace = true prost.workspace = true serde.workspace = true +serde_json.workspace = true snafu.workspace = true sqlparser.workspace = true sqlparser_derive = "0.1" diff --git a/src/common/query/src/request.rs b/src/common/query/src/request.rs index 22437bde71..792c4d0947 100644 --- a/src/common/query/src/request.rs +++ b/src/common/query/src/request.rs @@ -12,26 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod initial_remote_dyn_filter_reg; + use std::sync::Arc; use api::v1::region::RegionRequestHeader; -use datafusion::arrow::datatypes::Schema; use datafusion::execution::TaskContext; use datafusion::physical_expr::expressions::Column; -use datafusion::physical_plan::PhysicalExpr; use datafusion::physical_plan::joins::HashTableLookupExpr; +use datafusion::physical_plan::PhysicalExpr; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_expr::LogicalPlan; -use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use datafusion_proto::physical_plan::from_proto::parse_physical_expr; use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; +use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use datafusion_proto::protobuf::PhysicalExprNode; use prost::Message; use serde::{Deserialize, Serialize}; use store_api::storage::RegionId; /// Current wire-format version for remote dynamic filter payload updates. +pub use self::initial_remote_dyn_filter_reg::{ + InitialDynFilterReg, InitialDynFilterRegs, + INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY, +}; + pub const DYN_FILTER_PROTOCOL_VERSION: u32 = 1; /// Serialized predicate payload for remote dynamic filter updates. @@ -107,7 +113,7 @@ impl DynFilterPayload { pub fn decode_datafusion_expr( &self, task_ctx: &TaskContext, - input_schema: &Schema, + input_schema: &datafusion::arrow::datatypes::Schema, max_payload_bytes: usize, ) -> DataFusionResult> { let Self::Datafusion(bytes) = self; @@ -124,6 +130,34 @@ impl DynFilterPayload { } } +fn encode_physical_expr_to_bytes(expr: &Arc) -> DataFusionResult> { + let codec = DefaultPhysicalExtensionCodec {}; + let proto = serialize_physical_expr(expr, &codec)?; + let mut bytes = Vec::new(); + proto.encode(&mut bytes).map_err(|e| { + DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}")) + })?; + Ok(bytes) +} + +pub(crate) fn decode_physical_expr_from_bytes( + bytes: &[u8], + task_ctx: &TaskContext, + input_schema: &datafusion::arrow::datatypes::Schema, + max_payload_bytes: usize, +) -> DataFusionResult> { + validate_payload_size(bytes.len(), max_payload_bytes)?; + let codec = DefaultPhysicalExtensionCodec {}; + let proto = PhysicalExprNode::decode(bytes).map_err(|e| { + DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}")) + })?; + + let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?; + validate_supported_payload_expr(&expr)?; + validate_decoded_payload_expr(&expr, input_schema)?; + Ok(expr) +} + fn validate_payload_size( payload_size_bytes: usize, max_payload_bytes: usize, @@ -161,7 +195,7 @@ fn validate_supported_payload_expr(expr: &Arc) -> DataFusionRe /// schema inconsistency that should be surfaced loudly. fn validate_decoded_payload_expr( expr: &Arc, - input_schema: &Schema, + input_schema: &datafusion::arrow::datatypes::Schema, ) -> DataFusionResult<()> { expr.apply(|node| { if let Some(column) = node.as_any().downcast_ref::() { diff --git a/src/common/query/src/request/initial_remote_dyn_filter_reg.rs b/src/common/query/src/request/initial_remote_dyn_filter_reg.rs new file mode 100644 index 0000000000..8c71289377 --- /dev/null +++ b/src/common/query/src/request/initial_remote_dyn_filter_reg.rs @@ -0,0 +1,154 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use datafusion::arrow::datatypes::Schema; +use datafusion::execution::TaskContext; +use datafusion::physical_plan::PhysicalExpr; +use datafusion_common::Result as DataFusionResult; +use serde::{Deserialize, Serialize}; + +use crate::request::{decode_physical_expr_from_bytes, encode_physical_expr_to_bytes}; + +pub const INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY: &str = + "initial_remote_dyn_filter_registrations"; + +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct InitialDynFilterRegs { + #[serde(rename = "registrations")] + pub regs: Vec, +} + +impl InitialDynFilterRegs { + pub fn new(regs: Vec) -> Self { + Self { regs } + } + + pub fn is_empty(&self) -> bool { + self.regs.is_empty() + } + + pub fn to_extension_value(&self) -> serde_json::Result { + serde_json::to_string(self) + } + + pub fn from_extension_value(value: &str) -> serde_json::Result { + serde_json::from_str(value) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct InitialDynFilterReg { + pub filter_id: String, + pub child_exprs_datafusion_proto: Vec>, +} + +impl InitialDynFilterReg { + pub fn new(filter_id: impl Into, child_exprs_datafusion_proto: Vec>) -> Self { + Self { + filter_id: filter_id.into(), + child_exprs_datafusion_proto, + } + } + + pub fn from_filter_id_and_children( + filter_id: impl Into, + children: &[Arc], + ) -> DataFusionResult { + let child_exprs_datafusion_proto = children + .iter() + .map(encode_physical_expr_to_bytes) + .collect::>>()?; + + Ok(Self::new(filter_id, child_exprs_datafusion_proto)) + } + + pub fn decode_children( + &self, + task_ctx: &TaskContext, + input_schema: &Schema, + max_payload_bytes: usize, + ) -> DataFusionResult>> { + self.child_exprs_datafusion_proto + .iter() + .map(|expr_bytes| { + decode_physical_expr_from_bytes( + expr_bytes, + task_ctx, + input_schema, + max_payload_bytes, + ) + }) + .collect::>>() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::physical_expr::expressions::Column; + use datafusion::physical_plan::PhysicalExpr; + use datafusion_common::DataFusionError; + + use super::*; + + #[test] + fn initial_dyn_filter_regs_json_round_trip() { + let regs = InitialDynFilterRegs::new(vec![ + InitialDynFilterReg::new("filter-a", vec![vec![1, 2, 3]]), + InitialDynFilterReg::new("filter-b", vec![vec![4, 5]]), + ]); + + let encoded = regs.to_extension_value().unwrap(); + let decoded = InitialDynFilterRegs::from_extension_value(&encoded).unwrap(); + + assert_eq!(decoded, regs); + } + + #[test] + fn initial_dyn_filter_reg_round_trips_child_exprs() { + let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]); + let child: Arc = + Arc::new(Column::new_with_schema("host", &schema).unwrap()); + let reg = InitialDynFilterReg::from_filter_id_and_children("filter-1", &[child]).unwrap(); + + let decoded = reg + .decode_children(&TaskContext::default(), &schema, 1024) + .unwrap(); + let decoded = decoded[0].as_any().downcast_ref::().unwrap(); + + assert_eq!(reg.filter_id, "filter-1"); + assert_eq!(decoded.name(), "host"); + assert_eq!(decoded.index(), 0); + } + + #[test] + fn initial_dyn_filter_reg_decode_rejects_column_name_index_mismatch() { + let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]); + let reg = InitialDynFilterReg::from_filter_id_and_children( + "filter-1", + &[Arc::new(Column::new("service", 0)) as Arc], + ) + .unwrap(); + + let err = reg + .decode_children(&TaskContext::default(), &schema, 1024) + .unwrap_err(); + + assert!(matches!(err, DataFusionError::Plan(_))); + } +} diff --git a/src/datanode/src/region_server.rs b/src/datanode/src/region_server.rs index aa2e627ca2..2ae09356fb 100644 --- a/src/datanode/src/region_server.rs +++ b/src/datanode/src/region_server.rs @@ -13,6 +13,7 @@ // limitations under the License. mod catalog; +mod registrations; use std::collections::HashMap; use std::fmt::Debug; @@ -97,6 +98,10 @@ use crate::error::{ }; use crate::event_listener::RegionServerEventListenerRef; use crate::region_server::catalog::{NameAwareCatalogList, NameAwareDataSourceInjectorBuilder}; +use crate::region_server::registrations::{ + RegisteredDynFilter, initial_dyn_filter_regs_from_query_ctx, + register_initial_dyn_filter_regs, remove_initial_dyn_filter_regs_for_region, +}; #[derive(Clone)] pub struct RegionServer { @@ -274,6 +279,18 @@ impl RegionServer { common_telemetry::info!("Handle remote read for region: {}", region_id); } + let initial_dyn_filter_regs = initial_dyn_filter_regs_from_query_ctx(&query_ctx); + if query_ctx.explain_verbose() { + common_telemetry::info!( + "Initial remote dyn filter registrations for region {}: {}", + region_id, + initial_dyn_filter_regs + .as_ref() + .map(|regs| regs.regs.len()) + .unwrap_or(0) + ); + } + let decoder = self .inner .query_engine @@ -286,7 +303,20 @@ impl RegionServer { .await .context(DecodeLogicalPlanSnafu)?; - let stream = self + let query_id = query_ctx.remote_query_id().map(ToOwned::to_owned); + if let (Some(query_id), Some(regs)) = ( + query_id.as_deref(), + initial_dyn_filter_regs.as_ref(), + ) { + register_initial_dyn_filter_regs( + &self.inner.initial_remote_dyn_filter_registrations, + query_id, + region_id, + regs, + ); + } + + let result = self .inner .handle_read( QueryRequest { @@ -296,8 +326,17 @@ impl RegionServer { }, query_ctx.clone(), ) - .await?; + .await; + if result.is_err() && let Some(query_id) = query_id.as_deref() { + remove_initial_dyn_filter_regs_for_region( + &self.inner.initial_remote_dyn_filter_registrations, + query_id, + region_id, + ); + } + + let stream = result?; Ok(wrap_flow_region_watermark_stream( stream, region_id, &query_ctx, )) @@ -1056,6 +1095,10 @@ struct RegionServerInner { /// server with a concrete engine; acceptable for now to fetch Mito-specific /// info (e.g., list SSTs). Consider a diagnostics trait later. mito_engine: RwLock>, + /// TODO(remote-dyn-filter): Reap this query-scoped placeholder registry on query finish/cancel + /// and later fold it into the real remote dyn filter runtime state lifecycle. + initial_remote_dyn_filter_registrations: + DashMap>, } struct RegionServerParallelism { @@ -1123,6 +1166,7 @@ impl RegionServerInner { parallelism, topic_stats_reporter: RwLock::new(None), mito_engine: RwLock::new(None), + initial_remote_dyn_filter_registrations: DashMap::new(), } } @@ -1857,6 +1901,10 @@ mod tests { RemoteDynFilterRequest, RemoteDynFilterUnregister, RemoteDynFilterUpdate, remote_dyn_filter_request, }; + use common_query::request::{ + INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY, InitialDynFilterReg, + InitialDynFilterRegs, + }; use common_error::ext::ErrorExt; use common_recordbatch::RecordBatches; use common_recordbatch::adapter::{RecordBatchMetrics, RegionWatermarkEntry}; @@ -2015,6 +2063,114 @@ mod tests { assert!(pinned.as_ref().get_ref().metrics().is_none()); } + #[test] + fn initial_dyn_filter_regs_can_be_read_from_query_context() { + let mut query_ctx = QueryContext::with("greptime", "public"); + query_ctx.set_extension( + INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY, + InitialDynFilterRegs::new(vec![InitialDynFilterReg::new( + "filter-1", + vec![vec![1, 2, 3]], + )]) + .to_extension_value() + .unwrap(), + ); + + let regs = initial_dyn_filter_regs_from_query_ctx(&Arc::new(query_ctx)).unwrap(); + + assert_eq!(regs.regs.len(), 1); + assert_eq!(regs.regs[0].filter_id, "filter-1"); + } + + #[test] + fn register_initial_dyn_filter_regs_creates_query_scoped_entries() { + let regs_by_query = DashMap::>::new(); + let regs = InitialDynFilterRegs::new(vec![ + InitialDynFilterReg::new("filter-1", vec![vec![1, 2, 3]]), + InitialDynFilterReg::new("filter-2", vec![vec![4, 5, 6]]), + ]); + let query_id = "query-1"; + let region_id = RegionId::new(1024, 7); + + register_initial_dyn_filter_regs( + ®s_by_query, + query_id, + region_id, + ®s, + ); + + let query_regs = regs_by_query.get(query_id).unwrap(); + assert_eq!(query_regs.len(), 2); + let registered = query_regs.get("filter-1").unwrap(); + assert_eq!(registered.filter_id, "filter-1"); + assert_eq!(registered.child_exprs_datafusion_proto, vec![vec![1, 2, 3]]); + assert_eq!(registered.subscriber_regions, vec![region_id]); + } + + #[test] + fn register_initial_dyn_filter_regs_ignores_duplicate_filter_entry() { + let regs_by_query = DashMap::>::new(); + let regs = InitialDynFilterRegs::new(vec![ + InitialDynFilterReg::new("filter-1", vec![vec![1, 2, 3]]), + ]); + let query_id = "query-1"; + let region_id = RegionId::new(1024, 7); + + register_initial_dyn_filter_regs( + ®s_by_query, + query_id, + region_id, + ®s, + ); + register_initial_dyn_filter_regs( + ®s_by_query, + query_id, + region_id, + ®s, + ); + + let query_regs = regs_by_query.get(query_id).unwrap(); + assert_eq!(query_regs.len(), 1); + let registered = query_regs.get("filter-1").unwrap(); + assert_eq!(registered.subscriber_regions, vec![region_id]); + } + + #[test] + fn remove_initial_dyn_filter_regs_for_region_removes_region_entries() { + let regs_by_query = DashMap::>::new(); + let query_id = "query-1"; + let region_id = RegionId::new(1024, 7); + let other_query_id = "query-2"; + let other_region_id = RegionId::new(1024, 8); + + register_initial_dyn_filter_regs( + ®s_by_query, + query_id, + region_id, + &InitialDynFilterRegs::new(vec![ + InitialDynFilterReg::new("filter-1", vec![vec![1, 2, 3]]), + ]), + ); + register_initial_dyn_filter_regs( + ®s_by_query, + other_query_id, + other_region_id, + &InitialDynFilterRegs::new(vec![ + InitialDynFilterReg::new("filter-2", vec![vec![4, 5, 6]]), + ]), + ); + + remove_initial_dyn_filter_regs_for_region( + ®s_by_query, + query_id, + region_id, + ); + + assert!(regs_by_query.get(query_id).is_none()); + let other_query_regs = regs_by_query.get(other_query_id).unwrap(); + assert_eq!(other_query_regs.len(), 1); + } + #[tokio::test] async fn test_region_registering() { common_telemetry::init_default_ut_logging(); diff --git a/src/datanode/src/region_server/registrations.rs b/src/datanode/src/region_server/registrations.rs new file mode 100644 index 0000000000..dee4cf9be2 --- /dev/null +++ b/src/datanode/src/region_server/registrations.rs @@ -0,0 +1,124 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_query::request::{ + InitialDynFilterRegs, INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY, +}; +use common_telemetry::warn; +use dashmap::DashMap; +use session::context::QueryContextRef; +use store_api::storage::RegionId; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct RegisteredDynFilter { + pub(super) filter_id: String, + pub(super) child_exprs_datafusion_proto: Vec>, + pub(super) subscriber_regions: Vec, +} + +impl RegisteredDynFilter { + fn new( + filter_id: String, + child_exprs_datafusion_proto: Vec>, + region_id: RegionId, + ) -> Self { + Self { + filter_id, + child_exprs_datafusion_proto, + subscriber_regions: vec![region_id], + } + } +} + +pub(super) fn initial_dyn_filter_regs_from_query_ctx( + query_ctx: &QueryContextRef, +) -> Option { + let registrations = + query_ctx.extension(INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY)?; + match InitialDynFilterRegs::from_extension_value(registrations) { + Ok(registrations) => Some(registrations), + Err(error) => { + warn!(error; "Failed to decode initial remote dyn filter registrations from query context"); + None + } + } +} + +pub(super) fn register_initial_dyn_filter_regs( + regs_by_query: &DashMap>, + query_id: &str, + region_id: RegionId, + regs: &InitialDynFilterRegs, +) { + if regs.is_empty() { + return; + } + + let query_regs = regs_by_query + .entry(query_id.to_string()) + .or_insert_with(DashMap::new); + + for reg in ®s.regs { + if query_regs.contains_key(®.filter_id) { + warn!( + query_id, + filter_id = reg.filter_id, + region_id = %region_id, + "Duplicate initial remote dyn filter registration ignored" + ); + continue; + } + + query_regs.insert( + reg.filter_id.clone(), + RegisteredDynFilter::new( + reg.filter_id.clone(), + reg.child_exprs_datafusion_proto.clone(), + region_id, + ), + ); + } +} + +pub(super) fn remove_initial_dyn_filter_regs_for_region( + regs_by_query: &DashMap>, + query_id: &str, + region_id: RegionId, +) { + let should_remove_query = { + let Some(query_regs) = regs_by_query.get(query_id) else { + return; + }; + + let filter_ids_to_remove = query_regs + .iter() + .filter_map(|registered| { + registered + .subscriber_regions + .contains(®ion_id) + .then(|| registered.filter_id.clone()) + }) + .collect::>(); + + for filter_id in filter_ids_to_remove { + query_regs.remove(&filter_id); + } + + query_regs.is_empty() + }; + + if should_remove_query { + regs_by_query.remove(query_id); + } +} diff --git a/src/query/src/dist_plan.rs b/src/query/src/dist_plan.rs index 4c0a17542b..41735048d4 100644 --- a/src/query/src/dist_plan.rs +++ b/src/query/src/dist_plan.rs @@ -20,6 +20,7 @@ mod merge_sort; mod planner; mod predicate_extractor; mod region_pruner; +mod dyn_filter_bridge; mod remote_dyn_filter_registry; pub use analyzer::{DistPlannerAnalyzer, DistPlannerOptions}; diff --git a/src/query/src/dist_plan/dyn_filter_bridge.rs b/src/query/src/dist_plan/dyn_filter_bridge.rs new file mode 100644 index 0000000000..d735687f5d --- /dev/null +++ b/src/query/src/dist_plan/dyn_filter_bridge.rs @@ -0,0 +1,282 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::sync::Arc; + +use common_query::request::{ + InitialDynFilterReg, InitialDynFilterRegs, + INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY, +}; +use datafusion::execution::TaskContext; +use datafusion_common::Result; +use datafusion_physical_expr::expressions::{lit, Column, DynamicFilterPhysicalExpr}; +use datafusion_physical_expr::PhysicalExpr; +use session::context::{QueryContext, QueryContextRef}; +use session::query_id::QueryId; +use store_api::storage::RegionId; +use uuid::Uuid; + +use super::filter_id::build_remote_dyn_filter_id; +use crate::dist_plan::{FilterId, QueryDynFilterRegistry, Subscriber}; +use crate::query_engine::QueryEngineState; + +#[derive(Debug, Clone)] +pub(crate) struct CapturedDynFilter { + pub(crate) producer_local_ordinal: usize, + pub(crate) alive_dyn_filter: Arc, +} + +pub(crate) fn capture_remote_dyn_filters( + parent_filters: Vec>, +) -> Vec { + parent_filters + .into_iter() + .enumerate() + .filter_map(|(producer_local_ordinal, filter)| { + downcast_dynamic_filter(filter).map(|alive_dyn_filter| CapturedDynFilter { + producer_local_ordinal, + alive_dyn_filter, + }) + }) + .collect() +} + +fn downcast_dynamic_filter( + expr: Arc, +) -> Option> { + (expr as Arc) + .downcast::() + .ok() +} + +fn query_engine_state_from_task_context(context: &TaskContext) -> Option> { + let query_engine_state: Option> = + context.session_config().get_extension(); + query_engine_state +} + +pub(crate) fn register_dyn_filters_for_region( + registry: &QueryDynFilterRegistry, + region_id: RegionId, + captured_dyn_filters: &[CapturedDynFilter], +) { + for captured_dyn_filter in captured_dyn_filters { + let Ok((filter_id, _children)) = + filter_id_and_children_for_filter(region_id, captured_dyn_filter) + else { + continue; + }; + + let _ = registry.register_remote_dyn_filter( + filter_id.clone(), + captured_dyn_filter.alive_dyn_filter.clone(), + ); + let _ = registry.register_subscriber(&filter_id, Subscriber::new(region_id)); + } +} + +pub(crate) fn bridge_dyn_filters_for_region( + context: &TaskContext, + query_ctx: &QueryContextRef, + region_id: RegionId, + captured_dyn_filters: &[CapturedDynFilter], +) { + if captured_dyn_filters.is_empty() { + return; + } + + let Some(query_engine_state) = query_engine_state_from_task_context(context) else { + return; + }; + let Some(registry) = query_engine_state.get_or_init_remote_dyn_filter_registry(query_ctx) + else { + return; + }; + + register_dyn_filters_for_region(®istry, region_id, captured_dyn_filters); +} + +fn filter_id_and_children_for_filter( + region_id: RegionId, + captured_dyn_filter: &CapturedDynFilter, +) -> Result<( + FilterId, + Vec>, +)> { + let children = captured_dyn_filter + .alive_dyn_filter + .children() + .into_iter() + .cloned() + .collect::>(); + let filter_id = build_remote_dyn_filter_id( + region_id, + captured_dyn_filter.producer_local_ordinal, + &children, + )?; + + Ok((filter_id, children)) +} + +fn build_initial_dyn_filter_regs_for_region( + region_id: RegionId, + captured_dyn_filters: &[CapturedDynFilter], +) -> InitialDynFilterRegs { + InitialDynFilterRegs::new( + captured_dyn_filters + .iter() + .filter_map(|captured_dyn_filter| { + let Ok((filter_id, children)) = + filter_id_and_children_for_filter(region_id, captured_dyn_filter) + else { + return None; + }; + + match InitialDynFilterReg::from_filter_id_and_children( + filter_id.to_string(), + &children, + ) { + Ok(registration) => Some(registration), + Err(error) => { + common_telemetry::warn!(error; "Failed to encode initial remote dyn filter registration"); + None + } + } + }) + .collect(), + ) +} + +pub(crate) fn query_context_with_initial_dyn_filter_regs( + query_ctx: &QueryContextRef, + region_id: RegionId, + captured_dyn_filters: &[CapturedDynFilter], +) -> QueryContext { + let mut region_query_ctx = query_ctx.as_ref().clone(); + let regs = build_initial_dyn_filter_regs_for_region(region_id, captured_dyn_filters); + if regs.is_empty() { + return region_query_ctx; + } + + match regs.to_extension_value() { + Ok(serialized) => region_query_ctx.set_extension( + INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY, + serialized, + ), + Err(error) => { + common_telemetry::warn!(error; "Failed to serialize initial remote dyn filter registrations"); + } + } + + region_query_ctx +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_query_id(value: u128) -> QueryId { + QueryId::from(Uuid::from_u128(value)) + } + + #[test] + fn capture_remote_dyn_filters_preserves_parent_filter_ordinals() { + let parent_filters = vec![ + Arc::new(Column::new("service", 0)) as Arc, + Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("host", 1)) as Arc<_>], + lit(true) as _, + )) as Arc, + Arc::new(Column::new("zone", 2)) as Arc, + Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("pod", 3)) as Arc<_>], + lit(true) as _, + )) as Arc, + ]; + + let captured = capture_remote_dyn_filters(parent_filters); + + assert_eq!(captured.len(), 2); + assert_eq!(captured[0].producer_local_ordinal, 1); + assert_eq!(captured[1].producer_local_ordinal, 3); + } + + #[test] + fn register_dyn_filters_for_region_reuses_existing_entry() { + let registry = QueryDynFilterRegistry::new(test_query_id(1)); + let captured_dyn_filters = vec![CapturedDynFilter { + producer_local_ordinal: 2, + alive_dyn_filter: Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("host", 0)) as Arc<_>], + lit(true) as _, + )), + }]; + let region_id = RegionId::new(1024, 7); + + register_dyn_filters_for_region(®istry, region_id, &captured_dyn_filters); + register_dyn_filters_for_region(®istry, region_id, &captured_dyn_filters); + + assert_eq!(registry.entry_count(), 1); + let entry = registry.entries().pop().unwrap(); + assert_eq!(entry.filter_id().producer_ordinal(), 2); + assert_eq!(entry.subscribers().len(), 1); + assert_eq!(entry.subscribers()[0].region_id(), region_id); + } + + #[test] + fn query_context_includes_region_initial_dyn_filter_regs() { + let captured_dyn_filters = vec![CapturedDynFilter { + producer_local_ordinal: 2, + alive_dyn_filter: Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("host", 0)) as Arc<_>], + lit(true) as _, + )), + }]; + let region_id = RegionId::new(1024, 7); + let query_ctx = QueryContext::arc(); + + let region_query_ctx = query_context_with_initial_dyn_filter_regs( + &query_ctx, + region_id, + &captured_dyn_filters, + ); + let extension = region_query_ctx + .extension(INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY) + .unwrap(); + let regs = InitialDynFilterRegs::from_extension_value(extension).unwrap(); + let decoded_children = regs.regs[0] + .decode_children( + &TaskContext::default(), + &arrow_schema::Schema::new(vec![arrow_schema::Field::new( + "host", + arrow_schema::DataType::Utf8, + false, + )]), + 1024, + ) + .unwrap(); + let expected_filter_id = build_remote_dyn_filter_id( + region_id, + captured_dyn_filters[0].producer_local_ordinal, + &[Arc::new(Column::new("host", 0)) as Arc<_>], + ) + .unwrap(); + + assert_eq!(regs.regs.len(), 1); + assert_eq!(regs.regs[0].filter_id, expected_filter_id.to_string()); + assert_eq!(decoded_children.len(), 1); + assert!(decoded_children[0].as_any().is::()); + } +} diff --git a/src/query/src/dist_plan/merge_scan.rs b/src/query/src/dist_plan/merge_scan.rs index 98a9a7831c..043c4b5e96 100644 --- a/src/query/src/dist_plan/merge_scan.rs +++ b/src/query/src/dist_plan/merge_scan.rs @@ -40,9 +40,7 @@ use datafusion::physical_plan::{ use datafusion_common::{Column as ColumnExpr, DataFusionError, Result}; use datafusion_expr::{Expr, Extension, LogicalPlan, UserDefinedLogicalNodeCore}; use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr}; -use datafusion_physical_expr::{ - Distribution, EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, -}; +use datafusion_physical_expr::{Distribution, EquivalenceProperties, PhysicalExpr, PhysicalSortExpr}; use futures_util::StreamExt; use greptime_proto::v1::region::RegionRequestHeader; use meter_core::data::ReadItem; @@ -53,12 +51,13 @@ use table::table_name::TableName; use tokio::time::Instant; use tracing::{Instrument, Span}; -use super::filter_id::build_remote_dyn_filter_id; use crate::dist_plan::analyzer::AliasMapping; use crate::dist_plan::analyzer::utils::patch_batch_timezone; -use crate::dist_plan::{QueryDynFilterRegistry, Subscriber}; +use crate::dist_plan::dyn_filter_bridge::{ + CapturedDynFilter, bridge_dyn_filters_for_region, capture_remote_dyn_filters, + query_context_with_initial_dyn_filter_regs, +}; use crate::metrics::{MERGE_SCAN_ERRORS_TOTAL, MERGE_SCAN_POLL_ELAPSED, MERGE_SCAN_REGIONS}; -use crate::query_engine::QueryEngineState; use crate::region_query::RegionQueryHandlerRef; #[derive(Debug, Hash, PartialOrd, PartialEq, Eq, Clone)] @@ -140,90 +139,6 @@ impl MergeScanLogicalPlan { } } -#[derive(Debug, Clone)] -struct CapturedRemoteDynFilter { - producer_local_ordinal: usize, - alive_dyn_filter: Arc, -} - -fn capture_remote_dyn_filters( - parent_filters: Vec>, -) -> Vec { - parent_filters - .into_iter() - .enumerate() - .filter_map(|(producer_local_ordinal, filter)| { - downcast_dynamic_filter(filter).map(|alive_dyn_filter| CapturedRemoteDynFilter { - producer_local_ordinal, - alive_dyn_filter, - }) - }) - .collect() -} - -fn downcast_dynamic_filter( - expr: Arc, -) -> Option> { - (expr as Arc) - .downcast::() - .ok() -} - -fn query_engine_state_from_task_context(context: &TaskContext) -> Option> { - let query_engine_state: Option> = - context.session_config().get_extension(); - query_engine_state -} - -fn register_remote_dyn_filters_for_region( - registry: &QueryDynFilterRegistry, - region_id: RegionId, - captured_remote_dyn_filters: &[CapturedRemoteDynFilter], -) { - for captured_dyn_filter in captured_remote_dyn_filters { - let children = captured_dyn_filter - .alive_dyn_filter - .children() - .into_iter() - .cloned() - .collect::>(); - let Ok(filter_id) = build_remote_dyn_filter_id( - region_id, - captured_dyn_filter.producer_local_ordinal, - &children, - ) else { - continue; - }; - - let _ = registry.register_remote_dyn_filter( - filter_id.clone(), - captured_dyn_filter.alive_dyn_filter.clone(), - ); - let _ = registry.register_subscriber(&filter_id, Subscriber::new(region_id)); - } -} - -fn bridge_remote_dyn_filters_for_region( - context: &TaskContext, - query_ctx: &QueryContextRef, - region_id: RegionId, - captured_remote_dyn_filters: &[CapturedRemoteDynFilter], -) { - if captured_remote_dyn_filters.is_empty() { - return; - } - - let Some(query_engine_state) = query_engine_state_from_task_context(context) else { - return; - }; - let Some(registry) = query_engine_state.get_or_init_remote_dyn_filter_registry(query_ctx) - else { - return; - }; - - register_remote_dyn_filters_for_region(®istry, region_id, captured_remote_dyn_filters); -} - #[derive(Clone)] pub struct MergeScanExec { table: TableName, @@ -238,7 +153,7 @@ pub struct MergeScanExec { /// Metrics for each partition partition_metrics: Arc>>, query_ctx: QueryContextRef, - captured_remote_dyn_filters: Arc>>, + captured_remote_dyn_filters: Arc>>, target_partition: usize, partition_cols: AliasMapping, } @@ -381,7 +296,7 @@ impl MergeScanExec { .step_by(target_partition) .copied() { - bridge_remote_dyn_filters_for_region( + bridge_dyn_filters_for_region( context.as_ref(), &query_ctx, region_id, @@ -394,11 +309,16 @@ impl MergeScanExec { region_id = %region_id, partition = partition )); + let region_query_ctx = query_context_with_initial_dyn_filter_regs( + &query_ctx, + region_id, + &captured_remote_dyn_filters, + ); let request = QueryRequest { header: Some(RegionRequestHeader { tracing_context: tracing_context.to_w3c(), dbname: dbname.clone(), - query_context: Some(query_ctx.as_ref().into()), + query_context: Some((®ion_query_ctx).into()), }), region_id, plan: plan.clone(), @@ -572,7 +492,7 @@ impl MergeScanExec { }) } - fn captured_remote_dyn_filters(&self) -> Vec { + fn captured_remote_dyn_filters(&self) -> Vec { self.captured_remote_dyn_filters.lock().unwrap().clone() } @@ -855,60 +775,3 @@ impl MergeScanMetric { self.greptime_exec_cost.add(metrics); } } - -#[cfg(test)] -mod tests { - use datafusion_physical_expr::expressions::lit; - use session::query_id::QueryId; - use uuid::Uuid; - - use super::*; - - fn test_query_id(value: u128) -> QueryId { - QueryId::from(Uuid::from_u128(value)) - } - - #[test] - fn capture_remote_dyn_filters_preserves_parent_filter_ordinals() { - let parent_filters = vec![ - Arc::new(Column::new("service", 0)) as Arc, - Arc::new(DynamicFilterPhysicalExpr::new( - vec![Arc::new(Column::new("host", 1)) as Arc<_>], - lit(true) as _, - )) as Arc, - Arc::new(Column::new("zone", 2)) as Arc, - Arc::new(DynamicFilterPhysicalExpr::new( - vec![Arc::new(Column::new("pod", 3)) as Arc<_>], - lit(true) as _, - )) as Arc, - ]; - - let captured = capture_remote_dyn_filters(parent_filters); - - assert_eq!(captured.len(), 2); - assert_eq!(captured[0].producer_local_ordinal, 1); - assert_eq!(captured[1].producer_local_ordinal, 3); - } - - #[test] - fn register_remote_dyn_filters_for_region_reuses_existing_entry() { - let registry = QueryDynFilterRegistry::new(test_query_id(1)); - let captured_remote_dyn_filters = vec![CapturedRemoteDynFilter { - producer_local_ordinal: 2, - alive_dyn_filter: Arc::new(DynamicFilterPhysicalExpr::new( - vec![Arc::new(Column::new("host", 0)) as Arc<_>], - lit(true) as _, - )), - }]; - let region_id = RegionId::new(1024, 7); - - register_remote_dyn_filters_for_region(®istry, region_id, &captured_remote_dyn_filters); - register_remote_dyn_filters_for_region(®istry, region_id, &captured_remote_dyn_filters); - - assert_eq!(registry.entry_count(), 1); - let entry = registry.entries().pop().unwrap(); - assert_eq!(entry.filter_id().producer_ordinal(), 2); - assert_eq!(entry.subscribers().len(), 1); - assert_eq!(entry.subscribers()[0].region_id(), region_id); - } -}