Compare commits

..

10 Commits

Author SHA1 Message Date
Dmitry Ivanov
2e4bf7cee4 [proxy] Immediately log all compute node connection errors. 2023-03-14 01:45:57 +03:00
Dmitry Ivanov
15ed6af5f2 Add descriptions to proxy's python tests. 2023-03-14 01:32:37 +03:00
Dmitry Rodionov
50476a7cc7 test: update to match current interfaces 2023-03-13 17:50:10 +02:00
andres
d7ab69f303 add test for getting branchpoints from an inactive timeline 2023-03-13 17:50:10 +02:00
sharnoff
582620274a Enable file cache handling by vm-informant (#3794)
Enables the VM informant's file cache integration.

See also: https://github.com/neondatabase/autoscaling/pull/47
2023-03-13 07:16:39 -07:00
Vadim Kharitonov
daeaa767c4 Add neondatabase/release team as a default reviewers for storage
releases
2023-03-13 13:40:15 +01:00
Konstantin Knizhnik
f0573f5991 Remove block cursor cache (#3740)
## Describe your changes

Do not pin current block in BlockCursor

## Issue ticket number and link

See #3712 

There are places (see get_reconstruct_data) in our code when thread is
holding read layers lock and then try to read file and so lock page
cache slot. So we have edge in dependency graph layers->page cache slot.
At the same time (as Christian noticed) we can lock page cache slot in
BlockCursor and then try obtain shard lock on layers.
So there is backward edge in dependency graph page cache slot>layers
which forms loop and may cause deadlock.

There are three possible fixes of the problem:
1. Perform compaction under `layers` shared lock. See PR #3732. It fixes
the problem but make it not possible to append any data to pageserver
until compaction is completed.
2. Do not hold `layers` lock while accessing layers (not sure if it is
possible to do because it definitely introduce some new race
conditions).
3. Do not pin current pages in BockCursor (this PR).

My experiments shows that this cache in BlockCursor is not so useful:
the number of hits/misses for cursor cache on pgbench workload (-i -s
10/-c 10 -T 100/-c 10 -S -T 100):
```
hits: 163011
misses: 1023602
```
So number of cache misses is 10x times  larger.
And results for read-only pgbench are mostly the same:
```
with cache: 14581
w/out cache: 14429
```
 
## Checklist before requesting a review
- [ ] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.
2023-03-13 14:07:35 +02:00
Nikita Kalyanov
07dcf679de set content type explicitly (#3799)
I moved management API v2 to ogen and the generated code seems to be
more strict about content type. Let's set it properly as it is json
after all

## Describe your changes

## Issue ticket number and link

## Checklist before requesting a review
- [x] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.
2023-03-13 14:00:01 +02:00
Heikki Linnakangas
e0ee138a8b Add a test for tokio-postgres client to the driver test suite.
It is fully supported. To enable TLS, though, it requires some extra
glue code, and a dependency to a TLS library.
2023-03-13 12:14:37 +02:00
Arthur Petukhovsky
d9a1329834 Make postgres_backend use generic IO type (#3789)
- Support measuring inbound and outbound traffic in MeasuredStream
- Start using MeasuredStream in safekeepers code
2023-03-13 12:18:10 +03:00
34 changed files with 1449 additions and 258 deletions

View File

@@ -118,7 +118,7 @@
cmd: |
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers/$INSTANCE_ID | jq '.version = {{ current_version }}' > /tmp/new_version
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/pageservers
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/pageservers
tags:
- pageserver
@@ -188,6 +188,6 @@
cmd: |
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers/$INSTANCE_ID | jq '.version = {{ current_version }}' > /tmp/new_version
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/safekeepers
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/safekeepers
tags:
- safekeeper

View File

@@ -26,7 +26,7 @@ EOF
if ! curl -sf -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers/${INSTANCE_ID} -o /dev/null; then
# not registered, so register it now
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers -d@/tmp/payload | jq -r '.id')
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" {{ console_mgmt_base_url }}/management/api/v2/pageservers -d@/tmp/payload | jq -r '.id')
# init pageserver
sudo -u pageserver /usr/local/bin/pageserver -c "id=${ID}" -c "pg_distrib_dir='/usr/local'" --init -D /storage/pageserver/data

View File

@@ -25,7 +25,7 @@ EOF
if ! curl -sf -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers/${INSTANCE_ID} -o /dev/null; then
# not registered, so register it now
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers -d@/tmp/payload | jq -r '.id')
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" {{ console_mgmt_base_url }}/management/api/v2/safekeepers -d@/tmp/payload | jq -r '.id')
# init safekeeper
sudo -u safekeeper /usr/local/bin/safekeeper --id ${ID} --init -D /storage/safekeeper/data
fi

View File

@@ -31,3 +31,4 @@ jobs:
head: releases/${{ steps.date.outputs.date }}
base: release
title: Release ${{ steps.date.outputs.date }}
team_reviewers: release

1
Cargo.lock generated
View File

@@ -4532,6 +4532,7 @@ dependencies = [
"metrics",
"nix",
"once_cell",
"pin-project-lite",
"rand",
"routerify",
"sentry",

View File

@@ -1,7 +1,7 @@
# Note: this file *mostly* just builds on Dockerfile.compute-node
ARG SRC_IMAGE
ARG VM_INFORMANT_VERSION=v0.1.6
ARG VM_INFORMANT_VERSION=v0.1.14
# Pull VM informant and set up inittab
FROM neondatabase/vm-informant:$VM_INFORMANT_VERSION as informant
@@ -11,7 +11,9 @@ RUN set -e \
&& touch /etc/inittab
RUN set -e \
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant --auto-restart'" >> /etc/inittab
&& CONNSTR="dbname=neondb user=cloud_admin sslmode=disable" \
&& ARGS="--auto-restart --pgconnstr=\"$CONNSTR\"" \
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant $ARGS'" >> /etc/inittab
# Combine, starting from non-VM compute node image.
FROM $SRC_IMAGE as base

View File

@@ -59,14 +59,14 @@ pub fn is_expected_io_error(e: &io::Error) -> bool {
}
#[async_trait::async_trait]
pub trait Handler {
pub trait Handler<IO> {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care). It will also flush out the output buffer.
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
query_string: &str,
) -> Result<(), QueryError>;
@@ -77,7 +77,7 @@ pub trait Handler {
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackend<IO>,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
@@ -86,7 +86,7 @@ pub trait Handler {
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackend<IO>,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
@@ -115,12 +115,12 @@ pub enum ProcessMsgResult {
}
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
pub enum MaybeTlsStream {
Unencrypted(tokio::net::TcpStream),
Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
pub enum MaybeTlsStream<IO> {
Unencrypted(IO),
Tls(Box<tokio_rustls::server::TlsStream<IO>>),
}
impl AsyncWrite for MaybeTlsStream {
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
@@ -147,7 +147,7 @@ impl AsyncWrite for MaybeTlsStream {
}
}
}
impl AsyncRead for MaybeTlsStream {
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
@@ -192,13 +192,13 @@ impl fmt::Display for AuthType {
/// PostgresBackend after call to `split`. In principle we could always store a
/// pair of splitted handles, but that would force to to pay splitting price
/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
enum MaybeWriteOnly {
Full(Framed<MaybeTlsStream>),
WriteOnly(FramedWriter<MaybeTlsStream>),
enum MaybeWriteOnly<IO> {
Full(Framed<MaybeTlsStream<IO>>),
WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
Broken, // temporary value palmed off during the split
}
impl MaybeWriteOnly {
impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
@@ -244,8 +244,8 @@ impl MaybeWriteOnly {
}
}
pub struct PostgresBackend {
framed: MaybeWriteOnly,
pub struct PostgresBackend<IO> {
framed: MaybeWriteOnly<IO>,
pub state: ProtoState,
@@ -255,6 +255,8 @@ pub struct PostgresBackend {
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
@@ -271,7 +273,7 @@ fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
impl PostgresBackend<tokio::net::TcpStream> {
pub fn new(
socket: tokio::net::TcpStream,
auth_type: AuthType,
@@ -288,6 +290,25 @@ impl PostgresBackend {
peer_addr,
})
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
pub fn new_from_io(
socket: IO,
peer_addr: SocketAddr,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let stream = MaybeTlsStream::Unencrypted(socket);
Ok(Self {
framed: MaybeWriteOnly::Full(Framed::new(stream)),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
@@ -346,14 +367,14 @@ impl PostgresBackend {
/// to it in CopyData messages, and writes them to the connection
///
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
pub fn copyout_writer(&mut self) -> CopyDataWriter {
pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
CopyDataWriter { pgb: self }
}
/// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
@@ -369,7 +390,7 @@ impl PostgresBackend {
async fn run_message_loop<F, S>(
&mut self,
handler: &mut impl Handler,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
@@ -426,9 +447,9 @@ impl PostgresBackend {
/// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
async fn tls_upgrade(
src: MaybeTlsStream,
src: MaybeTlsStream<IO>,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<MaybeTlsStream> {
) -> anyhow::Result<MaybeTlsStream<IO>> {
match src {
MaybeTlsStream::Unencrypted(s) => {
let acceptor = TlsAcceptor::from(tls_config);
@@ -466,7 +487,7 @@ impl PostgresBackend {
/// Split off owned read part from which messages can be read in different
/// task/thread.
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader> {
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
// temporary replace stream with fake to cook split one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
@@ -482,7 +503,7 @@ impl PostgresBackend {
}
/// Join read part back.
pub fn unsplit(&mut self, reader: PostgresBackendReader) -> anyhow::Result<()> {
pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
// temporary replace stream with fake to cook joined one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(_) => {
@@ -499,7 +520,7 @@ impl PostgresBackend {
/// Perform handshake with the client, transitioning to Established.
/// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
async fn handshake(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
while self.state < ProtoState::Authentication {
match self.framed.read_startup_message().await? {
Some(msg) => {
@@ -565,7 +586,7 @@ impl PostgresBackend {
/// actual startup packet.
async fn process_startup_message(
&mut self,
handler: &mut impl Handler,
handler: &mut impl Handler<IO>,
msg: FeStartupPacket,
) -> Result<(), QueryError> {
assert!(self.state < ProtoState::Authentication);
@@ -629,7 +650,7 @@ impl PostgresBackend {
async fn process_message(
&mut self,
handler: &mut impl Handler,
handler: &mut impl Handler<IO>,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
@@ -776,9 +797,9 @@ impl PostgresBackend {
}
}
pub struct PostgresBackendReader(FramedReader<MaybeTlsStream>);
pub struct PostgresBackendReader<IO>(FramedReader<MaybeTlsStream<IO>>);
impl PostgresBackendReader {
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
@@ -812,11 +833,11 @@ impl PostgresBackendReader {
/// messages.
///
pub struct CopyDataWriter<'a> {
pgb: &'a mut PostgresBackend,
pub struct CopyDataWriter<'a, IO> {
pgb: &'a mut PostgresBackend<IO>,
}
impl<'a> AsyncWrite for CopyDataWriter<'a> {
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,

View File

@@ -4,6 +4,7 @@ use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
use pq_proto::{BeMessage, RowDescriptor};
use std::io::Cursor;
use std::{future, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect;
@@ -22,11 +23,11 @@ async fn make_tcp_pair() -> (TcpStream, TcpStream) {
struct TestHandler {}
#[async_trait::async_trait]
impl Handler for TestHandler {
impl<IO: AsyncRead + AsyncWrite + Unpin + Send> Handler<IO> for TestHandler {
// return single col 'hey' for any query
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
_query_string: &str,
) -> Result<(), QueryError> {
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(

View File

@@ -18,6 +18,7 @@ futures = { workspace = true}
jsonwebtoken.workspace = true
nix.workspace = true
once_cell.workspace = true
pin-project-lite.workspace = true
routerify.workspace = true
serde.workspace = true
serde_json.workspace = true

View File

@@ -49,6 +49,8 @@ pub mod fs_ext;
pub mod history_buffer;
pub mod measured_stream;
/// use with fail::cfg("$name", "return(2000)")
#[macro_export]
macro_rules! failpoint_sleep_millis_async {

View File

@@ -0,0 +1,77 @@
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::{io, task};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pin_project! {
/// This stream tracks all writes and calls user provided
/// callback when the underlying stream is flushed.
pub struct MeasuredStream<S, R, W> {
#[pin]
stream: S,
write_count: usize,
inc_read_count: R,
inc_write_count: W,
}
}
impl<S, R, W> MeasuredStream<S, R, W> {
pub fn new(stream: S, inc_read_count: R, inc_write_count: W) -> Self {
Self {
stream,
write_count: 0,
inc_read_count,
inc_write_count,
}
}
}
impl<S: AsyncRead + Unpin, R: FnMut(usize), W> AsyncRead for MeasuredStream<S, R, W> {
fn poll_read(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
let filled = buf.filled().len();
this.stream.poll_read(context, buf).map_ok(|()| {
let cnt = buf.filled().len() - filled;
// Increment the read count.
(this.inc_read_count)(cnt);
})
}
}
impl<S: AsyncWrite + Unpin, R, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, R, W> {
fn poll_write(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
let this = self.project();
this.stream.poll_write(context, buf).map_ok(|cnt| {
// Increment the write count.
*this.write_count += cnt;
cnt
})
}
fn poll_flush(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
this.stream.poll_flush(context).map_ok(|()| {
// Call the user provided callback and reset the write count.
(this.inc_write_count)(*this.write_count);
*this.write_count = 0;
})
}
fn poll_shutdown(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_shutdown(context)
}
}

View File

@@ -20,6 +20,7 @@ use pageserver_api::models::{
PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse,
PagestreamNblocksRequest, PagestreamNblocksResponse,
};
use postgres_backend::PostgresBackendTCP;
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
use pq_proto::framed::ConnectionError;
use pq_proto::FeStartupPacket;
@@ -54,7 +55,7 @@ use crate::trace::Tracer;
use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID;
use postgres_ffi::BLCKSZ;
fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Bytes>> + '_ {
fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<Bytes>> + '_ {
async_stream::try_stream! {
loop {
let msg = tokio::select! {
@@ -288,7 +289,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_pagerequests(
&self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackendTCP,
tenant_id: TenantId,
timeline_id: TimelineId,
ctx: RequestContext,
@@ -392,7 +393,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_import_basebackup(
&self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackendTCP,
tenant_id: TenantId,
timeline_id: TimelineId,
base_lsn: Lsn,
@@ -448,7 +449,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_import_wal(
&self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackendTCP,
tenant_id: TenantId,
timeline_id: TimelineId,
start_lsn: Lsn,
@@ -659,7 +660,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_basebackup_request(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackendTCP,
tenant_id: TenantId,
timeline_id: TimelineId,
lsn: Option<Lsn>,
@@ -723,10 +724,10 @@ impl PageServerHandler {
}
#[async_trait::async_trait]
impl postgres_backend::Handler for PageServerHandler {
impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackendTCP,
jwt_response: &[u8],
) -> Result<(), QueryError> {
// this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
@@ -754,7 +755,7 @@ impl postgres_backend::Handler for PageServerHandler {
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackendTCP,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
@@ -762,7 +763,7 @@ impl postgres_backend::Handler for PageServerHandler {
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackendTCP,
query_string: &str,
) -> Result<(), QueryError> {
let ctx = self.connection_ctx.attached_child();

View File

@@ -3176,6 +3176,44 @@ mod tests {
}
*/
#[tokio::test]
async fn test_get_branchpoints_from_an_inactive_timeline() -> anyhow::Result<()> {
let (tenant, ctx) =
TenantHarness::create("test_get_branchpoints_from_an_inactive_timeline")?
.load()
.await;
let tline = tenant
.create_empty_timeline(TIMELINE_ID, Lsn(0), DEFAULT_PG_VERSION, &ctx)?
.initialize(&ctx)?;
make_some_layers(tline.as_ref(), Lsn(0x20)).await?;
tenant
.branch_timeline(&tline, NEW_TIMELINE_ID, Some(Lsn(0x40)), &ctx)
.await?;
let newtline = tenant
.get_timeline(NEW_TIMELINE_ID, true)
.expect("Should have a local timeline");
make_some_layers(newtline.as_ref(), Lsn(0x60)).await?;
tline.set_state(TimelineState::Broken);
tenant
.gc_iteration(Some(TIMELINE_ID), 0x10, Duration::ZERO, &ctx)
.await?;
assert_eq!(
newtline.get(*TEST_KEY, Lsn(0x50), &ctx).await?,
TEST_IMG(&format!("foo at {}", Lsn(0x40)))
);
let branchpoints = &tline.gc_info.read().unwrap().retain_lsns;
assert_eq!(branchpoints.len(), 1);
assert_eq!(branchpoints[0], Lsn(0x40));
Ok(())
}
#[tokio::test]
async fn test_retain_data_in_parent_which_is_needed_for_child() -> anyhow::Result<()> {
let (tenant, ctx) =

View File

@@ -51,9 +51,6 @@ where
///
/// A "cursor" for efficiently reading multiple pages from a BlockReader
///
/// A cursor caches the last accessed page, allowing for faster access if the
/// same block is accessed repeatedly.
///
/// You can access the last page with `*cursor`. 'read_blk' returns 'self', so
/// that in many cases you can use a BlockCursor as a drop-in replacement for
/// the underlying BlockReader. For example:
@@ -73,8 +70,6 @@ where
R: BlockReader,
{
reader: R,
/// last accessed page
cache: Option<(u32, R::BlockLease)>,
}
impl<R> BlockCursor<R>
@@ -82,40 +77,13 @@ where
R: BlockReader,
{
pub fn new(reader: R) -> Self {
BlockCursor {
reader,
cache: None,
}
BlockCursor { reader }
}
pub fn read_blk(&mut self, blknum: u32) -> Result<&Self, std::io::Error> {
// Fast return if this is the same block as before
if let Some((cached_blk, _buf)) = &self.cache {
if *cached_blk == blknum {
return Ok(self);
}
}
// Read the block from the underlying reader, and cache it
self.cache = None;
let buf = self.reader.read_blk(blknum)?;
self.cache = Some((blknum, buf));
Ok(self)
pub fn read_blk(&mut self, blknum: u32) -> Result<R::BlockLease, std::io::Error> {
self.reader.read_blk(blknum)
}
}
impl<R> Deref for BlockCursor<R>
where
R: BlockReader,
{
type Target = [u8; PAGE_SZ];
fn deref(&self) -> &<Self as Deref>::Target {
&self.cache.as_ref().unwrap().1
}
}
static NEXT_ID: AtomicU64 = AtomicU64::new(1);
/// An adapter for reading a (virtual) file using the page cache.

View File

@@ -2,9 +2,7 @@
//! used to keep in-memory layers spilled on disk.
use crate::config::PageServerConf;
use crate::page_cache;
use crate::page_cache::PAGE_SZ;
use crate::page_cache::{ReadBufResult, WriteBufResult};
use crate::page_cache::{self, ReadBufResult, WriteBufResult, PAGE_SZ};
use crate::tenant::blob_io::BlobWriter;
use crate::tenant::block_io::BlockReader;
use crate::virtual_file::VirtualFile;
@@ -427,7 +425,6 @@ mod tests {
let actual = cursor.read_blob(pos)?;
assert_eq!(actual, expected);
}
drop(cursor);
// Test a large blob that spans multiple pages
let mut large_data = Vec::new();

View File

@@ -25,11 +25,12 @@ impl CancelMap {
cancel_closure.try_cancel_query().await
}
/// Create a new session, with a new client-facing random cancellation key.
///
/// Use `enable_query_cancellation` to register the Postgres backend's cancellation
/// key with it.
pub fn new_session<'a>(&'a self) -> anyhow::Result<Session<'a>> {
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
where
F: FnOnce(Session<'a>) -> R,
R: std::future::Future<Output = anyhow::Result<V>>,
{
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
@@ -43,9 +44,17 @@ impl CancelMap {
.write()
.try_insert(key, None)
.map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
info!("registered new query cancellation key {key}");
Ok(Session::new(key, self))
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.write().remove(&key);
info!("dropped query cancellation key {key}");
}
info!("registered new query cancellation key {key}");
let session = Session::new(key, self);
f(session).await
}
#[cfg(test)]
@@ -102,7 +111,7 @@ impl<'a> Session<'a> {
impl Session<'_> {
/// Store the cancel token for the given session.
/// This enables query cancellation in [`crate::proxy::handshake`].
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session");
self.cancel_map
.0
@@ -113,14 +122,6 @@ impl Session<'_> {
}
}
impl<'a> Drop for Session<'a> {
fn drop(&mut self) {
let key = &self.key;
self.cancel_map.0.write().remove(key);
info!("dropped query cancellation key {key}");
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -131,14 +132,14 @@ mod tests {
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
let (tx, rx) = tokio::sync::oneshot::channel();
let session = CANCEL_MAP.new_session()?;
let task = tokio::spawn(async move {
let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
assert!(CANCEL_MAP.contains(&session));
tx.send(()).expect("failed to send");
futures::future::pending::<()>().await; // sleep forever
});
Ok(())
}));
// Wait until the task has been spawned.
rx.await.context("failed to hear from the task")?;

View File

@@ -6,9 +6,9 @@ use std::{io, net::SocketAddr};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::NoTls;
use tracing::{error, info};
use tracing::{error, info, warn};
const COULD_NOT_CONNECT: &str = "Could not connect to compute node";
const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
#[derive(Debug, Error)]
pub enum ConnectionError {
@@ -131,7 +131,7 @@ impl ConnCfg {
use tokio_postgres::config::Host;
let connect_once = |host, port| {
info!("trying to connect to a compute node at {host}:{port}");
info!("trying to connect to compute node at {host}:{port}");
TcpStream::connect((host, port)).and_then(|socket| async {
let socket_addr = socket.peer_addr()?;
// This prevents load balancer from severing the connection.
@@ -151,7 +151,7 @@ impl ConnCfg {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"couldn't connect: bad compute config, \
"bad compute config, \
ports and hosts entries' count does not match: {:?}",
self.0
),
@@ -170,7 +170,7 @@ impl ConnCfg {
Ok(socket) => return Ok(socket),
Err(err) => {
// We can't throw an error here, as there might be more hosts to try.
error!("failed to connect to a compute node at {host}:{port}: {err}");
warn!("couldn't connect to compute node at {host}:{port}: {err}");
connection_error = Some(err);
}
}
@@ -179,7 +179,7 @@ impl ConnCfg {
Err(connection_error.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
format!("couldn't connect: bad compute config: {:?}", self.0),
format!("bad compute config: {:?}", self.0),
)
}))
}
@@ -195,12 +195,11 @@ pub struct PostgresConnection {
}
impl ConnCfg {
/// Connect to a corresponding compute node.
pub async fn connect(&self) -> Result<PostgresConnection, ConnectionError> {
async fn do_connect(&self) -> Result<PostgresConnection, ConnectionError> {
// TODO: establish a secure connection to the DB.
let (socket_addr, mut stream) = self.connect_raw().await?;
let (client, connection) = self.0.connect_raw(&mut stream, NoTls).await?;
info!("connected to user's compute node at {socket_addr}");
info!("connected to compute node at {socket_addr}");
// This is very ugly but as of now there's no better way to
// extract the connection parameters from tokio-postgres' connection.
@@ -219,6 +218,16 @@ impl ConnCfg {
Ok(connection)
}
/// Connect to a corresponding compute node.
pub async fn connect(&self) -> Result<PostgresConnection, ConnectionError> {
self.do_connect()
.inspect_err(|err| {
// Immediately log the error we have at our disposal.
error!("couldn't connect to compute node: {err}");
})
.await
}
}
/// Retrieve `options` from a startup message, dropping all proxy-secific flags.

View File

@@ -4,7 +4,7 @@ use crate::{
};
use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_backend::{self, AuthType, PostgresBackend, QueryError};
use postgres_backend::{self, AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use std::future;
use tokio::net::{TcpListener, TcpStream};
@@ -71,10 +71,10 @@ pub type ComputeReady = Result<DatabaseInfo, String>;
// TODO: replace with an http-based protocol.
struct MgmtHandler;
#[async_trait::async_trait]
impl postgres_backend::Handler for MgmtHandler {
impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackendTCP,
query: &str,
) -> Result<(), QueryError> {
try_process_query(pgb, query).await.map_err(|e| {
@@ -84,7 +84,7 @@ impl postgres_backend::Handler for MgmtHandler {
}
}
async fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
async fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
let span = info_span!("event", session_id = resp.session_id);

View File

@@ -8,7 +8,7 @@ use crate::{
config::{ProxyConfig, TlsConfig},
console::{self, messages::MetricsAuxInfo},
error::io_error,
stream::{MeasuredStream, PqStream, Stream},
stream::{PqStream, Stream},
};
use anyhow::{bail, Context};
use futures::TryFutureExt;
@@ -18,6 +18,7 @@ use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, warn};
use utils::measured_stream::MeasuredStream;
/// Number of times we should retry the `/proxy_wake_compute` http request.
const NUM_RETRIES_WAKE_COMPUTE: usize = 1;
@@ -133,14 +134,10 @@ pub async fn handle_ws_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(
stream,
creds,
&params,
session_id,
cancel_map.new_session()?,
);
client.connect_to_db(true).await
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session, true))
.await
}
#[tracing::instrument(fields(session_id), skip_all)]
@@ -176,14 +173,10 @@ async fn handle_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(
stream,
creds,
&params,
session_id,
cancel_map.new_session()?,
);
client.connect_to_db(false).await
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session, false))
.await
}
/// Establish a (most probably, secure) connection with the client.
@@ -361,16 +354,24 @@ async fn proxy_pass(
aux: &MetricsAuxInfo,
) -> anyhow::Result<()> {
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("tx"));
let mut client = MeasuredStream::new(client, |cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
});
let mut client = MeasuredStream::new(
client,
|_| {},
|cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
},
);
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("rx"));
let mut compute = MeasuredStream::new(compute, |cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);
});
let mut compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);
},
);
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
@@ -389,8 +390,6 @@ struct Client<'a, S> {
params: &'a StartupMessageParams,
/// Unique connection ID.
session_id: uuid::Uuid,
session: cancellation::Session<'a>,
}
impl<'a, S> Client<'a, S> {
@@ -400,27 +399,28 @@ impl<'a, S> Client<'a, S> {
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
session_id: uuid::Uuid,
session: cancellation::Session<'a>,
) -> Self {
Self {
stream,
creds,
params,
session_id,
session,
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
/// Let the client authenticate and connect to the designated compute node.
async fn connect_to_db(self, allow_cleartext: bool) -> anyhow::Result<()> {
async fn connect_to_db(
self,
session: cancellation::Session<'_>,
allow_cleartext: bool,
) -> anyhow::Result<()> {
let Self {
mut stream,
mut creds,
params,
session_id,
session,
} = self;
let extra = console::ConsoleReqExtra {

View File

@@ -217,68 +217,3 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
}
}
}
pin_project! {
/// This stream tracks all writes and calls user provided
/// callback when the underlying stream is flushed.
pub struct MeasuredStream<S, W> {
#[pin]
stream: S,
write_count: usize,
inc_write_count: W,
}
}
impl<S, W> MeasuredStream<S, W> {
pub fn new(stream: S, inc_write_count: W) -> Self {
Self {
stream,
write_count: 0,
inc_write_count,
}
}
}
impl<S: AsyncRead + Unpin, W> AsyncRead for MeasuredStream<S, W> {
fn poll_read(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_read(context, buf)
}
}
impl<S: AsyncWrite + Unpin, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, W> {
fn poll_write(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
let this = self.project();
this.stream.poll_write(context, buf).map_ok(|cnt| {
// Increment the write count.
*this.write_count += cnt;
cnt
})
}
fn poll_flush(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
this.stream.poll_flush(context).map_ok(|()| {
// Call the user provided callback and reset the write count.
(this.inc_write_count)(*this.write_count);
*this.write_count = 0;
})
}
fn poll_shutdown(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_shutdown(context)
}
}

View File

@@ -3,6 +3,7 @@
use anyhow::Context;
use std::str;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span, Instrument};
use crate::auth::check_permission;
@@ -67,11 +68,13 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
}
#[async_trait::async_trait]
impl postgres_backend::Handler for SafekeeperPostgresHandler {
impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
for SafekeeperPostgresHandler
{
// tenant_id and timeline_id are passed in connection string params
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackend<IO>,
sm: &FeStartupPacket,
) -> Result<(), QueryError> {
if let FeStartupPacket::StartupMessage { params, .. } = sm {
@@ -110,7 +113,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackend<IO>,
jwt_response: &[u8],
) -> Result<(), QueryError> {
// this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
@@ -139,7 +142,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
query_string: &str,
) -> Result<(), QueryError> {
if query_string
@@ -216,9 +219,9 @@ impl SafekeeperPostgresHandler {
///
/// Handle IDENTIFY_SYSTEM replication command
///
async fn handle_identify_system(
async fn handle_identify_system<IO: AsyncRead + AsyncWrite + Unpin>(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
) -> Result<(), QueryError> {
let tli = GlobalTimelines::get(self.ttid).map_err(|e| QueryError::Other(e.into()))?;

View File

@@ -12,6 +12,7 @@ use anyhow::Context;
use bytes::Bytes;
use postgres_backend::QueryError;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::*;
use utils::id::TenantTimelineId;
@@ -60,9 +61,9 @@ struct AppendResult {
/// Handles command to craft logical message WAL record with given
/// content, and then append it with specified term and lsn. This
/// function is used to test safekeepers in different scenarios.
pub async fn handle_json_ctrl(
pub async fn handle_json_ctrl<IO: AsyncRead + AsyncWrite + Unpin>(
spg: &SafekeeperPostgresHandler,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
append_request: &AppendLogicalMessage,
) -> Result<(), QueryError> {
info!("JSON_CTRL request: {append_request:?}");

View File

@@ -7,7 +7,7 @@ use anyhow::Result;
use metrics::{
core::{AtomicU64, Collector, Desc, GenericGaugeVec, Opts},
proto::MetricFamily,
Gauge, IntGaugeVec,
register_int_counter_vec, Gauge, IntCounterVec, IntGaugeVec,
};
use once_cell::sync::Lazy;
use postgres_ffi::XLogSegNo;
@@ -61,6 +61,14 @@ pub static PERSIST_CONTROL_FILE_SECONDS: Lazy<Histogram> = Lazy::new(|| {
)
.expect("Failed to register safekeeper_persist_control_file_seconds histogram vec")
});
pub static PG_IO_BYTES: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"safekeeper_pg_io_bytes",
"Bytes read from or written to any PostgreSQL connection",
&["direction"]
)
.expect("Failed to register safekeeper_pg_io_bytes gauge")
});
/// Metrics for WalStorage in a single timeline.
#[derive(Clone, Default)]

View File

@@ -20,6 +20,8 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::thread;
use std::thread::JoinHandle;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::Receiver;
@@ -36,9 +38,9 @@ impl SafekeeperPostgresHandler {
/// Wrapper around handle_start_wal_push_guts handling result. Error is
/// handled here while we're still in walreceiver ttid span; with API
/// extension, this can probably be moved into postgres_backend.
pub async fn handle_start_wal_push(
pub async fn handle_start_wal_push<IO: AsyncRead + AsyncWrite + Unpin>(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
) -> Result<(), QueryError> {
if let Err(end) = self.handle_start_wal_push_guts(pgb).await {
// Log the result and probably send it to the client, closing the stream.
@@ -47,9 +49,9 @@ impl SafekeeperPostgresHandler {
Ok(())
}
pub async fn handle_start_wal_push_guts(
pub async fn handle_start_wal_push_guts<IO: AsyncRead + AsyncWrite + Unpin>(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
) -> Result<(), CopyStreamHandlerEnd> {
// Notify the libpq client that it's allowed to send `CopyData` messages
pgb.write_message(&BeMessage::CopyBothResponse).await?;
@@ -111,17 +113,17 @@ impl SafekeeperPostgresHandler {
}
}
struct NetworkReader<'a> {
struct NetworkReader<'a, IO> {
ttid: TenantTimelineId,
conn_id: ConnectionId,
pgb_reader: &'a mut PostgresBackendReader,
pgb_reader: &'a mut PostgresBackendReader<IO>,
peer_addr: SocketAddr,
// WalAcceptor is spawned when we learn server info from walproposer and
// create timeline; handle is put here.
acceptor_handle: &'a mut Option<JoinHandle<anyhow::Result<()>>>,
}
impl<'a> NetworkReader<'a> {
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> NetworkReader<'a, IO> {
async fn run(
self,
msg_tx: Sender<ProposerAcceptorMessage>,
@@ -162,16 +164,16 @@ impl<'a> NetworkReader<'a> {
/// Read next message from walproposer.
/// TODO: Return Ok(None) on graceful termination.
async fn read_message(
pgb_reader: &mut PostgresBackendReader,
async fn read_message<IO: AsyncRead + AsyncWrite + Unpin>(
pgb_reader: &mut PostgresBackendReader<IO>,
) -> Result<ProposerAcceptorMessage, CopyStreamHandlerEnd> {
let copy_data = pgb_reader.read_copy_message().await?;
let msg = ProposerAcceptorMessage::parse(copy_data)?;
Ok(msg)
}
async fn read_network_loop(
pgb_reader: &mut PostgresBackendReader,
async fn read_network_loop<IO: AsyncRead + AsyncWrite + Unpin>(
pgb_reader: &mut PostgresBackendReader<IO>,
msg_tx: Sender<ProposerAcceptorMessage>,
mut next_msg: ProposerAcceptorMessage,
) -> Result<(), CopyStreamHandlerEnd> {
@@ -186,8 +188,8 @@ async fn read_network_loop(
/// Read replies from WalAcceptor and pass them back to socket. Returns Ok(())
/// if reply_rx closed; it must mean WalAcceptor terminated, joining it should
/// tell the error.
async fn network_write(
pgb_writer: &mut PostgresBackend,
async fn network_write<IO: AsyncRead + AsyncWrite + Unpin>(
pgb_writer: &mut PostgresBackend<IO>,
mut reply_rx: Receiver<AcceptorProposerMessage>,
) -> Result<(), CopyStreamHandlerEnd> {
let mut buf = BytesMut::with_capacity(128);

View File

@@ -13,6 +13,8 @@ use postgres_ffi::get_current_timestamp;
use postgres_ffi::{TimestampTz, MAX_SEND_SIZE};
use pq_proto::{BeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use std::cmp::min;
use std::str;
use std::sync::Arc;
@@ -74,9 +76,9 @@ impl SafekeeperPostgresHandler {
/// Wrapper around handle_start_replication_guts handling result. Error is
/// handled here while we're still in walsender ttid span; with API
/// extension, this can probably be moved into postgres_backend.
pub async fn handle_start_replication(
pub async fn handle_start_replication<IO: AsyncRead + AsyncWrite + Unpin>(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
start_pos: Lsn,
) -> Result<(), QueryError> {
if let Err(end) = self.handle_start_replication_guts(pgb, start_pos).await {
@@ -86,9 +88,9 @@ impl SafekeeperPostgresHandler {
Ok(())
}
pub async fn handle_start_replication_guts(
pub async fn handle_start_replication_guts<IO: AsyncRead + AsyncWrite + Unpin>(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
start_pos: Lsn,
) -> Result<(), CopyStreamHandlerEnd> {
let appname = self.appname.clone();
@@ -176,8 +178,8 @@ impl SafekeeperPostgresHandler {
}
/// A half driving sending WAL.
struct WalSender<'a> {
pgb: &'a mut PostgresBackend,
struct WalSender<'a, IO> {
pgb: &'a mut PostgresBackend<IO>,
tli: Arc<Timeline>,
appname: Option<String>,
// Position since which we are sending next chunk.
@@ -194,7 +196,7 @@ struct WalSender<'a> {
send_buf: [u8; MAX_SEND_SIZE],
}
impl WalSender<'_> {
impl<IO: AsyncRead + AsyncWrite + Unpin> WalSender<'_, IO> {
/// Send WAL until
/// - an error occurs
/// - if we are streaming to walproposer, we've streamed until stop_pos
@@ -282,14 +284,14 @@ impl WalSender<'_> {
}
/// A half driving receiving replies.
struct ReplyReader {
reader: PostgresBackendReader,
struct ReplyReader<IO> {
reader: PostgresBackendReader<IO>,
tli: Arc<Timeline>,
replica_id: usize,
feedback: ReplicaState,
}
impl ReplyReader {
impl<IO: AsyncRead + AsyncWrite + Unpin> ReplyReader<IO> {
async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> {
loop {
let msg = self.reader.read_copy_message().await?;

View File

@@ -7,9 +7,10 @@ use postgres_backend::QueryError;
use std::{future, thread};
use tokio::net::TcpStream;
use tracing::*;
use utils::measured_stream::MeasuredStream;
use crate::handler::SafekeeperPostgresHandler;
use crate::SafeKeeperConf;
use crate::{handler::SafekeeperPostgresHandler, metrics::PG_IO_BYTES};
use postgres_backend::{AuthType, PostgresBackend};
/// Accept incoming TCP connections and spawn them into a background thread.
@@ -67,14 +68,29 @@ fn handle_socket(
.build()?;
let local = tokio::task::LocalSet::new();
let read_metrics = PG_IO_BYTES.with_label_values(&["read"]);
let write_metrics = PG_IO_BYTES.with_label_values(&["write"]);
socket.set_nodelay(true)?;
let peer_addr = socket.peer_addr()?;
// TODO: measure cross-az traffic
let socket = MeasuredStream::new(
socket,
|cnt| {
read_metrics.inc_by(cnt as u64);
},
|cnt| {
write_metrics.inc_by(cnt as u64);
},
);
let auth_type = match conf.auth {
None => AuthType::Trust,
Some(_) => AuthType::NeonJWT,
};
let mut conn_handler = SafekeeperPostgresHandler::new(conf, conn_id);
let pgbackend = PostgresBackend::new(socket, auth_type, None)?;
let pgbackend = PostgresBackend::new_from_io(socket, peer_addr, auth_type, None)?;
// libpq protocol between safekeeper and walproposer / pageserver
// We don't use shutdown.
local.block_on(

View File

@@ -0,0 +1 @@
target/

View File

@@ -0,0 +1 @@
target/

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
[package]
name = "rust-neon-example"
version = "0.1.0"
edition = "2021"
publish = false
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
native-tls = "0.2.11"
postgres-native-tls = "0.5.0"
tokio = { version = "1.26", features=["rt", "macros"] }
tokio-postgres = "0.7.7"
# This is not part of the main 'neon' workspace
[workspace]

View File

@@ -0,0 +1,6 @@
FROM rust:1.67
WORKDIR /source
COPY . .
RUN cargo build
CMD ["/source/target/debug/rust-neon-example"]

View File

@@ -0,0 +1,43 @@
use std::env::VarError;
use tokio_postgres;
fn get_env(key: &str) -> String {
match std::env::var(key) {
Ok(val) => val,
Err(VarError::NotPresent) => panic!("{key} env variable not set"),
Err(VarError::NotUnicode(_)) => panic!("{key} is not valid unicode"),
}
}
#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), tokio_postgres::Error> {
let host = get_env("NEON_HOST");
let database = get_env("NEON_DATABASE");
let user = get_env("NEON_USER");
let password = get_env("NEON_PASSWORD");
let url = format!("postgresql://{user}:{password}@{host}/{database}");
// Use the native TLS implementation (Neon requires TLS)
let tls_connector =
postgres_native_tls::MakeTlsConnector::new(native_tls::TlsConnector::new().unwrap());
// Connect to the database.
let (client, connection) = tokio_postgres::connect(&url, tls_connector).await?;
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let result = client.query("SELECT 1", &[]).await?;
let value: i32 = result[0].get(0);
assert_eq!(value, 1);
println!("{value}");
Ok(())
}

View File

@@ -13,6 +13,7 @@ from fixtures.utils import subprocess_capture
[
"csharp/npgsql",
"java/jdbc",
"rust/tokio-postgres",
"python/asyncpg",
"python/pg8000",
pytest.param(

View File

@@ -4,10 +4,20 @@ from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres
def test_proxy_select_1(static_proxy: NeonProxy):
static_proxy.safe_psql("select 1", options="project=generic-project-name")
"""
A simplest smoke test: check proxy against a local postgres instance.
"""
out = static_proxy.safe_psql("select 1", options="project=generic-project-name")
assert out[0][0] == 1
def test_password_hack(static_proxy: NeonProxy):
"""
Check the PasswordHack auth flow: an alternative to SCRAM auth for
clients which can't provide the project/endpoint name via SNI or `options`.
"""
user = "borat"
password = "password"
static_proxy.safe_psql(
@@ -25,7 +35,11 @@ def test_password_hack(static_proxy: NeonProxy):
@pytest.mark.asyncio
async def test_psql_session_id(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy):
async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy):
"""
Check the Link auth flow: a lightweight auth method which delegates
all necessary checks to the console by sending client an auth URL.
"""
psql = await PSQL(host=link_proxy.host, port=link_proxy.proxy_port).run("select 42")
@@ -40,16 +54,27 @@ async def test_psql_session_id(vanilla_pg: VanillaPostgres, link_proxy: NeonProx
assert out == "42"
# Pass extra options to the server.
def test_proxy_options(static_proxy: NeonProxy):
with static_proxy.connect(options="project=irrelevant -cproxytest.option=value") as conn:
with conn.cursor() as cur:
cur.execute("SHOW proxytest.option")
value = cur.fetchall()[0][0]
assert value == "value"
"""
Check that we pass extra `options` to the PostgreSQL server:
* `project=...` shouldn't be passed at all (otherwise postgres will raise an error).
* everything else should be passed as-is.
"""
options = "project=irrelevant -cproxytest.option=value"
out = static_proxy.safe_psql("show proxytest.option", options=options)
assert out[0][0] == "value"
options = "-c proxytest.foo=\\ str project=irrelevant"
out = static_proxy.safe_psql("show proxytest.foo", options=options)
assert out[0][0] == " str"
def test_auth_errors(static_proxy: NeonProxy):
"""
Check that we throw very specific errors in some unsuccessful auth scenarios.
"""
# User does not exist
with pytest.raises(psycopg2.Error) as exprinfo:
static_proxy.connect(user="pinocchio", options="project=irrelevant")
@@ -78,6 +103,10 @@ def test_auth_errors(static_proxy: NeonProxy):
def test_forward_params_to_client(static_proxy: NeonProxy):
"""
Check that we forward all necessary PostgreSQL server params to client.
"""
# A subset of parameters (GUCs) which postgres
# sends to the client during connection setup.
# Unfortunately, `GUC_REPORT` can't be queried.