simplify error handling for query encoding

This commit is contained in:
Conrad Ludgate
2025-05-21 13:37:57 +01:00
parent f3c9d0adf4
commit 13d41b51a2
18 changed files with 246 additions and 313 deletions

View File

@@ -536,7 +536,8 @@ mod tests {
use control_plane::AuthSecret;
use fallible_iterator::FallibleIterator;
use once_cell::sync::Lazy;
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
use postgres_protocol::CSafeStr;
use postgres_protocol::authentication::sasl::{ChannelBinding, SCRAM_SHA_256, ScramSha256};
use postgres_protocol::message::backend::Message as PgMessage;
use postgres_protocol::message::frontend;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
@@ -714,15 +715,15 @@ mod tests {
// server should offer scram
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationSasl(a) => {
let options: Vec<&str> = a.mechanisms().collect().unwrap();
assert_eq!(options, ["SCRAM-SHA-256"]);
let options: Vec<&CSafeStr> = a.mechanisms().collect().unwrap();
assert_eq!(options, [SCRAM_SHA_256]);
}
_ => panic!("wrong message"),
}
// client sends client-first-message
let mut write = BytesMut::new();
frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
frontend::sasl_initial_response(SCRAM_SHA_256, scram.message(), &mut write);
client.write_all(&write).await.unwrap();
// server response with server-first-message
@@ -735,7 +736,7 @@ mod tests {
// client response with client-final-message
write.clear();
frontend::sasl_response(scram.message(), &mut write).unwrap();
frontend::sasl_response(scram.message(), &mut write);
client.write_all(&write).await.unwrap();
// server response with server-final-message
@@ -800,7 +801,7 @@ mod tests {
// client responds with password
write.clear();
frontend::password_message(b"my-secret-password", &mut write).unwrap();
frontend::password_message(c"my-secret-password".into(), &mut write);
client.write_all(&write).await.unwrap();
});
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
@@ -853,8 +854,10 @@ mod tests {
// client responds with password
let mut write = BytesMut::new();
frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
.unwrap();
frontend::password_message(
c"endpoint=my-endpoint;my-secret-password".into(),
&mut write,
);
client.write_all(&write).await.unwrap();
});

View File

@@ -3,7 +3,6 @@
use std::io;
use std::sync::Arc;
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -174,8 +173,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
match sasl.method {
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
scram::SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
scram::SCRAM_SHA_256_PLUS => {
ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus)
}
_ => {}
}

View File

@@ -9,7 +9,7 @@ use std::fmt::Debug;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use postgres_client::tls::TlsConnect;
use postgres_protocol::message::frontend;
use postgres_protocol::{authentication::sasl::SCRAM_SHA_256, message::frontend};
use tokio::io::{AsyncReadExt, DuplexStream};
use tokio_util::codec::{Decoder, Encoder};
@@ -60,8 +60,7 @@ async fn proxy_mitm(
params: startup.params.into(),
},
&mut buf,
)
.unwrap();
);
end_server.send(buf.freeze()).await.unwrap();
// proxy messages between end_client and end_server
@@ -90,7 +89,7 @@ async fn proxy_mitm(
new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap());
let mut buf = BytesMut::new();
frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap();
frontend::sasl_initial_response(SCRAM_SHA_256, &new_message, &mut buf);
end_server.send(buf.freeze()).await.unwrap();
continue;

View File

@@ -21,8 +21,8 @@ pub(crate) use key::ScramKey;
pub(crate) use secret::ServerSecret;
use sha2::{Digest, Sha256};
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
pub(crate) const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
pub(crate) const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
/// A list of supported SCRAM methods.
pub(crate) const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];