diff --git a/pgxn/neon/pagestore_smgr.c b/pgxn/neon/pagestore_smgr.c index cbb0e2ae6d..a5e0c402fb 100644 --- a/pgxn/neon/pagestore_smgr.c +++ b/pgxn/neon/pagestore_smgr.c @@ -439,6 +439,8 @@ readahead_buffer_resize(int newsize, void *extra) newPState->ring_unused = newsize; newPState->ring_receive = newsize; newPState->ring_flush = newsize; + newPState->max_shard_no = MyPState->max_shard_no; + memcpy(newPState->shard_bitmap, MyPState->shard_bitmap, sizeof(MyPState->shard_bitmap)); /* * Copy over the prefetches. @@ -495,7 +497,11 @@ readahead_buffer_resize(int newsize, void *extra) for (; end >= MyPState->ring_last && end != UINT64_MAX; end -= 1) { - prefetch_set_unused(end); + PrefetchRequest *slot = GetPrfSlot(end); + if (slot->status == PRFS_RECEIVED) + { + pfree(slot->response); + } } prfh_destroy(MyPState->prf_hash); @@ -944,6 +950,9 @@ Retry: Assert(entry == NULL); Assert(slot == NULL); + /* There should be no buffer overflow */ + Assert(MyPState->ring_last + readahead_buffer_size >= MyPState->ring_unused); + /* * If the prefetch queue is full, we need to make room by clearing the * oldest slot. If the oldest slot holds a buffer that was already @@ -958,7 +967,7 @@ Retry: * a prefetch request kind of goes against the principles of * prefetching) */ - if (MyPState->ring_last + readahead_buffer_size - 1 == MyPState->ring_unused) + if (MyPState->ring_last + readahead_buffer_size == MyPState->ring_unused) { uint64 cleanup_index = MyPState->ring_last; diff --git a/test_runner/regress/test_prefetch_buffer_resize.py b/test_runner/regress/test_prefetch_buffer_resize.py new file mode 100644 index 0000000000..7676b78b0e --- /dev/null +++ b/test_runner/regress/test_prefetch_buffer_resize.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import random + +import pytest +from fixtures.neon_fixtures import NeonEnvBuilder + + +@pytest.mark.parametrize("shard_count", [None, 4]) +@pytest.mark.timeout(600) +def test_prefetch(neon_env_builder: NeonEnvBuilder, shard_count: int | None): + if shard_count is not None: + neon_env_builder.num_pageservers = shard_count + env = neon_env_builder.init_start( + initial_tenant_shard_count=shard_count, + ) + n_iter = 10 + n_rec = 100000 + + endpoint = env.endpoints.create_start( + "main", + config_lines=[ + "shared_buffers=10MB", + ], + ) + + cur = endpoint.connect().cursor() + + cur.execute("CREATE TABLE t(pk integer, filler text default repeat('?', 200))") + cur.execute(f"insert into t (pk) values (generate_series(1,{n_rec}))") + + cur.execute("set statement_timeout=0") + cur.execute("set effective_io_concurrency=20") + cur.execute("set max_parallel_workers_per_gather=0") + + for _ in range(n_iter): + buf_size = random.randrange(16, 32) + cur.execute(f"set neon.readahead_buffer_size={buf_size}") + limit = random.randrange(1, n_rec) + cur.execute(f"select sum(pk) from (select pk from t limit {limit}) s")