From 727e681fd52053d4aa44ed5bbfc6b8ad83b0750f Mon Sep 17 00:00:00 2001 From: discord9 Date: Wed, 1 Apr 2026 14:25:52 +0800 Subject: [PATCH] feat: dyn filter update abi Signed-off-by: discord9 --- Cargo.lock | 41 ++++ Cargo.toml | 4 +- src/common/query/Cargo.toml | 3 + src/common/query/src/request.rs | 251 +++++++++++++++++++++++ src/servers/src/grpc/context_auth.rs | 9 +- src/servers/src/grpc/greptime_handler.rs | 22 +- src/servers/src/http/authorize.rs | 23 ++- src/session/Cargo.toml | 1 + src/session/src/context.rs | 66 +++++- src/session/src/lib.rs | 9 +- 10 files changed, 417 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 872095752b..4c7df9083b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2632,10 +2632,13 @@ dependencies = [ "datafusion", "datafusion-common", "datafusion-expr", + "datafusion-proto", "datatypes", "futures-util", "once_cell", + "prost 0.14.1", "serde", + "serde_json", "snafu 0.8.6", "sqlparser", "sqlparser_derive 0.1.1", @@ -4129,6 +4132,43 @@ dependencies = [ "tokio", ] +[[package]] +name = "datafusion-proto" +version = "52.1.0" +source = "git+https://github.com/GreptimeTeam/datafusion.git?rev=02b82535e0160c4545667f36a03e1ff9d1d2e51f#02b82535e0160c4545667f36a03e1ff9d1d2e51f" +dependencies = [ + "arrow 57.3.0", + "chrono", + "datafusion-catalog", + "datafusion-catalog-listing", + "datafusion-common", + "datafusion-datasource", + "datafusion-datasource-arrow", + "datafusion-datasource-csv", + "datafusion-datasource-json", + "datafusion-datasource-parquet", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions-table", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "datafusion-physical-plan", + "datafusion-proto-common", + "object_store", + "prost 0.14.1", + "rand 0.9.1", +] + +[[package]] +name = "datafusion-proto-common" +version = "52.1.0" +source = "git+https://github.com/GreptimeTeam/datafusion.git?rev=02b82535e0160c4545667f36a03e1ff9d1d2e51f#02b82535e0160c4545667f36a03e1ff9d1d2e51f" +dependencies = [ + "arrow 57.3.0", + "datafusion-common", + "prost 0.14.1", +] + [[package]] name = "datafusion-pruning" version = "52.1.0" @@ -12137,6 +12177,7 @@ dependencies = [ "derive_more", "snafu 0.8.6", "sql", + "uuid", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 227608bf64..9ebcfc8627 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -139,6 +139,7 @@ datafusion-orc = "0.7" datafusion-pg-catalog = "0.15.1" datafusion-physical-expr = "=52.1" datafusion-physical-plan = "=52.1" +datafusion-proto = "=52.1" datafusion-sql = "=52.1" datafusion-substrait = "=52.1" deadpool = "0.12" @@ -251,7 +252,7 @@ tracing-appender = "0.2" tracing-opentelemetry = "0.31.0" tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "fmt"] } typetag = "0.2" -uuid = { version = "1.17", features = ["serde", "v4", "fast-rng"] } +uuid = { version = "1.17", features = ["serde", "v4", "v7", "fast-rng"] } vrl = "0.25" zstd = "0.13" # DO_NOT_REMOVE_THIS: END_OF_EXTERNAL_DEPENDENCIES @@ -341,6 +342,7 @@ datafusion-optimizer = { git = "https://github.com/GreptimeTeam/datafusion.git", datafusion-physical-expr = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-physical-expr-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-physical-plan = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } +datafusion-proto = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-datasource = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-sql = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } datafusion-substrait = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" } diff --git a/src/common/query/Cargo.toml b/src/common/query/Cargo.toml index 48328ea612..13f3c174b6 100644 --- a/src/common/query/Cargo.toml +++ b/src/common/query/Cargo.toml @@ -22,8 +22,10 @@ common-time.workspace = true datafusion.workspace = true datafusion-common.workspace = true datafusion-expr.workspace = true +datafusion-proto.workspace = true datatypes.workspace = true once_cell.workspace = true +prost.workspace = true serde.workspace = true snafu.workspace = true sqlparser.workspace = true @@ -33,4 +35,5 @@ store-api.workspace = true [dev-dependencies] common-base.workspace = true futures-util.workspace = true +serde_json.workspace = true tokio.workspace = true diff --git a/src/common/query/src/request.rs b/src/common/query/src/request.rs index 260a43e79d..c33e209557 100644 --- a/src/common/query/src/request.rs +++ b/src/common/query/src/request.rs @@ -12,10 +12,162 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use api::v1::region::RegionRequestHeader; +use datafusion::arrow::datatypes::Schema; +use datafusion::execution::TaskContext; +use datafusion::physical_expr::expressions::Column; +use datafusion::physical_plan::PhysicalExpr; +use datafusion::physical_plan::joins::HashTableLookupExpr; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_expr::LogicalPlan; +use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; +use datafusion_proto::protobuf::PhysicalExprNode; +use prost::Message; +use serde::{Deserialize, Serialize}; use store_api::storage::RegionId; +pub const DYN_FILTER_PROTOCOL_VERSION: u32 = 1; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[non_exhaustive] +#[serde(tag = "kind", content = "payload", rename_all = "snake_case")] +pub enum DynFilterPayload { + Datafusion(Vec), +} + +impl DynFilterPayload { + pub fn from_datafusion_expr( + expr: &Arc, + max_payload_bytes: usize, + ) -> DataFusionResult { + validate_supported_payload_expr(expr)?; + + let codec = DefaultPhysicalExtensionCodec {}; + let proto = serialize_physical_expr(expr, &codec)?; + let mut bytes = Vec::new(); + proto.encode(&mut bytes).map_err(|e| { + DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}")) + })?; + + validate_payload_size(bytes.len(), max_payload_bytes)?; + + Ok(Self::Datafusion(bytes)) + } + + pub fn decode_datafusion_expr( + &self, + task_ctx: &TaskContext, + input_schema: &Schema, + max_payload_bytes: usize, + ) -> DataFusionResult> { + let Self::Datafusion(bytes) = self; + validate_payload_size(bytes.len(), max_payload_bytes)?; + let codec = DefaultPhysicalExtensionCodec {}; + let proto = PhysicalExprNode::decode(bytes.as_slice()).map_err(|e| { + DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}")) + })?; + + let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?; + validate_supported_payload_expr(&expr)?; + validate_decoded_payload_expr(&expr, input_schema)?; + Ok(expr) + } +} + +fn validate_payload_size( + payload_size_bytes: usize, + max_payload_bytes: usize, +) -> DataFusionResult<()> { + if payload_size_bytes > max_payload_bytes { + return Err(DataFusionError::Plan(format!( + "DynFilterPayload::Datafusion is {} bytes, which exceeds the configured limit of {} bytes", + payload_size_bytes, max_payload_bytes + ))); + } + + Ok(()) +} + +fn validate_supported_payload_expr(expr: &Arc) -> DataFusionResult<()> { + expr.apply(|node| { + if node.as_any().is::() { + return Err(DataFusionError::Plan( + "HashTableLookupExpr cannot be encoded into DynFilterPayload::Datafusion" + .to_string(), + )); + } + + Ok(TreeNodeRecursion::Continue) + })?; + + Ok(()) +} + +fn validate_decoded_payload_expr( + expr: &Arc, + input_schema: &Schema, +) -> DataFusionResult<()> { + expr.apply(|node| { + if let Some(column) = node.as_any().downcast_ref::() { + let Some(field) = input_schema.fields().get(column.index()) else { + return Err(DataFusionError::Plan(format!( + "Decoded Column '{}' references out-of-bounds index {} for input schema of size {}", + column.name(), + column.index(), + input_schema.fields().len() + ))); + }; + + if field.name() != column.name() { + return Err(DataFusionError::Plan(format!( + "Decoded Column name/index mismatch: payload has '{}' at index {}, but schema field is '{}'", + column.name(), + column.index(), + field.name() + ))); + } + } + + Ok(TreeNodeRecursion::Continue) + })?; + + Ok(()) +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct DynFilterUpdate { + pub protocol_version: u32, + pub query_id: String, + pub filter_id: String, + pub epoch: u64, + pub is_complete: bool, + pub payload: DynFilterPayload, +} + +impl DynFilterUpdate { + pub fn new( + query_id: String, + filter_id: String, + epoch: u64, + is_complete: bool, + payload: DynFilterPayload, + ) -> Self { + Self { + protocol_version: DYN_FILTER_PROTOCOL_VERSION, + query_id, + filter_id, + epoch, + is_complete, + payload, + } + } +} + /// The query request to be handled by the RegionServer (Datanode). #[derive(Clone, Debug)] pub struct QueryRequest { @@ -28,3 +180,102 @@ pub struct QueryRequest { /// The form of the query: a logical plan. pub plan: LogicalPlan, } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::physical_expr::expressions::Column; + + use super::*; + + #[test] + fn dyn_filter_update_sets_protocol_version() { + let update = DynFilterUpdate::new( + "query-1".to_string(), + "filter-1".to_string(), + 3, + false, + DynFilterPayload::Datafusion(vec![1, 2, 3]), + ); + + assert_eq!(update.protocol_version, DYN_FILTER_PROTOCOL_VERSION); + assert!(!update.is_complete); + assert!( + matches!(update.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![1, 2, 3]) + ); + } + + #[test] + fn dyn_filter_update_json_round_trip_preserves_payload_shape() { + let update = DynFilterUpdate::new( + "query-2".to_string(), + "filter-9".to_string(), + 9, + true, + DynFilterPayload::Datafusion(vec![9, 8, 7]), + ); + + let json = serde_json::to_string(&update).unwrap(); + let decoded: DynFilterUpdate = serde_json::from_str(&json).unwrap(); + + assert_eq!(decoded, update); + assert!(decoded.is_complete); + assert!( + matches!(decoded.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![9, 8, 7]) + ); + } + + #[test] + fn dyn_filter_payload_round_trips_physical_column_expr() { + let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]); + let expr: Arc = + Arc::new(Column::new_with_schema("host", &schema).unwrap()); + + let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap(); + let decoded = payload + .decode_datafusion_expr(&TaskContext::default(), &schema, 1024) + .unwrap(); + + let original = expr.as_any().downcast_ref::().unwrap(); + let decoded = decoded.as_any().downcast_ref::().unwrap(); + + assert_eq!(decoded.name(), original.name()); + assert_eq!(decoded.index(), original.index()); + } + + #[test] + fn dyn_filter_payload_decode_rejects_invalid_bytes() { + let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]); + let payload = DynFilterPayload::Datafusion(vec![1, 2, 3]); + + let err = payload + .decode_datafusion_expr(&TaskContext::default(), &schema, 1024) + .unwrap_err(); + + assert!(matches!(err, DataFusionError::Internal(_))); + } + + #[test] + fn dyn_filter_payload_decode_rejects_column_name_index_mismatch() { + let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]); + let mismatched_expr: Arc = Arc::new(Column::new("service", 0)); + + let payload = DynFilterPayload::from_datafusion_expr(&mismatched_expr, 1024).unwrap(); + let err = payload + .decode_datafusion_expr(&TaskContext::default(), &schema, 1024) + .unwrap_err(); + + assert!(matches!(err, DataFusionError::Plan(_))); + } + + #[test] + fn dyn_filter_payload_rejects_oversized_payload() { + let expr: Arc = Arc::new(Column::new("host", 0)); + + let err = DynFilterPayload::from_datafusion_expr(&expr, 1).unwrap_err(); + + assert!(matches!(err, DataFusionError::Plan(_))); + } +} diff --git a/src/servers/src/grpc/context_auth.rs b/src/servers/src/grpc/context_auth.rs index 39c4fc5c88..0cf71cc6a1 100644 --- a/src/servers/src/grpc/context_auth.rs +++ b/src/servers/src/grpc/context_auth.rs @@ -20,7 +20,10 @@ use auth::{Identity, Password, UserInfoRef, UserProviderRef}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; -use session::context::{Channel, QueryContextBuilder, QueryContextRef}; +use session::context::{ + Channel, QueryContextBuilder, QueryContextRef, REMOTE_QUERY_ID_EXTENSION_KEY, + generate_remote_query_id, +}; use snafu::{OptionExt, ResultExt}; use tonic::Status; use tonic::metadata::MetadataMap; @@ -50,6 +53,10 @@ pub fn create_query_context_from_grpc_metadata( .current_catalog(catalog) .current_schema(schema) .channel(Channel::Grpc) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) .build(), )) } diff --git a/src/servers/src/grpc/greptime_handler.rs b/src/servers/src/grpc/greptime_handler.rs index c1f146db6d..40b825c821 100644 --- a/src/servers/src/grpc/greptime_handler.rs +++ b/src/servers/src/grpc/greptime_handler.rs @@ -33,7 +33,10 @@ use common_telemetry::tracing_context::{FutureExt, TracingContext}; use common_telemetry::{debug, error, tracing, warn}; use common_time::timezone::parse_timezone; use futures_util::StreamExt; -use session::context::{Channel, QueryContextBuilder, QueryContextRef}; +use session::context::{ + Channel, QueryContextBuilder, QueryContextRef, REMOTE_QUERY_ID_EXTENSION_KEY, + generate_remote_query_id, +}; use session::hints::READ_PREFERENCE_HINT; use snafu::{OptionExt, ResultExt}; use tokio::sync::mpsc; @@ -214,7 +217,11 @@ pub(crate) fn create_query_context( .current_catalog(catalog) .current_schema(schema) .timezone(timezone) - .channel(channel); + .channel(channel) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ); if let Some(x) = extensions .iter() @@ -308,9 +315,16 @@ mod tests { query_context.read_preference(), ReadPreference::Leader )); + let mut extensions = query_context.extensions().into_iter().collect::>(); + extensions.sort_unstable_by(|a, b| a.0.cmp(&b.0)); assert_eq!( - query_context.extensions().into_iter().collect::>(), - vec![("auto_create_table".to_string(), "true".to_string())] + extensions[0], + ("auto_create_table".to_string(), "true".to_string()) + ); + assert_eq!(extensions[1].0, REMOTE_QUERY_ID_EXTENSION_KEY.to_string()); + assert_eq!( + query_context.remote_query_id(), + Some(extensions[1].1.as_str()) ); } } diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index d2bfd9eba2..07ac146899 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -28,7 +28,9 @@ use common_telemetry::warn; use common_time::Timezone; use common_time::timezone::parse_timezone; use headers::Header; -use session::context::QueryContextBuilder; +use session::context::{ + QueryContextBuilder, REMOTE_QUERY_ID_EXTENSION_KEY, generate_remote_query_id, +}; use snafu::{OptionExt, ResultExt, ensure}; use crate::error::{ @@ -64,7 +66,11 @@ pub async fn inner_auth( let query_ctx_builder = QueryContextBuilder::default() .current_catalog(catalog.clone()) .current_schema(schema.clone()) - .timezone(timezone); + .timezone(timezone) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ); let query_ctx = query_ctx_builder.build(); let need_auth = need_auth(&req); @@ -388,6 +394,19 @@ mod tests { assert!(auth_scheme.is_err()); } + #[test] + fn test_inner_auth_assigns_remote_query_id() { + let req = + mock_http_request(None, Some("http://127.0.0.1/v1/sql?db=greptime-public")).unwrap(); + let req = futures::executor::block_on(inner_auth::<()>(None, req)).unwrap(); + let query_ctx = req + .extensions() + .get::() + .unwrap(); + + assert!(query_ctx.remote_query_id().is_some()); + } + #[test] fn test_auth_header() { // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ=" diff --git a/src/session/Cargo.toml b/src/session/Cargo.toml index 5b8b60f5ab..c7be8c2f7e 100644 --- a/src/session/Cargo.toml +++ b/src/session/Cargo.toml @@ -27,3 +27,4 @@ derive_builder.workspace = true derive_more.workspace = true snafu.workspace = true sql.workspace = true +uuid.workspace = true diff --git a/src/session/src/context.rs b/src/session/src/context.rs index 5f16ea8b5a..7c6350dee1 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -31,6 +31,7 @@ use common_time::timezone::parse_timezone; use datafusion_common::config::ConfigOptions; use derive_builder::Builder; use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect}; +use uuid::Uuid; use crate::protocol_ctx::ProtocolCtx; use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle, PGIntervalStyle}; @@ -40,6 +41,11 @@ pub type QueryContextRef = Arc; pub type ConnInfoRef = Arc; const CURSOR_COUNT_WARNING_LIMIT: usize = 10; +pub const REMOTE_QUERY_ID_EXTENSION_KEY: &str = "remote_query_id"; + +pub fn generate_remote_query_id() -> String { + Uuid::now_v7().to_string() +} #[derive(Debug, Builder, Clone)] #[builder(pattern = "owned")] @@ -152,7 +158,12 @@ impl From<&RegionRequestHeader> for QueryContext { if let Some(ctx) = &value.query_context { ctx.clone().into() } else { - QueryContextBuilder::default().build() + QueryContextBuilder::default() + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) + .build() } } } @@ -219,7 +230,14 @@ impl From<&QueryContext> for api::v1::QueryContext { impl QueryContext { pub fn arc() -> QueryContextRef { - Arc::new(QueryContextBuilder::default().build()) + Arc::new( + QueryContextBuilder::default() + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) + .build(), + ) } /// Create a new datafusion's ConfigOptions instance based on the current QueryContext. @@ -233,6 +251,10 @@ impl QueryContext { QueryContextBuilder::default() .current_catalog(catalog.to_string()) .current_schema(schema.to_string()) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) .build() } @@ -241,6 +263,10 @@ impl QueryContext { .current_catalog(catalog.to_string()) .current_schema(schema.to_string()) .channel(channel) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) .build() } @@ -259,6 +285,10 @@ impl QueryContext { QueryContextBuilder::default() .current_catalog(catalog) .current_schema(schema.clone()) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) .build() } @@ -320,6 +350,10 @@ impl QueryContext { self.extensions.get(key.as_ref()).map(|v| v.as_str()) } + pub fn remote_query_id(&self) -> Option<&str> { + self.extension(REMOTE_QUERY_ID_EXTENSION_KEY) + } + pub fn extensions(&self) -> HashMap { self.extensions.clone() } @@ -483,6 +517,10 @@ impl QueryContext { impl QueryContextBuilder { pub fn build(self) -> QueryContext { let channel = self.channel.unwrap_or_default(); + let mut extensions = self.extensions.unwrap_or_default(); + extensions + .entry(REMOTE_QUERY_ID_EXTENSION_KEY.to_string()) + .or_insert_with(generate_remote_query_id); QueryContext { current_catalog: self .current_catalog @@ -494,7 +532,7 @@ impl QueryContextBuilder { sql_dialect: self .sql_dialect .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})), - extensions: self.extensions.unwrap_or_default(), + extensions, configuration_parameter: self .configuration_parameter .unwrap_or_else(|| Arc::new(ConfigurationVariables::default())), @@ -707,6 +745,9 @@ mod test { assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string()); assert_eq!(100, session.process_id()); + + let query_ctx = session.new_query_context(); + assert!(query_ctx.remote_query_id().is_some()); } #[test] @@ -743,4 +784,23 @@ mod test { assert_eq!(roundtrip_api.channel, api_ctx.channel); assert_eq!(roundtrip_api.snapshot_seqs, api_ctx.snapshot_seqs); } + + #[test] + fn test_query_context_remote_query_id_round_trip() { + let query_id = "0195f4fd-c503-7c54-8b8f-7dfb8f6f9c4a"; + let ctx = QueryContextBuilder::default() + .current_catalog(DEFAULT_CATALOG_NAME.to_string()) + .current_schema("public".to_string()) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + query_id.to_string(), + ) + .build(); + + assert_eq!(ctx.remote_query_id(), Some(query_id)); + + let proto: api::v1::QueryContext = (&ctx).into(); + let restored = QueryContext::from(proto); + assert_eq!(restored.remote_query_id(), Some(query_id)); + } } diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index 8d2a3e2141..1294a3368f 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -30,7 +30,10 @@ use common_recordbatch::cursor::RecordBatchStreamCursor; pub use common_session::ReadPreference; use common_time::Timezone; use common_time::timezone::get_timezone; -use context::{ConfigurationVariables, QueryContextBuilder}; +use context::{ + ConfigurationVariables, QueryContextBuilder, REMOTE_QUERY_ID_EXTENSION_KEY, + generate_remote_query_id, +}; use derive_more::Debug; use crate::context::{Channel, ConnInfo, QueryContextRef}; @@ -106,6 +109,10 @@ impl Session { .channel(self.conn_info.channel) .process_id(self.process_id) .conn_info(self.conn_info.clone()) + .set_extension( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + generate_remote_query_id(), + ) .build() .into() }