mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-01 20:40:37 +00:00
Compare commits
34 Commits
lr-rm-file
...
problame/i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
261f116a2d | ||
|
|
79a6cda45d | ||
|
|
f0a67c4071 | ||
|
|
a1f4eb2815 | ||
|
|
eb1ccd7988 | ||
|
|
f1f0452722 | ||
|
|
40200b7521 | ||
|
|
91bd729be2 | ||
|
|
4b84f23cea | ||
|
|
644c5e243d | ||
|
|
d66ccbae5e | ||
|
|
578a2d5d5f | ||
|
|
c9d1f51a93 | ||
|
|
1339834297 | ||
|
|
746fc530c5 | ||
|
|
94311052cd | ||
|
|
2edbc07733 | ||
|
|
c8c04c0db8 | ||
|
|
4a8e7f8716 | ||
|
|
72a8e090dd | ||
|
|
e0ea465aed | ||
|
|
d3c157eeee | ||
|
|
c600355802 | ||
|
|
57241c1c5a | ||
|
|
b9c30dbd6b | ||
|
|
35f8735a27 | ||
|
|
095130c1b3 | ||
|
|
5a0277476d | ||
|
|
6348833bdc | ||
|
|
dbabd4e4ea | ||
|
|
5b8888ce6b | ||
|
|
e85a631ddb | ||
|
|
95deea4f39 | ||
|
|
9876045444 |
@@ -683,7 +683,6 @@ impl ComputeNode {
|
||||
ComputeMode::Primary => {}
|
||||
ComputeMode::Replica | ComputeMode::Static(..) => {
|
||||
add_standby_signal(pgdata_path)?;
|
||||
remove_logrep_files(pgdata_path).context("remove_logrep_files")?;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
@@ -7,7 +8,6 @@ use std::path::Path;
|
||||
use std::process::Child;
|
||||
use std::thread::JoinHandle;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{fs, io};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use ini::Ini;
|
||||
@@ -366,25 +366,6 @@ pub fn create_pgdata(pgdata: &str) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove contents of the given directory. It must exist.
|
||||
fn remove_dir_contents<P: AsRef<Path>>(path: P) -> io::Result<()> {
|
||||
for entry in fs::read_dir(path)? {
|
||||
fs::remove_file(entry?.path())?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Logical replication slots and snapshot files are currently stored on
|
||||
/// pageserver via logical replication messages, so standby can't write them. So
|
||||
/// we remove them from the basebackup prepared for the standby. In particular
|
||||
/// this removes noise from ls_monitor failing to drop them.
|
||||
pub fn remove_logrep_files<P: AsRef<Path>>(pgdata: P) -> Result<()> {
|
||||
remove_dir_contents(pgdata.as_ref().join("pg_replslot"))?;
|
||||
remove_dir_contents(pgdata.as_ref().join("pg_logical/snapshots"))?;
|
||||
remove_dir_contents(pgdata.as_ref().join("pg_logical/mappings"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update pgbouncer.ini with provided options
|
||||
fn update_pgbouncer_ini(
|
||||
pgbouncer_config: HashMap<String, String>,
|
||||
|
||||
@@ -2222,18 +2222,7 @@ impl Service {
|
||||
|
||||
// unwrap safety: we would have returned above if we didn't find at least one shard to split
|
||||
let old_shard_count = old_shard_count.unwrap();
|
||||
let shard_ident = if let Some(new_stripe_size) = split_req.new_stripe_size {
|
||||
// This ShardIdentity will be used as the template for all children, so this implicitly
|
||||
// applies the new stripe size to the children.
|
||||
let mut shard_ident = shard_ident.unwrap();
|
||||
if shard_ident.count.count() > 1 && shard_ident.stripe_size != new_stripe_size {
|
||||
return Err(ApiError::BadRequest(anyhow::anyhow!("Attempted to change stripe size ({:?}->{new_stripe_size:?}) on a tenant with multiple shards", shard_ident.stripe_size)));
|
||||
}
|
||||
shard_ident.stripe_size = new_stripe_size;
|
||||
shard_ident
|
||||
} else {
|
||||
shard_ident.unwrap()
|
||||
};
|
||||
let shard_ident = shard_ident.unwrap();
|
||||
let policy = policy.unwrap();
|
||||
|
||||
// FIXME: we have dropped self.inner lock, and not yet written anything to the database: another
|
||||
@@ -2325,7 +2314,6 @@ impl Service {
|
||||
*parent_id,
|
||||
TenantShardSplitRequest {
|
||||
new_shard_count: split_req.new_shard_count,
|
||||
new_stripe_size: split_req.new_stripe_size,
|
||||
},
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -585,14 +585,10 @@ async fn handle_tenant(
|
||||
Some(("shard-split", matches)) => {
|
||||
let tenant_id = get_tenant_id(matches, env)?;
|
||||
let shard_count: u8 = matches.get_one::<u8>("shard-count").cloned().unwrap_or(0);
|
||||
let shard_stripe_size: Option<ShardStripeSize> = matches
|
||||
.get_one::<Option<ShardStripeSize>>("shard-stripe-size")
|
||||
.cloned()
|
||||
.unwrap();
|
||||
|
||||
let storage_controller = StorageController::from_env(env);
|
||||
let result = storage_controller
|
||||
.tenant_split(tenant_id, shard_count, shard_stripe_size)
|
||||
.tenant_split(tenant_id, shard_count)
|
||||
.await?;
|
||||
println!(
|
||||
"Split tenant {} into shards {}",
|
||||
@@ -1589,7 +1585,6 @@ fn cli() -> Command {
|
||||
.about("Increase the number of shards in the tenant")
|
||||
.arg(tenant_id_arg.clone())
|
||||
.arg(Arg::new("shard-count").value_parser(value_parser!(u8)).long("shard-count").action(ArgAction::Set).help("Number of shards in the new tenant (default 1)"))
|
||||
.arg(Arg::new("shard-stripe-size").value_parser(value_parser!(u32)).long("shard-stripe-size").action(ArgAction::Set).help("Sharding stripe size in pages"))
|
||||
)
|
||||
)
|
||||
.subcommand(
|
||||
|
||||
@@ -114,7 +114,7 @@ impl NeonBroker {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
|
||||
#[serde(default)]
|
||||
#[serde(default, deny_unknown_fields)]
|
||||
pub struct PageServerConf {
|
||||
// node id
|
||||
pub id: NodeId,
|
||||
@@ -126,6 +126,9 @@ pub struct PageServerConf {
|
||||
// auth type used for the PG and HTTP ports
|
||||
pub pg_auth_type: AuthType,
|
||||
pub http_auth_type: AuthType,
|
||||
|
||||
pub(crate) virtual_file_io_engine: String,
|
||||
pub(crate) get_vectored_impl: String,
|
||||
}
|
||||
|
||||
impl Default for PageServerConf {
|
||||
@@ -136,6 +139,9 @@ impl Default for PageServerConf {
|
||||
listen_http_addr: String::new(),
|
||||
pg_auth_type: AuthType::Trust,
|
||||
http_auth_type: AuthType::Trust,
|
||||
// FIXME: use the ones exposed by pageserver crate
|
||||
virtual_file_io_engine: "tokio-epoll-uring".to_owned(),
|
||||
get_vectored_impl: "sequential".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,18 +78,31 @@ impl PageServerNode {
|
||||
///
|
||||
/// These all end up on the command line of the `pageserver` binary.
|
||||
fn neon_local_overrides(&self, cli_overrides: &[&str]) -> Vec<String> {
|
||||
let id = format!("id={}", self.conf.id);
|
||||
// FIXME: the paths should be shell-escaped to handle paths with spaces, quotas etc.
|
||||
let pg_distrib_dir_param = format!(
|
||||
"pg_distrib_dir='{}'",
|
||||
self.env.pg_distrib_dir_raw().display()
|
||||
);
|
||||
|
||||
let http_auth_type_param = format!("http_auth_type='{}'", self.conf.http_auth_type);
|
||||
let listen_http_addr_param = format!("listen_http_addr='{}'", self.conf.listen_http_addr);
|
||||
let PageServerConf {
|
||||
id,
|
||||
listen_pg_addr,
|
||||
listen_http_addr,
|
||||
pg_auth_type,
|
||||
http_auth_type,
|
||||
virtual_file_io_engine,
|
||||
get_vectored_impl,
|
||||
} = &self.conf;
|
||||
|
||||
let pg_auth_type_param = format!("pg_auth_type='{}'", self.conf.pg_auth_type);
|
||||
let listen_pg_addr_param = format!("listen_pg_addr='{}'", self.conf.listen_pg_addr);
|
||||
let id = format!("id={}", id);
|
||||
|
||||
let http_auth_type_param = format!("http_auth_type='{}'", http_auth_type);
|
||||
let listen_http_addr_param = format!("listen_http_addr='{}'", listen_http_addr);
|
||||
|
||||
let pg_auth_type_param = format!("pg_auth_type='{}'", pg_auth_type);
|
||||
let listen_pg_addr_param = format!("listen_pg_addr='{}'", listen_pg_addr);
|
||||
let virtual_file_io_engine = format!("virtual_file_io_engine='{virtual_file_io_engine}'");
|
||||
let get_vectored_impl = format!("get_vectored_impl='{get_vectored_impl}'");
|
||||
|
||||
let broker_endpoint_param = format!("broker_endpoint='{}'", self.env.broker.client_url());
|
||||
|
||||
@@ -101,6 +114,8 @@ impl PageServerNode {
|
||||
listen_http_addr_param,
|
||||
listen_pg_addr_param,
|
||||
broker_endpoint_param,
|
||||
virtual_file_io_engine,
|
||||
get_vectored_impl,
|
||||
];
|
||||
|
||||
if let Some(control_plane_api) = &self.env.control_plane_api {
|
||||
@@ -111,7 +126,7 @@ impl PageServerNode {
|
||||
|
||||
// Storage controller uses the same auth as pageserver: if JWT is enabled
|
||||
// for us, we will also need it to talk to them.
|
||||
if matches!(self.conf.http_auth_type, AuthType::NeonJWT) {
|
||||
if matches!(http_auth_type, AuthType::NeonJWT) {
|
||||
let jwt_token = self
|
||||
.env
|
||||
.generate_auth_token(&Claims::new(None, Scope::GenerationsApi))
|
||||
@@ -129,8 +144,7 @@ impl PageServerNode {
|
||||
));
|
||||
}
|
||||
|
||||
if self.conf.http_auth_type != AuthType::Trust || self.conf.pg_auth_type != AuthType::Trust
|
||||
{
|
||||
if *http_auth_type != AuthType::Trust || *pg_auth_type != AuthType::Trust {
|
||||
// Keys are generated in the toplevel repo dir, pageservers' workdirs
|
||||
// are one level below that, so refer to keys with ../
|
||||
overrides.push("auth_validation_public_key_path='../auth_public_key.pem'".to_owned());
|
||||
|
||||
@@ -10,7 +10,7 @@ use pageserver_api::{
|
||||
TenantCreateRequest, TenantShardSplitRequest, TenantShardSplitResponse,
|
||||
TimelineCreateRequest, TimelineInfo,
|
||||
},
|
||||
shard::{ShardStripeSize, TenantShardId},
|
||||
shard::TenantShardId,
|
||||
};
|
||||
use pageserver_client::mgmt_api::ResponseErrorMessageExt;
|
||||
use postgres_backend::AuthType;
|
||||
@@ -496,15 +496,11 @@ impl StorageController {
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
new_shard_count: u8,
|
||||
new_stripe_size: Option<ShardStripeSize>,
|
||||
) -> anyhow::Result<TenantShardSplitResponse> {
|
||||
self.dispatch(
|
||||
Method::PUT,
|
||||
format!("control/v1/tenant/{tenant_id}/shard_split"),
|
||||
Some(TenantShardSplitRequest {
|
||||
new_shard_count,
|
||||
new_stripe_size,
|
||||
}),
|
||||
Some(TenantShardSplitRequest { new_shard_count }),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -198,13 +198,6 @@ pub struct TimelineCreateRequest {
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct TenantShardSplitRequest {
|
||||
pub new_shard_count: u8,
|
||||
|
||||
// A tenant's stripe size is only meaningful the first time their shard count goes
|
||||
// above 1: therefore during a split from 1->N shards, we may modify the stripe size.
|
||||
//
|
||||
// If this is set while the stripe count is being increased from an already >1 value,
|
||||
// then the request will fail with 400.
|
||||
pub new_stripe_size: Option<ShardStripeSize>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
|
||||
@@ -1151,12 +1151,7 @@ async fn tenant_shard_split_handler(
|
||||
|
||||
let new_shards = state
|
||||
.tenant_manager
|
||||
.shard_split(
|
||||
tenant_shard_id,
|
||||
ShardCount::new(req.new_shard_count),
|
||||
req.new_stripe_size,
|
||||
&ctx,
|
||||
)
|
||||
.shard_split(tenant_shard_id, ShardCount::new(req.new_shard_count), &ctx)
|
||||
.await
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
|
||||
@@ -2252,7 +2247,7 @@ pub fn make_router(
|
||||
.get("/v1/location_config", |r| {
|
||||
api_handler(r, list_location_config_handler)
|
||||
})
|
||||
.get("/v1/location_config/:tenant_shard_id", |r| {
|
||||
.get("/v1/location_config/:tenant_id", |r| {
|
||||
api_handler(r, get_location_config_handler)
|
||||
})
|
||||
.put(
|
||||
|
||||
@@ -121,7 +121,7 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
|
||||
self.offset
|
||||
}
|
||||
|
||||
const CAPACITY: usize = if BUFFERED { PAGE_SZ } else { 0 };
|
||||
const CAPACITY: usize = if BUFFERED { 64 * 1024 } else { 0 };
|
||||
|
||||
/// Writes the given buffer directly to the underlying `VirtualFile`.
|
||||
/// You need to make sure that the internal buffer is empty, otherwise
|
||||
|
||||
@@ -5,14 +5,12 @@ use crate::config::PageServerConf;
|
||||
use crate::context::RequestContext;
|
||||
use crate::page_cache::{self, PAGE_SZ};
|
||||
use crate::tenant::block_io::{BlockCursor, BlockLease, BlockReader};
|
||||
use crate::virtual_file::{self, VirtualFile};
|
||||
use bytes::BytesMut;
|
||||
use crate::virtual_file::owned_buffers_io::write::OwnedAsyncWriter;
|
||||
use crate::virtual_file::{self, owned_buffers_io, VirtualFile};
|
||||
use camino::Utf8PathBuf;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use std::cmp::min;
|
||||
|
||||
use std::io::{self, ErrorKind};
|
||||
use std::ops::DerefMut;
|
||||
use std::io;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use tracing::*;
|
||||
use utils::id::TimelineId;
|
||||
@@ -22,18 +20,31 @@ pub struct EphemeralFile {
|
||||
|
||||
_tenant_shard_id: TenantShardId,
|
||||
_timeline_id: TimelineId,
|
||||
file: VirtualFile,
|
||||
len: u64,
|
||||
/// An ephemeral file is append-only.
|
||||
/// We keep the last page, which can still be modified, in [`Self::mutable_tail`].
|
||||
/// The other pages, which can no longer be modified, are accessed through the page cache.
|
||||
/// We sandwich the buffered writer between two size-tracking writers.
|
||||
/// This allows us to "elegantly" track in-memory bytes vs flushed bytes,
|
||||
/// enabling [`Self::read_blk`] to determine whether to read from the
|
||||
/// buffered writer's buffer, versus going to the VirtualFile.
|
||||
///
|
||||
/// None <=> IO is ongoing.
|
||||
/// Size is fixed to PAGE_SZ at creation time and must not be changed.
|
||||
mutable_tail: Option<BytesMut>,
|
||||
/// TODO: longer-term, we probably wand to get rid of this in favor
|
||||
/// of a double-buffering scheme. See this commit's commit message
|
||||
/// and git history for what we had before this sandwich, it might be useful.
|
||||
file: owned_buffers_io::util::size_tracking_writer::Writer<
|
||||
owned_buffers_io::write::BufferedWriter<
|
||||
{ Self::TAIL_SZ },
|
||||
owned_buffers_io::util::size_tracking_writer::Writer<
|
||||
owned_buffers_io::util::page_cache_priming_writer::Writer<
|
||||
PAGE_SZ,
|
||||
owned_buffers_io::util::page_cache_priming_writer::VirtualFileAdaptor,
|
||||
&'static crate::page_cache::PageCache,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
}
|
||||
|
||||
impl EphemeralFile {
|
||||
const TAIL_SZ: usize = 64 * 1024;
|
||||
|
||||
pub async fn create(
|
||||
conf: &PageServerConf,
|
||||
tenant_shard_id: TenantShardId,
|
||||
@@ -58,18 +69,27 @@ impl EphemeralFile {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let page_cache_file_id = page_cache::next_file_id();
|
||||
let file = owned_buffers_io::util::page_cache_priming_writer::VirtualFileAdaptor::new(
|
||||
file,
|
||||
page_cache_file_id,
|
||||
);
|
||||
let file =
|
||||
owned_buffers_io::util::page_cache_priming_writer::Writer::new(file, page_cache::get());
|
||||
let file = owned_buffers_io::util::size_tracking_writer::Writer::new(file);
|
||||
let file = owned_buffers_io::write::BufferedWriter::new(file);
|
||||
let file = owned_buffers_io::util::size_tracking_writer::Writer::new(file);
|
||||
|
||||
Ok(EphemeralFile {
|
||||
page_cache_file_id: page_cache::next_file_id(),
|
||||
page_cache_file_id,
|
||||
_tenant_shard_id: tenant_shard_id,
|
||||
_timeline_id: timeline_id,
|
||||
file,
|
||||
len: 0,
|
||||
mutable_tail: Some(BytesMut::zeroed(PAGE_SZ)),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn len(&self) -> u64 {
|
||||
self.len
|
||||
self.file.bytes_written()
|
||||
}
|
||||
|
||||
pub(crate) async fn read_blk(
|
||||
@@ -77,8 +97,22 @@ impl EphemeralFile {
|
||||
blknum: u32,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<BlockLease, io::Error> {
|
||||
let flushed_blknums = 0..self.len / PAGE_SZ as u64;
|
||||
if flushed_blknums.contains(&(blknum as u64)) {
|
||||
let buffered_offset = self.file.bytes_written();
|
||||
let flushed_offset = self.file.as_inner().as_inner().bytes_written();
|
||||
assert!(buffered_offset >= flushed_offset);
|
||||
let read_offset = (blknum as u64) * (PAGE_SZ as u64);
|
||||
|
||||
assert_eq!(
|
||||
flushed_offset % (PAGE_SZ as u64),
|
||||
0,
|
||||
"we need this in the logic below, because it assumes the page isn't spread across flushed part and in-memory buffer"
|
||||
);
|
||||
|
||||
if read_offset < flushed_offset {
|
||||
assert!(
|
||||
read_offset + (PAGE_SZ as u64) <= flushed_offset,
|
||||
"this impl can't deal with pages spread across flushed & buffered part"
|
||||
);
|
||||
let cache = page_cache::get();
|
||||
match cache
|
||||
.read_immutable_buf(self.page_cache_file_id, blknum, ctx)
|
||||
@@ -89,7 +123,15 @@ impl EphemeralFile {
|
||||
// order path before error because error is anyhow::Error => might have many contexts
|
||||
format!(
|
||||
"ephemeral file: read immutable page #{}: {}: {:#}",
|
||||
blknum, self.file.path, e,
|
||||
blknum,
|
||||
self.file
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.path,
|
||||
e,
|
||||
),
|
||||
)
|
||||
})? {
|
||||
@@ -99,6 +141,11 @@ impl EphemeralFile {
|
||||
page_cache::ReadBufResult::NotFound(write_guard) => {
|
||||
let write_guard = self
|
||||
.file
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.read_exact_at_page(write_guard, blknum as u64 * PAGE_SZ as u64)
|
||||
.await?;
|
||||
let read_guard = write_guard.mark_valid();
|
||||
@@ -106,13 +153,31 @@ impl EphemeralFile {
|
||||
}
|
||||
};
|
||||
} else {
|
||||
debug_assert_eq!(blknum as u64, self.len / PAGE_SZ as u64);
|
||||
let read_until_offset = read_offset + (PAGE_SZ as u64);
|
||||
if !(0..buffered_offset).contains(&read_until_offset) {
|
||||
// The blob_io code relies on the reader allowing reads past
|
||||
// the end of what was written, up to end of the current PAGE_SZ chunk.
|
||||
// This is a relict of the past where we would get a pre-zeroed page from the page cache.
|
||||
//
|
||||
// DeltaLayer probably has the same issue, not sure why it needs no special treatment.
|
||||
let nbytes_past_end = read_until_offset.checked_sub(buffered_offset).unwrap();
|
||||
if nbytes_past_end >= (PAGE_SZ as u64) {
|
||||
// TODO: treat this as error. Pre-existing issue before this patch.
|
||||
panic!(
|
||||
"return IO error: read past end of file: read=0x{read_offset:x} buffered=0x{buffered_offset:x} flushed=0x{flushed_offset}"
|
||||
)
|
||||
}
|
||||
}
|
||||
let buffer: &[u8; Self::TAIL_SZ] = self.file.as_inner().inspect_buffer();
|
||||
let read_offset_in_buffer = read_offset
|
||||
.checked_sub(flushed_offset)
|
||||
.expect("would have taken `if` branch instead of this one");
|
||||
|
||||
let read_offset_in_buffer = usize::try_from(read_offset_in_buffer).unwrap();
|
||||
let page = &buffer[read_offset_in_buffer..(read_offset_in_buffer + PAGE_SZ)];
|
||||
Ok(BlockLease::EphemeralFileMutableTail(
|
||||
self.mutable_tail
|
||||
.as_deref()
|
||||
.expect("we're not doing IO, it must be Some()")
|
||||
.try_into()
|
||||
.expect("we ensure that it's always PAGE_SZ"),
|
||||
page.try_into()
|
||||
.expect("the slice above got it as page-size slice"),
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -122,137 +187,22 @@ impl EphemeralFile {
|
||||
srcbuf: &[u8],
|
||||
ctx: &RequestContext,
|
||||
) -> Result<u64, io::Error> {
|
||||
struct Writer<'a> {
|
||||
ephemeral_file: &'a mut EphemeralFile,
|
||||
/// The block to which the next [`push_bytes`] will write.
|
||||
blknum: u32,
|
||||
/// The offset inside the block identified by [`blknum`] to which [`push_bytes`] will write.
|
||||
off: usize,
|
||||
}
|
||||
impl<'a> Writer<'a> {
|
||||
fn new(ephemeral_file: &'a mut EphemeralFile) -> io::Result<Writer<'a>> {
|
||||
Ok(Writer {
|
||||
blknum: (ephemeral_file.len / PAGE_SZ as u64) as u32,
|
||||
off: (ephemeral_file.len % PAGE_SZ as u64) as usize,
|
||||
ephemeral_file,
|
||||
})
|
||||
}
|
||||
#[inline(always)]
|
||||
async fn push_bytes(
|
||||
&mut self,
|
||||
src: &[u8],
|
||||
ctx: &RequestContext,
|
||||
) -> Result<(), io::Error> {
|
||||
let mut src_remaining = src;
|
||||
while !src_remaining.is_empty() {
|
||||
let dst_remaining = &mut self
|
||||
.ephemeral_file
|
||||
.mutable_tail
|
||||
.as_deref_mut()
|
||||
.expect("IO is not yet ongoing")[self.off..];
|
||||
let n = min(dst_remaining.len(), src_remaining.len());
|
||||
dst_remaining[..n].copy_from_slice(&src_remaining[..n]);
|
||||
self.off += n;
|
||||
src_remaining = &src_remaining[n..];
|
||||
if self.off == PAGE_SZ {
|
||||
let mutable_tail = std::mem::take(&mut self.ephemeral_file.mutable_tail)
|
||||
.expect("IO is not yet ongoing");
|
||||
let (mutable_tail, res) = self
|
||||
.ephemeral_file
|
||||
.file
|
||||
.write_all_at(mutable_tail, self.blknum as u64 * PAGE_SZ as u64)
|
||||
.await;
|
||||
// TODO: If we panic before we can put the mutable_tail back, subsequent calls will fail.
|
||||
// I.e., the IO isn't retryable if we panic.
|
||||
self.ephemeral_file.mutable_tail = Some(mutable_tail);
|
||||
match res {
|
||||
Ok(_) => {
|
||||
// Pre-warm the page cache with what we just wrote.
|
||||
// This isn't necessary for coherency/correctness, but it's how we've always done it.
|
||||
let cache = page_cache::get();
|
||||
match cache
|
||||
.read_immutable_buf(
|
||||
self.ephemeral_file.page_cache_file_id,
|
||||
self.blknum,
|
||||
ctx,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(page_cache::ReadBufResult::Found(_guard)) => {
|
||||
// This function takes &mut self, so, it shouldn't be possible to reach this point.
|
||||
unreachable!("we just wrote blknum {} and this function takes &mut self, so, no concurrent read_blk is possible", self.blknum);
|
||||
}
|
||||
Ok(page_cache::ReadBufResult::NotFound(mut write_guard)) => {
|
||||
let buf: &mut [u8] = write_guard.deref_mut();
|
||||
debug_assert_eq!(buf.len(), PAGE_SZ);
|
||||
buf.copy_from_slice(
|
||||
self.ephemeral_file
|
||||
.mutable_tail
|
||||
.as_deref()
|
||||
.expect("IO is not ongoing"),
|
||||
);
|
||||
let _ = write_guard.mark_valid();
|
||||
// pre-warm successful
|
||||
}
|
||||
Err(e) => {
|
||||
error!("ephemeral_file write_blob failed to get immutable buf to pre-warm page cache: {e:?}");
|
||||
// fail gracefully, it's not the end of the world if we can't pre-warm the cache here
|
||||
}
|
||||
}
|
||||
// Zero the buffer for re-use.
|
||||
// Zeroing is critical for correcntess because the write_blob code below
|
||||
// and similarly read_blk expect zeroed pages.
|
||||
self.ephemeral_file
|
||||
.mutable_tail
|
||||
.as_deref_mut()
|
||||
.expect("IO is not ongoing")
|
||||
.fill(0);
|
||||
// This block is done, move to next one.
|
||||
self.blknum += 1;
|
||||
self.off = 0;
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(std::io::Error::new(
|
||||
ErrorKind::Other,
|
||||
// order error before path because path is long and error is short
|
||||
format!(
|
||||
"ephemeral_file: write_blob: write-back full tail blk #{}: {:#}: {}",
|
||||
self.blknum,
|
||||
e,
|
||||
self.ephemeral_file.file.path,
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let pos = self.len;
|
||||
let mut writer = Writer::new(self)?;
|
||||
let pos = self.file.bytes_written();
|
||||
|
||||
// Write the length field
|
||||
if srcbuf.len() < 0x80 {
|
||||
// short one-byte length header
|
||||
let len_buf = [srcbuf.len() as u8];
|
||||
writer.push_bytes(&len_buf, ctx).await?;
|
||||
|
||||
self.file.write_all_borrowed(&len_buf, ctx).await?;
|
||||
} else {
|
||||
let mut len_buf = u32::to_be_bytes(srcbuf.len() as u32);
|
||||
len_buf[0] |= 0x80;
|
||||
writer.push_bytes(&len_buf, ctx).await?;
|
||||
self.file.write_all_borrowed(&len_buf, ctx).await?;
|
||||
}
|
||||
|
||||
// Write the payload
|
||||
writer.push_bytes(srcbuf, ctx).await?;
|
||||
|
||||
if srcbuf.len() < 0x80 {
|
||||
self.len += 1;
|
||||
} else {
|
||||
self.len += 4;
|
||||
}
|
||||
self.len += srcbuf.len() as u64;
|
||||
self.file.write_all_borrowed(srcbuf, ctx).await?;
|
||||
|
||||
Ok(pos)
|
||||
}
|
||||
@@ -273,7 +223,16 @@ impl Drop for EphemeralFile {
|
||||
// We leave them there, [`crate::page_cache::PageCache::find_victim`] will evict them when needed.
|
||||
|
||||
// unlink the file
|
||||
let res = std::fs::remove_file(&self.file.path);
|
||||
let res = std::fs::remove_file(
|
||||
&self
|
||||
.file
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.path,
|
||||
);
|
||||
if let Err(e) = res {
|
||||
if e.kind() != std::io::ErrorKind::NotFound {
|
||||
// just never log the not found errors, we cannot do anything for them; on detach
|
||||
@@ -282,7 +241,14 @@ impl Drop for EphemeralFile {
|
||||
// not found files might also be related to https://github.com/neondatabase/neon/issues/2442
|
||||
error!(
|
||||
"could not remove ephemeral file '{}': {}",
|
||||
self.file.path, e
|
||||
self.file
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.path,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,9 +6,7 @@ use futures::stream::StreamExt;
|
||||
use itertools::Itertools;
|
||||
use pageserver_api::key::Key;
|
||||
use pageserver_api::models::ShardParameters;
|
||||
use pageserver_api::shard::{
|
||||
ShardCount, ShardIdentity, ShardNumber, ShardStripeSize, TenantShardId,
|
||||
};
|
||||
use pageserver_api::shard::{ShardCount, ShardIdentity, ShardNumber, TenantShardId};
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use std::borrow::Cow;
|
||||
use std::cmp::Ordering;
|
||||
@@ -1441,12 +1439,11 @@ impl TenantManager {
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
new_shard_count: ShardCount,
|
||||
new_stripe_size: Option<ShardStripeSize>,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<Vec<TenantShardId>> {
|
||||
let tenant = get_tenant(tenant_shard_id, true)?;
|
||||
|
||||
// Validate the incoming request
|
||||
// Plan: identify what the new child shards will be
|
||||
if new_shard_count.count() <= tenant_shard_id.shard_count.count() {
|
||||
anyhow::bail!("Requested shard count is not an increase");
|
||||
}
|
||||
@@ -1455,18 +1452,10 @@ impl TenantManager {
|
||||
anyhow::bail!("Requested split is not a power of two");
|
||||
}
|
||||
|
||||
if let Some(new_stripe_size) = new_stripe_size {
|
||||
if tenant.get_shard_stripe_size() != new_stripe_size
|
||||
&& tenant_shard_id.shard_count.count() > 1
|
||||
{
|
||||
// This tenant already has multiple shards, it is illegal to try and change its stripe size
|
||||
anyhow::bail!(
|
||||
"Shard stripe size may not be modified once tenant has multiple shards"
|
||||
);
|
||||
}
|
||||
}
|
||||
let parent_shard_identity = tenant.shard_identity;
|
||||
let parent_tenant_conf = tenant.get_tenant_conf();
|
||||
let parent_generation = tenant.generation;
|
||||
|
||||
// Plan: identify what the new child shards will be
|
||||
let child_shards = tenant_shard_id.split(new_shard_count);
|
||||
tracing::info!(
|
||||
"Shard {} splits into: {}",
|
||||
@@ -1477,10 +1466,6 @@ impl TenantManager {
|
||||
.join(",")
|
||||
);
|
||||
|
||||
let parent_shard_identity = tenant.shard_identity;
|
||||
let parent_tenant_conf = tenant.get_tenant_conf();
|
||||
let parent_generation = tenant.generation;
|
||||
|
||||
// Phase 1: Write out child shards' remote index files, in the parent tenant's current generation
|
||||
if let Err(e) = tenant.split_prepare(&child_shards).await {
|
||||
// If [`Tenant::split_prepare`] fails, we must reload the tenant, because it might
|
||||
@@ -1530,9 +1515,6 @@ impl TenantManager {
|
||||
// Phase 3: Spawn the child shards
|
||||
for child_shard in &child_shards {
|
||||
let mut child_shard_identity = parent_shard_identity;
|
||||
if let Some(new_stripe_size) = new_stripe_size {
|
||||
child_shard_identity.stripe_size = new_stripe_size;
|
||||
}
|
||||
child_shard_identity.count = child_shard.shard_count;
|
||||
child_shard_identity.number = child_shard.shard_number;
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
//! This is similar to PostgreSQL's virtual file descriptor facility in
|
||||
//! src/backend/storage/file/fd.c
|
||||
//!
|
||||
use crate::context::RequestContext;
|
||||
use crate::metrics::{StorageIoOperation, STORAGE_IO_SIZE, STORAGE_IO_TIME_METRIC};
|
||||
|
||||
use crate::page_cache::PageWriteGuard;
|
||||
@@ -34,6 +35,25 @@ pub(crate) use io_engine::IoEngineKind;
|
||||
pub(crate) use metadata::Metadata;
|
||||
pub(crate) use open_options::*;
|
||||
|
||||
use self::owned_buffers_io::write::OwnedAsyncWriter;
|
||||
|
||||
pub(crate) mod owned_buffers_io {
|
||||
//! Abstractions for IO with owned buffers.
|
||||
//!
|
||||
//! Not actually tied to [`crate::virtual_file`] specifically, but, it's the primary
|
||||
//! reason we need this abstraction.
|
||||
//!
|
||||
//! Over time, this could move into the `tokio-epoll-uring` crate, maybe `uring-common`,
|
||||
//! but for the time being we're proving out the primitives in the neon.git repo
|
||||
//! for faster iteration.
|
||||
|
||||
pub(crate) mod write;
|
||||
pub(crate) mod util {
|
||||
pub(crate) mod page_cache_priming_writer;
|
||||
pub(crate) mod size_tracking_writer;
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// A virtual file descriptor. You can use this just like std::fs::File, but internally
|
||||
/// the underlying file is closed if the system is low on file descriptors,
|
||||
@@ -1064,6 +1084,32 @@ impl Drop for VirtualFile {
|
||||
}
|
||||
}
|
||||
|
||||
impl OwnedAsyncWriter for VirtualFile {
|
||||
#[inline(always)]
|
||||
async fn write_all<
|
||||
B: BoundedBuf<Buf = Buf, Bounds = Bounds>,
|
||||
Buf: IoBuf + Send,
|
||||
Bounds: std::ops::RangeBounds<usize>,
|
||||
>(
|
||||
&mut self,
|
||||
buf: B,
|
||||
_: &RequestContext,
|
||||
) -> std::io::Result<(usize, B::Buf)> {
|
||||
let (buf, res) = VirtualFile::write_all(self, buf).await;
|
||||
res.map(move |v| (v, buf))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
async fn write_all_borrowed(
|
||||
&mut self,
|
||||
_buf: &[u8],
|
||||
_ctx: &RequestContext,
|
||||
) -> std::io::Result<usize> {
|
||||
// TODO: ensure this through the type system
|
||||
panic!("this should not happen");
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenFiles {
|
||||
fn new(num_slots: usize) -> OpenFiles {
|
||||
let mut slots = Box::new(Vec::with_capacity(num_slots));
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
use std::ops::Deref;
|
||||
|
||||
use tokio_epoll_uring::BoundedBuf;
|
||||
use tracing::error;
|
||||
|
||||
use crate::{
|
||||
context::RequestContext,
|
||||
page_cache::{self, PageReadGuard, PAGE_SZ},
|
||||
virtual_file::{owned_buffers_io::write::OwnedAsyncWriter, VirtualFile},
|
||||
};
|
||||
|
||||
pub struct Writer<const N: usize, W: OwnedAsyncWriter, C: Cache<N, W>> {
|
||||
under: W,
|
||||
written: u64,
|
||||
cache: C,
|
||||
}
|
||||
|
||||
pub trait Cache<const PAGE_SZ: usize, W> {
|
||||
async fn fill_cache(
|
||||
&self,
|
||||
under: &W,
|
||||
page_no: u64,
|
||||
contents: &[u8; PAGE_SZ],
|
||||
ctx: &RequestContext,
|
||||
);
|
||||
}
|
||||
|
||||
impl<const N: usize, W, C> Writer<N, W, C>
|
||||
where
|
||||
C: Cache<N, W>,
|
||||
W: OwnedAsyncWriter,
|
||||
{
|
||||
pub fn new(under: W, cache: C) -> Self {
|
||||
Self {
|
||||
under,
|
||||
written: 0,
|
||||
cache,
|
||||
}
|
||||
}
|
||||
pub fn as_inner(&self) -> &W {
|
||||
&self.under
|
||||
}
|
||||
fn invariants(&self) {
|
||||
assert_eq!(
|
||||
self.written % (N as u64),
|
||||
0,
|
||||
"writes must happen in multiples of N"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, W, C> OwnedAsyncWriter for Writer<N, W, C>
|
||||
where
|
||||
W: OwnedAsyncWriter,
|
||||
C: Cache<N, W>,
|
||||
{
|
||||
async fn write_all<
|
||||
B: tokio_epoll_uring::BoundedBuf<Buf = Buf, Bounds = Bounds>,
|
||||
Buf: tokio_epoll_uring::IoBuf + Send,
|
||||
Bounds: std::ops::RangeBounds<usize>,
|
||||
>(
|
||||
&mut self,
|
||||
buf: B,
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<(usize, B::Buf)> {
|
||||
let buf = buf.slice_full();
|
||||
assert_eq!(buf.bytes_init() % N, 0);
|
||||
self.invariants();
|
||||
let saved_bounds = buf.bounds();
|
||||
let debug_assert_contents_eq = if cfg!(debug_assertions) {
|
||||
Some(buf[..].to_vec())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let res = self.under.write_all(buf, ctx).await;
|
||||
let res = if let Ok((nwritten, buf)) = res {
|
||||
assert_eq!(nwritten % N, 0);
|
||||
let buf = tokio_epoll_uring::Slice::from_buf_bounds(buf, saved_bounds);
|
||||
if let Some(before) = debug_assert_contents_eq {
|
||||
debug_assert_eq!(&before[..], &buf[..]);
|
||||
}
|
||||
assert_eq!(nwritten, buf.bytes_init());
|
||||
for page_no_in_buf in 0..(nwritten / N) {
|
||||
let page: &[u8; N] = (&buf[(page_no_in_buf * N)..((page_no_in_buf + 1) * N)])
|
||||
.try_into()
|
||||
.unwrap();
|
||||
self.cache
|
||||
.fill_cache(
|
||||
&self.under,
|
||||
(self.written / (N as u64)) + (page_no_in_buf as u64),
|
||||
page,
|
||||
ctx,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
self.written += nwritten as u64;
|
||||
Ok((nwritten, tokio_epoll_uring::Slice::into_inner(buf)))
|
||||
} else {
|
||||
res
|
||||
};
|
||||
self.invariants();
|
||||
res
|
||||
}
|
||||
|
||||
async fn write_all_borrowed(&mut self, _: &[u8], _: &RequestContext) -> std::io::Result<usize> {
|
||||
// TODO: use type system to ensure this doesn't happen at runtime
|
||||
panic!("don't put the types together this way")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VirtualFileAdaptor {
|
||||
file: VirtualFile,
|
||||
file_id: page_cache::FileId,
|
||||
}
|
||||
|
||||
impl VirtualFileAdaptor {
|
||||
pub fn new(file: VirtualFile, file_id: page_cache::FileId) -> Self {
|
||||
Self { file, file_id }
|
||||
}
|
||||
pub fn as_inner(&self) -> &VirtualFile {
|
||||
&self.file
|
||||
}
|
||||
}
|
||||
|
||||
impl<'c> Cache<{ PAGE_SZ }, VirtualFileAdaptor> for &'c crate::page_cache::PageCache {
|
||||
async fn fill_cache(
|
||||
&self,
|
||||
under: &VirtualFileAdaptor,
|
||||
page_no: u64,
|
||||
contents: &[u8; PAGE_SZ],
|
||||
ctx: &RequestContext,
|
||||
) {
|
||||
match self
|
||||
.read_immutable_buf(
|
||||
under.file_id,
|
||||
u32::try_from(page_no).expect("files larger than u32::MAX * 8192 aren't supported"),
|
||||
ctx,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(crate::page_cache::ReadBufResult::Found(guard)) => {
|
||||
debug_assert_eq!(guard.deref(), contents);
|
||||
}
|
||||
Ok(crate::page_cache::ReadBufResult::NotFound(mut guard)) => {
|
||||
guard.copy_from_slice(contents);
|
||||
let guard: PageReadGuard<'_> = guard.mark_valid();
|
||||
drop(guard);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("failed to get immutable buf to pre-warm page cache: {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OwnedAsyncWriter for VirtualFileAdaptor {
|
||||
async fn write_all<
|
||||
B: tokio_epoll_uring::BoundedBuf<Buf = Buf, Bounds = Bounds>,
|
||||
Buf: tokio_epoll_uring::IoBuf + Send,
|
||||
Bounds: std::ops::RangeBounds<usize>,
|
||||
>(
|
||||
&mut self,
|
||||
buf: B,
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<(usize, B::Buf)> {
|
||||
OwnedAsyncWriter::write_all(&mut self.file, buf, ctx).await
|
||||
}
|
||||
|
||||
async fn write_all_borrowed(&mut self, _: &[u8], _: &RequestContext) -> std::io::Result<usize> {
|
||||
// TODO: use type system to ensure this doesn't happen at runtime
|
||||
panic!("don't put the types together this way")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
use crate::{context::RequestContext, virtual_file::owned_buffers_io::write::OwnedAsyncWriter};
|
||||
use tokio_epoll_uring::{BoundedBuf, IoBuf};
|
||||
|
||||
pub struct Writer<W> {
|
||||
dst: W,
|
||||
bytes_amount: u64,
|
||||
}
|
||||
|
||||
impl<W> Writer<W> {
|
||||
pub fn new(dst: W) -> Self {
|
||||
Self {
|
||||
dst,
|
||||
bytes_amount: 0,
|
||||
}
|
||||
}
|
||||
pub fn bytes_written(&self) -> u64 {
|
||||
self.bytes_amount
|
||||
}
|
||||
pub fn as_inner(&self) -> &W {
|
||||
&self.dst
|
||||
}
|
||||
/// Returns the wrapped `VirtualFile` object as well as the number
|
||||
/// of bytes that were written to it through this object.
|
||||
#[allow(dead_code)]
|
||||
pub fn into_inner(self) -> (u64, W) {
|
||||
(self.bytes_amount, self.dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> OwnedAsyncWriter for Writer<W>
|
||||
where
|
||||
W: OwnedAsyncWriter,
|
||||
{
|
||||
#[inline(always)]
|
||||
async fn write_all<
|
||||
B: BoundedBuf<Buf = Buf, Bounds = Bounds>,
|
||||
Buf: IoBuf + Send,
|
||||
Bounds: std::ops::RangeBounds<usize>,
|
||||
>(
|
||||
&mut self,
|
||||
buf: B,
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<(usize, B::Buf)> {
|
||||
let (nwritten, buf) = self.dst.write_all(buf, ctx).await?;
|
||||
self.bytes_amount += u64::try_from(nwritten).unwrap();
|
||||
Ok((nwritten, buf))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
async fn write_all_borrowed(
|
||||
&mut self,
|
||||
buf: &[u8],
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<usize> {
|
||||
let nwritten = self.dst.write_all_borrowed(buf, ctx).await?;
|
||||
self.bytes_amount += u64::try_from(nwritten).unwrap();
|
||||
Ok(nwritten)
|
||||
}
|
||||
}
|
||||
351
pageserver/src/virtual_file/owned_buffers_io/write.rs
Normal file
351
pageserver/src/virtual_file/owned_buffers_io/write.rs
Normal file
@@ -0,0 +1,351 @@
|
||||
use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
|
||||
|
||||
use crate::context::RequestContext;
|
||||
|
||||
/// A trait for doing owned-buffer write IO.
|
||||
/// Think [`tokio::io::AsyncWrite`] but with owned buffers.
|
||||
pub trait OwnedAsyncWriter {
|
||||
async fn write_all<
|
||||
B: BoundedBuf<Buf = Buf, Bounds = Bounds>,
|
||||
Buf: IoBuf + Send,
|
||||
Bounds: std::ops::RangeBounds<usize>,
|
||||
>(
|
||||
&mut self,
|
||||
buf: B,
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<(usize, B::Buf)>;
|
||||
async fn write_all_borrowed(
|
||||
&mut self,
|
||||
buf: &[u8],
|
||||
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<usize>;
|
||||
}
|
||||
|
||||
/// A wrapper aorund an [`OwnedAsyncWriter`] that batches smaller writers
|
||||
/// into `BUFFER_SIZE`-sized writes.
|
||||
///
|
||||
/// # Passthrough Of Large Writers
|
||||
///
|
||||
/// Buffered writes larger than the `BUFFER_SIZE` cause the internal
|
||||
/// buffer to be flushed, even if it is not full yet. Then, the large
|
||||
/// buffered write is passed through to the unerlying [`OwnedAsyncWriter`].
|
||||
///
|
||||
/// This pass-through is generally beneficial for throughput, but if
|
||||
/// the storage backend of the [`OwnedAsyncWriter`] is a shared resource,
|
||||
/// unlimited large writes may cause latency or fairness issues.
|
||||
///
|
||||
/// In such cases, a different implementation that always buffers in memory
|
||||
/// may be preferable.
|
||||
pub struct BufferedWriter<const BUFFER_SIZE: usize, W> {
|
||||
writer: W,
|
||||
// invariant: always remains Some(buf)
|
||||
// with buf.capacity() == BUFFER_SIZE except
|
||||
// - while IO is ongoing => goes back to Some() once the IO completed successfully
|
||||
// - after an IO error => stays `None` forever
|
||||
// In these exceptional cases, it's `None`.
|
||||
buf: Option<zero_initialized_buffer::Buf<BUFFER_SIZE>>,
|
||||
}
|
||||
|
||||
mod zero_initialized_buffer;
|
||||
|
||||
impl<const BUFFER_SIZE: usize, W> BufferedWriter<BUFFER_SIZE, W>
|
||||
where
|
||||
W: OwnedAsyncWriter,
|
||||
{
|
||||
pub fn new(writer: W) -> Self {
|
||||
Self {
|
||||
writer,
|
||||
buf: Some(zero_initialized_buffer::Buf::default()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_inner(&self) -> &W {
|
||||
&self.writer
|
||||
}
|
||||
|
||||
/// panics if used after an error
|
||||
pub fn inspect_buffer(&self) -> &[u8; BUFFER_SIZE] {
|
||||
self.buf
|
||||
.as_ref()
|
||||
// TODO: can this happen on the EphemeralFile read path?
|
||||
.expect("must not use after an error")
|
||||
.as_zero_padded_slice()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn flush_and_into_inner(mut self, ctx: &RequestContext) -> std::io::Result<W> {
|
||||
self.flush(ctx).await?;
|
||||
let Self { buf, writer } = self;
|
||||
assert!(buf.is_some());
|
||||
Ok(writer)
|
||||
}
|
||||
|
||||
pub async fn write_buffered<B: IoBuf>(
|
||||
&mut self,
|
||||
chunk: Slice<B>,
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<(usize, B)>
|
||||
where
|
||||
B: IoBuf + Send,
|
||||
{
|
||||
let chunk_len = chunk.len();
|
||||
// avoid memcpy for the middle of the chunk
|
||||
if chunk.len() >= BUFFER_SIZE {
|
||||
self.flush(ctx).await?;
|
||||
// do a big write, bypassing `buf`
|
||||
assert_eq!(
|
||||
self.buf
|
||||
.as_ref()
|
||||
.expect("must not use after an error")
|
||||
.len(),
|
||||
0
|
||||
);
|
||||
let (nwritten, chunk) = self.writer.write_all(chunk, ctx).await?;
|
||||
assert_eq!(nwritten, chunk_len);
|
||||
return Ok((nwritten, chunk));
|
||||
}
|
||||
// in-memory copy the < BUFFER_SIZED tail of the chunk
|
||||
assert!(chunk.len() < BUFFER_SIZE);
|
||||
let mut slice = &chunk[..];
|
||||
while !slice.is_empty() {
|
||||
let buf = self.buf.as_mut().expect("must not use after an error");
|
||||
let need = BUFFER_SIZE - buf.len();
|
||||
let have = slice.len();
|
||||
let n = std::cmp::min(need, have);
|
||||
buf.extend_from_slice(&slice[..n]);
|
||||
slice = &slice[n..];
|
||||
if buf.len() >= BUFFER_SIZE {
|
||||
assert_eq!(buf.len(), BUFFER_SIZE);
|
||||
self.flush(ctx).await?;
|
||||
}
|
||||
}
|
||||
assert!(slice.is_empty(), "by now we should have drained the chunk");
|
||||
Ok((chunk_len, chunk.into_inner()))
|
||||
}
|
||||
|
||||
/// Always goes through the internal buffer.
|
||||
/// Guaranteed to never invoke [`OwnedAsyncWrite::write_all_borrowed`] on the underlying.
|
||||
pub async fn write_all_borrowed(
|
||||
&mut self,
|
||||
mut chunk: &[u8],
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<usize> {
|
||||
let chunk_len = chunk.len();
|
||||
while !chunk.is_empty() {
|
||||
let buf = self.buf.as_mut().expect("must not use after an error");
|
||||
let need = BUFFER_SIZE - buf.len();
|
||||
let have = chunk.len();
|
||||
let n = std::cmp::min(need, have);
|
||||
buf.extend_from_slice(&chunk[..n]);
|
||||
chunk = &chunk[n..];
|
||||
if buf.len() >= BUFFER_SIZE {
|
||||
assert_eq!(buf.len(), BUFFER_SIZE);
|
||||
self.flush(ctx).await?;
|
||||
}
|
||||
}
|
||||
Ok(chunk_len)
|
||||
}
|
||||
|
||||
async fn flush(&mut self, ctx: &RequestContext) -> std::io::Result<()> {
|
||||
let buf = self.buf.take().expect("must not use after an error");
|
||||
if buf.is_empty() {
|
||||
self.buf = Some(buf);
|
||||
return std::io::Result::Ok(());
|
||||
}
|
||||
let buf_len = buf.len();
|
||||
let (nwritten, mut buf) = self.writer.write_all(buf, ctx).await?;
|
||||
assert_eq!(nwritten, buf_len);
|
||||
buf.clear();
|
||||
self.buf = Some(buf);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const BUFFER_SIZE: usize, W: OwnedAsyncWriter> OwnedAsyncWriter
|
||||
for BufferedWriter<BUFFER_SIZE, W>
|
||||
{
|
||||
#[inline(always)]
|
||||
async fn write_all<
|
||||
B: BoundedBuf<Buf = Buf, Bounds = Bounds>,
|
||||
Buf: IoBuf + Send,
|
||||
Bounds: std::ops::RangeBounds<usize>,
|
||||
>(
|
||||
&mut self,
|
||||
buf: B,
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<(usize, B::Buf)> {
|
||||
let nbytes = buf.bytes_init();
|
||||
if nbytes == 0 {
|
||||
return Ok((0, Slice::into_inner(buf.slice_full())));
|
||||
}
|
||||
let slice = buf.slice(0..nbytes);
|
||||
BufferedWriter::write_buffered(self, slice, ctx).await
|
||||
}
|
||||
#[inline(always)]
|
||||
async fn write_all_borrowed(
|
||||
&mut self,
|
||||
buf: &[u8],
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<usize> {
|
||||
BufferedWriter::write_all_borrowed(self, buf, ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
impl OwnedAsyncWriter for Vec<u8> {
|
||||
async fn write_all<
|
||||
B: BoundedBuf<Buf = Buf, Bounds = Bounds>,
|
||||
Buf: IoBuf + Send,
|
||||
Bounds: std::ops::RangeBounds<usize>,
|
||||
>(
|
||||
&mut self,
|
||||
buf: B,
|
||||
_: &RequestContext,
|
||||
) -> std::io::Result<(usize, B::Buf)> {
|
||||
let nbytes = buf.bytes_init();
|
||||
if nbytes == 0 {
|
||||
return Ok((0, Slice::into_inner(buf.slice_full())));
|
||||
}
|
||||
let buf = buf.slice(0..nbytes);
|
||||
self.extend_from_slice(&buf[..]);
|
||||
Ok((buf.len(), Slice::into_inner(buf)))
|
||||
}
|
||||
|
||||
async fn write_all_borrowed(
|
||||
&mut self,
|
||||
buf: &[u8],
|
||||
_ctx: &RequestContext,
|
||||
) -> std::io::Result<usize> {
|
||||
self.extend_from_slice(buf);
|
||||
Ok(buf.len())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::context::{DownloadBehavior, RequestContext};
|
||||
use crate::task_mgr::TaskKind;
|
||||
|
||||
#[derive(Default)]
|
||||
struct RecorderWriter {
|
||||
writes: Vec<Vec<u8>>,
|
||||
}
|
||||
impl OwnedAsyncWriter for RecorderWriter {
|
||||
async fn write_all<
|
||||
B: BoundedBuf<Buf = Buf, Bounds = Bounds>,
|
||||
Buf: IoBuf + Send,
|
||||
Bounds: std::ops::RangeBounds<usize>,
|
||||
>(
|
||||
&mut self,
|
||||
buf: B,
|
||||
_: &RequestContext,
|
||||
) -> std::io::Result<(usize, B::Buf)> {
|
||||
let nbytes = buf.bytes_init();
|
||||
if nbytes == 0 {
|
||||
self.writes.push(vec![]);
|
||||
return Ok((0, Slice::into_inner(buf.slice_full())));
|
||||
}
|
||||
let buf = buf.slice(0..nbytes);
|
||||
self.writes.push(Vec::from(&buf[..]));
|
||||
Ok((buf.len(), Slice::into_inner(buf)))
|
||||
}
|
||||
|
||||
async fn write_all_borrowed(
|
||||
&mut self,
|
||||
buf: &[u8],
|
||||
_: &RequestContext,
|
||||
) -> std::io::Result<usize> {
|
||||
self.writes.push(Vec::from(buf));
|
||||
Ok(buf.len())
|
||||
}
|
||||
}
|
||||
|
||||
fn test_ctx() -> RequestContext {
|
||||
RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error)
|
||||
}
|
||||
|
||||
macro_rules! write {
|
||||
($writer:ident, $data:literal) => {{
|
||||
$writer
|
||||
.write_buffered(::bytes::Bytes::from_static($data).slice_full(), &test_ctx())
|
||||
.await?;
|
||||
}};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_buffered_writes_only() -> std::io::Result<()> {
|
||||
let recorder = RecorderWriter::default();
|
||||
let mut writer = BufferedWriter::<2, _>::new(recorder);
|
||||
write!(writer, b"a");
|
||||
write!(writer, b"b");
|
||||
write!(writer, b"c");
|
||||
write!(writer, b"d");
|
||||
write!(writer, b"e");
|
||||
let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
|
||||
assert_eq!(
|
||||
recorder.writes,
|
||||
vec![Vec::from(b"ab"), Vec::from(b"cd"), Vec::from(b"e")]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_passthrough_writes_only() -> std::io::Result<()> {
|
||||
let recorder = RecorderWriter::default();
|
||||
let mut writer = BufferedWriter::<2, _>::new(recorder);
|
||||
write!(writer, b"abc");
|
||||
write!(writer, b"de");
|
||||
write!(writer, b"");
|
||||
write!(writer, b"fghijk");
|
||||
let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
|
||||
assert_eq!(
|
||||
recorder.writes,
|
||||
vec![Vec::from(b"abc"), Vec::from(b"de"), Vec::from(b"fghijk")]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_passthrough_write_with_nonempty_buffer() -> std::io::Result<()> {
|
||||
let recorder = RecorderWriter::default();
|
||||
let mut writer = BufferedWriter::<2, _>::new(recorder);
|
||||
write!(writer, b"a");
|
||||
write!(writer, b"bc");
|
||||
write!(writer, b"d");
|
||||
write!(writer, b"e");
|
||||
let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
|
||||
assert_eq!(
|
||||
recorder.writes,
|
||||
vec![Vec::from(b"a"), Vec::from(b"bc"), Vec::from(b"de")]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_all_borrowed_always_goes_through_buffer() -> std::io::Result<()> {
|
||||
let recorder = RecorderWriter::default();
|
||||
let mut writer = BufferedWriter::<2, _>::new(recorder);
|
||||
let ctx = test_ctx();
|
||||
writer.write_all_borrowed(b"abc", &ctx).await?;
|
||||
writer.write_all_borrowed(b"d", &ctx).await?;
|
||||
writer.write_all_borrowed(b"e", &ctx).await?;
|
||||
writer.write_all_borrowed(b"fg", &ctx).await?;
|
||||
writer.write_all_borrowed(b"hi", &ctx).await?;
|
||||
writer.write_all_borrowed(b"j", &ctx).await?;
|
||||
writer.write_all_borrowed(b"klmno", &ctx).await?;
|
||||
|
||||
let recorder = writer.flush_and_into_inner(&ctx).await?;
|
||||
assert_eq!(
|
||||
recorder.writes,
|
||||
{
|
||||
let expect: &[&[u8]] = &[b"ab", b"cd", b"ef", b"gh", b"ij", b"kl", b"mn", b"o"];
|
||||
expect
|
||||
}
|
||||
.iter()
|
||||
.map(|v| v[..].to_vec())
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
use std::mem::MaybeUninit;
|
||||
|
||||
pub struct Buf<const N: usize> {
|
||||
allocation: Box<[u8; N]>,
|
||||
written: usize,
|
||||
}
|
||||
|
||||
impl<const N: usize> Default for Buf<N> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allocation: Box::new(
|
||||
// SAFETY: zeroed memory is a valid [u8; N]
|
||||
unsafe { MaybeUninit::zeroed().assume_init() },
|
||||
),
|
||||
written: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Buf<N> {
|
||||
#[inline(always)]
|
||||
fn invariants(&self) {
|
||||
debug_assert!(self.written <= N, "{}", self.written);
|
||||
}
|
||||
|
||||
pub fn as_zero_padded_slice(&self) -> &[u8; N] {
|
||||
&self.allocation
|
||||
}
|
||||
|
||||
/// panics if there's not enough capacity left
|
||||
pub fn extend_from_slice(&mut self, buf: &[u8]) {
|
||||
self.invariants();
|
||||
let can = N - self.written;
|
||||
let want = buf.len();
|
||||
assert!(want <= can, "{:x} {:x}", want, can);
|
||||
self.allocation[self.written..(self.written + want)].copy_from_slice(buf);
|
||||
self.written += want;
|
||||
self.invariants();
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.written
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.invariants();
|
||||
self.written = 0;
|
||||
self.allocation[..].fill(0);
|
||||
self.invariants();
|
||||
}
|
||||
}
|
||||
|
||||
/// SAFETY: the Box<> has a stable location in memory.
|
||||
unsafe impl<const N: usize> tokio_epoll_uring::IoBuf for Buf<N> {
|
||||
fn stable_ptr(&self) -> *const u8 {
|
||||
self.allocation.as_ptr()
|
||||
}
|
||||
|
||||
fn bytes_init(&self) -> usize {
|
||||
self.written
|
||||
}
|
||||
|
||||
fn bytes_total(&self) -> usize {
|
||||
self.written // ?
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,6 @@ use crate::{
|
||||
CachedNodeInfo,
|
||||
},
|
||||
context::RequestMonitoring,
|
||||
error::{ErrorKind, ReportableError, UserFacingError},
|
||||
proxy::connect_compute::ConnectMechanism,
|
||||
};
|
||||
|
||||
@@ -118,30 +117,6 @@ pub enum HttpConnError {
|
||||
WakeCompute(#[from] WakeComputeError),
|
||||
}
|
||||
|
||||
impl ReportableError for HttpConnError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
|
||||
HttpConnError::ConnectionError(p) => p.get_error_kind(),
|
||||
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
|
||||
HttpConnError::AuthError(a) => a.get_error_kind(),
|
||||
HttpConnError::WakeCompute(w) => w.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for HttpConnError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
|
||||
HttpConnError::ConnectionError(p) => p.to_string(),
|
||||
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
|
||||
HttpConnError::AuthError(c) => c.to_string_client(),
|
||||
HttpConnError::WakeCompute(c) => c.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TokioMechanism {
|
||||
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
conn_info: ConnInfo,
|
||||
|
||||
@@ -119,12 +119,16 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
}
|
||||
}
|
||||
|
||||
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
|
||||
fn put(
|
||||
pool: &RwLock<Self>,
|
||||
conn_info: &ConnInfo,
|
||||
client: ClientInner<C>,
|
||||
) -> anyhow::Result<()> {
|
||||
let conn_id = client.conn_id;
|
||||
|
||||
if client.is_closed() {
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
let global_max_conn = pool.read().global_pool_size_max_conns;
|
||||
if pool
|
||||
@@ -134,7 +138,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
>= global_max_conn
|
||||
{
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// return connection to the pool
|
||||
@@ -168,6 +172,8 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
} else {
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -647,7 +653,7 @@ impl<C: ClientInnerExt> Client<C> {
|
||||
// return connection to the pool
|
||||
return Some(move || {
|
||||
let _span = current_span.enter();
|
||||
EndpointConnPool::put(&conn_pool, &conn_info, client);
|
||||
let _ = EndpointConnPool::put(&conn_pool, &conn_info, client);
|
||||
});
|
||||
}
|
||||
None
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use futures::future::select;
|
||||
use futures::future::try_join;
|
||||
use futures::future::Either;
|
||||
use futures::StreamExt;
|
||||
use futures::TryFutureExt;
|
||||
use hyper::body::HttpBody;
|
||||
use hyper::header;
|
||||
use hyper::http::HeaderName;
|
||||
@@ -37,13 +37,9 @@ use crate::auth::ComputeUserInfoParseError;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::error::ErrorKind;
|
||||
use crate::error::ReportableError;
|
||||
use crate::error::UserFacingError;
|
||||
use crate::metrics::HTTP_CONTENT_LENGTH;
|
||||
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::DbName;
|
||||
use crate::RoleName;
|
||||
|
||||
@@ -51,7 +47,6 @@ use super::backend::PoolingBackend;
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::json::json_to_pg_text;
|
||||
use super::json::pg_text_row_to_json;
|
||||
use super::json::JsonConversionError;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -122,18 +117,6 @@ pub enum ConnInfoError {
|
||||
MalformedEndpoint,
|
||||
}
|
||||
|
||||
impl ReportableError for ConnInfoError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ConnInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
ctx: &mut RequestMonitoring,
|
||||
headers: &HeaderMap,
|
||||
@@ -229,41 +212,17 @@ pub async fn handle(
|
||||
handle.abort();
|
||||
|
||||
let mut response = match result {
|
||||
Ok(r) => {
|
||||
Ok(Ok(r)) => {
|
||||
ctx.set_success();
|
||||
r
|
||||
}
|
||||
Err(e @ SqlOverHttpError::Cancelled(_)) => {
|
||||
let error_kind = e.get_error_kind();
|
||||
ctx.set_error_kind(error_kind);
|
||||
|
||||
let message = format!(
|
||||
"Query cancelled, runtime exceeded. SQL queries over HTTP must not exceed {} seconds of runtime. Please consider using our websocket based connections",
|
||||
config.http_config.request_timeout.as_secs_f64()
|
||||
);
|
||||
|
||||
tracing::info!(
|
||||
kind=error_kind.to_metric_label(),
|
||||
error=%e,
|
||||
msg=message,
|
||||
"forwarding error to user"
|
||||
);
|
||||
|
||||
json_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }),
|
||||
)?
|
||||
}
|
||||
Err(e) => {
|
||||
let error_kind = e.get_error_kind();
|
||||
ctx.set_error_kind(error_kind);
|
||||
// TODO: ctx.set_error_kind(e.get_error_type());
|
||||
|
||||
let mut message = e.to_string_client();
|
||||
let db_error = match &e {
|
||||
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
|
||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||
_ => None,
|
||||
};
|
||||
let mut message = format!("{:?}", e);
|
||||
let db_error = e
|
||||
.downcast_ref::<tokio_postgres::Error>()
|
||||
.and_then(|e| e.as_db_error());
|
||||
fn get<'a, T: serde::Serialize>(
|
||||
db: Option<&'a DbError>,
|
||||
x: impl FnOnce(&'a DbError) -> T,
|
||||
@@ -306,13 +265,10 @@ pub async fn handle(
|
||||
let line = get(db_error, |db| db.line().map(|l| l.to_string()));
|
||||
let routine = get(db_error, |db| db.routine());
|
||||
|
||||
tracing::info!(
|
||||
kind=error_kind.to_metric_label(),
|
||||
error=%e,
|
||||
msg=message,
|
||||
"forwarding error to user"
|
||||
error!(
|
||||
?code,
|
||||
"sql-over-http per-client task finished with an error: {e:#}"
|
||||
);
|
||||
|
||||
// TODO: this shouldn't always be bad request.
|
||||
json_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
@@ -337,6 +293,21 @@ pub async fn handle(
|
||||
}),
|
||||
)?
|
||||
}
|
||||
Ok(Err(Cancelled())) => {
|
||||
// TODO: when http error classification is done, distinguish between
|
||||
// timeout on sql vs timeout in proxy/cplane
|
||||
// ctx.set_error_kind(crate::error::ErrorKind::RateLimit);
|
||||
|
||||
let message = format!(
|
||||
"Query cancelled, runtime exceeded. SQL queries over HTTP must not exceed {} seconds of runtime. Please consider using our websocket based connections",
|
||||
config.http_config.request_timeout.as_secs_f64()
|
||||
);
|
||||
error!(message);
|
||||
json_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }),
|
||||
)?
|
||||
}
|
||||
};
|
||||
|
||||
response.headers_mut().insert(
|
||||
@@ -346,93 +317,7 @@ pub async fn handle(
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum SqlOverHttpError {
|
||||
#[error("{0}")]
|
||||
ReadPayload(#[from] ReadPayloadError),
|
||||
#[error("{0}")]
|
||||
ConnectCompute(#[from] HttpConnError),
|
||||
#[error("{0}")]
|
||||
ConnInfo(#[from] ConnInfoError),
|
||||
#[error("request is too large (max is {MAX_REQUEST_SIZE} bytes)")]
|
||||
RequestTooLarge,
|
||||
#[error("response is too large (max is {MAX_RESPONSE_SIZE} bytes)")]
|
||||
ResponseTooLarge,
|
||||
#[error("invalid isolation level")]
|
||||
InvalidIsolationLevel,
|
||||
#[error("{0}")]
|
||||
Postgres(#[from] tokio_postgres::Error),
|
||||
#[error("{0}")]
|
||||
JsonConversion(#[from] JsonConversionError),
|
||||
#[error("{0}")]
|
||||
Cancelled(SqlOverHttpCancel),
|
||||
}
|
||||
|
||||
impl ReportableError for SqlOverHttpError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
SqlOverHttpError::ReadPayload(e) => e.get_error_kind(),
|
||||
SqlOverHttpError::ConnectCompute(e) => e.get_error_kind(),
|
||||
SqlOverHttpError::ConnInfo(e) => e.get_error_kind(),
|
||||
SqlOverHttpError::RequestTooLarge => ErrorKind::User,
|
||||
SqlOverHttpError::ResponseTooLarge => ErrorKind::User,
|
||||
SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User,
|
||||
SqlOverHttpError::Postgres(p) => p.get_error_kind(),
|
||||
SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres,
|
||||
SqlOverHttpError::Cancelled(c) => c.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for SqlOverHttpError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
SqlOverHttpError::ReadPayload(p) => p.to_string(),
|
||||
SqlOverHttpError::ConnectCompute(c) => c.to_string_client(),
|
||||
SqlOverHttpError::ConnInfo(c) => c.to_string_client(),
|
||||
SqlOverHttpError::RequestTooLarge => self.to_string(),
|
||||
SqlOverHttpError::ResponseTooLarge => self.to_string(),
|
||||
SqlOverHttpError::InvalidIsolationLevel => self.to_string(),
|
||||
SqlOverHttpError::Postgres(p) => p.to_string(),
|
||||
SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(),
|
||||
SqlOverHttpError::Cancelled(_) => self.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ReadPayloadError {
|
||||
#[error("could not read the HTTP request body: {0}")]
|
||||
Read(#[from] hyper::Error),
|
||||
#[error("could not parse the HTTP request body: {0}")]
|
||||
Parse(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
impl ReportableError for ReadPayloadError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
|
||||
ReadPayloadError::Parse(_) => ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum SqlOverHttpCancel {
|
||||
#[error("query was cancelled")]
|
||||
Postgres,
|
||||
#[error("query was cancelled while stuck trying to connect to the database")]
|
||||
Connect,
|
||||
}
|
||||
|
||||
impl ReportableError for SqlOverHttpCancel {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
SqlOverHttpCancel::Postgres => ErrorKind::RateLimit,
|
||||
SqlOverHttpCancel::Connect => ErrorKind::ServiceRateLimit,
|
||||
}
|
||||
}
|
||||
}
|
||||
struct Cancelled();
|
||||
|
||||
async fn handle_inner(
|
||||
cancel: CancellationToken,
|
||||
@@ -440,7 +325,7 @@ async fn handle_inner(
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Body>, SqlOverHttpError> {
|
||||
) -> Result<Result<Response<Body>, Cancelled>, anyhow::Error> {
|
||||
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
|
||||
.with_label_values(&[ctx.protocol])
|
||||
.guard();
|
||||
@@ -473,7 +358,7 @@ async fn handle_inner(
|
||||
b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
|
||||
b"ReadCommitted" => IsolationLevel::ReadCommitted,
|
||||
b"RepeatableRead" => IsolationLevel::RepeatableRead,
|
||||
_ => return Err(SqlOverHttpError::InvalidIsolationLevel),
|
||||
_ => bail!("invalid isolation level"),
|
||||
}),
|
||||
None => None,
|
||||
};
|
||||
@@ -491,16 +376,19 @@ async fn handle_inner(
|
||||
// we don't have a streaming request support yet so this is to prevent OOM
|
||||
// from a malicious user sending an extremely large request body
|
||||
if request_content_length > MAX_REQUEST_SIZE {
|
||||
return Err(SqlOverHttpError::RequestTooLarge);
|
||||
return Err(anyhow::anyhow!(
|
||||
"request is too large (max is {MAX_REQUEST_SIZE} bytes)"
|
||||
));
|
||||
}
|
||||
|
||||
let fetch_and_process_request = async {
|
||||
let body = hyper::body::to_bytes(request.into_body()).await?;
|
||||
let body = hyper::body::to_bytes(request.into_body())
|
||||
.await
|
||||
.map_err(anyhow::Error::from)?;
|
||||
info!(length = body.len(), "request payload read");
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
|
||||
}
|
||||
.map_err(SqlOverHttpError::from);
|
||||
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
|
||||
};
|
||||
|
||||
let authenticate_and_connect = async {
|
||||
let keys = backend.authenticate(ctx, &conn_info).await?;
|
||||
@@ -510,9 +398,8 @@ async fn handle_inner(
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
ctx.latency_timer.success();
|
||||
Ok::<_, HttpConnError>(client)
|
||||
}
|
||||
.map_err(SqlOverHttpError::from);
|
||||
Ok::<_, anyhow::Error>(client)
|
||||
};
|
||||
|
||||
// Run both operations in parallel
|
||||
let (payload, mut client) = match select(
|
||||
@@ -525,9 +412,7 @@ async fn handle_inner(
|
||||
.await
|
||||
{
|
||||
Either::Left((result, _cancelled)) => result?,
|
||||
Either::Right((_cancelled, _)) => {
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect))
|
||||
}
|
||||
Either::Right((_cancelled, _)) => return Ok(Err(Cancelled())),
|
||||
};
|
||||
|
||||
let mut response = Response::builder()
|
||||
@@ -571,24 +456,20 @@ async fn handle_inner(
|
||||
results
|
||||
}
|
||||
Ok(Err(error)) => {
|
||||
let db_error = match &error {
|
||||
SqlOverHttpError::ConnectCompute(
|
||||
HttpConnError::ConnectionError(e),
|
||||
)
|
||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||
_ => None,
|
||||
};
|
||||
let db_error = error
|
||||
.downcast_ref::<tokio_postgres::Error>()
|
||||
.and_then(|e| e.as_db_error());
|
||||
|
||||
// if errored for some other reason, it might not be safe to return
|
||||
if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
|
||||
discard.discard();
|
||||
}
|
||||
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
return Ok(Err(Cancelled()));
|
||||
}
|
||||
Err(_timeout) => {
|
||||
discard.discard();
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
return Ok(Err(Cancelled()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -626,7 +507,7 @@ async fn handle_inner(
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(results) => {
|
||||
Ok(Ok(results)) => {
|
||||
info!("commit");
|
||||
let status = transaction.commit().await.map_err(|e| {
|
||||
// if we cannot commit - for now don't return connection to pool
|
||||
@@ -637,14 +518,14 @@ async fn handle_inner(
|
||||
discard.check_idle(status);
|
||||
results
|
||||
}
|
||||
Err(SqlOverHttpError::Cancelled(_)) => {
|
||||
Ok(Err(Cancelled())) => {
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::error!(?err, "could not cancel query");
|
||||
}
|
||||
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
|
||||
discard.discard();
|
||||
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
return Ok(Err(Cancelled()));
|
||||
}
|
||||
Err(err) => {
|
||||
info!("rollback");
|
||||
@@ -660,10 +541,16 @@ async fn handle_inner(
|
||||
};
|
||||
|
||||
if txn_read_only {
|
||||
response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
|
||||
response = response.header(
|
||||
TXN_READ_ONLY.clone(),
|
||||
HeaderValue::try_from(txn_read_only.to_string())?,
|
||||
);
|
||||
}
|
||||
if txn_deferrable {
|
||||
response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
|
||||
response = response.header(
|
||||
TXN_DEFERRABLE.clone(),
|
||||
HeaderValue::try_from(txn_deferrable.to_string())?,
|
||||
);
|
||||
}
|
||||
if let Some(txn_isolation_level) = txn_isolation_level_raw {
|
||||
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
|
||||
@@ -687,7 +574,7 @@ async fn handle_inner(
|
||||
// moving this later in the stack is going to be a lot of effort and ehhhh
|
||||
metrics.record_egress(len as u64);
|
||||
|
||||
Ok(response)
|
||||
Ok(Ok(response))
|
||||
}
|
||||
|
||||
async fn query_batch(
|
||||
@@ -697,7 +584,7 @@ async fn query_batch(
|
||||
total_size: &mut usize,
|
||||
raw_output: bool,
|
||||
array_mode: bool,
|
||||
) -> Result<Vec<Value>, SqlOverHttpError> {
|
||||
) -> anyhow::Result<Result<Vec<Value>, Cancelled>> {
|
||||
let mut results = Vec::with_capacity(queries.queries.len());
|
||||
let mut current_size = 0;
|
||||
for stmt in queries.queries {
|
||||
@@ -719,12 +606,12 @@ async fn query_batch(
|
||||
return Err(e);
|
||||
}
|
||||
Either::Right((_cancelled, _)) => {
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
return Ok(Err(Cancelled()));
|
||||
}
|
||||
}
|
||||
}
|
||||
*total_size += current_size;
|
||||
Ok(results)
|
||||
Ok(Ok(results))
|
||||
}
|
||||
|
||||
async fn query_to_json<T: GenericClient>(
|
||||
@@ -733,7 +620,7 @@ async fn query_to_json<T: GenericClient>(
|
||||
current_size: &mut usize,
|
||||
raw_output: bool,
|
||||
default_array_mode: bool,
|
||||
) -> Result<(ReadyForQueryStatus, Value), SqlOverHttpError> {
|
||||
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
|
||||
info!("executing query");
|
||||
let query_params = data.params;
|
||||
let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
|
||||
@@ -750,7 +637,9 @@ async fn query_to_json<T: GenericClient>(
|
||||
// we don't have a streaming response support yet so this is to prevent OOM
|
||||
// from a malicious query (eg a cross join)
|
||||
if *current_size > MAX_RESPONSE_SIZE {
|
||||
return Err(SqlOverHttpError::ResponseTooLarge);
|
||||
return Err(anyhow::anyhow!(
|
||||
"response is too large (max is {MAX_RESPONSE_SIZE} bytes)"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2142,13 +2142,11 @@ class NeonStorageController(MetricsGetter):
|
||||
shards: list[dict[str, Any]] = body["shards"]
|
||||
return shards
|
||||
|
||||
def tenant_shard_split(
|
||||
self, tenant_id: TenantId, shard_count: int, shard_stripe_size: Optional[int] = None
|
||||
) -> list[TenantShardId]:
|
||||
def tenant_shard_split(self, tenant_id: TenantId, shard_count: int) -> list[TenantShardId]:
|
||||
response = self.request(
|
||||
"PUT",
|
||||
f"{self.env.storage_controller_api}/control/v1/tenant/{tenant_id}/shard_split",
|
||||
json={"new_shard_count": shard_count, "new_stripe_size": shard_stripe_size},
|
||||
json={"new_shard_count": shard_count},
|
||||
headers=self.headers(TokenScope.ADMIN),
|
||||
)
|
||||
body = response.json()
|
||||
|
||||
@@ -318,13 +318,6 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
|
||||
assert isinstance(res_json["tenant_shards"], list)
|
||||
return res_json
|
||||
|
||||
def tenant_get_location(self, tenant_id: TenantShardId):
|
||||
res = self.get(
|
||||
f"http://localhost:{self.port}/v1/location_config/{tenant_id}",
|
||||
)
|
||||
self.verbose_error(res)
|
||||
return res.json()
|
||||
|
||||
def tenant_delete(self, tenant_id: Union[TenantId, TenantShardId]):
|
||||
res = self.delete(f"http://localhost:{self.port}/v1/tenant/{tenant_id}")
|
||||
self.verbose_error(res)
|
||||
|
||||
@@ -28,7 +28,7 @@ def platform() -> Optional[str]:
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def pageserver_virtual_file_io_engine() -> Optional[str]:
|
||||
return None
|
||||
return os.getenv("PAGESERVER_VIRTUAL_FILE_IO_ENGINE")
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc: Metafunc):
|
||||
@@ -48,11 +48,8 @@ def pytest_generate_tests(metafunc: Metafunc):
|
||||
|
||||
# A hacky way to parametrize tests only for `pageserver_virtual_file_io_engine=std-fs`
|
||||
# And do not change test name for default `pageserver_virtual_file_io_engine=tokio-epoll-uring` to keep tests statistics
|
||||
if (io_engine := os.getenv("PAGESERVER_VIRTUAL_FILE_IO_ENGINE", "")) not in (
|
||||
"",
|
||||
"tokio-epoll-uring",
|
||||
):
|
||||
metafunc.parametrize("pageserver_virtual_file_io_engine", [io_engine])
|
||||
io_engine = os.getenv("PAGESERVER_VIRTUAL_FILE_IO_ENGINE", "std-fs")
|
||||
metafunc.parametrize("pageserver_virtual_file_io_engine", [io_engine])
|
||||
|
||||
# For performance tests, parametrize also by platform
|
||||
if (
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pytest
|
||||
from fixtures.log_helper import log
|
||||
@@ -9,11 +8,7 @@ from fixtures.neon_fixtures import (
|
||||
)
|
||||
from fixtures.remote_storage import s3_storage
|
||||
from fixtures.types import Lsn, TenantShardId, TimelineId
|
||||
from fixtures.utils import wait_until
|
||||
from fixtures.workload import Workload
|
||||
from pytest_httpserver import HTTPServer
|
||||
from werkzeug.wrappers.request import Request
|
||||
from werkzeug.wrappers.response import Response
|
||||
|
||||
|
||||
def test_sharding_smoke(
|
||||
@@ -315,96 +310,6 @@ def test_sharding_split_smoke(
|
||||
workload.validate()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("initial_stripe_size", [None, 65536])
|
||||
def test_sharding_split_stripe_size(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
httpserver: HTTPServer,
|
||||
httpserver_listen_address,
|
||||
initial_stripe_size: int,
|
||||
):
|
||||
"""
|
||||
Check that modifying stripe size inline with a shard split works as expected
|
||||
"""
|
||||
(host, port) = httpserver_listen_address
|
||||
neon_env_builder.control_plane_compute_hook_api = f"http://{host}:{port}/notify"
|
||||
neon_env_builder.num_pageservers = 1
|
||||
|
||||
# Set up fake HTTP notify endpoint: we will use this to validate that we receive
|
||||
# the correct stripe size after split.
|
||||
notifications = []
|
||||
|
||||
def handler(request: Request):
|
||||
log.info(f"Notify request: {request}")
|
||||
notifications.append(request.json)
|
||||
return Response(status=200)
|
||||
|
||||
httpserver.expect_request("/notify", method="PUT").respond_with_handler(handler)
|
||||
|
||||
env = neon_env_builder.init_start(
|
||||
initial_tenant_shard_count=1, initial_tenant_shard_stripe_size=initial_stripe_size
|
||||
)
|
||||
tenant_id = env.initial_tenant
|
||||
|
||||
assert len(notifications) == 1
|
||||
expect: Dict[str, Union[List[Dict[str, int]], str, None, int]] = {
|
||||
"tenant_id": str(env.initial_tenant),
|
||||
"stripe_size": None,
|
||||
"shards": [{"node_id": int(env.pageservers[0].id), "shard_number": 0}],
|
||||
}
|
||||
assert notifications[0] == expect
|
||||
|
||||
new_stripe_size = 2048
|
||||
env.storage_controller.tenant_shard_split(
|
||||
tenant_id, shard_count=2, shard_stripe_size=new_stripe_size
|
||||
)
|
||||
|
||||
# Check that we ended up with the stripe size that we expected, both on the pageserver
|
||||
# and in the notifications to compute
|
||||
assert len(notifications) == 2
|
||||
expect_after: Dict[str, Union[List[Dict[str, int]], str, None, int]] = {
|
||||
"tenant_id": str(env.initial_tenant),
|
||||
"stripe_size": new_stripe_size,
|
||||
"shards": [
|
||||
{"node_id": int(env.pageservers[0].id), "shard_number": 0},
|
||||
{"node_id": int(env.pageservers[0].id), "shard_number": 1},
|
||||
],
|
||||
}
|
||||
log.info(f"Got notification: {notifications[1]}")
|
||||
assert notifications[1] == expect_after
|
||||
|
||||
# Inspect the stripe size on the pageserver
|
||||
shard_0_loc = (
|
||||
env.pageservers[0].http_client().tenant_get_location(TenantShardId(tenant_id, 0, 2))
|
||||
)
|
||||
assert shard_0_loc["shard_stripe_size"] == new_stripe_size
|
||||
shard_1_loc = (
|
||||
env.pageservers[0].http_client().tenant_get_location(TenantShardId(tenant_id, 1, 2))
|
||||
)
|
||||
assert shard_1_loc["shard_stripe_size"] == new_stripe_size
|
||||
|
||||
# Ensure stripe size survives a pageserver restart
|
||||
env.pageservers[0].stop()
|
||||
env.pageservers[0].start()
|
||||
shard_0_loc = (
|
||||
env.pageservers[0].http_client().tenant_get_location(TenantShardId(tenant_id, 0, 2))
|
||||
)
|
||||
assert shard_0_loc["shard_stripe_size"] == new_stripe_size
|
||||
shard_1_loc = (
|
||||
env.pageservers[0].http_client().tenant_get_location(TenantShardId(tenant_id, 1, 2))
|
||||
)
|
||||
assert shard_1_loc["shard_stripe_size"] == new_stripe_size
|
||||
|
||||
# Ensure stripe size survives a storage controller restart
|
||||
env.storage_controller.stop()
|
||||
env.storage_controller.start()
|
||||
|
||||
def assert_restart_notification():
|
||||
assert len(notifications) == 3
|
||||
assert notifications[2] == expect_after
|
||||
|
||||
wait_until(10, 1, assert_restart_notification)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
# The quantity of data isn't huge, but debug can be _very_ slow, and the things we're
|
||||
# validating in this test don't benefit much from debug assertions.
|
||||
|
||||
Reference in New Issue
Block a user