mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-18 19:02:56 +00:00
Compare commits
10 Commits
conrad/pro
...
proxy/remo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
101e770632 | ||
|
|
747ffa50d6 | ||
|
|
e8400d9d93 | ||
|
|
17627e8023 | ||
|
|
4fb5cdbdb8 | ||
|
|
70b503f83b | ||
|
|
21c15c4285 | ||
|
|
0a524e09a5 | ||
|
|
cbe24f7c35 | ||
|
|
64add503c8 |
14
Cargo.lock
generated
14
Cargo.lock
generated
@@ -2654,16 +2654,6 @@ dependencies = [
|
||||
"windows-sys 0.45.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pbkdf2"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0ca0b5a68607598bf3bad68f32227a8164f6254833f84eafaac409cd6746c31"
|
||||
dependencies = [
|
||||
"digest",
|
||||
"hmac",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "peeking_take_while"
|
||||
version = "0.1.2"
|
||||
@@ -3040,6 +3030,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"clap",
|
||||
"consumption_metrics",
|
||||
"fallible-iterator",
|
||||
"futures",
|
||||
"git-version",
|
||||
"hashbrown 0.13.2",
|
||||
@@ -3057,9 +3048,9 @@ dependencies = [
|
||||
"once_cell",
|
||||
"opentelemetry",
|
||||
"parking_lot 0.12.1",
|
||||
"pbkdf2",
|
||||
"pin-project-lite",
|
||||
"postgres-native-tls",
|
||||
"postgres-protocol",
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
"prometheus",
|
||||
@@ -3083,6 +3074,7 @@ dependencies = [
|
||||
"thiserror",
|
||||
"tls-listener",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls 0.23.4",
|
||||
|
||||
@@ -88,7 +88,6 @@ opentelemetry = "0.19.0"
|
||||
opentelemetry-otlp = { version = "0.12.0", default_features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
|
||||
opentelemetry-semantic-conventions = "0.11.0"
|
||||
parking_lot = "0.12"
|
||||
pbkdf2 = "0.12.1"
|
||||
pin-project-lite = "0.2"
|
||||
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
|
||||
prost = "0.11"
|
||||
|
||||
@@ -29,9 +29,9 @@ metrics.workspace = true
|
||||
once_cell.workspace = true
|
||||
opentelemetry.workspace = true
|
||||
parking_lot.workspace = true
|
||||
pbkdf2.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
pq_proto.workspace = true
|
||||
prometheus.workspace = true
|
||||
rand.workspace = true
|
||||
@@ -65,10 +65,13 @@ webpki-roots.workspace = true
|
||||
x509-parser.workspace = true
|
||||
native-tls.workspace = true
|
||||
postgres-native-tls.workspace = true
|
||||
tokio-native-tls = "0.3.1"
|
||||
|
||||
workspace_hack.workspace = true
|
||||
tokio-util.workspace = true
|
||||
|
||||
fallible-iterator = "0.2.0"
|
||||
|
||||
[dev-dependencies]
|
||||
rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::fmt;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::time;
|
||||
|
||||
use crate::{auth, console};
|
||||
use crate::{auth, console, pg_client};
|
||||
use crate::{compute, config};
|
||||
|
||||
use super::sql_over_http::MAX_RESPONSE_SIZE;
|
||||
@@ -41,8 +41,10 @@ impl fmt::Display for ConnInfo {
|
||||
}
|
||||
}
|
||||
|
||||
type PgConn =
|
||||
pg_client::connection::Connection<tokio_postgres::Socket, tokio_postgres::tls::NoTlsStream>;
|
||||
struct ConnPoolEntry {
|
||||
conn: tokio_postgres::Client,
|
||||
conn: PgConn,
|
||||
_last_access: std::time::Instant,
|
||||
}
|
||||
|
||||
@@ -78,12 +80,8 @@ impl GlobalConnPool {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
&self,
|
||||
conn_info: &ConnInfo,
|
||||
force_new: bool,
|
||||
) -> anyhow::Result<tokio_postgres::Client> {
|
||||
let mut client: Option<tokio_postgres::Client> = None;
|
||||
pub async fn get(&self, conn_info: &ConnInfo, force_new: bool) -> anyhow::Result<PgConn> {
|
||||
let mut client: Option<PgConn> = None;
|
||||
|
||||
if !force_new {
|
||||
let pool = self.get_endpoint_pool(&conn_info.hostname).await;
|
||||
@@ -114,11 +112,7 @@ impl GlobalConnPool {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn put(
|
||||
&self,
|
||||
conn_info: &ConnInfo,
|
||||
client: tokio_postgres::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
pub async fn put(&self, conn_info: &ConnInfo, client: PgConn) -> anyhow::Result<()> {
|
||||
let pool = self.get_endpoint_pool(&conn_info.hostname).await;
|
||||
|
||||
// return connection to the pool
|
||||
@@ -191,7 +185,7 @@ struct TokioMechanism<'a> {
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for TokioMechanism<'_> {
|
||||
type Connection = tokio_postgres::Client;
|
||||
type Connection = PgConn;
|
||||
type ConnectError = tokio_postgres::Error;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
@@ -213,7 +207,7 @@ impl ConnectMechanism for TokioMechanism<'_> {
|
||||
async fn connect_to_compute(
|
||||
config: &config::ProxyConfig,
|
||||
conn_info: &ConnInfo,
|
||||
) -> anyhow::Result<tokio_postgres::Client> {
|
||||
) -> anyhow::Result<PgConn> {
|
||||
let tls = config.tls_config.as_ref();
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
|
||||
@@ -251,7 +245,7 @@ async fn connect_to_compute_once(
|
||||
node_info: &console::CachedNodeInfo,
|
||||
conn_info: &ConnInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<tokio_postgres::Client, tokio_postgres::Error> {
|
||||
) -> Result<PgConn, tokio_postgres::Error> {
|
||||
let mut config = (*node_info.config).clone();
|
||||
|
||||
let (client, connection) = config
|
||||
@@ -263,11 +257,13 @@ async fn connect_to_compute_once(
|
||||
.connect(tokio_postgres::NoTls)
|
||||
.await?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
error!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
let stream = connection.stream.into_inner();
|
||||
|
||||
Ok(client)
|
||||
// tokio::spawn(async move {
|
||||
// if let Err(e) = connection.await {
|
||||
// error!("connection error: {}", e);
|
||||
// }
|
||||
// });
|
||||
|
||||
Ok(PgConn::new(stream))
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use std::io::ErrorKind;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use bytes::BufMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
use hashbrown::HashMap;
|
||||
@@ -8,16 +11,28 @@ use hyper::body::HttpBody;
|
||||
use hyper::http::HeaderName;
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::{Body, HeaderMap, Request};
|
||||
use postgres_protocol::message::backend::DataRowBody;
|
||||
use postgres_protocol::message::backend::ReadyForQueryBody;
|
||||
use serde_json::json;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio_postgres::types::Kind;
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::GenericClient;
|
||||
use tokio_postgres::IsolationLevel;
|
||||
use tokio_postgres::Row;
|
||||
use tokio_postgres::RowStream;
|
||||
use tokio_postgres::Statement;
|
||||
use url::Url;
|
||||
|
||||
use crate::pg_client;
|
||||
use crate::pg_client::codec::FrontendMessage;
|
||||
use crate::pg_client::connection;
|
||||
use crate::pg_client::connection::RequestMessages;
|
||||
use crate::pg_client::prepare::TypeinfoPreparedQueries;
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::conn_pool::GlobalConnPool;
|
||||
|
||||
@@ -230,30 +245,35 @@ pub async fn handle(
|
||||
// Now execute the query and return the result
|
||||
//
|
||||
let result = match payload {
|
||||
Payload::Single(query) => query_to_json(&client, query, raw_output, array_mode)
|
||||
Payload::Single(query) => query_raw_txt_as_json(&mut client, query, raw_output, array_mode)
|
||||
.await
|
||||
.map(|x| (x, HashMap::default())),
|
||||
Payload::Batch(queries) => {
|
||||
let mut results = Vec::new();
|
||||
let mut builder = client.build_transaction();
|
||||
if let Some(isolation_level) = txn_isolation_level {
|
||||
builder = builder.isolation_level(isolation_level);
|
||||
}
|
||||
if txn_read_only {
|
||||
builder = builder.read_only(true);
|
||||
}
|
||||
let transaction = builder.start().await?;
|
||||
|
||||
client
|
||||
.start_tx(txn_isolation_level, Some(txn_read_only))
|
||||
.await?;
|
||||
|
||||
for query in queries {
|
||||
let result = query_to_json(&transaction, query, raw_output, array_mode).await;
|
||||
let result =
|
||||
query_raw_txt_as_json(&mut client, query, raw_output, array_mode).await;
|
||||
match result {
|
||||
Ok(r) => results.push(r),
|
||||
// TODO: check this tag to see if the client has executed a commit during the non-interactive transactions...
|
||||
Ok((r, _ready_tag)) => results.push(r),
|
||||
Err(e) => {
|
||||
transaction.rollback().await?;
|
||||
let tag = client.rollback().await?;
|
||||
if allow_pool && tag.status() == b'I' {
|
||||
// return connection to the pool
|
||||
tokio::task::spawn(async move {
|
||||
let _ = conn_pool.put(&conn_info, client).await;
|
||||
});
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
transaction.commit().await?;
|
||||
let ready_tag = client.commit().await?;
|
||||
let mut headers = HashMap::default();
|
||||
headers.insert(
|
||||
TXN_READ_ONLY.clone(),
|
||||
@@ -262,11 +282,11 @@ pub async fn handle(
|
||||
if let Some(txn_isolation_level_raw) = txn_isolation_level_raw {
|
||||
headers.insert(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level_raw);
|
||||
}
|
||||
Ok((json!({ "results": results }), headers))
|
||||
Ok(((json!({ "results": results }), ready_tag), headers))
|
||||
}
|
||||
};
|
||||
|
||||
if allow_pool {
|
||||
if allow_pool && ready_tag.status() == b'I' {
|
||||
// return connection to the pool
|
||||
tokio::task::spawn(async move {
|
||||
let _ = conn_pool.put(&conn_info, client).await;
|
||||
@@ -351,6 +371,99 @@ async fn query_to_json<T: GenericClient>(
|
||||
}))
|
||||
}
|
||||
|
||||
async fn query_raw_txt_as_json<'a, St, T>(
|
||||
conn: &mut connection::Connection<St, T>,
|
||||
data: QueryData,
|
||||
raw_output: bool,
|
||||
array_mode: bool,
|
||||
) -> anyhow::Result<(Value, ReadyForQueryBody)>
|
||||
where
|
||||
St: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
{
|
||||
let params = json_to_pg_text(data.params)?;
|
||||
let params = params.into_iter();
|
||||
|
||||
let stmt_name = conn.statement_name();
|
||||
let row_description = conn.prepare(&stmt_name, &data.query).await?;
|
||||
|
||||
let mut fields = vec![];
|
||||
let mut columns = vec![];
|
||||
let mut it = row_description.fields();
|
||||
while let Some(field) = it.next().map_err(pg_client::error::Error::parse)? {
|
||||
fields.push(json!({
|
||||
"name": Value::String(field.name().to_owned()),
|
||||
"dataTypeID": Value::Number(field.type_oid().into()),
|
||||
"tableID": field.table_oid(),
|
||||
"columnID": field.column_id(),
|
||||
"dataTypeSize": field.type_size(),
|
||||
"dataTypeModifier": field.type_modifier(),
|
||||
"format": "text",
|
||||
}));
|
||||
|
||||
let type_ = match Type::from_oid(field.type_oid()) {
|
||||
Some(t) => t,
|
||||
None => TypeinfoPreparedQueries::get_type(conn, field.type_oid()).await?,
|
||||
};
|
||||
|
||||
columns.push(Column {
|
||||
name: field.name().to_string(),
|
||||
type_,
|
||||
});
|
||||
}
|
||||
|
||||
conn.execute("", &stmt_name, params)?;
|
||||
conn.sync().await?;
|
||||
|
||||
let mut rows = vec![];
|
||||
|
||||
let mut row_stream = conn.stream_query_results().await?;
|
||||
|
||||
let mut curret_size = 0;
|
||||
while let Some(row) = row_stream.next().await.transpose()? {
|
||||
// let row = row.map_err(Error::db)?;
|
||||
|
||||
curret_size += row.buffer().len();
|
||||
if curret_size > MAX_RESPONSE_SIZE {
|
||||
return Err(anyhow::anyhow!("response too large"));
|
||||
}
|
||||
|
||||
rows.push(pg_text_row_to_json2(&row, &columns, raw_output, array_mode).unwrap());
|
||||
}
|
||||
|
||||
let command_tag = row_stream.tag();
|
||||
let command_tag = command_tag.tag()?;
|
||||
let mut command_tag_split = command_tag.split(' ');
|
||||
let command_tag_name = command_tag_split.next().unwrap_or_default();
|
||||
let command_tag_count = if command_tag_name == "INSERT" {
|
||||
// INSERT returns OID first and then number of rows
|
||||
command_tag_split.nth(1)
|
||||
} else {
|
||||
// other commands return number of rows (if any)
|
||||
command_tag_split.next()
|
||||
}
|
||||
.and_then(|s| s.parse::<i64>().ok());
|
||||
|
||||
let ready_tag = conn.wait_for_ready().await?;
|
||||
|
||||
// resulting JSON format is based on the format of node-postgres result
|
||||
Ok((
|
||||
json!({
|
||||
"command": command_tag_name,
|
||||
"rowCount": command_tag_count,
|
||||
"rows": rows,
|
||||
"fields": fields,
|
||||
"rowAsArray": array_mode,
|
||||
}),
|
||||
ready_tag,
|
||||
))
|
||||
}
|
||||
|
||||
struct Column {
|
||||
name: String,
|
||||
type_: Type,
|
||||
}
|
||||
|
||||
//
|
||||
// Convert postgres row with text-encoded values to JSON object
|
||||
//
|
||||
@@ -370,7 +483,7 @@ pub fn pg_text_row_to_json(
|
||||
} else {
|
||||
pg_text_to_json(pg_value, column.type_())?
|
||||
};
|
||||
Ok((name.to_string(), json_value))
|
||||
Ok((name, json_value))
|
||||
});
|
||||
|
||||
if array_mode {
|
||||
@@ -380,7 +493,55 @@ pub fn pg_text_row_to_json(
|
||||
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Array(arr))
|
||||
} else {
|
||||
let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||
let obj = iter
|
||||
.map(|r| r.map(|(key, val)| (key.to_owned(), val)))
|
||||
.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Object(obj))
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Convert postgres row with text-encoded values to JSON object
|
||||
//
|
||||
fn pg_text_row_to_json2(
|
||||
row: &DataRowBody,
|
||||
columns: &[Column],
|
||||
raw_output: bool,
|
||||
array_mode: bool,
|
||||
) -> Result<Value, anyhow::Error> {
|
||||
let ranges: Vec<Option<std::ops::Range<usize>>> = row.ranges().collect()?;
|
||||
let iter = std::iter::zip(ranges, columns)
|
||||
.enumerate()
|
||||
.map(|(i, (range, column))| {
|
||||
let name = &column.name;
|
||||
let pg_value = range
|
||||
.map(|r| {
|
||||
std::str::from_utf8(&row.buffer()[r])
|
||||
.map_err(|e| pg_client::error::Error::from_sql(e.into(), i))
|
||||
})
|
||||
.transpose()?;
|
||||
// let pg_value = row.as_text(i)?;
|
||||
let json_value = if raw_output {
|
||||
match pg_value {
|
||||
Some(v) => Value::String(v.to_string()),
|
||||
None => Value::Null,
|
||||
}
|
||||
} else {
|
||||
pg_text_to_json(pg_value, &column.type_)?
|
||||
};
|
||||
Ok((name, json_value))
|
||||
});
|
||||
|
||||
if array_mode {
|
||||
// drop keys and aggregate into array
|
||||
let arr = iter
|
||||
.map(|r| r.map(|(_key, val)| val))
|
||||
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Array(arr))
|
||||
} else {
|
||||
let obj = iter
|
||||
.map(|r| r.map(|(key, val)| (key.to_owned(), val)))
|
||||
.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Object(obj))
|
||||
}
|
||||
}
|
||||
@@ -391,16 +552,16 @@ pub fn pg_text_row_to_json(
|
||||
pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
|
||||
if let Some(val) = pg_value {
|
||||
if let Kind::Array(elem_type) = pg_type.kind() {
|
||||
return pg_array_parse(val, elem_type);
|
||||
return pg_array_parse(val, &elem_type);
|
||||
}
|
||||
|
||||
match *pg_type {
|
||||
Type::BOOL => Ok(Value::Bool(val == "t")),
|
||||
Type::INT2 | Type::INT4 => {
|
||||
match pg_type {
|
||||
&Type::BOOL => Ok(Value::Bool(val == "t")),
|
||||
&Type::INT2 | &Type::INT4 => {
|
||||
let val = val.parse::<i32>()?;
|
||||
Ok(Value::Number(serde_json::Number::from(val)))
|
||||
}
|
||||
Type::FLOAT4 | Type::FLOAT8 => {
|
||||
&Type::FLOAT4 | &Type::FLOAT8 => {
|
||||
let fval = val.parse::<f64>()?;
|
||||
let num = serde_json::Number::from_f64(fval);
|
||||
if let Some(num) = num {
|
||||
@@ -412,7 +573,7 @@ pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value,
|
||||
Ok(Value::String(val.to_string()))
|
||||
}
|
||||
}
|
||||
Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
|
||||
&Type::JSON | &Type::JSONB => Ok(serde_json::from_str(val)?),
|
||||
_ => Ok(Value::String(val.to_string())),
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -22,6 +22,7 @@ pub mod scram;
|
||||
pub mod stream;
|
||||
pub mod url;
|
||||
pub mod waiters;
|
||||
pub mod pg_client;
|
||||
|
||||
/// Handle unix signals appropriately.
|
||||
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
|
||||
|
||||
43
proxy/src/pg_client/codec.rs
Normal file
43
proxy/src/pg_client/codec.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol::message::backend::{self, Message};
|
||||
use std::io;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
pub struct FrontendMessage(pub Bytes);
|
||||
pub struct BackendMessages(pub BytesMut);
|
||||
|
||||
impl BackendMessages {
|
||||
pub fn empty() -> BackendMessages {
|
||||
BackendMessages(BytesMut::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl FallibleIterator for BackendMessages {
|
||||
type Item = backend::Message;
|
||||
type Error = io::Error;
|
||||
|
||||
fn next(&mut self) -> io::Result<Option<backend::Message>> {
|
||||
backend::Message::parse(&mut self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresCodec;
|
||||
|
||||
impl Encoder<FrontendMessage> for PostgresCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
|
||||
dst.extend_from_slice(&item.0);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for PostgresCodec {
|
||||
type Item = Message;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message>, io::Error> {
|
||||
Message::parse(src)
|
||||
}
|
||||
}
|
||||
369
proxy/src/pg_client/connection.rs
Normal file
369
proxy/src/pg_client/connection.rs
Normal file
@@ -0,0 +1,369 @@
|
||||
use super::codec::{BackendMessages, FrontendMessage, PostgresCodec};
|
||||
use super::error::Error;
|
||||
use super::prepare::TypeinfoPreparedQueries;
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{Sink, StreamExt};
|
||||
use futures::{SinkExt, Stream};
|
||||
use hashbrown::HashMap;
|
||||
use postgres_protocol::message::backend::{
|
||||
BackendKeyDataBody, CommandCompleteBody, DataRowBody, ErrorResponseBody, Message,
|
||||
ReadyForQueryBody, RowDescriptionBody,
|
||||
};
|
||||
use postgres_protocol::message::frontend;
|
||||
use postgres_protocol::Oid;
|
||||
use std::collections::VecDeque;
|
||||
use std::future::poll_fn;
|
||||
use std::pin::Pin;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::maybe_tls_stream::MaybeTlsStream;
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::IsolationLevel;
|
||||
use tokio_util::codec::Framed;
|
||||
|
||||
pub enum RequestMessages {
|
||||
Single(FrontendMessage),
|
||||
}
|
||||
|
||||
pub struct Request {
|
||||
pub messages: RequestMessages,
|
||||
pub sender: mpsc::Sender<BackendMessages>,
|
||||
}
|
||||
|
||||
pub struct Response {
|
||||
sender: mpsc::Sender<BackendMessages>,
|
||||
}
|
||||
|
||||
/// A connection to a PostgreSQL database.
|
||||
pub struct RawConnection<S, T> {
|
||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
pending_responses: VecDeque<Message>,
|
||||
pub buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> RawConnection<S, T> {
|
||||
pub fn new(
|
||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
buf: BytesMut,
|
||||
) -> RawConnection<S, T> {
|
||||
RawConnection {
|
||||
stream,
|
||||
pending_responses: VecDeque::new(),
|
||||
buf,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&mut self) -> Result<(), Error> {
|
||||
poll_fn(|cx| self.poll_send(cx)).await?;
|
||||
let request = FrontendMessage(self.buf.split().freeze());
|
||||
self.stream.start_send_unpin(request).map_err(Error::io)?;
|
||||
poll_fn(|cx| self.poll_flush(cx)).await
|
||||
}
|
||||
|
||||
pub async fn next_message(&mut self) -> Result<Message, Error> {
|
||||
match self.pending_responses.pop_front() {
|
||||
Some(message) => Ok(message),
|
||||
None => poll_fn(|cx| self.poll_read(cx)).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
|
||||
let message = match ready!(self.stream.poll_next_unpin(cx)?) {
|
||||
Some(message) => message,
|
||||
None => return Poll::Ready(Err(Error::closed())),
|
||||
};
|
||||
Poll::Ready(Ok(message))
|
||||
}
|
||||
|
||||
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Pin::new(&mut self.stream).poll_close(cx).map_err(Error::io)
|
||||
}
|
||||
|
||||
fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if let Poll::Ready(msg) = self.poll_read(cx)? {
|
||||
self.pending_responses.push_back(msg);
|
||||
};
|
||||
self.stream.poll_ready_unpin(cx).map_err(Error::io)
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if let Poll::Ready(msg) = self.poll_read(cx)? {
|
||||
self.pending_responses.push_back(msg);
|
||||
};
|
||||
self.stream.poll_flush_unpin(cx).map_err(Error::io)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Connection<S, T> {
|
||||
stmt_counter: usize,
|
||||
pub typeinfo: Option<TypeinfoPreparedQueries>,
|
||||
pub typecache: HashMap<Oid, Type>,
|
||||
pub raw: RawConnection<S, T>,
|
||||
// key: BackendKeyDataBody,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> Connection<S, T> {
|
||||
pub fn new(stream: MaybeTlsStream<S, T>) -> Connection<S, T> {
|
||||
Connection {
|
||||
stmt_counter: 0,
|
||||
typeinfo: None,
|
||||
typecache: HashMap::new(),
|
||||
raw: RawConnection::new(Framed::new(stream, PostgresCodec), BytesMut::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_tx(
|
||||
&mut self,
|
||||
isolation_level: Option<IsolationLevel>,
|
||||
read_only: Option<bool>,
|
||||
) -> Result<ReadyForQueryBody, Error> {
|
||||
let mut query = "START TRANSACTION".to_string();
|
||||
let mut first = true;
|
||||
|
||||
if let Some(level) = isolation_level {
|
||||
first = false;
|
||||
|
||||
query.push_str(" ISOLATION LEVEL ");
|
||||
let level = match level {
|
||||
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
|
||||
IsolationLevel::ReadCommitted => "READ COMMITTED",
|
||||
IsolationLevel::RepeatableRead => "REPEATABLE READ",
|
||||
IsolationLevel::Serializable => "SERIALIZABLE",
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
};
|
||||
query.push_str(level);
|
||||
}
|
||||
|
||||
if let Some(read_only) = read_only {
|
||||
if !first {
|
||||
query.push(',');
|
||||
}
|
||||
first = false;
|
||||
|
||||
let s = if read_only {
|
||||
" READ ONLY"
|
||||
} else {
|
||||
" READ WRITE"
|
||||
};
|
||||
query.push_str(s);
|
||||
}
|
||||
|
||||
self.execute_simple(&query).await
|
||||
}
|
||||
|
||||
pub async fn rollback(&mut self) -> Result<ReadyForQueryBody, Error> {
|
||||
self.execute_simple("ROLLBACK").await
|
||||
}
|
||||
|
||||
pub async fn commit(&mut self) -> Result<ReadyForQueryBody, Error> {
|
||||
self.execute_simple("COMMIT").await
|
||||
}
|
||||
|
||||
// pub async fn auth_sasl_scram<'a, I>(
|
||||
// mut raw: RawConnection<S, T>,
|
||||
// params: I,
|
||||
// password: &[u8],
|
||||
// ) -> Result<Self, Error>
|
||||
// where
|
||||
// I: IntoIterator<Item = (&'a str, &'a str)>,
|
||||
// {
|
||||
// // send a startup message
|
||||
// frontend::startup_message(params, &mut raw.buf).unwrap();
|
||||
// raw.send().await?;
|
||||
|
||||
// // expect sasl authentication message
|
||||
// let Message::AuthenticationSasl(body) = raw.next_message().await? else { return Err(Error::expecting("sasl authentication")) };
|
||||
// // expect support for SCRAM_SHA_256
|
||||
// if body
|
||||
// .mechanisms()
|
||||
// .find(|&x| Ok(x == authentication::sasl::SCRAM_SHA_256))?
|
||||
// .is_none()
|
||||
// {
|
||||
// return Err(Error::expecting("SCRAM-SHA-256 auth"));
|
||||
// }
|
||||
|
||||
// // initiate SCRAM_SHA_256 authentication without channel binding
|
||||
// let auth = authentication::sasl::ChannelBinding::unrequested();
|
||||
// let mut scram = authentication::sasl::ScramSha256::new(password, auth);
|
||||
|
||||
// frontend::sasl_initial_response(
|
||||
// authentication::sasl::SCRAM_SHA_256,
|
||||
// scram.message(),
|
||||
// &mut raw.buf,
|
||||
// )
|
||||
// .unwrap();
|
||||
// raw.send().await?;
|
||||
|
||||
// // expect sasl continue
|
||||
// let Message::AuthenticationSaslContinue(b) = raw.next_message().await? else { return Err(Error::expecting("auth continue")) };
|
||||
// scram.update(b.data()).unwrap();
|
||||
|
||||
// // continue sasl
|
||||
// frontend::sasl_response(scram.message(), &mut raw.buf).unwrap();
|
||||
// raw.send().await?;
|
||||
|
||||
// // expect sasl final
|
||||
// let Message::AuthenticationSaslFinal(b) = raw.next_message().await? else { return Err(Error::expecting("auth final")) };
|
||||
// scram.finish(b.data()).unwrap();
|
||||
|
||||
// // expect auth ok
|
||||
// let Message::AuthenticationOk = raw.next_message().await? else { return Err(Error::expecting("auth ok")) };
|
||||
|
||||
// // expect connection accepted
|
||||
// let key = loop {
|
||||
// match raw.next_message().await? {
|
||||
// Message::BackendKeyData(key) => break key,
|
||||
// Message::ParameterStatus(_) => {}
|
||||
// _ => return Err(Error::expecting("backend ready")),
|
||||
// }
|
||||
// };
|
||||
|
||||
// let Message::ReadyForQuery(b) = raw.next_message().await? else { return Err(Error::expecting("ready for query")) };
|
||||
// // assert_eq!(b.status(), b'I');
|
||||
|
||||
// Ok(Self { raw, key })
|
||||
// }
|
||||
|
||||
// pub fn prepare_and_execute(
|
||||
// &mut self,
|
||||
// portal: &str,
|
||||
// name: &str,
|
||||
// query: &str,
|
||||
// params: impl IntoIterator<Item = Option<impl AsRef<str>>>,
|
||||
// ) -> std::io::Result<()> {
|
||||
// self.prepare(name, query)?;
|
||||
// self.execute(portal, name, params)
|
||||
// }
|
||||
|
||||
pub fn statement_name(&mut self) -> String {
|
||||
self.stmt_counter += 1;
|
||||
format!("s{}", self.stmt_counter)
|
||||
}
|
||||
|
||||
async fn execute_simple(&mut self, query: &str) -> Result<ReadyForQueryBody, Error> {
|
||||
frontend::query(query, &mut self.raw.buf)?;
|
||||
self.raw.send().await?;
|
||||
|
||||
loop {
|
||||
match self.raw.next_message().await? {
|
||||
Message::ReadyForQuery(q) => return Ok(q),
|
||||
Message::CommandComplete(_)
|
||||
| Message::EmptyQueryResponse
|
||||
| Message::RowDescription(_)
|
||||
| Message::DataRow(_) => {}
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn prepare(&mut self, name: &str, query: &str) -> Result<RowDescriptionBody, Error> {
|
||||
frontend::parse(name, query, std::iter::empty(), &mut self.raw.buf)?;
|
||||
frontend::describe(b'S', name, &mut self.raw.buf)?;
|
||||
self.sync().await?;
|
||||
self.wait_for_prepare().await
|
||||
}
|
||||
|
||||
pub fn execute(
|
||||
&mut self,
|
||||
portal: &str,
|
||||
name: &str,
|
||||
params: impl IntoIterator<Item = Option<impl AsRef<str>>>,
|
||||
) -> std::io::Result<()> {
|
||||
frontend::bind(
|
||||
portal,
|
||||
name,
|
||||
std::iter::empty(), // all parameters use the default format (text)
|
||||
params,
|
||||
|param, buf| match param {
|
||||
Some(param) => {
|
||||
buf.put_slice(param.as_ref().as_bytes());
|
||||
Ok(postgres_protocol::IsNull::No)
|
||||
}
|
||||
None => Ok(postgres_protocol::IsNull::Yes),
|
||||
},
|
||||
Some(0), // all text
|
||||
&mut self.raw.buf,
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
|
||||
frontend::BindError::Serialization(io) => io,
|
||||
})?;
|
||||
frontend::execute(portal, 0, &mut self.raw.buf)
|
||||
}
|
||||
|
||||
pub async fn sync(&mut self) -> Result<(), Error> {
|
||||
frontend::sync(&mut self.raw.buf);
|
||||
self.raw.send().await
|
||||
}
|
||||
|
||||
pub async fn wait_for_prepare(&mut self) -> Result<RowDescriptionBody, Error> {
|
||||
let Message::ParseComplete = self.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||
let Message::ParameterDescription(_) = self.raw.next_message().await? else { return Err(Error::expecting("param description")) };
|
||||
let Message::RowDescription(desc) = self.raw.next_message().await? else { return Err(Error::expecting("row description")) };
|
||||
|
||||
self.wait_for_ready().await?;
|
||||
|
||||
Ok(desc)
|
||||
}
|
||||
|
||||
pub async fn stream_query_results(&mut self) -> Result<RowStream<'_, S, T>, Error> {
|
||||
// let Message::ParseComplete = self.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||
let Message::BindComplete = self.raw.next_message().await? else { return Err(Error::expecting("bind")) };
|
||||
Ok(RowStream::Stream(&mut self.raw))
|
||||
}
|
||||
|
||||
pub async fn wait_for_ready(&mut self) -> Result<ReadyForQueryBody, Error> {
|
||||
loop {
|
||||
match self.raw.next_message().await.unwrap() {
|
||||
Message::ReadyForQuery(b) => break Ok(b),
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum RowStream<'a, S, T> {
|
||||
Stream(&'a mut RawConnection<S, T>),
|
||||
Complete(Option<CommandCompleteBody>),
|
||||
}
|
||||
impl<S, T> Unpin for RowStream<'_, S, T> {}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> Stream
|
||||
for RowStream<'_, S, T>
|
||||
{
|
||||
// this is horrible - first result is for transport/protocol errors errors
|
||||
// second result is for sql errors.
|
||||
type Item = Result<Result<DataRowBody, ErrorResponseBody>, Error>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match &mut *self {
|
||||
RowStream::Stream(raw) => match ready!(raw.poll_read(cx)?) {
|
||||
Message::DataRow(row) => Poll::Ready(Some(Ok(Ok(row)))),
|
||||
Message::CommandComplete(tag) => {
|
||||
*self = Self::Complete(Some(tag));
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Message::EmptyQueryResponse | Message::PortalSuspended => {
|
||||
*self = Self::Complete(None);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Message::ErrorResponse(error) => {
|
||||
*self = Self::Complete(None);
|
||||
Poll::Ready(Some(Ok(Err(error))))
|
||||
}
|
||||
_ => Poll::Ready(Some(Err(Error::expecting("command completion")))),
|
||||
},
|
||||
RowStream::Complete(_) => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> RowStream<'_, S, T> {
|
||||
pub fn tag(self) -> Option<CommandCompleteBody> {
|
||||
match self {
|
||||
RowStream::Stream(_) => panic!("should not get tag unless row stream is exhausted"),
|
||||
RowStream::Complete(tag) => tag,
|
||||
}
|
||||
}
|
||||
}
|
||||
447
proxy/src/pg_client/error.rs
Normal file
447
proxy/src/pg_client/error.rs
Normal file
@@ -0,0 +1,447 @@
|
||||
use std::{error, fmt, io};
|
||||
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
|
||||
use tokio_native_tls::native_tls;
|
||||
use tokio_postgres::error::{ErrorPosition, SqlState};
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum Kind {
|
||||
Io,
|
||||
Tls,
|
||||
UnexpectedMessage,
|
||||
FromSql(usize),
|
||||
Closed,
|
||||
Db,
|
||||
Parse,
|
||||
Encode,
|
||||
}
|
||||
|
||||
struct ErrorInner {
|
||||
kind: Kind,
|
||||
cause: Option<Box<dyn error::Error + Sync + Send>>,
|
||||
}
|
||||
|
||||
/// An error communicating with the Postgres server.
|
||||
pub struct Error(ErrorInner);
|
||||
|
||||
impl fmt::Debug for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt.debug_struct("Error")
|
||||
.field("kind", &self.0.kind)
|
||||
.field("cause", &self.0.cause)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match &self.0.kind {
|
||||
Kind::Io => fmt.write_str("error communicating with the server")?,
|
||||
Kind::Tls => fmt.write_str("error establishing tls")?,
|
||||
Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?,
|
||||
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
|
||||
Kind::Closed => fmt.write_str("connection closed")?,
|
||||
Kind::Db => fmt.write_str("db error")?,
|
||||
Kind::Parse => fmt.write_str("error parsing response from server")?,
|
||||
Kind::Encode => fmt.write_str("error encoding message to server")?,
|
||||
};
|
||||
if let Some(ref cause) = self.0.cause {
|
||||
write!(fmt, ": {}", cause)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for Error {
|
||||
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
|
||||
self.0.cause.as_ref().map(|e| &**e as _)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
fn from(value: io::Error) -> Self {
|
||||
Self::io(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Consumes the error, returning its cause.
|
||||
pub fn into_source(self) -> Option<Box<dyn error::Error + Sync + Send>> {
|
||||
self.0.cause
|
||||
}
|
||||
|
||||
/// Returns the source of this error if it was a `DbError`.
|
||||
///
|
||||
/// This is a simple convenience method.
|
||||
pub fn as_db_error(&self) -> Option<&DbError> {
|
||||
error::Error::source(self).and_then(|e| e.downcast_ref::<DbError>())
|
||||
}
|
||||
|
||||
/// Determines if the error was associated with closed connection.
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.0.kind == Kind::Closed
|
||||
}
|
||||
|
||||
/// Returns the SQLSTATE error code associated with the error.
|
||||
///
|
||||
/// This is a convenience method that downcasts the cause to a `DbError` and returns its code.
|
||||
pub fn code(&self) -> Option<&SqlState> {
|
||||
self.as_db_error().map(DbError::code)
|
||||
}
|
||||
|
||||
fn new(kind: Kind, cause: Option<Box<dyn error::Error + Sync + Send>>) -> Error {
|
||||
Error(ErrorInner { kind, cause })
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
pub(crate) fn db(error: ErrorResponseBody) -> Error {
|
||||
match DbError::parse(&mut error.fields()) {
|
||||
Ok(e) => Error::new(Kind::Db, Some(Box::new(e))),
|
||||
Err(e) => Error::new(Kind::Parse, Some(Box::new(e))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_sql(e: Box<dyn error::Error + Sync + Send>, idx: usize) -> Error {
|
||||
Error::new(Kind::FromSql(idx), Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn closed() -> Error {
|
||||
Error::new(Kind::Closed, None)
|
||||
}
|
||||
|
||||
pub(crate) fn unexpected_message() -> Error {
|
||||
Error::new(Kind::UnexpectedMessage, None)
|
||||
}
|
||||
|
||||
pub(crate) fn expecting(expected: &str) -> Error {
|
||||
Error::new(Kind::UnexpectedMessage, Some(expected.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn parse(e: io::Error) -> Error {
|
||||
Error::new(Kind::Parse, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
pub(crate) fn encode(e: io::Error) -> Error {
|
||||
Error::new(Kind::Encode, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
pub(crate) fn io(e: io::Error) -> Error {
|
||||
Error::new(Kind::Io, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
pub(crate) fn tls(e: native_tls::Error) -> Error {
|
||||
Error::new(Kind::Tls, Some(Box::new(e)))
|
||||
}
|
||||
}
|
||||
|
||||
/// The severity of a Postgres error or notice.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum Severity {
|
||||
/// PANIC
|
||||
Panic,
|
||||
/// FATAL
|
||||
Fatal,
|
||||
/// ERROR
|
||||
Error,
|
||||
/// WARNING
|
||||
Warning,
|
||||
/// NOTICE
|
||||
Notice,
|
||||
/// DEBUG
|
||||
Debug,
|
||||
/// INFO
|
||||
Info,
|
||||
/// LOG
|
||||
Log,
|
||||
}
|
||||
|
||||
impl fmt::Display for Severity {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let s = match *self {
|
||||
Severity::Panic => "PANIC",
|
||||
Severity::Fatal => "FATAL",
|
||||
Severity::Error => "ERROR",
|
||||
Severity::Warning => "WARNING",
|
||||
Severity::Notice => "NOTICE",
|
||||
Severity::Debug => "DEBUG",
|
||||
Severity::Info => "INFO",
|
||||
Severity::Log => "LOG",
|
||||
};
|
||||
fmt.write_str(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl Severity {
|
||||
fn from_str(s: &str) -> Option<Severity> {
|
||||
match s {
|
||||
"PANIC" => Some(Severity::Panic),
|
||||
"FATAL" => Some(Severity::Fatal),
|
||||
"ERROR" => Some(Severity::Error),
|
||||
"WARNING" => Some(Severity::Warning),
|
||||
"NOTICE" => Some(Severity::Notice),
|
||||
"DEBUG" => Some(Severity::Debug),
|
||||
"INFO" => Some(Severity::Info),
|
||||
"LOG" => Some(Severity::Log),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Postgres error or notice.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct DbError {
|
||||
severity: String,
|
||||
parsed_severity: Option<Severity>,
|
||||
code: SqlState,
|
||||
message: String,
|
||||
detail: Option<String>,
|
||||
hint: Option<String>,
|
||||
position: Option<ErrorPosition>,
|
||||
where_: Option<String>,
|
||||
schema: Option<String>,
|
||||
table: Option<String>,
|
||||
column: Option<String>,
|
||||
datatype: Option<String>,
|
||||
constraint: Option<String>,
|
||||
file: Option<String>,
|
||||
line: Option<u32>,
|
||||
routine: Option<String>,
|
||||
}
|
||||
|
||||
impl DbError {
|
||||
pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result<DbError> {
|
||||
let mut severity = None;
|
||||
let mut parsed_severity = None;
|
||||
let mut code = None;
|
||||
let mut message = None;
|
||||
let mut detail = None;
|
||||
let mut hint = None;
|
||||
let mut normal_position = None;
|
||||
let mut internal_position = None;
|
||||
let mut internal_query = None;
|
||||
let mut where_ = None;
|
||||
let mut schema = None;
|
||||
let mut table = None;
|
||||
let mut column = None;
|
||||
let mut datatype = None;
|
||||
let mut constraint = None;
|
||||
let mut file = None;
|
||||
let mut line = None;
|
||||
let mut routine = None;
|
||||
|
||||
while let Some(field) = fields.next()? {
|
||||
match field.type_() {
|
||||
b'S' => severity = Some(field.value().to_owned()),
|
||||
b'C' => code = Some(SqlState::from_code(field.value())),
|
||||
b'M' => message = Some(field.value().to_owned()),
|
||||
b'D' => detail = Some(field.value().to_owned()),
|
||||
b'H' => hint = Some(field.value().to_owned()),
|
||||
b'P' => {
|
||||
normal_position = Some(field.value().parse::<u32>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`P` field did not contain an integer",
|
||||
)
|
||||
})?);
|
||||
}
|
||||
b'p' => {
|
||||
internal_position = Some(field.value().parse::<u32>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`p` field did not contain an integer",
|
||||
)
|
||||
})?);
|
||||
}
|
||||
b'q' => internal_query = Some(field.value().to_owned()),
|
||||
b'W' => where_ = Some(field.value().to_owned()),
|
||||
b's' => schema = Some(field.value().to_owned()),
|
||||
b't' => table = Some(field.value().to_owned()),
|
||||
b'c' => column = Some(field.value().to_owned()),
|
||||
b'd' => datatype = Some(field.value().to_owned()),
|
||||
b'n' => constraint = Some(field.value().to_owned()),
|
||||
b'F' => file = Some(field.value().to_owned()),
|
||||
b'L' => {
|
||||
line = Some(field.value().parse::<u32>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`L` field did not contain an integer",
|
||||
)
|
||||
})?);
|
||||
}
|
||||
b'R' => routine = Some(field.value().to_owned()),
|
||||
b'V' => {
|
||||
parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`V` field contained an invalid value",
|
||||
)
|
||||
})?);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(DbError {
|
||||
severity: severity
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?,
|
||||
parsed_severity,
|
||||
code: code
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?,
|
||||
message: message
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?,
|
||||
detail,
|
||||
hint,
|
||||
position: match normal_position {
|
||||
Some(position) => Some(ErrorPosition::Original(position)),
|
||||
None => match internal_position {
|
||||
Some(position) => Some(ErrorPosition::Internal {
|
||||
position,
|
||||
query: internal_query.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`q` field missing but `p` field present",
|
||||
)
|
||||
})?,
|
||||
}),
|
||||
None => None,
|
||||
},
|
||||
},
|
||||
where_,
|
||||
schema,
|
||||
table,
|
||||
column,
|
||||
datatype,
|
||||
constraint,
|
||||
file,
|
||||
line,
|
||||
routine,
|
||||
})
|
||||
}
|
||||
|
||||
/// The field contents are ERROR, FATAL, or PANIC (in an error message),
|
||||
/// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a
|
||||
/// localized translation of one of these.
|
||||
pub fn severity(&self) -> &str {
|
||||
&self.severity
|
||||
}
|
||||
|
||||
/// A parsed, nonlocalized version of `severity`. (PostgreSQL 9.6+)
|
||||
pub fn parsed_severity(&self) -> Option<Severity> {
|
||||
self.parsed_severity
|
||||
}
|
||||
|
||||
/// The SQLSTATE code for the error.
|
||||
pub fn code(&self) -> &SqlState {
|
||||
&self.code
|
||||
}
|
||||
|
||||
/// The primary human-readable error message.
|
||||
///
|
||||
/// This should be accurate but terse (typically one line).
|
||||
pub fn message(&self) -> &str {
|
||||
&self.message
|
||||
}
|
||||
|
||||
/// An optional secondary error message carrying more detail about the
|
||||
/// problem.
|
||||
///
|
||||
/// Might run to multiple lines.
|
||||
pub fn detail(&self) -> Option<&str> {
|
||||
self.detail.as_deref()
|
||||
}
|
||||
|
||||
/// An optional suggestion what to do about the problem.
|
||||
///
|
||||
/// This is intended to differ from `detail` in that it offers advice
|
||||
/// (potentially inappropriate) rather than hard facts. Might run to
|
||||
/// multiple lines.
|
||||
pub fn hint(&self) -> Option<&str> {
|
||||
self.hint.as_deref()
|
||||
}
|
||||
|
||||
/// An optional error cursor position into either the original query string
|
||||
/// or an internally generated query.
|
||||
pub fn position(&self) -> Option<&ErrorPosition> {
|
||||
self.position.as_ref()
|
||||
}
|
||||
|
||||
/// An indication of the context in which the error occurred.
|
||||
///
|
||||
/// Presently this includes a call stack traceback of active procedural
|
||||
/// language functions and internally-generated queries. The trace is one
|
||||
/// entry per line, most recent first.
|
||||
pub fn where_(&self) -> Option<&str> {
|
||||
self.where_.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific database object, the name
|
||||
/// of the schema containing that object, if any. (PostgreSQL 9.3+)
|
||||
pub fn schema(&self) -> Option<&str> {
|
||||
self.schema.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific table, the name of the
|
||||
/// table. (Refer to the schema name field for the name of the table's
|
||||
/// schema.) (PostgreSQL 9.3+)
|
||||
pub fn table(&self) -> Option<&str> {
|
||||
self.table.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific table column, the name of
|
||||
/// the column.
|
||||
///
|
||||
/// (Refer to the schema and table name fields to identify the table.)
|
||||
/// (PostgreSQL 9.3+)
|
||||
pub fn column(&self) -> Option<&str> {
|
||||
self.column.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific data type, the name of the
|
||||
/// data type. (Refer to the schema name field for the name of the data
|
||||
/// type's schema.) (PostgreSQL 9.3+)
|
||||
pub fn datatype(&self) -> Option<&str> {
|
||||
self.datatype.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific constraint, the name of the
|
||||
/// constraint.
|
||||
///
|
||||
/// Refer to fields listed above for the associated table or domain.
|
||||
/// (For this purpose, indexes are treated as constraints, even if they
|
||||
/// weren't created with constraint syntax.) (PostgreSQL 9.3+)
|
||||
pub fn constraint(&self) -> Option<&str> {
|
||||
self.constraint.as_deref()
|
||||
}
|
||||
|
||||
/// The file name of the source-code location where the error was reported.
|
||||
pub fn file(&self) -> Option<&str> {
|
||||
self.file.as_deref()
|
||||
}
|
||||
|
||||
/// The line number of the source-code location where the error was
|
||||
/// reported.
|
||||
pub fn line(&self) -> Option<u32> {
|
||||
self.line
|
||||
}
|
||||
|
||||
/// The name of the source-code routine reporting the error.
|
||||
pub fn routine(&self) -> Option<&str> {
|
||||
self.routine.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for DbError {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "{}: {}", self.severity, self.message)?;
|
||||
if let Some(detail) = &self.detail {
|
||||
write!(fmt, "\nDETAIL: {}", detail)?;
|
||||
}
|
||||
if let Some(hint) = &self.hint {
|
||||
write!(fmt, "\nHINT: {}", hint)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for DbError {}
|
||||
5
proxy/src/pg_client/mod.rs
Normal file
5
proxy/src/pg_client/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
|
||||
pub mod codec;
|
||||
pub mod connection;
|
||||
pub mod error;
|
||||
pub mod prepare;
|
||||
293
proxy/src/pg_client/prepare.rs
Normal file
293
proxy/src/pg_client/prepare.rs
Normal file
@@ -0,0 +1,293 @@
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::StreamExt;
|
||||
use postgres_protocol::message::backend::{DataRowRanges, Message};
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::types::{Field, Kind, Oid, ToSql, Type};
|
||||
|
||||
use super::connection::Connection;
|
||||
use super::error::Error;
|
||||
|
||||
const TYPEINFO_QUERY: &str = "\
|
||||
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
|
||||
FROM pg_catalog.pg_type t
|
||||
LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
|
||||
INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
|
||||
WHERE t.oid = $1
|
||||
";
|
||||
|
||||
const TYPEINFO_ENUM_QUERY: &str = "\
|
||||
SELECT enumlabel
|
||||
FROM pg_catalog.pg_enum
|
||||
WHERE enumtypid = $1
|
||||
ORDER BY enumsortorder
|
||||
";
|
||||
|
||||
const TYPEINFO_COMPOSITE_QUERY: &str = "\
|
||||
SELECT attname, atttypid
|
||||
FROM pg_catalog.pg_attribute
|
||||
WHERE attrelid = $1
|
||||
AND NOT attisdropped
|
||||
AND attnum > 0
|
||||
ORDER BY attnum
|
||||
";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TypeinfoPreparedQueries {
|
||||
query: String,
|
||||
enum_query: String,
|
||||
composite_query: String,
|
||||
}
|
||||
|
||||
fn map_is_null(x: tokio_postgres::types::IsNull) -> postgres_protocol::IsNull {
|
||||
match x {
|
||||
tokio_postgres::types::IsNull::Yes => postgres_protocol::IsNull::Yes,
|
||||
tokio_postgres::types::IsNull::No => postgres_protocol::IsNull::No,
|
||||
}
|
||||
}
|
||||
|
||||
fn read_column<'a, T: tokio_postgres::types::FromSql<'a>>(
|
||||
buffer: &'a [u8],
|
||||
type_: &Type,
|
||||
ranges: &mut DataRowRanges<'a>,
|
||||
) -> Result<T, Error> {
|
||||
let range = ranges.next()?;
|
||||
match range {
|
||||
Some(range) => T::from_sql_nullable(type_, range.map(|r| &buffer[r])),
|
||||
None => T::from_sql_null(type_),
|
||||
}
|
||||
.map_err(|e| Error::from_sql(e, 0))
|
||||
}
|
||||
|
||||
impl TypeinfoPreparedQueries {
|
||||
pub async fn new<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
>(
|
||||
c: &mut Connection<S, T>,
|
||||
) -> Result<Self, Error> {
|
||||
if let Some(ti) = &c.typeinfo {
|
||||
return Ok(ti.clone());
|
||||
}
|
||||
|
||||
let query = c.statement_name();
|
||||
let enum_query = c.statement_name();
|
||||
let composite_query = c.statement_name();
|
||||
|
||||
frontend::parse(&query, TYPEINFO_QUERY, [Type::OID.oid()], &mut c.raw.buf)?;
|
||||
frontend::parse(
|
||||
&enum_query,
|
||||
TYPEINFO_ENUM_QUERY,
|
||||
[Type::OID.oid()],
|
||||
&mut c.raw.buf,
|
||||
)?;
|
||||
c.sync().await?;
|
||||
frontend::parse(
|
||||
&composite_query,
|
||||
TYPEINFO_COMPOSITE_QUERY,
|
||||
[Type::OID.oid()],
|
||||
&mut c.raw.buf,
|
||||
)?;
|
||||
c.sync().await?;
|
||||
|
||||
let Message::ParseComplete = c.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||
let Message::ParseComplete = c.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||
let Message::ParseComplete = c.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||
c.wait_for_ready().await?;
|
||||
|
||||
Ok(c.typeinfo
|
||||
.insert(TypeinfoPreparedQueries {
|
||||
query,
|
||||
enum_query,
|
||||
composite_query,
|
||||
})
|
||||
.clone())
|
||||
}
|
||||
|
||||
fn get_type_rec<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
>(
|
||||
c: &mut Connection<S, T>,
|
||||
oid: Oid,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + '_>> {
|
||||
Box::pin(Self::get_type(c, oid))
|
||||
}
|
||||
|
||||
pub async fn get_type<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
>(
|
||||
c: &mut Connection<S, T>,
|
||||
oid: Oid,
|
||||
) -> Result<Type, Error> {
|
||||
if let Some(type_) = Type::from_oid(oid) {
|
||||
return Ok(type_);
|
||||
}
|
||||
|
||||
if let Some(type_) = c.typecache.get(&oid) {
|
||||
return Ok(type_.clone());
|
||||
}
|
||||
|
||||
let queries = Self::new(c).await?;
|
||||
|
||||
frontend::bind(
|
||||
"",
|
||||
&queries.query,
|
||||
[1], // the only parameter is in binary format
|
||||
[oid],
|
||||
|param, buf| param.to_sql(&Type::OID, buf).map(map_is_null),
|
||||
Some(1), // binary return type
|
||||
&mut c.raw.buf,
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
|
||||
frontend::BindError::Serialization(io) => io,
|
||||
})?;
|
||||
frontend::execute("", 0, &mut c.raw.buf)?;
|
||||
|
||||
c.sync().await?;
|
||||
|
||||
let mut stream = c.stream_query_results().await?;
|
||||
|
||||
let Some(row) = stream.next().await.transpose()? else {
|
||||
todo!()
|
||||
};
|
||||
|
||||
let row = row.map_err(Error::db)?;
|
||||
let b = row.buffer();
|
||||
let mut ranges = row.ranges();
|
||||
|
||||
let name: String = read_column(b, &Type::NAME, &mut ranges)?;
|
||||
let type_: i8 = read_column(b, &Type::CHAR, &mut ranges)?;
|
||||
let elem_oid: Oid = read_column(b, &Type::OID, &mut ranges)?;
|
||||
let rngsubtype: Option<Oid> = read_column(b, &Type::OID, &mut ranges)?;
|
||||
let basetype: Oid = read_column(b, &Type::OID, &mut ranges)?;
|
||||
let schema: String = read_column(b, &Type::NAME, &mut ranges)?;
|
||||
let relid: Oid = read_column(b, &Type::OID, &mut ranges)?;
|
||||
|
||||
{
|
||||
// should be none
|
||||
let None = stream.next().await.transpose()? else {
|
||||
todo!()
|
||||
};
|
||||
drop(stream);
|
||||
}
|
||||
|
||||
let kind = if type_ == b'e' as i8 {
|
||||
let variants = Self::get_enum_variants(c, oid).await?;
|
||||
Kind::Enum(variants)
|
||||
} else if type_ == b'p' as i8 {
|
||||
Kind::Pseudo
|
||||
} else if basetype != 0 {
|
||||
let type_ = Self::get_type_rec(c, basetype).await?;
|
||||
Kind::Domain(type_)
|
||||
} else if elem_oid != 0 {
|
||||
let type_ = Self::get_type_rec(c, elem_oid).await?;
|
||||
Kind::Array(type_)
|
||||
} else if relid != 0 {
|
||||
let fields = Self::get_composite_fields(c, relid).await?;
|
||||
Kind::Composite(fields)
|
||||
} else if let Some(rngsubtype) = rngsubtype {
|
||||
let type_ = Self::get_type_rec(c, rngsubtype).await?;
|
||||
Kind::Range(type_)
|
||||
} else {
|
||||
Kind::Simple
|
||||
};
|
||||
|
||||
let type_ = Type::new(name, oid, kind, schema);
|
||||
c.typecache.insert(oid, type_.clone());
|
||||
|
||||
Ok(type_)
|
||||
}
|
||||
|
||||
async fn get_enum_variants<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
>(
|
||||
c: &mut Connection<S, T>,
|
||||
oid: Oid,
|
||||
) -> Result<Vec<String>, Error> {
|
||||
let queries = Self::new(c).await?;
|
||||
|
||||
frontend::bind(
|
||||
"",
|
||||
&queries.enum_query,
|
||||
[1], // the only parameter is in binary format
|
||||
[oid],
|
||||
|param, buf| param.to_sql(&Type::OID, buf).map(map_is_null),
|
||||
Some(1), // binary return type
|
||||
&mut c.raw.buf,
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
|
||||
frontend::BindError::Serialization(io) => io,
|
||||
})?;
|
||||
frontend::execute("", 0, &mut c.raw.buf)?;
|
||||
|
||||
c.sync().await?;
|
||||
|
||||
let mut stream = c.stream_query_results().await?;
|
||||
let mut variants = Vec::new();
|
||||
while let Some(row) = stream.next().await.transpose()? {
|
||||
let row = row.map_err(Error::db)?;
|
||||
|
||||
let variant: String = read_column(row.buffer(), &Type::NAME, &mut row.ranges())?;
|
||||
variants.push(variant);
|
||||
}
|
||||
|
||||
c.wait_for_ready().await?;
|
||||
|
||||
Ok(variants)
|
||||
}
|
||||
|
||||
async fn get_composite_fields<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
>(
|
||||
c: &mut Connection<S, T>,
|
||||
oid: Oid,
|
||||
) -> Result<Vec<Field>, Error> {
|
||||
let queries = Self::new(c).await?;
|
||||
|
||||
frontend::bind(
|
||||
"",
|
||||
&queries.composite_query,
|
||||
[1], // the only parameter is in binary format
|
||||
[oid],
|
||||
|param, buf| param.to_sql(&Type::OID, buf).map(map_is_null),
|
||||
Some(1), // binary return type
|
||||
&mut c.raw.buf,
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
|
||||
frontend::BindError::Serialization(io) => io,
|
||||
})?;
|
||||
frontend::execute("", 0, &mut c.raw.buf)?;
|
||||
|
||||
c.sync().await?;
|
||||
|
||||
let mut stream = c.stream_query_results().await?;
|
||||
let mut fields = Vec::new();
|
||||
while let Some(row) = stream.next().await.transpose()? {
|
||||
let row = row.map_err(Error::db)?;
|
||||
|
||||
let mut ranges = row.ranges();
|
||||
let name: String = read_column(row.buffer(), &Type::NAME, &mut ranges)?;
|
||||
let oid: Oid = read_column(row.buffer(), &Type::OID, &mut ranges)?;
|
||||
fields.push((name, oid));
|
||||
}
|
||||
|
||||
c.wait_for_ready().await?;
|
||||
|
||||
let mut output_fields = Vec::with_capacity(fields.len());
|
||||
for (name, oid) in fields {
|
||||
let type_ = Self::get_type_rec(c, oid).await?;
|
||||
output_fields.push(Field::new(name, type_))
|
||||
}
|
||||
|
||||
Ok(output_fields)
|
||||
}
|
||||
}
|
||||
@@ -99,9 +99,8 @@ struct Scram(scram::ServerSecret);
|
||||
|
||||
impl Scram {
|
||||
fn new(password: &str) -> anyhow::Result<Self> {
|
||||
let salt = rand::random::<[u8; 16]>();
|
||||
let secret = scram::ServerSecret::build(password, &salt, 256)
|
||||
.context("failed to generate scram secret")?;
|
||||
let secret =
|
||||
scram::ServerSecret::build(password).context("failed to generate scram secret")?;
|
||||
Ok(Scram(secret))
|
||||
}
|
||||
|
||||
|
||||
@@ -12,9 +12,6 @@ mod messages;
|
||||
mod secret;
|
||||
mod signature;
|
||||
|
||||
#[cfg(any(test, doc))]
|
||||
mod password;
|
||||
|
||||
pub use exchange::Exchange;
|
||||
pub use key::ScramKey;
|
||||
pub use secret::ServerSecret;
|
||||
@@ -57,27 +54,21 @@ fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
|
||||
|
||||
use crate::sasl::{Mechanism, Step};
|
||||
|
||||
use super::{password::SaltedPassword, Exchange, ServerSecret};
|
||||
use super::{Exchange, ServerSecret};
|
||||
|
||||
#[test]
|
||||
fn happy_path() {
|
||||
fn snapshot() {
|
||||
let iterations = 4096;
|
||||
let salt_base64 = "QSXCR+Q6sek8bf92";
|
||||
let pw = SaltedPassword::new(
|
||||
b"pencil",
|
||||
base64::decode(salt_base64).unwrap().as_slice(),
|
||||
iterations,
|
||||
);
|
||||
let salt = "QSXCR+Q6sek8bf92";
|
||||
let stored_key = "FO+9jBb3MUukt6jJnzjPZOWc5ow/Pu6JtPyju0aqaE8=";
|
||||
let server_key = "qxJ1SbmSAi5EcS0J5Ck/cKAm/+Ixa+Kwp63f4OHDgzo=";
|
||||
let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",);
|
||||
let secret = ServerSecret::parse(&secret).unwrap();
|
||||
|
||||
let secret = ServerSecret {
|
||||
iterations,
|
||||
salt_base64: salt_base64.to_owned(),
|
||||
stored_key: pw.client_key().sha256(),
|
||||
server_key: pw.server_key(),
|
||||
doomed: false,
|
||||
};
|
||||
const NONCE: [u8; 18] = [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
];
|
||||
@@ -115,4 +106,40 @@ mod tests {
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_round_trip_test(client_password: &str) {
|
||||
let secret = ServerSecret::build("pencil").unwrap();
|
||||
let mut exchange = Exchange::new(&secret, rand::random, None);
|
||||
|
||||
let mut client =
|
||||
ScramSha256::new(client_password.as_bytes(), ChannelBinding::unsupported());
|
||||
|
||||
let client_first = std::str::from_utf8(client.message()).unwrap();
|
||||
exchange = match exchange.exchange(client_first).unwrap() {
|
||||
Step::Continue(exchange, message) => {
|
||||
client.update(message.as_bytes()).unwrap();
|
||||
exchange
|
||||
}
|
||||
Step::Success(_, _) => panic!("expected continue, got success"),
|
||||
Step::Failure(f) => panic!("{f}"),
|
||||
};
|
||||
|
||||
let client_final = std::str::from_utf8(client.message()).unwrap();
|
||||
match exchange.exchange(client_final).unwrap() {
|
||||
Step::Success(_, message) => client.finish(message.as_bytes()).unwrap(),
|
||||
Step::Continue(_, _) => panic!("expected success, got continue"),
|
||||
Step::Failure(f) => panic!("{f}"),
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_trip() {
|
||||
run_round_trip_test("pencil")
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "password doesn't match")]
|
||||
fn failure() {
|
||||
run_round_trip_test("eraser")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
/// Faithfully taken from PostgreSQL.
|
||||
pub const SCRAM_KEY_LEN: usize = 32;
|
||||
|
||||
/// One of the keys derived from the [password](super::password::SaltedPassword).
|
||||
/// One of the keys derived from the user's password.
|
||||
/// We use the same structure for all keys, i.e.
|
||||
/// `ClientKey`, `StoredKey`, and `ServerKey`.
|
||||
#[derive(Default, PartialEq, Eq)]
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
//! Password hashing routines.
|
||||
|
||||
use super::key::ScramKey;
|
||||
|
||||
pub const SALTED_PASSWORD_LEN: usize = 32;
|
||||
|
||||
/// Salted hashed password is essential for [key](super::key) derivation.
|
||||
#[repr(transparent)]
|
||||
pub struct SaltedPassword {
|
||||
bytes: [u8; SALTED_PASSWORD_LEN],
|
||||
}
|
||||
|
||||
impl SaltedPassword {
|
||||
/// See `scram-common.c : scram_SaltedPassword` for details.
|
||||
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2898> (see `PBKDF2`).
|
||||
pub fn new(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword {
|
||||
pbkdf2::pbkdf2_hmac_array::<sha2::Sha256, 32>(password, salt, iterations).into()
|
||||
}
|
||||
|
||||
/// Derive `ClientKey` from a salted hashed password.
|
||||
pub fn client_key(&self) -> ScramKey {
|
||||
super::hmac_sha256(&self.bytes, [b"Client Key".as_ref()]).into()
|
||||
}
|
||||
|
||||
/// Derive `ServerKey` from a salted hashed password.
|
||||
pub fn server_key(&self) -> ScramKey {
|
||||
super::hmac_sha256(&self.bytes, [b"Server Key".as_ref()]).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; SALTED_PASSWORD_LEN]> for SaltedPassword {
|
||||
#[inline(always)]
|
||||
fn from(bytes: [u8; SALTED_PASSWORD_LEN]) -> Self {
|
||||
Self { bytes }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::SaltedPassword;
|
||||
|
||||
fn legacy_pbkdf2_impl(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword {
|
||||
let one = 1_u32.to_be_bytes(); // magic
|
||||
|
||||
let mut current = super::super::hmac_sha256(password, [salt, &one]);
|
||||
let mut result = current;
|
||||
for _ in 1..iterations {
|
||||
current = super::super::hmac_sha256(password, [current.as_ref()]);
|
||||
// TODO: result = current.zip(result).map(|(x, y)| x ^ y), issue #80094
|
||||
for (i, x) in current.iter().enumerate() {
|
||||
result[i] ^= x;
|
||||
}
|
||||
}
|
||||
|
||||
result.into()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pbkdf2() {
|
||||
let password = "a-very-secure-password";
|
||||
let salt = "such-a-random-salt";
|
||||
let iterations = 4096;
|
||||
let output = [
|
||||
203, 18, 206, 81, 4, 154, 193, 100, 147, 41, 211, 217, 177, 203, 69, 210, 194, 211,
|
||||
101, 1, 248, 156, 96, 0, 8, 223, 30, 87, 158, 41, 20, 42,
|
||||
];
|
||||
|
||||
let actual = SaltedPassword::new(password.as_bytes(), salt.as_bytes(), iterations);
|
||||
let expected = legacy_pbkdf2_impl(password.as_bytes(), salt.as_bytes(), iterations);
|
||||
|
||||
assert_eq!(actual.bytes, output);
|
||||
assert_eq!(actual.bytes, expected.bytes);
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
use super::base64_decode_array;
|
||||
use super::key::ScramKey;
|
||||
|
||||
/// Server secret is produced from [password](super::password::SaltedPassword)
|
||||
/// Server secret is produced from user's password,
|
||||
/// and is used throughout the authentication process.
|
||||
pub struct ServerSecret {
|
||||
/// Number of iterations for `PBKDF2` function.
|
||||
@@ -58,21 +58,10 @@ impl ServerSecret {
|
||||
/// Build a new server secret from the prerequisites.
|
||||
/// XXX: We only use this function in tests.
|
||||
#[cfg(test)]
|
||||
pub fn build(password: &str, salt: &[u8], iterations: u32) -> Option<Self> {
|
||||
// TODO: implement proper password normalization required by the RFC
|
||||
if !password.is_ascii() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let password = super::password::SaltedPassword::new(password.as_bytes(), salt, iterations);
|
||||
|
||||
Some(Self {
|
||||
iterations,
|
||||
salt_base64: base64::encode(salt),
|
||||
stored_key: password.client_key().sha256(),
|
||||
server_key: password.server_key(),
|
||||
doomed: false,
|
||||
})
|
||||
pub fn build(password: &str) -> Option<Self> {
|
||||
Self::parse(&postgres_protocol::password::scram_sha_256(
|
||||
password.as_bytes(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,20 +91,4 @@ mod tests {
|
||||
assert_eq!(base64::encode(parsed.stored_key), stored_key);
|
||||
assert_eq!(base64::encode(parsed.server_key), server_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_scram_secret() {
|
||||
let salt = b"salt";
|
||||
let secret = ServerSecret::build("password", salt, 4096).unwrap();
|
||||
assert_eq!(secret.iterations, 4096);
|
||||
assert_eq!(secret.salt_base64, base64::encode(salt));
|
||||
assert_eq!(
|
||||
base64::encode(secret.stored_key.as_ref()),
|
||||
"lF4cRm/Jky763CN4HtxdHnjV4Q8AWTNlKvGmEFFU8IQ="
|
||||
);
|
||||
assert_eq!(
|
||||
base64::encode(secret.server_key.as_ref()),
|
||||
"ub8OgRsftnk2ccDMOt7ffHXNcikRkQkq1lh4xaAqrSw="
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user