From e4570fb31f341186aa122ef87d984f9129ecf673 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 24 Apr 2024 11:43:05 +0100 Subject: [PATCH] fmt --- proxy/src/compute.rs | 22 +++++++--------------- proxy/src/dns.rs | 31 ++++++++++++++++++++++++++----- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 6e14982844..ba56120ae2 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -11,12 +11,7 @@ use crate::{ use futures::TryFutureExt; use itertools::Itertools; use pq_proto::StartupMessageParams; -use std::{ - io, - net::{IpAddr, SocketAddr}, - str::FromStr, - time::Duration, -}; +use std::{io, net::SocketAddr, time::Duration}; use thiserror::Error; use tokio::net::TcpStream; use tokio_postgres::{ @@ -187,14 +182,10 @@ impl ConnCfg { // wrap TcpStream::connect with timeout let connect_with_timeout = |host, port| async move { - let addrs = if let Ok(ip) = IpAddr::from_str(host) { - vec![ip] - } else { - dns.resolve(host) - .await - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))? - .collect() - }; + let addrs = dns + .resolve(host) + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; let timeout = timeout / addrs.len() as u32; @@ -318,7 +309,8 @@ impl ConnCfg { client.set_socket_config(SocketConfig { host: tokio_postgres::config::Host::Tcp(host.to_owned()), - socket_addr, + port: socket_addr.port(), + socket_addr: tokio_postgres::SocketAddr::Tcp(socket_addr), connect_timeout: None, keepalive: None, }); diff --git a/proxy/src/dns.rs b/proxy/src/dns.rs index 2cf4027a61..ffe3bf9b95 100644 --- a/proxy/src/dns.rs +++ b/proxy/src/dns.rs @@ -1,11 +1,11 @@ //! Async dns resolvers use std::{ - net::{IpAddr, SocketAddr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, sync::Arc, }; -use hickory_resolver::error::ResolveError; +use hickory_resolver::{error::ResolveError, proto::rr::RData}; use tokio::time::Instant; use tracing::trace; @@ -30,15 +30,34 @@ impl Dns { Self { resolver } } - pub async fn resolve(&self, name: &str) -> Result, ResolveError> { + pub async fn resolve(&self, name: &str) -> Result, ResolveError> { let start = Instant::now(); + // try to parse the host as a regular IP address first + if let Ok(addr) = name.parse::() { + return Ok(vec![IpAddr::V4(addr)]); + } + + if let Ok(addr) = name.parse::() { + return Ok(vec![IpAddr::V6(addr)]); + } + let res = self.resolver.lookup_ip(name).await; let resolve_duration = start.elapsed(); trace!(duration = ?resolve_duration, addr = %name, "resolve host complete"); - Ok(res?.into_iter()) + Ok(res? + .as_lookup() + .records() + .iter() + .filter_map(|r| r.data()) + .filter_map(|rdata| match rdata { + RData::A(ip) => Some(IpAddr::from(ip.0)), + RData::AAAA(ip) => Some(IpAddr::from(ip.0)), + _ => None, + }) + .collect()) } } @@ -47,7 +66,9 @@ impl reqwest::dns::Resolve for Dns { let this = self.clone(); Box::pin(async move { match this.resolve(name.as_str()).await { - Ok(iter) => Ok(Box::new(iter.map(|ip| SocketAddr::new(ip, 0))) as Box<_>), + Ok(iter) => { + Ok(Box::new(iter.into_iter().map(|ip| SocketAddr::new(ip, 0))) as Box<_>) + } Err(e) => Err(e.into()), } })