Compare commits

...

1 Commits

Author SHA1 Message Date
Arseny Sher
51f672b0bb Rename write_message to write_message_noflush in postgres_backend_async.rs
To make it unifrom across the project; proxy stream.rs and older
postgres_backend uses write_message_noflush.
2023-03-01 20:05:56 +04:00
2 changed files with 48 additions and 46 deletions

View File

@@ -233,7 +233,7 @@ impl PostgresBackend {
} }
/// Write message into internal output buffer. /// Write message into internal output buffer.
pub fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?; BeMessage::write(&mut self.buf_out, message)?;
Ok(self) Ok(self)
} }
@@ -383,7 +383,7 @@ impl PostgresBackend {
FeStartupPacket::SslRequest => { FeStartupPacket::SslRequest => {
debug!("SSL requested"); debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?; self.write_message_noflush(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls { if have_tls {
self.start_tls().await?; self.start_tls().await?;
self.state = ProtoState::Encrypted; self.state = ProtoState::Encrypted;
@@ -391,11 +391,11 @@ impl PostgresBackend {
} }
FeStartupPacket::GssEncRequest => { FeStartupPacket::GssEncRequest => {
debug!("GSS requested"); debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))?; self.write_message_noflush(&BeMessage::EncryptionResponse(false))?;
} }
FeStartupPacket::StartupMessage { .. } => { FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) { if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse( self.write_message_noflush(&BeMessage::ErrorResponse(
"must connect with TLS", "must connect with TLS",
None, None,
))?; ))?;
@@ -410,15 +410,17 @@ impl PostgresBackend {
match self.auth_type { match self.auth_type {
AuthType::Trust => { AuthType::Trust => {
self.write_message(&BeMessage::AuthenticationOk)? self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)? .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
// The async python driver requires a valid server_version // The async python driver requires a valid server_version
.write_message(&BeMessage::server_version("14.1"))? .write_message_noflush(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)?; .write_message_noflush(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established; self.state = ProtoState::Established;
} }
AuthType::NeonJWT => { AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)?; self.write_message_noflush(
&BeMessage::AuthenticationCleartextPassword,
)?;
self.state = ProtoState::Authentication; self.state = ProtoState::Authentication;
} }
} }
@@ -441,7 +443,7 @@ impl PostgresBackend {
let (_, jwt_response) = m.split_last().context("protocol violation")?; let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) { if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message(&BeMessage::ErrorResponse( self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(), &e.to_string(),
Some(e.pg_error_code()), Some(e.pg_error_code()),
))?; ))?;
@@ -449,9 +451,9 @@ impl PostgresBackend {
} }
} }
} }
self.write_message(&BeMessage::AuthenticationOk)? self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)? .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)?; .write_message_noflush(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established; self.state = ProtoState::Established;
} }
@@ -486,30 +488,30 @@ impl PostgresBackend {
if let Err(e) = handler.process_query(self, query_string).await { if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e); log_query_error(query_string, &e);
let short_error = short_error(&e); let short_error = short_error(&e);
self.write_message(&BeMessage::ErrorResponse( self.write_message_noflush(&BeMessage::ErrorResponse(
&short_error, &short_error,
Some(e.pg_error_code()), Some(e.pg_error_code()),
))?; ))?;
} }
self.write_message(&BeMessage::ReadyForQuery)?; self.write_message_noflush(&BeMessage::ReadyForQuery)?;
} }
FeMessage::Parse(m) => { FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string; *unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?; self.write_message_noflush(&BeMessage::ParseComplete)?;
} }
FeMessage::Describe(_) => { FeMessage::Describe(_) => {
self.write_message(&BeMessage::ParameterDescription)? self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?; .write_message_noflush(&BeMessage::NoData)?;
} }
FeMessage::Bind(_) => { FeMessage::Bind(_) => {
self.write_message(&BeMessage::BindComplete)?; self.write_message_noflush(&BeMessage::BindComplete)?;
} }
FeMessage::Close(_) => { FeMessage::Close(_) => {
self.write_message(&BeMessage::CloseComplete)?; self.write_message_noflush(&BeMessage::CloseComplete)?;
} }
FeMessage::Execute(_) => { FeMessage::Execute(_) => {
@@ -517,7 +519,7 @@ impl PostgresBackend {
trace!("got execute {query_string:?}"); trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await { if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e); log_query_error(query_string, &e);
self.write_message(&BeMessage::ErrorResponse( self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(), &e.to_string(),
Some(e.pg_error_code()), Some(e.pg_error_code()),
))?; ))?;
@@ -529,7 +531,7 @@ impl PostgresBackend {
} }
FeMessage::Sync => { FeMessage::Sync => {
self.write_message(&BeMessage::ReadyForQuery)?; self.write_message_noflush(&BeMessage::ReadyForQuery)?;
} }
FeMessage::Terminate => { FeMessage::Terminate => {
@@ -579,7 +581,7 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
// XXX: if the input is large, we should split it into multiple messages. // XXX: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that // Not sure what the threshold should be, but the ultimate hard limit is that
// the length cannot exceed u32. // the length cannot exceed u32.
this.pgb.write_message(&BeMessage::CopyData(buf))?; this.pgb.write_message_noflush(&BeMessage::CopyData(buf))?;
Poll::Ready(Ok(buf.len())) Poll::Ready(Ok(buf.len()))
} }

