Skip to main content

session/
lib.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod context;
16pub mod hints;
17pub mod protocol_ctx;
18pub mod query_id;
19pub mod session_config;
20pub mod table_name;
21
22use std::collections::{HashMap, VecDeque};
23use std::net::SocketAddr;
24use std::sync::{Arc, RwLock};
25use std::time::Duration;
26
27use auth::UserInfoRef;
28use common_catalog::build_db_string;
29use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
30use common_recordbatch::cursor::RecordBatchStreamCursor;
31pub use common_session::ReadPreference;
32use common_time::Timezone;
33use common_time::timezone::get_timezone;
34use context::{ConfigurationVariables, QueryContextBuilder};
35use derive_more::Debug;
36
37use crate::context::{Channel, ConnInfo, QueryContextRef};
38
39/// Maximum number of warnings to store per session (similar to MySQL's max_error_count)
40const MAX_WARNINGS: usize = 64;
41
42/// Session for persistent connection such as MySQL, PostgreSQL etc.
43#[derive(Debug)]
44pub struct Session {
45    catalog: RwLock<String>,
46    mutable_inner: Arc<RwLock<MutableInner>>,
47    conn_info: ConnInfo,
48    configuration_variables: Arc<ConfigurationVariables>,
49    // the process id to use when killing the query
50    process_id: u32,
51}
52
53pub type SessionRef = Arc<Session>;
54
55/// A container for mutable items in query context
56#[derive(Debug)]
57pub(crate) struct MutableInner {
58    schema: String,
59    user_info: UserInfoRef,
60    timezone: Timezone,
61    query_timeout: Option<Duration>,
62    read_preference: ReadPreference,
63    #[debug(skip)]
64    pub(crate) cursors: HashMap<String, Arc<RecordBatchStreamCursor>>,
65    /// Warning messages for MySQL SHOW WARNINGS support
66    warnings: VecDeque<String>,
67}
68
69impl Default for MutableInner {
70    fn default() -> Self {
71        Self {
72            schema: DEFAULT_SCHEMA_NAME.into(),
73            user_info: auth::userinfo_by_name(None),
74            timezone: get_timezone(None).clone(),
75            query_timeout: None,
76            read_preference: ReadPreference::Leader,
77            cursors: HashMap::with_capacity(0),
78            warnings: VecDeque::new(),
79        }
80    }
81}
82
83impl Session {
84    pub fn new(
85        addr: Option<SocketAddr>,
86        channel: Channel,
87        configuration_variables: ConfigurationVariables,
88        process_id: u32,
89    ) -> Self {
90        Session {
91            catalog: RwLock::new(DEFAULT_CATALOG_NAME.into()),
92            conn_info: ConnInfo::new(addr, channel),
93            configuration_variables: Arc::new(configuration_variables),
94            mutable_inner: Arc::new(RwLock::new(MutableInner::default())),
95            process_id,
96        }
97    }
98
99    pub fn new_query_context(&self) -> QueryContextRef {
100        QueryContextBuilder::default()
101            // catalog is not allowed for update in query context so we use
102            // string here
103            .current_catalog(self.catalog.read().unwrap().clone())
104            .mutable_session_data(self.mutable_inner.clone())
105            .sql_dialect(self.conn_info.channel.dialect())
106            .configuration_parameter(self.configuration_variables.clone())
107            .channel(self.conn_info.channel)
108            .process_id(self.process_id)
109            .conn_info(self.conn_info.clone())
110            .build()
111            .into()
112    }
113
114    pub fn conn_info(&self) -> &ConnInfo {
115        &self.conn_info
116    }
117
118    pub fn timezone(&self) -> Timezone {
119        self.mutable_inner.read().unwrap().timezone.clone()
120    }
121
122    pub fn read_preference(&self) -> ReadPreference {
123        self.mutable_inner.read().unwrap().read_preference
124    }
125
126    pub fn set_timezone(&self, tz: Timezone) {
127        let mut inner = self.mutable_inner.write().unwrap();
128        inner.timezone = tz;
129    }
130
131    pub fn set_read_preference(&self, read_preference: ReadPreference) {
132        self.mutable_inner.write().unwrap().read_preference = read_preference;
133    }
134
135    pub fn user_info(&self) -> UserInfoRef {
136        self.mutable_inner.read().unwrap().user_info.clone()
137    }
138
139    pub fn set_user_info(&self, user_info: UserInfoRef) {
140        self.mutable_inner.write().unwrap().user_info = user_info;
141    }
142
143    pub fn set_catalog(&self, catalog: String) {
144        *self.catalog.write().unwrap() = catalog;
145    }
146
147    pub fn catalog(&self) -> String {
148        self.catalog.read().unwrap().clone()
149    }
150
151    pub fn schema(&self) -> String {
152        self.mutable_inner.read().unwrap().schema.clone()
153    }
154
155    pub fn set_schema(&self, schema: String) {
156        self.mutable_inner.write().unwrap().schema = schema;
157    }
158
159    pub fn get_db_string(&self) -> String {
160        build_db_string(&self.catalog(), &self.schema())
161    }
162
163    pub fn process_id(&self) -> u32 {
164        self.process_id
165    }
166
167    pub fn warnings_count(&self) -> usize {
168        self.mutable_inner.read().unwrap().warnings.len()
169    }
170
171    pub fn warnings(&self) -> Vec<String> {
172        self.mutable_inner
173            .read()
174            .unwrap()
175            .warnings
176            .iter()
177            .cloned()
178            .collect()
179    }
180
181    /// Add a warning message. If the limit is reached, discard the oldest warning.
182    pub fn add_warning(&self, warning: String) {
183        let mut inner = self.mutable_inner.write().unwrap();
184        if inner.warnings.len() >= MAX_WARNINGS {
185            inner.warnings.pop_front();
186        }
187        inner.warnings.push_back(warning);
188    }
189
190    pub fn clear_warnings(&self) {
191        let mut inner = self.mutable_inner.write().unwrap();
192        if inner.warnings.is_empty() {
193            return;
194        }
195        inner.warnings.clear();
196    }
197}