diff --git a/src/datanode/src/sql/insert.rs b/src/datanode/src/sql/insert.rs index b7fbfe39bc..cbc6f201b7 100644 --- a/src/datanode/src/sql/insert.rs +++ b/src/datanode/src/sql/insert.rs @@ -34,7 +34,7 @@ impl SqlHandler { stmt: Insert, ) -> Result { let columns = stmt.columns(); - let values = stmt.values(); + let values = stmt.values().context(ParseSqlValueSnafu)?; //TODO(dennis): table name may be in the form of `catalog.schema.table`, // but we don't process it right now. let table_name = stmt.table_name(); diff --git a/src/sql/src/statements/insert.rs b/src/sql/src/statements/insert.rs index ad9064fd0b..5ab89c7a3f 100644 --- a/src/sql/src/statements/insert.rs +++ b/src/sql/src/statements/insert.rs @@ -1,7 +1,8 @@ -use sqlparser::ast::{SetExpr, Statement, Values}; +use sqlparser::ast::{SetExpr, Statement, UnaryOperator, Values}; use sqlparser::parser::ParserError; use crate::ast::{Expr, Value}; +use crate::error::{self, Result}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Insert { @@ -27,34 +28,59 @@ impl Insert { } } - pub fn values(&self) -> Vec> { - match &self.inner { + pub fn values(&self) -> Result>> { + let values = match &self.inner { Statement::Insert { source, .. } => match &source.body { - SetExpr::Values(Values(values)) => values - .iter() - .map(|v| { - v.iter() - .map(|expr| match expr { - Expr::Value(v) => v.clone(), - Expr::Identifier(ident) => { - Value::SingleQuotedString(ident.value.clone()) - } - _ => unreachable!(), - }) - .collect::>() - }) - .collect(), + SetExpr::Values(Values(exprs)) => sql_exprs_to_values(exprs)?, _ => unreachable!(), }, _ => unreachable!(), - } + }; + Ok(values) } } +fn sql_exprs_to_values(exprs: &Vec>) -> Result>> { + let mut values = Vec::with_capacity(exprs.len()); + for es in exprs.iter() { + let mut vs = Vec::with_capacity(es.len()); + for expr in es.iter() { + vs.push(match expr { + Expr::Value(v) => v.clone(), + Expr::Identifier(ident) => Value::SingleQuotedString(ident.value.clone()), + Expr::UnaryOp { op, expr } + if matches!(op, UnaryOperator::Minus | UnaryOperator::Plus) => + { + if let Expr::Value(Value::Number(s, b)) = &**expr { + match op { + UnaryOperator::Minus => Value::Number(format!("-{}", s), *b), + UnaryOperator::Plus => Value::Number(s.to_string(), *b), + _ => unreachable!(), + } + } else { + return error::ParseSqlValueSnafu { + msg: format!("{:?}", expr), + } + .fail(); + } + } + _ => { + return error::ParseSqlValueSnafu { + msg: format!("{:?}", expr), + } + .fail() + } + }); + } + values.push(vs); + } + Ok(values) +} + impl TryFrom for Insert { type Error = ParserError; - fn try_from(value: Statement) -> Result { + fn try_from(value: Statement) -> std::result::Result { match value { Statement::Insert { .. } => Ok(Insert { inner: value }), unexp => Err(ParserError::ParserError(format!( @@ -78,7 +104,37 @@ mod tests { let mut stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); assert_eq!(1, stmts.len()); let insert = stmts.pop().unwrap(); - let r: Result = insert.try_into(); - r.unwrap(); + let _stmt: Statement = insert.try_into().unwrap(); + } + + #[test] + fn test_insert_value_with_unary_op() { + use crate::statements::statement::Statement; + + // insert "-1" + let sql = "INSERT INTO my_table VALUES(-1)"; + let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) + .unwrap() + .remove(0); + match stmt { + Statement::Insert(insert) => { + let values = insert.values().unwrap(); + assert_eq!(values, vec![vec![Value::Number("-1".to_string(), false)]]); + } + _ => unreachable!(), + } + + // insert "+1" + let sql = "INSERT INTO my_table VALUES(+1)"; + let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) + .unwrap() + .remove(0); + match stmt { + Statement::Insert(insert) => { + let values = insert.values().unwrap(); + assert_eq!(values, vec![vec![Value::Number("1".to_string(), false)]]); + } + _ => unreachable!(), + } } }