diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs
index 1988348adb..6ebcf1688b 100644
--- a/proxy/src/http/websocket.rs
+++ b/proxy/src/http/websocket.rs
@@ -12,9 +12,11 @@ use hyper::{
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
use pin_project_lite::pin_project;
use pq_proto::StartupMessageParams;
+use serde_json::Value;
use tokio::sync::Mutex;
use percent_encoding::percent_decode;
+use tokio_postgres::types::ToSql;
use std::collections::HashMap;
use std::{
convert::Infallible,
@@ -30,7 +32,7 @@ use tokio::{
};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument};
-use url::form_urlencoded;
+use url::{form_urlencoded, Url};
use utils::http::{error::ApiError, json::json_response};
// TODO: use `std::sync::Exclusive` once it's stabilized.
@@ -215,39 +217,71 @@ async fn handle_sql(
request: Request
,
cache: Arc>,
) -> anyhow::Result {
- let get_params = request
- .uri()
- .query()
- .ok_or(anyhow::anyhow!("missing query string"))?;
- let parsed_params: HashMap = form_urlencoded::parse(get_params.as_bytes())
- .into_owned()
- .collect();
+ let headers = request.headers();
- let sql = parsed_params
- .get("query")
- .ok_or(anyhow::anyhow!("missing query"))?;
- let dbname = parsed_params
- .get("dbname")
- .ok_or(anyhow::anyhow!("missing dbname"))?;
- let username = parsed_params
- .get("username")
- .ok_or(anyhow::anyhow!("missing username"))?;
- let password = parsed_params
- .get("password")
- .ok_or(anyhow::anyhow!("missing password"))?;
- // XXX: does URI includes host too? then Url::parse() should work for both host_str and params
- let hostname = request
+ let connection_string = headers
+ .get("X-Neon-ConnectionString")
+ .ok_or(anyhow::anyhow!("missing connection string"))?
+ .to_str()?;
+
+ let connection_url = Url::parse(connection_string)?;
+
+ let mut url_path = connection_url
+ .path_segments()
+ .ok_or(anyhow::anyhow!("missing database name"))?;
+
+ let dbname = url_path
+ .next()
+ .ok_or(anyhow::anyhow!("invalid database name"))?;
+
+ let username = match connection_url.username() {
+ "" => return Err(anyhow::anyhow!("empty username")),
+ s => Ok(s)
+ }?;
+
+ let maybe_empty_password = connection_url
+ .password()
+ .ok_or(anyhow::anyhow!("no password"))?;
+
+ let password = match maybe_empty_password {
+ "" => return Err(anyhow::anyhow!("empty password")),
+ s => Ok(s)
+ }?;
+
+ let hostname = connection_url
+ .host_str()
+ .ok_or(anyhow::anyhow!("no host"))?;
+
+ let host_header = request
.headers()
.get("host")
.and_then(|h| h.to_str().ok())
- .and_then(|h| h.split(':').next())
- .map(|s| s.to_string())
- .ok_or(anyhow::anyhow!("missing host header"))?;
+ .and_then(|h| h.split(':').next());
+
+ match host_header {
+ Some(h) if h == hostname => Ok(h),
+ Some(_) => return Err(anyhow::anyhow!("mismatched host header and hostname")),
+ None => return Err(anyhow::anyhow!("no host header"))
+ };
+
+ let body = request.into_body();
+ let mut data = Vec::with_capacity(512);
+ while let Some(chunk) = body.next().await {
+ data.extend(&chunk?);
+ }
+
+ #[derive(serde::Deserialize)]
+ struct QueryData {
+ query: String,
+ params: Vec
+ }
+
+ let queryData: QueryData = serde_json::from_slice(&data)?;
let params = StartupMessageParams::new([
- ("user", username.as_str()),
- ("database", dbname.as_str()),
+ ("user", username),
+ ("database", dbname),
("application_name", "proxy_http_sql"),
]);
let tls = config.tls_config.as_ref();
@@ -255,7 +289,7 @@ async fn handle_sql(
let creds = config
.auth_backend
.as_ref()
- .map(|_| auth::ClientCredentials::parse(¶ms, Some(hostname.as_str()), common_names))
+ .map(|_| auth::ClientCredentials::parse(¶ms, Some(hostname), common_names))
.transpose()?;
let extra = console::ConsoleReqExtra {
@@ -282,7 +316,43 @@ async fn handle_sql(
dbname
);
- ConnectionCache::execute(&cache, conn_string, &hostname, sql).await
+ let (client, connection) = tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
+ tokio::spawn(async move {
+ if let Err(e) = connection.await {
+ eprintln!("connection error: {}", e);
+ }
+ });
+
+ // let sql = percent_decode(sql.as_bytes()).decode_utf8()?.to_string();
+
+ let rows: Vec> = client
+ .query(&queryData.query, queryData.params.iter().map(|x| match x {
+ Value::Null => None,
+ Value::Bool(boolean) => boolean,
+ Value::Number(number) => number,
+ Value::String(string) => string,
+ _ => return Err(anyhow::anyhow!("unsupported param type"))
+ }).collect())
+ .await?
+ .into_iter()
+ .filter_map(|el| {
+ if let tokio_postgres::SimpleQueryMessage::Row(row) = el {
+ let mut serilaized_row: HashMap = HashMap::new();
+ for i in 0..row.len() {
+ let col = row.columns().get(i).map_or("?", |c| c.name());
+ let val = row.get(i).unwrap_or("?");
+ serilaized_row.insert(col.into(), val.into());
+ }
+ Some(serilaized_row)
+ } else {
+ None
+ }
+ })
+ .collect();
+
+ Ok(serde_json::to_string(&rows)?)
+
+ // ConnectionCache::execute(&cache, conn_string, &hostname, sql).await
}
pub struct ConnectionCache {