1use serde::Serialize;
16use sqlparser::ast::{
17 Insert as SpInsert, ObjectName, Query, SetExpr, Statement, TableObject, UnaryOperator,
18 ValueWithSpan, Values,
19};
20use sqlparser::parser::ParserError;
21use sqlparser_derive::{Visit, VisitMut};
22
23use crate::ast::{Expr, Value};
24use crate::error::{Result, UnsupportedSnafu};
25use crate::statements::query::Query as GtQuery;
26
27#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
28pub struct Insert {
29 pub inner: Statement,
31}
32
33macro_rules! parse_fail {
34 ($expr: expr) => {
35 return crate::error::ParseSqlValueSnafu {
36 msg: format!("{:?}", $expr),
37 }
38 .fail();
39 };
40}
41
42impl Insert {
43 pub fn table_name(&self) -> Result<&ObjectName> {
44 match &self.inner {
45 Statement::Insert(insert) => {
46 let TableObject::TableName(name) = &insert.table else {
47 return UnsupportedSnafu {
48 keyword: "TABLE FUNCTION".to_string(),
49 }
50 .fail();
51 };
52 Ok(name)
53 }
54 _ => unreachable!(),
55 }
56 }
57
58 pub fn columns(&self) -> Vec<&String> {
59 match &self.inner {
60 Statement::Insert(insert) => insert.columns.iter().map(|ident| &ident.value).collect(),
61 _ => unreachable!(),
62 }
63 }
64
65 pub fn values_body(&self) -> Result<Vec<Vec<Value>>> {
67 match &self.inner {
68 Statement::Insert(SpInsert {
69 source:
70 Some(box Query {
71 body: box SetExpr::Values(Values { rows, .. }),
72 ..
73 }),
74 ..
75 }) => sql_exprs_to_values(rows),
76 _ => unreachable!(),
77 }
78 }
79
80 pub fn can_extract_values(&self) -> bool {
83 match &self.inner {
84 Statement::Insert(SpInsert {
85 source:
86 Some(box Query {
87 body: box SetExpr::Values(Values { rows, .. }),
88 ..
89 }),
90 ..
91 }) => rows.iter().all(|es| {
92 es.iter().all(|expr| match expr {
93 Expr::Value(_) => true,
94 Expr::Identifier(ident) => {
95 if ident.quote_style.is_none() {
96 ident.value.to_lowercase() == "default"
97 } else {
98 ident.quote_style == Some('"')
99 }
100 }
101 Expr::UnaryOp { op, expr } => {
102 matches!(op, UnaryOperator::Minus | UnaryOperator::Plus)
103 && matches!(
104 &**expr,
105 Expr::Value(ValueWithSpan {
106 value: Value::Number(_, _),
107 ..
108 })
109 )
110 }
111 _ => false,
112 })
113 }),
114 _ => false,
115 }
116 }
117
118 pub fn has_non_values_query_source(&self) -> bool {
120 match &self.inner {
121 Statement::Insert(SpInsert {
122 source: Some(box query),
123 ..
124 }) => !matches!(&*query.body, SetExpr::Values(_)),
125 _ => false,
126 }
127 }
128
129 pub fn query_body(&self) -> Result<Option<GtQuery>> {
130 Ok(match &self.inner {
131 Statement::Insert(SpInsert {
132 source: Some(box query),
133 ..
134 }) => Some(query.clone().try_into()?),
135 _ => None,
136 })
137 }
138}
139
140fn sql_exprs_to_values(exprs: &[Vec<Expr>]) -> Result<Vec<Vec<Value>>> {
141 let mut values = Vec::with_capacity(exprs.len());
142 for es in exprs.iter() {
143 let mut vs = Vec::with_capacity(es.len());
144 for expr in es.iter() {
145 vs.push(match expr {
146 Expr::Value(v) => v.value.clone(),
147 Expr::Identifier(ident) => {
148 if ident.quote_style.is_none() {
149 if ident.value.to_lowercase() == "default" {
151 Value::Placeholder(ident.value.clone())
152 } else {
153 parse_fail!(expr);
154 }
155 } else {
156 if ident.quote_style == Some('"') {
158 Value::SingleQuotedString(ident.value.clone())
159 } else {
160 parse_fail!(expr);
161 }
162 }
163 }
164 Expr::UnaryOp { op, expr }
165 if matches!(op, UnaryOperator::Minus | UnaryOperator::Plus) =>
166 {
167 if let Expr::Value(ValueWithSpan {
168 value: Value::Number(s, b),
169 ..
170 }) = &**expr
171 {
172 match op {
173 UnaryOperator::Minus => Value::Number(format!("-{s}"), *b),
174 UnaryOperator::Plus => Value::Number(s.clone(), *b),
175 _ => unreachable!(),
176 }
177 } else {
178 parse_fail!(expr);
179 }
180 }
181 _ => {
182 parse_fail!(expr);
183 }
184 });
185 }
186 values.push(vs);
187 }
188 Ok(values)
189}
190
191impl TryFrom<Statement> for Insert {
192 type Error = ParserError;
193
194 fn try_from(value: Statement) -> std::result::Result<Self, Self::Error> {
195 match value {
196 Statement::Insert { .. } => Ok(Insert { inner: value }),
197 unexp => Err(ParserError::ParserError(format!(
198 "Not expected to be {unexp}"
199 ))),
200 }
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use crate::dialect::GreptimeDbDialect;
208 use crate::parser::{ParseOptions, ParserContext};
209 use crate::statements::statement::Statement;
210
211 #[test]
212 fn test_insert_value_with_unary_op() {
213 let sql = "INSERT INTO my_table VALUES(-1)";
215 let stmt =
216 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
217 .unwrap()
218 .remove(0);
219 match stmt {
220 Statement::Insert(insert) => {
221 let values = insert.values_body().unwrap();
222 assert_eq!(values, vec![vec![Value::Number("-1".to_string(), false)]]);
223 }
224 _ => unreachable!(),
225 }
226
227 let sql = "INSERT INTO my_table VALUES(+1)";
229 let stmt =
230 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
231 .unwrap()
232 .remove(0);
233 match stmt {
234 Statement::Insert(insert) => {
235 let values = insert.values_body().unwrap();
236 assert_eq!(values, vec![vec![Value::Number("1".to_string(), false)]]);
237 }
238 _ => unreachable!(),
239 }
240 }
241
242 #[test]
243 fn test_insert_value_with_default() {
244 let sql = "INSERT INTO my_table VALUES(default)";
246 let stmt =
247 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
248 .unwrap()
249 .remove(0);
250 match stmt {
251 Statement::Insert(insert) => {
252 let values = insert.values_body().unwrap();
253 assert_eq!(values, vec![vec![Value::Placeholder("default".to_owned())]]);
254 }
255 _ => unreachable!(),
256 }
257 }
258
259 #[test]
260 fn test_insert_value_with_default_uppercase() {
261 let sql = "INSERT INTO my_table VALUES(DEFAULT)";
263 let stmt =
264 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
265 .unwrap()
266 .remove(0);
267 match stmt {
268 Statement::Insert(insert) => {
269 let values = insert.values_body().unwrap();
270 assert_eq!(values, vec![vec![Value::Placeholder("DEFAULT".to_owned())]]);
271 }
272 _ => unreachable!(),
273 }
274 }
275
276 #[test]
277 fn test_insert_value_with_quoted_string() {
278 let sql = "INSERT INTO my_table VALUES('default')";
280 let stmt =
281 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
282 .unwrap()
283 .remove(0);
284 match stmt {
285 Statement::Insert(insert) => {
286 let values = insert.values_body().unwrap();
287 assert_eq!(
288 values,
289 vec![vec![Value::SingleQuotedString("default".to_owned())]]
290 );
291 }
292 _ => unreachable!(),
293 }
294
295 let sql = "INSERT INTO my_table VALUES(\"default\")";
297 let stmt =
298 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
299 .unwrap()
300 .remove(0);
301 match stmt {
302 Statement::Insert(insert) => {
303 let values = insert.values_body().unwrap();
304 assert_eq!(
305 values,
306 vec![vec![Value::SingleQuotedString("default".to_owned())]]
307 );
308 }
309 _ => unreachable!(),
310 }
311
312 let sql = "INSERT INTO my_table VALUES(`default`)";
313 let stmt =
314 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
315 .unwrap()
316 .remove(0);
317 match stmt {
318 Statement::Insert(insert) => {
319 assert!(insert.values_body().is_err());
320 }
321 _ => unreachable!(),
322 }
323 }
324
325 #[test]
326 fn test_insert_select() {
327 let sql = "INSERT INTO my_table select * from other_table";
328 let stmt =
329 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
330 .unwrap()
331 .remove(0);
332 match stmt {
333 Statement::Insert(insert) => {
334 let q = insert.query_body().unwrap().unwrap();
335 assert!(insert.has_non_values_query_source());
336 assert!(matches!(
337 q.inner,
338 Query {
339 body: box SetExpr::Select { .. },
340 ..
341 }
342 ));
343 }
344 _ => unreachable!(),
345 }
346 }
347
348 #[test]
349 fn test_has_non_values_query_source() {
350 let cases = [
351 ("INSERT INTO my_table SELECT * FROM other_table", true),
352 (
353 "INSERT INTO my_table WITH cte AS (SELECT * FROM other_table) SELECT * FROM cte",
354 true,
355 ),
356 (
357 "INSERT INTO my_table SELECT * FROM t1 UNION ALL SELECT * FROM t2",
358 true,
359 ),
360 ("INSERT INTO my_table VALUES(1)", false),
361 ("INSERT INTO my_table VALUES(now())", false),
362 ("INSERT INTO my_table VALUES(1 + 1)", false),
363 ];
364
365 for (sql, expected) in cases {
366 let stmt = ParserContext::create_with_dialect(
367 sql,
368 &GreptimeDbDialect {},
369 ParseOptions::default(),
370 )
371 .unwrap()
372 .remove(0);
373 match stmt {
374 Statement::Insert(insert) => {
375 assert_eq!(insert.has_non_values_query_source(), expected, "{sql}");
376 }
377 _ => unreachable!(),
378 }
379 }
380 }
381}