Skip to main content

common_query/
request.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use api::v1::region::RegionRequestHeader;
18use datafusion::arrow::datatypes::Schema;
19use datafusion::execution::TaskContext;
20use datafusion::physical_expr::expressions::Column;
21use datafusion::physical_plan::PhysicalExpr;
22use datafusion::physical_plan::joins::HashTableLookupExpr;
23use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
24use datafusion_common::{DataFusionError, Result as DataFusionResult};
25use datafusion_expr::LogicalPlan;
26use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
27use datafusion_proto::physical_plan::from_proto::parse_physical_expr;
28use datafusion_proto::physical_plan::to_proto::serialize_physical_expr;
29use datafusion_proto::protobuf::PhysicalExprNode;
30use prost::Message;
31use serde::{Deserialize, Serialize};
32use store_api::storage::RegionId;
33
34/// Current wire-format version for remote dynamic filter payload updates.
35pub const DYN_FILTER_PROTOCOL_VERSION: u32 = 1;
36
37/// Serialized predicate payload for remote dynamic filter updates.
38///
39/// The payload is tagged in JSON so receivers can reject unsupported encodings
40/// before decoding engine-specific bytes. For DataFusion expressions the
41/// `payload` bytes are serialized by `serde_json` as a base64 string, for example:
42///
43/// ```json
44/// { "kind": "datafusion", "payload": "CQgH" }
45/// ```
46#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
47#[non_exhaustive]
48#[serde(tag = "kind", content = "payload", rename_all = "snake_case")]
49pub enum DynFilterPayload {
50    /// A serialized DataFusion [`PhysicalExpr`] encoded as a protobuf
51    /// [`PhysicalExprNode`].
52    Datafusion(#[serde(with = "base64_bytes")] Vec<u8>),
53}
54
55mod base64_bytes {
56    use base64::Engine;
57    use base64::prelude::BASE64_STANDARD;
58    use serde::de::Error;
59    use serde::{Deserialize, Deserializer, Serializer};
60
61    pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
62    where
63        S: Serializer,
64    {
65        serializer.serialize_str(&BASE64_STANDARD.encode(bytes))
66    }
67
68    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
69    where
70        D: Deserializer<'de>,
71    {
72        let encoded = String::deserialize(deserializer)?;
73        BASE64_STANDARD.decode(encoded).map_err(|err| {
74            D::Error::custom(format!("invalid base64 dynamic filter payload: {err}"))
75        })
76    }
77}
78
79impl DynFilterPayload {
80    /// Encodes a DataFusion physical expression into a bounded dynamic filter payload.
81    ///
82    /// This rejects expressions that cannot be safely shipped as dynamic filter
83    /// predicates and fails if the serialized payload exceeds `max_payload_bytes`.
84    pub fn from_datafusion_expr(
85        expr: &Arc<dyn PhysicalExpr>,
86        max_payload_bytes: usize,
87    ) -> DataFusionResult<Self> {
88        validate_supported_payload_expr(expr)?;
89
90        let codec = DefaultPhysicalExtensionCodec {};
91        let proto = serialize_physical_expr(expr, &codec)?;
92        let mut bytes = Vec::new();
93        proto.encode(&mut bytes).map_err(|e| {
94            DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}"))
95        })?;
96
97        validate_payload_size(bytes.len(), max_payload_bytes)?;
98
99        Ok(Self::Datafusion(bytes))
100    }
101
102    /// Decodes a DataFusion dynamic filter payload against the provided input schema.
103    ///
104    /// The decoded expression is validated to ensure column indexes stay within the receiver-side
105    /// schema, column names are consistent (defensive check), and the payload stays within
106    /// `max_payload_bytes`.
107    pub fn decode_datafusion_expr(
108        &self,
109        task_ctx: &TaskContext,
110        input_schema: &Schema,
111        max_payload_bytes: usize,
112    ) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
113        let Self::Datafusion(bytes) = self;
114        validate_payload_size(bytes.len(), max_payload_bytes)?;
115        let codec = DefaultPhysicalExtensionCodec {};
116        let proto = PhysicalExprNode::decode(bytes.as_slice()).map_err(|e| {
117            DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}"))
118        })?;
119
120        let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?;
121        validate_supported_payload_expr(&expr)?;
122        validate_decoded_payload_expr(&expr, input_schema)?;
123        Ok(expr)
124    }
125}
126
127fn validate_payload_size(
128    payload_size_bytes: usize,
129    max_payload_bytes: usize,
130) -> DataFusionResult<()> {
131    if payload_size_bytes > max_payload_bytes {
132        return Err(DataFusionError::Plan(format!(
133            "DynFilterPayload::Datafusion is {} bytes, which exceeds the configured limit of {} bytes",
134            payload_size_bytes, max_payload_bytes
135        )));
136    }
137
138    Ok(())
139}
140
141fn validate_supported_payload_expr(expr: &Arc<dyn PhysicalExpr>) -> DataFusionResult<()> {
142    expr.apply(|node| {
143        if node.as_any().is::<HashTableLookupExpr>() {
144            return Err(DataFusionError::Plan(
145                "HashTableLookupExpr cannot be encoded into DynFilterPayload::Datafusion"
146                    .to_string(),
147            ));
148        }
149
150        Ok(TreeNodeRecursion::Continue)
151    })?;
152
153    Ok(())
154}
155
156/// Validates decoded dynamic filter physical expressions against the receiver-side schema.
157///
158/// Rejects out-of-bounds column indexes and, as a defensive check, rejects columns whose
159/// name disagrees with the corresponding schema field. DataFusion physical `Column` is
160/// index-authoritative, so a name mismatch usually indicates a coordinator/receiver
161/// schema inconsistency that should be surfaced loudly.
162fn validate_decoded_payload_expr(
163    expr: &Arc<dyn PhysicalExpr>,
164    input_schema: &Schema,
165) -> DataFusionResult<()> {
166    expr.apply(|node| {
167        if let Some(column) = node.as_any().downcast_ref::<Column>() {
168            let Some(field) = input_schema.fields().get(column.index()) else {
169                return Err(DataFusionError::Plan(format!(
170                    "Decoded Column '{}' references out-of-bounds index {} for input schema of size {}",
171                    column.name(),
172                    column.index(),
173                    input_schema.fields().len()
174                )));
175            };
176
177            if field.name() != column.name() {
178                return Err(DataFusionError::Plan(format!(
179                    "Decoded Column name/index mismatch: payload has '{}' at index {}, but schema field is '{}'",
180                    column.name(),
181                    column.index(),
182                    field.name()
183                )));
184            }
185        }
186
187        Ok(TreeNodeRecursion::Continue)
188    })?;
189
190    Ok(())
191}
192
193/// A remote dynamic filter update sent from a query coordinator to region servers.
194///
195/// `generation` is monotonic within a `query_id`/`filter_id` pair and matches the
196/// gRPC field name used by `RemoteDynFilterUpdate`. Receivers use it to ignore
197/// stale updates while `is_complete` marks the final payload for the filter.
198#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
199pub struct DynFilterUpdate {
200    /// Protocol version used by this update payload.
201    pub protocol_version: u32,
202    /// Internal query identifier that owns this dynamic filter lifecycle.
203    pub query_id: String,
204    /// Identifier of the dynamic filter within the query.
205    pub filter_id: String,
206    /// Monotonic update generation for this filter.
207    pub generation: u64,
208    /// Whether this update completes the dynamic filter stream.
209    pub is_complete: bool,
210    /// Serialized predicate payload carried by this update.
211    pub payload: DynFilterPayload,
212}
213
214impl DynFilterUpdate {
215    /// Creates a dynamic filter update with the current protocol version.
216    pub fn new(
217        query_id: String,
218        filter_id: String,
219        generation: u64,
220        is_complete: bool,
221        payload: DynFilterPayload,
222    ) -> Self {
223        Self {
224            protocol_version: DYN_FILTER_PROTOCOL_VERSION,
225            query_id,
226            filter_id,
227            generation,
228            is_complete,
229            payload,
230        }
231    }
232}
233
234/// The query request to be handled by the RegionServer (Datanode).
235#[derive(Clone, Debug)]
236pub struct QueryRequest {
237    /// The header of this request. Often to store some context of the query. None means all to defaults.
238    pub header: Option<RegionRequestHeader>,
239
240    /// The id of the region to be queried.
241    pub region_id: RegionId,
242
243    /// The form of the query: a logical plan.
244    pub plan: LogicalPlan,
245}
246
247#[cfg(test)]
248mod tests {
249    use std::sync::Arc;
250
251    use base64::Engine;
252    use base64::prelude::BASE64_STANDARD;
253    use datafusion::arrow::datatypes::{DataType, Field, Schema};
254    use datafusion::physical_expr::expressions::Column;
255
256    use super::*;
257
258    #[test]
259    fn dyn_filter_update_sets_protocol_version() {
260        let update = DynFilterUpdate::new(
261            "query-1".to_string(),
262            "filter-1".to_string(),
263            3,
264            false,
265            DynFilterPayload::Datafusion(vec![1, 2, 3]),
266        );
267
268        assert_eq!(update.protocol_version, DYN_FILTER_PROTOCOL_VERSION);
269        assert!(!update.is_complete);
270        assert!(
271            matches!(update.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![1, 2, 3])
272        );
273    }
274
275    #[test]
276    fn dyn_filter_update_json_round_trip_preserves_payload_shape() {
277        let update = DynFilterUpdate::new(
278            "query-2".to_string(),
279            "filter-9".to_string(),
280            9,
281            true,
282            DynFilterPayload::Datafusion(vec![9, 8, 7]),
283        );
284
285        let json = serde_json::to_string(&update).unwrap();
286        let value: serde_json::Value = serde_json::from_str(&json).unwrap();
287        let decoded: DynFilterUpdate = serde_json::from_str(&json).unwrap();
288
289        assert_eq!(value["generation"], serde_json::json!(9));
290        assert!(value.get("epoch").is_none());
291        assert_eq!(
292            value["payload"],
293            serde_json::json!({ "kind": "datafusion", "payload": BASE64_STANDARD.encode([9, 8, 7]) })
294        );
295        assert_eq!(decoded, update);
296        assert!(decoded.is_complete);
297        assert!(
298            matches!(decoded.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![9, 8, 7])
299        );
300    }
301
302    #[test]
303    fn dyn_filter_payload_json_uses_base64_for_empty_and_padded_payloads() {
304        let empty = serde_json::to_value(DynFilterPayload::Datafusion(vec![])).unwrap();
305        let one = serde_json::to_value(DynFilterPayload::Datafusion(vec![1])).unwrap();
306        let two = serde_json::to_value(DynFilterPayload::Datafusion(vec![1, 2])).unwrap();
307
308        assert_eq!(
309            empty,
310            serde_json::json!({"kind": "datafusion", "payload": ""})
311        );
312        assert_eq!(
313            one,
314            serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1])})
315        );
316        assert_eq!(
317            two,
318            serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1, 2])})
319        );
320    }
321
322    #[test]
323    fn dyn_filter_payload_json_rejects_invalid_base64() {
324        let err = serde_json::from_value::<DynFilterPayload>(serde_json::json!({
325            "kind": "datafusion",
326            "payload": "not base64!",
327        }))
328        .unwrap_err();
329
330        assert!(
331            err.to_string()
332                .contains("invalid base64 dynamic filter payload")
333        );
334    }
335
336    #[test]
337    fn dyn_filter_payload_round_trips_physical_column_expr() {
338        let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
339        let expr: Arc<dyn PhysicalExpr> =
340            Arc::new(Column::new_with_schema("host", &schema).unwrap());
341
342        let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
343        let decoded = payload
344            .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
345            .unwrap();
346
347        let original = expr.as_any().downcast_ref::<Column>().unwrap();
348        let decoded = decoded.as_any().downcast_ref::<Column>().unwrap();
349
350        assert_eq!(decoded.name(), original.name());
351        assert_eq!(decoded.index(), original.index());
352    }
353
354    #[test]
355    fn dyn_filter_payload_decode_rejects_invalid_bytes() {
356        let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
357        let payload = DynFilterPayload::Datafusion(vec![1, 2, 3]);
358
359        let err = payload
360            .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
361            .unwrap_err();
362
363        assert!(matches!(err, DataFusionError::Internal(_)));
364    }
365
366    #[test]
367    fn dyn_filter_payload_decode_rejects_column_name_index_mismatch() {
368        let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
369        let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("service", 0));
370
371        let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
372        let err = payload
373            .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
374            .unwrap_err();
375
376        let msg = err.to_string();
377        assert!(
378            msg.contains("name/index mismatch"),
379            "expected name/index mismatch error, got: {msg}"
380        );
381        assert!(msg.contains("service"));
382        assert!(msg.contains("host"));
383    }
384
385    #[test]
386    fn dyn_filter_payload_decode_rejects_out_of_bounds_column_index() {
387        let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
388        let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 1));
389
390        let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
391        let err = payload
392            .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
393            .unwrap_err();
394
395        assert!(matches!(err, DataFusionError::Plan(_)));
396    }
397
398    #[test]
399    fn dyn_filter_payload_rejects_oversized_payload() {
400        let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 0));
401
402        let err = DynFilterPayload::from_datafusion_expr(&expr, 1).unwrap_err();
403
404        assert!(matches!(err, DataFusionError::Plan(_)));
405    }
406}