diff --git a/Cargo.lock b/Cargo.lock index 6081085ba0..01e87bed93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2623,6 +2623,7 @@ version = "1.0.0" dependencies = [ "api", "async-trait", + "base64 0.22.1", "bytes", "common-base", "common-error", diff --git a/src/common/query/Cargo.toml b/src/common/query/Cargo.toml index 13f3c174b6..b28e56fe8e 100644 --- a/src/common/query/Cargo.toml +++ b/src/common/query/Cargo.toml @@ -13,6 +13,7 @@ workspace = true [dependencies] api.workspace = true async-trait.workspace = true +base64.workspace = true bytes.workspace = true common-base.workspace = true common-error.workspace = true diff --git a/src/common/query/src/request.rs b/src/common/query/src/request.rs index c33e209557..3ab47e911b 100644 --- a/src/common/query/src/request.rs +++ b/src/common/query/src/request.rs @@ -31,16 +31,56 @@ use prost::Message; use serde::{Deserialize, Serialize}; use store_api::storage::RegionId; +/// Current wire-format version for remote dynamic filter payload updates. pub const DYN_FILTER_PROTOCOL_VERSION: u32 = 1; +/// Serialized predicate payload for remote dynamic filter updates. +/// +/// The payload is tagged in JSON so receivers can reject unsupported encodings +/// before decoding engine-specific bytes. For DataFusion expressions the +/// `payload` bytes are serialized by `serde_json` as a base64 string, for example: +/// +/// ```json +/// { "kind": "datafusion", "payload": "CQgH" } +/// ``` #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[non_exhaustive] #[serde(tag = "kind", content = "payload", rename_all = "snake_case")] pub enum DynFilterPayload { - Datafusion(Vec), + /// A serialized DataFusion [`PhysicalExpr`] encoded as a protobuf + /// [`PhysicalExprNode`]. + Datafusion(#[serde(with = "base64_bytes")] Vec), +} + +mod base64_bytes { + use base64::Engine; + use base64::prelude::BASE64_STANDARD; + use serde::de::Error; + use serde::{Deserialize, Deserializer, Serializer}; + + pub fn serialize(bytes: &[u8], serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&BASE64_STANDARD.encode(bytes)) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let encoded = String::deserialize(deserializer)?; + BASE64_STANDARD.decode(encoded).map_err(|err| { + D::Error::custom(format!("invalid base64 dynamic filter payload: {err}")) + }) + } } impl DynFilterPayload { + /// Encodes a DataFusion physical expression into a bounded dynamic filter payload. + /// + /// This rejects expressions that cannot be safely shipped as dynamic filter + /// predicates and fails if the serialized payload exceeds `max_payload_bytes`. pub fn from_datafusion_expr( expr: &Arc, max_payload_bytes: usize, @@ -114,22 +154,13 @@ fn validate_decoded_payload_expr( ) -> DataFusionResult<()> { expr.apply(|node| { if let Some(column) = node.as_any().downcast_ref::() { - let Some(field) = input_schema.fields().get(column.index()) else { + if input_schema.fields().get(column.index()).is_none() { return Err(DataFusionError::Plan(format!( "Decoded Column '{}' references out-of-bounds index {} for input schema of size {}", column.name(), column.index(), input_schema.fields().len() ))); - }; - - if field.name() != column.name() { - return Err(DataFusionError::Plan(format!( - "Decoded Column name/index mismatch: payload has '{}' at index {}, but schema field is '{}'", - column.name(), - column.index(), - field.name() - ))); } } @@ -139,21 +170,33 @@ fn validate_decoded_payload_expr( Ok(()) } +/// A remote dynamic filter update sent from a query coordinator to region servers. +/// +/// `generation` is monotonic within a `query_id`/`filter_id` pair and matches the +/// gRPC field name used by `RemoteDynFilterUpdate`. Receivers use it to ignore +/// stale updates while `is_complete` marks the final payload for the filter. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct DynFilterUpdate { + /// Protocol version used by this update payload. pub protocol_version: u32, + /// Internal query identifier that owns this dynamic filter lifecycle. pub query_id: String, + /// Identifier of the dynamic filter within the query. pub filter_id: String, - pub epoch: u64, + /// Monotonic update generation for this filter. + pub generation: u64, + /// Whether this update completes the dynamic filter stream. pub is_complete: bool, + /// Serialized predicate payload carried by this update. pub payload: DynFilterPayload, } impl DynFilterUpdate { + /// Creates a dynamic filter update with the current protocol version. pub fn new( query_id: String, filter_id: String, - epoch: u64, + generation: u64, is_complete: bool, payload: DynFilterPayload, ) -> Self { @@ -161,7 +204,7 @@ impl DynFilterUpdate { protocol_version: DYN_FILTER_PROTOCOL_VERSION, query_id, filter_id, - epoch, + generation, is_complete, payload, } @@ -185,6 +228,8 @@ pub struct QueryRequest { mod tests { use std::sync::Arc; + use base64::Engine; + use base64::prelude::BASE64_STANDARD; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::physical_expr::expressions::Column; @@ -218,8 +263,15 @@ mod tests { ); let json = serde_json::to_string(&update).unwrap(); + let value: serde_json::Value = serde_json::from_str(&json).unwrap(); let decoded: DynFilterUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(value["generation"], serde_json::json!(9)); + assert!(value.get("epoch").is_none()); + assert_eq!( + value["payload"], + serde_json::json!({ "kind": "datafusion", "payload": BASE64_STANDARD.encode([9, 8, 7]) }) + ); assert_eq!(decoded, update); assert!(decoded.is_complete); assert!( @@ -227,6 +279,40 @@ mod tests { ); } + #[test] + fn dyn_filter_payload_json_uses_base64_for_empty_and_padded_payloads() { + let empty = serde_json::to_value(DynFilterPayload::Datafusion(vec![])).unwrap(); + let one = serde_json::to_value(DynFilterPayload::Datafusion(vec![1])).unwrap(); + let two = serde_json::to_value(DynFilterPayload::Datafusion(vec![1, 2])).unwrap(); + + assert_eq!( + empty, + serde_json::json!({"kind": "datafusion", "payload": ""}) + ); + assert_eq!( + one, + serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1])}) + ); + assert_eq!( + two, + serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1, 2])}) + ); + } + + #[test] + fn dyn_filter_payload_json_rejects_invalid_base64() { + let err = serde_json::from_value::(serde_json::json!({ + "kind": "datafusion", + "payload": "not base64!", + })) + .unwrap_err(); + + assert!( + err.to_string() + .contains("invalid base64 dynamic filter payload") + ); + } + #[test] fn dyn_filter_payload_round_trips_physical_column_expr() { let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]); @@ -258,11 +344,26 @@ mod tests { } #[test] - fn dyn_filter_payload_decode_rejects_column_name_index_mismatch() { + fn dyn_filter_payload_decode_accepts_column_name_mismatch_when_index_is_valid() { let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]); - let mismatched_expr: Arc = Arc::new(Column::new("service", 0)); + let expr: Arc = Arc::new(Column::new("service", 0)); - let payload = DynFilterPayload::from_datafusion_expr(&mismatched_expr, 1024).unwrap(); + let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap(); + let decoded = payload + .decode_datafusion_expr(&TaskContext::default(), &schema, 1024) + .unwrap(); + + let decoded = decoded.as_any().downcast_ref::().unwrap(); + + assert_eq!(decoded.index(), 0); + } + + #[test] + fn dyn_filter_payload_decode_rejects_out_of_bounds_column_index() { + let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]); + let expr: Arc = Arc::new(Column::new("host", 1)); + + let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap(); let err = payload .decode_datafusion_expr(&TaskContext::default(), &schema, 1024) .unwrap_err(); diff --git a/src/datanode/src/region_server.rs b/src/datanode/src/region_server.rs index 9bf028fef7..7563701a0b 100644 --- a/src/datanode/src/region_server.rs +++ b/src/datanode/src/region_server.rs @@ -23,6 +23,7 @@ use std::time::Duration; use api::region::RegionResponse; use api::v1::meta::TopicStat; +use api::v1::region::remote_dyn_filter_request::Action; use api::v1::region::sync_request::ManifestInfo; use api::v1::region::{ ListMetadataRequest, RegionResponse as RegionResponseV1, RemoteDynFilterRequest, SyncRequest, @@ -711,11 +712,11 @@ impl RegionServer { .as_ref() .context(error::MissingRequiredFieldSnafu { name: "action" })? { - api::v1::region::remote_dyn_filter_request::Action::Update(update) => { + Action::Update(update) => { self.handle_remote_dyn_filter_update(&request.query_id, update) .await } - api::v1::region::remote_dyn_filter_request::Action::Unregister(unregister) => { + Action::Unregister(unregister) => { self.handle_remote_dyn_filter_unregister(&request.query_id, unregister) .await } diff --git a/src/servers/src/grpc/context_auth.rs b/src/servers/src/grpc/context_auth.rs index 0cf71cc6a1..39c4fc5c88 100644 --- a/src/servers/src/grpc/context_auth.rs +++ b/src/servers/src/grpc/context_auth.rs @@ -20,10 +20,7 @@ use auth::{Identity, Password, UserInfoRef, UserProviderRef}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; -use session::context::{ - Channel, QueryContextBuilder, QueryContextRef, REMOTE_QUERY_ID_EXTENSION_KEY, - generate_remote_query_id, -}; +use session::context::{Channel, QueryContextBuilder, QueryContextRef}; use snafu::{OptionExt, ResultExt}; use tonic::Status; use tonic::metadata::MetadataMap; @@ -53,10 +50,6 @@ pub fn create_query_context_from_grpc_metadata( .current_catalog(catalog) .current_schema(schema) .channel(Channel::Grpc) - .set_extension( - REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), - generate_remote_query_id(), - ) .build(), )) } diff --git a/src/servers/src/grpc/greptime_handler.rs b/src/servers/src/grpc/greptime_handler.rs index 3763725bdc..7ae881adea 100644 --- a/src/servers/src/grpc/greptime_handler.rs +++ b/src/servers/src/grpc/greptime_handler.rs @@ -33,10 +33,7 @@ use common_telemetry::tracing_context::{FutureExt, TracingContext}; use common_telemetry::{debug, error, tracing, warn}; use common_time::timezone::parse_timezone; use futures_util::StreamExt; -use session::context::{ - Channel, QueryContextBuilder, QueryContextRef, REMOTE_QUERY_ID_EXTENSION_KEY, - generate_remote_query_id, -}; +use session::context::{Channel, QueryContextBuilder, QueryContextRef}; use session::hints::{READ_PREFERENCE_HINT, is_reserved_extension_key}; use snafu::{OptionExt, ResultExt}; use tokio::sync::mpsc; @@ -217,11 +214,7 @@ pub(crate) fn create_query_context( .current_catalog(catalog) .current_schema(schema) .timezone(timezone) - .channel(channel) - .set_extension( - REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), - generate_remote_query_id(), - ); + .channel(channel); if let Some(x) = extensions .iter() @@ -293,6 +286,7 @@ impl Drop for RequestTimer { mod tests { use chrono::FixedOffset; use common_time::Timezone; + use session::hints::REMOTE_QUERY_ID_EXTENSION_KEY; use super::*; diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index 07ac146899..b55d8d78e8 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -28,9 +28,7 @@ use common_telemetry::warn; use common_time::Timezone; use common_time::timezone::parse_timezone; use headers::Header; -use session::context::{ - QueryContextBuilder, REMOTE_QUERY_ID_EXTENSION_KEY, generate_remote_query_id, -}; +use session::context::QueryContextBuilder; use snafu::{OptionExt, ResultExt, ensure}; use crate::error::{ @@ -66,11 +64,7 @@ pub async fn inner_auth( let query_ctx_builder = QueryContextBuilder::default() .current_catalog(catalog.clone()) .current_schema(schema.clone()) - .timezone(timezone) - .set_extension( - REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), - generate_remote_query_id(), - ); + .timezone(timezone); let query_ctx = query_ctx_builder.build(); let need_auth = need_auth(&req); diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index cba78a060e..2fce3c022d 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -31,10 +31,7 @@ use common_recordbatch::cursor::RecordBatchStreamCursor; pub use common_session::ReadPreference; use common_time::Timezone; use common_time::timezone::get_timezone; -use context::{ - ConfigurationVariables, QueryContextBuilder, REMOTE_QUERY_ID_EXTENSION_KEY, - generate_remote_query_id, -}; +use context::{ConfigurationVariables, QueryContextBuilder}; use derive_more::Debug; use crate::context::{Channel, ConnInfo, QueryContextRef}; @@ -110,10 +107,6 @@ impl Session { .channel(self.conn_info.channel) .process_id(self.process_id) .conn_info(self.conn_info.clone()) - .set_extension( - REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), - generate_remote_query_id(), - ) .build() .into() }