mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-02 02:00:38 +00:00
Compare commits
20 Commits
conrad/pro
...
gm/sql_ove
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6abe34bfa1 | ||
|
|
b932c14f94 | ||
|
|
2e05a9b652 | ||
|
|
b808160aff | ||
|
|
2997018005 | ||
|
|
8f98cc29fa | ||
|
|
b4ee8e3b73 | ||
|
|
68ea916bf1 | ||
|
|
eea3dd54da | ||
|
|
4071b22519 | ||
|
|
c58edec63a | ||
|
|
fdb9d4373d | ||
|
|
cb88df7ffa | ||
|
|
595248532c | ||
|
|
158c051c9a | ||
|
|
143c4954df | ||
|
|
373ae7672b | ||
|
|
1233923c3a | ||
|
|
1a3a7f14dd | ||
|
|
4e9067a8c2 |
41
Cargo.lock
generated
41
Cargo.lock
generated
@@ -130,6 +130,12 @@ dependencies = [
|
|||||||
"static_assertions",
|
"static_assertions",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "array-init"
|
||||||
|
version = "2.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3d62b7694a562cdf5a74227903507c56ab2cc8bdd1f781ed5cb4cf9c9f810bfc"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "asn1-rs"
|
name = "asn1-rs"
|
||||||
version = "0.5.2"
|
version = "0.5.2"
|
||||||
@@ -297,7 +303,7 @@ dependencies = [
|
|||||||
"http",
|
"http",
|
||||||
"http-body",
|
"http-body",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"percent-encoding",
|
"percent-encoding 2.2.0",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
@@ -401,7 +407,7 @@ dependencies = [
|
|||||||
"hex",
|
"hex",
|
||||||
"http",
|
"http",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"percent-encoding",
|
"percent-encoding 2.2.0",
|
||||||
"regex",
|
"regex",
|
||||||
"ring",
|
"ring",
|
||||||
"time",
|
"time",
|
||||||
@@ -522,7 +528,7 @@ dependencies = [
|
|||||||
"http-body",
|
"http-body",
|
||||||
"hyper",
|
"hyper",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"percent-encoding",
|
"percent-encoding 2.2.0",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"pin-utils",
|
"pin-utils",
|
||||||
"tokio",
|
"tokio",
|
||||||
@@ -544,7 +550,7 @@ dependencies = [
|
|||||||
"http-body",
|
"http-body",
|
||||||
"hyper",
|
"hyper",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"percent-encoding",
|
"percent-encoding 2.2.0",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"pin-utils",
|
"pin-utils",
|
||||||
"tracing",
|
"tracing",
|
||||||
@@ -684,7 +690,7 @@ dependencies = [
|
|||||||
"matchit",
|
"matchit",
|
||||||
"memchr",
|
"memchr",
|
||||||
"mime",
|
"mime",
|
||||||
"percent-encoding",
|
"percent-encoding 2.2.0",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"rustversion",
|
"rustversion",
|
||||||
"serde",
|
"serde",
|
||||||
@@ -1580,7 +1586,7 @@ version = "1.1.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8"
|
checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"percent-encoding",
|
"percent-encoding 2.2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2584,7 +2590,7 @@ dependencies = [
|
|||||||
"futures-util",
|
"futures-util",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"opentelemetry_api",
|
"opentelemetry_api",
|
||||||
"percent-encoding",
|
"percent-encoding 2.2.0",
|
||||||
"rand",
|
"rand",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
@@ -2749,6 +2755,12 @@ dependencies = [
|
|||||||
"base64 0.13.1",
|
"base64 0.13.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "percent-encoding"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "31010dd2e1ac33d5b46a5b413495239882813e0369f8ed8a5e266f173602f831"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "percent-encoding"
|
name = "percent-encoding"
|
||||||
version = "2.2.0"
|
version = "2.2.0"
|
||||||
@@ -2879,7 +2891,9 @@ name = "postgres-types"
|
|||||||
version = "0.2.4"
|
version = "0.2.4"
|
||||||
source = "git+https://github.com/neondatabase/rust-postgres.git?rev=43e6db254a97fdecbce33d8bc0890accfd74495e#43e6db254a97fdecbce33d8bc0890accfd74495e"
|
source = "git+https://github.com/neondatabase/rust-postgres.git?rev=43e6db254a97fdecbce33d8bc0890accfd74495e#43e6db254a97fdecbce33d8bc0890accfd74495e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"array-init",
|
||||||
"bytes",
|
"bytes",
|
||||||
|
"chrono",
|
||||||
"fallible-iterator",
|
"fallible-iterator",
|
||||||
"postgres-protocol",
|
"postgres-protocol",
|
||||||
]
|
]
|
||||||
@@ -3112,6 +3126,7 @@ dependencies = [
|
|||||||
"once_cell",
|
"once_cell",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
|
"percent-encoding 1.0.1",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"postgres_backend",
|
"postgres_backend",
|
||||||
"pq_proto",
|
"pq_proto",
|
||||||
@@ -3316,7 +3331,7 @@ dependencies = [
|
|||||||
"mime",
|
"mime",
|
||||||
"mime_guess",
|
"mime_guess",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"percent-encoding",
|
"percent-encoding 2.2.0",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"rustls 0.20.8",
|
"rustls 0.20.8",
|
||||||
"rustls-pemfile",
|
"rustls-pemfile",
|
||||||
@@ -3389,7 +3404,7 @@ dependencies = [
|
|||||||
"http",
|
"http",
|
||||||
"hyper",
|
"hyper",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"percent-encoding",
|
"percent-encoding 2.2.0",
|
||||||
"regex",
|
"regex",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -4332,7 +4347,7 @@ dependencies = [
|
|||||||
"futures-util",
|
"futures-util",
|
||||||
"log",
|
"log",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"percent-encoding",
|
"percent-encoding 2.2.0",
|
||||||
"phf",
|
"phf",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"postgres-protocol",
|
"postgres-protocol",
|
||||||
@@ -4480,7 +4495,7 @@ dependencies = [
|
|||||||
"http-body",
|
"http-body",
|
||||||
"hyper",
|
"hyper",
|
||||||
"hyper-timeout",
|
"hyper-timeout",
|
||||||
"percent-encoding",
|
"percent-encoding 2.2.0",
|
||||||
"pin-project",
|
"pin-project",
|
||||||
"prost",
|
"prost",
|
||||||
"prost-derive",
|
"prost-derive",
|
||||||
@@ -4512,7 +4527,7 @@ dependencies = [
|
|||||||
"http-body",
|
"http-body",
|
||||||
"hyper",
|
"hyper",
|
||||||
"hyper-timeout",
|
"hyper-timeout",
|
||||||
"percent-encoding",
|
"percent-encoding 2.2.0",
|
||||||
"pin-project",
|
"pin-project",
|
||||||
"prost",
|
"prost",
|
||||||
"rustls-native-certs",
|
"rustls-native-certs",
|
||||||
@@ -4822,7 +4837,7 @@ checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"form_urlencoded",
|
"form_urlencoded",
|
||||||
"idna",
|
"idna",
|
||||||
"percent-encoding",
|
"percent-encoding 2.2.0",
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ hmac = "0.12.1"
|
|||||||
hostname = "0.3.1"
|
hostname = "0.3.1"
|
||||||
humantime = "2.1"
|
humantime = "2.1"
|
||||||
humantime-serde = "1.1.1"
|
humantime-serde = "1.1.1"
|
||||||
hyper = "0.14"
|
hyper = { version = "0.14", features = ["http2", "tcp", "runtime", "http1"]}
|
||||||
hyper-tungstenite = "0.9"
|
hyper-tungstenite = "0.9"
|
||||||
itertools = "0.10"
|
itertools = "0.10"
|
||||||
jsonwebtoken = "8"
|
jsonwebtoken = "8"
|
||||||
@@ -117,16 +117,17 @@ uuid = { version = "1.2", features = ["v4", "serde"] }
|
|||||||
walkdir = "2.3.2"
|
walkdir = "2.3.2"
|
||||||
webpki-roots = "0.23"
|
webpki-roots = "0.23"
|
||||||
x509-parser = "0.15"
|
x509-parser = "0.15"
|
||||||
|
percent-encoding = "1.0"
|
||||||
|
|
||||||
## TODO replace this with tracing
|
## TODO replace this with tracing
|
||||||
env_logger = "0.10"
|
env_logger = "0.10"
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
|
|
||||||
## Libraries from neondatabase/ git forks, ideally with changes to be upstreamed
|
## Libraries from neondatabase/ git forks, ideally with changes to be upstreamed
|
||||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e", features = ["with-chrono-0_4", "array-impls"] }
|
||||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
||||||
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e", features = ["array-impls"] }
|
||||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e", features = ["array-impls"] }
|
||||||
tokio-tar = { git = "https://github.com/neondatabase/tokio-tar.git", rev="404df61437de0feef49ba2ccdbdd94eb8ad6e142" }
|
tokio-tar = { git = "https://github.com/neondatabase/tokio-tar.git", rev="404df61437de0feef49ba2ccdbdd94eb8ad6e142" }
|
||||||
|
|
||||||
## Other git libraries
|
## Other git libraries
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ hmac.workspace = true
|
|||||||
hostname.workspace = true
|
hostname.workspace = true
|
||||||
humantime.workspace = true
|
humantime.workspace = true
|
||||||
hyper-tungstenite.workspace = true
|
hyper-tungstenite.workspace = true
|
||||||
hyper.workspace = true
|
hyper = { workspace = true, features = ["http2", "http1", "tcp", "runtime"] }
|
||||||
itertools.workspace = true
|
itertools.workspace = true
|
||||||
md5.workspace = true
|
md5.workspace = true
|
||||||
metrics.workspace = true
|
metrics.workspace = true
|
||||||
@@ -65,6 +65,7 @@ x509-parser.workspace = true
|
|||||||
|
|
||||||
workspace_hack.workspace = true
|
workspace_hack.workspace = true
|
||||||
tokio-util.workspace = true
|
tokio-util.workspace = true
|
||||||
|
percent-encoding.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
rcgen.workspace = true
|
rcgen.workspace = true
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
//! Other modules should use stuff from this module instead of
|
//! Other modules should use stuff from this module instead of
|
||||||
//! directly relying on deps like `reqwest` (think loose coupling).
|
//! directly relying on deps like `reqwest` (think loose coupling).
|
||||||
|
|
||||||
|
pub mod pg_to_json;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
pub mod websocket;
|
pub mod websocket;
|
||||||
|
|
||||||
|
|||||||
154
proxy/src/http/pg_to_json.rs
Normal file
154
proxy/src/http/pg_to_json.rs
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
|
||||||
|
use anyhow::{anyhow, Context};
|
||||||
|
use chrono::{Utc, DateTime, NaiveDateTime, NaiveTime, NaiveDate};
|
||||||
|
use serde_json::Map;
|
||||||
|
use tokio_postgres::{
|
||||||
|
types::{FromSql, Type},
|
||||||
|
Column, Row,
|
||||||
|
};
|
||||||
|
|
||||||
|
// some type-aliases I use in my project
|
||||||
|
pub type JSONValue = serde_json::Value;
|
||||||
|
pub type RowData = Map<String, JSONValue>;
|
||||||
|
pub type Error = anyhow::Error; // from: https://github.com/dtolnay/anyhow
|
||||||
|
|
||||||
|
pub fn postgres_row_to_json_value(row: &Row) -> Result<JSONValue, Error> {
|
||||||
|
let row_data = postgres_row_to_row_data(row)?;
|
||||||
|
Ok(JSONValue::Object(row_data))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn postgres_row_to_row_data(row: &Row) -> Result<RowData, Error> {
|
||||||
|
let mut result: Map<String, JSONValue> = Map::new();
|
||||||
|
for (i, column) in row.columns().iter().enumerate() {
|
||||||
|
let name = column.name();
|
||||||
|
let json_value = pg_cell_to_json_value(&row, column, i)?;
|
||||||
|
result.insert(name.to_string(), json_value);
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pg_cell_to_json_value(
|
||||||
|
row: &Row,
|
||||||
|
column: &Column,
|
||||||
|
column_i: usize,
|
||||||
|
) -> Result<JSONValue, Error> {
|
||||||
|
let f64_to_json_number = |raw_val: f64| -> Result<JSONValue, Error> {
|
||||||
|
let temp = serde_json::Number::from_f64(raw_val).ok_or(anyhow!("invalid json-float"))?;
|
||||||
|
Ok(JSONValue::Number(temp))
|
||||||
|
};
|
||||||
|
Ok(match *column.type_() {
|
||||||
|
// for rust-postgres <> postgres type-mappings: https://docs.rs/postgres/latest/postgres/types/trait.FromSql.html#types
|
||||||
|
// for postgres types: https://www.postgresql.org/docs/7.4/datatype.html#DATATYPE-TABLE
|
||||||
|
|
||||||
|
Type::TIMESTAMPTZ => get_basic(row, column, column_i, |a: DateTime<Utc>| {
|
||||||
|
Ok(JSONValue::String(a.to_rfc3339()))
|
||||||
|
})?,
|
||||||
|
Type::TIMESTAMP => get_basic(row, column, column_i, |a: NaiveDateTime| {
|
||||||
|
Ok(JSONValue::String(a.format("%Y-%m-%dT%H:%M:%S%.6f").to_string()))
|
||||||
|
})?,
|
||||||
|
|
||||||
|
Type::TIME => get_basic(row, column, column_i, |a: NaiveTime| {
|
||||||
|
Ok(JSONValue::String(a.format("%H:%M:%S%.6f").to_string()))
|
||||||
|
})?,
|
||||||
|
Type::DATE => get_basic(row, column, column_i, |a: NaiveDate| {
|
||||||
|
Ok(JSONValue::String(a.format("%Y-%m-%d").to_string()))
|
||||||
|
})?,
|
||||||
|
// no TIMETZ support?
|
||||||
|
|
||||||
|
// single types
|
||||||
|
Type::BOOL => get_basic(row, column, column_i, |a: bool| Ok(JSONValue::Bool(a)))?,
|
||||||
|
Type::INT2 => get_basic(row, column, column_i, |a: i16| {
|
||||||
|
Ok(JSONValue::Number(serde_json::Number::from(a)))
|
||||||
|
})?,
|
||||||
|
Type::INT4 => get_basic(row, column, column_i, |a: i32| {
|
||||||
|
Ok(JSONValue::Number(serde_json::Number::from(a)))
|
||||||
|
})?,
|
||||||
|
Type::INT8 => get_basic(row, column, column_i, |a: i64| {
|
||||||
|
Ok(JSONValue::Number(serde_json::Number::from(a)))
|
||||||
|
})?,
|
||||||
|
Type::TEXT | Type::VARCHAR => {
|
||||||
|
get_basic(row, column, column_i, |a: String| Ok(JSONValue::String(a)))?
|
||||||
|
}
|
||||||
|
// Type::JSON | Type::JSONB => get_basic(row, column, column_i, |a: JSONValue| Ok(a))?,
|
||||||
|
Type::FLOAT4 => get_basic(row, column, column_i, |a: f32| f64_to_json_number(a.into()))?,
|
||||||
|
Type::FLOAT8 => get_basic(row, column, column_i, f64_to_json_number)?,
|
||||||
|
// these types require a custom StringCollector struct as an intermediary (see struct at bottom)
|
||||||
|
Type::TS_VECTOR => get_basic(row, column, column_i, |a: StringCollector| {
|
||||||
|
Ok(JSONValue::String(a.0))
|
||||||
|
})?,
|
||||||
|
|
||||||
|
// array types
|
||||||
|
Type::BOOL_ARRAY => get_array(row, column, column_i, |a: bool| Ok(JSONValue::Bool(a)))?,
|
||||||
|
Type::INT2_ARRAY => get_array(row, column, column_i, |a: i16| {
|
||||||
|
Ok(JSONValue::Number(serde_json::Number::from(a)))
|
||||||
|
})?,
|
||||||
|
Type::INT4_ARRAY => get_array(row, column, column_i, |a: i32| {
|
||||||
|
Ok(JSONValue::Number(serde_json::Number::from(a)))
|
||||||
|
})?,
|
||||||
|
Type::INT8_ARRAY => get_array(row, column, column_i, |a: i64| {
|
||||||
|
Ok(JSONValue::Number(serde_json::Number::from(a)))
|
||||||
|
})?,
|
||||||
|
Type::TEXT_ARRAY | Type::VARCHAR_ARRAY => {
|
||||||
|
get_array(row, column, column_i, |a: String| Ok(JSONValue::String(a)))?
|
||||||
|
}
|
||||||
|
// Type::JSON_ARRAY | Type::JSONB_ARRAY => get_array(row, column, column_i, |a: JSONValue| Ok(a))?,
|
||||||
|
Type::FLOAT4_ARRAY => get_array(row, column, column_i, f64_to_json_number)?,
|
||||||
|
Type::FLOAT8_ARRAY => get_array(row, column, column_i, f64_to_json_number)?,
|
||||||
|
// these types require a custom StringCollector struct as an intermediary (see struct at bottom)
|
||||||
|
Type::TS_VECTOR_ARRAY => get_array(row, column, column_i, |a: StringCollector| {
|
||||||
|
Ok(JSONValue::String(a.0))
|
||||||
|
})?,
|
||||||
|
|
||||||
|
_ => anyhow::bail!(
|
||||||
|
"Cannot convert pg-cell \"{}\" of type \"{}\" to a JSONValue.",
|
||||||
|
column.name(),
|
||||||
|
column.type_().name()
|
||||||
|
),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_basic<'a, T: FromSql<'a>>(
|
||||||
|
row: &'a Row,
|
||||||
|
column: &Column,
|
||||||
|
column_i: usize,
|
||||||
|
val_to_json_val: impl Fn(T) -> Result<JSONValue, Error>,
|
||||||
|
) -> Result<JSONValue, Error> {
|
||||||
|
let raw_val = row
|
||||||
|
.try_get::<_, Option<T>>(column_i)
|
||||||
|
.with_context(|| format!("column_name:{}", column.name()))?;
|
||||||
|
raw_val.map_or(Ok(JSONValue::Null), val_to_json_val)
|
||||||
|
}
|
||||||
|
fn get_array<'a, T: FromSql<'a>>(
|
||||||
|
row: &'a Row,
|
||||||
|
column: &Column,
|
||||||
|
column_i: usize,
|
||||||
|
val_to_json_val: impl Fn(T) -> Result<JSONValue, Error>,
|
||||||
|
) -> Result<JSONValue, Error> {
|
||||||
|
let raw_val_array = row
|
||||||
|
.try_get::<_, Option<Vec<T>>>(column_i)
|
||||||
|
.with_context(|| format!("column_name:{}", column.name()))?;
|
||||||
|
Ok(match raw_val_array {
|
||||||
|
Some(val_array) => {
|
||||||
|
let mut result = vec![];
|
||||||
|
for val in val_array {
|
||||||
|
result.push(val_to_json_val(val)?);
|
||||||
|
}
|
||||||
|
JSONValue::Array(result)
|
||||||
|
}
|
||||||
|
None => JSONValue::Null,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
struct StringCollector(String);
|
||||||
|
impl FromSql<'_> for StringCollector {
|
||||||
|
fn from_sql(
|
||||||
|
_: &Type,
|
||||||
|
raw: &[u8],
|
||||||
|
) -> Result<StringCollector, Box<dyn std::error::Error + Sync + Send>> {
|
||||||
|
let result = std::str::from_utf8(raw)?;
|
||||||
|
Ok(StringCollector(result.to_owned()))
|
||||||
|
}
|
||||||
|
fn accepts(_ty: &Type) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,15 +1,25 @@
|
|||||||
|
use crate::http::pg_to_json::postgres_row_to_json_value;
|
||||||
use crate::{
|
use crate::{
|
||||||
cancellation::CancelMap, config::ProxyConfig, error::io_error, proxy::handle_ws_client,
|
auth, cancellation::CancelMap, config::ProxyConfig, console, error::io_error,
|
||||||
|
proxy::handle_ws_client,
|
||||||
};
|
};
|
||||||
use bytes::{Buf, Bytes};
|
use bytes::{Buf, Bytes};
|
||||||
use futures::{Sink, Stream, StreamExt};
|
use futures::{Sink, Stream, StreamExt, TryStreamExt};
|
||||||
|
use tokio_postgres::Row;
|
||||||
|
use std::collections::HashMap;
|
||||||
use hyper::{
|
use hyper::{
|
||||||
server::{accept, conn::AddrIncoming},
|
server::{accept, conn::AddrIncoming},
|
||||||
upgrade::Upgraded,
|
upgrade::Upgraded,
|
||||||
Body, Request, Response, StatusCode,
|
Body, Method, Request, Response, StatusCode,
|
||||||
};
|
};
|
||||||
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
||||||
|
use percent_encoding::percent_decode;
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
|
use pq_proto::StartupMessageParams;
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
use tokio_postgres::types::{ToSql};
|
||||||
use std::{
|
use std::{
|
||||||
convert::Infallible,
|
convert::Infallible,
|
||||||
future::ready,
|
future::ready,
|
||||||
@@ -24,6 +34,7 @@ use tokio::{
|
|||||||
};
|
};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::{error, info, info_span, warn, Instrument};
|
use tracing::{error, info, info_span, warn, Instrument};
|
||||||
|
use url::{Url};
|
||||||
use utils::http::{error::ApiError, json::json_response};
|
use utils::http::{error::ApiError, json::json_response};
|
||||||
|
|
||||||
// TODO: use `std::sync::Exclusive` once it's stabilized.
|
// TODO: use `std::sync::Exclusive` once it's stabilized.
|
||||||
@@ -41,10 +52,10 @@ pin_project! {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl WebSocketRw {
|
impl WebSocketRw {
|
||||||
pub fn new(stream: WebSocketStream<Upgraded>) -> Self {
|
pub fn new(stream: WebSocketStream<Upgraded>, startup_data: Bytes) -> Self {
|
||||||
Self {
|
Self {
|
||||||
stream: stream.into(),
|
stream: stream.into(),
|
||||||
bytes: Bytes::new(),
|
bytes: startup_data,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -81,6 +92,7 @@ impl AsyncRead for WebSocketRw {
|
|||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
buf: &mut ReadBuf<'_>,
|
buf: &mut ReadBuf<'_>,
|
||||||
) -> Poll<io::Result<()>> {
|
) -> Poll<io::Result<()>> {
|
||||||
|
|
||||||
if buf.remaining() > 0 {
|
if buf.remaining() > 0 {
|
||||||
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
|
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
|
||||||
let len = std::cmp::min(bytes.len(), buf.remaining());
|
let len = std::cmp::min(bytes.len(), buf.remaining());
|
||||||
@@ -141,25 +153,101 @@ async fn serve_websocket(
|
|||||||
cancel_map: &CancelMap,
|
cancel_map: &CancelMap,
|
||||||
session_id: uuid::Uuid,
|
session_id: uuid::Uuid,
|
||||||
hostname: Option<String>,
|
hostname: Option<String>,
|
||||||
|
startup_data: Vec<u8>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
|
|
||||||
let websocket = websocket.await?;
|
let websocket = websocket.await?;
|
||||||
|
|
||||||
handle_ws_client(
|
handle_ws_client(
|
||||||
config,
|
config,
|
||||||
cancel_map,
|
cancel_map,
|
||||||
session_id,
|
session_id,
|
||||||
WebSocketRw::new(websocket),
|
WebSocketRw::new(websocket, startup_data.into()),
|
||||||
hostname,
|
hostname,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct MyObject {
|
||||||
|
data: Vec<u8>,
|
||||||
|
recv_data: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MyObject {
|
||||||
|
fn new(data: Vec<u8>) -> Self {
|
||||||
|
MyObject {
|
||||||
|
data,
|
||||||
|
recv_data: Vec::with_capacity(512),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl AsyncRead for MyObject {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<Result<(), io::Error>> {
|
||||||
|
|
||||||
|
let data = &self.get_mut().data;
|
||||||
|
|
||||||
|
// let len: usize = recv_data.len();
|
||||||
|
// info!("len: {}", len);
|
||||||
|
// let mut i: usize = 0;
|
||||||
|
// let mut zcount = 0;
|
||||||
|
// while len >= i + 5 {
|
||||||
|
// let cmd = recv_data[i];
|
||||||
|
// if cmd == 0x5a { zcount += 1; }
|
||||||
|
// let size = u32::from_be_bytes(recv_data[(i + 1)..(i + 5)].try_into().unwrap());
|
||||||
|
// info!("cmd: {} size: {} buf: {:?}", cmd, size, recv_data);
|
||||||
|
// i += usize::try_from(size).unwrap() + 1;
|
||||||
|
// }
|
||||||
|
// if zcount < 2 {
|
||||||
|
// Poll::Pending
|
||||||
|
|
||||||
|
// } else {
|
||||||
|
// let mut reader = &recv_data[..];
|
||||||
|
// Pin::new(&mut reader).poll_read(cx, buf)
|
||||||
|
// }
|
||||||
|
|
||||||
|
let mut reader = &data[..];
|
||||||
|
Pin::new(&mut reader).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for MyObject {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<Result<usize, io::Error>> {
|
||||||
|
|
||||||
|
// eprintln!("{:?}", buf);
|
||||||
|
let recv_data = &mut self.get_mut().recv_data;
|
||||||
|
recv_data.extend(buf);
|
||||||
|
Poll::Ready(Ok(buf.len()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||||
|
// ...
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||||
|
// ...
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async fn ws_handler(
|
async fn ws_handler(
|
||||||
mut request: Request<Body>,
|
mut request: Request<Body>,
|
||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
cancel_map: Arc<CancelMap>,
|
cancel_map: Arc<CancelMap>,
|
||||||
session_id: uuid::Uuid,
|
session_id: uuid::Uuid,
|
||||||
|
cache: Arc<Mutex<ConnectionCache>>,
|
||||||
) -> Result<Response<Body>, ApiError> {
|
) -> Result<Response<Body>, ApiError> {
|
||||||
|
|
||||||
let host = request
|
let host = request
|
||||||
.headers()
|
.headers()
|
||||||
.get("host")
|
.get("host")
|
||||||
@@ -169,25 +257,311 @@ async fn ws_handler(
|
|||||||
|
|
||||||
// Check if the request is a websocket upgrade request.
|
// Check if the request is a websocket upgrade request.
|
||||||
if hyper_tungstenite::is_upgrade_request(&request) {
|
if hyper_tungstenite::is_upgrade_request(&request) {
|
||||||
|
let startup_data = match request.uri().query() {
|
||||||
|
Some(b64_str) => match base64::decode_config(b64_str, base64::URL_SAFE) {
|
||||||
|
Ok(x) => x,
|
||||||
|
Err(_) => {
|
||||||
|
eprintln!("invalid WebSocket base64 startup data");
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
None => vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
info!("{} bytes of startup data received via WebSocket URL query", startup_data.len());
|
||||||
|
|
||||||
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
|
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
|
||||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host).await
|
if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host, startup_data).await {
|
||||||
{
|
|
||||||
error!("error in websocket connection: {e:?}");
|
error!("error in websocket connection: {e:?}");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Return the response so the spawned future can continue.
|
// Return the response so the spawned future can continue.
|
||||||
Ok(response)
|
Ok(response)
|
||||||
|
|
||||||
|
} else if request.uri().path() == "/pg-protocol" && request.method() == Method::POST {
|
||||||
|
let mut body = request.into_body();
|
||||||
|
let mut data = Vec::with_capacity(512);
|
||||||
|
while let Some(chunk) = body.next().await {
|
||||||
|
data.extend(&chunk.map_err(|e| ApiError::InternalServerError(e.into()))?);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut my_object = MyObject::new(data);
|
||||||
|
let my_object = tokio::spawn(async move {
|
||||||
|
let result = handle_ws_client(
|
||||||
|
config,
|
||||||
|
&cancel_map,
|
||||||
|
session_id,
|
||||||
|
&mut my_object,
|
||||||
|
host,
|
||||||
|
).await;
|
||||||
|
my_object
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||||
|
|
||||||
|
let response = Response::builder()
|
||||||
|
.header("Content-Type", "application/octet-stream")
|
||||||
|
.header("Access-Control-Allow-Origin", "*")
|
||||||
|
.status(StatusCode::OK)
|
||||||
|
.body(Body::from(my_object.recv_data))
|
||||||
|
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
|
||||||
|
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||||
|
let result = handle_sql(config, request).await;
|
||||||
|
let status_code = match result {
|
||||||
|
Ok(_) => StatusCode::OK,
|
||||||
|
Err(_) => StatusCode::BAD_REQUEST
|
||||||
|
};
|
||||||
|
let json = match result {
|
||||||
|
Ok(r) => serde_json::to_value(r).unwrap(),
|
||||||
|
Err(e) => {
|
||||||
|
let message = format!("{:?}", e);
|
||||||
|
let code = match e.downcast_ref::<tokio_postgres::Error>() {
|
||||||
|
Some(e) => match e.code() {
|
||||||
|
Some(e) => serde_json::to_value(e.code()).unwrap(),
|
||||||
|
None => Value::Null
|
||||||
|
},
|
||||||
|
None => Value::Null
|
||||||
|
};
|
||||||
|
json!({ "message": message, "code": code })
|
||||||
|
}
|
||||||
|
};
|
||||||
|
json_response(status_code, json).map(|mut r| {
|
||||||
|
r.headers_mut().insert(
|
||||||
|
"Access-Control-Allow-Origin",
|
||||||
|
hyper::http::HeaderValue::from_static("*"),
|
||||||
|
);
|
||||||
|
r
|
||||||
|
})
|
||||||
|
|
||||||
|
} else if request.uri().path() == "/sleep" {
|
||||||
|
// sleep 15ms
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(15)).await;
|
||||||
|
json_response(StatusCode::OK, "done")
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
json_response(StatusCode::OK, "Connect with a websocket client")
|
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn boxed_from_json(json: Vec<Value>) -> Result<Vec<Box<dyn ToSql + Sync + Send>>, anyhow::Error> {
|
||||||
|
json.iter().map(|value| {
|
||||||
|
let boxed: Result<Box<dyn ToSql + Sync + Send>, anyhow::Error> = match value {
|
||||||
|
Value::Bool(b) => Ok(Box::new(b.clone())),
|
||||||
|
Value::Number(n) => Ok(Box::new(n.as_f64().unwrap())),
|
||||||
|
Value::String(s) => Ok(Box::new(s.clone())),
|
||||||
|
// TODO: support null (not like this: `Value::Null => Ok(Box::new(None::<bool>)),`)
|
||||||
|
// TODO: support arrays
|
||||||
|
x => Err(anyhow::anyhow!("unsupported param {:?}", x))
|
||||||
|
};
|
||||||
|
boxed
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, anyhow::Error>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
// XXX: return different error codes
|
||||||
|
async fn handle_sql(
|
||||||
|
config: &'static ProxyConfig,
|
||||||
|
request: Request<Body>
|
||||||
|
) -> anyhow::Result<Vec<Value>> {
|
||||||
|
|
||||||
|
let headers = request.headers();
|
||||||
|
|
||||||
|
let connection_string = headers
|
||||||
|
.get("X-Neon-ConnectionString")
|
||||||
|
.ok_or(anyhow::anyhow!("missing connection string"))?
|
||||||
|
.to_str()?;
|
||||||
|
|
||||||
|
let connection_url = Url::parse(connection_string)?;
|
||||||
|
|
||||||
|
let protocol = connection_url.scheme();
|
||||||
|
if protocol != "postgres" && protocol != "postgresql" {
|
||||||
|
return Err(anyhow::anyhow!("connection string must start with postgres: or postgresql:"))
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut url_path = connection_url
|
||||||
|
.path_segments()
|
||||||
|
.ok_or(anyhow::anyhow!("missing database name"))?;
|
||||||
|
|
||||||
|
let dbname = url_path
|
||||||
|
.next()
|
||||||
|
.ok_or(anyhow::anyhow!("invalid database name"))?;
|
||||||
|
|
||||||
|
let username = connection_url.username();
|
||||||
|
|
||||||
|
let password = connection_url
|
||||||
|
.password()
|
||||||
|
.ok_or(anyhow::anyhow!("no password"))?;
|
||||||
|
|
||||||
|
let hostname = connection_url
|
||||||
|
.host_str()
|
||||||
|
.ok_or(anyhow::anyhow!("no host"))?;
|
||||||
|
|
||||||
|
let host_header = request
|
||||||
|
.headers()
|
||||||
|
.get("host")
|
||||||
|
.and_then(|h| h.to_str().ok())
|
||||||
|
.and_then(|h| h.split(':').next());
|
||||||
|
|
||||||
|
match host_header {
|
||||||
|
Some(h) if h == hostname => h,
|
||||||
|
Some(_) => return Err(anyhow::anyhow!("mismatched host header and hostname")),
|
||||||
|
None => return Err(anyhow::anyhow!("no host header"))
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut body = request.into_body();
|
||||||
|
let mut data = Vec::with_capacity(512);
|
||||||
|
while let Some(chunk) = body.next().await {
|
||||||
|
data.extend(&chunk?);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
struct QueryData {
|
||||||
|
query: String,
|
||||||
|
params: Vec<serde_json::Value>
|
||||||
|
}
|
||||||
|
let query_data: QueryData = serde_json::from_slice(&data)?;
|
||||||
|
|
||||||
|
let credential_params = StartupMessageParams::new([
|
||||||
|
("user", username),
|
||||||
|
("database", dbname),
|
||||||
|
("application_name", "proxy_http_sql"),
|
||||||
|
]);
|
||||||
|
let tls = config.tls_config.as_ref();
|
||||||
|
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||||
|
let creds = config
|
||||||
|
.auth_backend
|
||||||
|
.as_ref()
|
||||||
|
.map(|_| auth::ClientCredentials::parse(&credential_params, Some(hostname), common_names))
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
let extra = console::ConsoleReqExtra {
|
||||||
|
session_id: uuid::Uuid::new_v4(),
|
||||||
|
application_name: Some("proxy_http_sql"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let node = creds.wake_compute(&extra).await?.expect("msg");
|
||||||
|
let conf = node.value.config;
|
||||||
|
|
||||||
|
let host = match conf.get_hosts().first().expect("no host") {
|
||||||
|
tokio_postgres::config::Host::Tcp(host) => host,
|
||||||
|
tokio_postgres::config::Host::Unix(_) => {
|
||||||
|
return Err(anyhow::anyhow!("unix socket is not supported"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let conn_string = &format!(
|
||||||
|
"host={} port={} user={} password={} dbname={}",
|
||||||
|
host,
|
||||||
|
conf.get_ports().first().expect("no port"),
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
dbname
|
||||||
|
);
|
||||||
|
|
||||||
|
let (client, connection) = tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = connection.await {
|
||||||
|
eprintln!("connection error: {}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let query = &query_data.query;
|
||||||
|
let query_params = boxed_from_json(query_data.params)?;
|
||||||
|
|
||||||
|
// TODO: find a way to catch the panic and return an error if the number of params is wrong
|
||||||
|
let pg_rows: Vec<Row> = client
|
||||||
|
.query_raw(query, query_params)
|
||||||
|
.await?
|
||||||
|
.try_collect::<Vec<Row>>()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let rows: Result<Vec<serde_json::Value>, anyhow::Error> = pg_rows
|
||||||
|
.iter()
|
||||||
|
.map(postgres_row_to_json_value)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let rows = rows?;
|
||||||
|
|
||||||
|
Ok(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ConnectionCache {
|
||||||
|
connections: HashMap<String, tokio_postgres::Client>,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl ConnectionCache {
|
||||||
|
pub fn new() -> Arc<Mutex<Self>> {
|
||||||
|
Arc::new(Mutex::new(Self {
|
||||||
|
connections: HashMap::new(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn execute(
|
||||||
|
cache: &Arc<Mutex<ConnectionCache>>,
|
||||||
|
conn_string: &str,
|
||||||
|
hostname: &str,
|
||||||
|
sql: &str,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
// TODO: let go mutex when establishing connection
|
||||||
|
let mut cache = cache.lock().await;
|
||||||
|
let cache_key = format!("connstr={}, hostname={}", conn_string, hostname);
|
||||||
|
let client = if let Some(client) = cache.connections.get(&cache_key) {
|
||||||
|
info!("using cached connection {}", conn_string);
|
||||||
|
client
|
||||||
|
} else {
|
||||||
|
info!("!!!! connecting to: {}", conn_string);
|
||||||
|
|
||||||
|
let (client, connection) =
|
||||||
|
tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// TODO: remove connection from cache
|
||||||
|
if let Err(e) = connection.await {
|
||||||
|
eprintln!("connection error: {}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
cache.connections.insert(cache_key.clone(), client);
|
||||||
|
cache.connections.get(&cache_key).unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
let sql = percent_decode(sql.as_bytes()).decode_utf8()?.to_string();
|
||||||
|
|
||||||
|
let rows: Vec<HashMap<_, _>> = client
|
||||||
|
.simple_query(&sql)
|
||||||
|
.await?
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|el| {
|
||||||
|
if let tokio_postgres::SimpleQueryMessage::Row(row) = el {
|
||||||
|
let mut serialized_row: HashMap<String, String> = HashMap::new();
|
||||||
|
for i in 0..row.len() {
|
||||||
|
let col = row.columns().get(i).map_or("?", |c| c.name());
|
||||||
|
let val = row.get(i).unwrap_or("?");
|
||||||
|
serialized_row.insert(col.into(), val.into());
|
||||||
|
}
|
||||||
|
Some(serialized_row)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(serde_json::to_string(&rows)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
pub async fn task_main(
|
pub async fn task_main(
|
||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
|
cache: &'static Arc<Mutex<ConnectionCache>>,
|
||||||
ws_listener: TcpListener,
|
ws_listener: TcpListener,
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
@@ -221,7 +595,7 @@ pub async fn task_main(
|
|||||||
move |req: Request<Body>| async move {
|
move |req: Request<Body>| async move {
|
||||||
let cancel_map = Arc::new(CancelMap::default());
|
let cancel_map = Arc::new(CancelMap::default());
|
||||||
let session_id = uuid::Uuid::new_v4();
|
let session_id = uuid::Uuid::new_v4();
|
||||||
ws_handler(req, config, cancel_map, session_id)
|
ws_handler(req, config, cancel_map, session_id, cache.clone())
|
||||||
.instrument(info_span!(
|
.instrument(info_span!(
|
||||||
"ws-client",
|
"ws-client",
|
||||||
session = format_args!("{session_id}")
|
session = format_args!("{session_id}")
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ use tokio_util::sync::CancellationToken;
|
|||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
use utils::{project_git_version, sentry_init::init_sentry};
|
use utils::{project_git_version, sentry_init::init_sentry};
|
||||||
|
|
||||||
|
use crate::http::websocket::ConnectionCache;
|
||||||
|
|
||||||
project_git_version!(GIT_VERSION);
|
project_git_version!(GIT_VERSION);
|
||||||
|
|
||||||
/// Flattens `Result<Result<T>>` into `Result<T>`.
|
/// Flattens `Result<Result<T>>` into `Result<T>`.
|
||||||
@@ -53,6 +55,8 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let args = cli().get_matches();
|
let args = cli().get_matches();
|
||||||
let config = build_config(&args)?;
|
let config = build_config(&args)?;
|
||||||
|
|
||||||
|
let wsconn_cache = Box::leak(Box::new(ConnectionCache::new()));
|
||||||
|
|
||||||
info!("Authentication backend: {}", config.auth_backend);
|
info!("Authentication backend: {}", config.auth_backend);
|
||||||
|
|
||||||
// Check that we can bind to address before further initialization
|
// Check that we can bind to address before further initialization
|
||||||
@@ -82,6 +86,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
client_tasks.push(tokio::spawn(http::websocket::task_main(
|
client_tasks.push(tokio::spawn(http::websocket::task_main(
|
||||||
config,
|
config,
|
||||||
|
wsconn_cache,
|
||||||
wss_listener,
|
wss_listener,
|
||||||
cancellation_token.clone(),
|
cancellation_token.clone(),
|
||||||
)));
|
)));
|
||||||
|
|||||||
2
vendor/postgres-v14
vendored
2
vendor/postgres-v14
vendored
Submodule vendor/postgres-v14 updated: 3e70693c91...c22aea6714
2
vendor/postgres-v15
vendored
2
vendor/postgres-v15
vendored
Submodule vendor/postgres-v15 updated: 4ad87b0f36...114da43a49
Reference in New Issue
Block a user