mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-27 16:12:56 +00:00
chore(proxy): remove postgres config parser and md5 support (#9990)
Keeping the `mock` postgres cplane adaptor using "stock" tokio-postgres allows us to remove a lot of dead weight from our actual postgres connection logic.
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -4209,7 +4209,6 @@ dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"hmac",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"rand 0.8.5",
|
||||
"sha2",
|
||||
@@ -4612,6 +4611,7 @@ dependencies = [
|
||||
"tikv-jemalloc-ctl",
|
||||
"tikv-jemallocator",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres2",
|
||||
"tokio-rustls 0.26.0",
|
||||
"tokio-tungstenite",
|
||||
|
||||
@@ -10,7 +10,6 @@ byteorder.workspace = true
|
||||
bytes.workspace = true
|
||||
fallible-iterator.workspace = true
|
||||
hmac.workspace = true
|
||||
md-5 = "0.10"
|
||||
memchr = "2.0"
|
||||
rand.workspace = true
|
||||
sha2.workspace = true
|
||||
|
||||
@@ -1,37 +1,2 @@
|
||||
//! Authentication protocol support.
|
||||
use md5::{Digest, Md5};
|
||||
|
||||
pub mod sasl;
|
||||
|
||||
/// Hashes authentication information in a way suitable for use in response
|
||||
/// to an `AuthenticationMd5Password` message.
|
||||
///
|
||||
/// The resulting string should be sent back to the database in a
|
||||
/// `PasswordMessage` message.
|
||||
#[inline]
|
||||
pub fn md5_hash(username: &[u8], password: &[u8], salt: [u8; 4]) -> String {
|
||||
let mut md5 = Md5::new();
|
||||
md5.update(password);
|
||||
md5.update(username);
|
||||
let output = md5.finalize_reset();
|
||||
md5.update(format!("{:x}", output));
|
||||
md5.update(salt);
|
||||
format!("md5{:x}", md5.finalize())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn md5() {
|
||||
let username = b"md5_user";
|
||||
let password = b"password";
|
||||
let salt = [0x2a, 0x3d, 0x8f, 0xe0];
|
||||
|
||||
assert_eq!(
|
||||
md5_hash(username, password, salt),
|
||||
"md562af4dd09bbb41884907a838a3233294"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,7 +79,7 @@ pub enum Message {
|
||||
AuthenticationCleartextPassword,
|
||||
AuthenticationGss,
|
||||
AuthenticationKerberosV5,
|
||||
AuthenticationMd5Password(AuthenticationMd5PasswordBody),
|
||||
AuthenticationMd5Password,
|
||||
AuthenticationOk,
|
||||
AuthenticationScmCredential,
|
||||
AuthenticationSspi,
|
||||
@@ -191,11 +191,7 @@ impl Message {
|
||||
0 => Message::AuthenticationOk,
|
||||
2 => Message::AuthenticationKerberosV5,
|
||||
3 => Message::AuthenticationCleartextPassword,
|
||||
5 => {
|
||||
let mut salt = [0; 4];
|
||||
buf.read_exact(&mut salt)?;
|
||||
Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt })
|
||||
}
|
||||
5 => Message::AuthenticationMd5Password,
|
||||
6 => Message::AuthenticationScmCredential,
|
||||
7 => Message::AuthenticationGss,
|
||||
8 => Message::AuthenticationGssContinue,
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
|
||||
use crate::authentication::sasl;
|
||||
use hmac::{Hmac, Mac};
|
||||
use md5::Md5;
|
||||
use rand::RngCore;
|
||||
use sha2::digest::FixedOutput;
|
||||
use sha2::{Digest, Sha256};
|
||||
@@ -88,20 +87,3 @@ pub(crate) async fn scram_sha_256_salt(
|
||||
base64::encode(server_key)
|
||||
)
|
||||
}
|
||||
|
||||
/// **Not recommended, as MD5 is not considered to be secure.**
|
||||
///
|
||||
/// Hash password using MD5 with the username as the salt.
|
||||
///
|
||||
/// The client may assume the returned string doesn't contain any
|
||||
/// special characters that would require escaping.
|
||||
pub fn md5(password: &[u8], username: &str) -> String {
|
||||
// salt password with username
|
||||
let mut salted_password = Vec::from(password);
|
||||
salted_password.extend_from_slice(username.as_bytes());
|
||||
|
||||
let mut hash = Md5::new();
|
||||
hash.update(&salted_password);
|
||||
let digest = hash.finalize();
|
||||
format!("md5{:x}", digest)
|
||||
}
|
||||
|
||||
@@ -9,11 +9,3 @@ async fn test_encrypt_scram_sha_256() {
|
||||
"SCRAM-SHA-256$4096:AQIDBAUGBwgJCgsMDQ4PEA==$8rrDg00OqaiWXJ7p+sCgHEIaBSHY89ZJl3mfIsf32oY=:05L1f+yZbiN8O0AnO40Og85NNRhvzTS57naKRWCcsIA="
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encrypt_md5() {
|
||||
assert_eq!(
|
||||
password::md5(b"secret", "foo"),
|
||||
"md54ab2c5d00339c4b2a4e921d2dc4edec7"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -6,11 +6,9 @@ use crate::connect_raw::RawConnection;
|
||||
use crate::tls::MakeTlsConnect;
|
||||
use crate::tls::TlsConnect;
|
||||
use crate::{Client, Connection, Error};
|
||||
use std::borrow::Cow;
|
||||
use std::fmt;
|
||||
use std::str;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use std::{error, fmt, iter, mem};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub use postgres_protocol2::authentication::sasl::ScramKeys;
|
||||
@@ -380,99 +378,6 @@ impl Config {
|
||||
self.max_backend_message_size
|
||||
}
|
||||
|
||||
fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
|
||||
match key {
|
||||
"user" => {
|
||||
self.user(value);
|
||||
}
|
||||
"password" => {
|
||||
self.password(value);
|
||||
}
|
||||
"dbname" => {
|
||||
self.dbname(value);
|
||||
}
|
||||
"options" => {
|
||||
self.options(value);
|
||||
}
|
||||
"application_name" => {
|
||||
self.application_name(value);
|
||||
}
|
||||
"sslmode" => {
|
||||
let mode = match value {
|
||||
"disable" => SslMode::Disable,
|
||||
"prefer" => SslMode::Prefer,
|
||||
"require" => SslMode::Require,
|
||||
_ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
|
||||
};
|
||||
self.ssl_mode(mode);
|
||||
}
|
||||
"host" => {
|
||||
for host in value.split(',') {
|
||||
self.host(host);
|
||||
}
|
||||
}
|
||||
"port" => {
|
||||
for port in value.split(',') {
|
||||
let port = if port.is_empty() {
|
||||
5432
|
||||
} else {
|
||||
port.parse()
|
||||
.map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))?
|
||||
};
|
||||
self.port(port);
|
||||
}
|
||||
}
|
||||
"connect_timeout" => {
|
||||
let timeout = value
|
||||
.parse::<i64>()
|
||||
.map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?;
|
||||
if timeout > 0 {
|
||||
self.connect_timeout(Duration::from_secs(timeout as u64));
|
||||
}
|
||||
}
|
||||
"target_session_attrs" => {
|
||||
let target_session_attrs = match value {
|
||||
"any" => TargetSessionAttrs::Any,
|
||||
"read-write" => TargetSessionAttrs::ReadWrite,
|
||||
_ => {
|
||||
return Err(Error::config_parse(Box::new(InvalidValue(
|
||||
"target_session_attrs",
|
||||
))));
|
||||
}
|
||||
};
|
||||
self.target_session_attrs(target_session_attrs);
|
||||
}
|
||||
"channel_binding" => {
|
||||
let channel_binding = match value {
|
||||
"disable" => ChannelBinding::Disable,
|
||||
"prefer" => ChannelBinding::Prefer,
|
||||
"require" => ChannelBinding::Require,
|
||||
_ => {
|
||||
return Err(Error::config_parse(Box::new(InvalidValue(
|
||||
"channel_binding",
|
||||
))))
|
||||
}
|
||||
};
|
||||
self.channel_binding(channel_binding);
|
||||
}
|
||||
"max_backend_message_size" => {
|
||||
let limit = value.parse::<usize>().map_err(|_| {
|
||||
Error::config_parse(Box::new(InvalidValue("max_backend_message_size")))
|
||||
})?;
|
||||
if limit > 0 {
|
||||
self.max_backend_message_size(limit);
|
||||
}
|
||||
}
|
||||
key => {
|
||||
return Err(Error::config_parse(Box::new(UnknownOption(
|
||||
key.to_string(),
|
||||
))));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Opens a connection to a PostgreSQL database.
|
||||
///
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
@@ -499,17 +404,6 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Config {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Config, Error> {
|
||||
match UrlParser::parse(s)? {
|
||||
Some(config) => Ok(config),
|
||||
None => Parser::parse(s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Omit password from debug output
|
||||
impl fmt::Debug for Config {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
@@ -536,360 +430,3 @@ impl fmt::Debug for Config {
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UnknownOption(String);
|
||||
|
||||
impl fmt::Display for UnknownOption {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "unknown option `{}`", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for UnknownOption {}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InvalidValue(&'static str);
|
||||
|
||||
impl fmt::Display for InvalidValue {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "invalid value for option `{}`", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for InvalidValue {}
|
||||
|
||||
struct Parser<'a> {
|
||||
s: &'a str,
|
||||
it: iter::Peekable<str::CharIndices<'a>>,
|
||||
}
|
||||
|
||||
impl<'a> Parser<'a> {
|
||||
fn parse(s: &'a str) -> Result<Config, Error> {
|
||||
let mut parser = Parser {
|
||||
s,
|
||||
it: s.char_indices().peekable(),
|
||||
};
|
||||
|
||||
let mut config = Config::new();
|
||||
|
||||
while let Some((key, value)) = parser.parameter()? {
|
||||
config.param(key, &value)?;
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn skip_ws(&mut self) {
|
||||
self.take_while(char::is_whitespace);
|
||||
}
|
||||
|
||||
fn take_while<F>(&mut self, f: F) -> &'a str
|
||||
where
|
||||
F: Fn(char) -> bool,
|
||||
{
|
||||
let start = match self.it.peek() {
|
||||
Some(&(i, _)) => i,
|
||||
None => return "",
|
||||
};
|
||||
|
||||
loop {
|
||||
match self.it.peek() {
|
||||
Some(&(_, c)) if f(c) => {
|
||||
self.it.next();
|
||||
}
|
||||
Some(&(i, _)) => return &self.s[start..i],
|
||||
None => return &self.s[start..],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn eat(&mut self, target: char) -> Result<(), Error> {
|
||||
match self.it.next() {
|
||||
Some((_, c)) if c == target => Ok(()),
|
||||
Some((i, c)) => {
|
||||
let m = format!(
|
||||
"unexpected character at byte {}: expected `{}` but got `{}`",
|
||||
i, target, c
|
||||
);
|
||||
Err(Error::config_parse(m.into()))
|
||||
}
|
||||
None => Err(Error::config_parse("unexpected EOF".into())),
|
||||
}
|
||||
}
|
||||
|
||||
fn eat_if(&mut self, target: char) -> bool {
|
||||
match self.it.peek() {
|
||||
Some(&(_, c)) if c == target => {
|
||||
self.it.next();
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn keyword(&mut self) -> Option<&'a str> {
|
||||
let s = self.take_while(|c| match c {
|
||||
c if c.is_whitespace() => false,
|
||||
'=' => false,
|
||||
_ => true,
|
||||
});
|
||||
|
||||
if s.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(s)
|
||||
}
|
||||
}
|
||||
|
||||
fn value(&mut self) -> Result<String, Error> {
|
||||
let value = if self.eat_if('\'') {
|
||||
let value = self.quoted_value()?;
|
||||
self.eat('\'')?;
|
||||
value
|
||||
} else {
|
||||
self.simple_value()?
|
||||
};
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn simple_value(&mut self) -> Result<String, Error> {
|
||||
let mut value = String::new();
|
||||
|
||||
while let Some(&(_, c)) = self.it.peek() {
|
||||
if c.is_whitespace() {
|
||||
break;
|
||||
}
|
||||
|
||||
self.it.next();
|
||||
if c == '\\' {
|
||||
if let Some((_, c2)) = self.it.next() {
|
||||
value.push(c2);
|
||||
}
|
||||
} else {
|
||||
value.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
if value.is_empty() {
|
||||
return Err(Error::config_parse("unexpected EOF".into()));
|
||||
}
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn quoted_value(&mut self) -> Result<String, Error> {
|
||||
let mut value = String::new();
|
||||
|
||||
while let Some(&(_, c)) = self.it.peek() {
|
||||
if c == '\'' {
|
||||
return Ok(value);
|
||||
}
|
||||
|
||||
self.it.next();
|
||||
if c == '\\' {
|
||||
if let Some((_, c2)) = self.it.next() {
|
||||
value.push(c2);
|
||||
}
|
||||
} else {
|
||||
value.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::config_parse(
|
||||
"unterminated quoted connection parameter value".into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
|
||||
self.skip_ws();
|
||||
let keyword = match self.keyword() {
|
||||
Some(keyword) => keyword,
|
||||
None => return Ok(None),
|
||||
};
|
||||
self.skip_ws();
|
||||
self.eat('=')?;
|
||||
self.skip_ws();
|
||||
let value = self.value()?;
|
||||
|
||||
Ok(Some((keyword, value)))
|
||||
}
|
||||
}
|
||||
|
||||
// This is a pretty sloppy "URL" parser, but it matches the behavior of libpq, where things really aren't very strict
|
||||
struct UrlParser<'a> {
|
||||
s: &'a str,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl<'a> UrlParser<'a> {
|
||||
fn parse(s: &'a str) -> Result<Option<Config>, Error> {
|
||||
let s = match Self::remove_url_prefix(s) {
|
||||
Some(s) => s,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let mut parser = UrlParser {
|
||||
s,
|
||||
config: Config::new(),
|
||||
};
|
||||
|
||||
parser.parse_credentials()?;
|
||||
parser.parse_host()?;
|
||||
parser.parse_path()?;
|
||||
parser.parse_params()?;
|
||||
|
||||
Ok(Some(parser.config))
|
||||
}
|
||||
|
||||
fn remove_url_prefix(s: &str) -> Option<&str> {
|
||||
for prefix in &["postgres://", "postgresql://"] {
|
||||
if let Some(stripped) = s.strip_prefix(prefix) {
|
||||
return Some(stripped);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
|
||||
match self.s.find(end) {
|
||||
Some(pos) => {
|
||||
let (head, tail) = self.s.split_at(pos);
|
||||
self.s = tail;
|
||||
Some(head)
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn take_all(&mut self) -> &'a str {
|
||||
mem::take(&mut self.s)
|
||||
}
|
||||
|
||||
fn eat_byte(&mut self) {
|
||||
self.s = &self.s[1..];
|
||||
}
|
||||
|
||||
fn parse_credentials(&mut self) -> Result<(), Error> {
|
||||
let creds = match self.take_until(&['@']) {
|
||||
Some(creds) => creds,
|
||||
None => return Ok(()),
|
||||
};
|
||||
self.eat_byte();
|
||||
|
||||
let mut it = creds.splitn(2, ':');
|
||||
let user = self.decode(it.next().unwrap())?;
|
||||
self.config.user(&user);
|
||||
|
||||
if let Some(password) = it.next() {
|
||||
let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
|
||||
self.config.password(password);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_host(&mut self) -> Result<(), Error> {
|
||||
let host = match self.take_until(&['/', '?']) {
|
||||
Some(host) => host,
|
||||
None => self.take_all(),
|
||||
};
|
||||
|
||||
if host.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for chunk in host.split(',') {
|
||||
let (host, port) = if chunk.starts_with('[') {
|
||||
let idx = match chunk.find(']') {
|
||||
Some(idx) => idx,
|
||||
None => return Err(Error::config_parse(InvalidValue("host").into())),
|
||||
};
|
||||
|
||||
let host = &chunk[1..idx];
|
||||
let remaining = &chunk[idx + 1..];
|
||||
let port = if let Some(port) = remaining.strip_prefix(':') {
|
||||
Some(port)
|
||||
} else if remaining.is_empty() {
|
||||
None
|
||||
} else {
|
||||
return Err(Error::config_parse(InvalidValue("host").into()));
|
||||
};
|
||||
|
||||
(host, port)
|
||||
} else {
|
||||
let mut it = chunk.splitn(2, ':');
|
||||
(it.next().unwrap(), it.next())
|
||||
};
|
||||
|
||||
self.host_param(host)?;
|
||||
let port = self.decode(port.unwrap_or("5432"))?;
|
||||
self.config.param("port", &port)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_path(&mut self) -> Result<(), Error> {
|
||||
if !self.s.starts_with('/') {
|
||||
return Ok(());
|
||||
}
|
||||
self.eat_byte();
|
||||
|
||||
let dbname = match self.take_until(&['?']) {
|
||||
Some(dbname) => dbname,
|
||||
None => self.take_all(),
|
||||
};
|
||||
|
||||
if !dbname.is_empty() {
|
||||
self.config.dbname(&self.decode(dbname)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_params(&mut self) -> Result<(), Error> {
|
||||
if !self.s.starts_with('?') {
|
||||
return Ok(());
|
||||
}
|
||||
self.eat_byte();
|
||||
|
||||
while !self.s.is_empty() {
|
||||
let key = match self.take_until(&['=']) {
|
||||
Some(key) => self.decode(key)?,
|
||||
None => return Err(Error::config_parse("unterminated parameter".into())),
|
||||
};
|
||||
self.eat_byte();
|
||||
|
||||
let value = match self.take_until(&['&']) {
|
||||
Some(value) => {
|
||||
self.eat_byte();
|
||||
value
|
||||
}
|
||||
None => self.take_all(),
|
||||
};
|
||||
|
||||
if key == "host" {
|
||||
self.host_param(value)?;
|
||||
} else {
|
||||
let value = self.decode(value)?;
|
||||
self.config.param(&key, &value)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn host_param(&mut self, s: &str) -> Result<(), Error> {
|
||||
let s = self.decode(s)?;
|
||||
self.config.param("host", &s)
|
||||
}
|
||||
|
||||
fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
|
||||
percent_encoding::percent_decode(s.as_bytes())
|
||||
.decode_utf8()
|
||||
.map_err(|e| Error::config_parse(e.into()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ use crate::Error;
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt};
|
||||
use postgres_protocol2::authentication;
|
||||
use postgres_protocol2::authentication::sasl;
|
||||
use postgres_protocol2::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
|
||||
@@ -174,25 +173,11 @@ where
|
||||
|
||||
authenticate_password(stream, pass).await?;
|
||||
}
|
||||
Some(Message::AuthenticationMd5Password(body)) => {
|
||||
can_skip_channel_binding(config)?;
|
||||
|
||||
let user = config
|
||||
.user
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::config("user missing".into()))?;
|
||||
let pass = config
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::config("password missing".into()))?;
|
||||
|
||||
let output = authentication::md5_hash(user.as_bytes(), pass, body.salt());
|
||||
authenticate_password(stream, output.as_bytes()).await?;
|
||||
}
|
||||
Some(Message::AuthenticationSasl(body)) => {
|
||||
authenticate_sasl(stream, body, config).await?;
|
||||
}
|
||||
Some(Message::AuthenticationKerberosV5)
|
||||
Some(Message::AuthenticationMd5Password)
|
||||
| Some(Message::AuthenticationKerberosV5)
|
||||
| Some(Message::AuthenticationScmCredential)
|
||||
| Some(Message::AuthenticationGss)
|
||||
| Some(Message::AuthenticationSspi) => {
|
||||
|
||||
@@ -349,7 +349,6 @@ enum Kind {
|
||||
Parse,
|
||||
Encode,
|
||||
Authentication,
|
||||
ConfigParse,
|
||||
Config,
|
||||
Connect,
|
||||
Timeout,
|
||||
@@ -386,7 +385,6 @@ impl fmt::Display for Error {
|
||||
Kind::Parse => fmt.write_str("error parsing response from server")?,
|
||||
Kind::Encode => fmt.write_str("error encoding message to server")?,
|
||||
Kind::Authentication => fmt.write_str("authentication error")?,
|
||||
Kind::ConfigParse => fmt.write_str("invalid connection string")?,
|
||||
Kind::Config => fmt.write_str("invalid configuration")?,
|
||||
Kind::Connect => fmt.write_str("error connecting to server")?,
|
||||
Kind::Timeout => fmt.write_str("timeout waiting for server")?,
|
||||
@@ -482,10 +480,6 @@ impl Error {
|
||||
Error::new(Kind::Authentication, Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn config_parse(e: Box<dyn error::Error + Sync + Send>) -> Error {
|
||||
Error::new(Kind::ConfigParse, Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn config(e: Box<dyn error::Error + Sync + Send>) -> Error {
|
||||
Error::new(Kind::Config, Some(e))
|
||||
}
|
||||
|
||||
@@ -13,14 +13,12 @@ pub use crate::query::RowStream;
|
||||
pub use crate::row::{Row, SimpleQueryRow};
|
||||
pub use crate::simple_query::SimpleQueryStream;
|
||||
pub use crate::statement::{Column, Statement};
|
||||
use crate::tls::MakeTlsConnect;
|
||||
pub use crate::tls::NoTls;
|
||||
pub use crate::to_statement::ToStatement;
|
||||
pub use crate::transaction::Transaction;
|
||||
pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder};
|
||||
use crate::types::ToSql;
|
||||
use postgres_protocol2::message::backend::ReadyForQueryBody;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// After executing a query, the connection will be in one of these states
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
@@ -72,24 +70,6 @@ mod transaction;
|
||||
mod transaction_builder;
|
||||
pub mod types;
|
||||
|
||||
/// A convenience function which parses a connection string and connects to the database.
|
||||
///
|
||||
/// See the documentation for [`Config`] for details on the connection string format.
|
||||
///
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
///
|
||||
/// [`Config`]: config/struct.Config.html
|
||||
pub async fn connect<T>(
|
||||
config: &str,
|
||||
tls: T,
|
||||
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
|
||||
where
|
||||
T: MakeTlsConnect<TcpStream>,
|
||||
{
|
||||
let config = config.parse::<Config>()?;
|
||||
config.connect(tls).await
|
||||
}
|
||||
|
||||
/// An asynchronous notification.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Notification {
|
||||
|
||||
@@ -6,7 +6,7 @@ license.workspace = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
testing = []
|
||||
testing = ["dep:tokio-postgres"]
|
||||
|
||||
[dependencies]
|
||||
ahash.workspace = true
|
||||
@@ -55,6 +55,7 @@ parquet.workspace = true
|
||||
parquet_derive.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
postgres-client = { package = "tokio-postgres2", path = "../libs/proxy/tokio-postgres2" }
|
||||
postgres-protocol = { package = "postgres-protocol2", path = "../libs/proxy/postgres-protocol2" }
|
||||
pq_proto.workspace = true
|
||||
prometheus.workspace = true
|
||||
@@ -81,7 +82,7 @@ subtle.workspace = true
|
||||
thiserror.workspace = true
|
||||
tikv-jemallocator.workspace = true
|
||||
tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] }
|
||||
tokio-postgres = { package = "tokio-postgres2", path = "../libs/proxy/tokio-postgres2" }
|
||||
tokio-postgres = { workspace = true, optional = true }
|
||||
tokio-rustls.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
@@ -119,3 +120,4 @@ rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
walkdir.workspace = true
|
||||
rand_distr = "0.4"
|
||||
tokio-postgres.workspace = true
|
||||
|
||||
@@ -66,7 +66,7 @@ pub(super) async fn authenticate(
|
||||
|
||||
Ok(ComputeCredentials {
|
||||
info: creds,
|
||||
keys: ComputeCredentialKeys::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256(
|
||||
keys: ComputeCredentialKeys::AuthKeys(postgres_client::config::AuthKeys::ScramSha256(
|
||||
scram_keys,
|
||||
)),
|
||||
})
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use async_trait::async_trait;
|
||||
use postgres_client::config::SslMode;
|
||||
use pq_proto::BeMessage as Be;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{info, info_span};
|
||||
|
||||
use super::ComputeCredentialKeys;
|
||||
|
||||
@@ -11,8 +11,8 @@ pub use console_redirect::ConsoleRedirectBackend;
|
||||
pub(crate) use console_redirect::ConsoleRedirectError;
|
||||
use ipnet::{Ipv4Net, Ipv6Net};
|
||||
use local::LocalBackend;
|
||||
use postgres_client::config::AuthKeys;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
|
||||
@@ -227,7 +227,7 @@ pub(crate) async fn validate_password_and_exchange(
|
||||
};
|
||||
|
||||
Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
|
||||
tokio_postgres::config::AuthKeys::ScramSha256(keys),
|
||||
postgres_client::config::AuthKeys::ScramSha256(keys),
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,11 +3,11 @@ use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||
use postgres_client::{CancelToken, NoTls};
|
||||
use pq_proto::CancelKeyData;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_postgres::{CancelToken, NoTls};
|
||||
use tracing::{debug, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -44,7 +44,7 @@ pub(crate) enum CancelError {
|
||||
IO(#[from] std::io::Error),
|
||||
|
||||
#[error("{0}")]
|
||||
Postgres(#[from] tokio_postgres::Error),
|
||||
Postgres(#[from] postgres_client::Error),
|
||||
|
||||
#[error("rate limit exceeded")]
|
||||
RateLimit,
|
||||
@@ -70,7 +70,7 @@ impl ReportableError for CancelError {
|
||||
impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
||||
pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
|
||||
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||
// HACK: We'd rather get the real backend_pid but postgres_client doesn't
|
||||
// expose it and we don't want to do another roundtrip to query
|
||||
// for it. The client will be able to notice that this is not the
|
||||
// actual backend_pid, but backend_pid is not used for anything
|
||||
|
||||
@@ -6,6 +6,8 @@ use std::time::Duration;
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::{CancelToken, RawConnection};
|
||||
use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use rustls::client::danger::ServerCertVerifier;
|
||||
@@ -13,8 +15,6 @@ use rustls::crypto::ring;
|
||||
use rustls::pki_types::InvalidDnsNameError;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
use tokio_postgres::{CancelToken, RawConnection};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::parse_endpoint_param;
|
||||
@@ -34,9 +34,9 @@ pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum ConnectionError {
|
||||
/// This error doesn't seem to reveal any secrets; for instance,
|
||||
/// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
|
||||
/// `postgres_client::error::Kind` doesn't contain ip addresses and such.
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
Postgres(#[from] tokio_postgres::Error),
|
||||
Postgres(#[from] postgres_client::Error),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
CouldNotConnect(#[from] io::Error),
|
||||
@@ -99,13 +99,13 @@ impl ReportableError for ConnectionError {
|
||||
}
|
||||
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
pub(crate) type ScramKeys = tokio_postgres::config::ScramKeys<32>;
|
||||
pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
|
||||
|
||||
/// A config for establishing a connection to compute node.
|
||||
/// Eventually, `tokio_postgres` will be replaced with something better.
|
||||
/// Eventually, `postgres_client` will be replaced with something better.
|
||||
/// Newtype allows us to implement methods on top of it.
|
||||
#[derive(Clone, Default)]
|
||||
pub(crate) struct ConnCfg(Box<tokio_postgres::Config>);
|
||||
pub(crate) struct ConnCfg(Box<postgres_client::Config>);
|
||||
|
||||
/// Creation and initialization routines.
|
||||
impl ConnCfg {
|
||||
@@ -126,7 +126,7 @@ impl ConnCfg {
|
||||
|
||||
pub(crate) fn get_host(&self) -> Result<Host, WakeComputeError> {
|
||||
match self.0.get_hosts() {
|
||||
[tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()),
|
||||
[postgres_client::config::Host::Tcp(s)] => Ok(s.into()),
|
||||
// we should not have multiple address or unix addresses.
|
||||
_ => Err(WakeComputeError::BadComputeAddress(
|
||||
"invalid compute address".into(),
|
||||
@@ -160,7 +160,7 @@ impl ConnCfg {
|
||||
|
||||
// TODO: This is especially ugly...
|
||||
if let Some(replication) = params.get("replication") {
|
||||
use tokio_postgres::config::ReplicationMode;
|
||||
use postgres_client::config::ReplicationMode;
|
||||
match replication {
|
||||
"true" | "on" | "yes" | "1" => {
|
||||
self.replication_mode(ReplicationMode::Physical);
|
||||
@@ -182,7 +182,7 @@ impl ConnCfg {
|
||||
}
|
||||
|
||||
impl std::ops::Deref for ConnCfg {
|
||||
type Target = tokio_postgres::Config;
|
||||
type Target = postgres_client::Config;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
@@ -199,7 +199,7 @@ impl std::ops::DerefMut for ConnCfg {
|
||||
impl ConnCfg {
|
||||
/// Establish a raw TCP connection to the compute node.
|
||||
async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
|
||||
use tokio_postgres::config::Host;
|
||||
use postgres_client::config::Host;
|
||||
|
||||
// wrap TcpStream::connect with timeout
|
||||
let connect_with_timeout = |host, port| {
|
||||
@@ -224,7 +224,7 @@ impl ConnCfg {
|
||||
})
|
||||
};
|
||||
|
||||
// We can't reuse connection establishing logic from `tokio_postgres` here,
|
||||
// We can't reuse connection establishing logic from `postgres_client` here,
|
||||
// because it has no means for extracting the underlying socket which we
|
||||
// require for our business.
|
||||
let mut connection_error = None;
|
||||
@@ -272,7 +272,7 @@ type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>
|
||||
pub(crate) struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub(crate) stream:
|
||||
tokio_postgres::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
postgres_client::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
/// PostgreSQL connection parameters.
|
||||
pub(crate) params: std::collections::HashMap<String, String>,
|
||||
/// Query cancellation token.
|
||||
|
||||
@@ -5,7 +5,6 @@ use std::sync::Arc;
|
||||
|
||||
use futures::TryFutureExt;
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::Client;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
@@ -165,7 +164,7 @@ impl MockControlPlane {
|
||||
config
|
||||
.host(self.endpoint.host_str().unwrap_or("localhost"))
|
||||
.port(self.endpoint.port().unwrap_or(5432))
|
||||
.ssl_mode(SslMode::Disable);
|
||||
.ssl_mode(postgres_client::config::SslMode::Disable);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
|
||||
@@ -6,8 +6,8 @@ use std::time::Duration;
|
||||
use ::http::header::AUTHORIZATION;
|
||||
use ::http::HeaderName;
|
||||
use futures::TryFutureExt;
|
||||
use postgres_client::config::SslMode;
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use super::super::messages::{ControlPlaneErrorMessage, GetRoleSecret, WakeCompute};
|
||||
|
||||
@@ -84,7 +84,7 @@ pub(crate) trait ReportableError: fmt::Display + Send + 'static {
|
||||
fn get_error_kind(&self) -> ErrorKind;
|
||||
}
|
||||
|
||||
impl ReportableError for tokio_postgres::error::Error {
|
||||
impl ReportableError for postgres_client::error::Error {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
if self.as_db_error().is_some() {
|
||||
ErrorKind::Postgres
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use std::convert::TryFrom;
|
||||
use std::sync::Arc;
|
||||
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use rustls::pki_types::ServerName;
|
||||
use rustls::ClientConfig;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
|
||||
mod private {
|
||||
use std::future::Future;
|
||||
@@ -12,9 +12,9 @@ mod private {
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use postgres_client::tls::{ChannelBinding, TlsConnect};
|
||||
use rustls::pki_types::ServerName;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
|
||||
use tokio_rustls::client::TlsStream;
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
@@ -59,7 +59,7 @@ mod private {
|
||||
|
||||
pub struct RustlsStream<S>(TlsStream<S>);
|
||||
|
||||
impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
|
||||
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
|
||||
@@ -31,9 +31,9 @@ impl CouldRetry for io::Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for tokio_postgres::error::DbError {
|
||||
impl CouldRetry for postgres_client::error::DbError {
|
||||
fn could_retry(&self) -> bool {
|
||||
use tokio_postgres::error::SqlState;
|
||||
use postgres_client::error::SqlState;
|
||||
matches!(
|
||||
self.code(),
|
||||
&SqlState::CONNECTION_FAILURE
|
||||
@@ -43,9 +43,9 @@ impl CouldRetry for tokio_postgres::error::DbError {
|
||||
)
|
||||
}
|
||||
}
|
||||
impl ShouldRetryWakeCompute for tokio_postgres::error::DbError {
|
||||
impl ShouldRetryWakeCompute for postgres_client::error::DbError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
use tokio_postgres::error::SqlState;
|
||||
use postgres_client::error::SqlState;
|
||||
// Here are errors that happens after the user successfully authenticated to the database.
|
||||
// TODO: there are pgbouncer errors that should be retried, but they are not listed here.
|
||||
!matches!(
|
||||
@@ -61,21 +61,21 @@ impl ShouldRetryWakeCompute for tokio_postgres::error::DbError {
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for tokio_postgres::Error {
|
||||
impl CouldRetry for postgres_client::Error {
|
||||
fn could_retry(&self) -> bool {
|
||||
if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) {
|
||||
io::Error::could_retry(io_err)
|
||||
} else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
|
||||
tokio_postgres::error::DbError::could_retry(db_err)
|
||||
postgres_client::error::DbError::could_retry(db_err)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
impl ShouldRetryWakeCompute for tokio_postgres::Error {
|
||||
impl ShouldRetryWakeCompute for postgres_client::Error {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
|
||||
tokio_postgres::error::DbError::should_retry_wake_compute(db_err)
|
||||
postgres_client::error::DbError::should_retry_wake_compute(db_err)
|
||||
} else {
|
||||
// likely an IO error. Possible the compute has shutdown and the
|
||||
// cache is stale.
|
||||
|
||||
@@ -8,9 +8,9 @@ use std::fmt::Debug;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use postgres_client::tls::TlsConnect;
|
||||
use postgres_protocol::message::frontend;
|
||||
use tokio::io::{AsyncReadExt, DuplexStream};
|
||||
use tokio_postgres::tls::TlsConnect;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
use super::*;
|
||||
@@ -158,8 +158,8 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
Scram::new("password").await?,
|
||||
));
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
|
||||
let _client_err = postgres_client::Config::new()
|
||||
.channel_binding(postgres_client::config::ChannelBinding::Disable)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
@@ -175,7 +175,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::None,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
postgres_client::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -185,7 +185,7 @@ async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> {
|
||||
async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::Methods,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
postgres_client::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -195,7 +195,7 @@ async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::SASLResponse,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
postgres_client::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -205,7 +205,7 @@ async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Resul
|
||||
async fn scram_auth_require_channel_binding() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::None,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
postgres_client::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -215,7 +215,7 @@ async fn scram_auth_require_channel_binding() -> anyhow::Result<()> {
|
||||
async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::Methods,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
postgres_client::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -225,14 +225,14 @@ async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
async fn scram_auth_require_channel_binding_intercept_response() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::SASLResponse,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
postgres_client::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn connect_failure(
|
||||
intercept: Intercept,
|
||||
channel_binding: tokio_postgres::config::ChannelBinding,
|
||||
channel_binding: postgres_client::config::ChannelBinding,
|
||||
) -> anyhow::Result<()> {
|
||||
let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
@@ -241,7 +241,7 @@ async fn connect_failure(
|
||||
Scram::new("password").await?,
|
||||
));
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
let _client_err = postgres_client::Config::new()
|
||||
.channel_binding(channel_binding)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
|
||||
@@ -7,13 +7,13 @@ use std::time::Duration;
|
||||
use anyhow::{bail, Context};
|
||||
use async_trait::async_trait;
|
||||
use http::StatusCode;
|
||||
use postgres_client::config::SslMode;
|
||||
use postgres_client::tls::{MakeTlsConnect, NoTls};
|
||||
use retry::{retry_after, ShouldRetryWakeCompute};
|
||||
use rstest::rstest;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::pki_types;
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
|
||||
use super::connect_compute::ConnectMechanism;
|
||||
use super::retry::CouldRetry;
|
||||
@@ -204,7 +204,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
|
||||
let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
|
||||
|
||||
let client_err = tokio_postgres::Config::new()
|
||||
let client_err = postgres_client::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Disable)
|
||||
@@ -233,7 +233,7 @@ async fn handshake_tls() -> anyhow::Result<()> {
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
|
||||
|
||||
let _conn = tokio_postgres::Config::new()
|
||||
let _conn = postgres_client::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Require)
|
||||
@@ -249,7 +249,7 @@ async fn handshake_raw() -> anyhow::Result<()> {
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
|
||||
|
||||
let _conn = tokio_postgres::Config::new()
|
||||
let _conn = postgres_client::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.options("project=generic-project-name")
|
||||
@@ -296,8 +296,8 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
Scram::new(password).await?,
|
||||
));
|
||||
|
||||
let _conn = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Require)
|
||||
let _conn = postgres_client::Config::new()
|
||||
.channel_binding(postgres_client::config::ChannelBinding::Require)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(password)
|
||||
@@ -320,8 +320,8 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
Scram::new("password").await?,
|
||||
));
|
||||
|
||||
let _conn = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
|
||||
let _conn = postgres_client::Config::new()
|
||||
.channel_binding(postgres_client::config::ChannelBinding::Disable)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
@@ -348,7 +348,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
let _client_err = postgres_client::Config::new()
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(&password) // no password will match the mocked secret
|
||||
|
||||
@@ -37,9 +37,9 @@ use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
|
||||
|
||||
pub(crate) struct PoolingBackend {
|
||||
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
|
||||
pub(crate) local_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) local_pool: Arc<LocalConnPool<postgres_client::Client>>,
|
||||
pub(crate) pool:
|
||||
Arc<GlobalConnPool<tokio_postgres::Client, EndpointConnPool<tokio_postgres::Client>>>,
|
||||
Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
|
||||
|
||||
pub(crate) config: &'static ProxyConfig,
|
||||
pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
|
||||
@@ -170,7 +170,7 @@ impl PoolingBackend {
|
||||
conn_info: ConnInfo,
|
||||
keys: ComputeCredentials,
|
||||
force_new: bool,
|
||||
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
|
||||
) -> Result<Client<postgres_client::Client>, HttpConnError> {
|
||||
let maybe_client = if force_new {
|
||||
debug!("pool: pool is disabled");
|
||||
None
|
||||
@@ -256,7 +256,7 @@ impl PoolingBackend {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
|
||||
) -> Result<Client<postgres_client::Client>, HttpConnError> {
|
||||
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
|
||||
return Ok(client);
|
||||
}
|
||||
@@ -315,7 +315,7 @@ impl PoolingBackend {
|
||||
));
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
|
||||
let (client, connection) = config.connect(postgres_client::NoTls).await?;
|
||||
drop(pause);
|
||||
|
||||
let pid = client.get_process_id();
|
||||
@@ -360,7 +360,7 @@ pub(crate) enum HttpConnError {
|
||||
#[error("pooled connection closed at inconsistent state")]
|
||||
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
|
||||
#[error("could not connection to postgres in compute")]
|
||||
PostgresConnectionError(#[from] tokio_postgres::Error),
|
||||
PostgresConnectionError(#[from] postgres_client::Error),
|
||||
#[error("could not connection to local-proxy in compute")]
|
||||
LocalProxyConnectionError(#[from] LocalProxyConnError),
|
||||
#[error("could not parse JWT payload")]
|
||||
@@ -479,7 +479,7 @@ impl ShouldRetryWakeCompute for LocalProxyConnError {
|
||||
}
|
||||
|
||||
struct TokioMechanism {
|
||||
pool: Arc<GlobalConnPool<tokio_postgres::Client, EndpointConnPool<tokio_postgres::Client>>>,
|
||||
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
|
||||
conn_info: ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
@@ -489,7 +489,7 @@ struct TokioMechanism {
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for TokioMechanism {
|
||||
type Connection = Client<tokio_postgres::Client>;
|
||||
type Connection = Client<postgres_client::Client>;
|
||||
type ConnectError = HttpConnError;
|
||||
type Error = HttpConnError;
|
||||
|
||||
@@ -509,7 +509,7 @@ impl ConnectMechanism for TokioMechanism {
|
||||
.connect_timeout(timeout);
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let res = config.connect(tokio_postgres::NoTls).await;
|
||||
let res = config.connect(postgres_client::NoTls).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
|
||||
@@ -5,11 +5,11 @@ use std::task::{ready, Poll};
|
||||
|
||||
use futures::future::poll_fn;
|
||||
use futures::Future;
|
||||
use postgres_client::tls::NoTlsStream;
|
||||
use postgres_client::AsyncMessage;
|
||||
use smallvec::SmallVec;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::AsyncMessage;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
#[cfg(test)]
|
||||
@@ -58,7 +58,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
ctx: &RequestContext,
|
||||
conn_info: ConnInfo,
|
||||
client: C,
|
||||
mut connection: tokio_postgres::Connection<TcpStream, NoTlsStream>,
|
||||
mut connection: postgres_client::Connection<TcpStream, NoTlsStream>,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client<C> {
|
||||
|
||||
@@ -7,8 +7,8 @@ use std::time::Duration;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::RwLock;
|
||||
use postgres_client::ReadyForQueryStatus;
|
||||
use rand::Rng;
|
||||
use tokio_postgres::ReadyForQueryStatus;
|
||||
use tracing::{debug, info, Span};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
@@ -683,7 +683,7 @@ pub(crate) trait ClientInnerExt: Sync + Send + 'static {
|
||||
fn get_process_id(&self) -> i32;
|
||||
}
|
||||
|
||||
impl ClientInnerExt for tokio_postgres::Client {
|
||||
impl ClientInnerExt for postgres_client::Client {
|
||||
fn is_closed(&self) -> bool {
|
||||
self.is_closed()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use postgres_client::types::{Kind, Type};
|
||||
use postgres_client::Row;
|
||||
use serde_json::{Map, Value};
|
||||
use tokio_postgres::types::{Kind, Type};
|
||||
use tokio_postgres::Row;
|
||||
|
||||
//
|
||||
// Convert json non-string types to strings, so that they can be passed to Postgres
|
||||
@@ -61,7 +61,7 @@ fn json_array_to_pg_array(value: &Value) -> Option<String> {
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum JsonConversionError {
|
||||
#[error("internal error compute returned invalid data: {0}")]
|
||||
AsTextError(tokio_postgres::Error),
|
||||
AsTextError(postgres_client::Error),
|
||||
#[error("parse int error: {0}")]
|
||||
ParseIntError(#[from] std::num::ParseIntError),
|
||||
#[error("parse float error: {0}")]
|
||||
|
||||
@@ -22,13 +22,13 @@ use indexmap::IndexMap;
|
||||
use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding};
|
||||
use p256::ecdsa::{Signature, SigningKey};
|
||||
use parking_lot::RwLock;
|
||||
use postgres_client::tls::NoTlsStream;
|
||||
use postgres_client::types::ToSql;
|
||||
use postgres_client::AsyncMessage;
|
||||
use serde_json::value::RawValue;
|
||||
use signature::Signer;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::types::ToSql;
|
||||
use tokio_postgres::AsyncMessage;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||
|
||||
@@ -164,7 +164,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
ctx: &RequestContext,
|
||||
conn_info: ConnInfo,
|
||||
client: C,
|
||||
mut connection: tokio_postgres::Connection<TcpStream, NoTlsStream>,
|
||||
mut connection: postgres_client::Connection<TcpStream, NoTlsStream>,
|
||||
key: SigningKey,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
@@ -280,7 +280,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
)
|
||||
}
|
||||
|
||||
impl ClientInnerCommon<tokio_postgres::Client> {
|
||||
impl ClientInnerCommon<postgres_client::Client> {
|
||||
pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> {
|
||||
if let ClientDataEnum::Local(local_data) = &mut self.data {
|
||||
local_data.jti += 1;
|
||||
|
||||
@@ -11,12 +11,12 @@ use http_body_util::{BodyExt, Full};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::http::{HeaderName, HeaderValue};
|
||||
use hyper::{header, HeaderMap, Request, Response, StatusCode};
|
||||
use postgres_client::error::{DbError, ErrorPosition, SqlState};
|
||||
use postgres_client::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction};
|
||||
use pq_proto::StartupMessageParamsBuilder;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use tokio::time::{self, Instant};
|
||||
use tokio_postgres::error::{DbError, ErrorPosition, SqlState};
|
||||
use tokio_postgres::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info};
|
||||
use typed_json::json;
|
||||
@@ -361,7 +361,7 @@ pub(crate) enum SqlOverHttpError {
|
||||
#[error("invalid isolation level")]
|
||||
InvalidIsolationLevel,
|
||||
#[error("{0}")]
|
||||
Postgres(#[from] tokio_postgres::Error),
|
||||
Postgres(#[from] postgres_client::Error),
|
||||
#[error("{0}")]
|
||||
JsonConversion(#[from] JsonConversionError),
|
||||
#[error("{0}")]
|
||||
@@ -986,7 +986,7 @@ async fn query_to_json<T: GenericClient>(
|
||||
// Manually drain the stream into a vector to leave row_stream hanging
|
||||
// around to get a command tag. Also check that the response is not too
|
||||
// big.
|
||||
let mut rows: Vec<tokio_postgres::Row> = Vec::new();
|
||||
let mut rows: Vec<postgres_client::Row> = Vec::new();
|
||||
while let Some(row) = row_stream.next().await {
|
||||
let row = row?;
|
||||
*current_size += row.body_len();
|
||||
@@ -1063,13 +1063,13 @@ async fn query_to_json<T: GenericClient>(
|
||||
}
|
||||
|
||||
enum Client {
|
||||
Remote(conn_pool_lib::Client<tokio_postgres::Client>),
|
||||
Local(conn_pool_lib::Client<tokio_postgres::Client>),
|
||||
Remote(conn_pool_lib::Client<postgres_client::Client>),
|
||||
Local(conn_pool_lib::Client<postgres_client::Client>),
|
||||
}
|
||||
|
||||
enum Discard<'a> {
|
||||
Remote(conn_pool_lib::Discard<'a, tokio_postgres::Client>),
|
||||
Local(conn_pool_lib::Discard<'a, tokio_postgres::Client>),
|
||||
Remote(conn_pool_lib::Discard<'a, postgres_client::Client>),
|
||||
Local(conn_pool_lib::Discard<'a, postgres_client::Client>),
|
||||
}
|
||||
|
||||
impl Client {
|
||||
@@ -1080,7 +1080,7 @@ impl Client {
|
||||
}
|
||||
}
|
||||
|
||||
fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
|
||||
fn inner(&mut self) -> (&mut postgres_client::Client, Discard<'_>) {
|
||||
match self {
|
||||
Client::Remote(client) => {
|
||||
let (c, d) = client.inner();
|
||||
|
||||
Reference in New Issue
Block a user