Skip to main content

session/
context.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::{Display, Formatter};
17use std::net::SocketAddr;
18use std::sync::{Arc, RwLock};
19use std::time::Duration;
20
21use api::v1::ExplainOptions;
22use api::v1::region::RegionRequestHeader;
23use arc_swap::ArcSwap;
24use auth::UserInfoRef;
25use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
27use common_recordbatch::cursor::RecordBatchStreamCursor;
28use common_telemetry::warn;
29use common_time::Timezone;
30use common_time::timezone::parse_timezone;
31use datafusion_common::config::ConfigOptions;
32use derive_builder::Builder;
33use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
34
35pub use crate::hints::REMOTE_QUERY_ID_EXTENSION_KEY;
36use crate::protocol_ctx::ProtocolCtx;
37use crate::query_id::QueryId;
38use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle, PGIntervalStyle};
39use crate::{MutableInner, ReadPreference};
40
41pub type QueryContextRef = Arc<QueryContext>;
42pub type ConnInfoRef = Arc<ConnInfo>;
43
44const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
45
46pub fn generate_remote_query_id() -> String {
47    generate_remote_query_id_value().to_string()
48}
49
50pub fn generate_remote_query_id_value() -> QueryId {
51    QueryId::new()
52}
53
54#[derive(Debug, Builder, Clone)]
55#[builder(pattern = "owned")]
56#[builder(build_fn(skip))]
57pub struct QueryContext {
58    current_catalog: String,
59    /// mapping of RegionId to SequenceNumber, for snapshot read, meaning that the read should only
60    /// container data that was committed before(and include) the given sequence number
61    /// this field will only be filled if extensions contains a pair of "snapshot_read" and "true"
62    snapshot_seqs: Arc<RwLock<HashMap<u64, u64>>>,
63    /// Mappings of the RegionId to the minimal sequence of SST file to scan.
64    sst_min_sequences: Arc<RwLock<HashMap<u64, u64>>>,
65    // we use Arc<RwLock>> for modifiable fields
66    #[builder(default)]
67    mutable_session_data: Arc<RwLock<MutableInner>>,
68    #[builder(default)]
69    mutable_query_context_data: Arc<RwLock<QueryContextMutableFields>>,
70    sql_dialect: Arc<dyn Dialect + Send + Sync>,
71    #[builder(default)]
72    extensions: HashMap<String, String>,
73    /// The configuration parameter are used to store the parameters that are set by the user
74    #[builder(default)]
75    configuration_parameter: Arc<ConfigurationVariables>,
76    /// Track which protocol the query comes from.
77    #[builder(default)]
78    channel: Channel,
79    /// Process id for managing on-going queries
80    #[builder(default)]
81    process_id: u32,
82    /// Connection information
83    #[builder(default)]
84    conn_info: ConnInfo,
85    /// Protocol specific context
86    #[builder(default)]
87    protocol_ctx: ProtocolCtx,
88}
89
90/// This fields hold data that is only valid to current query context
91#[derive(Debug, Builder, Clone, Default)]
92pub struct QueryContextMutableFields {
93    warning: Option<String>,
94    // TODO: remove this when format is supported in datafusion
95    explain_format: Option<String>,
96    /// Explain options to control the verbose analyze output.
97    explain_options: Option<ExplainOptions>,
98}
99
100impl Display for QueryContext {
101    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
102        write!(
103            f,
104            "QueryContext{{catalog: {}, schema: {}}}",
105            self.current_catalog(),
106            self.current_schema()
107        )
108    }
109}
110
111impl QueryContextBuilder {
112    pub fn current_schema(mut self, schema: String) -> Self {
113        if self.mutable_session_data.is_none() {
114            self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
115        }
116
117        // safe for unwrap because previous none check
118        self.mutable_session_data
119            .as_mut()
120            .unwrap()
121            .write()
122            .unwrap()
123            .schema = schema;
124        self
125    }
126
127    pub fn timezone(mut self, timezone: Timezone) -> Self {
128        if self.mutable_session_data.is_none() {
129            self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
130        }
131
132        self.mutable_session_data
133            .as_mut()
134            .unwrap()
135            .write()
136            .unwrap()
137            .timezone = timezone;
138        self
139    }
140
141    pub fn explain_options(mut self, explain_options: Option<ExplainOptions>) -> Self {
142        self.mutable_query_context_data
143            .get_or_insert_default()
144            .write()
145            .unwrap()
146            .explain_options = explain_options;
147        self
148    }
149
150    pub fn read_preference(mut self, read_preference: ReadPreference) -> Self {
151        self.mutable_session_data
152            .get_or_insert_default()
153            .write()
154            .unwrap()
155            .read_preference = read_preference;
156        self
157    }
158}
159
160impl From<&RegionRequestHeader> for QueryContext {
161    fn from(value: &RegionRequestHeader) -> Self {
162        if let Some(ctx) = &value.query_context {
163            ctx.clone().into()
164        } else {
165            QueryContextBuilder::default()
166                .set_extension(
167                    REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
168                    generate_remote_query_id(),
169                )
170                .build()
171        }
172    }
173}
174
175impl From<api::v1::QueryContext> for QueryContext {
176    fn from(ctx: api::v1::QueryContext) -> Self {
177        let sequences = ctx.snapshot_seqs.as_ref();
178        QueryContextBuilder::default()
179            .current_catalog(ctx.current_catalog)
180            .current_schema(ctx.current_schema)
181            .timezone(parse_timezone(Some(&ctx.timezone)))
182            .extensions(ctx.extensions)
183            .channel(ctx.channel.into())
184            .snapshot_seqs(Arc::new(RwLock::new(
185                sequences
186                    .map(|x| x.snapshot_seqs.clone())
187                    .unwrap_or_default(),
188            )))
189            .sst_min_sequences(Arc::new(RwLock::new(
190                sequences
191                    .map(|x| x.sst_min_sequences.clone())
192                    .unwrap_or_default(),
193            )))
194            .explain_options(ctx.explain)
195            .build()
196    }
197}
198
199impl From<QueryContext> for api::v1::QueryContext {
200    fn from(
201        QueryContext {
202            current_catalog,
203            mutable_session_data: mutable_inner,
204            extensions,
205            channel,
206            snapshot_seqs,
207            sst_min_sequences,
208            mutable_query_context_data,
209            ..
210        }: QueryContext,
211    ) -> Self {
212        let explain = mutable_query_context_data.read().unwrap().explain_options;
213        let mutable_inner = mutable_inner.read().unwrap();
214        api::v1::QueryContext {
215            current_catalog,
216            current_schema: mutable_inner.schema.clone(),
217            timezone: mutable_inner.timezone.to_string(),
218            extensions,
219            channel: channel as u32,
220            snapshot_seqs: Some(api::v1::SnapshotSequences {
221                snapshot_seqs: snapshot_seqs.read().unwrap().clone(),
222                sst_min_sequences: sst_min_sequences.read().unwrap().clone(),
223            }),
224            explain,
225        }
226    }
227}
228
229impl From<&QueryContext> for api::v1::QueryContext {
230    fn from(ctx: &QueryContext) -> Self {
231        ctx.clone().into()
232    }
233}
234
235impl QueryContext {
236    pub fn arc() -> QueryContextRef {
237        Arc::new(
238            QueryContextBuilder::default()
239                .set_extension(
240                    REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
241                    generate_remote_query_id(),
242                )
243                .build(),
244        )
245    }
246
247    /// Create a new  datafusion's ConfigOptions instance based on the current QueryContext.
248    pub fn create_config_options(&self) -> ConfigOptions {
249        let mut config = ConfigOptions::default();
250        config.execution.time_zone = Some(self.timezone().to_string());
251        config
252    }
253
254    pub fn with(catalog: &str, schema: &str) -> QueryContext {
255        QueryContextBuilder::default()
256            .current_catalog(catalog.to_string())
257            .current_schema(schema.to_string())
258            .set_extension(
259                REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
260                generate_remote_query_id(),
261            )
262            .build()
263    }
264
265    pub fn with_channel(catalog: &str, schema: &str, channel: Channel) -> QueryContext {
266        QueryContextBuilder::default()
267            .current_catalog(catalog.to_string())
268            .current_schema(schema.to_string())
269            .channel(channel)
270            .set_extension(
271                REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
272                generate_remote_query_id(),
273            )
274            .build()
275    }
276
277    pub fn with_db_name(db_name: Option<&str>) -> QueryContext {
278        let (catalog, schema) = db_name
279            .map(|db| {
280                let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
281                (catalog, schema)
282            })
283            .unwrap_or_else(|| {
284                (
285                    DEFAULT_CATALOG_NAME.to_string(),
286                    DEFAULT_SCHEMA_NAME.to_string(),
287                )
288            });
289        QueryContextBuilder::default()
290            .current_catalog(catalog)
291            .current_schema(schema.clone())
292            .set_extension(
293                REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
294                generate_remote_query_id(),
295            )
296            .build()
297    }
298
299    pub fn current_schema(&self) -> String {
300        self.mutable_session_data.read().unwrap().schema.clone()
301    }
302
303    pub fn set_current_schema(&self, new_schema: &str) {
304        self.mutable_session_data.write().unwrap().schema = new_schema.to_string();
305    }
306
307    pub fn current_catalog(&self) -> &str {
308        &self.current_catalog
309    }
310
311    pub fn set_current_catalog(&mut self, new_catalog: &str) {
312        self.current_catalog = new_catalog.to_string();
313    }
314
315    pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) {
316        &*self.sql_dialect
317    }
318
319    pub fn get_db_string(&self) -> String {
320        let catalog = self.current_catalog();
321        let schema = self.current_schema();
322        build_db_string(catalog, &schema)
323    }
324
325    pub fn timezone(&self) -> Timezone {
326        self.mutable_session_data.read().unwrap().timezone.clone()
327    }
328
329    pub fn set_timezone(&self, timezone: Timezone) {
330        self.mutable_session_data.write().unwrap().timezone = timezone;
331    }
332
333    pub fn read_preference(&self) -> ReadPreference {
334        self.mutable_session_data.read().unwrap().read_preference
335    }
336
337    pub fn set_read_preference(&self, read_preference: ReadPreference) {
338        self.mutable_session_data.write().unwrap().read_preference = read_preference;
339    }
340
341    pub fn current_user(&self) -> UserInfoRef {
342        self.mutable_session_data.read().unwrap().user_info.clone()
343    }
344
345    pub fn set_current_user(&self, user: UserInfoRef) {
346        self.mutable_session_data.write().unwrap().user_info = user;
347    }
348
349    pub fn set_extension<S1: Into<String>, S2: Into<String>>(&mut self, key: S1, value: S2) {
350        self.extensions.insert(key.into(), value.into());
351    }
352
353    pub fn extension<S: AsRef<str>>(&self, key: S) -> Option<&str> {
354        self.extensions.get(key.as_ref()).map(|v| v.as_str())
355    }
356
357    pub fn remote_query_id(&self) -> Option<&str> {
358        self.extension(REMOTE_QUERY_ID_EXTENSION_KEY)
359    }
360
361    pub fn remote_query_id_value(&self) -> Option<QueryId> {
362        self.remote_query_id()
363            .and_then(|query_id| query_id.parse().ok())
364    }
365
366    pub fn extensions(&self) -> HashMap<String, String> {
367        self.extensions.clone()
368    }
369
370    /// Default to double quote and fallback to back quote
371    pub fn quote_style(&self) -> char {
372        if self.sql_dialect().is_delimited_identifier_start('"') {
373            '"'
374        } else if self.sql_dialect().is_delimited_identifier_start('\'') {
375            '\''
376        } else {
377            '`'
378        }
379    }
380
381    pub fn configuration_parameter(&self) -> &ConfigurationVariables {
382        &self.configuration_parameter
383    }
384
385    pub fn channel(&self) -> Channel {
386        self.channel
387    }
388
389    pub fn set_channel(&mut self, channel: Channel) {
390        self.channel = channel;
391    }
392
393    pub fn warning(&self) -> Option<String> {
394        self.mutable_query_context_data
395            .read()
396            .unwrap()
397            .warning
398            .clone()
399    }
400
401    pub fn set_warning(&self, msg: String) {
402        self.mutable_query_context_data.write().unwrap().warning = Some(msg);
403    }
404
405    pub fn explain_format(&self) -> Option<String> {
406        self.mutable_query_context_data
407            .read()
408            .unwrap()
409            .explain_format
410            .clone()
411    }
412
413    pub fn set_explain_format(&self, format: String) {
414        self.mutable_query_context_data
415            .write()
416            .unwrap()
417            .explain_format = Some(format);
418    }
419
420    pub fn explain_verbose(&self) -> bool {
421        self.mutable_query_context_data
422            .read()
423            .unwrap()
424            .explain_options
425            .map(|opts| opts.verbose)
426            .unwrap_or(false)
427    }
428
429    pub fn set_explain_verbose(&self, verbose: bool) {
430        self.mutable_query_context_data
431            .write()
432            .unwrap()
433            .explain_options
434            .get_or_insert_default()
435            .verbose = verbose;
436    }
437
438    pub fn query_timeout(&self) -> Option<Duration> {
439        self.mutable_session_data.read().unwrap().query_timeout
440    }
441
442    pub fn query_timeout_as_millis(&self) -> u128 {
443        let timeout = self.mutable_session_data.read().unwrap().query_timeout;
444        if let Some(t) = timeout {
445            return t.as_millis();
446        }
447        0
448    }
449
450    pub fn set_query_timeout(&self, timeout: Duration) {
451        self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
452    }
453
454    pub fn insert_cursor(&self, name: String, rb: RecordBatchStreamCursor) {
455        let mut guard = self.mutable_session_data.write().unwrap();
456        guard.cursors.insert(name, Arc::new(rb));
457
458        let cursor_count = guard.cursors.len();
459        if cursor_count > CURSOR_COUNT_WARNING_LIMIT {
460            warn!("Current connection has {} open cursors", cursor_count);
461        }
462    }
463
464    pub fn remove_cursor(&self, name: &str) {
465        let mut guard = self.mutable_session_data.write().unwrap();
466        guard.cursors.remove(name);
467    }
468
469    pub fn get_cursor(&self, name: &str) -> Option<Arc<RecordBatchStreamCursor>> {
470        let guard = self.mutable_session_data.read().unwrap();
471        let rb = guard.cursors.get(name);
472        rb.cloned()
473    }
474
475    pub fn snapshots(&self) -> HashMap<u64, u64> {
476        self.snapshot_seqs.read().unwrap().clone()
477    }
478
479    pub fn sst_min_sequences(&self) -> HashMap<u64, u64> {
480        self.sst_min_sequences.read().unwrap().clone()
481    }
482
483    pub fn get_snapshot(&self, region_id: u64) -> Option<u64> {
484        self.snapshot_seqs.read().unwrap().get(&region_id).cloned()
485    }
486
487    pub fn set_snapshot(&self, region_id: u64, sequence: u64) {
488        self.snapshot_seqs
489            .write()
490            .unwrap()
491            .insert(region_id, sequence);
492    }
493
494    /// Returns `true` if the session can cast strings to numbers in MySQL style.
495    pub fn auto_string_to_numeric(&self) -> bool {
496        matches!(self.channel, Channel::Mysql)
497    }
498
499    /// Finds the minimal sequence of SST files to scan of a Region.
500    pub fn sst_min_sequence(&self, region_id: u64) -> Option<u64> {
501        self.sst_min_sequences
502            .read()
503            .unwrap()
504            .get(&region_id)
505            .copied()
506    }
507
508    pub fn process_id(&self) -> u32 {
509        self.process_id
510    }
511
512    /// Get client information
513    pub fn conn_info(&self) -> &ConnInfo {
514        &self.conn_info
515    }
516
517    pub fn protocol_ctx(&self) -> &ProtocolCtx {
518        &self.protocol_ctx
519    }
520
521    pub fn set_protocol_ctx(&mut self, protocol_ctx: ProtocolCtx) {
522        self.protocol_ctx = protocol_ctx;
523    }
524}
525
526impl QueryContextBuilder {
527    pub fn build(self) -> QueryContext {
528        let channel = self.channel.unwrap_or_default();
529        let mut extensions = self.extensions.unwrap_or_default();
530        extensions
531            .entry(REMOTE_QUERY_ID_EXTENSION_KEY.to_string())
532            .or_insert_with(generate_remote_query_id);
533        QueryContext {
534            current_catalog: self
535                .current_catalog
536                .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()),
537            snapshot_seqs: self.snapshot_seqs.unwrap_or_default(),
538            sst_min_sequences: self.sst_min_sequences.unwrap_or_default(),
539            mutable_session_data: self.mutable_session_data.unwrap_or_default(),
540            mutable_query_context_data: self.mutable_query_context_data.unwrap_or_default(),
541            sql_dialect: self
542                .sql_dialect
543                .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
544            extensions,
545            configuration_parameter: self
546                .configuration_parameter
547                .unwrap_or_else(|| Arc::new(ConfigurationVariables::default())),
548            channel,
549            process_id: self.process_id.unwrap_or_default(),
550            conn_info: self.conn_info.unwrap_or_default(),
551            protocol_ctx: self.protocol_ctx.unwrap_or_default(),
552        }
553    }
554
555    pub fn set_extension(mut self, key: String, value: String) -> Self {
556        self.extensions
557            .get_or_insert_with(HashMap::new)
558            .insert(key, value);
559        self
560    }
561}
562
563#[derive(Debug, Clone, Default)]
564pub struct ConnInfo {
565    pub client_addr: Option<SocketAddr>,
566    pub channel: Channel,
567}
568
569impl Display for ConnInfo {
570    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
571        write!(
572            f,
573            "{}[{}]",
574            self.channel,
575            self.client_addr
576                .map(|addr| addr.to_string())
577                .as_deref()
578                .unwrap_or("unknown client addr")
579        )
580    }
581}
582
583impl ConnInfo {
584    pub fn new(client_addr: Option<SocketAddr>, channel: Channel) -> Self {
585        Self {
586            client_addr,
587            channel,
588        }
589    }
590}
591
592#[derive(Debug, PartialEq, Default, Clone, Copy)]
593#[repr(u8)]
594pub enum Channel {
595    #[default]
596    Unknown = 0,
597
598    Mysql = 1,
599    Postgres = 2,
600    HttpSql = 3,
601    Prometheus = 4,
602    Otlp = 5,
603    Grpc = 6,
604    Influx = 7,
605    Opentsdb = 8,
606    Loki = 9,
607    Elasticsearch = 10,
608    Jaeger = 11,
609    Log = 12,
610    Promql = 13,
611}
612
613impl From<u32> for Channel {
614    fn from(value: u32) -> Self {
615        match value {
616            1 => Self::Mysql,
617            2 => Self::Postgres,
618            3 => Self::HttpSql,
619            4 => Self::Prometheus,
620            5 => Self::Otlp,
621            6 => Self::Grpc,
622            7 => Self::Influx,
623            8 => Self::Opentsdb,
624            9 => Self::Loki,
625            10 => Self::Elasticsearch,
626            11 => Self::Jaeger,
627            12 => Self::Log,
628            13 => Self::Promql,
629            _ => Self::Unknown,
630        }
631    }
632}
633
634impl Channel {
635    pub fn dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
636        match self {
637            Channel::Mysql => Arc::new(MySqlDialect {}),
638            Channel::Postgres => Arc::new(PostgreSqlDialect {}),
639            _ => Arc::new(GenericDialect {}),
640        }
641    }
642}
643
644impl Display for Channel {
645    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
646        write!(f, "{}", self.as_ref())
647    }
648}
649
650impl AsRef<str> for Channel {
651    fn as_ref(&self) -> &str {
652        match self {
653            Channel::Mysql => "mysql",
654            Channel::Postgres => "postgres",
655            Channel::HttpSql => "httpsql",
656            Channel::Prometheus => "prometheus",
657            Channel::Otlp => "otlp",
658            Channel::Grpc => "grpc",
659            Channel::Influx => "influx",
660            Channel::Opentsdb => "opentsdb",
661            Channel::Loki => "loki",
662            Channel::Elasticsearch => "elasticsearch",
663            Channel::Jaeger => "jaeger",
664            Channel::Log => "log",
665            Channel::Promql => "promql",
666            Channel::Unknown => "unknown",
667        }
668    }
669}
670
671#[derive(Default, Debug)]
672pub struct ConfigurationVariables {
673    postgres_bytea_output: ArcSwap<PGByteaOutputValue>,
674    pg_datestyle_format: ArcSwap<(PGDateTimeStyle, PGDateOrder)>,
675    pg_intervalstyle_format: ArcSwap<PGIntervalStyle>,
676    allow_query_fallback: ArcSwap<bool>,
677}
678
679impl Clone for ConfigurationVariables {
680    fn clone(&self) -> Self {
681        Self {
682            postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()),
683            pg_datestyle_format: ArcSwap::new(self.pg_datestyle_format.load().clone()),
684            pg_intervalstyle_format: ArcSwap::new(self.pg_intervalstyle_format.load().clone()),
685            allow_query_fallback: ArcSwap::new(self.allow_query_fallback.load().clone()),
686        }
687    }
688}
689
690impl ConfigurationVariables {
691    pub fn new() -> Self {
692        Self::default()
693    }
694
695    pub fn set_postgres_bytea_output(&self, value: PGByteaOutputValue) {
696        let _ = self.postgres_bytea_output.swap(Arc::new(value));
697    }
698
699    pub fn postgres_bytea_output(&self) -> Arc<PGByteaOutputValue> {
700        self.postgres_bytea_output.load().clone()
701    }
702
703    pub fn pg_datetime_style(&self) -> Arc<(PGDateTimeStyle, PGDateOrder)> {
704        self.pg_datestyle_format.load().clone()
705    }
706
707    pub fn set_pg_datetime_style(&self, style: PGDateTimeStyle, order: PGDateOrder) {
708        self.pg_datestyle_format.swap(Arc::new((style, order)));
709    }
710
711    pub fn pg_intervalstyle_format(&self) -> Arc<PGIntervalStyle> {
712        self.pg_intervalstyle_format.load().clone()
713    }
714
715    pub fn set_pg_intervalstyle_format(&self, value: PGIntervalStyle) {
716        self.pg_intervalstyle_format.swap(Arc::new(value));
717    }
718
719    pub fn allow_query_fallback(&self) -> bool {
720        **self.allow_query_fallback.load()
721    }
722
723    pub fn set_allow_query_fallback(&self, allow: bool) {
724        self.allow_query_fallback.swap(Arc::new(allow));
725    }
726}
727
728#[cfg(test)]
729mod test {
730    use std::collections::HashMap;
731
732    use common_catalog::consts::DEFAULT_CATALOG_NAME;
733
734    use super::*;
735    use crate::Session;
736    use crate::context::Channel;
737
738    #[test]
739    fn test_session() {
740        let session = Session::new(
741            Some("127.0.0.1:9000".parse().unwrap()),
742            Channel::Mysql,
743            Default::default(),
744            100,
745        );
746        // test user_info
747        assert_eq!(session.user_info().username(), "greptime");
748
749        // test channel
750        assert_eq!(session.conn_info().channel, Channel::Mysql);
751        let client_addr = session.conn_info().client_addr.as_ref().unwrap();
752        assert_eq!(client_addr.ip().to_string(), "127.0.0.1");
753        assert_eq!(client_addr.port(), 9000);
754
755        assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string());
756        assert_eq!(100, session.process_id());
757
758        let query_ctx = session.new_query_context();
759        assert!(query_ctx.remote_query_id().is_some());
760    }
761
762    #[test]
763    fn test_context_db_string() {
764        let context = QueryContext::with("a0b1c2d3", "test");
765        assert_eq!("a0b1c2d3-test", context.get_db_string());
766
767        let context = QueryContext::with(DEFAULT_CATALOG_NAME, "test");
768        assert_eq!("test", context.get_db_string());
769    }
770
771    #[test]
772    fn test_api_query_context_roundtrip_with_sequences() {
773        let api_ctx = api::v1::QueryContext {
774            current_catalog: "c1".to_string(),
775            current_schema: "s1".to_string(),
776            timezone: "UTC".to_string(),
777            extensions: HashMap::from([("flow.return_region_seq".to_string(), "true".to_string())]),
778            channel: Channel::Grpc as u32,
779            snapshot_seqs: Some(api::v1::SnapshotSequences {
780                snapshot_seqs: HashMap::from([(1, 100)]),
781                sst_min_sequences: HashMap::from([(1, 90)]),
782            }),
783            explain: None,
784        };
785
786        let session_ctx: QueryContext = api_ctx.clone().into();
787        let roundtrip_api: api::v1::QueryContext = session_ctx.into();
788
789        assert_eq!(roundtrip_api.current_catalog, api_ctx.current_catalog);
790        assert_eq!(roundtrip_api.current_schema, api_ctx.current_schema);
791        assert_eq!(roundtrip_api.timezone, api_ctx.timezone);
792        assert_eq!(
793            roundtrip_api.extensions.get("flow.return_region_seq"),
794            Some(&"true".to_string())
795        );
796        assert!(
797            roundtrip_api
798                .extensions
799                .contains_key(REMOTE_QUERY_ID_EXTENSION_KEY)
800        );
801        assert_eq!(roundtrip_api.channel, api_ctx.channel);
802        assert_eq!(roundtrip_api.snapshot_seqs, api_ctx.snapshot_seqs);
803    }
804
805    #[test]
806    fn test_query_context_remote_query_id_round_trip() {
807        let query_id = "0195f4fd-c503-7c54-8b8f-7dfb8f6f9c4a";
808        let ctx = QueryContextBuilder::default()
809            .current_catalog(DEFAULT_CATALOG_NAME.to_string())
810            .current_schema("public".to_string())
811            .set_extension(
812                REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
813                query_id.to_string(),
814            )
815            .build();
816
817        assert_eq!(ctx.remote_query_id(), Some(query_id));
818        assert_eq!(ctx.remote_query_id_value().unwrap().to_string(), query_id);
819
820        let proto: api::v1::QueryContext = (&ctx).into();
821        let restored = QueryContext::from(proto);
822        assert_eq!(restored.remote_query_id(), Some(query_id));
823        assert_eq!(
824            restored.remote_query_id_value().unwrap().to_string(),
825            query_id
826        );
827    }
828}