diff --git a/Cargo.lock b/Cargo.lock index a647568f28..6695b3cac9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1249,6 +1249,7 @@ dependencies = [ "clap", "comfy-table", "compute_api", + "futures", "git-version", "hex", "hyper", @@ -1269,6 +1270,8 @@ dependencies = [ "tar", "thiserror", "tokio", + "tokio-postgres", + "tokio-util", "toml", "tracing", "url", diff --git a/control_plane/Cargo.toml b/control_plane/Cargo.toml index fb36ee8aa1..d1487d0c53 100644 --- a/control_plane/Cargo.toml +++ b/control_plane/Cargo.toml @@ -9,6 +9,7 @@ anyhow.workspace = true camino.workspace = true clap.workspace = true comfy-table.workspace = true +futures.workspace = true git-version.workspace = true nix.workspace = true once_cell.workspace = true @@ -24,6 +25,8 @@ tar.workspace = true thiserror.workspace = true toml.workspace = true tokio.workspace = true +tokio-postgres.workspace = true +tokio-util.workspace = true url.workspace = true pageserver = { path = "../pageserver" } pageserver_api.workspace = true diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index f639c6cdc6..5dd6370dac 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -560,7 +560,7 @@ async fn handle_timeline(timeline_match: &ArgMatches, env: &mut local_env::Local let mut cplane = ComputeControlPlane::load(env.clone())?; println!("Importing timeline into pageserver ..."); - pageserver.timeline_import(tenant_id, timeline_id, base, pg_wal, pg_version)?; + pageserver.timeline_import(tenant_id, timeline_id, base, pg_wal, pg_version).await?; env.register_branch_mapping(name.to_string(), tenant_id, timeline_id)?; println!("Creating endpoint for imported timeline ..."); diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index f230973cd0..d812ff12e3 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -16,6 +16,7 @@ use std::time::Duration; use anyhow::{bail, Context}; use camino::Utf8PathBuf; +use futures::{SinkExt, StreamExt}; use pageserver::client::mgmt_api; use pageserver_api::models::{self, LocationConfig, TenantInfo, TimelineInfo}; use pageserver_api::shard::TenantShardId; @@ -282,7 +283,12 @@ impl PageServerNode { background_process::stop_process(immediate, "pageserver", &self.pid_file()) } - pub fn page_server_psql_client(&self) -> anyhow::Result { + pub async fn page_server_psql_client( + &self, + ) -> anyhow::Result<( + tokio_postgres::Client, + tokio_postgres::Connection, + )> { let mut config = self.pg_connection_config.clone(); if self.conf.pg_auth_type == AuthType::NeonJWT { let token = self @@ -290,7 +296,7 @@ impl PageServerNode { .generate_auth_token(&Claims::new(None, Scope::PageServerApi))?; config = config.set_password(Some(token)); } - Ok(config.connect_no_tls()?) + Ok(config.connect_no_tls().await?) } pub async fn check_status(&self) -> mgmt_api::Result<()> { @@ -514,7 +520,7 @@ impl PageServerNode { /// * `timeline_id` - id to assign to imported timeline /// * `base` - (start lsn of basebackup, path to `base.tar` file) /// * `pg_wal` - if there's any wal to import: (end lsn, path to `pg_wal.tar`) - pub fn timeline_import( + pub async fn timeline_import( &self, tenant_id: TenantId, timeline_id: TimelineId, @@ -522,17 +528,25 @@ impl PageServerNode { pg_wal: Option<(Lsn, PathBuf)>, pg_version: u32, ) -> anyhow::Result<()> { - let mut client = self.page_server_psql_client()?; + let (client, conn) = self.page_server_psql_client().await?; + // The connection object performs the actual communication with the database, + // so spawn it off to run on its own. + tokio::spawn(async move { + if let Err(e) = conn.await { + eprintln!("connection error: {}", e); + } + }); + tokio::pin!(client); // Init base reader let (start_lsn, base_tarfile_path) = base; - let base_tarfile = File::open(base_tarfile_path)?; - let mut base_reader = BufReader::new(base_tarfile); + let base_tarfile = tokio::fs::File::open(base_tarfile_path).await?; + let mut base_tarfile = tokio_util::io::ReaderStream::new(base_tarfile); // Init wal reader if necessary let (end_lsn, wal_reader) = if let Some((end_lsn, wal_tarfile_path)) = pg_wal { - let wal_tarfile = File::open(wal_tarfile_path)?; - let wal_reader = BufReader::new(wal_tarfile); + let wal_tarfile = tokio::fs::File::open(wal_tarfile_path).await?; + let wal_reader = tokio_util::io::ReaderStream::new(wal_tarfile); (end_lsn, Some(wal_reader)) } else { (start_lsn, None) @@ -542,16 +556,25 @@ impl PageServerNode { let import_cmd = format!( "import basebackup {tenant_id} {timeline_id} {start_lsn} {end_lsn} {pg_version}" ); - let mut writer = client.copy_in(&import_cmd)?; - io::copy(&mut base_reader, &mut writer)?; - writer.finish()?; + let writer = client.copy_in(&import_cmd).await?; + let mut writer = std::pin::pin!(writer); + let mut writer = + writer.sink_map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{e}"))); + let mut base_tarfile = std::pin::pin!(base_tarfile); + writer.send_all(&mut base_tarfile).await?; + writer.into_inner().finish().await?; // Import wal if necessary if let Some(mut wal_reader) = wal_reader { let import_cmd = format!("import wal {tenant_id} {timeline_id} {start_lsn} {end_lsn}"); - let mut writer = client.copy_in(&import_cmd)?; - io::copy(&mut wal_reader, &mut writer)?; - writer.finish()?; + + let writer = client.copy_in(&import_cmd).await?; + let mut writer = std::pin::pin!(writer); + let mut writer = writer + .sink_map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{e}"))); + let mut wal_reader = std::pin::pin!(wal_reader); + writer.send_all(&mut wal_reader).await?; + writer.into_inner().finish().await?; } Ok(()) diff --git a/libs/postgres_connection/src/lib.rs b/libs/postgres_connection/src/lib.rs index 35cb1a2691..d793abc8d0 100644 --- a/libs/postgres_connection/src/lib.rs +++ b/libs/postgres_connection/src/lib.rs @@ -4,6 +4,7 @@ use anyhow::{bail, Context}; use itertools::Itertools; use std::borrow::Cow; use std::fmt; +use tokio_postgres::tls::NoTlsStream; use url::Host; /// Parses a string of format either `host:port` or `host` into a corresponding pair. @@ -163,8 +164,18 @@ impl PgConnectionConfig { } /// Connect using postgres protocol with TLS disabled. - pub fn connect_no_tls(&self) -> Result { - postgres::Config::from(self.to_tokio_postgres_config()).connect(postgres::NoTls) + pub async fn connect_no_tls( + &self, + ) -> Result< + ( + tokio_postgres::Client, + tokio_postgres::Connection, + ), + postgres::Error, + > { + self.to_tokio_postgres_config() + .connect(postgres::NoTls) + .await } }