View File

@@ -64,7 +64,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
_ = task_mgr::shutdown_watcher() => { _ = task_mgr::shutdown_watcher() => {
// We were requested to shut down. // We were requested to shut down.
let msg = format!("pageserver is shutting down"); let msg = format!("pageserver is shutting down");
let _ = pgb.write_message(&BeMessage::ErrorResponse(&msg, None)); let _ = pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None));
Err(QueryError::Other(anyhow::anyhow!(msg))) Err(QueryError::Other(anyhow::anyhow!(msg)))
} }
@@ -80,13 +80,13 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
FeMessage::Terminate => { FeMessage::Terminate => {
let msg = "client terminated connection with Terminate message during COPY"; let msg = "client terminated connection with Terminate message during COPY";
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?; pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
break; break;
} }
m => { m => {
let msg = format!("unexpected message {m:?}"); let msg = format!("unexpected message {m:?}");
pgb.write_message(&BeMessage::ErrorResponse(&msg, None))?; pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None))?;
Err(io::Error::new(io::ErrorKind::Other, msg))?; Err(io::Error::new(io::ErrorKind::Other, msg))?;
break; break;
} }
@@ -97,7 +97,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
Ok(None) => { Ok(None) => {
let msg = "client closed connection during COPY"; let msg = "client closed connection during COPY";
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?; pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
pgb.flush().await?; pgb.flush().await?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
} }
@@ -311,7 +311,7 @@ impl PageServerHandler {
let timeline = tenant.get_timeline(timeline_id, true)?; let timeline = tenant.get_timeline(timeline_id, true)?;
// switch client to COPYBOTH // switch client to COPYBOTH
pgb.write_message(&BeMessage::CopyBothResponse)?; pgb.write_message_noflush(&BeMessage::CopyBothResponse)?;
pgb.flush().await?; pgb.flush().await?;
let metrics = PageRequestMetrics::new(&tenant_id, &timeline_id); let metrics = PageRequestMetrics::new(&tenant_id, &timeline_id);
@@ -380,7 +380,7 @@ impl PageServerHandler {
}) })
}); });
pgb.write_message(&BeMessage::CopyData(&response.serialize()))?; pgb.write_message_noflush(&BeMessage::CopyData(&response.serialize()))?;
pgb.flush().await?; pgb.flush().await?;
} }
Ok(()) Ok(())
@@ -416,7 +416,7 @@ impl PageServerHandler {
// Import basebackup provided via CopyData // Import basebackup provided via CopyData
info!("importing basebackup"); info!("importing basebackup");
pgb.write_message(&BeMessage::CopyInResponse)?; pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
pgb.flush().await?; pgb.flush().await?;
let mut copyin_stream = Box::pin(copyin_stream(pgb)); let mut copyin_stream = Box::pin(copyin_stream(pgb));
@@ -468,7 +468,7 @@ impl PageServerHandler {
// Import wal provided via CopyData // Import wal provided via CopyData
info!("importing wal"); info!("importing wal");
pgb.write_message(&BeMessage::CopyInResponse)?; pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
pgb.flush().await?; pgb.flush().await?;
let mut copyin_stream = Box::pin(copyin_stream(pgb)); let mut copyin_stream = Box::pin(copyin_stream(pgb));
let mut reader = tokio_util::io::StreamReader::new(&mut copyin_stream); let mut reader = tokio_util::io::StreamReader::new(&mut copyin_stream);
@@ -678,7 +678,7 @@ impl PageServerHandler {
} }
// switch client to COPYOUT // switch client to COPYOUT
pgb.write_message(&BeMessage::CopyOutResponse)?; pgb.write_message_noflush(&BeMessage::CopyOutResponse)?;
pgb.flush().await?; pgb.flush().await?;
// Send a tarball of the latest layer on the timeline // Send a tarball of the latest layer on the timeline
@@ -695,7 +695,7 @@ impl PageServerHandler {
.await?; .await?;
} }
pgb.write_message(&BeMessage::CopyDone)?; pgb.write_message_noflush(&BeMessage::CopyDone)?;
pgb.flush().await?; pgb.flush().await?;
info!("basebackup complete"); info!("basebackup complete");
@@ -812,7 +812,7 @@ impl postgres_backend_async::Handler for PageServerHandler {
// Check that the timeline exists // Check that the timeline exists
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, None, false, ctx) self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, None, false, ctx)
.await?; .await?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} }
// return pair of prev_lsn and last_lsn // return pair of prev_lsn and last_lsn
else if query_string.starts_with("get_last_record_rlsn ") { else if query_string.starts_with("get_last_record_rlsn ") {
@@ -835,15 +835,15 @@ impl postgres_backend_async::Handler for PageServerHandler {
let end_of_timeline = timeline.get_last_record_rlsn(); let end_of_timeline = timeline.get_last_record_rlsn();
pgb.write_message(&BeMessage::RowDescription(&[ pgb.write_message_noflush(&BeMessage::RowDescription(&[
RowDescriptor::text_col(b"prev_lsn"), RowDescriptor::text_col(b"prev_lsn"),
RowDescriptor::text_col(b"last_lsn"), RowDescriptor::text_col(b"last_lsn"),
]))? ]))?
.write_message(&BeMessage::DataRow(&[ .write_message_noflush(&BeMessage::DataRow(&[
Some(end_of_timeline.prev.to_string().as_bytes()), Some(end_of_timeline.prev.to_string().as_bytes()),
Some(end_of_timeline.last.to_string().as_bytes()), Some(end_of_timeline.last.to_string().as_bytes()),
]))? ]))?
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} }
// same as basebackup, but result includes relational data as well // same as basebackup, but result includes relational data as well
else if query_string.starts_with("fullbackup ") { else if query_string.starts_with("fullbackup ") {
@@ -884,7 +884,7 @@ impl postgres_backend_async::Handler for PageServerHandler {
// Check that the timeline exists // Check that the timeline exists
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, prev_lsn, true, ctx) self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, prev_lsn, true, ctx)
.await?; .await?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("import basebackup ") { } else if query_string.starts_with("import basebackup ") {
// Import the `base` section (everything but the wal) of a basebackup. // Import the `base` section (everything but the wal) of a basebackup.
// Assumes the tenant already exists on this pageserver. // Assumes the tenant already exists on this pageserver.
@@ -929,10 +929,10 @@ impl postgres_backend_async::Handler for PageServerHandler {
) )
.await .await
{ {
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?, Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => { Err(e) => {
error!("error importing base backup between {base_lsn} and {end_lsn}: {e:?}"); error!("error importing base backup between {base_lsn} and {end_lsn}: {e:?}");
pgb.write_message(&BeMessage::ErrorResponse( pgb.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(), &e.to_string(),
Some(e.pg_error_code()), Some(e.pg_error_code()),
))? ))?
@@ -965,10 +965,10 @@ impl postgres_backend_async::Handler for PageServerHandler {
.handle_import_wal(pgb, tenant_id, timeline_id, start_lsn, end_lsn, ctx) .handle_import_wal(pgb, tenant_id, timeline_id, start_lsn, end_lsn, ctx)
.await .await
{ {
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?, Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => { Err(e) => {
error!("error importing WAL between {start_lsn} and {end_lsn}: {e:?}"); error!("error importing WAL between {start_lsn} and {end_lsn}: {e:?}");
pgb.write_message(&BeMessage::ErrorResponse( pgb.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(), &e.to_string(),
Some(e.pg_error_code()), Some(e.pg_error_code()),
))? ))?
@@ -977,7 +977,7 @@ impl postgres_backend_async::Handler for PageServerHandler {
} else if query_string.to_ascii_lowercase().starts_with("set ") { } else if query_string.to_ascii_lowercase().starts_with("set ") {
// important because psycopg2 executes "SET datestyle TO 'ISO'" // important because psycopg2 executes "SET datestyle TO 'ISO'"
// on connect // on connect
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("show ") { } else if query_string.starts_with("show ") {
// show <tenant_id> // show <tenant_id>
let (_, params_raw) = query_string.split_at("show ".len()); let (_, params_raw) = query_string.split_at("show ".len());
@@ -993,7 +993,7 @@ impl postgres_backend_async::Handler for PageServerHandler {
self.check_permission(Some(tenant_id))?; self.check_permission(Some(tenant_id))?;
let tenant = get_active_tenant_with_timeout(tenant_id, &ctx).await?; let tenant = get_active_tenant_with_timeout(tenant_id, &ctx).await?;
pgb.write_message(&BeMessage::RowDescription(&[ pgb.write_message_noflush(&BeMessage::RowDescription(&[
RowDescriptor::int8_col(b"checkpoint_distance"), RowDescriptor::int8_col(b"checkpoint_distance"),
RowDescriptor::int8_col(b"checkpoint_timeout"), RowDescriptor::int8_col(b"checkpoint_timeout"),
RowDescriptor::int8_col(b"compaction_target_size"), RowDescriptor::int8_col(b"compaction_target_size"),
@@ -1004,7 +1004,7 @@ impl postgres_backend_async::Handler for PageServerHandler {
RowDescriptor::int8_col(b"image_creation_threshold"), RowDescriptor::int8_col(b"image_creation_threshold"),
RowDescriptor::int8_col(b"pitr_interval"), RowDescriptor::int8_col(b"pitr_interval"),
]))? ]))?
.write_message(&BeMessage::DataRow(&[ .write_message_noflush(&BeMessage::DataRow(&[
Some(tenant.get_checkpoint_distance().to_string().as_bytes()), Some(tenant.get_checkpoint_distance().to_string().as_bytes()),
Some( Some(
tenant tenant
@@ -1027,7 +1027,7 @@ impl postgres_backend_async::Handler for PageServerHandler {
Some(tenant.get_image_creation_threshold().to_string().as_bytes()), Some(tenant.get_image_creation_threshold().to_string().as_bytes()),
Some(tenant.get_pitr_interval().as_secs().to_string().as_bytes()), Some(tenant.get_pitr_interval().as_secs().to_string().as_bytes()),
]))? ]))?
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else { } else {
return Err(QueryError::Other(anyhow::anyhow!( return Err(QueryError::Other(anyhow::anyhow!(
"unknown command {query_string}" "unknown command {query_string}"