mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-17 18:32:56 +00:00
Compare commits
20 Commits
proxy_id
...
gm/sql_ove
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6abe34bfa1 | ||
|
|
b932c14f94 | ||
|
|
2e05a9b652 | ||
|
|
b808160aff | ||
|
|
2997018005 | ||
|
|
8f98cc29fa | ||
|
|
b4ee8e3b73 | ||
|
|
68ea916bf1 | ||
|
|
eea3dd54da | ||
|
|
4071b22519 | ||
|
|
c58edec63a | ||
|
|
fdb9d4373d | ||
|
|
cb88df7ffa | ||
|
|
595248532c | ||
|
|
158c051c9a | ||
|
|
143c4954df | ||
|
|
373ae7672b | ||
|
|
1233923c3a | ||
|
|
1a3a7f14dd | ||
|
|
4e9067a8c2 |
41
Cargo.lock
generated
41
Cargo.lock
generated
@@ -130,6 +130,12 @@ dependencies = [
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "array-init"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d62b7694a562cdf5a74227903507c56ab2cc8bdd1f781ed5cb4cf9c9f810bfc"
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs"
|
||||
version = "0.5.2"
|
||||
@@ -297,7 +303,7 @@ dependencies = [
|
||||
"http",
|
||||
"http-body",
|
||||
"lazy_static",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project-lite",
|
||||
"tracing",
|
||||
]
|
||||
@@ -401,7 +407,7 @@ dependencies = [
|
||||
"hex",
|
||||
"http",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"regex",
|
||||
"ring",
|
||||
"time",
|
||||
@@ -522,7 +528,7 @@ dependencies = [
|
||||
"http-body",
|
||||
"hyper",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"tokio",
|
||||
@@ -544,7 +550,7 @@ dependencies = [
|
||||
"http-body",
|
||||
"hyper",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"tracing",
|
||||
@@ -684,7 +690,7 @@ dependencies = [
|
||||
"matchit",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"serde",
|
||||
@@ -1580,7 +1586,7 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2584,7 +2590,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"once_cell",
|
||||
"opentelemetry_api",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"rand",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
@@ -2749,6 +2755,12 @@ dependencies = [
|
||||
"base64 0.13.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "31010dd2e1ac33d5b46a5b413495239882813e0369f8ed8a5e266f173602f831"
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.2.0"
|
||||
@@ -2879,7 +2891,9 @@ name = "postgres-types"
|
||||
version = "0.2.4"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?rev=43e6db254a97fdecbce33d8bc0890accfd74495e#43e6db254a97fdecbce33d8bc0890accfd74495e"
|
||||
dependencies = [
|
||||
"array-init",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"fallible-iterator",
|
||||
"postgres-protocol",
|
||||
]
|
||||
@@ -3112,6 +3126,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"opentelemetry",
|
||||
"parking_lot",
|
||||
"percent-encoding 1.0.1",
|
||||
"pin-project-lite",
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
@@ -3316,7 +3331,7 @@ dependencies = [
|
||||
"mime",
|
||||
"mime_guess",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project-lite",
|
||||
"rustls 0.20.8",
|
||||
"rustls-pemfile",
|
||||
@@ -3389,7 +3404,7 @@ dependencies = [
|
||||
"http",
|
||||
"hyper",
|
||||
"lazy_static",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"regex",
|
||||
]
|
||||
|
||||
@@ -4332,7 +4347,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"phf",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
@@ -4480,7 +4495,7 @@ dependencies = [
|
||||
"http-body",
|
||||
"hyper",
|
||||
"hyper-timeout",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project",
|
||||
"prost",
|
||||
"prost-derive",
|
||||
@@ -4512,7 +4527,7 @@ dependencies = [
|
||||
"http-body",
|
||||
"hyper",
|
||||
"hyper-timeout",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"pin-project",
|
||||
"prost",
|
||||
"rustls-native-certs",
|
||||
@@ -4822,7 +4837,7 @@ checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
|
||||
dependencies = [
|
||||
"form_urlencoded",
|
||||
"idna",
|
||||
"percent-encoding",
|
||||
"percent-encoding 2.2.0",
|
||||
"serde",
|
||||
]
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ hmac = "0.12.1"
|
||||
hostname = "0.3.1"
|
||||
humantime = "2.1"
|
||||
humantime-serde = "1.1.1"
|
||||
hyper = "0.14"
|
||||
hyper = { version = "0.14", features = ["http2", "tcp", "runtime", "http1"]}
|
||||
hyper-tungstenite = "0.9"
|
||||
itertools = "0.10"
|
||||
jsonwebtoken = "8"
|
||||
@@ -117,16 +117,17 @@ uuid = { version = "1.2", features = ["v4", "serde"] }
|
||||
walkdir = "2.3.2"
|
||||
webpki-roots = "0.23"
|
||||
x509-parser = "0.15"
|
||||
percent-encoding = "1.0"
|
||||
|
||||
## TODO replace this with tracing
|
||||
env_logger = "0.10"
|
||||
log = "0.4"
|
||||
|
||||
## Libraries from neondatabase/ git forks, ideally with changes to be upstreamed
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e", features = ["with-chrono-0_4", "array-impls"] }
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
||||
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
||||
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e", features = ["array-impls"] }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e", features = ["array-impls"] }
|
||||
tokio-tar = { git = "https://github.com/neondatabase/tokio-tar.git", rev="404df61437de0feef49ba2ccdbdd94eb8ad6e142" }
|
||||
|
||||
## Other git libraries
|
||||
|
||||
@@ -23,7 +23,7 @@ hmac.workspace = true
|
||||
hostname.workspace = true
|
||||
humantime.workspace = true
|
||||
hyper-tungstenite.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper = { workspace = true, features = ["http2", "http1", "tcp", "runtime"] }
|
||||
itertools.workspace = true
|
||||
md5.workspace = true
|
||||
metrics.workspace = true
|
||||
@@ -65,6 +65,7 @@ x509-parser.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
tokio-util.workspace = true
|
||||
percent-encoding.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
rcgen.workspace = true
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
//! Other modules should use stuff from this module instead of
|
||||
//! directly relying on deps like `reqwest` (think loose coupling).
|
||||
|
||||
pub mod pg_to_json;
|
||||
pub mod server;
|
||||
pub mod websocket;
|
||||
|
||||
|
||||
154
proxy/src/http/pg_to_json.rs
Normal file
154
proxy/src/http/pg_to_json.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use chrono::{Utc, DateTime, NaiveDateTime, NaiveTime, NaiveDate};
|
||||
use serde_json::Map;
|
||||
use tokio_postgres::{
|
||||
types::{FromSql, Type},
|
||||
Column, Row,
|
||||
};
|
||||
|
||||
// some type-aliases I use in my project
|
||||
pub type JSONValue = serde_json::Value;
|
||||
pub type RowData = Map<String, JSONValue>;
|
||||
pub type Error = anyhow::Error; // from: https://github.com/dtolnay/anyhow
|
||||
|
||||
pub fn postgres_row_to_json_value(row: &Row) -> Result<JSONValue, Error> {
|
||||
let row_data = postgres_row_to_row_data(row)?;
|
||||
Ok(JSONValue::Object(row_data))
|
||||
}
|
||||
|
||||
pub fn postgres_row_to_row_data(row: &Row) -> Result<RowData, Error> {
|
||||
let mut result: Map<String, JSONValue> = Map::new();
|
||||
for (i, column) in row.columns().iter().enumerate() {
|
||||
let name = column.name();
|
||||
let json_value = pg_cell_to_json_value(&row, column, i)?;
|
||||
result.insert(name.to_string(), json_value);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn pg_cell_to_json_value(
|
||||
row: &Row,
|
||||
column: &Column,
|
||||
column_i: usize,
|
||||
) -> Result<JSONValue, Error> {
|
||||
let f64_to_json_number = |raw_val: f64| -> Result<JSONValue, Error> {
|
||||
let temp = serde_json::Number::from_f64(raw_val).ok_or(anyhow!("invalid json-float"))?;
|
||||
Ok(JSONValue::Number(temp))
|
||||
};
|
||||
Ok(match *column.type_() {
|
||||
// for rust-postgres <> postgres type-mappings: https://docs.rs/postgres/latest/postgres/types/trait.FromSql.html#types
|
||||
// for postgres types: https://www.postgresql.org/docs/7.4/datatype.html#DATATYPE-TABLE
|
||||
|
||||
Type::TIMESTAMPTZ => get_basic(row, column, column_i, |a: DateTime<Utc>| {
|
||||
Ok(JSONValue::String(a.to_rfc3339()))
|
||||
})?,
|
||||
Type::TIMESTAMP => get_basic(row, column, column_i, |a: NaiveDateTime| {
|
||||
Ok(JSONValue::String(a.format("%Y-%m-%dT%H:%M:%S%.6f").to_string()))
|
||||
})?,
|
||||
|
||||
Type::TIME => get_basic(row, column, column_i, |a: NaiveTime| {
|
||||
Ok(JSONValue::String(a.format("%H:%M:%S%.6f").to_string()))
|
||||
})?,
|
||||
Type::DATE => get_basic(row, column, column_i, |a: NaiveDate| {
|
||||
Ok(JSONValue::String(a.format("%Y-%m-%d").to_string()))
|
||||
})?,
|
||||
// no TIMETZ support?
|
||||
|
||||
// single types
|
||||
Type::BOOL => get_basic(row, column, column_i, |a: bool| Ok(JSONValue::Bool(a)))?,
|
||||
Type::INT2 => get_basic(row, column, column_i, |a: i16| {
|
||||
Ok(JSONValue::Number(serde_json::Number::from(a)))
|
||||
})?,
|
||||
Type::INT4 => get_basic(row, column, column_i, |a: i32| {
|
||||
Ok(JSONValue::Number(serde_json::Number::from(a)))
|
||||
})?,
|
||||
Type::INT8 => get_basic(row, column, column_i, |a: i64| {
|
||||
Ok(JSONValue::Number(serde_json::Number::from(a)))
|
||||
})?,
|
||||
Type::TEXT | Type::VARCHAR => {
|
||||
get_basic(row, column, column_i, |a: String| Ok(JSONValue::String(a)))?
|
||||
}
|
||||
// Type::JSON | Type::JSONB => get_basic(row, column, column_i, |a: JSONValue| Ok(a))?,
|
||||
Type::FLOAT4 => get_basic(row, column, column_i, |a: f32| f64_to_json_number(a.into()))?,
|
||||
Type::FLOAT8 => get_basic(row, column, column_i, f64_to_json_number)?,
|
||||
// these types require a custom StringCollector struct as an intermediary (see struct at bottom)
|
||||
Type::TS_VECTOR => get_basic(row, column, column_i, |a: StringCollector| {
|
||||
Ok(JSONValue::String(a.0))
|
||||
})?,
|
||||
|
||||
// array types
|
||||
Type::BOOL_ARRAY => get_array(row, column, column_i, |a: bool| Ok(JSONValue::Bool(a)))?,
|
||||
Type::INT2_ARRAY => get_array(row, column, column_i, |a: i16| {
|
||||
Ok(JSONValue::Number(serde_json::Number::from(a)))
|
||||
})?,
|
||||
Type::INT4_ARRAY => get_array(row, column, column_i, |a: i32| {
|
||||
Ok(JSONValue::Number(serde_json::Number::from(a)))
|
||||
})?,
|
||||
Type::INT8_ARRAY => get_array(row, column, column_i, |a: i64| {
|
||||
Ok(JSONValue::Number(serde_json::Number::from(a)))
|
||||
})?,
|
||||
Type::TEXT_ARRAY | Type::VARCHAR_ARRAY => {
|
||||
get_array(row, column, column_i, |a: String| Ok(JSONValue::String(a)))?
|
||||
}
|
||||
// Type::JSON_ARRAY | Type::JSONB_ARRAY => get_array(row, column, column_i, |a: JSONValue| Ok(a))?,
|
||||
Type::FLOAT4_ARRAY => get_array(row, column, column_i, f64_to_json_number)?,
|
||||
Type::FLOAT8_ARRAY => get_array(row, column, column_i, f64_to_json_number)?,
|
||||
// these types require a custom StringCollector struct as an intermediary (see struct at bottom)
|
||||
Type::TS_VECTOR_ARRAY => get_array(row, column, column_i, |a: StringCollector| {
|
||||
Ok(JSONValue::String(a.0))
|
||||
})?,
|
||||
|
||||
_ => anyhow::bail!(
|
||||
"Cannot convert pg-cell \"{}\" of type \"{}\" to a JSONValue.",
|
||||
column.name(),
|
||||
column.type_().name()
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
fn get_basic<'a, T: FromSql<'a>>(
|
||||
row: &'a Row,
|
||||
column: &Column,
|
||||
column_i: usize,
|
||||
val_to_json_val: impl Fn(T) -> Result<JSONValue, Error>,
|
||||
) -> Result<JSONValue, Error> {
|
||||
let raw_val = row
|
||||
.try_get::<_, Option<T>>(column_i)
|
||||
.with_context(|| format!("column_name:{}", column.name()))?;
|
||||
raw_val.map_or(Ok(JSONValue::Null), val_to_json_val)
|
||||
}
|
||||
fn get_array<'a, T: FromSql<'a>>(
|
||||
row: &'a Row,
|
||||
column: &Column,
|
||||
column_i: usize,
|
||||
val_to_json_val: impl Fn(T) -> Result<JSONValue, Error>,
|
||||
) -> Result<JSONValue, Error> {
|
||||
let raw_val_array = row
|
||||
.try_get::<_, Option<Vec<T>>>(column_i)
|
||||
.with_context(|| format!("column_name:{}", column.name()))?;
|
||||
Ok(match raw_val_array {
|
||||
Some(val_array) => {
|
||||
let mut result = vec![];
|
||||
for val in val_array {
|
||||
result.push(val_to_json_val(val)?);
|
||||
}
|
||||
JSONValue::Array(result)
|
||||
}
|
||||
None => JSONValue::Null,
|
||||
})
|
||||
}
|
||||
|
||||
struct StringCollector(String);
|
||||
impl FromSql<'_> for StringCollector {
|
||||
fn from_sql(
|
||||
_: &Type,
|
||||
raw: &[u8],
|
||||
) -> Result<StringCollector, Box<dyn std::error::Error + Sync + Send>> {
|
||||
let result = std::str::from_utf8(raw)?;
|
||||
Ok(StringCollector(result.to_owned()))
|
||||
}
|
||||
fn accepts(_ty: &Type) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,25 @@
|
||||
use crate::http::pg_to_json::postgres_row_to_json_value;
|
||||
use crate::{
|
||||
cancellation::CancelMap, config::ProxyConfig, error::io_error, proxy::handle_ws_client,
|
||||
auth, cancellation::CancelMap, config::ProxyConfig, console, error::io_error,
|
||||
proxy::handle_ws_client,
|
||||
};
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{Sink, Stream, StreamExt};
|
||||
use futures::{Sink, Stream, StreamExt, TryStreamExt};
|
||||
use tokio_postgres::Row;
|
||||
use std::collections::HashMap;
|
||||
use hyper::{
|
||||
server::{accept, conn::AddrIncoming},
|
||||
upgrade::Upgraded,
|
||||
Body, Request, Response, StatusCode,
|
||||
Body, Method, Request, Response, StatusCode,
|
||||
};
|
||||
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
||||
use percent_encoding::percent_decode;
|
||||
use pin_project_lite::pin_project;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use serde_json::{Value, json};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use tokio_postgres::types::{ToSql};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
future::ready,
|
||||
@@ -24,6 +34,7 @@ use tokio::{
|
||||
};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
use url::{Url};
|
||||
use utils::http::{error::ApiError, json::json_response};
|
||||
|
||||
// TODO: use `std::sync::Exclusive` once it's stabilized.
|
||||
@@ -41,10 +52,10 @@ pin_project! {
|
||||
}
|
||||
|
||||
impl WebSocketRw {
|
||||
pub fn new(stream: WebSocketStream<Upgraded>) -> Self {
|
||||
pub fn new(stream: WebSocketStream<Upgraded>, startup_data: Bytes) -> Self {
|
||||
Self {
|
||||
stream: stream.into(),
|
||||
bytes: Bytes::new(),
|
||||
bytes: startup_data,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -81,6 +92,7 @@ impl AsyncRead for WebSocketRw {
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
|
||||
if buf.remaining() > 0 {
|
||||
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
|
||||
let len = std::cmp::min(bytes.len(), buf.remaining());
|
||||
@@ -141,25 +153,101 @@ async fn serve_websocket(
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
hostname: Option<String>,
|
||||
startup_data: Vec<u8>,
|
||||
) -> anyhow::Result<()> {
|
||||
|
||||
let websocket = websocket.await?;
|
||||
|
||||
handle_ws_client(
|
||||
config,
|
||||
cancel_map,
|
||||
session_id,
|
||||
WebSocketRw::new(websocket),
|
||||
WebSocketRw::new(websocket, startup_data.into()),
|
||||
hostname,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct MyObject {
|
||||
data: Vec<u8>,
|
||||
recv_data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl MyObject {
|
||||
fn new(data: Vec<u8>) -> Self {
|
||||
MyObject {
|
||||
data,
|
||||
recv_data: Vec::with_capacity(512),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl AsyncRead for MyObject {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<(), io::Error>> {
|
||||
|
||||
let data = &self.get_mut().data;
|
||||
|
||||
// let len: usize = recv_data.len();
|
||||
// info!("len: {}", len);
|
||||
// let mut i: usize = 0;
|
||||
// let mut zcount = 0;
|
||||
// while len >= i + 5 {
|
||||
// let cmd = recv_data[i];
|
||||
// if cmd == 0x5a { zcount += 1; }
|
||||
// let size = u32::from_be_bytes(recv_data[(i + 1)..(i + 5)].try_into().unwrap());
|
||||
// info!("cmd: {} size: {} buf: {:?}", cmd, size, recv_data);
|
||||
// i += usize::try_from(size).unwrap() + 1;
|
||||
// }
|
||||
// if zcount < 2 {
|
||||
// Poll::Pending
|
||||
|
||||
// } else {
|
||||
// let mut reader = &recv_data[..];
|
||||
// Pin::new(&mut reader).poll_read(cx, buf)
|
||||
// }
|
||||
|
||||
let mut reader = &data[..];
|
||||
Pin::new(&mut reader).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for MyObject {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
|
||||
// eprintln!("{:?}", buf);
|
||||
let recv_data = &mut self.get_mut().recv_data;
|
||||
recv_data.extend(buf);
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
// ...
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
// ...
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async fn ws_handler(
|
||||
mut request: Request<Body>,
|
||||
config: &'static ProxyConfig,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
session_id: uuid::Uuid,
|
||||
cache: Arc<Mutex<ConnectionCache>>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
|
||||
let host = request
|
||||
.headers()
|
||||
.get("host")
|
||||
@@ -169,25 +257,311 @@ async fn ws_handler(
|
||||
|
||||
// Check if the request is a websocket upgrade request.
|
||||
if hyper_tungstenite::is_upgrade_request(&request) {
|
||||
let startup_data = match request.uri().query() {
|
||||
Some(b64_str) => match base64::decode_config(b64_str, base64::URL_SAFE) {
|
||||
Ok(x) => x,
|
||||
Err(_) => {
|
||||
eprintln!("invalid WebSocket base64 startup data");
|
||||
vec![]
|
||||
}
|
||||
},
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
info!("{} bytes of startup data received via WebSocket URL query", startup_data.len());
|
||||
|
||||
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host).await
|
||||
{
|
||||
if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host, startup_data).await {
|
||||
error!("error in websocket connection: {e:?}");
|
||||
}
|
||||
});
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
|
||||
} else if request.uri().path() == "/pg-protocol" && request.method() == Method::POST {
|
||||
let mut body = request.into_body();
|
||||
let mut data = Vec::with_capacity(512);
|
||||
while let Some(chunk) = body.next().await {
|
||||
data.extend(&chunk.map_err(|e| ApiError::InternalServerError(e.into()))?);
|
||||
}
|
||||
|
||||
let mut my_object = MyObject::new(data);
|
||||
let my_object = tokio::spawn(async move {
|
||||
let result = handle_ws_client(
|
||||
config,
|
||||
&cancel_map,
|
||||
session_id,
|
||||
&mut my_object,
|
||||
host,
|
||||
).await;
|
||||
my_object
|
||||
})
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
|
||||
let response = Response::builder()
|
||||
.header("Content-Type", "application/octet-stream")
|
||||
.header("Access-Control-Allow-Origin", "*")
|
||||
.status(StatusCode::OK)
|
||||
.body(Body::from(my_object.recv_data))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
|
||||
Ok(response)
|
||||
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||
let result = handle_sql(config, request).await;
|
||||
let status_code = match result {
|
||||
Ok(_) => StatusCode::OK,
|
||||
Err(_) => StatusCode::BAD_REQUEST
|
||||
};
|
||||
let json = match result {
|
||||
Ok(r) => serde_json::to_value(r).unwrap(),
|
||||
Err(e) => {
|
||||
let message = format!("{:?}", e);
|
||||
let code = match e.downcast_ref::<tokio_postgres::Error>() {
|
||||
Some(e) => match e.code() {
|
||||
Some(e) => serde_json::to_value(e.code()).unwrap(),
|
||||
None => Value::Null
|
||||
},
|
||||
None => Value::Null
|
||||
};
|
||||
json!({ "message": message, "code": code })
|
||||
}
|
||||
};
|
||||
json_response(status_code, json).map(|mut r| {
|
||||
r.headers_mut().insert(
|
||||
"Access-Control-Allow-Origin",
|
||||
hyper::http::HeaderValue::from_static("*"),
|
||||
);
|
||||
r
|
||||
})
|
||||
|
||||
} else if request.uri().path() == "/sleep" {
|
||||
// sleep 15ms
|
||||
tokio::time::sleep(std::time::Duration::from_millis(15)).await;
|
||||
json_response(StatusCode::OK, "done")
|
||||
|
||||
} else {
|
||||
json_response(StatusCode::OK, "Connect with a websocket client")
|
||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||
}
|
||||
}
|
||||
|
||||
fn boxed_from_json(json: Vec<Value>) -> Result<Vec<Box<dyn ToSql + Sync + Send>>, anyhow::Error> {
|
||||
json.iter().map(|value| {
|
||||
let boxed: Result<Box<dyn ToSql + Sync + Send>, anyhow::Error> = match value {
|
||||
Value::Bool(b) => Ok(Box::new(b.clone())),
|
||||
Value::Number(n) => Ok(Box::new(n.as_f64().unwrap())),
|
||||
Value::String(s) => Ok(Box::new(s.clone())),
|
||||
// TODO: support null (not like this: `Value::Null => Ok(Box::new(None::<bool>)),`)
|
||||
// TODO: support arrays
|
||||
x => Err(anyhow::anyhow!("unsupported param {:?}", x))
|
||||
};
|
||||
boxed
|
||||
})
|
||||
.collect::<Result<Vec<_>, anyhow::Error>>()
|
||||
}
|
||||
|
||||
// XXX: return different error codes
|
||||
async fn handle_sql(
|
||||
config: &'static ProxyConfig,
|
||||
request: Request<Body>
|
||||
) -> anyhow::Result<Vec<Value>> {
|
||||
|
||||
let headers = request.headers();
|
||||
|
||||
let connection_string = headers
|
||||
.get("X-Neon-ConnectionString")
|
||||
.ok_or(anyhow::anyhow!("missing connection string"))?
|
||||
.to_str()?;
|
||||
|
||||
let connection_url = Url::parse(connection_string)?;
|
||||
|
||||
let protocol = connection_url.scheme();
|
||||
if protocol != "postgres" && protocol != "postgresql" {
|
||||
return Err(anyhow::anyhow!("connection string must start with postgres: or postgresql:"))
|
||||
}
|
||||
|
||||
let mut url_path = connection_url
|
||||
.path_segments()
|
||||
.ok_or(anyhow::anyhow!("missing database name"))?;
|
||||
|
||||
let dbname = url_path
|
||||
.next()
|
||||
.ok_or(anyhow::anyhow!("invalid database name"))?;
|
||||
|
||||
let username = connection_url.username();
|
||||
|
||||
let password = connection_url
|
||||
.password()
|
||||
.ok_or(anyhow::anyhow!("no password"))?;
|
||||
|
||||
let hostname = connection_url
|
||||
.host_str()
|
||||
.ok_or(anyhow::anyhow!("no host"))?;
|
||||
|
||||
let host_header = request
|
||||
.headers()
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.split(':').next());
|
||||
|
||||
match host_header {
|
||||
Some(h) if h == hostname => h,
|
||||
Some(_) => return Err(anyhow::anyhow!("mismatched host header and hostname")),
|
||||
None => return Err(anyhow::anyhow!("no host header"))
|
||||
};
|
||||
|
||||
let mut body = request.into_body();
|
||||
let mut data = Vec::with_capacity(512);
|
||||
while let Some(chunk) = body.next().await {
|
||||
data.extend(&chunk?);
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct QueryData {
|
||||
query: String,
|
||||
params: Vec<serde_json::Value>
|
||||
}
|
||||
let query_data: QueryData = serde_json::from_slice(&data)?;
|
||||
|
||||
let credential_params = StartupMessageParams::new([
|
||||
("user", username),
|
||||
("database", dbname),
|
||||
("application_name", "proxy_http_sql"),
|
||||
]);
|
||||
let tls = config.tls_config.as_ref();
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
let creds = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(&credential_params, Some(hostname), common_names))
|
||||
.transpose()?;
|
||||
|
||||
let extra = console::ConsoleReqExtra {
|
||||
session_id: uuid::Uuid::new_v4(),
|
||||
application_name: Some("proxy_http_sql"),
|
||||
};
|
||||
|
||||
let node = creds.wake_compute(&extra).await?.expect("msg");
|
||||
let conf = node.value.config;
|
||||
|
||||
let host = match conf.get_hosts().first().expect("no host") {
|
||||
tokio_postgres::config::Host::Tcp(host) => host,
|
||||
tokio_postgres::config::Host::Unix(_) => {
|
||||
return Err(anyhow::anyhow!("unix socket is not supported"));
|
||||
}
|
||||
};
|
||||
|
||||
let conn_string = &format!(
|
||||
"host={} port={} user={} password={} dbname={}",
|
||||
host,
|
||||
conf.get_ports().first().expect("no port"),
|
||||
username,
|
||||
password,
|
||||
dbname
|
||||
);
|
||||
|
||||
let (client, connection) = tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let query = &query_data.query;
|
||||
let query_params = boxed_from_json(query_data.params)?;
|
||||
|
||||
// TODO: find a way to catch the panic and return an error if the number of params is wrong
|
||||
let pg_rows: Vec<Row> = client
|
||||
.query_raw(query, query_params)
|
||||
.await?
|
||||
.try_collect::<Vec<Row>>()
|
||||
.await?;
|
||||
|
||||
let rows: Result<Vec<serde_json::Value>, anyhow::Error> = pg_rows
|
||||
.iter()
|
||||
.map(postgres_row_to_json_value)
|
||||
.collect();
|
||||
|
||||
let rows = rows?;
|
||||
|
||||
Ok(rows)
|
||||
}
|
||||
|
||||
pub struct ConnectionCache {
|
||||
connections: HashMap<String, tokio_postgres::Client>,
|
||||
}
|
||||
|
||||
|
||||
impl ConnectionCache {
|
||||
pub fn new() -> Arc<Mutex<Self>> {
|
||||
Arc::new(Mutex::new(Self {
|
||||
connections: HashMap::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn execute(
|
||||
cache: &Arc<Mutex<ConnectionCache>>,
|
||||
conn_string: &str,
|
||||
hostname: &str,
|
||||
sql: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
// TODO: let go mutex when establishing connection
|
||||
let mut cache = cache.lock().await;
|
||||
let cache_key = format!("connstr={}, hostname={}", conn_string, hostname);
|
||||
let client = if let Some(client) = cache.connections.get(&cache_key) {
|
||||
info!("using cached connection {}", conn_string);
|
||||
client
|
||||
} else {
|
||||
info!("!!!! connecting to: {}", conn_string);
|
||||
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
// TODO: remove connection from cache
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
cache.connections.insert(cache_key.clone(), client);
|
||||
cache.connections.get(&cache_key).unwrap()
|
||||
};
|
||||
|
||||
let sql = percent_decode(sql.as_bytes()).decode_utf8()?.to_string();
|
||||
|
||||
let rows: Vec<HashMap<_, _>> = client
|
||||
.simple_query(&sql)
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter_map(|el| {
|
||||
if let tokio_postgres::SimpleQueryMessage::Row(row) = el {
|
||||
let mut serialized_row: HashMap<String, String> = HashMap::new();
|
||||
for i in 0..row.len() {
|
||||
let col = row.columns().get(i).map_or("?", |c| c.name());
|
||||
let val = row.get(i).unwrap_or("?");
|
||||
serialized_row.insert(col.into(), val.into());
|
||||
}
|
||||
Some(serialized_row)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(serde_json::to_string(&rows)?)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
cache: &'static Arc<Mutex<ConnectionCache>>,
|
||||
ws_listener: TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -221,7 +595,7 @@ pub async fn task_main(
|
||||
move |req: Request<Body>| async move {
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
ws_handler(req, config, cancel_map, session_id)
|
||||
ws_handler(req, config, cancel_map, session_id, cache.clone())
|
||||
.instrument(info_span!(
|
||||
"ws-client",
|
||||
session = format_args!("{session_id}")
|
||||
|
||||
@@ -32,6 +32,8 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info, warn};
|
||||
use utils::{project_git_version, sentry_init::init_sentry};
|
||||
|
||||
use crate::http::websocket::ConnectionCache;
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
|
||||
/// Flattens `Result<Result<T>>` into `Result<T>`.
|
||||
@@ -53,6 +55,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
let args = cli().get_matches();
|
||||
let config = build_config(&args)?;
|
||||
|
||||
let wsconn_cache = Box::leak(Box::new(ConnectionCache::new()));
|
||||
|
||||
info!("Authentication backend: {}", config.auth_backend);
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
@@ -82,6 +86,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
client_tasks.push(tokio::spawn(http::websocket::task_main(
|
||||
config,
|
||||
wsconn_cache,
|
||||
wss_listener,
|
||||
cancellation_token.clone(),
|
||||
)));
|
||||
|
||||
2
vendor/postgres-v14
vendored
2
vendor/postgres-v14
vendored
Submodule vendor/postgres-v14 updated: 3e70693c91...c22aea6714
2
vendor/postgres-v15
vendored
2
vendor/postgres-v15
vendored
Submodule vendor/postgres-v15 updated: 4ad87b0f36...114da43a49
Reference in New Issue
Block a user