mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-13 00:12:54 +00:00
99 lines
2.9 KiB
Rust
99 lines
2.9 KiB
Rust
use crate::cancellation::CancelClosure;
|
|
use crate::error::UserFacingError;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::io;
|
|
use std::net::SocketAddr;
|
|
use thiserror::Error;
|
|
use tokio::net::TcpStream;
|
|
use tokio_postgres::NoTls;
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum ConnectionError {
|
|
/// This error doesn't seem to reveal any secrets; for instance,
|
|
/// [`tokio_postgres::error::Kind`] doesn't contain ip addresses and such.
|
|
#[error("Failed to connect to the compute node: {0}")]
|
|
Postgres(#[from] tokio_postgres::Error),
|
|
|
|
#[error("Failed to connect to the compute node")]
|
|
FailedToConnectToCompute,
|
|
|
|
#[error("Failed to fetch compute node version")]
|
|
FailedToFetchPgVersion,
|
|
}
|
|
|
|
impl UserFacingError for ConnectionError {}
|
|
|
|
/// Compute node connection params.
|
|
#[derive(Serialize, Deserialize, Default)]
|
|
pub struct DatabaseInfo {
|
|
pub host: String,
|
|
pub port: u16,
|
|
pub dbname: String,
|
|
pub user: String,
|
|
pub password: Option<String>,
|
|
}
|
|
|
|
// Manually implement debug to omit personal and sensitive info
|
|
impl std::fmt::Debug for DatabaseInfo {
|
|
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
fmt.debug_struct("DatabaseInfo")
|
|
.field("host", &self.host)
|
|
.field("port", &self.port)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
/// PostgreSQL version as [`String`].
|
|
pub type Version = String;
|
|
|
|
impl DatabaseInfo {
|
|
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> {
|
|
let host_port = format!("{}:{}", self.host, self.port);
|
|
let socket = TcpStream::connect(host_port).await?;
|
|
let socket_addr = socket.peer_addr()?;
|
|
socket2::SockRef::from(&socket).set_keepalive(true)?;
|
|
|
|
Ok((socket_addr, socket))
|
|
}
|
|
|
|
/// Connect to a corresponding compute node.
|
|
pub async fn connect(self) -> Result<(TcpStream, Version, CancelClosure), ConnectionError> {
|
|
let (socket_addr, mut socket) = self
|
|
.connect_raw()
|
|
.await
|
|
.map_err(|_| ConnectionError::FailedToConnectToCompute)?;
|
|
|
|
// TODO: establish a secure connection to the DB
|
|
let (client, conn) = tokio_postgres::Config::from(self)
|
|
.connect_raw(&mut socket, NoTls)
|
|
.await?;
|
|
|
|
let version = conn
|
|
.parameter("server_version")
|
|
.ok_or(ConnectionError::FailedToFetchPgVersion)?
|
|
.into();
|
|
|
|
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
|
|
|
|
Ok((socket, version, cancel_closure))
|
|
}
|
|
}
|
|
|
|
impl From<DatabaseInfo> for tokio_postgres::Config {
|
|
fn from(db_info: DatabaseInfo) -> Self {
|
|
let mut config = tokio_postgres::Config::new();
|
|
|
|
config
|
|
.host(&db_info.host)
|
|
.port(db_info.port)
|
|
.dbname(&db_info.dbname)
|
|
.user(&db_info.user);
|
|
|
|
if let Some(password) = db_info.password {
|
|
config.password(password);
|
|
}
|
|
|
|
config
|
|
}
|
|
}
|