1use std::collections::HashMap;
16use std::fmt::{Display, Formatter};
17use std::net::SocketAddr;
18use std::sync::{Arc, RwLock};
19use std::time::Duration;
20
21use api::v1::ExplainOptions;
22use api::v1::region::RegionRequestHeader;
23use arc_swap::ArcSwap;
24use auth::UserInfoRef;
25use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
27use common_recordbatch::cursor::RecordBatchStreamCursor;
28use common_telemetry::warn;
29use common_time::Timezone;
30use common_time::timezone::parse_timezone;
31use datafusion_common::config::ConfigOptions;
32use derive_builder::Builder;
33use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
34
35use crate::protocol_ctx::ProtocolCtx;
36use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle, PGIntervalStyle};
37use crate::{MutableInner, ReadPreference};
38
39pub type QueryContextRef = Arc<QueryContext>;
40pub type ConnInfoRef = Arc<ConnInfo>;
41
42const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
43
44#[derive(Debug, Builder, Clone)]
45#[builder(pattern = "owned")]
46#[builder(build_fn(skip))]
47pub struct QueryContext {
48 current_catalog: String,
49 snapshot_seqs: Arc<RwLock<HashMap<u64, u64>>>,
53 sst_min_sequences: Arc<RwLock<HashMap<u64, u64>>>,
55 #[builder(default)]
57 mutable_session_data: Arc<RwLock<MutableInner>>,
58 #[builder(default)]
59 mutable_query_context_data: Arc<RwLock<QueryContextMutableFields>>,
60 sql_dialect: Arc<dyn Dialect + Send + Sync>,
61 #[builder(default)]
62 extensions: HashMap<String, String>,
63 #[builder(default)]
65 configuration_parameter: Arc<ConfigurationVariables>,
66 #[builder(default)]
68 channel: Channel,
69 #[builder(default)]
71 process_id: u32,
72 #[builder(default)]
74 conn_info: ConnInfo,
75 #[builder(default)]
77 protocol_ctx: ProtocolCtx,
78}
79
80#[derive(Debug, Builder, Clone, Default)]
82pub struct QueryContextMutableFields {
83 warning: Option<String>,
84 explain_format: Option<String>,
86 explain_options: Option<ExplainOptions>,
88}
89
90impl Display for QueryContext {
91 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
92 write!(
93 f,
94 "QueryContext{{catalog: {}, schema: {}}}",
95 self.current_catalog(),
96 self.current_schema()
97 )
98 }
99}
100
101impl QueryContextBuilder {
102 pub fn current_schema(mut self, schema: String) -> Self {
103 if self.mutable_session_data.is_none() {
104 self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
105 }
106
107 self.mutable_session_data
109 .as_mut()
110 .unwrap()
111 .write()
112 .unwrap()
113 .schema = schema;
114 self
115 }
116
117 pub fn timezone(mut self, timezone: Timezone) -> Self {
118 if self.mutable_session_data.is_none() {
119 self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
120 }
121
122 self.mutable_session_data
123 .as_mut()
124 .unwrap()
125 .write()
126 .unwrap()
127 .timezone = timezone;
128 self
129 }
130
131 pub fn explain_options(mut self, explain_options: Option<ExplainOptions>) -> Self {
132 self.mutable_query_context_data
133 .get_or_insert_default()
134 .write()
135 .unwrap()
136 .explain_options = explain_options;
137 self
138 }
139
140 pub fn read_preference(mut self, read_preference: ReadPreference) -> Self {
141 self.mutable_session_data
142 .get_or_insert_default()
143 .write()
144 .unwrap()
145 .read_preference = read_preference;
146 self
147 }
148}
149
150impl From<&RegionRequestHeader> for QueryContext {
151 fn from(value: &RegionRequestHeader) -> Self {
152 if let Some(ctx) = &value.query_context {
153 ctx.clone().into()
154 } else {
155 QueryContextBuilder::default().build()
156 }
157 }
158}
159
160impl From<api::v1::QueryContext> for QueryContext {
161 fn from(ctx: api::v1::QueryContext) -> Self {
162 let sequences = ctx.snapshot_seqs.as_ref();
163 QueryContextBuilder::default()
164 .current_catalog(ctx.current_catalog)
165 .current_schema(ctx.current_schema)
166 .timezone(parse_timezone(Some(&ctx.timezone)))
167 .extensions(ctx.extensions)
168 .channel(ctx.channel.into())
169 .snapshot_seqs(Arc::new(RwLock::new(
170 sequences
171 .map(|x| x.snapshot_seqs.clone())
172 .unwrap_or_default(),
173 )))
174 .sst_min_sequences(Arc::new(RwLock::new(
175 sequences
176 .map(|x| x.sst_min_sequences.clone())
177 .unwrap_or_default(),
178 )))
179 .explain_options(ctx.explain)
180 .build()
181 }
182}
183
184impl From<QueryContext> for api::v1::QueryContext {
185 fn from(
186 QueryContext {
187 current_catalog,
188 mutable_session_data: mutable_inner,
189 extensions,
190 channel,
191 snapshot_seqs,
192 sst_min_sequences,
193 mutable_query_context_data,
194 ..
195 }: QueryContext,
196 ) -> Self {
197 let explain = mutable_query_context_data.read().unwrap().explain_options;
198 let mutable_inner = mutable_inner.read().unwrap();
199 api::v1::QueryContext {
200 current_catalog,
201 current_schema: mutable_inner.schema.clone(),
202 timezone: mutable_inner.timezone.to_string(),
203 extensions,
204 channel: channel as u32,
205 snapshot_seqs: Some(api::v1::SnapshotSequences {
206 snapshot_seqs: snapshot_seqs.read().unwrap().clone(),
207 sst_min_sequences: sst_min_sequences.read().unwrap().clone(),
208 }),
209 explain,
210 }
211 }
212}
213
214impl From<&QueryContext> for api::v1::QueryContext {
215 fn from(ctx: &QueryContext) -> Self {
216 ctx.clone().into()
217 }
218}
219
220impl QueryContext {
221 pub fn arc() -> QueryContextRef {
222 Arc::new(QueryContextBuilder::default().build())
223 }
224
225 pub fn create_config_options(&self) -> ConfigOptions {
227 let mut config = ConfigOptions::default();
228 config.execution.time_zone = Some(self.timezone().to_string());
229 config
230 }
231
232 pub fn with(catalog: &str, schema: &str) -> QueryContext {
233 QueryContextBuilder::default()
234 .current_catalog(catalog.to_string())
235 .current_schema(schema.to_string())
236 .build()
237 }
238
239 pub fn with_channel(catalog: &str, schema: &str, channel: Channel) -> QueryContext {
240 QueryContextBuilder::default()
241 .current_catalog(catalog.to_string())
242 .current_schema(schema.to_string())
243 .channel(channel)
244 .build()
245 }
246
247 pub fn with_db_name(db_name: Option<&str>) -> QueryContext {
248 let (catalog, schema) = db_name
249 .map(|db| {
250 let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
251 (catalog, schema)
252 })
253 .unwrap_or_else(|| {
254 (
255 DEFAULT_CATALOG_NAME.to_string(),
256 DEFAULT_SCHEMA_NAME.to_string(),
257 )
258 });
259 QueryContextBuilder::default()
260 .current_catalog(catalog)
261 .current_schema(schema.clone())
262 .build()
263 }
264
265 pub fn current_schema(&self) -> String {
266 self.mutable_session_data.read().unwrap().schema.clone()
267 }
268
269 pub fn set_current_schema(&self, new_schema: &str) {
270 self.mutable_session_data.write().unwrap().schema = new_schema.to_string();
271 }
272
273 pub fn current_catalog(&self) -> &str {
274 &self.current_catalog
275 }
276
277 pub fn set_current_catalog(&mut self, new_catalog: &str) {
278 self.current_catalog = new_catalog.to_string();
279 }
280
281 pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) {
282 &*self.sql_dialect
283 }
284
285 pub fn get_db_string(&self) -> String {
286 let catalog = self.current_catalog();
287 let schema = self.current_schema();
288 build_db_string(catalog, &schema)
289 }
290
291 pub fn timezone(&self) -> Timezone {
292 self.mutable_session_data.read().unwrap().timezone.clone()
293 }
294
295 pub fn set_timezone(&self, timezone: Timezone) {
296 self.mutable_session_data.write().unwrap().timezone = timezone;
297 }
298
299 pub fn read_preference(&self) -> ReadPreference {
300 self.mutable_session_data.read().unwrap().read_preference
301 }
302
303 pub fn set_read_preference(&self, read_preference: ReadPreference) {
304 self.mutable_session_data.write().unwrap().read_preference = read_preference;
305 }
306
307 pub fn current_user(&self) -> UserInfoRef {
308 self.mutable_session_data.read().unwrap().user_info.clone()
309 }
310
311 pub fn set_current_user(&self, user: UserInfoRef) {
312 self.mutable_session_data.write().unwrap().user_info = user;
313 }
314
315 pub fn set_extension<S1: Into<String>, S2: Into<String>>(&mut self, key: S1, value: S2) {
316 self.extensions.insert(key.into(), value.into());
317 }
318
319 pub fn extension<S: AsRef<str>>(&self, key: S) -> Option<&str> {
320 self.extensions.get(key.as_ref()).map(|v| v.as_str())
321 }
322
323 pub fn extensions(&self) -> HashMap<String, String> {
324 self.extensions.clone()
325 }
326
327 pub fn quote_style(&self) -> char {
329 if self.sql_dialect().is_delimited_identifier_start('"') {
330 '"'
331 } else if self.sql_dialect().is_delimited_identifier_start('\'') {
332 '\''
333 } else {
334 '`'
335 }
336 }
337
338 pub fn configuration_parameter(&self) -> &ConfigurationVariables {
339 &self.configuration_parameter
340 }
341
342 pub fn channel(&self) -> Channel {
343 self.channel
344 }
345
346 pub fn set_channel(&mut self, channel: Channel) {
347 self.channel = channel;
348 }
349
350 pub fn warning(&self) -> Option<String> {
351 self.mutable_query_context_data
352 .read()
353 .unwrap()
354 .warning
355 .clone()
356 }
357
358 pub fn set_warning(&self, msg: String) {
359 self.mutable_query_context_data.write().unwrap().warning = Some(msg);
360 }
361
362 pub fn explain_format(&self) -> Option<String> {
363 self.mutable_query_context_data
364 .read()
365 .unwrap()
366 .explain_format
367 .clone()
368 }
369
370 pub fn set_explain_format(&self, format: String) {
371 self.mutable_query_context_data
372 .write()
373 .unwrap()
374 .explain_format = Some(format);
375 }
376
377 pub fn explain_verbose(&self) -> bool {
378 self.mutable_query_context_data
379 .read()
380 .unwrap()
381 .explain_options
382 .map(|opts| opts.verbose)
383 .unwrap_or(false)
384 }
385
386 pub fn set_explain_verbose(&self, verbose: bool) {
387 self.mutable_query_context_data
388 .write()
389 .unwrap()
390 .explain_options
391 .get_or_insert_default()
392 .verbose = verbose;
393 }
394
395 pub fn query_timeout(&self) -> Option<Duration> {
396 self.mutable_session_data.read().unwrap().query_timeout
397 }
398
399 pub fn query_timeout_as_millis(&self) -> u128 {
400 let timeout = self.mutable_session_data.read().unwrap().query_timeout;
401 if let Some(t) = timeout {
402 return t.as_millis();
403 }
404 0
405 }
406
407 pub fn set_query_timeout(&self, timeout: Duration) {
408 self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
409 }
410
411 pub fn insert_cursor(&self, name: String, rb: RecordBatchStreamCursor) {
412 let mut guard = self.mutable_session_data.write().unwrap();
413 guard.cursors.insert(name, Arc::new(rb));
414
415 let cursor_count = guard.cursors.len();
416 if cursor_count > CURSOR_COUNT_WARNING_LIMIT {
417 warn!("Current connection has {} open cursors", cursor_count);
418 }
419 }
420
421 pub fn remove_cursor(&self, name: &str) {
422 let mut guard = self.mutable_session_data.write().unwrap();
423 guard.cursors.remove(name);
424 }
425
426 pub fn get_cursor(&self, name: &str) -> Option<Arc<RecordBatchStreamCursor>> {
427 let guard = self.mutable_session_data.read().unwrap();
428 let rb = guard.cursors.get(name);
429 rb.cloned()
430 }
431
432 pub fn snapshots(&self) -> HashMap<u64, u64> {
433 self.snapshot_seqs.read().unwrap().clone()
434 }
435
436 pub fn sst_min_sequences(&self) -> HashMap<u64, u64> {
437 self.sst_min_sequences.read().unwrap().clone()
438 }
439
440 pub fn get_snapshot(&self, region_id: u64) -> Option<u64> {
441 self.snapshot_seqs.read().unwrap().get(®ion_id).cloned()
442 }
443
444 pub fn set_snapshot(&self, region_id: u64, sequence: u64) {
445 self.snapshot_seqs
446 .write()
447 .unwrap()
448 .insert(region_id, sequence);
449 }
450
451 pub fn auto_string_to_numeric(&self) -> bool {
453 matches!(self.channel, Channel::Mysql)
454 }
455
456 pub fn sst_min_sequence(&self, region_id: u64) -> Option<u64> {
458 self.sst_min_sequences
459 .read()
460 .unwrap()
461 .get(®ion_id)
462 .copied()
463 }
464
465 pub fn process_id(&self) -> u32 {
466 self.process_id
467 }
468
469 pub fn conn_info(&self) -> &ConnInfo {
471 &self.conn_info
472 }
473
474 pub fn protocol_ctx(&self) -> &ProtocolCtx {
475 &self.protocol_ctx
476 }
477
478 pub fn set_protocol_ctx(&mut self, protocol_ctx: ProtocolCtx) {
479 self.protocol_ctx = protocol_ctx;
480 }
481}
482
483impl QueryContextBuilder {
484 pub fn build(self) -> QueryContext {
485 let channel = self.channel.unwrap_or_default();
486 QueryContext {
487 current_catalog: self
488 .current_catalog
489 .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()),
490 snapshot_seqs: self.snapshot_seqs.unwrap_or_default(),
491 sst_min_sequences: self.sst_min_sequences.unwrap_or_default(),
492 mutable_session_data: self.mutable_session_data.unwrap_or_default(),
493 mutable_query_context_data: self.mutable_query_context_data.unwrap_or_default(),
494 sql_dialect: self
495 .sql_dialect
496 .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
497 extensions: self.extensions.unwrap_or_default(),
498 configuration_parameter: self
499 .configuration_parameter
500 .unwrap_or_else(|| Arc::new(ConfigurationVariables::default())),
501 channel,
502 process_id: self.process_id.unwrap_or_default(),
503 conn_info: self.conn_info.unwrap_or_default(),
504 protocol_ctx: self.protocol_ctx.unwrap_or_default(),
505 }
506 }
507
508 pub fn set_extension(mut self, key: String, value: String) -> Self {
509 self.extensions
510 .get_or_insert_with(HashMap::new)
511 .insert(key, value);
512 self
513 }
514}
515
516#[derive(Debug, Clone, Default)]
517pub struct ConnInfo {
518 pub client_addr: Option<SocketAddr>,
519 pub channel: Channel,
520}
521
522impl Display for ConnInfo {
523 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
524 write!(
525 f,
526 "{}[{}]",
527 self.channel,
528 self.client_addr
529 .map(|addr| addr.to_string())
530 .as_deref()
531 .unwrap_or("unknown client addr")
532 )
533 }
534}
535
536impl ConnInfo {
537 pub fn new(client_addr: Option<SocketAddr>, channel: Channel) -> Self {
538 Self {
539 client_addr,
540 channel,
541 }
542 }
543}
544
545#[derive(Debug, PartialEq, Default, Clone, Copy)]
546#[repr(u8)]
547pub enum Channel {
548 #[default]
549 Unknown = 0,
550
551 Mysql = 1,
552 Postgres = 2,
553 HttpSql = 3,
554 Prometheus = 4,
555 Otlp = 5,
556 Grpc = 6,
557 Influx = 7,
558 Opentsdb = 8,
559 Loki = 9,
560 Elasticsearch = 10,
561 Jaeger = 11,
562 Log = 12,
563 Promql = 13,
564}
565
566impl From<u32> for Channel {
567 fn from(value: u32) -> Self {
568 match value {
569 1 => Self::Mysql,
570 2 => Self::Postgres,
571 3 => Self::HttpSql,
572 4 => Self::Prometheus,
573 5 => Self::Otlp,
574 6 => Self::Grpc,
575 7 => Self::Influx,
576 8 => Self::Opentsdb,
577 9 => Self::Loki,
578 10 => Self::Elasticsearch,
579 11 => Self::Jaeger,
580 12 => Self::Log,
581 13 => Self::Promql,
582 _ => Self::Unknown,
583 }
584 }
585}
586
587impl Channel {
588 pub fn dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
589 match self {
590 Channel::Mysql => Arc::new(MySqlDialect {}),
591 Channel::Postgres => Arc::new(PostgreSqlDialect {}),
592 _ => Arc::new(GenericDialect {}),
593 }
594 }
595}
596
597impl Display for Channel {
598 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
599 write!(f, "{}", self.as_ref())
600 }
601}
602
603impl AsRef<str> for Channel {
604 fn as_ref(&self) -> &str {
605 match self {
606 Channel::Mysql => "mysql",
607 Channel::Postgres => "postgres",
608 Channel::HttpSql => "httpsql",
609 Channel::Prometheus => "prometheus",
610 Channel::Otlp => "otlp",
611 Channel::Grpc => "grpc",
612 Channel::Influx => "influx",
613 Channel::Opentsdb => "opentsdb",
614 Channel::Loki => "loki",
615 Channel::Elasticsearch => "elasticsearch",
616 Channel::Jaeger => "jaeger",
617 Channel::Log => "log",
618 Channel::Promql => "promql",
619 Channel::Unknown => "unknown",
620 }
621 }
622}
623
624#[derive(Default, Debug)]
625pub struct ConfigurationVariables {
626 postgres_bytea_output: ArcSwap<PGByteaOutputValue>,
627 pg_datestyle_format: ArcSwap<(PGDateTimeStyle, PGDateOrder)>,
628 pg_intervalstyle_format: ArcSwap<PGIntervalStyle>,
629 allow_query_fallback: ArcSwap<bool>,
630}
631
632impl Clone for ConfigurationVariables {
633 fn clone(&self) -> Self {
634 Self {
635 postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()),
636 pg_datestyle_format: ArcSwap::new(self.pg_datestyle_format.load().clone()),
637 pg_intervalstyle_format: ArcSwap::new(self.pg_intervalstyle_format.load().clone()),
638 allow_query_fallback: ArcSwap::new(self.allow_query_fallback.load().clone()),
639 }
640 }
641}
642
643impl ConfigurationVariables {
644 pub fn new() -> Self {
645 Self::default()
646 }
647
648 pub fn set_postgres_bytea_output(&self, value: PGByteaOutputValue) {
649 let _ = self.postgres_bytea_output.swap(Arc::new(value));
650 }
651
652 pub fn postgres_bytea_output(&self) -> Arc<PGByteaOutputValue> {
653 self.postgres_bytea_output.load().clone()
654 }
655
656 pub fn pg_datetime_style(&self) -> Arc<(PGDateTimeStyle, PGDateOrder)> {
657 self.pg_datestyle_format.load().clone()
658 }
659
660 pub fn set_pg_datetime_style(&self, style: PGDateTimeStyle, order: PGDateOrder) {
661 self.pg_datestyle_format.swap(Arc::new((style, order)));
662 }
663
664 pub fn pg_intervalstyle_format(&self) -> Arc<PGIntervalStyle> {
665 self.pg_intervalstyle_format.load().clone()
666 }
667
668 pub fn set_pg_intervalstyle_format(&self, value: PGIntervalStyle) {
669 self.pg_intervalstyle_format.swap(Arc::new(value));
670 }
671
672 pub fn allow_query_fallback(&self) -> bool {
673 **self.allow_query_fallback.load()
674 }
675
676 pub fn set_allow_query_fallback(&self, allow: bool) {
677 self.allow_query_fallback.swap(Arc::new(allow));
678 }
679}
680
681#[cfg(test)]
682mod test {
683 use std::collections::HashMap;
684
685 use common_catalog::consts::DEFAULT_CATALOG_NAME;
686
687 use super::*;
688 use crate::Session;
689 use crate::context::Channel;
690
691 #[test]
692 fn test_session() {
693 let session = Session::new(
694 Some("127.0.0.1:9000".parse().unwrap()),
695 Channel::Mysql,
696 Default::default(),
697 100,
698 );
699 assert_eq!(session.user_info().username(), "greptime");
701
702 assert_eq!(session.conn_info().channel, Channel::Mysql);
704 let client_addr = session.conn_info().client_addr.as_ref().unwrap();
705 assert_eq!(client_addr.ip().to_string(), "127.0.0.1");
706 assert_eq!(client_addr.port(), 9000);
707
708 assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string());
709 assert_eq!(100, session.process_id());
710 }
711
712 #[test]
713 fn test_context_db_string() {
714 let context = QueryContext::with("a0b1c2d3", "test");
715 assert_eq!("a0b1c2d3-test", context.get_db_string());
716
717 let context = QueryContext::with(DEFAULT_CATALOG_NAME, "test");
718 assert_eq!("test", context.get_db_string());
719 }
720
721 #[test]
722 fn test_api_query_context_roundtrip_with_sequences() {
723 let api_ctx = api::v1::QueryContext {
724 current_catalog: "c1".to_string(),
725 current_schema: "s1".to_string(),
726 timezone: "UTC".to_string(),
727 extensions: HashMap::from([("flow.return_region_seq".to_string(), "true".to_string())]),
728 channel: Channel::Grpc as u32,
729 snapshot_seqs: Some(api::v1::SnapshotSequences {
730 snapshot_seqs: HashMap::from([(1, 100)]),
731 sst_min_sequences: HashMap::from([(1, 90)]),
732 }),
733 explain: None,
734 };
735
736 let session_ctx: QueryContext = api_ctx.clone().into();
737 let roundtrip_api: api::v1::QueryContext = session_ctx.into();
738
739 assert_eq!(roundtrip_api.current_catalog, api_ctx.current_catalog);
740 assert_eq!(roundtrip_api.current_schema, api_ctx.current_schema);
741 assert_eq!(roundtrip_api.timezone, api_ctx.timezone);
742 assert_eq!(roundtrip_api.extensions, api_ctx.extensions);
743 assert_eq!(roundtrip_api.channel, api_ctx.channel);
744 assert_eq!(roundtrip_api.snapshot_seqs, api_ctx.snapshot_seqs);
745 }
746}