mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-08 21:20:38 +00:00
Compare commits
10 Commits
hack/compu
...
proxy/remo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
101e770632 | ||
|
|
747ffa50d6 | ||
|
|
e8400d9d93 | ||
|
|
17627e8023 | ||
|
|
4fb5cdbdb8 | ||
|
|
70b503f83b | ||
|
|
21c15c4285 | ||
|
|
0a524e09a5 | ||
|
|
cbe24f7c35 | ||
|
|
64add503c8 |
14
Cargo.lock
generated
14
Cargo.lock
generated
@@ -2654,16 +2654,6 @@ dependencies = [
|
|||||||
"windows-sys 0.45.0",
|
"windows-sys 0.45.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pbkdf2"
|
|
||||||
version = "0.12.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "f0ca0b5a68607598bf3bad68f32227a8164f6254833f84eafaac409cd6746c31"
|
|
||||||
dependencies = [
|
|
||||||
"digest",
|
|
||||||
"hmac",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "peeking_take_while"
|
name = "peeking_take_while"
|
||||||
version = "0.1.2"
|
version = "0.1.2"
|
||||||
@@ -3040,6 +3030,7 @@ dependencies = [
|
|||||||
"chrono",
|
"chrono",
|
||||||
"clap",
|
"clap",
|
||||||
"consumption_metrics",
|
"consumption_metrics",
|
||||||
|
"fallible-iterator",
|
||||||
"futures",
|
"futures",
|
||||||
"git-version",
|
"git-version",
|
||||||
"hashbrown 0.13.2",
|
"hashbrown 0.13.2",
|
||||||
@@ -3057,9 +3048,9 @@ dependencies = [
|
|||||||
"once_cell",
|
"once_cell",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"parking_lot 0.12.1",
|
"parking_lot 0.12.1",
|
||||||
"pbkdf2",
|
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"postgres-native-tls",
|
"postgres-native-tls",
|
||||||
|
"postgres-protocol",
|
||||||
"postgres_backend",
|
"postgres_backend",
|
||||||
"pq_proto",
|
"pq_proto",
|
||||||
"prometheus",
|
"prometheus",
|
||||||
@@ -3083,6 +3074,7 @@ dependencies = [
|
|||||||
"thiserror",
|
"thiserror",
|
||||||
"tls-listener",
|
"tls-listener",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-native-tls",
|
||||||
"tokio-postgres",
|
"tokio-postgres",
|
||||||
"tokio-postgres-rustls",
|
"tokio-postgres-rustls",
|
||||||
"tokio-rustls 0.23.4",
|
"tokio-rustls 0.23.4",
|
||||||
|
|||||||
@@ -88,7 +88,6 @@ opentelemetry = "0.19.0"
|
|||||||
opentelemetry-otlp = { version = "0.12.0", default_features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
|
opentelemetry-otlp = { version = "0.12.0", default_features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
|
||||||
opentelemetry-semantic-conventions = "0.11.0"
|
opentelemetry-semantic-conventions = "0.11.0"
|
||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
pbkdf2 = "0.12.1"
|
|
||||||
pin-project-lite = "0.2"
|
pin-project-lite = "0.2"
|
||||||
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
|
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
|
||||||
prost = "0.11"
|
prost = "0.11"
|
||||||
|
|||||||
@@ -29,9 +29,9 @@ metrics.workspace = true
|
|||||||
once_cell.workspace = true
|
once_cell.workspace = true
|
||||||
opentelemetry.workspace = true
|
opentelemetry.workspace = true
|
||||||
parking_lot.workspace = true
|
parking_lot.workspace = true
|
||||||
pbkdf2.workspace = true
|
|
||||||
pin-project-lite.workspace = true
|
pin-project-lite.workspace = true
|
||||||
postgres_backend.workspace = true
|
postgres_backend.workspace = true
|
||||||
|
postgres-protocol.workspace = true
|
||||||
pq_proto.workspace = true
|
pq_proto.workspace = true
|
||||||
prometheus.workspace = true
|
prometheus.workspace = true
|
||||||
rand.workspace = true
|
rand.workspace = true
|
||||||
@@ -65,10 +65,13 @@ webpki-roots.workspace = true
|
|||||||
x509-parser.workspace = true
|
x509-parser.workspace = true
|
||||||
native-tls.workspace = true
|
native-tls.workspace = true
|
||||||
postgres-native-tls.workspace = true
|
postgres-native-tls.workspace = true
|
||||||
|
tokio-native-tls = "0.3.1"
|
||||||
|
|
||||||
workspace_hack.workspace = true
|
workspace_hack.workspace = true
|
||||||
tokio-util.workspace = true
|
tokio-util.workspace = true
|
||||||
|
|
||||||
|
fallible-iterator = "0.2.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
rcgen.workspace = true
|
rcgen.workspace = true
|
||||||
rstest.workspace = true
|
rstest.workspace = true
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use std::fmt;
|
|||||||
use std::{collections::HashMap, sync::Arc};
|
use std::{collections::HashMap, sync::Arc};
|
||||||
use tokio::time;
|
use tokio::time;
|
||||||
|
|
||||||
use crate::{auth, console};
|
use crate::{auth, console, pg_client};
|
||||||
use crate::{compute, config};
|
use crate::{compute, config};
|
||||||
|
|
||||||
use super::sql_over_http::MAX_RESPONSE_SIZE;
|
use super::sql_over_http::MAX_RESPONSE_SIZE;
|
||||||
@@ -41,8 +41,10 @@ impl fmt::Display for ConnInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PgConn =
|
||||||
|
pg_client::connection::Connection<tokio_postgres::Socket, tokio_postgres::tls::NoTlsStream>;
|
||||||
struct ConnPoolEntry {
|
struct ConnPoolEntry {
|
||||||
conn: tokio_postgres::Client,
|
conn: PgConn,
|
||||||
_last_access: std::time::Instant,
|
_last_access: std::time::Instant,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,12 +80,8 @@ impl GlobalConnPool {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get(
|
pub async fn get(&self, conn_info: &ConnInfo, force_new: bool) -> anyhow::Result<PgConn> {
|
||||||
&self,
|
let mut client: Option<PgConn> = None;
|
||||||
conn_info: &ConnInfo,
|
|
||||||
force_new: bool,
|
|
||||||
) -> anyhow::Result<tokio_postgres::Client> {
|
|
||||||
let mut client: Option<tokio_postgres::Client> = None;
|
|
||||||
|
|
||||||
if !force_new {
|
if !force_new {
|
||||||
let pool = self.get_endpoint_pool(&conn_info.hostname).await;
|
let pool = self.get_endpoint_pool(&conn_info.hostname).await;
|
||||||
@@ -114,11 +112,7 @@ impl GlobalConnPool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn put(
|
pub async fn put(&self, conn_info: &ConnInfo, client: PgConn) -> anyhow::Result<()> {
|
||||||
&self,
|
|
||||||
conn_info: &ConnInfo,
|
|
||||||
client: tokio_postgres::Client,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let pool = self.get_endpoint_pool(&conn_info.hostname).await;
|
let pool = self.get_endpoint_pool(&conn_info.hostname).await;
|
||||||
|
|
||||||
// return connection to the pool
|
// return connection to the pool
|
||||||
@@ -191,7 +185,7 @@ struct TokioMechanism<'a> {
|
|||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl ConnectMechanism for TokioMechanism<'_> {
|
impl ConnectMechanism for TokioMechanism<'_> {
|
||||||
type Connection = tokio_postgres::Client;
|
type Connection = PgConn;
|
||||||
type ConnectError = tokio_postgres::Error;
|
type ConnectError = tokio_postgres::Error;
|
||||||
type Error = anyhow::Error;
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
@@ -213,7 +207,7 @@ impl ConnectMechanism for TokioMechanism<'_> {
|
|||||||
async fn connect_to_compute(
|
async fn connect_to_compute(
|
||||||
config: &config::ProxyConfig,
|
config: &config::ProxyConfig,
|
||||||
conn_info: &ConnInfo,
|
conn_info: &ConnInfo,
|
||||||
) -> anyhow::Result<tokio_postgres::Client> {
|
) -> anyhow::Result<PgConn> {
|
||||||
let tls = config.tls_config.as_ref();
|
let tls = config.tls_config.as_ref();
|
||||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||||
|
|
||||||
@@ -251,7 +245,7 @@ async fn connect_to_compute_once(
|
|||||||
node_info: &console::CachedNodeInfo,
|
node_info: &console::CachedNodeInfo,
|
||||||
conn_info: &ConnInfo,
|
conn_info: &ConnInfo,
|
||||||
timeout: time::Duration,
|
timeout: time::Duration,
|
||||||
) -> Result<tokio_postgres::Client, tokio_postgres::Error> {
|
) -> Result<PgConn, tokio_postgres::Error> {
|
||||||
let mut config = (*node_info.config).clone();
|
let mut config = (*node_info.config).clone();
|
||||||
|
|
||||||
let (client, connection) = config
|
let (client, connection) = config
|
||||||
@@ -263,11 +257,13 @@ async fn connect_to_compute_once(
|
|||||||
.connect(tokio_postgres::NoTls)
|
.connect(tokio_postgres::NoTls)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
tokio::spawn(async move {
|
let stream = connection.stream.into_inner();
|
||||||
if let Err(e) = connection.await {
|
|
||||||
error!("connection error: {}", e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(client)
|
// tokio::spawn(async move {
|
||||||
|
// if let Err(e) = connection.await {
|
||||||
|
// error!("connection error: {}", e);
|
||||||
|
// }
|
||||||
|
// });
|
||||||
|
|
||||||
|
Ok(PgConn::new(stream))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
|
use std::io::ErrorKind;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::bail;
|
use anyhow::bail;
|
||||||
|
use bytes::BufMut;
|
||||||
|
use fallible_iterator::FallibleIterator;
|
||||||
use futures::pin_mut;
|
use futures::pin_mut;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use hashbrown::HashMap;
|
use hashbrown::HashMap;
|
||||||
@@ -8,16 +11,28 @@ use hyper::body::HttpBody;
|
|||||||
use hyper::http::HeaderName;
|
use hyper::http::HeaderName;
|
||||||
use hyper::http::HeaderValue;
|
use hyper::http::HeaderValue;
|
||||||
use hyper::{Body, HeaderMap, Request};
|
use hyper::{Body, HeaderMap, Request};
|
||||||
|
use postgres_protocol::message::backend::DataRowBody;
|
||||||
|
use postgres_protocol::message::backend::ReadyForQueryBody;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use serde_json::Map;
|
use serde_json::Map;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
use tokio::io::AsyncRead;
|
||||||
|
use tokio::io::AsyncWrite;
|
||||||
use tokio_postgres::types::Kind;
|
use tokio_postgres::types::Kind;
|
||||||
use tokio_postgres::types::Type;
|
use tokio_postgres::types::Type;
|
||||||
use tokio_postgres::GenericClient;
|
use tokio_postgres::GenericClient;
|
||||||
use tokio_postgres::IsolationLevel;
|
use tokio_postgres::IsolationLevel;
|
||||||
use tokio_postgres::Row;
|
use tokio_postgres::Row;
|
||||||
|
use tokio_postgres::RowStream;
|
||||||
|
use tokio_postgres::Statement;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
|
use crate::pg_client;
|
||||||
|
use crate::pg_client::codec::FrontendMessage;
|
||||||
|
use crate::pg_client::connection;
|
||||||
|
use crate::pg_client::connection::RequestMessages;
|
||||||
|
use crate::pg_client::prepare::TypeinfoPreparedQueries;
|
||||||
|
|
||||||
use super::conn_pool::ConnInfo;
|
use super::conn_pool::ConnInfo;
|
||||||
use super::conn_pool::GlobalConnPool;
|
use super::conn_pool::GlobalConnPool;
|
||||||
|
|
||||||
@@ -230,30 +245,35 @@ pub async fn handle(
|
|||||||
// Now execute the query and return the result
|
// Now execute the query and return the result
|
||||||
//
|
//
|
||||||
let result = match payload {
|
let result = match payload {
|
||||||
Payload::Single(query) => query_to_json(&client, query, raw_output, array_mode)
|
Payload::Single(query) => query_raw_txt_as_json(&mut client, query, raw_output, array_mode)
|
||||||
.await
|
.await
|
||||||
.map(|x| (x, HashMap::default())),
|
.map(|x| (x, HashMap::default())),
|
||||||
Payload::Batch(queries) => {
|
Payload::Batch(queries) => {
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
let mut builder = client.build_transaction();
|
|
||||||
if let Some(isolation_level) = txn_isolation_level {
|
client
|
||||||
builder = builder.isolation_level(isolation_level);
|
.start_tx(txn_isolation_level, Some(txn_read_only))
|
||||||
}
|
.await?;
|
||||||
if txn_read_only {
|
|
||||||
builder = builder.read_only(true);
|
|
||||||
}
|
|
||||||
let transaction = builder.start().await?;
|
|
||||||
for query in queries {
|
for query in queries {
|
||||||
let result = query_to_json(&transaction, query, raw_output, array_mode).await;
|
let result =
|
||||||
|
query_raw_txt_as_json(&mut client, query, raw_output, array_mode).await;
|
||||||
match result {
|
match result {
|
||||||
Ok(r) => results.push(r),
|
// TODO: check this tag to see if the client has executed a commit during the non-interactive transactions...
|
||||||
|
Ok((r, _ready_tag)) => results.push(r),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
transaction.rollback().await?;
|
let tag = client.rollback().await?;
|
||||||
|
if allow_pool && tag.status() == b'I' {
|
||||||
|
// return connection to the pool
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
let _ = conn_pool.put(&conn_info, client).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
return Err(e);
|
return Err(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
transaction.commit().await?;
|
let ready_tag = client.commit().await?;
|
||||||
let mut headers = HashMap::default();
|
let mut headers = HashMap::default();
|
||||||
headers.insert(
|
headers.insert(
|
||||||
TXN_READ_ONLY.clone(),
|
TXN_READ_ONLY.clone(),
|
||||||
@@ -262,11 +282,11 @@ pub async fn handle(
|
|||||||
if let Some(txn_isolation_level_raw) = txn_isolation_level_raw {
|
if let Some(txn_isolation_level_raw) = txn_isolation_level_raw {
|
||||||
headers.insert(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level_raw);
|
headers.insert(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level_raw);
|
||||||
}
|
}
|
||||||
Ok((json!({ "results": results }), headers))
|
Ok(((json!({ "results": results }), ready_tag), headers))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if allow_pool {
|
if allow_pool && ready_tag.status() == b'I' {
|
||||||
// return connection to the pool
|
// return connection to the pool
|
||||||
tokio::task::spawn(async move {
|
tokio::task::spawn(async move {
|
||||||
let _ = conn_pool.put(&conn_info, client).await;
|
let _ = conn_pool.put(&conn_info, client).await;
|
||||||
@@ -351,6 +371,99 @@ async fn query_to_json<T: GenericClient>(
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn query_raw_txt_as_json<'a, St, T>(
|
||||||
|
conn: &mut connection::Connection<St, T>,
|
||||||
|
data: QueryData,
|
||||||
|
raw_output: bool,
|
||||||
|
array_mode: bool,
|
||||||
|
) -> anyhow::Result<(Value, ReadyForQueryBody)>
|
||||||
|
where
|
||||||
|
St: AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
{
|
||||||
|
let params = json_to_pg_text(data.params)?;
|
||||||
|
let params = params.into_iter();
|
||||||
|
|
||||||
|
let stmt_name = conn.statement_name();
|
||||||
|
let row_description = conn.prepare(&stmt_name, &data.query).await?;
|
||||||
|
|
||||||
|
let mut fields = vec![];
|
||||||
|
let mut columns = vec![];
|
||||||
|
let mut it = row_description.fields();
|
||||||
|
while let Some(field) = it.next().map_err(pg_client::error::Error::parse)? {
|
||||||
|
fields.push(json!({
|
||||||
|
"name": Value::String(field.name().to_owned()),
|
||||||
|
"dataTypeID": Value::Number(field.type_oid().into()),
|
||||||
|
"tableID": field.table_oid(),
|
||||||
|
"columnID": field.column_id(),
|
||||||
|
"dataTypeSize": field.type_size(),
|
||||||
|
"dataTypeModifier": field.type_modifier(),
|
||||||
|
"format": "text",
|
||||||
|
}));
|
||||||
|
|
||||||
|
let type_ = match Type::from_oid(field.type_oid()) {
|
||||||
|
Some(t) => t,
|
||||||
|
None => TypeinfoPreparedQueries::get_type(conn, field.type_oid()).await?,
|
||||||
|
};
|
||||||
|
|
||||||
|
columns.push(Column {
|
||||||
|
name: field.name().to_string(),
|
||||||
|
type_,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.execute("", &stmt_name, params)?;
|
||||||
|
conn.sync().await?;
|
||||||
|
|
||||||
|
let mut rows = vec![];
|
||||||
|
|
||||||
|
let mut row_stream = conn.stream_query_results().await?;
|
||||||
|
|
||||||
|
let mut curret_size = 0;
|
||||||
|
while let Some(row) = row_stream.next().await.transpose()? {
|
||||||
|
// let row = row.map_err(Error::db)?;
|
||||||
|
|
||||||
|
curret_size += row.buffer().len();
|
||||||
|
if curret_size > MAX_RESPONSE_SIZE {
|
||||||
|
return Err(anyhow::anyhow!("response too large"));
|
||||||
|
}
|
||||||
|
|
||||||
|
rows.push(pg_text_row_to_json2(&row, &columns, raw_output, array_mode).unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
let command_tag = row_stream.tag();
|
||||||
|
let command_tag = command_tag.tag()?;
|
||||||
|
let mut command_tag_split = command_tag.split(' ');
|
||||||
|
let command_tag_name = command_tag_split.next().unwrap_or_default();
|
||||||
|
let command_tag_count = if command_tag_name == "INSERT" {
|
||||||
|
// INSERT returns OID first and then number of rows
|
||||||
|
command_tag_split.nth(1)
|
||||||
|
} else {
|
||||||
|
// other commands return number of rows (if any)
|
||||||
|
command_tag_split.next()
|
||||||
|
}
|
||||||
|
.and_then(|s| s.parse::<i64>().ok());
|
||||||
|
|
||||||
|
let ready_tag = conn.wait_for_ready().await?;
|
||||||
|
|
||||||
|
// resulting JSON format is based on the format of node-postgres result
|
||||||
|
Ok((
|
||||||
|
json!({
|
||||||
|
"command": command_tag_name,
|
||||||
|
"rowCount": command_tag_count,
|
||||||
|
"rows": rows,
|
||||||
|
"fields": fields,
|
||||||
|
"rowAsArray": array_mode,
|
||||||
|
}),
|
||||||
|
ready_tag,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Column {
|
||||||
|
name: String,
|
||||||
|
type_: Type,
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Convert postgres row with text-encoded values to JSON object
|
// Convert postgres row with text-encoded values to JSON object
|
||||||
//
|
//
|
||||||
@@ -370,7 +483,7 @@ pub fn pg_text_row_to_json(
|
|||||||
} else {
|
} else {
|
||||||
pg_text_to_json(pg_value, column.type_())?
|
pg_text_to_json(pg_value, column.type_())?
|
||||||
};
|
};
|
||||||
Ok((name.to_string(), json_value))
|
Ok((name, json_value))
|
||||||
});
|
});
|
||||||
|
|
||||||
if array_mode {
|
if array_mode {
|
||||||
@@ -380,7 +493,55 @@ pub fn pg_text_row_to_json(
|
|||||||
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
|
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
|
||||||
Ok(Value::Array(arr))
|
Ok(Value::Array(arr))
|
||||||
} else {
|
} else {
|
||||||
let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
let obj = iter
|
||||||
|
.map(|r| r.map(|(key, val)| (key.to_owned(), val)))
|
||||||
|
.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||||
|
Ok(Value::Object(obj))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Convert postgres row with text-encoded values to JSON object
|
||||||
|
//
|
||||||
|
fn pg_text_row_to_json2(
|
||||||
|
row: &DataRowBody,
|
||||||
|
columns: &[Column],
|
||||||
|
raw_output: bool,
|
||||||
|
array_mode: bool,
|
||||||
|
) -> Result<Value, anyhow::Error> {
|
||||||
|
let ranges: Vec<Option<std::ops::Range<usize>>> = row.ranges().collect()?;
|
||||||
|
let iter = std::iter::zip(ranges, columns)
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, (range, column))| {
|
||||||
|
let name = &column.name;
|
||||||
|
let pg_value = range
|
||||||
|
.map(|r| {
|
||||||
|
std::str::from_utf8(&row.buffer()[r])
|
||||||
|
.map_err(|e| pg_client::error::Error::from_sql(e.into(), i))
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
// let pg_value = row.as_text(i)?;
|
||||||
|
let json_value = if raw_output {
|
||||||
|
match pg_value {
|
||||||
|
Some(v) => Value::String(v.to_string()),
|
||||||
|
None => Value::Null,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
pg_text_to_json(pg_value, &column.type_)?
|
||||||
|
};
|
||||||
|
Ok((name, json_value))
|
||||||
|
});
|
||||||
|
|
||||||
|
if array_mode {
|
||||||
|
// drop keys and aggregate into array
|
||||||
|
let arr = iter
|
||||||
|
.map(|r| r.map(|(_key, val)| val))
|
||||||
|
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
|
||||||
|
Ok(Value::Array(arr))
|
||||||
|
} else {
|
||||||
|
let obj = iter
|
||||||
|
.map(|r| r.map(|(key, val)| (key.to_owned(), val)))
|
||||||
|
.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||||
Ok(Value::Object(obj))
|
Ok(Value::Object(obj))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -391,16 +552,16 @@ pub fn pg_text_row_to_json(
|
|||||||
pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
|
pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
|
||||||
if let Some(val) = pg_value {
|
if let Some(val) = pg_value {
|
||||||
if let Kind::Array(elem_type) = pg_type.kind() {
|
if let Kind::Array(elem_type) = pg_type.kind() {
|
||||||
return pg_array_parse(val, elem_type);
|
return pg_array_parse(val, &elem_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
match *pg_type {
|
match pg_type {
|
||||||
Type::BOOL => Ok(Value::Bool(val == "t")),
|
&Type::BOOL => Ok(Value::Bool(val == "t")),
|
||||||
Type::INT2 | Type::INT4 => {
|
&Type::INT2 | &Type::INT4 => {
|
||||||
let val = val.parse::<i32>()?;
|
let val = val.parse::<i32>()?;
|
||||||
Ok(Value::Number(serde_json::Number::from(val)))
|
Ok(Value::Number(serde_json::Number::from(val)))
|
||||||
}
|
}
|
||||||
Type::FLOAT4 | Type::FLOAT8 => {
|
&Type::FLOAT4 | &Type::FLOAT8 => {
|
||||||
let fval = val.parse::<f64>()?;
|
let fval = val.parse::<f64>()?;
|
||||||
let num = serde_json::Number::from_f64(fval);
|
let num = serde_json::Number::from_f64(fval);
|
||||||
if let Some(num) = num {
|
if let Some(num) = num {
|
||||||
@@ -412,7 +573,7 @@ pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value,
|
|||||||
Ok(Value::String(val.to_string()))
|
Ok(Value::String(val.to_string()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
|
&Type::JSON | &Type::JSONB => Ok(serde_json::from_str(val)?),
|
||||||
_ => Ok(Value::String(val.to_string())),
|
_ => Ok(Value::String(val.to_string())),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ pub mod scram;
|
|||||||
pub mod stream;
|
pub mod stream;
|
||||||
pub mod url;
|
pub mod url;
|
||||||
pub mod waiters;
|
pub mod waiters;
|
||||||
|
pub mod pg_client;
|
||||||
|
|
||||||
/// Handle unix signals appropriately.
|
/// Handle unix signals appropriately.
|
||||||
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
|
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
|
||||||
|
|||||||
43
proxy/src/pg_client/codec.rs
Normal file
43
proxy/src/pg_client/codec.rs
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
|
use fallible_iterator::FallibleIterator;
|
||||||
|
use postgres_protocol::message::backend::{self, Message};
|
||||||
|
use std::io;
|
||||||
|
use tokio_util::codec::{Decoder, Encoder};
|
||||||
|
|
||||||
|
pub struct FrontendMessage(pub Bytes);
|
||||||
|
pub struct BackendMessages(pub BytesMut);
|
||||||
|
|
||||||
|
impl BackendMessages {
|
||||||
|
pub fn empty() -> BackendMessages {
|
||||||
|
BackendMessages(BytesMut::new())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FallibleIterator for BackendMessages {
|
||||||
|
type Item = backend::Message;
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn next(&mut self) -> io::Result<Option<backend::Message>> {
|
||||||
|
backend::Message::parse(&mut self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PostgresCodec;
|
||||||
|
|
||||||
|
impl Encoder<FrontendMessage> for PostgresCodec {
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
|
||||||
|
dst.extend_from_slice(&item.0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Decoder for PostgresCodec {
|
||||||
|
type Item = Message;
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message>, io::Error> {
|
||||||
|
Message::parse(src)
|
||||||
|
}
|
||||||
|
}
|
||||||
369
proxy/src/pg_client/connection.rs
Normal file
369
proxy/src/pg_client/connection.rs
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
use super::codec::{BackendMessages, FrontendMessage, PostgresCodec};
|
||||||
|
use super::error::Error;
|
||||||
|
use super::prepare::TypeinfoPreparedQueries;
|
||||||
|
use bytes::{BufMut, BytesMut};
|
||||||
|
use futures::channel::mpsc;
|
||||||
|
use futures::{Sink, StreamExt};
|
||||||
|
use futures::{SinkExt, Stream};
|
||||||
|
use hashbrown::HashMap;
|
||||||
|
use postgres_protocol::message::backend::{
|
||||||
|
BackendKeyDataBody, CommandCompleteBody, DataRowBody, ErrorResponseBody, Message,
|
||||||
|
ReadyForQueryBody, RowDescriptionBody,
|
||||||
|
};
|
||||||
|
use postgres_protocol::message::frontend;
|
||||||
|
use postgres_protocol::Oid;
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::future::poll_fn;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{ready, Context, Poll};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
use tokio_postgres::maybe_tls_stream::MaybeTlsStream;
|
||||||
|
use tokio_postgres::types::Type;
|
||||||
|
use tokio_postgres::IsolationLevel;
|
||||||
|
use tokio_util::codec::Framed;
|
||||||
|
|
||||||
|
pub enum RequestMessages {
|
||||||
|
Single(FrontendMessage),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Request {
|
||||||
|
pub messages: RequestMessages,
|
||||||
|
pub sender: mpsc::Sender<BackendMessages>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Response {
|
||||||
|
sender: mpsc::Sender<BackendMessages>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A connection to a PostgreSQL database.
|
||||||
|
pub struct RawConnection<S, T> {
|
||||||
|
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||||
|
pending_responses: VecDeque<Message>,
|
||||||
|
pub buf: BytesMut,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> RawConnection<S, T> {
|
||||||
|
pub fn new(
|
||||||
|
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||||
|
buf: BytesMut,
|
||||||
|
) -> RawConnection<S, T> {
|
||||||
|
RawConnection {
|
||||||
|
stream,
|
||||||
|
pending_responses: VecDeque::new(),
|
||||||
|
buf,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send(&mut self) -> Result<(), Error> {
|
||||||
|
poll_fn(|cx| self.poll_send(cx)).await?;
|
||||||
|
let request = FrontendMessage(self.buf.split().freeze());
|
||||||
|
self.stream.start_send_unpin(request).map_err(Error::io)?;
|
||||||
|
poll_fn(|cx| self.poll_flush(cx)).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn next_message(&mut self) -> Result<Message, Error> {
|
||||||
|
match self.pending_responses.pop_front() {
|
||||||
|
Some(message) => Ok(message),
|
||||||
|
None => poll_fn(|cx| self.poll_read(cx)).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
|
||||||
|
let message = match ready!(self.stream.poll_next_unpin(cx)?) {
|
||||||
|
Some(message) => message,
|
||||||
|
None => return Poll::Ready(Err(Error::closed())),
|
||||||
|
};
|
||||||
|
Poll::Ready(Ok(message))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
|
Pin::new(&mut self.stream).poll_close(cx).map_err(Error::io)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
|
if let Poll::Ready(msg) = self.poll_read(cx)? {
|
||||||
|
self.pending_responses.push_back(msg);
|
||||||
|
};
|
||||||
|
self.stream.poll_ready_unpin(cx).map_err(Error::io)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
|
if let Poll::Ready(msg) = self.poll_read(cx)? {
|
||||||
|
self.pending_responses.push_back(msg);
|
||||||
|
};
|
||||||
|
self.stream.poll_flush_unpin(cx).map_err(Error::io)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Connection<S, T> {
|
||||||
|
stmt_counter: usize,
|
||||||
|
pub typeinfo: Option<TypeinfoPreparedQueries>,
|
||||||
|
pub typecache: HashMap<Oid, Type>,
|
||||||
|
pub raw: RawConnection<S, T>,
|
||||||
|
// key: BackendKeyDataBody,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> Connection<S, T> {
|
||||||
|
pub fn new(stream: MaybeTlsStream<S, T>) -> Connection<S, T> {
|
||||||
|
Connection {
|
||||||
|
stmt_counter: 0,
|
||||||
|
typeinfo: None,
|
||||||
|
typecache: HashMap::new(),
|
||||||
|
raw: RawConnection::new(Framed::new(stream, PostgresCodec), BytesMut::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start_tx(
|
||||||
|
&mut self,
|
||||||
|
isolation_level: Option<IsolationLevel>,
|
||||||
|
read_only: Option<bool>,
|
||||||
|
) -> Result<ReadyForQueryBody, Error> {
|
||||||
|
let mut query = "START TRANSACTION".to_string();
|
||||||
|
let mut first = true;
|
||||||
|
|
||||||
|
if let Some(level) = isolation_level {
|
||||||
|
first = false;
|
||||||
|
|
||||||
|
query.push_str(" ISOLATION LEVEL ");
|
||||||
|
let level = match level {
|
||||||
|
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
|
||||||
|
IsolationLevel::ReadCommitted => "READ COMMITTED",
|
||||||
|
IsolationLevel::RepeatableRead => "REPEATABLE READ",
|
||||||
|
IsolationLevel::Serializable => "SERIALIZABLE",
|
||||||
|
_ => return Err(Error::unexpected_message()),
|
||||||
|
};
|
||||||
|
query.push_str(level);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(read_only) = read_only {
|
||||||
|
if !first {
|
||||||
|
query.push(',');
|
||||||
|
}
|
||||||
|
first = false;
|
||||||
|
|
||||||
|
let s = if read_only {
|
||||||
|
" READ ONLY"
|
||||||
|
} else {
|
||||||
|
" READ WRITE"
|
||||||
|
};
|
||||||
|
query.push_str(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.execute_simple(&query).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn rollback(&mut self) -> Result<ReadyForQueryBody, Error> {
|
||||||
|
self.execute_simple("ROLLBACK").await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn commit(&mut self) -> Result<ReadyForQueryBody, Error> {
|
||||||
|
self.execute_simple("COMMIT").await
|
||||||
|
}
|
||||||
|
|
||||||
|
// pub async fn auth_sasl_scram<'a, I>(
|
||||||
|
// mut raw: RawConnection<S, T>,
|
||||||
|
// params: I,
|
||||||
|
// password: &[u8],
|
||||||
|
// ) -> Result<Self, Error>
|
||||||
|
// where
|
||||||
|
// I: IntoIterator<Item = (&'a str, &'a str)>,
|
||||||
|
// {
|
||||||
|
// // send a startup message
|
||||||
|
// frontend::startup_message(params, &mut raw.buf).unwrap();
|
||||||
|
// raw.send().await?;
|
||||||
|
|
||||||
|
// // expect sasl authentication message
|
||||||
|
// let Message::AuthenticationSasl(body) = raw.next_message().await? else { return Err(Error::expecting("sasl authentication")) };
|
||||||
|
// // expect support for SCRAM_SHA_256
|
||||||
|
// if body
|
||||||
|
// .mechanisms()
|
||||||
|
// .find(|&x| Ok(x == authentication::sasl::SCRAM_SHA_256))?
|
||||||
|
// .is_none()
|
||||||
|
// {
|
||||||
|
// return Err(Error::expecting("SCRAM-SHA-256 auth"));
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // initiate SCRAM_SHA_256 authentication without channel binding
|
||||||
|
// let auth = authentication::sasl::ChannelBinding::unrequested();
|
||||||
|
// let mut scram = authentication::sasl::ScramSha256::new(password, auth);
|
||||||
|
|
||||||
|
// frontend::sasl_initial_response(
|
||||||
|
// authentication::sasl::SCRAM_SHA_256,
|
||||||
|
// scram.message(),
|
||||||
|
// &mut raw.buf,
|
||||||
|
// )
|
||||||
|
// .unwrap();
|
||||||
|
// raw.send().await?;
|
||||||
|
|
||||||
|
// // expect sasl continue
|
||||||
|
// let Message::AuthenticationSaslContinue(b) = raw.next_message().await? else { return Err(Error::expecting("auth continue")) };
|
||||||
|
// scram.update(b.data()).unwrap();
|
||||||
|
|
||||||
|
// // continue sasl
|
||||||
|
// frontend::sasl_response(scram.message(), &mut raw.buf).unwrap();
|
||||||
|
// raw.send().await?;
|
||||||
|
|
||||||
|
// // expect sasl final
|
||||||
|
// let Message::AuthenticationSaslFinal(b) = raw.next_message().await? else { return Err(Error::expecting("auth final")) };
|
||||||
|
// scram.finish(b.data()).unwrap();
|
||||||
|
|
||||||
|
// // expect auth ok
|
||||||
|
// let Message::AuthenticationOk = raw.next_message().await? else { return Err(Error::expecting("auth ok")) };
|
||||||
|
|
||||||
|
// // expect connection accepted
|
||||||
|
// let key = loop {
|
||||||
|
// match raw.next_message().await? {
|
||||||
|
// Message::BackendKeyData(key) => break key,
|
||||||
|
// Message::ParameterStatus(_) => {}
|
||||||
|
// _ => return Err(Error::expecting("backend ready")),
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
|
||||||
|
// let Message::ReadyForQuery(b) = raw.next_message().await? else { return Err(Error::expecting("ready for query")) };
|
||||||
|
// // assert_eq!(b.status(), b'I');
|
||||||
|
|
||||||
|
// Ok(Self { raw, key })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// pub fn prepare_and_execute(
|
||||||
|
// &mut self,
|
||||||
|
// portal: &str,
|
||||||
|
// name: &str,
|
||||||
|
// query: &str,
|
||||||
|
// params: impl IntoIterator<Item = Option<impl AsRef<str>>>,
|
||||||
|
// ) -> std::io::Result<()> {
|
||||||
|
// self.prepare(name, query)?;
|
||||||
|
// self.execute(portal, name, params)
|
||||||
|
// }
|
||||||
|
|
||||||
|
pub fn statement_name(&mut self) -> String {
|
||||||
|
self.stmt_counter += 1;
|
||||||
|
format!("s{}", self.stmt_counter)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute_simple(&mut self, query: &str) -> Result<ReadyForQueryBody, Error> {
|
||||||
|
frontend::query(query, &mut self.raw.buf)?;
|
||||||
|
self.raw.send().await?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match self.raw.next_message().await? {
|
||||||
|
Message::ReadyForQuery(q) => return Ok(q),
|
||||||
|
Message::CommandComplete(_)
|
||||||
|
| Message::EmptyQueryResponse
|
||||||
|
| Message::RowDescription(_)
|
||||||
|
| Message::DataRow(_) => {}
|
||||||
|
_ => return Err(Error::unexpected_message()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn prepare(&mut self, name: &str, query: &str) -> Result<RowDescriptionBody, Error> {
|
||||||
|
frontend::parse(name, query, std::iter::empty(), &mut self.raw.buf)?;
|
||||||
|
frontend::describe(b'S', name, &mut self.raw.buf)?;
|
||||||
|
self.sync().await?;
|
||||||
|
self.wait_for_prepare().await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn execute(
|
||||||
|
&mut self,
|
||||||
|
portal: &str,
|
||||||
|
name: &str,
|
||||||
|
params: impl IntoIterator<Item = Option<impl AsRef<str>>>,
|
||||||
|
) -> std::io::Result<()> {
|
||||||
|
frontend::bind(
|
||||||
|
portal,
|
||||||
|
name,
|
||||||
|
std::iter::empty(), // all parameters use the default format (text)
|
||||||
|
params,
|
||||||
|
|param, buf| match param {
|
||||||
|
Some(param) => {
|
||||||
|
buf.put_slice(param.as_ref().as_bytes());
|
||||||
|
Ok(postgres_protocol::IsNull::No)
|
||||||
|
}
|
||||||
|
None => Ok(postgres_protocol::IsNull::Yes),
|
||||||
|
},
|
||||||
|
Some(0), // all text
|
||||||
|
&mut self.raw.buf,
|
||||||
|
)
|
||||||
|
.map_err(|e| match e {
|
||||||
|
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
|
||||||
|
frontend::BindError::Serialization(io) => io,
|
||||||
|
})?;
|
||||||
|
frontend::execute(portal, 0, &mut self.raw.buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn sync(&mut self) -> Result<(), Error> {
|
||||||
|
frontend::sync(&mut self.raw.buf);
|
||||||
|
self.raw.send().await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn wait_for_prepare(&mut self) -> Result<RowDescriptionBody, Error> {
|
||||||
|
let Message::ParseComplete = self.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||||
|
let Message::ParameterDescription(_) = self.raw.next_message().await? else { return Err(Error::expecting("param description")) };
|
||||||
|
let Message::RowDescription(desc) = self.raw.next_message().await? else { return Err(Error::expecting("row description")) };
|
||||||
|
|
||||||
|
self.wait_for_ready().await?;
|
||||||
|
|
||||||
|
Ok(desc)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stream_query_results(&mut self) -> Result<RowStream<'_, S, T>, Error> {
|
||||||
|
// let Message::ParseComplete = self.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||||
|
let Message::BindComplete = self.raw.next_message().await? else { return Err(Error::expecting("bind")) };
|
||||||
|
Ok(RowStream::Stream(&mut self.raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn wait_for_ready(&mut self) -> Result<ReadyForQueryBody, Error> {
|
||||||
|
loop {
|
||||||
|
match self.raw.next_message().await.unwrap() {
|
||||||
|
Message::ReadyForQuery(b) => break Ok(b),
|
||||||
|
_ => continue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum RowStream<'a, S, T> {
|
||||||
|
Stream(&'a mut RawConnection<S, T>),
|
||||||
|
Complete(Option<CommandCompleteBody>),
|
||||||
|
}
|
||||||
|
impl<S, T> Unpin for RowStream<'_, S, T> {}
|
||||||
|
|
||||||
|
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> Stream
|
||||||
|
for RowStream<'_, S, T>
|
||||||
|
{
|
||||||
|
// this is horrible - first result is for transport/protocol errors errors
|
||||||
|
// second result is for sql errors.
|
||||||
|
type Item = Result<Result<DataRowBody, ErrorResponseBody>, Error>;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
match &mut *self {
|
||||||
|
RowStream::Stream(raw) => match ready!(raw.poll_read(cx)?) {
|
||||||
|
Message::DataRow(row) => Poll::Ready(Some(Ok(Ok(row)))),
|
||||||
|
Message::CommandComplete(tag) => {
|
||||||
|
*self = Self::Complete(Some(tag));
|
||||||
|
Poll::Ready(None)
|
||||||
|
}
|
||||||
|
Message::EmptyQueryResponse | Message::PortalSuspended => {
|
||||||
|
*self = Self::Complete(None);
|
||||||
|
Poll::Ready(None)
|
||||||
|
}
|
||||||
|
Message::ErrorResponse(error) => {
|
||||||
|
*self = Self::Complete(None);
|
||||||
|
Poll::Ready(Some(Ok(Err(error))))
|
||||||
|
}
|
||||||
|
_ => Poll::Ready(Some(Err(Error::expecting("command completion")))),
|
||||||
|
},
|
||||||
|
RowStream::Complete(_) => Poll::Ready(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, T> RowStream<'_, S, T> {
|
||||||
|
pub fn tag(self) -> Option<CommandCompleteBody> {
|
||||||
|
match self {
|
||||||
|
RowStream::Stream(_) => panic!("should not get tag unless row stream is exhausted"),
|
||||||
|
RowStream::Complete(tag) => tag,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
447
proxy/src/pg_client/error.rs
Normal file
447
proxy/src/pg_client/error.rs
Normal file
@@ -0,0 +1,447 @@
|
|||||||
|
use std::{error, fmt, io};
|
||||||
|
|
||||||
|
use fallible_iterator::FallibleIterator;
|
||||||
|
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
|
||||||
|
use tokio_native_tls::native_tls;
|
||||||
|
use tokio_postgres::error::{ErrorPosition, SqlState};
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
enum Kind {
|
||||||
|
Io,
|
||||||
|
Tls,
|
||||||
|
UnexpectedMessage,
|
||||||
|
FromSql(usize),
|
||||||
|
Closed,
|
||||||
|
Db,
|
||||||
|
Parse,
|
||||||
|
Encode,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ErrorInner {
|
||||||
|
kind: Kind,
|
||||||
|
cause: Option<Box<dyn error::Error + Sync + Send>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An error communicating with the Postgres server.
|
||||||
|
pub struct Error(ErrorInner);
|
||||||
|
|
||||||
|
impl fmt::Debug for Error {
|
||||||
|
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
fmt.debug_struct("Error")
|
||||||
|
.field("kind", &self.0.kind)
|
||||||
|
.field("cause", &self.0.cause)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Error {
|
||||||
|
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match &self.0.kind {
|
||||||
|
Kind::Io => fmt.write_str("error communicating with the server")?,
|
||||||
|
Kind::Tls => fmt.write_str("error establishing tls")?,
|
||||||
|
Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?,
|
||||||
|
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
|
||||||
|
Kind::Closed => fmt.write_str("connection closed")?,
|
||||||
|
Kind::Db => fmt.write_str("db error")?,
|
||||||
|
Kind::Parse => fmt.write_str("error parsing response from server")?,
|
||||||
|
Kind::Encode => fmt.write_str("error encoding message to server")?,
|
||||||
|
};
|
||||||
|
if let Some(ref cause) = self.0.cause {
|
||||||
|
write!(fmt, ": {}", cause)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl error::Error for Error {
|
||||||
|
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
|
||||||
|
self.0.cause.as_ref().map(|e| &**e as _)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<io::Error> for Error {
|
||||||
|
fn from(value: io::Error) -> Self {
|
||||||
|
Self::io(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error {
|
||||||
|
/// Consumes the error, returning its cause.
|
||||||
|
pub fn into_source(self) -> Option<Box<dyn error::Error + Sync + Send>> {
|
||||||
|
self.0.cause
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the source of this error if it was a `DbError`.
|
||||||
|
///
|
||||||
|
/// This is a simple convenience method.
|
||||||
|
pub fn as_db_error(&self) -> Option<&DbError> {
|
||||||
|
error::Error::source(self).and_then(|e| e.downcast_ref::<DbError>())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determines if the error was associated with closed connection.
|
||||||
|
pub fn is_closed(&self) -> bool {
|
||||||
|
self.0.kind == Kind::Closed
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the SQLSTATE error code associated with the error.
|
||||||
|
///
|
||||||
|
/// This is a convenience method that downcasts the cause to a `DbError` and returns its code.
|
||||||
|
pub fn code(&self) -> Option<&SqlState> {
|
||||||
|
self.as_db_error().map(DbError::code)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(kind: Kind, cause: Option<Box<dyn error::Error + Sync + Send>>) -> Error {
|
||||||
|
Error(ErrorInner { kind, cause })
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::needless_pass_by_value)]
|
||||||
|
pub(crate) fn db(error: ErrorResponseBody) -> Error {
|
||||||
|
match DbError::parse(&mut error.fields()) {
|
||||||
|
Ok(e) => Error::new(Kind::Db, Some(Box::new(e))),
|
||||||
|
Err(e) => Error::new(Kind::Parse, Some(Box::new(e))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn from_sql(e: Box<dyn error::Error + Sync + Send>, idx: usize) -> Error {
|
||||||
|
Error::new(Kind::FromSql(idx), Some(e))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn closed() -> Error {
|
||||||
|
Error::new(Kind::Closed, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn unexpected_message() -> Error {
|
||||||
|
Error::new(Kind::UnexpectedMessage, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn expecting(expected: &str) -> Error {
|
||||||
|
Error::new(Kind::UnexpectedMessage, Some(expected.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn parse(e: io::Error) -> Error {
|
||||||
|
Error::new(Kind::Parse, Some(Box::new(e)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn encode(e: io::Error) -> Error {
|
||||||
|
Error::new(Kind::Encode, Some(Box::new(e)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn io(e: io::Error) -> Error {
|
||||||
|
Error::new(Kind::Io, Some(Box::new(e)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn tls(e: native_tls::Error) -> Error {
|
||||||
|
Error::new(Kind::Tls, Some(Box::new(e)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The severity of a Postgres error or notice.
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||||
|
pub enum Severity {
|
||||||
|
/// PANIC
|
||||||
|
Panic,
|
||||||
|
/// FATAL
|
||||||
|
Fatal,
|
||||||
|
/// ERROR
|
||||||
|
Error,
|
||||||
|
/// WARNING
|
||||||
|
Warning,
|
||||||
|
/// NOTICE
|
||||||
|
Notice,
|
||||||
|
/// DEBUG
|
||||||
|
Debug,
|
||||||
|
/// INFO
|
||||||
|
Info,
|
||||||
|
/// LOG
|
||||||
|
Log,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Severity {
|
||||||
|
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
let s = match *self {
|
||||||
|
Severity::Panic => "PANIC",
|
||||||
|
Severity::Fatal => "FATAL",
|
||||||
|
Severity::Error => "ERROR",
|
||||||
|
Severity::Warning => "WARNING",
|
||||||
|
Severity::Notice => "NOTICE",
|
||||||
|
Severity::Debug => "DEBUG",
|
||||||
|
Severity::Info => "INFO",
|
||||||
|
Severity::Log => "LOG",
|
||||||
|
};
|
||||||
|
fmt.write_str(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Severity {
|
||||||
|
fn from_str(s: &str) -> Option<Severity> {
|
||||||
|
match s {
|
||||||
|
"PANIC" => Some(Severity::Panic),
|
||||||
|
"FATAL" => Some(Severity::Fatal),
|
||||||
|
"ERROR" => Some(Severity::Error),
|
||||||
|
"WARNING" => Some(Severity::Warning),
|
||||||
|
"NOTICE" => Some(Severity::Notice),
|
||||||
|
"DEBUG" => Some(Severity::Debug),
|
||||||
|
"INFO" => Some(Severity::Info),
|
||||||
|
"LOG" => Some(Severity::Log),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Postgres error or notice.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct DbError {
|
||||||
|
severity: String,
|
||||||
|
parsed_severity: Option<Severity>,
|
||||||
|
code: SqlState,
|
||||||
|
message: String,
|
||||||
|
detail: Option<String>,
|
||||||
|
hint: Option<String>,
|
||||||
|
position: Option<ErrorPosition>,
|
||||||
|
where_: Option<String>,
|
||||||
|
schema: Option<String>,
|
||||||
|
table: Option<String>,
|
||||||
|
column: Option<String>,
|
||||||
|
datatype: Option<String>,
|
||||||
|
constraint: Option<String>,
|
||||||
|
file: Option<String>,
|
||||||
|
line: Option<u32>,
|
||||||
|
routine: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DbError {
|
||||||
|
pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result<DbError> {
|
||||||
|
let mut severity = None;
|
||||||
|
let mut parsed_severity = None;
|
||||||
|
let mut code = None;
|
||||||
|
let mut message = None;
|
||||||
|
let mut detail = None;
|
||||||
|
let mut hint = None;
|
||||||
|
let mut normal_position = None;
|
||||||
|
let mut internal_position = None;
|
||||||
|
let mut internal_query = None;
|
||||||
|
let mut where_ = None;
|
||||||
|
let mut schema = None;
|
||||||
|
let mut table = None;
|
||||||
|
let mut column = None;
|
||||||
|
let mut datatype = None;
|
||||||
|
let mut constraint = None;
|
||||||
|
let mut file = None;
|
||||||
|
let mut line = None;
|
||||||
|
let mut routine = None;
|
||||||
|
|
||||||
|
while let Some(field) = fields.next()? {
|
||||||
|
match field.type_() {
|
||||||
|
b'S' => severity = Some(field.value().to_owned()),
|
||||||
|
b'C' => code = Some(SqlState::from_code(field.value())),
|
||||||
|
b'M' => message = Some(field.value().to_owned()),
|
||||||
|
b'D' => detail = Some(field.value().to_owned()),
|
||||||
|
b'H' => hint = Some(field.value().to_owned()),
|
||||||
|
b'P' => {
|
||||||
|
normal_position = Some(field.value().parse::<u32>().map_err(|_| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::InvalidInput,
|
||||||
|
"`P` field did not contain an integer",
|
||||||
|
)
|
||||||
|
})?);
|
||||||
|
}
|
||||||
|
b'p' => {
|
||||||
|
internal_position = Some(field.value().parse::<u32>().map_err(|_| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::InvalidInput,
|
||||||
|
"`p` field did not contain an integer",
|
||||||
|
)
|
||||||
|
})?);
|
||||||
|
}
|
||||||
|
b'q' => internal_query = Some(field.value().to_owned()),
|
||||||
|
b'W' => where_ = Some(field.value().to_owned()),
|
||||||
|
b's' => schema = Some(field.value().to_owned()),
|
||||||
|
b't' => table = Some(field.value().to_owned()),
|
||||||
|
b'c' => column = Some(field.value().to_owned()),
|
||||||
|
b'd' => datatype = Some(field.value().to_owned()),
|
||||||
|
b'n' => constraint = Some(field.value().to_owned()),
|
||||||
|
b'F' => file = Some(field.value().to_owned()),
|
||||||
|
b'L' => {
|
||||||
|
line = Some(field.value().parse::<u32>().map_err(|_| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::InvalidInput,
|
||||||
|
"`L` field did not contain an integer",
|
||||||
|
)
|
||||||
|
})?);
|
||||||
|
}
|
||||||
|
b'R' => routine = Some(field.value().to_owned()),
|
||||||
|
b'V' => {
|
||||||
|
parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::InvalidInput,
|
||||||
|
"`V` field contained an invalid value",
|
||||||
|
)
|
||||||
|
})?);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(DbError {
|
||||||
|
severity: severity
|
||||||
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?,
|
||||||
|
parsed_severity,
|
||||||
|
code: code
|
||||||
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?,
|
||||||
|
message: message
|
||||||
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?,
|
||||||
|
detail,
|
||||||
|
hint,
|
||||||
|
position: match normal_position {
|
||||||
|
Some(position) => Some(ErrorPosition::Original(position)),
|
||||||
|
None => match internal_position {
|
||||||
|
Some(position) => Some(ErrorPosition::Internal {
|
||||||
|
position,
|
||||||
|
query: internal_query.ok_or_else(|| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::InvalidInput,
|
||||||
|
"`q` field missing but `p` field present",
|
||||||
|
)
|
||||||
|
})?,
|
||||||
|
}),
|
||||||
|
None => None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
where_,
|
||||||
|
schema,
|
||||||
|
table,
|
||||||
|
column,
|
||||||
|
datatype,
|
||||||
|
constraint,
|
||||||
|
file,
|
||||||
|
line,
|
||||||
|
routine,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The field contents are ERROR, FATAL, or PANIC (in an error message),
|
||||||
|
/// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a
|
||||||
|
/// localized translation of one of these.
|
||||||
|
pub fn severity(&self) -> &str {
|
||||||
|
&self.severity
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A parsed, nonlocalized version of `severity`. (PostgreSQL 9.6+)
|
||||||
|
pub fn parsed_severity(&self) -> Option<Severity> {
|
||||||
|
self.parsed_severity
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The SQLSTATE code for the error.
|
||||||
|
pub fn code(&self) -> &SqlState {
|
||||||
|
&self.code
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The primary human-readable error message.
|
||||||
|
///
|
||||||
|
/// This should be accurate but terse (typically one line).
|
||||||
|
pub fn message(&self) -> &str {
|
||||||
|
&self.message
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An optional secondary error message carrying more detail about the
|
||||||
|
/// problem.
|
||||||
|
///
|
||||||
|
/// Might run to multiple lines.
|
||||||
|
pub fn detail(&self) -> Option<&str> {
|
||||||
|
self.detail.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An optional suggestion what to do about the problem.
|
||||||
|
///
|
||||||
|
/// This is intended to differ from `detail` in that it offers advice
|
||||||
|
/// (potentially inappropriate) rather than hard facts. Might run to
|
||||||
|
/// multiple lines.
|
||||||
|
pub fn hint(&self) -> Option<&str> {
|
||||||
|
self.hint.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An optional error cursor position into either the original query string
|
||||||
|
/// or an internally generated query.
|
||||||
|
pub fn position(&self) -> Option<&ErrorPosition> {
|
||||||
|
self.position.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An indication of the context in which the error occurred.
|
||||||
|
///
|
||||||
|
/// Presently this includes a call stack traceback of active procedural
|
||||||
|
/// language functions and internally-generated queries. The trace is one
|
||||||
|
/// entry per line, most recent first.
|
||||||
|
pub fn where_(&self) -> Option<&str> {
|
||||||
|
self.where_.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If the error was associated with a specific database object, the name
|
||||||
|
/// of the schema containing that object, if any. (PostgreSQL 9.3+)
|
||||||
|
pub fn schema(&self) -> Option<&str> {
|
||||||
|
self.schema.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If the error was associated with a specific table, the name of the
|
||||||
|
/// table. (Refer to the schema name field for the name of the table's
|
||||||
|
/// schema.) (PostgreSQL 9.3+)
|
||||||
|
pub fn table(&self) -> Option<&str> {
|
||||||
|
self.table.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If the error was associated with a specific table column, the name of
|
||||||
|
/// the column.
|
||||||
|
///
|
||||||
|
/// (Refer to the schema and table name fields to identify the table.)
|
||||||
|
/// (PostgreSQL 9.3+)
|
||||||
|
pub fn column(&self) -> Option<&str> {
|
||||||
|
self.column.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If the error was associated with a specific data type, the name of the
|
||||||
|
/// data type. (Refer to the schema name field for the name of the data
|
||||||
|
/// type's schema.) (PostgreSQL 9.3+)
|
||||||
|
pub fn datatype(&self) -> Option<&str> {
|
||||||
|
self.datatype.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If the error was associated with a specific constraint, the name of the
|
||||||
|
/// constraint.
|
||||||
|
///
|
||||||
|
/// Refer to fields listed above for the associated table or domain.
|
||||||
|
/// (For this purpose, indexes are treated as constraints, even if they
|
||||||
|
/// weren't created with constraint syntax.) (PostgreSQL 9.3+)
|
||||||
|
pub fn constraint(&self) -> Option<&str> {
|
||||||
|
self.constraint.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The file name of the source-code location where the error was reported.
|
||||||
|
pub fn file(&self) -> Option<&str> {
|
||||||
|
self.file.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The line number of the source-code location where the error was
|
||||||
|
/// reported.
|
||||||
|
pub fn line(&self) -> Option<u32> {
|
||||||
|
self.line
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The name of the source-code routine reporting the error.
|
||||||
|
pub fn routine(&self) -> Option<&str> {
|
||||||
|
self.routine.as_deref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for DbError {
|
||||||
|
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(fmt, "{}: {}", self.severity, self.message)?;
|
||||||
|
if let Some(detail) = &self.detail {
|
||||||
|
write!(fmt, "\nDETAIL: {}", detail)?;
|
||||||
|
}
|
||||||
|
if let Some(hint) = &self.hint {
|
||||||
|
write!(fmt, "\nHINT: {}", hint)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl error::Error for DbError {}
|
||||||
5
proxy/src/pg_client/mod.rs
Normal file
5
proxy/src/pg_client/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
|
||||||
|
pub mod codec;
|
||||||
|
pub mod connection;
|
||||||
|
pub mod error;
|
||||||
|
pub mod prepare;
|
||||||
293
proxy/src/pg_client/prepare.rs
Normal file
293
proxy/src/pg_client/prepare.rs
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
use fallible_iterator::FallibleIterator;
|
||||||
|
use futures::StreamExt;
|
||||||
|
use postgres_protocol::message::backend::{DataRowRanges, Message};
|
||||||
|
use postgres_protocol::message::frontend;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
use tokio_postgres::types::{Field, Kind, Oid, ToSql, Type};
|
||||||
|
|
||||||
|
use super::connection::Connection;
|
||||||
|
use super::error::Error;
|
||||||
|
|
||||||
|
const TYPEINFO_QUERY: &str = "\
|
||||||
|
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
|
||||||
|
FROM pg_catalog.pg_type t
|
||||||
|
LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
|
||||||
|
INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
|
||||||
|
WHERE t.oid = $1
|
||||||
|
";
|
||||||
|
|
||||||
|
const TYPEINFO_ENUM_QUERY: &str = "\
|
||||||
|
SELECT enumlabel
|
||||||
|
FROM pg_catalog.pg_enum
|
||||||
|
WHERE enumtypid = $1
|
||||||
|
ORDER BY enumsortorder
|
||||||
|
";
|
||||||
|
|
||||||
|
const TYPEINFO_COMPOSITE_QUERY: &str = "\
|
||||||
|
SELECT attname, atttypid
|
||||||
|
FROM pg_catalog.pg_attribute
|
||||||
|
WHERE attrelid = $1
|
||||||
|
AND NOT attisdropped
|
||||||
|
AND attnum > 0
|
||||||
|
ORDER BY attnum
|
||||||
|
";
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct TypeinfoPreparedQueries {
|
||||||
|
query: String,
|
||||||
|
enum_query: String,
|
||||||
|
composite_query: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_is_null(x: tokio_postgres::types::IsNull) -> postgres_protocol::IsNull {
|
||||||
|
match x {
|
||||||
|
tokio_postgres::types::IsNull::Yes => postgres_protocol::IsNull::Yes,
|
||||||
|
tokio_postgres::types::IsNull::No => postgres_protocol::IsNull::No,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_column<'a, T: tokio_postgres::types::FromSql<'a>>(
|
||||||
|
buffer: &'a [u8],
|
||||||
|
type_: &Type,
|
||||||
|
ranges: &mut DataRowRanges<'a>,
|
||||||
|
) -> Result<T, Error> {
|
||||||
|
let range = ranges.next()?;
|
||||||
|
match range {
|
||||||
|
Some(range) => T::from_sql_nullable(type_, range.map(|r| &buffer[r])),
|
||||||
|
None => T::from_sql_null(type_),
|
||||||
|
}
|
||||||
|
.map_err(|e| Error::from_sql(e, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TypeinfoPreparedQueries {
|
||||||
|
pub async fn new<
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
>(
|
||||||
|
c: &mut Connection<S, T>,
|
||||||
|
) -> Result<Self, Error> {
|
||||||
|
if let Some(ti) = &c.typeinfo {
|
||||||
|
return Ok(ti.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let query = c.statement_name();
|
||||||
|
let enum_query = c.statement_name();
|
||||||
|
let composite_query = c.statement_name();
|
||||||
|
|
||||||
|
frontend::parse(&query, TYPEINFO_QUERY, [Type::OID.oid()], &mut c.raw.buf)?;
|
||||||
|
frontend::parse(
|
||||||
|
&enum_query,
|
||||||
|
TYPEINFO_ENUM_QUERY,
|
||||||
|
[Type::OID.oid()],
|
||||||
|
&mut c.raw.buf,
|
||||||
|
)?;
|
||||||
|
c.sync().await?;
|
||||||
|
frontend::parse(
|
||||||
|
&composite_query,
|
||||||
|
TYPEINFO_COMPOSITE_QUERY,
|
||||||
|
[Type::OID.oid()],
|
||||||
|
&mut c.raw.buf,
|
||||||
|
)?;
|
||||||
|
c.sync().await?;
|
||||||
|
|
||||||
|
let Message::ParseComplete = c.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||||
|
let Message::ParseComplete = c.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||||
|
let Message::ParseComplete = c.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||||
|
c.wait_for_ready().await?;
|
||||||
|
|
||||||
|
Ok(c.typeinfo
|
||||||
|
.insert(TypeinfoPreparedQueries {
|
||||||
|
query,
|
||||||
|
enum_query,
|
||||||
|
composite_query,
|
||||||
|
})
|
||||||
|
.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_type_rec<
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
>(
|
||||||
|
c: &mut Connection<S, T>,
|
||||||
|
oid: Oid,
|
||||||
|
) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + '_>> {
|
||||||
|
Box::pin(Self::get_type(c, oid))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_type<
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
>(
|
||||||
|
c: &mut Connection<S, T>,
|
||||||
|
oid: Oid,
|
||||||
|
) -> Result<Type, Error> {
|
||||||
|
if let Some(type_) = Type::from_oid(oid) {
|
||||||
|
return Ok(type_);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(type_) = c.typecache.get(&oid) {
|
||||||
|
return Ok(type_.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let queries = Self::new(c).await?;
|
||||||
|
|
||||||
|
frontend::bind(
|
||||||
|
"",
|
||||||
|
&queries.query,
|
||||||
|
[1], // the only parameter is in binary format
|
||||||
|
[oid],
|
||||||
|
|param, buf| param.to_sql(&Type::OID, buf).map(map_is_null),
|
||||||
|
Some(1), // binary return type
|
||||||
|
&mut c.raw.buf,
|
||||||
|
)
|
||||||
|
.map_err(|e| match e {
|
||||||
|
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
|
||||||
|
frontend::BindError::Serialization(io) => io,
|
||||||
|
})?;
|
||||||
|
frontend::execute("", 0, &mut c.raw.buf)?;
|
||||||
|
|
||||||
|
c.sync().await?;
|
||||||
|
|
||||||
|
let mut stream = c.stream_query_results().await?;
|
||||||
|
|
||||||
|
let Some(row) = stream.next().await.transpose()? else {
|
||||||
|
todo!()
|
||||||
|
};
|
||||||
|
|
||||||
|
let row = row.map_err(Error::db)?;
|
||||||
|
let b = row.buffer();
|
||||||
|
let mut ranges = row.ranges();
|
||||||
|
|
||||||
|
let name: String = read_column(b, &Type::NAME, &mut ranges)?;
|
||||||
|
let type_: i8 = read_column(b, &Type::CHAR, &mut ranges)?;
|
||||||
|
let elem_oid: Oid = read_column(b, &Type::OID, &mut ranges)?;
|
||||||
|
let rngsubtype: Option<Oid> = read_column(b, &Type::OID, &mut ranges)?;
|
||||||
|
let basetype: Oid = read_column(b, &Type::OID, &mut ranges)?;
|
||||||
|
let schema: String = read_column(b, &Type::NAME, &mut ranges)?;
|
||||||
|
let relid: Oid = read_column(b, &Type::OID, &mut ranges)?;
|
||||||
|
|
||||||
|
{
|
||||||
|
// should be none
|
||||||
|
let None = stream.next().await.transpose()? else {
|
||||||
|
todo!()
|
||||||
|
};
|
||||||
|
drop(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
let kind = if type_ == b'e' as i8 {
|
||||||
|
let variants = Self::get_enum_variants(c, oid).await?;
|
||||||
|
Kind::Enum(variants)
|
||||||
|
} else if type_ == b'p' as i8 {
|
||||||
|
Kind::Pseudo
|
||||||
|
} else if basetype != 0 {
|
||||||
|
let type_ = Self::get_type_rec(c, basetype).await?;
|
||||||
|
Kind::Domain(type_)
|
||||||
|
} else if elem_oid != 0 {
|
||||||
|
let type_ = Self::get_type_rec(c, elem_oid).await?;
|
||||||
|
Kind::Array(type_)
|
||||||
|
} else if relid != 0 {
|
||||||
|
let fields = Self::get_composite_fields(c, relid).await?;
|
||||||
|
Kind::Composite(fields)
|
||||||
|
} else if let Some(rngsubtype) = rngsubtype {
|
||||||
|
let type_ = Self::get_type_rec(c, rngsubtype).await?;
|
||||||
|
Kind::Range(type_)
|
||||||
|
} else {
|
||||||
|
Kind::Simple
|
||||||
|
};
|
||||||
|
|
||||||
|
let type_ = Type::new(name, oid, kind, schema);
|
||||||
|
c.typecache.insert(oid, type_.clone());
|
||||||
|
|
||||||
|
Ok(type_)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_enum_variants<
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
>(
|
||||||
|
c: &mut Connection<S, T>,
|
||||||
|
oid: Oid,
|
||||||
|
) -> Result<Vec<String>, Error> {
|
||||||
|
let queries = Self::new(c).await?;
|
||||||
|
|
||||||
|
frontend::bind(
|
||||||
|
"",
|
||||||
|
&queries.enum_query,
|
||||||
|
[1], // the only parameter is in binary format
|
||||||
|
[oid],
|
||||||
|
|param, buf| param.to_sql(&Type::OID, buf).map(map_is_null),
|
||||||
|
Some(1), // binary return type
|
||||||
|
&mut c.raw.buf,
|
||||||
|
)
|
||||||
|
.map_err(|e| match e {
|
||||||
|
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
|
||||||
|
frontend::BindError::Serialization(io) => io,
|
||||||
|
})?;
|
||||||
|
frontend::execute("", 0, &mut c.raw.buf)?;
|
||||||
|
|
||||||
|
c.sync().await?;
|
||||||
|
|
||||||
|
let mut stream = c.stream_query_results().await?;
|
||||||
|
let mut variants = Vec::new();
|
||||||
|
while let Some(row) = stream.next().await.transpose()? {
|
||||||
|
let row = row.map_err(Error::db)?;
|
||||||
|
|
||||||
|
let variant: String = read_column(row.buffer(), &Type::NAME, &mut row.ranges())?;
|
||||||
|
variants.push(variant);
|
||||||
|
}
|
||||||
|
|
||||||
|
c.wait_for_ready().await?;
|
||||||
|
|
||||||
|
Ok(variants)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_composite_fields<
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||||
|
>(
|
||||||
|
c: &mut Connection<S, T>,
|
||||||
|
oid: Oid,
|
||||||
|
) -> Result<Vec<Field>, Error> {
|
||||||
|
let queries = Self::new(c).await?;
|
||||||
|
|
||||||
|
frontend::bind(
|
||||||
|
"",
|
||||||
|
&queries.composite_query,
|
||||||
|
[1], // the only parameter is in binary format
|
||||||
|
[oid],
|
||||||
|
|param, buf| param.to_sql(&Type::OID, buf).map(map_is_null),
|
||||||
|
Some(1), // binary return type
|
||||||
|
&mut c.raw.buf,
|
||||||
|
)
|
||||||
|
.map_err(|e| match e {
|
||||||
|
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
|
||||||
|
frontend::BindError::Serialization(io) => io,
|
||||||
|
})?;
|
||||||
|
frontend::execute("", 0, &mut c.raw.buf)?;
|
||||||
|
|
||||||
|
c.sync().await?;
|
||||||
|
|
||||||
|
let mut stream = c.stream_query_results().await?;
|
||||||
|
let mut fields = Vec::new();
|
||||||
|
while let Some(row) = stream.next().await.transpose()? {
|
||||||
|
let row = row.map_err(Error::db)?;
|
||||||
|
|
||||||
|
let mut ranges = row.ranges();
|
||||||
|
let name: String = read_column(row.buffer(), &Type::NAME, &mut ranges)?;
|
||||||
|
let oid: Oid = read_column(row.buffer(), &Type::OID, &mut ranges)?;
|
||||||
|
fields.push((name, oid));
|
||||||
|
}
|
||||||
|
|
||||||
|
c.wait_for_ready().await?;
|
||||||
|
|
||||||
|
let mut output_fields = Vec::with_capacity(fields.len());
|
||||||
|
for (name, oid) in fields {
|
||||||
|
let type_ = Self::get_type_rec(c, oid).await?;
|
||||||
|
output_fields.push(Field::new(name, type_))
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(output_fields)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -99,9 +99,8 @@ struct Scram(scram::ServerSecret);
|
|||||||
|
|
||||||
impl Scram {
|
impl Scram {
|
||||||
fn new(password: &str) -> anyhow::Result<Self> {
|
fn new(password: &str) -> anyhow::Result<Self> {
|
||||||
let salt = rand::random::<[u8; 16]>();
|
let secret =
|
||||||
let secret = scram::ServerSecret::build(password, &salt, 256)
|
scram::ServerSecret::build(password).context("failed to generate scram secret")?;
|
||||||
.context("failed to generate scram secret")?;
|
|
||||||
Ok(Scram(secret))
|
Ok(Scram(secret))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,9 +12,6 @@ mod messages;
|
|||||||
mod secret;
|
mod secret;
|
||||||
mod signature;
|
mod signature;
|
||||||
|
|
||||||
#[cfg(any(test, doc))]
|
|
||||||
mod password;
|
|
||||||
|
|
||||||
pub use exchange::Exchange;
|
pub use exchange::Exchange;
|
||||||
pub use key::ScramKey;
|
pub use key::ScramKey;
|
||||||
pub use secret::ServerSecret;
|
pub use secret::ServerSecret;
|
||||||
@@ -57,27 +54,21 @@ fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
|
||||||
|
|
||||||
use crate::sasl::{Mechanism, Step};
|
use crate::sasl::{Mechanism, Step};
|
||||||
|
|
||||||
use super::{password::SaltedPassword, Exchange, ServerSecret};
|
use super::{Exchange, ServerSecret};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn happy_path() {
|
fn snapshot() {
|
||||||
let iterations = 4096;
|
let iterations = 4096;
|
||||||
let salt_base64 = "QSXCR+Q6sek8bf92";
|
let salt = "QSXCR+Q6sek8bf92";
|
||||||
let pw = SaltedPassword::new(
|
let stored_key = "FO+9jBb3MUukt6jJnzjPZOWc5ow/Pu6JtPyju0aqaE8=";
|
||||||
b"pencil",
|
let server_key = "qxJ1SbmSAi5EcS0J5Ck/cKAm/+Ixa+Kwp63f4OHDgzo=";
|
||||||
base64::decode(salt_base64).unwrap().as_slice(),
|
let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",);
|
||||||
iterations,
|
let secret = ServerSecret::parse(&secret).unwrap();
|
||||||
);
|
|
||||||
|
|
||||||
let secret = ServerSecret {
|
|
||||||
iterations,
|
|
||||||
salt_base64: salt_base64.to_owned(),
|
|
||||||
stored_key: pw.client_key().sha256(),
|
|
||||||
server_key: pw.server_key(),
|
|
||||||
doomed: false,
|
|
||||||
};
|
|
||||||
const NONCE: [u8; 18] = [
|
const NONCE: [u8; 18] = [
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||||
];
|
];
|
||||||
@@ -115,4 +106,40 @@ mod tests {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn run_round_trip_test(client_password: &str) {
|
||||||
|
let secret = ServerSecret::build("pencil").unwrap();
|
||||||
|
let mut exchange = Exchange::new(&secret, rand::random, None);
|
||||||
|
|
||||||
|
let mut client =
|
||||||
|
ScramSha256::new(client_password.as_bytes(), ChannelBinding::unsupported());
|
||||||
|
|
||||||
|
let client_first = std::str::from_utf8(client.message()).unwrap();
|
||||||
|
exchange = match exchange.exchange(client_first).unwrap() {
|
||||||
|
Step::Continue(exchange, message) => {
|
||||||
|
client.update(message.as_bytes()).unwrap();
|
||||||
|
exchange
|
||||||
|
}
|
||||||
|
Step::Success(_, _) => panic!("expected continue, got success"),
|
||||||
|
Step::Failure(f) => panic!("{f}"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let client_final = std::str::from_utf8(client.message()).unwrap();
|
||||||
|
match exchange.exchange(client_final).unwrap() {
|
||||||
|
Step::Success(_, message) => client.finish(message.as_bytes()).unwrap(),
|
||||||
|
Step::Continue(_, _) => panic!("expected success, got continue"),
|
||||||
|
Step::Failure(f) => panic!("{f}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn round_trip() {
|
||||||
|
run_round_trip_test("pencil")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "password doesn't match")]
|
||||||
|
fn failure() {
|
||||||
|
run_round_trip_test("eraser")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
/// Faithfully taken from PostgreSQL.
|
/// Faithfully taken from PostgreSQL.
|
||||||
pub const SCRAM_KEY_LEN: usize = 32;
|
pub const SCRAM_KEY_LEN: usize = 32;
|
||||||
|
|
||||||
/// One of the keys derived from the [password](super::password::SaltedPassword).
|
/// One of the keys derived from the user's password.
|
||||||
/// We use the same structure for all keys, i.e.
|
/// We use the same structure for all keys, i.e.
|
||||||
/// `ClientKey`, `StoredKey`, and `ServerKey`.
|
/// `ClientKey`, `StoredKey`, and `ServerKey`.
|
||||||
#[derive(Default, PartialEq, Eq)]
|
#[derive(Default, PartialEq, Eq)]
|
||||||
|
|||||||
@@ -1,74 +0,0 @@
|
|||||||
//! Password hashing routines.
|
|
||||||
|
|
||||||
use super::key::ScramKey;
|
|
||||||
|
|
||||||
pub const SALTED_PASSWORD_LEN: usize = 32;
|
|
||||||
|
|
||||||
/// Salted hashed password is essential for [key](super::key) derivation.
|
|
||||||
#[repr(transparent)]
|
|
||||||
pub struct SaltedPassword {
|
|
||||||
bytes: [u8; SALTED_PASSWORD_LEN],
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SaltedPassword {
|
|
||||||
/// See `scram-common.c : scram_SaltedPassword` for details.
|
|
||||||
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2898> (see `PBKDF2`).
|
|
||||||
pub fn new(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword {
|
|
||||||
pbkdf2::pbkdf2_hmac_array::<sha2::Sha256, 32>(password, salt, iterations).into()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Derive `ClientKey` from a salted hashed password.
|
|
||||||
pub fn client_key(&self) -> ScramKey {
|
|
||||||
super::hmac_sha256(&self.bytes, [b"Client Key".as_ref()]).into()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Derive `ServerKey` from a salted hashed password.
|
|
||||||
pub fn server_key(&self) -> ScramKey {
|
|
||||||
super::hmac_sha256(&self.bytes, [b"Server Key".as_ref()]).into()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<[u8; SALTED_PASSWORD_LEN]> for SaltedPassword {
|
|
||||||
#[inline(always)]
|
|
||||||
fn from(bytes: [u8; SALTED_PASSWORD_LEN]) -> Self {
|
|
||||||
Self { bytes }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::SaltedPassword;
|
|
||||||
|
|
||||||
fn legacy_pbkdf2_impl(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword {
|
|
||||||
let one = 1_u32.to_be_bytes(); // magic
|
|
||||||
|
|
||||||
let mut current = super::super::hmac_sha256(password, [salt, &one]);
|
|
||||||
let mut result = current;
|
|
||||||
for _ in 1..iterations {
|
|
||||||
current = super::super::hmac_sha256(password, [current.as_ref()]);
|
|
||||||
// TODO: result = current.zip(result).map(|(x, y)| x ^ y), issue #80094
|
|
||||||
for (i, x) in current.iter().enumerate() {
|
|
||||||
result[i] ^= x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result.into()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn pbkdf2() {
|
|
||||||
let password = "a-very-secure-password";
|
|
||||||
let salt = "such-a-random-salt";
|
|
||||||
let iterations = 4096;
|
|
||||||
let output = [
|
|
||||||
203, 18, 206, 81, 4, 154, 193, 100, 147, 41, 211, 217, 177, 203, 69, 210, 194, 211,
|
|
||||||
101, 1, 248, 156, 96, 0, 8, 223, 30, 87, 158, 41, 20, 42,
|
|
||||||
];
|
|
||||||
|
|
||||||
let actual = SaltedPassword::new(password.as_bytes(), salt.as_bytes(), iterations);
|
|
||||||
let expected = legacy_pbkdf2_impl(password.as_bytes(), salt.as_bytes(), iterations);
|
|
||||||
|
|
||||||
assert_eq!(actual.bytes, output);
|
|
||||||
assert_eq!(actual.bytes, expected.bytes);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
use super::base64_decode_array;
|
use super::base64_decode_array;
|
||||||
use super::key::ScramKey;
|
use super::key::ScramKey;
|
||||||
|
|
||||||
/// Server secret is produced from [password](super::password::SaltedPassword)
|
/// Server secret is produced from user's password,
|
||||||
/// and is used throughout the authentication process.
|
/// and is used throughout the authentication process.
|
||||||
pub struct ServerSecret {
|
pub struct ServerSecret {
|
||||||
/// Number of iterations for `PBKDF2` function.
|
/// Number of iterations for `PBKDF2` function.
|
||||||
@@ -58,21 +58,10 @@ impl ServerSecret {
|
|||||||
/// Build a new server secret from the prerequisites.
|
/// Build a new server secret from the prerequisites.
|
||||||
/// XXX: We only use this function in tests.
|
/// XXX: We only use this function in tests.
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub fn build(password: &str, salt: &[u8], iterations: u32) -> Option<Self> {
|
pub fn build(password: &str) -> Option<Self> {
|
||||||
// TODO: implement proper password normalization required by the RFC
|
Self::parse(&postgres_protocol::password::scram_sha_256(
|
||||||
if !password.is_ascii() {
|
password.as_bytes(),
|
||||||
return None;
|
))
|
||||||
}
|
|
||||||
|
|
||||||
let password = super::password::SaltedPassword::new(password.as_bytes(), salt, iterations);
|
|
||||||
|
|
||||||
Some(Self {
|
|
||||||
iterations,
|
|
||||||
salt_base64: base64::encode(salt),
|
|
||||||
stored_key: password.client_key().sha256(),
|
|
||||||
server_key: password.server_key(),
|
|
||||||
doomed: false,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,20 +91,4 @@ mod tests {
|
|||||||
assert_eq!(base64::encode(parsed.stored_key), stored_key);
|
assert_eq!(base64::encode(parsed.stored_key), stored_key);
|
||||||
assert_eq!(base64::encode(parsed.server_key), server_key);
|
assert_eq!(base64::encode(parsed.server_key), server_key);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn build_scram_secret() {
|
|
||||||
let salt = b"salt";
|
|
||||||
let secret = ServerSecret::build("password", salt, 4096).unwrap();
|
|
||||||
assert_eq!(secret.iterations, 4096);
|
|
||||||
assert_eq!(secret.salt_base64, base64::encode(salt));
|
|
||||||
assert_eq!(
|
|
||||||
base64::encode(secret.stored_key.as_ref()),
|
|
||||||
"lF4cRm/Jky763CN4HtxdHnjV4Q8AWTNlKvGmEFFU8IQ="
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
base64::encode(secret.server_key.as_ref()),
|
|
||||||
"ub8OgRsftnk2ccDMOt7ffHXNcikRkQkq1lh4xaAqrSw="
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user