use crate::client::SocketConfig; use crate::codec::BackendMessage; use crate::config::{Host, TargetSessionAttrs}; use crate::connect_raw::connect_raw; use crate::connect_socket::connect_socket; use crate::tls::{MakeTlsConnect, TlsConnect}; use crate::{Client, Config, Connection, Error, RawConnection, SimpleQueryMessage}; use futures_util::{future, pin_mut, Future, FutureExt, Stream}; use postgres_protocol2::message::backend::Message; use std::io; use std::task::Poll; use tokio::net::TcpStream; use tokio::sync::mpsc; pub async fn connect( mut tls: T, config: &Config, ) -> Result<(Client, Connection), Error> where T: MakeTlsConnect, { let hostname = match &config.host { Host::Tcp(host) => host.as_str(), }; let tls = tls .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; match connect_once(&config.host, config.port, tls, config).await { Ok((client, connection)) => Ok((client, connection)), Err(e) => Err(e), } } async fn connect_once( host: &Host, port: u16, tls: T, config: &Config, ) -> Result<(Client, Connection), Error> where T: TlsConnect, { let socket = connect_socket(host, port, config.connect_timeout).await?; let RawConnection { stream, parameters, delayed_notice, process_id, secret_key, } = connect_raw(socket, tls, config).await?; let socket_config = SocketConfig { host: host.clone(), port, connect_timeout: config.connect_timeout, }; let (sender, receiver) = mpsc::unbounded_channel(); let client = Client::new( sender, socket_config, config.ssl_mode, process_id, secret_key, ); // delayed notices are always sent as "Async" messages. let delayed = delayed_notice .into_iter() .map(|m| BackendMessage::Async(Message::NoticeResponse(m))) .collect(); let mut connection = Connection::new(stream, delayed, parameters, receiver); if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { let rows = client.simple_query_raw("SHOW transaction_read_only"); pin_mut!(rows); let rows = future::poll_fn(|cx| { if connection.poll_unpin(cx)?.is_ready() { return Poll::Ready(Err(Error::closed())); } rows.as_mut().poll(cx) }) .await?; pin_mut!(rows); loop { let next = future::poll_fn(|cx| { if connection.poll_unpin(cx)?.is_ready() { return Poll::Ready(Some(Err(Error::closed()))); } rows.as_mut().poll_next(cx) }); match next.await.transpose()? { Some(SimpleQueryMessage::Row(row)) => { if row.try_get(0)? == Some("on") { return Err(Error::connect(io::Error::new( io::ErrorKind::PermissionDenied, "database does not allow writes", ))); } else { break; } } Some(_) => {} None => return Err(Error::unexpected_message()), } } } Ok((client, connection)) }