From cb88df7ffa080363ba0a8429511ec046d80b2dce Mon Sep 17 00:00:00 2001 From: George MacKerron Date: Thu, 4 May 2023 22:00:47 +0100 Subject: [PATCH] Broken/incomplete attempts to implement new http API for serverless driver --- proxy/src/http/websocket.rs | 128 ++++++++++++++++++++++++++++-------- 1 file changed, 99 insertions(+), 29 deletions(-) diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index 1988348adb..6ebcf1688b 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -12,9 +12,11 @@ use hyper::{ use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream}; use pin_project_lite::pin_project; use pq_proto::StartupMessageParams; +use serde_json::Value; use tokio::sync::Mutex; use percent_encoding::percent_decode; +use tokio_postgres::types::ToSql; use std::collections::HashMap; use std::{ convert::Infallible, @@ -30,7 +32,7 @@ use tokio::{ }; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, warn, Instrument}; -use url::form_urlencoded; +use url::{form_urlencoded, Url}; use utils::http::{error::ApiError, json::json_response}; // TODO: use `std::sync::Exclusive` once it's stabilized. @@ -215,39 +217,71 @@ async fn handle_sql( request: Request, cache: Arc>, ) -> anyhow::Result { - let get_params = request - .uri() - .query() - .ok_or(anyhow::anyhow!("missing query string"))?; - let parsed_params: HashMap = form_urlencoded::parse(get_params.as_bytes()) - .into_owned() - .collect(); + let headers = request.headers(); - let sql = parsed_params - .get("query") - .ok_or(anyhow::anyhow!("missing query"))?; - let dbname = parsed_params - .get("dbname") - .ok_or(anyhow::anyhow!("missing dbname"))?; - let username = parsed_params - .get("username") - .ok_or(anyhow::anyhow!("missing username"))?; - let password = parsed_params - .get("password") - .ok_or(anyhow::anyhow!("missing password"))?; - // XXX: does URI includes host too? then Url::parse() should work for both host_str and params - let hostname = request + let connection_string = headers + .get("X-Neon-ConnectionString") + .ok_or(anyhow::anyhow!("missing connection string"))? + .to_str()?; + + let connection_url = Url::parse(connection_string)?; + + let mut url_path = connection_url + .path_segments() + .ok_or(anyhow::anyhow!("missing database name"))?; + + let dbname = url_path + .next() + .ok_or(anyhow::anyhow!("invalid database name"))?; + + let username = match connection_url.username() { + "" => return Err(anyhow::anyhow!("empty username")), + s => Ok(s) + }?; + + let maybe_empty_password = connection_url + .password() + .ok_or(anyhow::anyhow!("no password"))?; + + let password = match maybe_empty_password { + "" => return Err(anyhow::anyhow!("empty password")), + s => Ok(s) + }?; + + let hostname = connection_url + .host_str() + .ok_or(anyhow::anyhow!("no host"))?; + + let host_header = request .headers() .get("host") .and_then(|h| h.to_str().ok()) - .and_then(|h| h.split(':').next()) - .map(|s| s.to_string()) - .ok_or(anyhow::anyhow!("missing host header"))?; + .and_then(|h| h.split(':').next()); + + match host_header { + Some(h) if h == hostname => Ok(h), + Some(_) => return Err(anyhow::anyhow!("mismatched host header and hostname")), + None => return Err(anyhow::anyhow!("no host header")) + }; + + let body = request.into_body(); + let mut data = Vec::with_capacity(512); + while let Some(chunk) = body.next().await { + data.extend(&chunk?); + } + + #[derive(serde::Deserialize)] + struct QueryData { + query: String, + params: Vec + } + + let queryData: QueryData = serde_json::from_slice(&data)?; let params = StartupMessageParams::new([ - ("user", username.as_str()), - ("database", dbname.as_str()), + ("user", username), + ("database", dbname), ("application_name", "proxy_http_sql"), ]); let tls = config.tls_config.as_ref(); @@ -255,7 +289,7 @@ async fn handle_sql( let creds = config .auth_backend .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, Some(hostname.as_str()), common_names)) + .map(|_| auth::ClientCredentials::parse(¶ms, Some(hostname), common_names)) .transpose()?; let extra = console::ConsoleReqExtra { @@ -282,7 +316,43 @@ async fn handle_sql( dbname ); - ConnectionCache::execute(&cache, conn_string, &hostname, sql).await + let (client, connection) = tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?; + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + // let sql = percent_decode(sql.as_bytes()).decode_utf8()?.to_string(); + + let rows: Vec> = client + .query(&queryData.query, queryData.params.iter().map(|x| match x { + Value::Null => None, + Value::Bool(boolean) => boolean, + Value::Number(number) => number, + Value::String(string) => string, + _ => return Err(anyhow::anyhow!("unsupported param type")) + }).collect()) + .await? + .into_iter() + .filter_map(|el| { + if let tokio_postgres::SimpleQueryMessage::Row(row) = el { + let mut serilaized_row: HashMap = HashMap::new(); + for i in 0..row.len() { + let col = row.columns().get(i).map_or("?", |c| c.name()); + let val = row.get(i).unwrap_or("?"); + serilaized_row.insert(col.into(), val.into()); + } + Some(serilaized_row) + } else { + None + } + }) + .collect(); + + Ok(serde_json::to_string(&rows)?) + + // ConnectionCache::execute(&cache, conn_string, &hostname, sql).await } pub struct ConnectionCache {