Skip to main content

cli/
database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use base64::Engine;
18use base64::engine::general_purpose;
19use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
20use common_error::ext::BoxedError;
21use humantime::format_duration;
22use serde_json::Value;
23use servers::http::GreptimeQueryOutput;
24use servers::http::header::constants::GREPTIME_DB_HEADER_TIMEOUT;
25use servers::http::result::greptime_result_v1::GreptimedbV1Response;
26use snafu::ResultExt;
27
28use crate::error::{
29    BuildClientSnafu, HttpQuerySqlSnafu, ParseProxyOptsSnafu, Result, SerdeJsonSnafu,
30};
31
32#[derive(Debug, Clone)]
33pub struct DatabaseClient {
34    addr: String,
35    catalog: String,
36    auth_header: Option<String>,
37    timeout: Duration,
38    proxy: Option<reqwest::Proxy>,
39    no_proxy: bool,
40}
41
42pub fn parse_proxy_opts(
43    proxy: Option<String>,
44    no_proxy: bool,
45) -> std::result::Result<Option<reqwest::Proxy>, BoxedError> {
46    if no_proxy {
47        return Ok(None);
48    }
49    proxy
50        .map(|proxy| {
51            reqwest::Proxy::all(proxy)
52                .context(ParseProxyOptsSnafu)
53                .map_err(BoxedError::new)
54        })
55        .transpose()
56}
57
58impl DatabaseClient {
59    pub fn new(
60        addr: String,
61        catalog: String,
62        auth_basic: Option<String>,
63        timeout: Duration,
64        proxy: Option<reqwest::Proxy>,
65        no_proxy: bool,
66    ) -> Self {
67        let auth_header = if let Some(basic) = auth_basic {
68            let encoded = general_purpose::STANDARD.encode(basic);
69            Some(format!("basic {}", encoded))
70        } else {
71            None
72        };
73
74        if no_proxy {
75            common_telemetry::info!("Proxy disabled");
76        } else if let Some(ref proxy) = proxy {
77            common_telemetry::info!("Using proxy: {:?}", proxy);
78        } else {
79            common_telemetry::info!("Using system proxy(if any)");
80        }
81
82        Self {
83            addr,
84            catalog,
85            auth_header,
86            timeout,
87            proxy,
88            no_proxy,
89        }
90    }
91
92    pub fn addr(&self) -> &str {
93        &self.addr
94    }
95
96    pub async fn sql_in_public(&self, sql: &str) -> Result<Option<Vec<Vec<Value>>>> {
97        self.sql(sql, DEFAULT_SCHEMA_NAME).await
98    }
99
100    /// Execute sql query.
101    pub async fn sql(&self, sql: &str, schema: &str) -> Result<Option<Vec<Vec<Value>>>> {
102        let url = format!("http://{}/v1/sql", self.addr);
103        let params = [
104            ("db", format!("{}-{}", self.catalog, schema)),
105            ("sql", sql.to_string()),
106        ];
107        let mut builder = reqwest::Client::builder();
108        if let Some(proxy) = self.proxy.clone() {
109            builder = builder.proxy(proxy);
110        }
111        if self.no_proxy {
112            builder = builder.no_proxy();
113        }
114        let client = builder.build().context(BuildClientSnafu)?;
115        let mut request = client
116            .post(&url)
117            .form(&params)
118            .header("Content-Type", "application/x-www-form-urlencoded");
119        if let Some(ref auth) = self.auth_header {
120            request = request.header("Authorization", auth);
121        }
122
123        request = request.header(
124            GREPTIME_DB_HEADER_TIMEOUT,
125            format_duration(self.timeout).to_string(),
126        );
127
128        let response = request.send().await.with_context(|_| HttpQuerySqlSnafu {
129            reason: format!("bad url: {}", url),
130        })?;
131        let response = response
132            .error_for_status()
133            .with_context(|_| HttpQuerySqlSnafu {
134                reason: format!("query failed: {}", sql),
135            })?;
136
137        let text = response.text().await.with_context(|_| HttpQuerySqlSnafu {
138            reason: "cannot get response text".to_string(),
139        })?;
140
141        let body = serde_json::from_str::<GreptimedbV1Response>(&text).context(SerdeJsonSnafu)?;
142        Ok(body.output().first().and_then(|output| match output {
143            GreptimeQueryOutput::Records(records) => Some(records.rows().clone()),
144            GreptimeQueryOutput::AffectedRows(_) => None,
145        }))
146    }
147}
148
149/// Split at `-`.
150pub(crate) fn split_database(database: &str) -> Result<(String, Option<String>)> {
151    let (catalog, schema) = match database.split_once('-') {
152        Some((catalog, schema)) => (catalog, schema),
153        None => (DEFAULT_CATALOG_NAME, database),
154    };
155
156    if schema == "*" {
157        Ok((catalog.to_string(), None))
158    } else {
159        Ok((catalog.to_string(), Some(schema.to_string())))
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn test_split_database() {
169        let result = split_database("catalog-schema").unwrap();
170        assert_eq!(result, ("catalog".to_string(), Some("schema".to_string())));
171
172        let result = split_database("schema").unwrap();
173        assert_eq!(result, ("greptime".to_string(), Some("schema".to_string())));
174
175        let result = split_database("catalog-*").unwrap();
176        assert_eq!(result, ("catalog".to_string(), None));
177
178        let result = split_database("*").unwrap();
179        assert_eq!(result, ("greptime".to_string(), None));
180    }
181}