mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-18 02:42:56 +00:00
Compare commits
7 Commits
proxy_id
...
sk/sql_ove
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
595248532c | ||
|
|
158c051c9a | ||
|
|
143c4954df | ||
|
|
373ae7672b | ||
|
|
1233923c3a | ||
|
|
1a3a7f14dd | ||
|
|
4e9067a8c2 |
33
Cargo.lock
generated
33
Cargo.lock
generated
@@ -297,7 +297,7 @@ dependencies = [
|
||||
"http",
|
||||
"http-body",
|
||||
"lazy_static",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project-lite",
|
||||
"tracing",
|
||||
]
|
||||
@@ -401,7 +401,7 @@ dependencies = [
|
||||
"hex",
|
||||
"http",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"regex",
|
||||
"ring",
|
||||
"time",
|
||||
@@ -522,7 +522,7 @@ dependencies = [
|
||||
"http-body",
|
||||
"hyper",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"tokio",
|
||||
@@ -544,7 +544,7 @@ dependencies = [
|
||||
"http-body",
|
||||
"hyper",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"tracing",
|
||||
@@ -684,7 +684,7 @@ dependencies = [
|
||||
"matchit",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"serde",
|
||||
@@ -1580,7 +1580,7 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2584,7 +2584,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"once_cell",
|
||||
"opentelemetry_api",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"rand",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
@@ -2749,6 +2749,12 @@ dependencies = [
|
||||
"base64 0.13.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "31010dd2e1ac33d5b46a5b413495239882813e0369f8ed8a5e266f173602f831"
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.2.0"
|
||||
@@ -3112,6 +3118,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"opentelemetry",
|
||||
"parking_lot",
|
||||
"percent-encoding 1.0.1",
|
||||
"pin-project-lite",
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
@@ -3316,7 +3323,7 @@ dependencies = [
|
||||
"mime",
|
||||
"mime_guess",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project-lite",
|
||||
"rustls 0.20.8",
|
||||
"rustls-pemfile",
|
||||
@@ -3389,7 +3396,7 @@ dependencies = [
|
||||
"http",
|
||||
"hyper",
|
||||
"lazy_static",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"regex",
|
||||
]
|
||||
|
||||
@@ -4332,7 +4339,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"phf",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
@@ -4480,7 +4487,7 @@ dependencies = [
|
||||
"http-body",
|
||||
"hyper",
|
||||
"hyper-timeout",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project",
|
||||
"prost",
|
||||
"prost-derive",
|
||||
@@ -4512,7 +4519,7 @@ dependencies = [
|
||||
"http-body",
|
||||
"hyper",
|
||||
"hyper-timeout",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project",
|
||||
"prost",
|
||||
"rustls-native-certs",
|
||||
@@ -4822,7 +4829,7 @@ checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
|
||||
dependencies = [
|
||||
"form_urlencoded",
|
||||
"idna",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"serde",
|
||||
]
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ hmac = "0.12.1"
|
||||
hostname = "0.3.1"
|
||||
humantime = "2.1"
|
||||
humantime-serde = "1.1.1"
|
||||
hyper = "0.14"
|
||||
hyper = { version = "0.14", features = ["http2", "tcp", "runtime", "http1"]}
|
||||
hyper-tungstenite = "0.9"
|
||||
itertools = "0.10"
|
||||
jsonwebtoken = "8"
|
||||
@@ -117,6 +117,7 @@ uuid = { version = "1.2", features = ["v4", "serde"] }
|
||||
walkdir = "2.3.2"
|
||||
webpki-roots = "0.23"
|
||||
x509-parser = "0.15"
|
||||
percent-encoding = "1.0"
|
||||
|
||||
## TODO replace this with tracing
|
||||
env_logger = "0.10"
|
||||
|
||||
@@ -23,7 +23,7 @@ hmac.workspace = true
|
||||
hostname.workspace = true
|
||||
humantime.workspace = true
|
||||
hyper-tungstenite.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper = { workspace = true, features = ["http2", "http1", "tcp", "runtime"] }
|
||||
itertools.workspace = true
|
||||
md5.workspace = true
|
||||
metrics.workspace = true
|
||||
@@ -65,6 +65,7 @@ x509-parser.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
tokio-util.workspace = true
|
||||
percent-encoding.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
rcgen.workspace = true
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
use crate::{
|
||||
cancellation::CancelMap, config::ProxyConfig, error::io_error, proxy::handle_ws_client,
|
||||
auth, cancellation::CancelMap, config::ProxyConfig, console, error::io_error,
|
||||
proxy::handle_ws_client,
|
||||
};
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{Sink, Stream, StreamExt};
|
||||
use hyper::{
|
||||
server::{accept, conn::AddrIncoming},
|
||||
upgrade::Upgraded,
|
||||
Body, Request, Response, StatusCode,
|
||||
Body, Method, Request, Response, StatusCode,
|
||||
};
|
||||
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
||||
use pin_project_lite::pin_project;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use percent_encoding::percent_decode;
|
||||
use std::collections::HashMap;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
future::ready,
|
||||
@@ -24,6 +30,7 @@ use tokio::{
|
||||
};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
use url::form_urlencoded;
|
||||
use utils::http::{error::ApiError, json::json_response};
|
||||
|
||||
// TODO: use `std::sync::Exclusive` once it's stabilized.
|
||||
@@ -159,6 +166,7 @@ async fn ws_handler(
|
||||
config: &'static ProxyConfig,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
session_id: uuid::Uuid,
|
||||
cache: Arc<Mutex<ConnectionCache>>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
@@ -181,13 +189,170 @@ async fn ws_handler(
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||
match handle_sql(config, request, cache).await {
|
||||
Ok(resp) => json_response(StatusCode::OK, resp).map(|mut r| {
|
||||
r.headers_mut().insert(
|
||||
"Access-Control-Allow-Origin",
|
||||
hyper::http::HeaderValue::from_static("*"),
|
||||
);
|
||||
r
|
||||
}),
|
||||
Err(e) => json_response(StatusCode::BAD_REQUEST, format!("error: {e:?}")),
|
||||
}
|
||||
} else if request.uri().path() == "/sleep" {
|
||||
// sleep 15ms
|
||||
tokio::time::sleep(std::time::Duration::from_millis(15)).await;
|
||||
json_response(StatusCode::OK, "done")
|
||||
} else {
|
||||
json_response(StatusCode::OK, "Connect with a websocket client")
|
||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||
}
|
||||
}
|
||||
|
||||
// XXX: return different error codes
|
||||
async fn handle_sql(
|
||||
config: &'static ProxyConfig,
|
||||
request: Request<Body>,
|
||||
cache: Arc<Mutex<ConnectionCache>>,
|
||||
) -> anyhow::Result<String> {
|
||||
let get_params = request
|
||||
.uri()
|
||||
.query()
|
||||
.ok_or(anyhow::anyhow!("missing query string"))?;
|
||||
|
||||
let parsed_params: HashMap<String, String> = form_urlencoded::parse(get_params.as_bytes())
|
||||
.into_owned()
|
||||
.collect();
|
||||
|
||||
let sql = parsed_params
|
||||
.get("query")
|
||||
.ok_or(anyhow::anyhow!("missing query"))?;
|
||||
let dbname = parsed_params
|
||||
.get("dbname")
|
||||
.ok_or(anyhow::anyhow!("missing dbname"))?;
|
||||
let username = parsed_params
|
||||
.get("username")
|
||||
.ok_or(anyhow::anyhow!("missing username"))?;
|
||||
let password = parsed_params
|
||||
.get("password")
|
||||
.ok_or(anyhow::anyhow!("missing password"))?;
|
||||
// XXX: does URI includes host too? then Url::parse() should work for both host_str and params
|
||||
let hostname = request
|
||||
.headers()
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.split(':').next())
|
||||
.map(|s| s.to_string())
|
||||
.ok_or(anyhow::anyhow!("missing host header"))?;
|
||||
|
||||
let params = StartupMessageParams::new([
|
||||
("user", username.as_str()),
|
||||
("database", dbname.as_str()),
|
||||
("application_name", "proxy_http_sql"),
|
||||
]);
|
||||
let tls = config.tls_config.as_ref();
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
let creds = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, Some(hostname.as_str()), common_names))
|
||||
.transpose()?;
|
||||
|
||||
let extra = console::ConsoleReqExtra {
|
||||
session_id: uuid::Uuid::new_v4(),
|
||||
application_name: Some("proxy_http_sql"),
|
||||
};
|
||||
|
||||
let node = creds.wake_compute(&extra).await?.expect("msg");
|
||||
let conf = node.value.config;
|
||||
|
||||
let host = match conf.get_hosts().first().expect("no host") {
|
||||
tokio_postgres::config::Host::Tcp(host) => host,
|
||||
tokio_postgres::config::Host::Unix(_) => {
|
||||
return Err(anyhow::anyhow!("unix socket is not supported"));
|
||||
}
|
||||
};
|
||||
|
||||
let conn_string = &format!(
|
||||
"host={} port={} user={} password={} dbname={}",
|
||||
host,
|
||||
conf.get_ports().first().expect("no port"),
|
||||
username,
|
||||
password,
|
||||
dbname
|
||||
);
|
||||
|
||||
ConnectionCache::execute(&cache, conn_string, &hostname, sql).await
|
||||
}
|
||||
|
||||
pub struct ConnectionCache {
|
||||
connections: HashMap<String, tokio_postgres::Client>,
|
||||
}
|
||||
|
||||
impl ConnectionCache {
|
||||
pub fn new() -> Arc<Mutex<Self>> {
|
||||
Arc::new(Mutex::new(Self {
|
||||
connections: HashMap::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn execute(
|
||||
cache: &Arc<Mutex<ConnectionCache>>,
|
||||
conn_string: &str,
|
||||
hostname: &str,
|
||||
sql: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
// TODO: let go mutex when establishing connection
|
||||
let mut cache = cache.lock().await;
|
||||
let cache_key = format!("connstr={}, hostname={}", conn_string, hostname);
|
||||
let client = if let Some(client) = cache.connections.get(&cache_key) {
|
||||
info!("using cached connection {}", conn_string);
|
||||
client
|
||||
} else {
|
||||
info!("!!!! connecting to: {}", conn_string);
|
||||
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
// TODO: remove connection from cache
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
cache.connections.insert(cache_key.clone(), client);
|
||||
cache.connections.get(&cache_key).unwrap()
|
||||
};
|
||||
|
||||
let sql = percent_decode(sql.as_bytes()).decode_utf8()?.to_string();
|
||||
|
||||
let rows: Vec<HashMap<_, _>> = client
|
||||
.simple_query(&sql)
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter_map(|el| {
|
||||
if let tokio_postgres::SimpleQueryMessage::Row(row) = el {
|
||||
let mut serilaized_row: HashMap<String, String> = HashMap::new();
|
||||
for i in 0..row.len() {
|
||||
let col = row.columns().get(i).map_or("?", |c| c.name());
|
||||
let val = row.get(i).unwrap_or("?");
|
||||
serilaized_row.insert(col.into(), val.into());
|
||||
}
|
||||
Some(serilaized_row)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(serde_json::to_string(&rows)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
cache: &'static Arc<Mutex<ConnectionCache>>,
|
||||
ws_listener: TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -221,7 +386,7 @@ pub async fn task_main(
|
||||
move |req: Request<Body>| async move {
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
ws_handler(req, config, cancel_map, session_id)
|
||||
ws_handler(req, config, cancel_map, session_id, cache.clone())
|
||||
.instrument(info_span!(
|
||||
"ws-client",
|
||||
session = format_args!("{session_id}")
|
||||
|
||||
@@ -32,6 +32,8 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info, warn};
|
||||
use utils::{project_git_version, sentry_init::init_sentry};
|
||||
|
||||
use crate::http::websocket::ConnectionCache;
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
|
||||
/// Flattens `Result<Result<T>>` into `Result<T>`.
|
||||
@@ -53,6 +55,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
let args = cli().get_matches();
|
||||
let config = build_config(&args)?;
|
||||
|
||||
let wsconn_cache = Box::leak(Box::new(ConnectionCache::new()));
|
||||
|
||||
info!("Authentication backend: {}", config.auth_backend);
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
@@ -82,6 +86,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
client_tasks.push(tokio::spawn(http::websocket::task_main(
|
||||
config,
|
||||
wsconn_cache,
|
||||
wss_listener,
|
||||
cancellation_token.clone(),
|
||||
)));
|
||||
|
||||
Reference in New Issue
Block a user