1mod base64_serde;
16mod initial_remote_dyn_filter_reg;
17
18use std::sync::Arc;
19
20use api::v1::region::RegionRequestHeader;
21use datafusion::execution::TaskContext;
22use datafusion::physical_expr::expressions::Column;
23use datafusion::physical_plan::PhysicalExpr;
24use datafusion::physical_plan::joins::HashTableLookupExpr;
25use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
26use datafusion_common::{DataFusionError, Result as DataFusionResult};
27use datafusion_expr::LogicalPlan;
28use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
29use datafusion_proto::physical_plan::from_proto::parse_physical_expr;
30use datafusion_proto::physical_plan::to_proto::serialize_physical_expr;
31use datafusion_proto::protobuf::PhysicalExprNode;
32use prost::Message;
33use serde::{Deserialize, Serialize};
34use store_api::storage::RegionId;
35
36pub use self::initial_remote_dyn_filter_reg::{
38 INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY,
39 INITIAL_REMOTE_DYN_FILTER_REGS_MAX_TOTAL_PROTO_BYTES, InitialDynFilterReg,
40 InitialDynFilterRegs, InitialDynFilterSnapshot,
41};
42
43pub const DYN_FILTER_PROTOCOL_VERSION: u32 = 1;
44
45#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
55#[non_exhaustive]
56#[serde(tag = "kind", content = "payload", rename_all = "snake_case")]
57pub enum DynFilterPayload {
58 Datafusion(#[serde(with = "base64_serde::bytes")] Vec<u8>),
61}
62
63impl DynFilterPayload {
64 pub fn from_datafusion_expr(
69 expr: &Arc<dyn PhysicalExpr>,
70 max_payload_bytes: usize,
71 ) -> DataFusionResult<Self> {
72 validate_supported_payload_expr(expr)?;
73
74 let codec = DefaultPhysicalExtensionCodec {};
75 let proto = serialize_physical_expr(expr, &codec)?;
76 let mut bytes = Vec::new();
77 proto.encode(&mut bytes).map_err(|e| {
78 DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}"))
79 })?;
80
81 validate_payload_size(bytes.len(), max_payload_bytes)?;
82
83 Ok(Self::Datafusion(bytes))
84 }
85
86 pub fn decode_datafusion_expr(
92 &self,
93 task_ctx: &TaskContext,
94 input_schema: &datafusion::arrow::datatypes::Schema,
95 max_payload_bytes: usize,
96 ) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
97 let Self::Datafusion(bytes) = self;
98 validate_payload_size(bytes.len(), max_payload_bytes)?;
99 let codec = DefaultPhysicalExtensionCodec {};
100 let proto = PhysicalExprNode::decode(bytes.as_slice()).map_err(|e| {
101 DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}"))
102 })?;
103
104 let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?;
105 validate_supported_payload_expr(&expr)?;
106 validate_decoded_payload_expr(&expr, input_schema)?;
107 Ok(expr)
108 }
109}
110
111fn encode_physical_expr_to_bytes(expr: &Arc<dyn PhysicalExpr>) -> DataFusionResult<Vec<u8>> {
112 let codec = DefaultPhysicalExtensionCodec {};
113 let proto = serialize_physical_expr(expr, &codec)?;
114 let mut bytes = Vec::new();
115 proto.encode(&mut bytes).map_err(|e| {
116 DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}"))
117 })?;
118 Ok(bytes)
119}
120
121pub(crate) fn decode_physical_expr_from_bytes(
122 bytes: &[u8],
123 task_ctx: &TaskContext,
124 input_schema: &datafusion::arrow::datatypes::Schema,
125 max_payload_bytes: usize,
126) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
127 validate_payload_size(bytes.len(), max_payload_bytes)?;
128 let codec = DefaultPhysicalExtensionCodec {};
129 let proto = PhysicalExprNode::decode(bytes).map_err(|e| {
130 DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}"))
131 })?;
132
133 let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?;
134 validate_supported_payload_expr(&expr)?;
135 validate_decoded_payload_expr(&expr, input_schema)?;
136 Ok(expr)
137}
138
139fn validate_payload_size(
140 payload_size_bytes: usize,
141 max_payload_bytes: usize,
142) -> DataFusionResult<()> {
143 if payload_size_bytes > max_payload_bytes {
144 return Err(DataFusionError::Plan(format!(
145 "DynFilterPayload::Datafusion is {} bytes, which exceeds the configured limit of {} bytes",
146 payload_size_bytes, max_payload_bytes
147 )));
148 }
149
150 Ok(())
151}
152
153fn validate_supported_payload_expr(expr: &Arc<dyn PhysicalExpr>) -> DataFusionResult<()> {
154 expr.apply(|node| {
155 if node.as_any().is::<HashTableLookupExpr>() {
156 return Err(DataFusionError::Plan(
157 "HashTableLookupExpr cannot be encoded into DynFilterPayload::Datafusion"
158 .to_string(),
159 ));
160 }
161
162 Ok(TreeNodeRecursion::Continue)
163 })?;
164
165 Ok(())
166}
167
168fn validate_decoded_payload_expr(
175 expr: &Arc<dyn PhysicalExpr>,
176 input_schema: &datafusion::arrow::datatypes::Schema,
177) -> DataFusionResult<()> {
178 expr.apply(|node| {
179 if let Some(column) = node.as_any().downcast_ref::<Column>() {
180 let Some(field) = input_schema.fields().get(column.index()) else {
181 return Err(DataFusionError::Plan(format!(
182 "Decoded Column '{}' references out-of-bounds index {} for input schema of size {}",
183 column.name(),
184 column.index(),
185 input_schema.fields().len()
186 )));
187 };
188
189 if field.name() != column.name() {
190 return Err(DataFusionError::Plan(format!(
191 "Decoded Column name/index mismatch: payload has '{}' at index {}, but schema field is '{}'",
192 column.name(),
193 column.index(),
194 field.name()
195 )));
196 }
197 }
198
199 Ok(TreeNodeRecursion::Continue)
200 })?;
201
202 Ok(())
203}
204
205#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
211pub struct DynFilterUpdate {
212 pub protocol_version: u32,
214 pub query_id: String,
216 pub filter_id: String,
218 pub generation: u64,
220 pub is_complete: bool,
222 pub payload: DynFilterPayload,
224}
225
226impl DynFilterUpdate {
227 pub fn new(
229 query_id: String,
230 filter_id: String,
231 generation: u64,
232 is_complete: bool,
233 payload: DynFilterPayload,
234 ) -> Self {
235 Self {
236 protocol_version: DYN_FILTER_PROTOCOL_VERSION,
237 query_id,
238 filter_id,
239 generation,
240 is_complete,
241 payload,
242 }
243 }
244}
245
246#[derive(Clone, Debug)]
248pub struct QueryRequest {
249 pub header: Option<RegionRequestHeader>,
251
252 pub region_id: RegionId,
254
255 pub plan: LogicalPlan,
257}
258
259#[cfg(test)]
260mod tests {
261 use std::sync::Arc;
262
263 use base64::Engine;
264 use base64::prelude::BASE64_STANDARD;
265 use datafusion::arrow::datatypes::{DataType, Field, Schema};
266 use datafusion::physical_expr::expressions::Column;
267
268 use super::*;
269
270 #[test]
271 fn dyn_filter_update_sets_protocol_version() {
272 let update = DynFilterUpdate::new(
273 "query-1".to_string(),
274 "filter-1".to_string(),
275 3,
276 false,
277 DynFilterPayload::Datafusion(vec![1, 2, 3]),
278 );
279
280 assert_eq!(update.protocol_version, DYN_FILTER_PROTOCOL_VERSION);
281 assert!(!update.is_complete);
282 assert!(
283 matches!(update.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![1, 2, 3])
284 );
285 }
286
287 #[test]
288 fn dyn_filter_update_json_round_trip_preserves_payload_shape() {
289 let update = DynFilterUpdate::new(
290 "query-2".to_string(),
291 "filter-9".to_string(),
292 9,
293 true,
294 DynFilterPayload::Datafusion(vec![9, 8, 7]),
295 );
296
297 let json = serde_json::to_string(&update).unwrap();
298 let value: serde_json::Value = serde_json::from_str(&json).unwrap();
299 let decoded: DynFilterUpdate = serde_json::from_str(&json).unwrap();
300
301 assert_eq!(value["generation"], serde_json::json!(9));
302 assert!(value.get("epoch").is_none());
303 assert_eq!(
304 value["payload"],
305 serde_json::json!({ "kind": "datafusion", "payload": BASE64_STANDARD.encode([9, 8, 7]) })
306 );
307 assert_eq!(decoded, update);
308 assert!(decoded.is_complete);
309 assert!(
310 matches!(decoded.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![9, 8, 7])
311 );
312 }
313
314 #[test]
315 fn dyn_filter_payload_json_uses_base64_for_empty_and_padded_payloads() {
316 let empty = serde_json::to_value(DynFilterPayload::Datafusion(vec![])).unwrap();
317 let one = serde_json::to_value(DynFilterPayload::Datafusion(vec![1])).unwrap();
318 let two = serde_json::to_value(DynFilterPayload::Datafusion(vec![1, 2])).unwrap();
319
320 assert_eq!(
321 empty,
322 serde_json::json!({"kind": "datafusion", "payload": ""})
323 );
324 assert_eq!(
325 one,
326 serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1])})
327 );
328 assert_eq!(
329 two,
330 serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1, 2])})
331 );
332 }
333
334 #[test]
335 fn dyn_filter_payload_json_rejects_invalid_base64() {
336 let err = serde_json::from_value::<DynFilterPayload>(serde_json::json!({
337 "kind": "datafusion",
338 "payload": "not base64!",
339 }))
340 .unwrap_err();
341
342 assert!(
343 err.to_string()
344 .contains("invalid base64 dynamic filter payload")
345 );
346 }
347
348 #[test]
349 fn dyn_filter_payload_round_trips_physical_column_expr() {
350 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
351 let expr: Arc<dyn PhysicalExpr> =
352 Arc::new(Column::new_with_schema("host", &schema).unwrap());
353
354 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
355 let decoded = payload
356 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
357 .unwrap();
358
359 let original = expr.as_any().downcast_ref::<Column>().unwrap();
360 let decoded = decoded.as_any().downcast_ref::<Column>().unwrap();
361
362 assert_eq!(decoded.name(), original.name());
363 assert_eq!(decoded.index(), original.index());
364 }
365
366 #[test]
367 fn dyn_filter_payload_decode_rejects_invalid_bytes() {
368 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
369 let payload = DynFilterPayload::Datafusion(vec![1, 2, 3]);
370
371 let err = payload
372 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
373 .unwrap_err();
374
375 assert!(matches!(err, DataFusionError::Internal(_)));
376 }
377
378 #[test]
379 fn dyn_filter_payload_decode_rejects_column_name_index_mismatch() {
380 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
381 let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("service", 0));
382
383 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
384 let err = payload
385 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
386 .unwrap_err();
387
388 let msg = err.to_string();
389 assert!(
390 msg.contains("name/index mismatch"),
391 "expected name/index mismatch error, got: {msg}"
392 );
393 assert!(msg.contains("service"));
394 assert!(msg.contains("host"));
395 }
396
397 #[test]
398 fn dyn_filter_payload_decode_rejects_out_of_bounds_column_index() {
399 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
400 let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 1));
401
402 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
403 let err = payload
404 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
405 .unwrap_err();
406
407 assert!(matches!(err, DataFusionError::Plan(_)));
408 }
409
410 #[test]
411 fn dyn_filter_payload_rejects_oversized_payload() {
412 let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 0));
413
414 let err = DynFilterPayload::from_datafusion_expr(&expr, 1).unwrap_err();
415
416 assert!(matches!(err, DataFusionError::Plan(_)));
417 }
418}