1mod base64_serde;
16mod initial_remote_dyn_filter_reg;
17
18use std::sync::Arc;
19
20use api::v1::region::RegionRequestHeader;
21use datafusion::execution::TaskContext;
22use datafusion::physical_expr::expressions::{Column, InListExpr, lit};
23use datafusion::physical_plan::PhysicalExpr;
24use datafusion::physical_plan::joins::HashTableLookupExpr;
25use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
26use datafusion_common::{DataFusionError, Result as DataFusionResult};
27use datafusion_expr::LogicalPlan;
28use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
29use datafusion_proto::physical_plan::from_proto::parse_physical_expr;
30use datafusion_proto::physical_plan::to_proto::serialize_physical_expr;
31use datafusion_proto::protobuf::PhysicalExprNode;
32use prost::Message;
33use serde::{Deserialize, Serialize};
34use snafu::ensure;
35use store_api::storage::RegionId;
36
37pub use self::initial_remote_dyn_filter_reg::{
39 INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY, InitialDynFilterReg,
40 InitialDynFilterRegs, InitialDynFilterSnapshot,
41};
42use crate::error::{DynFilterPayloadTooLargeSnafu, Error as CommonQueryError};
43
44pub const DYN_FILTER_PROTOCOL_VERSION: u32 = 1;
45pub const REMOTE_DYN_FILTER_PAYLOAD_MAX_BYTES: usize = 512 * 1024;
47
48#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
58#[non_exhaustive]
59#[serde(tag = "kind", content = "payload", rename_all = "snake_case")]
60pub enum DynFilterPayload {
61 Datafusion(#[serde(with = "base64_serde::bytes")] Vec<u8>),
64}
65
66impl DynFilterPayload {
67 pub fn from_datafusion_expr(
74 expr: &Arc<dyn PhysicalExpr>,
75 max_payload_bytes: usize,
76 ) -> DataFusionResult<Self> {
77 match encode_remote_dyn_filter_expr(expr, max_payload_bytes, false) {
78 Ok(bytes) => Ok(Self::Datafusion(bytes)),
79 Err(CommonQueryError::DynFilterPayloadTooLarge { .. }) => {
80 encode_remote_dyn_filter_expr(expr, max_payload_bytes, true)
81 .map(Self::Datafusion)
82 .map_err(DataFusionError::from)
83 }
84 Err(error) => Err(DataFusionError::from(error)),
85 }
86 }
87
88 pub fn decode_datafusion_expr(
94 &self,
95 task_ctx: &TaskContext,
96 input_schema: &datafusion::arrow::datatypes::Schema,
97 max_payload_bytes: usize,
98 ) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
99 let Self::Datafusion(bytes) = self;
100 validate_payload_size(bytes.len(), max_payload_bytes).map_err(DataFusionError::from)?;
101 let codec = DefaultPhysicalExtensionCodec {};
102 let proto = PhysicalExprNode::decode(bytes.as_slice()).map_err(|e| {
103 DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}"))
104 })?;
105
106 let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?;
107 validate_supported_payload_expr(&expr)?;
108 validate_decoded_payload_expr(&expr, input_schema)?;
109 Ok(expr)
110 }
111}
112
113fn encode_physical_expr_to_bytes(expr: &Arc<dyn PhysicalExpr>) -> DataFusionResult<Vec<u8>> {
114 let codec = DefaultPhysicalExtensionCodec {};
115 let proto = serialize_physical_expr(expr, &codec)?;
116 let mut bytes = Vec::new();
117 proto.encode(&mut bytes).map_err(|e| {
118 DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}"))
119 })?;
120 Ok(bytes)
121}
122
123fn encode_remote_dyn_filter_expr(
124 expr: &Arc<dyn PhysicalExpr>,
125 max_payload_bytes: usize,
126 bounds_only: bool,
127) -> Result<Vec<u8>, CommonQueryError> {
128 let expr = portable_remote_dyn_filter_expr(Arc::clone(expr), bounds_only)
129 .map_err(CommonQueryError::from)?;
130 let bytes = encode_physical_expr_to_bytes(&expr).map_err(CommonQueryError::from)?;
131 validate_payload_size(bytes.len(), max_payload_bytes)?;
132 Ok(bytes)
133}
134
135fn portable_remote_dyn_filter_expr(
136 expr: Arc<dyn PhysicalExpr>,
137 bounds_only: bool,
138) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
139 expr.transform_up(|node| {
140 if node.as_any().is::<HashTableLookupExpr>()
141 || (bounds_only && node.as_any().is::<InListExpr>())
142 {
143 Ok(Transformed::yes(lit(true)))
144 } else {
145 Ok(Transformed::no(node))
146 }
147 })
148 .map(|transformed| transformed.data)
149}
150
151pub(crate) fn decode_physical_expr_from_bytes(
152 bytes: &[u8],
153 task_ctx: &TaskContext,
154 input_schema: &datafusion::arrow::datatypes::Schema,
155 max_payload_bytes: usize,
156) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
157 validate_payload_size(bytes.len(), max_payload_bytes).map_err(DataFusionError::from)?;
158 let codec = DefaultPhysicalExtensionCodec {};
159 let proto = PhysicalExprNode::decode(bytes).map_err(|e| {
160 DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}"))
161 })?;
162
163 let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?;
164 validate_supported_payload_expr(&expr)?;
165 validate_decoded_payload_expr(&expr, input_schema)?;
166 Ok(expr)
167}
168
169fn validate_payload_size(
170 payload_size_bytes: usize,
171 max_payload_bytes: usize,
172) -> Result<(), CommonQueryError> {
173 ensure!(
174 payload_size_bytes <= max_payload_bytes,
175 DynFilterPayloadTooLargeSnafu {
176 payload_size_bytes,
177 max_payload_bytes,
178 }
179 );
180
181 Ok(())
182}
183
184fn validate_supported_payload_expr(expr: &Arc<dyn PhysicalExpr>) -> DataFusionResult<()> {
185 expr.apply(|node| {
186 if node.as_any().is::<HashTableLookupExpr>() {
187 return Err(DataFusionError::Plan(
188 "HashTableLookupExpr cannot be encoded into DynFilterPayload::Datafusion"
189 .to_string(),
190 ));
191 }
192
193 Ok(TreeNodeRecursion::Continue)
194 })?;
195
196 Ok(())
197}
198
199fn validate_decoded_payload_expr(
206 expr: &Arc<dyn PhysicalExpr>,
207 input_schema: &datafusion::arrow::datatypes::Schema,
208) -> DataFusionResult<()> {
209 expr.apply(|node| {
210 if let Some(column) = node.as_any().downcast_ref::<Column>() {
211 let Some(field) = input_schema.fields().get(column.index()) else {
212 return Err(DataFusionError::Plan(format!(
213 "Decoded Column '{}' references out-of-bounds index {} for input schema of size {}",
214 column.name(),
215 column.index(),
216 input_schema.fields().len()
217 )));
218 };
219
220 if field.name() != column.name() {
221 return Err(DataFusionError::Plan(format!(
222 "Decoded Column name/index mismatch: payload has '{}' at index {}, but schema field is '{}'",
223 column.name(),
224 column.index(),
225 field.name()
226 )));
227 }
228 }
229
230 Ok(TreeNodeRecursion::Continue)
231 })?;
232
233 Ok(())
234}
235
236#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
242pub struct DynFilterUpdate {
243 pub protocol_version: u32,
245 pub query_id: String,
247 pub filter_id: String,
249 pub generation: u64,
251 pub is_complete: bool,
253 pub payload: DynFilterPayload,
255}
256
257impl DynFilterUpdate {
258 pub fn new(
260 query_id: String,
261 filter_id: String,
262 generation: u64,
263 is_complete: bool,
264 payload: DynFilterPayload,
265 ) -> Self {
266 Self {
267 protocol_version: DYN_FILTER_PROTOCOL_VERSION,
268 query_id,
269 filter_id,
270 generation,
271 is_complete,
272 payload,
273 }
274 }
275}
276
277#[derive(Clone, Debug)]
279pub struct QueryRequest {
280 pub header: Option<RegionRequestHeader>,
282
283 pub region_id: RegionId,
285
286 pub plan: LogicalPlan,
288}
289
290#[cfg(test)]
291mod tests {
292 use std::sync::Arc;
293
294 use base64::Engine;
295 use base64::prelude::BASE64_STANDARD;
296 use datafusion::arrow::datatypes::{DataType, Field, Schema};
297 use datafusion::physical_expr::expressions::{BinaryExpr, Column, InListExpr, lit};
298 use datafusion::physical_plan::expressions::col;
299 use datafusion::physical_plan::joins::join_hash_map::JoinHashMapU32;
300 use datafusion::physical_plan::joins::{HashTableLookupExpr, Map, SeededRandomState};
301 use datafusion_expr::Operator;
302
303 use super::*;
304
305 #[test]
306 fn dyn_filter_update_sets_protocol_version() {
307 let update = DynFilterUpdate::new(
308 "query-1".to_string(),
309 "filter-1".to_string(),
310 3,
311 false,
312 DynFilterPayload::Datafusion(vec![1, 2, 3]),
313 );
314
315 assert_eq!(update.protocol_version, DYN_FILTER_PROTOCOL_VERSION);
316 assert!(!update.is_complete);
317 assert!(
318 matches!(update.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![1, 2, 3])
319 );
320 }
321
322 #[test]
323 fn dyn_filter_update_json_round_trip_preserves_payload_shape() {
324 let update = DynFilterUpdate::new(
325 "query-2".to_string(),
326 "filter-9".to_string(),
327 9,
328 true,
329 DynFilterPayload::Datafusion(vec![9, 8, 7]),
330 );
331
332 let json = serde_json::to_string(&update).unwrap();
333 let value: serde_json::Value = serde_json::from_str(&json).unwrap();
334 let decoded: DynFilterUpdate = serde_json::from_str(&json).unwrap();
335
336 assert_eq!(value["generation"], serde_json::json!(9));
337 assert!(value.get("epoch").is_none());
338 assert_eq!(
339 value["payload"],
340 serde_json::json!({ "kind": "datafusion", "payload": BASE64_STANDARD.encode([9, 8, 7]) })
341 );
342 assert_eq!(decoded, update);
343 assert!(decoded.is_complete);
344 assert!(
345 matches!(decoded.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![9, 8, 7])
346 );
347 }
348
349 #[test]
350 fn dyn_filter_payload_json_uses_base64_for_empty_and_padded_payloads() {
351 let empty = serde_json::to_value(DynFilterPayload::Datafusion(vec![])).unwrap();
352 let one = serde_json::to_value(DynFilterPayload::Datafusion(vec![1])).unwrap();
353 let two = serde_json::to_value(DynFilterPayload::Datafusion(vec![1, 2])).unwrap();
354
355 assert_eq!(
356 empty,
357 serde_json::json!({"kind": "datafusion", "payload": ""})
358 );
359 assert_eq!(
360 one,
361 serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1])})
362 );
363 assert_eq!(
364 two,
365 serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1, 2])})
366 );
367 }
368
369 #[test]
370 fn dyn_filter_payload_json_rejects_invalid_base64() {
371 let err = serde_json::from_value::<DynFilterPayload>(serde_json::json!({
372 "kind": "datafusion",
373 "payload": "not base64!",
374 }))
375 .unwrap_err();
376
377 assert!(
378 err.to_string()
379 .contains("invalid base64 dynamic filter payload")
380 );
381 }
382
383 #[test]
384 fn dyn_filter_payload_round_trips_physical_column_expr() {
385 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
386 let expr: Arc<dyn PhysicalExpr> =
387 Arc::new(Column::new_with_schema("host", &schema).unwrap());
388
389 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
390 let decoded = payload
391 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
392 .unwrap();
393
394 let original = expr.as_any().downcast_ref::<Column>().unwrap();
395 let decoded = decoded.as_any().downcast_ref::<Column>().unwrap();
396
397 assert_eq!(decoded.name(), original.name());
398 assert_eq!(decoded.index(), original.index());
399 }
400
401 #[test]
402 fn dyn_filter_payload_decode_rejects_invalid_bytes() {
403 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
404 let payload = DynFilterPayload::Datafusion(vec![1, 2, 3]);
405
406 let err = payload
407 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
408 .unwrap_err();
409
410 assert!(matches!(err, DataFusionError::Internal(_)));
411 }
412
413 #[test]
414 fn dyn_filter_payload_decode_rejects_column_name_index_mismatch() {
415 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
416 let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("service", 0));
417
418 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
419 let err = payload
420 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
421 .unwrap_err();
422
423 let msg = err.to_string();
424 assert!(
425 msg.contains("name/index mismatch"),
426 "expected name/index mismatch error, got: {msg}"
427 );
428 assert!(msg.contains("service"));
429 assert!(msg.contains("host"));
430 }
431
432 #[test]
433 fn dyn_filter_payload_decode_rejects_out_of_bounds_column_index() {
434 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
435 let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 1));
436
437 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
438 let err = payload
439 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
440 .unwrap_err();
441
442 assert!(matches!(err, DataFusionError::Plan(_)));
443 }
444
445 #[test]
446 fn dyn_filter_payload_hash_lookup_fallback_preserves_bounds() {
447 let schema = Arc::new(Schema::new(vec![Field::new(
448 "device_id",
449 DataType::Int32,
450 false,
451 )]));
452 let device_id = col("device_id", &schema).unwrap();
453 let lower_bound = Arc::new(BinaryExpr::new(
454 Arc::clone(&device_id),
455 Operator::GtEq,
456 lit(10i32),
457 )) as Arc<dyn PhysicalExpr>;
458 let lookup = Arc::new(HashTableLookupExpr::new(
459 vec![Arc::clone(&device_id)],
460 SeededRandomState::with_seeds(0, 0, 0, 0),
461 Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(0)))),
462 "hash_lookup".to_string(),
463 )) as Arc<dyn PhysicalExpr>;
464 let expr =
465 Arc::new(BinaryExpr::new(lower_bound, Operator::And, lookup)) as Arc<dyn PhysicalExpr>;
466
467 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
468 let decoded = payload
469 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
470 .unwrap();
471
472 assert!(!contains_expr::<HashTableLookupExpr>(&decoded));
473 let decoded_display = decoded.to_string();
474 assert!(decoded_display.contains("device_id"));
475 assert!(decoded_display.contains(">="));
476 assert!(!decoded_display.contains("hash_lookup"));
477 }
478
479 #[test]
480 fn dyn_filter_payload_oversized_inlist_falls_back_to_bounds() {
481 let schema = Arc::new(Schema::new(vec![Field::new(
482 "device_id",
483 DataType::Int32,
484 false,
485 )]));
486 let device_id = col("device_id", &schema).unwrap();
487 let lower_bound = Arc::new(BinaryExpr::new(
488 Arc::clone(&device_id),
489 Operator::GtEq,
490 lit(8192i32),
491 )) as Arc<dyn PhysicalExpr>;
492 let upper_bound = Arc::new(BinaryExpr::new(
493 Arc::clone(&device_id),
494 Operator::LtEq,
495 lit(8255i32),
496 )) as Arc<dyn PhysicalExpr>;
497 let bounds = Arc::new(BinaryExpr::new(lower_bound, Operator::And, upper_bound))
498 as Arc<dyn PhysicalExpr>;
499 let in_list = Arc::new(
500 InListExpr::try_new(
501 Arc::clone(&device_id),
502 (8192..8256).map(lit).collect(),
503 false,
504 &schema,
505 )
506 .unwrap(),
507 ) as Arc<dyn PhysicalExpr>;
508 let expr = Arc::new(BinaryExpr::new(Arc::clone(&bounds), Operator::And, in_list))
509 as Arc<dyn PhysicalExpr>;
510 let bounds_only = portable_remote_dyn_filter_expr(Arc::clone(&expr), true).unwrap();
511 let bounds_only_size = encode_physical_expr_to_bytes(&bounds_only).unwrap().len();
512 let full_size = encode_physical_expr_to_bytes(&expr).unwrap().len();
513 assert!(full_size > bounds_only_size);
514
515 let payload = DynFilterPayload::from_datafusion_expr(&expr, bounds_only_size).unwrap();
516 let decoded = payload
517 .decode_datafusion_expr(&TaskContext::default(), &schema, bounds_only_size)
518 .unwrap();
519
520 assert!(!contains_expr::<InListExpr>(&decoded));
521 let decoded_display = decoded.to_string();
522 assert!(decoded_display.contains("device_id"));
523 assert!(decoded_display.contains(">="));
524 assert!(decoded_display.contains("<="));
525 }
526
527 #[test]
528 fn dyn_filter_payload_rejects_oversized_payload() {
529 let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 0));
530
531 let err = DynFilterPayload::from_datafusion_expr(&expr, 1).unwrap_err();
532
533 let DataFusionError::External(error) = err else {
534 panic!("expected external common query error, got: {err:?}");
535 };
536 assert!(matches!(
537 error.downcast_ref::<CommonQueryError>(),
538 Some(CommonQueryError::DynFilterPayloadTooLarge { .. })
539 ));
540 }
541
542 fn contains_expr<T: 'static>(expr: &Arc<dyn PhysicalExpr>) -> bool {
543 let mut found = false;
544 expr.apply(|node| {
545 if node.as_any().is::<T>() {
546 found = true;
547 Ok(TreeNodeRecursion::Stop)
548 } else {
549 Ok(TreeNodeRecursion::Continue)
550 }
551 })
552 .unwrap();
553 found
554 }
555}