1use std::collections::HashMap;
16use std::fmt::{Display, Formatter};
17use std::net::SocketAddr;
18use std::sync::{Arc, RwLock};
19use std::time::Duration;
20
21use api::v1::ExplainOptions;
22use api::v1::region::RegionRequestHeader;
23use arc_swap::ArcSwap;
24use auth::UserInfoRef;
25use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
27use common_recordbatch::cursor::RecordBatchStreamCursor;
28use common_telemetry::warn;
29use common_time::Timezone;
30use common_time::timezone::parse_timezone;
31use datafusion_common::config::ConfigOptions;
32use derive_builder::Builder;
33use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
34
35pub use crate::hints::REMOTE_QUERY_ID_EXTENSION_KEY;
36use crate::protocol_ctx::ProtocolCtx;
37use crate::query_id::QueryId;
38use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle, PGIntervalStyle};
39use crate::{MutableInner, ReadPreference};
40
41pub type QueryContextRef = Arc<QueryContext>;
42pub type ConnInfoRef = Arc<ConnInfo>;
43
44const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
45
46pub fn generate_remote_query_id() -> String {
47 generate_remote_query_id_value().to_string()
48}
49
50pub fn generate_remote_query_id_value() -> QueryId {
51 QueryId::new()
52}
53
54#[derive(Debug, Builder, Clone)]
55#[builder(pattern = "owned")]
56#[builder(build_fn(skip))]
57pub struct QueryContext {
58 current_catalog: String,
59 snapshot_seqs: Arc<RwLock<HashMap<u64, u64>>>,
63 sst_min_sequences: Arc<RwLock<HashMap<u64, u64>>>,
65 #[builder(default)]
67 mutable_session_data: Arc<RwLock<MutableInner>>,
68 #[builder(default)]
69 mutable_query_context_data: Arc<RwLock<QueryContextMutableFields>>,
70 sql_dialect: Arc<dyn Dialect + Send + Sync>,
71 #[builder(default)]
72 extensions: HashMap<String, String>,
73 #[builder(default)]
75 configuration_parameter: Arc<ConfigurationVariables>,
76 #[builder(default)]
78 channel: Channel,
79 #[builder(default)]
81 process_id: u32,
82 #[builder(default)]
84 conn_info: ConnInfo,
85 #[builder(default)]
87 protocol_ctx: ProtocolCtx,
88}
89
90#[derive(Debug, Builder, Clone, Default)]
92pub struct QueryContextMutableFields {
93 warning: Option<String>,
94 explain_format: Option<String>,
96 explain_options: Option<ExplainOptions>,
98}
99
100impl Display for QueryContext {
101 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
102 write!(
103 f,
104 "QueryContext{{catalog: {}, schema: {}}}",
105 self.current_catalog(),
106 self.current_schema()
107 )
108 }
109}
110
111impl QueryContextBuilder {
112 pub fn current_schema(mut self, schema: String) -> Self {
113 if self.mutable_session_data.is_none() {
114 self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
115 }
116
117 self.mutable_session_data
119 .as_mut()
120 .unwrap()
121 .write()
122 .unwrap()
123 .schema = schema;
124 self
125 }
126
127 pub fn timezone(mut self, timezone: Timezone) -> Self {
128 if self.mutable_session_data.is_none() {
129 self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
130 }
131
132 self.mutable_session_data
133 .as_mut()
134 .unwrap()
135 .write()
136 .unwrap()
137 .timezone = timezone;
138 self
139 }
140
141 pub fn explain_options(mut self, explain_options: Option<ExplainOptions>) -> Self {
142 self.mutable_query_context_data
143 .get_or_insert_default()
144 .write()
145 .unwrap()
146 .explain_options = explain_options;
147 self
148 }
149
150 pub fn read_preference(mut self, read_preference: ReadPreference) -> Self {
151 self.mutable_session_data
152 .get_or_insert_default()
153 .write()
154 .unwrap()
155 .read_preference = read_preference;
156 self
157 }
158}
159
160impl From<&RegionRequestHeader> for QueryContext {
161 fn from(value: &RegionRequestHeader) -> Self {
162 if let Some(ctx) = &value.query_context {
163 ctx.clone().into()
164 } else {
165 QueryContextBuilder::default()
166 .set_extension(
167 REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
168 generate_remote_query_id(),
169 )
170 .build()
171 }
172 }
173}
174
175impl From<api::v1::QueryContext> for QueryContext {
176 fn from(ctx: api::v1::QueryContext) -> Self {
177 let sequences = ctx.snapshot_seqs.as_ref();
178 QueryContextBuilder::default()
179 .current_catalog(ctx.current_catalog)
180 .current_schema(ctx.current_schema)
181 .timezone(parse_timezone(Some(&ctx.timezone)))
182 .extensions(ctx.extensions)
183 .channel(ctx.channel.into())
184 .snapshot_seqs(Arc::new(RwLock::new(
185 sequences
186 .map(|x| x.snapshot_seqs.clone())
187 .unwrap_or_default(),
188 )))
189 .sst_min_sequences(Arc::new(RwLock::new(
190 sequences
191 .map(|x| x.sst_min_sequences.clone())
192 .unwrap_or_default(),
193 )))
194 .explain_options(ctx.explain)
195 .build()
196 }
197}
198
199impl From<QueryContext> for api::v1::QueryContext {
200 fn from(
201 QueryContext {
202 current_catalog,
203 mutable_session_data: mutable_inner,
204 extensions,
205 channel,
206 snapshot_seqs,
207 sst_min_sequences,
208 mutable_query_context_data,
209 ..
210 }: QueryContext,
211 ) -> Self {
212 let explain = mutable_query_context_data.read().unwrap().explain_options;
213 let mutable_inner = mutable_inner.read().unwrap();
214 api::v1::QueryContext {
215 current_catalog,
216 current_schema: mutable_inner.schema.clone(),
217 timezone: mutable_inner.timezone.to_string(),
218 extensions,
219 channel: channel as u32,
220 snapshot_seqs: Some(api::v1::SnapshotSequences {
221 snapshot_seqs: snapshot_seqs.read().unwrap().clone(),
222 sst_min_sequences: sst_min_sequences.read().unwrap().clone(),
223 }),
224 explain,
225 }
226 }
227}
228
229impl From<&QueryContext> for api::v1::QueryContext {
230 fn from(ctx: &QueryContext) -> Self {
231 ctx.clone().into()
232 }
233}
234
235impl QueryContext {
236 pub fn arc() -> QueryContextRef {
237 Arc::new(
238 QueryContextBuilder::default()
239 .set_extension(
240 REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
241 generate_remote_query_id(),
242 )
243 .build(),
244 )
245 }
246
247 pub fn create_config_options(&self) -> ConfigOptions {
249 let mut config = ConfigOptions::default();
250 config.execution.time_zone = Some(self.timezone().to_string());
251 config
252 }
253
254 pub fn with(catalog: &str, schema: &str) -> QueryContext {
255 QueryContextBuilder::default()
256 .current_catalog(catalog.to_string())
257 .current_schema(schema.to_string())
258 .set_extension(
259 REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
260 generate_remote_query_id(),
261 )
262 .build()
263 }
264
265 pub fn with_channel(catalog: &str, schema: &str, channel: Channel) -> QueryContext {
266 QueryContextBuilder::default()
267 .current_catalog(catalog.to_string())
268 .current_schema(schema.to_string())
269 .channel(channel)
270 .set_extension(
271 REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
272 generate_remote_query_id(),
273 )
274 .build()
275 }
276
277 pub fn with_db_name(db_name: Option<&str>) -> QueryContext {
278 let (catalog, schema) = db_name
279 .map(|db| {
280 let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
281 (catalog, schema)
282 })
283 .unwrap_or_else(|| {
284 (
285 DEFAULT_CATALOG_NAME.to_string(),
286 DEFAULT_SCHEMA_NAME.to_string(),
287 )
288 });
289 QueryContextBuilder::default()
290 .current_catalog(catalog)
291 .current_schema(schema.clone())
292 .set_extension(
293 REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
294 generate_remote_query_id(),
295 )
296 .build()
297 }
298
299 pub fn current_schema(&self) -> String {
300 self.mutable_session_data.read().unwrap().schema.clone()
301 }
302
303 pub fn set_current_schema(&self, new_schema: &str) {
304 self.mutable_session_data.write().unwrap().schema = new_schema.to_string();
305 }
306
307 pub fn current_catalog(&self) -> &str {
308 &self.current_catalog
309 }
310
311 pub fn set_current_catalog(&mut self, new_catalog: &str) {
312 self.current_catalog = new_catalog.to_string();
313 }
314
315 pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) {
316 &*self.sql_dialect
317 }
318
319 pub fn get_db_string(&self) -> String {
320 let catalog = self.current_catalog();
321 let schema = self.current_schema();
322 build_db_string(catalog, &schema)
323 }
324
325 pub fn timezone(&self) -> Timezone {
326 self.mutable_session_data.read().unwrap().timezone.clone()
327 }
328
329 pub fn set_timezone(&self, timezone: Timezone) {
330 self.mutable_session_data.write().unwrap().timezone = timezone;
331 }
332
333 pub fn read_preference(&self) -> ReadPreference {
334 self.mutable_session_data.read().unwrap().read_preference
335 }
336
337 pub fn set_read_preference(&self, read_preference: ReadPreference) {
338 self.mutable_session_data.write().unwrap().read_preference = read_preference;
339 }
340
341 pub fn current_user(&self) -> UserInfoRef {
342 self.mutable_session_data.read().unwrap().user_info.clone()
343 }
344
345 pub fn set_current_user(&self, user: UserInfoRef) {
346 self.mutable_session_data.write().unwrap().user_info = user;
347 }
348
349 pub fn set_extension<S1: Into<String>, S2: Into<String>>(&mut self, key: S1, value: S2) {
350 self.extensions.insert(key.into(), value.into());
351 }
352
353 pub fn extension<S: AsRef<str>>(&self, key: S) -> Option<&str> {
354 self.extensions.get(key.as_ref()).map(|v| v.as_str())
355 }
356
357 pub fn remote_query_id(&self) -> Option<&str> {
358 self.extension(REMOTE_QUERY_ID_EXTENSION_KEY)
359 }
360
361 pub fn remote_query_id_value(&self) -> Option<QueryId> {
362 self.remote_query_id()
363 .and_then(|query_id| query_id.parse().ok())
364 }
365
366 pub fn extensions(&self) -> HashMap<String, String> {
367 self.extensions.clone()
368 }
369
370 pub fn quote_style(&self) -> char {
372 if self.sql_dialect().is_delimited_identifier_start('"') {
373 '"'
374 } else if self.sql_dialect().is_delimited_identifier_start('\'') {
375 '\''
376 } else {
377 '`'
378 }
379 }
380
381 pub fn configuration_parameter(&self) -> &ConfigurationVariables {
382 &self.configuration_parameter
383 }
384
385 pub fn channel(&self) -> Channel {
386 self.channel
387 }
388
389 pub fn set_channel(&mut self, channel: Channel) {
390 self.channel = channel;
391 }
392
393 pub fn warning(&self) -> Option<String> {
394 self.mutable_query_context_data
395 .read()
396 .unwrap()
397 .warning
398 .clone()
399 }
400
401 pub fn set_warning(&self, msg: String) {
402 self.mutable_query_context_data.write().unwrap().warning = Some(msg);
403 }
404
405 pub fn explain_format(&self) -> Option<String> {
406 self.mutable_query_context_data
407 .read()
408 .unwrap()
409 .explain_format
410 .clone()
411 }
412
413 pub fn set_explain_format(&self, format: String) {
414 self.mutable_query_context_data
415 .write()
416 .unwrap()
417 .explain_format = Some(format);
418 }
419
420 pub fn explain_verbose(&self) -> bool {
421 self.mutable_query_context_data
422 .read()
423 .unwrap()
424 .explain_options
425 .map(|opts| opts.verbose)
426 .unwrap_or(false)
427 }
428
429 pub fn set_explain_verbose(&self, verbose: bool) {
430 self.mutable_query_context_data
431 .write()
432 .unwrap()
433 .explain_options
434 .get_or_insert_default()
435 .verbose = verbose;
436 }
437
438 pub fn query_timeout(&self) -> Option<Duration> {
439 self.mutable_session_data.read().unwrap().query_timeout
440 }
441
442 pub fn query_timeout_as_millis(&self) -> u128 {
443 let timeout = self.mutable_session_data.read().unwrap().query_timeout;
444 if let Some(t) = timeout {
445 return t.as_millis();
446 }
447 0
448 }
449
450 pub fn set_query_timeout(&self, timeout: Duration) {
451 self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
452 }
453
454 pub fn insert_cursor(&self, name: String, rb: RecordBatchStreamCursor) {
455 let mut guard = self.mutable_session_data.write().unwrap();
456 guard.cursors.insert(name, Arc::new(rb));
457
458 let cursor_count = guard.cursors.len();
459 if cursor_count > CURSOR_COUNT_WARNING_LIMIT {
460 warn!("Current connection has {} open cursors", cursor_count);
461 }
462 }
463
464 pub fn remove_cursor(&self, name: &str) {
465 let mut guard = self.mutable_session_data.write().unwrap();
466 guard.cursors.remove(name);
467 }
468
469 pub fn get_cursor(&self, name: &str) -> Option<Arc<RecordBatchStreamCursor>> {
470 let guard = self.mutable_session_data.read().unwrap();
471 let rb = guard.cursors.get(name);
472 rb.cloned()
473 }
474
475 pub fn snapshots(&self) -> HashMap<u64, u64> {
476 self.snapshot_seqs.read().unwrap().clone()
477 }
478
479 pub fn sst_min_sequences(&self) -> HashMap<u64, u64> {
480 self.sst_min_sequences.read().unwrap().clone()
481 }
482
483 pub fn get_snapshot(&self, region_id: u64) -> Option<u64> {
484 self.snapshot_seqs.read().unwrap().get(®ion_id).cloned()
485 }
486
487 pub fn set_snapshot(&self, region_id: u64, sequence: u64) {
488 self.snapshot_seqs
489 .write()
490 .unwrap()
491 .insert(region_id, sequence);
492 }
493
494 pub fn auto_string_to_numeric(&self) -> bool {
496 matches!(self.channel, Channel::Mysql)
497 }
498
499 pub fn sst_min_sequence(&self, region_id: u64) -> Option<u64> {
501 self.sst_min_sequences
502 .read()
503 .unwrap()
504 .get(®ion_id)
505 .copied()
506 }
507
508 pub fn process_id(&self) -> u32 {
509 self.process_id
510 }
511
512 pub fn conn_info(&self) -> &ConnInfo {
514 &self.conn_info
515 }
516
517 pub fn protocol_ctx(&self) -> &ProtocolCtx {
518 &self.protocol_ctx
519 }
520
521 pub fn set_protocol_ctx(&mut self, protocol_ctx: ProtocolCtx) {
522 self.protocol_ctx = protocol_ctx;
523 }
524}
525
526impl QueryContextBuilder {
527 pub fn build(self) -> QueryContext {
528 let channel = self.channel.unwrap_or_default();
529 let mut extensions = self.extensions.unwrap_or_default();
530 extensions
531 .entry(REMOTE_QUERY_ID_EXTENSION_KEY.to_string())
532 .or_insert_with(generate_remote_query_id);
533 QueryContext {
534 current_catalog: self
535 .current_catalog
536 .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()),
537 snapshot_seqs: self.snapshot_seqs.unwrap_or_default(),
538 sst_min_sequences: self.sst_min_sequences.unwrap_or_default(),
539 mutable_session_data: self.mutable_session_data.unwrap_or_default(),
540 mutable_query_context_data: self.mutable_query_context_data.unwrap_or_default(),
541 sql_dialect: self
542 .sql_dialect
543 .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
544 extensions,
545 configuration_parameter: self
546 .configuration_parameter
547 .unwrap_or_else(|| Arc::new(ConfigurationVariables::default())),
548 channel,
549 process_id: self.process_id.unwrap_or_default(),
550 conn_info: self.conn_info.unwrap_or_default(),
551 protocol_ctx: self.protocol_ctx.unwrap_or_default(),
552 }
553 }
554
555 pub fn set_extension(mut self, key: String, value: String) -> Self {
556 self.extensions
557 .get_or_insert_with(HashMap::new)
558 .insert(key, value);
559 self
560 }
561}
562
563#[derive(Debug, Clone, Default)]
564pub struct ConnInfo {
565 pub client_addr: Option<SocketAddr>,
566 pub channel: Channel,
567}
568
569impl Display for ConnInfo {
570 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
571 write!(
572 f,
573 "{}[{}]",
574 self.channel,
575 self.client_addr
576 .map(|addr| addr.to_string())
577 .as_deref()
578 .unwrap_or("unknown client addr")
579 )
580 }
581}
582
583impl ConnInfo {
584 pub fn new(client_addr: Option<SocketAddr>, channel: Channel) -> Self {
585 Self {
586 client_addr,
587 channel,
588 }
589 }
590}
591
592#[derive(Debug, PartialEq, Default, Clone, Copy)]
593#[repr(u8)]
594pub enum Channel {
595 #[default]
596 Unknown = 0,
597
598 Mysql = 1,
599 Postgres = 2,
600 HttpSql = 3,
601 Prometheus = 4,
602 Otlp = 5,
603 Grpc = 6,
604 Influx = 7,
605 Opentsdb = 8,
606 Loki = 9,
607 Elasticsearch = 10,
608 Jaeger = 11,
609 Log = 12,
610 Promql = 13,
611}
612
613impl From<u32> for Channel {
614 fn from(value: u32) -> Self {
615 match value {
616 1 => Self::Mysql,
617 2 => Self::Postgres,
618 3 => Self::HttpSql,
619 4 => Self::Prometheus,
620 5 => Self::Otlp,
621 6 => Self::Grpc,
622 7 => Self::Influx,
623 8 => Self::Opentsdb,
624 9 => Self::Loki,
625 10 => Self::Elasticsearch,
626 11 => Self::Jaeger,
627 12 => Self::Log,
628 13 => Self::Promql,
629 _ => Self::Unknown,
630 }
631 }
632}
633
634impl Channel {
635 pub fn dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
636 match self {
637 Channel::Mysql => Arc::new(MySqlDialect {}),
638 Channel::Postgres => Arc::new(PostgreSqlDialect {}),
639 _ => Arc::new(GenericDialect {}),
640 }
641 }
642}
643
644impl Display for Channel {
645 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
646 write!(f, "{}", self.as_ref())
647 }
648}
649
650impl AsRef<str> for Channel {
651 fn as_ref(&self) -> &str {
652 match self {
653 Channel::Mysql => "mysql",
654 Channel::Postgres => "postgres",
655 Channel::HttpSql => "httpsql",
656 Channel::Prometheus => "prometheus",
657 Channel::Otlp => "otlp",
658 Channel::Grpc => "grpc",
659 Channel::Influx => "influx",
660 Channel::Opentsdb => "opentsdb",
661 Channel::Loki => "loki",
662 Channel::Elasticsearch => "elasticsearch",
663 Channel::Jaeger => "jaeger",
664 Channel::Log => "log",
665 Channel::Promql => "promql",
666 Channel::Unknown => "unknown",
667 }
668 }
669}
670
671#[derive(Default, Debug)]
672pub struct ConfigurationVariables {
673 postgres_bytea_output: ArcSwap<PGByteaOutputValue>,
674 pg_datestyle_format: ArcSwap<(PGDateTimeStyle, PGDateOrder)>,
675 pg_intervalstyle_format: ArcSwap<PGIntervalStyle>,
676 allow_query_fallback: ArcSwap<bool>,
677}
678
679impl Clone for ConfigurationVariables {
680 fn clone(&self) -> Self {
681 Self {
682 postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()),
683 pg_datestyle_format: ArcSwap::new(self.pg_datestyle_format.load().clone()),
684 pg_intervalstyle_format: ArcSwap::new(self.pg_intervalstyle_format.load().clone()),
685 allow_query_fallback: ArcSwap::new(self.allow_query_fallback.load().clone()),
686 }
687 }
688}
689
690impl ConfigurationVariables {
691 pub fn new() -> Self {
692 Self::default()
693 }
694
695 pub fn set_postgres_bytea_output(&self, value: PGByteaOutputValue) {
696 let _ = self.postgres_bytea_output.swap(Arc::new(value));
697 }
698
699 pub fn postgres_bytea_output(&self) -> Arc<PGByteaOutputValue> {
700 self.postgres_bytea_output.load().clone()
701 }
702
703 pub fn pg_datetime_style(&self) -> Arc<(PGDateTimeStyle, PGDateOrder)> {
704 self.pg_datestyle_format.load().clone()
705 }
706
707 pub fn set_pg_datetime_style(&self, style: PGDateTimeStyle, order: PGDateOrder) {
708 self.pg_datestyle_format.swap(Arc::new((style, order)));
709 }
710
711 pub fn pg_intervalstyle_format(&self) -> Arc<PGIntervalStyle> {
712 self.pg_intervalstyle_format.load().clone()
713 }
714
715 pub fn set_pg_intervalstyle_format(&self, value: PGIntervalStyle) {
716 self.pg_intervalstyle_format.swap(Arc::new(value));
717 }
718
719 pub fn allow_query_fallback(&self) -> bool {
720 **self.allow_query_fallback.load()
721 }
722
723 pub fn set_allow_query_fallback(&self, allow: bool) {
724 self.allow_query_fallback.swap(Arc::new(allow));
725 }
726}
727
728#[cfg(test)]
729mod test {
730 use std::collections::HashMap;
731
732 use common_catalog::consts::DEFAULT_CATALOG_NAME;
733
734 use super::*;
735 use crate::Session;
736 use crate::context::Channel;
737
738 #[test]
739 fn test_session() {
740 let session = Session::new(
741 Some("127.0.0.1:9000".parse().unwrap()),
742 Channel::Mysql,
743 Default::default(),
744 100,
745 );
746 assert_eq!(session.user_info().username(), "greptime");
748
749 assert_eq!(session.conn_info().channel, Channel::Mysql);
751 let client_addr = session.conn_info().client_addr.as_ref().unwrap();
752 assert_eq!(client_addr.ip().to_string(), "127.0.0.1");
753 assert_eq!(client_addr.port(), 9000);
754
755 assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string());
756 assert_eq!(100, session.process_id());
757
758 let query_ctx = session.new_query_context();
759 assert!(query_ctx.remote_query_id().is_some());
760 }
761
762 #[test]
763 fn test_context_db_string() {
764 let context = QueryContext::with("a0b1c2d3", "test");
765 assert_eq!("a0b1c2d3-test", context.get_db_string());
766
767 let context = QueryContext::with(DEFAULT_CATALOG_NAME, "test");
768 assert_eq!("test", context.get_db_string());
769 }
770
771 #[test]
772 fn test_api_query_context_roundtrip_with_sequences() {
773 let api_ctx = api::v1::QueryContext {
774 current_catalog: "c1".to_string(),
775 current_schema: "s1".to_string(),
776 timezone: "UTC".to_string(),
777 extensions: HashMap::from([("flow.return_region_seq".to_string(), "true".to_string())]),
778 channel: Channel::Grpc as u32,
779 snapshot_seqs: Some(api::v1::SnapshotSequences {
780 snapshot_seqs: HashMap::from([(1, 100)]),
781 sst_min_sequences: HashMap::from([(1, 90)]),
782 }),
783 explain: None,
784 };
785
786 let session_ctx: QueryContext = api_ctx.clone().into();
787 let roundtrip_api: api::v1::QueryContext = session_ctx.into();
788
789 assert_eq!(roundtrip_api.current_catalog, api_ctx.current_catalog);
790 assert_eq!(roundtrip_api.current_schema, api_ctx.current_schema);
791 assert_eq!(roundtrip_api.timezone, api_ctx.timezone);
792 assert_eq!(
793 roundtrip_api.extensions.get("flow.return_region_seq"),
794 Some(&"true".to_string())
795 );
796 assert!(
797 roundtrip_api
798 .extensions
799 .contains_key(REMOTE_QUERY_ID_EXTENSION_KEY)
800 );
801 assert_eq!(roundtrip_api.channel, api_ctx.channel);
802 assert_eq!(roundtrip_api.snapshot_seqs, api_ctx.snapshot_seqs);
803 }
804
805 #[test]
806 fn test_query_context_remote_query_id_round_trip() {
807 let query_id = "0195f4fd-c503-7c54-8b8f-7dfb8f6f9c4a";
808 let ctx = QueryContextBuilder::default()
809 .current_catalog(DEFAULT_CATALOG_NAME.to_string())
810 .current_schema("public".to_string())
811 .set_extension(
812 REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
813 query_id.to_string(),
814 )
815 .build();
816
817 assert_eq!(ctx.remote_query_id(), Some(query_id));
818 assert_eq!(ctx.remote_query_id_value().unwrap().to_string(), query_id);
819
820 let proto: api::v1::QueryContext = (&ctx).into();
821 let restored = QueryContext::from(proto);
822 assert_eq!(restored.remote_query_id(), Some(query_id));
823 assert_eq!(
824 restored.remote_query_id_value().unwrap().to_string(),
825 query_id
826 );
827 }
828}