Skip to main content

sql/statements/
insert.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use serde::Serialize;
16use sqlparser::ast::{
17    Insert as SpInsert, ObjectName, Query, SetExpr, Statement, TableObject, UnaryOperator,
18    ValueWithSpan, Values,
19};
20use sqlparser::parser::ParserError;
21use sqlparser_derive::{Visit, VisitMut};
22
23use crate::ast::{Expr, Value};
24use crate::error::{Result, UnsupportedSnafu};
25use crate::statements::query::Query as GtQuery;
26
27#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
28pub struct Insert {
29    // Can only be sqlparser::ast::Statement::Insert variant
30    pub inner: Statement,
31}
32
33macro_rules! parse_fail {
34    ($expr: expr) => {
35        return crate::error::ParseSqlValueSnafu {
36            msg: format!("{:?}", $expr),
37        }
38        .fail();
39    };
40}
41
42impl Insert {
43    pub fn table_name(&self) -> Result<&ObjectName> {
44        match &self.inner {
45            Statement::Insert(insert) => {
46                let TableObject::TableName(name) = &insert.table else {
47                    return UnsupportedSnafu {
48                        keyword: "TABLE FUNCTION".to_string(),
49                    }
50                    .fail();
51                };
52                Ok(name)
53            }
54            _ => unreachable!(),
55        }
56    }
57
58    pub fn columns(&self) -> Vec<&String> {
59        match &self.inner {
60            Statement::Insert(insert) => insert.columns.iter().map(|ident| &ident.value).collect(),
61            _ => unreachable!(),
62        }
63    }
64
65    /// Extracts the literal insert statement body if possible
66    pub fn values_body(&self) -> Result<Vec<Vec<Value>>> {
67        match &self.inner {
68            Statement::Insert(SpInsert {
69                source:
70                    Some(box Query {
71                        body: box SetExpr::Values(Values { rows, .. }),
72                        ..
73                    }),
74                ..
75            }) => sql_exprs_to_values(rows),
76            _ => unreachable!(),
77        }
78    }
79
80    /// Returns true when the insert statement can extract literal values.
81    /// The rules is the same as function `values_body()`.
82    pub fn can_extract_values(&self) -> bool {
83        match &self.inner {
84            Statement::Insert(SpInsert {
85                source:
86                    Some(box Query {
87                        body: box SetExpr::Values(Values { rows, .. }),
88                        ..
89                    }),
90                ..
91            }) => rows.iter().all(|es| {
92                es.iter().all(|expr| match expr {
93                    Expr::Value(_) => true,
94                    Expr::Identifier(ident) => {
95                        if ident.quote_style.is_none() {
96                            ident.value.to_lowercase() == "default"
97                        } else {
98                            ident.quote_style == Some('"')
99                        }
100                    }
101                    Expr::UnaryOp { op, expr } => {
102                        matches!(op, UnaryOperator::Minus | UnaryOperator::Plus)
103                            && matches!(
104                                &**expr,
105                                Expr::Value(ValueWithSpan {
106                                    value: Value::Number(_, _),
107                                    ..
108                                })
109                            )
110                    }
111                    _ => false,
112                })
113            }),
114            _ => false,
115        }
116    }
117
118    /// Returns true when the insert source is a query rather than `VALUES`.
119    pub fn has_non_values_query_source(&self) -> bool {
120        match &self.inner {
121            Statement::Insert(SpInsert {
122                source: Some(box query),
123                ..
124            }) => !matches!(&*query.body, SetExpr::Values(_)),
125            _ => false,
126        }
127    }
128
129    pub fn query_body(&self) -> Result<Option<GtQuery>> {
130        Ok(match &self.inner {
131            Statement::Insert(SpInsert {
132                source: Some(box query),
133                ..
134            }) => Some(query.clone().try_into()?),
135            _ => None,
136        })
137    }
138}
139
140fn sql_exprs_to_values(exprs: &[Vec<Expr>]) -> Result<Vec<Vec<Value>>> {
141    let mut values = Vec::with_capacity(exprs.len());
142    for es in exprs.iter() {
143        let mut vs = Vec::with_capacity(es.len());
144        for expr in es.iter() {
145            vs.push(match expr {
146                Expr::Value(v) => v.value.clone(),
147                Expr::Identifier(ident) => {
148                    if ident.quote_style.is_none() {
149                        // Special processing for `default` value
150                        if ident.value.to_lowercase() == "default" {
151                            Value::Placeholder(ident.value.clone())
152                        } else {
153                            parse_fail!(expr);
154                        }
155                    } else {
156                        // Identifiers with double quotes, we treat them as strings.
157                        if ident.quote_style == Some('"') {
158                            Value::SingleQuotedString(ident.value.clone())
159                        } else {
160                            parse_fail!(expr);
161                        }
162                    }
163                }
164                Expr::UnaryOp { op, expr }
165                    if matches!(op, UnaryOperator::Minus | UnaryOperator::Plus) =>
166                {
167                    if let Expr::Value(ValueWithSpan {
168                        value: Value::Number(s, b),
169                        ..
170                    }) = &**expr
171                    {
172                        match op {
173                            UnaryOperator::Minus => Value::Number(format!("-{s}"), *b),
174                            UnaryOperator::Plus => Value::Number(s.clone(), *b),
175                            _ => unreachable!(),
176                        }
177                    } else {
178                        parse_fail!(expr);
179                    }
180                }
181                _ => {
182                    parse_fail!(expr);
183                }
184            });
185        }
186        values.push(vs);
187    }
188    Ok(values)
189}
190
191impl TryFrom<Statement> for Insert {
192    type Error = ParserError;
193
194    fn try_from(value: Statement) -> std::result::Result<Self, Self::Error> {
195        match value {
196            Statement::Insert { .. } => Ok(Insert { inner: value }),
197            unexp => Err(ParserError::ParserError(format!(
198                "Not expected to be {unexp}"
199            ))),
200        }
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use crate::dialect::GreptimeDbDialect;
208    use crate::parser::{ParseOptions, ParserContext};
209    use crate::statements::statement::Statement;
210
211    #[test]
212    fn test_insert_value_with_unary_op() {
213        // insert "-1"
214        let sql = "INSERT INTO my_table VALUES(-1)";
215        let stmt =
216            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
217                .unwrap()
218                .remove(0);
219        match stmt {
220            Statement::Insert(insert) => {
221                let values = insert.values_body().unwrap();
222                assert_eq!(values, vec![vec![Value::Number("-1".to_string(), false)]]);
223            }
224            _ => unreachable!(),
225        }
226
227        // insert "+1"
228        let sql = "INSERT INTO my_table VALUES(+1)";
229        let stmt =
230            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
231                .unwrap()
232                .remove(0);
233        match stmt {
234            Statement::Insert(insert) => {
235                let values = insert.values_body().unwrap();
236                assert_eq!(values, vec![vec![Value::Number("1".to_string(), false)]]);
237            }
238            _ => unreachable!(),
239        }
240    }
241
242    #[test]
243    fn test_insert_value_with_default() {
244        // insert "default"
245        let sql = "INSERT INTO my_table VALUES(default)";
246        let stmt =
247            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
248                .unwrap()
249                .remove(0);
250        match stmt {
251            Statement::Insert(insert) => {
252                let values = insert.values_body().unwrap();
253                assert_eq!(values, vec![vec![Value::Placeholder("default".to_owned())]]);
254            }
255            _ => unreachable!(),
256        }
257    }
258
259    #[test]
260    fn test_insert_value_with_default_uppercase() {
261        // insert "DEFAULT"
262        let sql = "INSERT INTO my_table VALUES(DEFAULT)";
263        let stmt =
264            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
265                .unwrap()
266                .remove(0);
267        match stmt {
268            Statement::Insert(insert) => {
269                let values = insert.values_body().unwrap();
270                assert_eq!(values, vec![vec![Value::Placeholder("DEFAULT".to_owned())]]);
271            }
272            _ => unreachable!(),
273        }
274    }
275
276    #[test]
277    fn test_insert_value_with_quoted_string() {
278        // insert 'default'
279        let sql = "INSERT INTO my_table VALUES('default')";
280        let stmt =
281            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
282                .unwrap()
283                .remove(0);
284        match stmt {
285            Statement::Insert(insert) => {
286                let values = insert.values_body().unwrap();
287                assert_eq!(
288                    values,
289                    vec![vec![Value::SingleQuotedString("default".to_owned())]]
290                );
291            }
292            _ => unreachable!(),
293        }
294
295        // insert "default". Treating double-quoted identifiers as strings.
296        let sql = "INSERT INTO my_table VALUES(\"default\")";
297        let stmt =
298            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
299                .unwrap()
300                .remove(0);
301        match stmt {
302            Statement::Insert(insert) => {
303                let values = insert.values_body().unwrap();
304                assert_eq!(
305                    values,
306                    vec![vec![Value::SingleQuotedString("default".to_owned())]]
307                );
308            }
309            _ => unreachable!(),
310        }
311
312        let sql = "INSERT INTO my_table VALUES(`default`)";
313        let stmt =
314            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
315                .unwrap()
316                .remove(0);
317        match stmt {
318            Statement::Insert(insert) => {
319                assert!(insert.values_body().is_err());
320            }
321            _ => unreachable!(),
322        }
323    }
324
325    #[test]
326    fn test_insert_select() {
327        let sql = "INSERT INTO my_table select * from other_table";
328        let stmt =
329            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
330                .unwrap()
331                .remove(0);
332        match stmt {
333            Statement::Insert(insert) => {
334                let q = insert.query_body().unwrap().unwrap();
335                assert!(insert.has_non_values_query_source());
336                assert!(matches!(
337                    q.inner,
338                    Query {
339                        body: box SetExpr::Select { .. },
340                        ..
341                    }
342                ));
343            }
344            _ => unreachable!(),
345        }
346    }
347
348    #[test]
349    fn test_has_non_values_query_source() {
350        let cases = [
351            ("INSERT INTO my_table SELECT * FROM other_table", true),
352            (
353                "INSERT INTO my_table WITH cte AS (SELECT * FROM other_table) SELECT * FROM cte",
354                true,
355            ),
356            (
357                "INSERT INTO my_table SELECT * FROM t1 UNION ALL SELECT * FROM t2",
358                true,
359            ),
360            ("INSERT INTO my_table VALUES(1)", false),
361            ("INSERT INTO my_table VALUES(now())", false),
362            ("INSERT INTO my_table VALUES(1 + 1)", false),
363        ];
364
365        for (sql, expected) in cases {
366            let stmt = ParserContext::create_with_dialect(
367                sql,
368                &GreptimeDbDialect {},
369                ParseOptions::default(),
370            )
371            .unwrap()
372            .remove(0);
373            match stmt {
374                Statement::Insert(insert) => {
375                    assert_eq!(insert.has_non_values_query_source(), expected, "{sql}");
376                }
377                _ => unreachable!(),
378            }
379        }
380    }
381}