1mod base64_serde;
16mod initial_remote_dyn_filter_reg;
17
18use std::sync::Arc;
19
20use api::v1::region::RegionRequestHeader;
21use datafusion::execution::TaskContext;
22use datafusion::physical_expr::expressions::{Column, InListExpr, lit};
23use datafusion::physical_plan::PhysicalExpr;
24use datafusion::physical_plan::joins::HashTableLookupExpr;
25use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
26use datafusion_common::{DataFusionError, Result as DataFusionResult};
27use datafusion_expr::LogicalPlan;
28use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
29use datafusion_proto::physical_plan::from_proto::parse_physical_expr;
30use datafusion_proto::physical_plan::to_proto::serialize_physical_expr;
31use datafusion_proto::protobuf::PhysicalExprNode;
32use prost::Message;
33use serde::{Deserialize, Serialize};
34use snafu::ensure;
35use store_api::storage::RegionId;
36
37pub use self::initial_remote_dyn_filter_reg::{
39 INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY,
40 INITIAL_REMOTE_DYN_FILTER_REGS_MAX_TOTAL_PROTO_BYTES, InitialDynFilterReg,
41 InitialDynFilterRegs, InitialDynFilterSnapshot,
42};
43use crate::error::{DynFilterPayloadTooLargeSnafu, Error as CommonQueryError};
44
45pub const DYN_FILTER_PROTOCOL_VERSION: u32 = 1;
46
47#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
57#[non_exhaustive]
58#[serde(tag = "kind", content = "payload", rename_all = "snake_case")]
59pub enum DynFilterPayload {
60 Datafusion(#[serde(with = "base64_serde::bytes")] Vec<u8>),
63}
64
65impl DynFilterPayload {
66 pub fn from_datafusion_expr(
73 expr: &Arc<dyn PhysicalExpr>,
74 max_payload_bytes: usize,
75 ) -> DataFusionResult<Self> {
76 match encode_remote_dyn_filter_expr(expr, max_payload_bytes, false) {
77 Ok(bytes) => Ok(Self::Datafusion(bytes)),
78 Err(CommonQueryError::DynFilterPayloadTooLarge { .. }) => {
79 encode_remote_dyn_filter_expr(expr, max_payload_bytes, true)
80 .map(Self::Datafusion)
81 .map_err(DataFusionError::from)
82 }
83 Err(error) => Err(DataFusionError::from(error)),
84 }
85 }
86
87 pub fn decode_datafusion_expr(
93 &self,
94 task_ctx: &TaskContext,
95 input_schema: &datafusion::arrow::datatypes::Schema,
96 max_payload_bytes: usize,
97 ) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
98 let Self::Datafusion(bytes) = self;
99 validate_payload_size(bytes.len(), max_payload_bytes).map_err(DataFusionError::from)?;
100 let codec = DefaultPhysicalExtensionCodec {};
101 let proto = PhysicalExprNode::decode(bytes.as_slice()).map_err(|e| {
102 DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}"))
103 })?;
104
105 let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?;
106 validate_supported_payload_expr(&expr)?;
107 validate_decoded_payload_expr(&expr, input_schema)?;
108 Ok(expr)
109 }
110}
111
112fn encode_physical_expr_to_bytes(expr: &Arc<dyn PhysicalExpr>) -> DataFusionResult<Vec<u8>> {
113 let codec = DefaultPhysicalExtensionCodec {};
114 let proto = serialize_physical_expr(expr, &codec)?;
115 let mut bytes = Vec::new();
116 proto.encode(&mut bytes).map_err(|e| {
117 DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}"))
118 })?;
119 Ok(bytes)
120}
121
122fn encode_remote_dyn_filter_expr(
123 expr: &Arc<dyn PhysicalExpr>,
124 max_payload_bytes: usize,
125 bounds_only: bool,
126) -> Result<Vec<u8>, CommonQueryError> {
127 let expr = portable_remote_dyn_filter_expr(Arc::clone(expr), bounds_only)
128 .map_err(CommonQueryError::from)?;
129 let bytes = encode_physical_expr_to_bytes(&expr).map_err(CommonQueryError::from)?;
130 validate_payload_size(bytes.len(), max_payload_bytes)?;
131 Ok(bytes)
132}
133
134fn portable_remote_dyn_filter_expr(
135 expr: Arc<dyn PhysicalExpr>,
136 bounds_only: bool,
137) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
138 expr.transform_up(|node| {
139 if node.as_any().is::<HashTableLookupExpr>()
140 || (bounds_only && node.as_any().is::<InListExpr>())
141 {
142 Ok(Transformed::yes(lit(true)))
143 } else {
144 Ok(Transformed::no(node))
145 }
146 })
147 .map(|transformed| transformed.data)
148}
149
150pub(crate) fn decode_physical_expr_from_bytes(
151 bytes: &[u8],
152 task_ctx: &TaskContext,
153 input_schema: &datafusion::arrow::datatypes::Schema,
154 max_payload_bytes: usize,
155) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
156 validate_payload_size(bytes.len(), max_payload_bytes).map_err(DataFusionError::from)?;
157 let codec = DefaultPhysicalExtensionCodec {};
158 let proto = PhysicalExprNode::decode(bytes).map_err(|e| {
159 DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}"))
160 })?;
161
162 let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?;
163 validate_supported_payload_expr(&expr)?;
164 validate_decoded_payload_expr(&expr, input_schema)?;
165 Ok(expr)
166}
167
168fn validate_payload_size(
169 payload_size_bytes: usize,
170 max_payload_bytes: usize,
171) -> Result<(), CommonQueryError> {
172 ensure!(
173 payload_size_bytes <= max_payload_bytes,
174 DynFilterPayloadTooLargeSnafu {
175 payload_size_bytes,
176 max_payload_bytes,
177 }
178 );
179
180 Ok(())
181}
182
183fn validate_supported_payload_expr(expr: &Arc<dyn PhysicalExpr>) -> DataFusionResult<()> {
184 expr.apply(|node| {
185 if node.as_any().is::<HashTableLookupExpr>() {
186 return Err(DataFusionError::Plan(
187 "HashTableLookupExpr cannot be encoded into DynFilterPayload::Datafusion"
188 .to_string(),
189 ));
190 }
191
192 Ok(TreeNodeRecursion::Continue)
193 })?;
194
195 Ok(())
196}
197
198fn validate_decoded_payload_expr(
205 expr: &Arc<dyn PhysicalExpr>,
206 input_schema: &datafusion::arrow::datatypes::Schema,
207) -> DataFusionResult<()> {
208 expr.apply(|node| {
209 if let Some(column) = node.as_any().downcast_ref::<Column>() {
210 let Some(field) = input_schema.fields().get(column.index()) else {
211 return Err(DataFusionError::Plan(format!(
212 "Decoded Column '{}' references out-of-bounds index {} for input schema of size {}",
213 column.name(),
214 column.index(),
215 input_schema.fields().len()
216 )));
217 };
218
219 if field.name() != column.name() {
220 return Err(DataFusionError::Plan(format!(
221 "Decoded Column name/index mismatch: payload has '{}' at index {}, but schema field is '{}'",
222 column.name(),
223 column.index(),
224 field.name()
225 )));
226 }
227 }
228
229 Ok(TreeNodeRecursion::Continue)
230 })?;
231
232 Ok(())
233}
234
235#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
241pub struct DynFilterUpdate {
242 pub protocol_version: u32,
244 pub query_id: String,
246 pub filter_id: String,
248 pub generation: u64,
250 pub is_complete: bool,
252 pub payload: DynFilterPayload,
254}
255
256impl DynFilterUpdate {
257 pub fn new(
259 query_id: String,
260 filter_id: String,
261 generation: u64,
262 is_complete: bool,
263 payload: DynFilterPayload,
264 ) -> Self {
265 Self {
266 protocol_version: DYN_FILTER_PROTOCOL_VERSION,
267 query_id,
268 filter_id,
269 generation,
270 is_complete,
271 payload,
272 }
273 }
274}
275
276#[derive(Clone, Debug)]
278pub struct QueryRequest {
279 pub header: Option<RegionRequestHeader>,
281
282 pub region_id: RegionId,
284
285 pub plan: LogicalPlan,
287}
288
289#[cfg(test)]
290mod tests {
291 use std::sync::Arc;
292
293 use base64::Engine;
294 use base64::prelude::BASE64_STANDARD;
295 use datafusion::arrow::datatypes::{DataType, Field, Schema};
296 use datafusion::physical_expr::expressions::{BinaryExpr, Column, InListExpr, lit};
297 use datafusion::physical_plan::expressions::col;
298 use datafusion::physical_plan::joins::join_hash_map::JoinHashMapU32;
299 use datafusion::physical_plan::joins::{HashTableLookupExpr, Map, SeededRandomState};
300 use datafusion_expr::Operator;
301
302 use super::*;
303
304 #[test]
305 fn dyn_filter_update_sets_protocol_version() {
306 let update = DynFilterUpdate::new(
307 "query-1".to_string(),
308 "filter-1".to_string(),
309 3,
310 false,
311 DynFilterPayload::Datafusion(vec![1, 2, 3]),
312 );
313
314 assert_eq!(update.protocol_version, DYN_FILTER_PROTOCOL_VERSION);
315 assert!(!update.is_complete);
316 assert!(
317 matches!(update.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![1, 2, 3])
318 );
319 }
320
321 #[test]
322 fn dyn_filter_update_json_round_trip_preserves_payload_shape() {
323 let update = DynFilterUpdate::new(
324 "query-2".to_string(),
325 "filter-9".to_string(),
326 9,
327 true,
328 DynFilterPayload::Datafusion(vec![9, 8, 7]),
329 );
330
331 let json = serde_json::to_string(&update).unwrap();
332 let value: serde_json::Value = serde_json::from_str(&json).unwrap();
333 let decoded: DynFilterUpdate = serde_json::from_str(&json).unwrap();
334
335 assert_eq!(value["generation"], serde_json::json!(9));
336 assert!(value.get("epoch").is_none());
337 assert_eq!(
338 value["payload"],
339 serde_json::json!({ "kind": "datafusion", "payload": BASE64_STANDARD.encode([9, 8, 7]) })
340 );
341 assert_eq!(decoded, update);
342 assert!(decoded.is_complete);
343 assert!(
344 matches!(decoded.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![9, 8, 7])
345 );
346 }
347
348 #[test]
349 fn dyn_filter_payload_json_uses_base64_for_empty_and_padded_payloads() {
350 let empty = serde_json::to_value(DynFilterPayload::Datafusion(vec![])).unwrap();
351 let one = serde_json::to_value(DynFilterPayload::Datafusion(vec![1])).unwrap();
352 let two = serde_json::to_value(DynFilterPayload::Datafusion(vec![1, 2])).unwrap();
353
354 assert_eq!(
355 empty,
356 serde_json::json!({"kind": "datafusion", "payload": ""})
357 );
358 assert_eq!(
359 one,
360 serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1])})
361 );
362 assert_eq!(
363 two,
364 serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1, 2])})
365 );
366 }
367
368 #[test]
369 fn dyn_filter_payload_json_rejects_invalid_base64() {
370 let err = serde_json::from_value::<DynFilterPayload>(serde_json::json!({
371 "kind": "datafusion",
372 "payload": "not base64!",
373 }))
374 .unwrap_err();
375
376 assert!(
377 err.to_string()
378 .contains("invalid base64 dynamic filter payload")
379 );
380 }
381
382 #[test]
383 fn dyn_filter_payload_round_trips_physical_column_expr() {
384 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
385 let expr: Arc<dyn PhysicalExpr> =
386 Arc::new(Column::new_with_schema("host", &schema).unwrap());
387
388 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
389 let decoded = payload
390 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
391 .unwrap();
392
393 let original = expr.as_any().downcast_ref::<Column>().unwrap();
394 let decoded = decoded.as_any().downcast_ref::<Column>().unwrap();
395
396 assert_eq!(decoded.name(), original.name());
397 assert_eq!(decoded.index(), original.index());
398 }
399
400 #[test]
401 fn dyn_filter_payload_decode_rejects_invalid_bytes() {
402 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
403 let payload = DynFilterPayload::Datafusion(vec![1, 2, 3]);
404
405 let err = payload
406 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
407 .unwrap_err();
408
409 assert!(matches!(err, DataFusionError::Internal(_)));
410 }
411
412 #[test]
413 fn dyn_filter_payload_decode_rejects_column_name_index_mismatch() {
414 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
415 let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("service", 0));
416
417 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
418 let err = payload
419 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
420 .unwrap_err();
421
422 let msg = err.to_string();
423 assert!(
424 msg.contains("name/index mismatch"),
425 "expected name/index mismatch error, got: {msg}"
426 );
427 assert!(msg.contains("service"));
428 assert!(msg.contains("host"));
429 }
430
431 #[test]
432 fn dyn_filter_payload_decode_rejects_out_of_bounds_column_index() {
433 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
434 let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 1));
435
436 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
437 let err = payload
438 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
439 .unwrap_err();
440
441 assert!(matches!(err, DataFusionError::Plan(_)));
442 }
443
444 #[test]
445 fn dyn_filter_payload_hash_lookup_fallback_preserves_bounds() {
446 let schema = Arc::new(Schema::new(vec![Field::new(
447 "device_id",
448 DataType::Int32,
449 false,
450 )]));
451 let device_id = col("device_id", &schema).unwrap();
452 let lower_bound = Arc::new(BinaryExpr::new(
453 Arc::clone(&device_id),
454 Operator::GtEq,
455 lit(10i32),
456 )) as Arc<dyn PhysicalExpr>;
457 let lookup = Arc::new(HashTableLookupExpr::new(
458 vec![Arc::clone(&device_id)],
459 SeededRandomState::with_seeds(0, 0, 0, 0),
460 Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(0)))),
461 "hash_lookup".to_string(),
462 )) as Arc<dyn PhysicalExpr>;
463 let expr =
464 Arc::new(BinaryExpr::new(lower_bound, Operator::And, lookup)) as Arc<dyn PhysicalExpr>;
465
466 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
467 let decoded = payload
468 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
469 .unwrap();
470
471 assert!(!contains_expr::<HashTableLookupExpr>(&decoded));
472 let decoded_display = decoded.to_string();
473 assert!(decoded_display.contains("device_id"));
474 assert!(decoded_display.contains(">="));
475 assert!(!decoded_display.contains("hash_lookup"));
476 }
477
478 #[test]
479 fn dyn_filter_payload_oversized_inlist_falls_back_to_bounds() {
480 let schema = Arc::new(Schema::new(vec![Field::new(
481 "device_id",
482 DataType::Int32,
483 false,
484 )]));
485 let device_id = col("device_id", &schema).unwrap();
486 let lower_bound = Arc::new(BinaryExpr::new(
487 Arc::clone(&device_id),
488 Operator::GtEq,
489 lit(8192i32),
490 )) as Arc<dyn PhysicalExpr>;
491 let upper_bound = Arc::new(BinaryExpr::new(
492 Arc::clone(&device_id),
493 Operator::LtEq,
494 lit(8255i32),
495 )) as Arc<dyn PhysicalExpr>;
496 let bounds = Arc::new(BinaryExpr::new(lower_bound, Operator::And, upper_bound))
497 as Arc<dyn PhysicalExpr>;
498 let in_list = Arc::new(
499 InListExpr::try_new(
500 Arc::clone(&device_id),
501 (8192..8256).map(lit).collect(),
502 false,
503 &schema,
504 )
505 .unwrap(),
506 ) as Arc<dyn PhysicalExpr>;
507 let expr = Arc::new(BinaryExpr::new(Arc::clone(&bounds), Operator::And, in_list))
508 as Arc<dyn PhysicalExpr>;
509 let bounds_only = portable_remote_dyn_filter_expr(Arc::clone(&expr), true).unwrap();
510 let bounds_only_size = encode_physical_expr_to_bytes(&bounds_only).unwrap().len();
511 let full_size = encode_physical_expr_to_bytes(&expr).unwrap().len();
512 assert!(full_size > bounds_only_size);
513
514 let payload = DynFilterPayload::from_datafusion_expr(&expr, bounds_only_size).unwrap();
515 let decoded = payload
516 .decode_datafusion_expr(&TaskContext::default(), &schema, bounds_only_size)
517 .unwrap();
518
519 assert!(!contains_expr::<InListExpr>(&decoded));
520 let decoded_display = decoded.to_string();
521 assert!(decoded_display.contains("device_id"));
522 assert!(decoded_display.contains(">="));
523 assert!(decoded_display.contains("<="));
524 }
525
526 #[test]
527 fn dyn_filter_payload_rejects_oversized_payload() {
528 let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 0));
529
530 let err = DynFilterPayload::from_datafusion_expr(&expr, 1).unwrap_err();
531
532 let DataFusionError::External(error) = err else {
533 panic!("expected external common query error, got: {err:?}");
534 };
535 assert!(matches!(
536 error.downcast_ref::<CommonQueryError>(),
537 Some(CommonQueryError::DynFilterPayloadTooLarge { .. })
538 ));
539 }
540
541 fn contains_expr<T: 'static>(expr: &Arc<dyn PhysicalExpr>) -> bool {
542 let mut found = false;
543 expr.apply(|node| {
544 if node.as_any().is::<T>() {
545 found = true;
546 Ok(TreeNodeRecursion::Stop)
547 } else {
548 Ok(TreeNodeRecursion::Continue)
549 }
550 })
551 .unwrap();
552 found
553 }
554}