1use std::sync::Arc;
16
17use api::v1::region::RegionRequestHeader;
18use datafusion::arrow::datatypes::Schema;
19use datafusion::execution::TaskContext;
20use datafusion::physical_expr::expressions::Column;
21use datafusion::physical_plan::PhysicalExpr;
22use datafusion::physical_plan::joins::HashTableLookupExpr;
23use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
24use datafusion_common::{DataFusionError, Result as DataFusionResult};
25use datafusion_expr::LogicalPlan;
26use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
27use datafusion_proto::physical_plan::from_proto::parse_physical_expr;
28use datafusion_proto::physical_plan::to_proto::serialize_physical_expr;
29use datafusion_proto::protobuf::PhysicalExprNode;
30use prost::Message;
31use serde::{Deserialize, Serialize};
32use store_api::storage::RegionId;
33
34pub const DYN_FILTER_PROTOCOL_VERSION: u32 = 1;
36
37#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
47#[non_exhaustive]
48#[serde(tag = "kind", content = "payload", rename_all = "snake_case")]
49pub enum DynFilterPayload {
50 Datafusion(#[serde(with = "base64_bytes")] Vec<u8>),
53}
54
55mod base64_bytes {
56 use base64::Engine;
57 use base64::prelude::BASE64_STANDARD;
58 use serde::de::Error;
59 use serde::{Deserialize, Deserializer, Serializer};
60
61 pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
62 where
63 S: Serializer,
64 {
65 serializer.serialize_str(&BASE64_STANDARD.encode(bytes))
66 }
67
68 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
69 where
70 D: Deserializer<'de>,
71 {
72 let encoded = String::deserialize(deserializer)?;
73 BASE64_STANDARD.decode(encoded).map_err(|err| {
74 D::Error::custom(format!("invalid base64 dynamic filter payload: {err}"))
75 })
76 }
77}
78
79impl DynFilterPayload {
80 pub fn from_datafusion_expr(
85 expr: &Arc<dyn PhysicalExpr>,
86 max_payload_bytes: usize,
87 ) -> DataFusionResult<Self> {
88 validate_supported_payload_expr(expr)?;
89
90 let codec = DefaultPhysicalExtensionCodec {};
91 let proto = serialize_physical_expr(expr, &codec)?;
92 let mut bytes = Vec::new();
93 proto.encode(&mut bytes).map_err(|e| {
94 DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}"))
95 })?;
96
97 validate_payload_size(bytes.len(), max_payload_bytes)?;
98
99 Ok(Self::Datafusion(bytes))
100 }
101
102 pub fn decode_datafusion_expr(
108 &self,
109 task_ctx: &TaskContext,
110 input_schema: &Schema,
111 max_payload_bytes: usize,
112 ) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
113 let Self::Datafusion(bytes) = self;
114 validate_payload_size(bytes.len(), max_payload_bytes)?;
115 let codec = DefaultPhysicalExtensionCodec {};
116 let proto = PhysicalExprNode::decode(bytes.as_slice()).map_err(|e| {
117 DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}"))
118 })?;
119
120 let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?;
121 validate_supported_payload_expr(&expr)?;
122 validate_decoded_payload_expr(&expr, input_schema)?;
123 Ok(expr)
124 }
125}
126
127fn validate_payload_size(
128 payload_size_bytes: usize,
129 max_payload_bytes: usize,
130) -> DataFusionResult<()> {
131 if payload_size_bytes > max_payload_bytes {
132 return Err(DataFusionError::Plan(format!(
133 "DynFilterPayload::Datafusion is {} bytes, which exceeds the configured limit of {} bytes",
134 payload_size_bytes, max_payload_bytes
135 )));
136 }
137
138 Ok(())
139}
140
141fn validate_supported_payload_expr(expr: &Arc<dyn PhysicalExpr>) -> DataFusionResult<()> {
142 expr.apply(|node| {
143 if node.as_any().is::<HashTableLookupExpr>() {
144 return Err(DataFusionError::Plan(
145 "HashTableLookupExpr cannot be encoded into DynFilterPayload::Datafusion"
146 .to_string(),
147 ));
148 }
149
150 Ok(TreeNodeRecursion::Continue)
151 })?;
152
153 Ok(())
154}
155
156fn validate_decoded_payload_expr(
163 expr: &Arc<dyn PhysicalExpr>,
164 input_schema: &Schema,
165) -> DataFusionResult<()> {
166 expr.apply(|node| {
167 if let Some(column) = node.as_any().downcast_ref::<Column>() {
168 let Some(field) = input_schema.fields().get(column.index()) else {
169 return Err(DataFusionError::Plan(format!(
170 "Decoded Column '{}' references out-of-bounds index {} for input schema of size {}",
171 column.name(),
172 column.index(),
173 input_schema.fields().len()
174 )));
175 };
176
177 if field.name() != column.name() {
178 return Err(DataFusionError::Plan(format!(
179 "Decoded Column name/index mismatch: payload has '{}' at index {}, but schema field is '{}'",
180 column.name(),
181 column.index(),
182 field.name()
183 )));
184 }
185 }
186
187 Ok(TreeNodeRecursion::Continue)
188 })?;
189
190 Ok(())
191}
192
193#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
199pub struct DynFilterUpdate {
200 pub protocol_version: u32,
202 pub query_id: String,
204 pub filter_id: String,
206 pub generation: u64,
208 pub is_complete: bool,
210 pub payload: DynFilterPayload,
212}
213
214impl DynFilterUpdate {
215 pub fn new(
217 query_id: String,
218 filter_id: String,
219 generation: u64,
220 is_complete: bool,
221 payload: DynFilterPayload,
222 ) -> Self {
223 Self {
224 protocol_version: DYN_FILTER_PROTOCOL_VERSION,
225 query_id,
226 filter_id,
227 generation,
228 is_complete,
229 payload,
230 }
231 }
232}
233
234#[derive(Clone, Debug)]
236pub struct QueryRequest {
237 pub header: Option<RegionRequestHeader>,
239
240 pub region_id: RegionId,
242
243 pub plan: LogicalPlan,
245}
246
247#[cfg(test)]
248mod tests {
249 use std::sync::Arc;
250
251 use base64::Engine;
252 use base64::prelude::BASE64_STANDARD;
253 use datafusion::arrow::datatypes::{DataType, Field, Schema};
254 use datafusion::physical_expr::expressions::Column;
255
256 use super::*;
257
258 #[test]
259 fn dyn_filter_update_sets_protocol_version() {
260 let update = DynFilterUpdate::new(
261 "query-1".to_string(),
262 "filter-1".to_string(),
263 3,
264 false,
265 DynFilterPayload::Datafusion(vec![1, 2, 3]),
266 );
267
268 assert_eq!(update.protocol_version, DYN_FILTER_PROTOCOL_VERSION);
269 assert!(!update.is_complete);
270 assert!(
271 matches!(update.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![1, 2, 3])
272 );
273 }
274
275 #[test]
276 fn dyn_filter_update_json_round_trip_preserves_payload_shape() {
277 let update = DynFilterUpdate::new(
278 "query-2".to_string(),
279 "filter-9".to_string(),
280 9,
281 true,
282 DynFilterPayload::Datafusion(vec![9, 8, 7]),
283 );
284
285 let json = serde_json::to_string(&update).unwrap();
286 let value: serde_json::Value = serde_json::from_str(&json).unwrap();
287 let decoded: DynFilterUpdate = serde_json::from_str(&json).unwrap();
288
289 assert_eq!(value["generation"], serde_json::json!(9));
290 assert!(value.get("epoch").is_none());
291 assert_eq!(
292 value["payload"],
293 serde_json::json!({ "kind": "datafusion", "payload": BASE64_STANDARD.encode([9, 8, 7]) })
294 );
295 assert_eq!(decoded, update);
296 assert!(decoded.is_complete);
297 assert!(
298 matches!(decoded.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![9, 8, 7])
299 );
300 }
301
302 #[test]
303 fn dyn_filter_payload_json_uses_base64_for_empty_and_padded_payloads() {
304 let empty = serde_json::to_value(DynFilterPayload::Datafusion(vec![])).unwrap();
305 let one = serde_json::to_value(DynFilterPayload::Datafusion(vec![1])).unwrap();
306 let two = serde_json::to_value(DynFilterPayload::Datafusion(vec![1, 2])).unwrap();
307
308 assert_eq!(
309 empty,
310 serde_json::json!({"kind": "datafusion", "payload": ""})
311 );
312 assert_eq!(
313 one,
314 serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1])})
315 );
316 assert_eq!(
317 two,
318 serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1, 2])})
319 );
320 }
321
322 #[test]
323 fn dyn_filter_payload_json_rejects_invalid_base64() {
324 let err = serde_json::from_value::<DynFilterPayload>(serde_json::json!({
325 "kind": "datafusion",
326 "payload": "not base64!",
327 }))
328 .unwrap_err();
329
330 assert!(
331 err.to_string()
332 .contains("invalid base64 dynamic filter payload")
333 );
334 }
335
336 #[test]
337 fn dyn_filter_payload_round_trips_physical_column_expr() {
338 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
339 let expr: Arc<dyn PhysicalExpr> =
340 Arc::new(Column::new_with_schema("host", &schema).unwrap());
341
342 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
343 let decoded = payload
344 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
345 .unwrap();
346
347 let original = expr.as_any().downcast_ref::<Column>().unwrap();
348 let decoded = decoded.as_any().downcast_ref::<Column>().unwrap();
349
350 assert_eq!(decoded.name(), original.name());
351 assert_eq!(decoded.index(), original.index());
352 }
353
354 #[test]
355 fn dyn_filter_payload_decode_rejects_invalid_bytes() {
356 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
357 let payload = DynFilterPayload::Datafusion(vec![1, 2, 3]);
358
359 let err = payload
360 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
361 .unwrap_err();
362
363 assert!(matches!(err, DataFusionError::Internal(_)));
364 }
365
366 #[test]
367 fn dyn_filter_payload_decode_rejects_column_name_index_mismatch() {
368 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
369 let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("service", 0));
370
371 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
372 let err = payload
373 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
374 .unwrap_err();
375
376 let msg = err.to_string();
377 assert!(
378 msg.contains("name/index mismatch"),
379 "expected name/index mismatch error, got: {msg}"
380 );
381 assert!(msg.contains("service"));
382 assert!(msg.contains("host"));
383 }
384
385 #[test]
386 fn dyn_filter_payload_decode_rejects_out_of_bounds_column_index() {
387 let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
388 let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 1));
389
390 let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
391 let err = payload
392 .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
393 .unwrap_err();
394
395 assert!(matches!(err, DataFusionError::Plan(_)));
396 }
397
398 #[test]
399 fn dyn_filter_payload_rejects_oversized_payload() {
400 let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 0));
401
402 let err = DynFilterPayload::from_datafusion_expr(&expr, 1).unwrap_err();
403
404 assert!(matches!(err, DataFusionError::Plan(_)));
405 }
406}