diff --git a/proxy/src/http/sql_over_http.rs b/proxy/src/http/sql_over_http.rs index a4eb13a10b..8168a4cb35 100644 --- a/proxy/src/http/sql_over_http.rs +++ b/proxy/src/http/sql_over_http.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use anyhow::bail; +use anyhow::Context; use futures::pin_mut; use futures::StreamExt; use hashbrown::HashMap; @@ -49,7 +50,7 @@ static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); // Convert json non-string types to strings, so that they can be passed to Postgres // as parameters. // -fn json_to_pg_text(json: Vec) -> Result>, serde_json::Error> { +fn json_to_pg_text(json: &[Value]) -> Result>, serde_json::Error> { json.iter() .map(|value| { match value { @@ -230,7 +231,7 @@ pub async fn handle( // Now execute the query and return the result // let result = match payload { - Payload::Single(query) => query_to_json(&client, query, raw_output, array_mode) + Payload::Single(query) => query_to_json(&client, &query, raw_output, array_mode) .await .map(|x| (x, HashMap::default())), Payload::Batch(queries) => { @@ -243,12 +244,18 @@ pub async fn handle( builder = builder.read_only(true); } let transaction = builder.start().await?; - for query in queries { - let result = query_to_json(&transaction, query, raw_output, array_mode).await; + for (idx, query) in queries.into_iter().enumerate() { + let result = query_to_json(&transaction, &query, raw_output, array_mode) + .await + .with_context(|| { + format!("error when executing queries[{}] \"{}\"", idx, query.query) + }); match result { Ok(r) => results.push(r), Err(e) => { - transaction.rollback().await?; + transaction.rollback().await.with_context(|| { + format!("error when rollback queries[{}] \"{}\"", idx, query.query) + })?; return Err(e); } } @@ -278,13 +285,20 @@ pub async fn handle( async fn query_to_json( client: &T, - data: QueryData, + data: &QueryData, raw_output: bool, array_mode: bool, ) -> anyhow::Result { - let query_params = json_to_pg_text(data.params)?; + let query_params = json_to_pg_text(&data.params)?; let row_stream = client - .query_raw_txt::(data.query, query_params) + // TODO: query_raw_txt should be able to accept &str and Vec + .query_raw_txt::<&str, _>( + &data.query, + query_params + .iter() + .map(|x| x.as_ref().map(|y| y.as_str())) + .collect::>>(), + ) .await?; // Manually drain the stream into a vector to leave row_stream hanging @@ -533,22 +547,22 @@ mod tests { #[test] fn test_atomic_types_to_pg_params() { let json = vec![Value::Bool(true), Value::Bool(false)]; - let pg_params = json_to_pg_text(json).unwrap(); + let pg_params = json_to_pg_text(&json).unwrap(); assert_eq!( pg_params, vec![Some("true".to_owned()), Some("false".to_owned())] ); let json = vec![Value::Number(serde_json::Number::from(42))]; - let pg_params = json_to_pg_text(json).unwrap(); + let pg_params = json_to_pg_text(&json).unwrap(); assert_eq!(pg_params, vec![Some("42".to_owned())]); let json = vec![Value::String("foo\"".to_string())]; - let pg_params = json_to_pg_text(json).unwrap(); + let pg_params = json_to_pg_text(&json).unwrap(); assert_eq!(pg_params, vec![Some("foo\"".to_owned())]); let json = vec![Value::Null]; - let pg_params = json_to_pg_text(json).unwrap(); + let pg_params = json_to_pg_text(&json).unwrap(); assert_eq!(pg_params, vec![None]); } @@ -557,7 +571,7 @@ mod tests { // atoms and escaping let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]"; let json: Value = serde_json::from_str(json).unwrap(); - let pg_params = json_to_pg_text(vec![json]).unwrap(); + let pg_params = json_to_pg_text(&[json]).unwrap(); assert_eq!( pg_params, vec![Some( @@ -568,7 +582,7 @@ mod tests { // nested arrays let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]"; let json: Value = serde_json::from_str(json).unwrap(); - let pg_params = json_to_pg_text(vec![json]).unwrap(); + let pg_params = json_to_pg_text(&[json]).unwrap(); assert_eq!( pg_params, vec![Some(