1use std::time::Duration;
16
17use base64::Engine;
18use base64::engine::general_purpose;
19use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
20use common_error::ext::BoxedError;
21use humantime::format_duration;
22use serde_json::Value;
23use servers::http::GreptimeQueryOutput;
24use servers::http::header::constants::GREPTIME_DB_HEADER_TIMEOUT;
25use servers::http::result::greptime_result_v1::GreptimedbV1Response;
26use snafu::ResultExt;
27
28use crate::error::{
29 BuildClientSnafu, HttpQuerySqlSnafu, ParseProxyOptsSnafu, Result, SerdeJsonSnafu,
30};
31
32#[derive(Debug, Clone)]
33pub struct DatabaseClient {
34 addr: String,
35 catalog: String,
36 auth_header: Option<String>,
37 timeout: Duration,
38 proxy: Option<reqwest::Proxy>,
39 no_proxy: bool,
40}
41
42pub fn parse_proxy_opts(
43 proxy: Option<String>,
44 no_proxy: bool,
45) -> std::result::Result<Option<reqwest::Proxy>, BoxedError> {
46 if no_proxy {
47 return Ok(None);
48 }
49 proxy
50 .map(|proxy| {
51 reqwest::Proxy::all(proxy)
52 .context(ParseProxyOptsSnafu)
53 .map_err(BoxedError::new)
54 })
55 .transpose()
56}
57
58impl DatabaseClient {
59 pub fn new(
60 addr: String,
61 catalog: String,
62 auth_basic: Option<String>,
63 timeout: Duration,
64 proxy: Option<reqwest::Proxy>,
65 no_proxy: bool,
66 ) -> Self {
67 let auth_header = if let Some(basic) = auth_basic {
68 let encoded = general_purpose::STANDARD.encode(basic);
69 Some(format!("basic {}", encoded))
70 } else {
71 None
72 };
73
74 if no_proxy {
75 common_telemetry::info!("Proxy disabled");
76 } else if let Some(ref proxy) = proxy {
77 common_telemetry::info!("Using proxy: {:?}", proxy);
78 } else {
79 common_telemetry::info!("Using system proxy(if any)");
80 }
81
82 Self {
83 addr,
84 catalog,
85 auth_header,
86 timeout,
87 proxy,
88 no_proxy,
89 }
90 }
91
92 pub fn addr(&self) -> &str {
93 &self.addr
94 }
95
96 pub async fn sql_in_public(&self, sql: &str) -> Result<Option<Vec<Vec<Value>>>> {
97 self.sql(sql, DEFAULT_SCHEMA_NAME).await
98 }
99
100 pub async fn sql(&self, sql: &str, schema: &str) -> Result<Option<Vec<Vec<Value>>>> {
102 let url = format!("http://{}/v1/sql", self.addr);
103 let params = [
104 ("db", format!("{}-{}", self.catalog, schema)),
105 ("sql", sql.to_string()),
106 ];
107 let mut builder = reqwest::Client::builder();
108 if let Some(proxy) = self.proxy.clone() {
109 builder = builder.proxy(proxy);
110 }
111 if self.no_proxy {
112 builder = builder.no_proxy();
113 }
114 let client = builder.build().context(BuildClientSnafu)?;
115 let mut request = client
116 .post(&url)
117 .form(¶ms)
118 .header("Content-Type", "application/x-www-form-urlencoded");
119 if let Some(ref auth) = self.auth_header {
120 request = request.header("Authorization", auth);
121 }
122
123 request = request.header(
124 GREPTIME_DB_HEADER_TIMEOUT,
125 format_duration(self.timeout).to_string(),
126 );
127
128 let response = request.send().await.with_context(|_| HttpQuerySqlSnafu {
129 reason: format!("bad url: {}", url),
130 })?;
131 let response = response
132 .error_for_status()
133 .with_context(|_| HttpQuerySqlSnafu {
134 reason: format!("query failed: {}", sql),
135 })?;
136
137 let text = response.text().await.with_context(|_| HttpQuerySqlSnafu {
138 reason: "cannot get response text".to_string(),
139 })?;
140
141 let body = serde_json::from_str::<GreptimedbV1Response>(&text).context(SerdeJsonSnafu)?;
142 Ok(body.output().first().and_then(|output| match output {
143 GreptimeQueryOutput::Records(records) => Some(records.rows().clone()),
144 GreptimeQueryOutput::AffectedRows(_) => None,
145 }))
146 }
147}
148
149pub(crate) fn split_database(database: &str) -> Result<(String, Option<String>)> {
151 let (catalog, schema) = match database.split_once('-') {
152 Some((catalog, schema)) => (catalog, schema),
153 None => (DEFAULT_CATALOG_NAME, database),
154 };
155
156 if schema == "*" {
157 Ok((catalog.to_string(), None))
158 } else {
159 Ok((catalog.to_string(), Some(schema.to_string())))
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn test_split_database() {
169 let result = split_database("catalog-schema").unwrap();
170 assert_eq!(result, ("catalog".to_string(), Some("schema".to_string())));
171
172 let result = split_database("schema").unwrap();
173 assert_eq!(result, ("greptime".to_string(), Some("schema".to_string())));
174
175 let result = split_database("catalog-*").unwrap();
176 assert_eq!(result, ("catalog".to_string(), None));
177
178 let result = split_database("*").unwrap();
179 assert_eq!(result, ("greptime".to_string(), None));
180 }
181}