This commit is contained in:
Conrad Ludgate
2024-04-24 11:43:05 +01:00
parent dd7c4b79e3
commit e4570fb31f
2 changed files with 33 additions and 20 deletions

View File

@@ -11,12 +11,7 @@ use crate::{
use futures::TryFutureExt;
use itertools::Itertools;
use pq_proto::StartupMessageParams;
use std::{
io,
net::{IpAddr, SocketAddr},
str::FromStr,
time::Duration,
};
use std::{io, net::SocketAddr, time::Duration};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::{
@@ -187,14 +182,10 @@ impl ConnCfg {
// wrap TcpStream::connect with timeout
let connect_with_timeout = |host, port| async move {
let addrs = if let Ok(ip) = IpAddr::from_str(host) {
vec![ip]
} else {
dns.resolve(host)
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.collect()
};
let addrs = dns
.resolve(host)
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let timeout = timeout / addrs.len() as u32;
@@ -318,7 +309,8 @@ impl ConnCfg {
client.set_socket_config(SocketConfig {
host: tokio_postgres::config::Host::Tcp(host.to_owned()),
socket_addr,
port: socket_addr.port(),
socket_addr: tokio_postgres::SocketAddr::Tcp(socket_addr),
connect_timeout: None,
keepalive: None,
});

View File

@@ -1,11 +1,11 @@
//! Async dns resolvers
use std::{
net::{IpAddr, SocketAddr},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::Arc,
};
use hickory_resolver::error::ResolveError;
use hickory_resolver::{error::ResolveError, proto::rr::RData};
use tokio::time::Instant;
use tracing::trace;
@@ -30,15 +30,34 @@ impl Dns {
Self { resolver }
}
pub async fn resolve(&self, name: &str) -> Result<impl Iterator<Item = IpAddr>, ResolveError> {
pub async fn resolve(&self, name: &str) -> Result<Vec<IpAddr>, ResolveError> {
let start = Instant::now();
// try to parse the host as a regular IP address first
if let Ok(addr) = name.parse::<Ipv4Addr>() {
return Ok(vec![IpAddr::V4(addr)]);
}
if let Ok(addr) = name.parse::<Ipv6Addr>() {
return Ok(vec![IpAddr::V6(addr)]);
}
let res = self.resolver.lookup_ip(name).await;
let resolve_duration = start.elapsed();
trace!(duration = ?resolve_duration, addr = %name, "resolve host complete");
Ok(res?.into_iter())
Ok(res?
.as_lookup()
.records()
.iter()
.filter_map(|r| r.data())
.filter_map(|rdata| match rdata {
RData::A(ip) => Some(IpAddr::from(ip.0)),
RData::AAAA(ip) => Some(IpAddr::from(ip.0)),
_ => None,
})
.collect())
}
}
@@ -47,7 +66,9 @@ impl reqwest::dns::Resolve for Dns {
let this = self.clone();
Box::pin(async move {
match this.resolve(name.as_str()).await {
Ok(iter) => Ok(Box::new(iter.map(|ip| SocketAddr::new(ip, 0))) as Box<_>),
Ok(iter) => {
Ok(Box::new(iter.into_iter().map(|ip| SocketAddr::new(ip, 0))) as Box<_>)
}
Err(e) => Err(e.into()),
}
})