Skip to main content

common_query/
request.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod base64_serde;
16mod initial_remote_dyn_filter_reg;
17
18use std::sync::Arc;
19
20use api::v1::region::RegionRequestHeader;
21use datafusion::execution::TaskContext;
22use datafusion::physical_expr::expressions::Column;
23use datafusion::physical_plan::PhysicalExpr;
24use datafusion::physical_plan::joins::HashTableLookupExpr;
25use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
26use datafusion_common::{DataFusionError, Result as DataFusionResult};
27use datafusion_expr::LogicalPlan;
28use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
29use datafusion_proto::physical_plan::from_proto::parse_physical_expr;
30use datafusion_proto::physical_plan::to_proto::serialize_physical_expr;
31use datafusion_proto::protobuf::PhysicalExprNode;
32use prost::Message;
33use serde::{Deserialize, Serialize};
34use store_api::storage::RegionId;
35
36/// Current wire-format version for remote dynamic filter payload updates.
37pub use self::initial_remote_dyn_filter_reg::{
38    INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY,
39    INITIAL_REMOTE_DYN_FILTER_REGS_MAX_TOTAL_PROTO_BYTES, InitialDynFilterReg,
40    InitialDynFilterRegs, InitialDynFilterSnapshot,
41};
42
43pub const DYN_FILTER_PROTOCOL_VERSION: u32 = 1;
44
45/// Serialized predicate payload for remote dynamic filter updates.
46///
47/// The payload is tagged in JSON so receivers can reject unsupported encodings
48/// before decoding engine-specific bytes. For DataFusion expressions the
49/// `payload` bytes are serialized by `serde_json` as a base64 string, for example:
50///
51/// ```json
52/// { "kind": "datafusion", "payload": "CQgH" }
53/// ```
54#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
55#[non_exhaustive]
56#[serde(tag = "kind", content = "payload", rename_all = "snake_case")]
57pub enum DynFilterPayload {
58    /// A serialized DataFusion [`PhysicalExpr`] encoded as a protobuf
59    /// [`PhysicalExprNode`].
60    Datafusion(#[serde(with = "base64_serde::bytes")] Vec<u8>),
61}
62
63impl DynFilterPayload {
64    /// Encodes a DataFusion physical expression into a bounded dynamic filter payload.
65    ///
66    /// This rejects expressions that cannot be safely shipped as dynamic filter
67    /// predicates and fails if the serialized payload exceeds `max_payload_bytes`.
68    pub fn from_datafusion_expr(
69        expr: &Arc<dyn PhysicalExpr>,
70        max_payload_bytes: usize,
71    ) -> DataFusionResult<Self> {
72        validate_supported_payload_expr(expr)?;
73
74        let codec = DefaultPhysicalExtensionCodec {};
75        let proto = serialize_physical_expr(expr, &codec)?;
76        let mut bytes = Vec::new();
77        proto.encode(&mut bytes).map_err(|e| {
78            DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}"))
79        })?;
80
81        validate_payload_size(bytes.len(), max_payload_bytes)?;
82
83        Ok(Self::Datafusion(bytes))
84    }
85
86    /// Decodes a DataFusion dynamic filter payload against the provided input schema.
87    ///
88    /// The decoded expression is validated to ensure column indexes stay within the receiver-side
89    /// schema, column names are consistent (defensive check), and the payload stays within
90    /// `max_payload_bytes`.
91    pub fn decode_datafusion_expr(
92        &self,
93        task_ctx: &TaskContext,
94        input_schema: &datafusion::arrow::datatypes::Schema,
95        max_payload_bytes: usize,
96    ) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
97        let Self::Datafusion(bytes) = self;
98        validate_payload_size(bytes.len(), max_payload_bytes)?;
99        let codec = DefaultPhysicalExtensionCodec {};
100        let proto = PhysicalExprNode::decode(bytes.as_slice()).map_err(|e| {
101            DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}"))
102        })?;
103
104        let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?;
105        validate_supported_payload_expr(&expr)?;
106        validate_decoded_payload_expr(&expr, input_schema)?;
107        Ok(expr)
108    }
109}
110
111fn encode_physical_expr_to_bytes(expr: &Arc<dyn PhysicalExpr>) -> DataFusionResult<Vec<u8>> {
112    let codec = DefaultPhysicalExtensionCodec {};
113    let proto = serialize_physical_expr(expr, &codec)?;
114    let mut bytes = Vec::new();
115    proto.encode(&mut bytes).map_err(|e| {
116        DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}"))
117    })?;
118    Ok(bytes)
119}
120
121pub(crate) fn decode_physical_expr_from_bytes(
122    bytes: &[u8],
123    task_ctx: &TaskContext,
124    input_schema: &datafusion::arrow::datatypes::Schema,
125    max_payload_bytes: usize,
126) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
127    validate_payload_size(bytes.len(), max_payload_bytes)?;
128    let codec = DefaultPhysicalExtensionCodec {};
129    let proto = PhysicalExprNode::decode(bytes).map_err(|e| {
130        DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}"))
131    })?;
132
133    let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?;
134    validate_supported_payload_expr(&expr)?;
135    validate_decoded_payload_expr(&expr, input_schema)?;
136    Ok(expr)
137}
138
139fn validate_payload_size(
140    payload_size_bytes: usize,
141    max_payload_bytes: usize,
142) -> DataFusionResult<()> {
143    if payload_size_bytes > max_payload_bytes {
144        return Err(DataFusionError::Plan(format!(
145            "DynFilterPayload::Datafusion is {} bytes, which exceeds the configured limit of {} bytes",
146            payload_size_bytes, max_payload_bytes
147        )));
148    }
149
150    Ok(())
151}
152
153fn validate_supported_payload_expr(expr: &Arc<dyn PhysicalExpr>) -> DataFusionResult<()> {
154    expr.apply(|node| {
155        if node.as_any().is::<HashTableLookupExpr>() {
156            return Err(DataFusionError::Plan(
157                "HashTableLookupExpr cannot be encoded into DynFilterPayload::Datafusion"
158                    .to_string(),
159            ));
160        }
161
162        Ok(TreeNodeRecursion::Continue)
163    })?;
164
165    Ok(())
166}
167
168/// Validates decoded dynamic filter physical expressions against the receiver-side schema.
169///
170/// Rejects out-of-bounds column indexes and, as a defensive check, rejects columns whose
171/// name disagrees with the corresponding schema field. DataFusion physical `Column` is
172/// index-authoritative, so a name mismatch usually indicates a coordinator/receiver
173/// schema inconsistency that should be surfaced loudly.
174fn validate_decoded_payload_expr(
175    expr: &Arc<dyn PhysicalExpr>,
176    input_schema: &datafusion::arrow::datatypes::Schema,
177) -> DataFusionResult<()> {
178    expr.apply(|node| {
179        if let Some(column) = node.as_any().downcast_ref::<Column>() {
180            let Some(field) = input_schema.fields().get(column.index()) else {
181                return Err(DataFusionError::Plan(format!(
182                    "Decoded Column '{}' references out-of-bounds index {} for input schema of size {}",
183                    column.name(),
184                    column.index(),
185                    input_schema.fields().len()
186                )));
187            };
188
189            if field.name() != column.name() {
190                return Err(DataFusionError::Plan(format!(
191                    "Decoded Column name/index mismatch: payload has '{}' at index {}, but schema field is '{}'",
192                    column.name(),
193                    column.index(),
194                    field.name()
195                )));
196            }
197        }
198
199        Ok(TreeNodeRecursion::Continue)
200    })?;
201
202    Ok(())
203}
204
205/// A remote dynamic filter update sent from a query coordinator to region servers.
206///
207/// `generation` is monotonic within a `query_id`/`filter_id` pair and matches the
208/// gRPC field name used by `RemoteDynFilterUpdate`. Receivers use it to ignore
209/// stale updates while `is_complete` marks the final payload for the filter.
210#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
211pub struct DynFilterUpdate {
212    /// Protocol version used by this update payload.
213    pub protocol_version: u32,
214    /// Internal query identifier that owns this dynamic filter lifecycle.
215    pub query_id: String,
216    /// Identifier of the dynamic filter within the query.
217    pub filter_id: String,
218    /// Monotonic update generation for this filter.
219    pub generation: u64,
220    /// Whether this update completes the dynamic filter stream.
221    pub is_complete: bool,
222    /// Serialized predicate payload carried by this update.
223    pub payload: DynFilterPayload,
224}
225
226impl DynFilterUpdate {
227    /// Creates a dynamic filter update with the current protocol version.
228    pub fn new(
229        query_id: String,
230        filter_id: String,
231        generation: u64,
232        is_complete: bool,
233        payload: DynFilterPayload,
234    ) -> Self {
235        Self {
236            protocol_version: DYN_FILTER_PROTOCOL_VERSION,
237            query_id,
238            filter_id,
239            generation,
240            is_complete,
241            payload,
242        }
243    }
244}
245
246/// The query request to be handled by the RegionServer (Datanode).
247#[derive(Clone, Debug)]
248pub struct QueryRequest {
249    /// The header of this request. Often to store some context of the query. None means all to defaults.
250    pub header: Option<RegionRequestHeader>,
251
252    /// The id of the region to be queried.
253    pub region_id: RegionId,
254
255    /// The form of the query: a logical plan.
256    pub plan: LogicalPlan,
257}
258
259#[cfg(test)]
260mod tests {
261    use std::sync::Arc;
262
263    use base64::Engine;
264    use base64::prelude::BASE64_STANDARD;
265    use datafusion::arrow::datatypes::{DataType, Field, Schema};
266    use datafusion::physical_expr::expressions::Column;
267
268    use super::*;
269
270    #[test]
271    fn dyn_filter_update_sets_protocol_version() {
272        let update = DynFilterUpdate::new(
273            "query-1".to_string(),
274            "filter-1".to_string(),
275            3,
276            false,
277            DynFilterPayload::Datafusion(vec![1, 2, 3]),
278        );
279
280        assert_eq!(update.protocol_version, DYN_FILTER_PROTOCOL_VERSION);
281        assert!(!update.is_complete);
282        assert!(
283            matches!(update.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![1, 2, 3])
284        );
285    }
286
287    #[test]
288    fn dyn_filter_update_json_round_trip_preserves_payload_shape() {
289        let update = DynFilterUpdate::new(
290            "query-2".to_string(),
291            "filter-9".to_string(),
292            9,
293            true,
294            DynFilterPayload::Datafusion(vec![9, 8, 7]),
295        );
296
297        let json = serde_json::to_string(&update).unwrap();
298        let value: serde_json::Value = serde_json::from_str(&json).unwrap();
299        let decoded: DynFilterUpdate = serde_json::from_str(&json).unwrap();
300
301        assert_eq!(value["generation"], serde_json::json!(9));
302        assert!(value.get("epoch").is_none());
303        assert_eq!(
304            value["payload"],
305            serde_json::json!({ "kind": "datafusion", "payload": BASE64_STANDARD.encode([9, 8, 7]) })
306        );
307        assert_eq!(decoded, update);
308        assert!(decoded.is_complete);
309        assert!(
310            matches!(decoded.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![9, 8, 7])
311        );
312    }
313
314    #[test]
315    fn dyn_filter_payload_json_uses_base64_for_empty_and_padded_payloads() {
316        let empty = serde_json::to_value(DynFilterPayload::Datafusion(vec![])).unwrap();
317        let one = serde_json::to_value(DynFilterPayload::Datafusion(vec![1])).unwrap();
318        let two = serde_json::to_value(DynFilterPayload::Datafusion(vec![1, 2])).unwrap();
319
320        assert_eq!(
321            empty,
322            serde_json::json!({"kind": "datafusion", "payload": ""})
323        );
324        assert_eq!(
325            one,
326            serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1])})
327        );
328        assert_eq!(
329            two,
330            serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1, 2])})
331        );
332    }
333
334    #[test]
335    fn dyn_filter_payload_json_rejects_invalid_base64() {
336        let err = serde_json::from_value::<DynFilterPayload>(serde_json::json!({
337            "kind": "datafusion",
338            "payload": "not base64!",
339        }))
340        .unwrap_err();
341
342        assert!(
343            err.to_string()
344                .contains("invalid base64 dynamic filter payload")
345        );
346    }
347
348    #[test]
349    fn dyn_filter_payload_round_trips_physical_column_expr() {
350        let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
351        let expr: Arc<dyn PhysicalExpr> =
352            Arc::new(Column::new_with_schema("host", &schema).unwrap());
353
354        let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
355        let decoded = payload
356            .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
357            .unwrap();
358
359        let original = expr.as_any().downcast_ref::<Column>().unwrap();
360        let decoded = decoded.as_any().downcast_ref::<Column>().unwrap();
361
362        assert_eq!(decoded.name(), original.name());
363        assert_eq!(decoded.index(), original.index());
364    }
365
366    #[test]
367    fn dyn_filter_payload_decode_rejects_invalid_bytes() {
368        let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
369        let payload = DynFilterPayload::Datafusion(vec![1, 2, 3]);
370
371        let err = payload
372            .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
373            .unwrap_err();
374
375        assert!(matches!(err, DataFusionError::Internal(_)));
376    }
377
378    #[test]
379    fn dyn_filter_payload_decode_rejects_column_name_index_mismatch() {
380        let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
381        let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("service", 0));
382
383        let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
384        let err = payload
385            .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
386            .unwrap_err();
387
388        let msg = err.to_string();
389        assert!(
390            msg.contains("name/index mismatch"),
391            "expected name/index mismatch error, got: {msg}"
392        );
393        assert!(msg.contains("service"));
394        assert!(msg.contains("host"));
395    }
396
397    #[test]
398    fn dyn_filter_payload_decode_rejects_out_of_bounds_column_index() {
399        let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
400        let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 1));
401
402        let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
403        let err = payload
404            .decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
405            .unwrap_err();
406
407        assert!(matches!(err, DataFusionError::Plan(_)));
408    }
409
410    #[test]
411    fn dyn_filter_payload_rejects_oversized_payload() {
412        let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 0));
413
414        let err = DynFilterPayload::from_datafusion_expr(&expr, 1).unwrap_err();
415
416        assert!(matches!(err, DataFusionError::Plan(_)));
417    }
418}