test_runner: initial gRPC protocol support

This commit is contained in:
Erik Grinaker
2025-06-06 16:56:33 +02:00
parent 396a16a3b2
commit e74a957045
5 changed files with 95 additions and 43 deletions

View File

@@ -18,7 +18,7 @@ use clap::Parser;
use compute_api::requests::ComputeClaimsScope;
use compute_api::spec::ComputeMode;
use control_plane::broker::StorageBroker;
use control_plane::endpoint::ComputeControlPlane;
use control_plane::endpoint::{ComputeControlPlane, PageserverProtocol};
use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage};
use control_plane::local_env;
use control_plane::local_env::{
@@ -664,6 +664,10 @@ struct EndpointStartCmdArgs {
#[clap(short = 't', long, value_parser= humantime::parse_duration, help = "timeout until we fail the command")]
#[arg(default_value = "90s")]
start_timeout: Duration,
/// If enabled, use gRPC (and the communicator) to talk to Pageservers.
#[clap(long)]
grpc: bool,
}
#[derive(clap::Args)]
@@ -682,6 +686,10 @@ struct EndpointReconfigureCmdArgs {
#[clap(long)]
safekeepers: Option<String>,
/// If enabled, use gRPC (and communicator) to talk to Pageservers.
#[clap(long)]
grpc: bool,
}
#[derive(clap::Args)]
@@ -1452,14 +1460,22 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
let (pageservers, stripe_size) = if let Some(pageserver_id) = pageserver_id {
let conf = env.get_pageserver_conf(pageserver_id).unwrap();
let parsed = parse_host_port(&conf.listen_pg_addr).expect("Bad config");
(
vec![(parsed.0, parsed.1.unwrap_or(5432))],
// If caller is telling us what pageserver to use, this is not a tenant which is
// full managed by storage controller, therefore not sharded.
DEFAULT_STRIPE_SIZE,
)
// Use gRPC if requested.
let (protocol, host, port) = if args.grpc {
let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config");
let (host, port) = parse_host_port(grpc_addr).expect("bad config");
(PageserverProtocol::Grpc, host, port.unwrap_or(51051))
} else {
let (host, port) = parse_host_port(&conf.listen_pg_addr).expect("bad config");
(PageserverProtocol::Libpq, host, port.unwrap_or(5432))
};
// If caller is telling us what pageserver to use, this is not a tenant which is
// fully managed by storage controller, therefore not sharded.
(vec![(protocol, host, port)], DEFAULT_STRIPE_SIZE)
} else {
// TODO: plumb Pageserver gRPC ports through storage-controller.
assert!(!args.grpc, "gRPC not supported with storage-controller yet");
// Look up the currently attached location of the tenant, and its striping metadata,
// to pass these on to postgres.
let storage_controller = StorageController::from_env(env);
@@ -1478,6 +1494,7 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
}
anyhow::Ok((
PageserverProtocol::Libpq,
Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported bad hostname"),
shard.listen_pg_port,
@@ -1536,12 +1553,20 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.get(endpoint_id.as_str())
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
let pageservers = if let Some(ps_id) = args.endpoint_pageserver_id {
let pageserver = PageServerNode::from_env(env, env.get_pageserver_conf(ps_id)?);
vec![(
pageserver.pg_connection_config.host().clone(),
pageserver.pg_connection_config.port(),
)]
let conf = env.get_pageserver_conf(ps_id)?;
// Use gRPC if requested.
let (protocol, host, port) = if args.grpc {
let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config");
let (host, port) = parse_host_port(grpc_addr).expect("bad config");
(PageserverProtocol::Grpc, host, port.unwrap_or(51051))
} else {
let (host, port) = parse_host_port(&conf.listen_pg_addr).expect("bad config");
(PageserverProtocol::Libpq, host, port.unwrap_or(5432))
};
vec![(protocol, host, port)]
} else {
// TODO: plumb gRPC ports through storage-controller.
assert!(!args.grpc, "gRPC not supported with storage-controller yet");
let storage_controller = StorageController::from_env(env);
storage_controller
.tenant_locate(endpoint.tenant_id)
@@ -1550,6 +1575,7 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.into_iter()
.map(|shard| {
(
PageserverProtocol::Libpq,
Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported malformed host"),
shard.listen_pg_port,

View File

@@ -37,6 +37,7 @@
//! ```
//!
use std::collections::BTreeMap;
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
use std::process::Command;
@@ -74,7 +75,6 @@ use utils::id::{NodeId, TenantId, TimelineId};
use crate::local_env::LocalEnv;
use crate::postgresql_conf::PostgresConf;
use crate::storage_controller::StorageController;
// contents of a endpoint.json file
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
@@ -331,7 +331,7 @@ pub enum EndpointStatus {
RunningNoPidfile,
}
impl std::fmt::Display for EndpointStatus {
impl Display for EndpointStatus {
fn fmt(&self, writer: &mut std::fmt::Formatter) -> std::fmt::Result {
let s = match self {
Self::Running => "running",
@@ -343,6 +343,28 @@ impl std::fmt::Display for EndpointStatus {
}
}
#[derive(Clone, Copy, Debug)]
pub enum PageserverProtocol {
Libpq,
Grpc,
}
impl PageserverProtocol {
/// Returns the URL scheme for the protocol, used in connstrings.
pub fn scheme(&self) -> &'static str {
match self {
Self::Libpq => "postgresql",
Self::Grpc => "grpc",
}
}
}
impl Display for PageserverProtocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.scheme())
}
}
impl Endpoint {
fn from_dir_entry(entry: std::fs::DirEntry, env: &LocalEnv) -> Result<Endpoint> {
if !entry.file_type()?.is_dir() {
@@ -606,10 +628,10 @@ impl Endpoint {
}
}
fn build_pageserver_connstr(pageservers: &[(Host, u16)]) -> String {
fn build_pageserver_connstr(pageservers: &[(PageserverProtocol, Host, u16)]) -> String {
pageservers
.iter()
.map(|(host, port)| format!("postgresql://no_user@{host}:{port}"))
.map(|(scheme, host, port)| format!("{scheme}://no_user@{host}:{port}"))
.collect::<Vec<_>>()
.join(",")
}
@@ -654,7 +676,7 @@ impl Endpoint {
endpoint_storage_addr: String,
safekeepers_generation: Option<SafekeeperGeneration>,
safekeepers: Vec<NodeId>,
pageservers: Vec<(Host, u16)>,
pageservers: Vec<(PageserverProtocol, Host, u16)>,
remote_ext_base_url: Option<&String>,
shard_stripe_size: usize,
create_test_user: bool,
@@ -939,10 +961,12 @@ impl Endpoint {
pub async fn reconfigure(
&self,
mut pageservers: Vec<(Host, u16)>,
pageservers: Vec<(PageserverProtocol, Host, u16)>,
stripe_size: Option<ShardStripeSize>,
safekeepers: Option<Vec<NodeId>>,
) -> Result<()> {
anyhow::ensure!(!pageservers.is_empty(), "no pageservers provided");
let (mut spec, compute_ctl_config) = {
let config_path = self.endpoint_path().join("config.json");
let file = std::fs::File::open(config_path)?;
@@ -954,25 +978,7 @@ impl Endpoint {
let postgresql_conf = self.read_postgresql_conf()?;
spec.cluster.postgresql_conf = Some(postgresql_conf);
// If we weren't given explicit pageservers, query the storage controller
if pageservers.is_empty() {
let storage_controller = StorageController::from_env(&self.env);
let locate_result = storage_controller.tenant_locate(self.tenant_id).await?;
pageservers = locate_result
.shards
.into_iter()
.map(|shard| {
(
Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported bad hostname"),
shard.listen_pg_port,
)
})
.collect::<Vec<_>>();
}
let pageserver_connstr = Self::build_pageserver_connstr(&pageservers);
assert!(!pageserver_connstr.is_empty());
spec.pageserver_connstring = Some(pageserver_connstr);
if stripe_size.is_some() {
spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize);

View File

@@ -5,7 +5,7 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::Context;
use control_plane::endpoint::{ComputeControlPlane, EndpointStatus};
use control_plane::endpoint::{ComputeControlPlane, EndpointStatus, PageserverProtocol};
use control_plane::local_env::LocalEnv;
use futures::StreamExt;
use hyper::StatusCode;
@@ -428,7 +428,8 @@ impl ComputeHook {
.expect("Unknown pageserver");
let (pg_host, pg_port) = parse_host_port(&ps_conf.listen_pg_addr)
.expect("Unable to parse listen_pg_addr");
(pg_host, pg_port.unwrap_or(5432))
// TODO: plumb gRPC through storage-controller.
(PageserverProtocol::Libpq, pg_host, pg_port.unwrap_or(5432))
})
.collect::<Vec<_>>();

View File

@@ -564,6 +564,7 @@ class NeonLocalCli(AbstractNeonCli):
basebackup_request_tries: int | None = None,
timeout: str | None = None,
env: dict[str, str] | None = None,
grpc: bool = False,
) -> subprocess.CompletedProcess[str]:
args = [
"endpoint",
@@ -583,6 +584,8 @@ class NeonLocalCli(AbstractNeonCli):
args.append(endpoint_id)
if pageserver_id is not None:
args.extend(["--pageserver-id", str(pageserver_id)])
if grpc:
args.extend(["--grpc"])
if allow_multiple:
args.extend(["--allow-multiple"])
if create_test_user:
@@ -599,6 +602,7 @@ class NeonLocalCli(AbstractNeonCli):
endpoint_id: str,
tenant_id: TenantId | None = None,
pageserver_id: int | None = None,
grpc: bool = False,
safekeepers: list[int] | None = None,
check_return_code=True,
) -> subprocess.CompletedProcess[str]:
@@ -607,6 +611,8 @@ class NeonLocalCli(AbstractNeonCli):
args.extend(["--tenant-id", str(tenant_id)])
if pageserver_id is not None:
args.extend(["--pageserver-id", str(pageserver_id)])
if grpc:
args.extend(["--grpc"])
if safekeepers is not None:
args.extend(["--safekeepers", (",".join(map(str, safekeepers)))])
return self.raw_cli(args, check_return_code=check_return_code)

View File

@@ -4176,6 +4176,7 @@ class Endpoint(PgProtocol, LogUtils):
pageserver_id: int | None = None,
allow_multiple: bool = False,
update_catalog: bool = False,
grpc: bool = False,
) -> Self:
"""
Create a new Postgres endpoint.
@@ -4209,9 +4210,12 @@ class Endpoint(PgProtocol, LogUtils):
# set small 'max_replication_write_lag' to enable backpressure
# and make tests more stable.
config_lines = ["max_replication_write_lag=15MB"] + config_lines
config_lines += ["max_replication_write_lag=15MB"]
config_lines = ["neon.enable_new_communicator=true"] + config_lines
# If gRPC is enabled, use the new communicator too.
#
# NB: the communicator is enabled by default, so force it to false otherwise.
config_lines += [f"neon.enable_new_communicator={str(grpc).lower()}"]
# Delete file cache if it exists (and we're recreating the endpoint)
if USE_LFC:
@@ -4264,6 +4268,7 @@ class Endpoint(PgProtocol, LogUtils):
basebackup_request_tries: int | None = None,
timeout: str | None = None,
env: dict[str, str] | None = None,
grpc: bool = False,
) -> Self:
"""
Start the Postgres instance.
@@ -4288,6 +4293,7 @@ class Endpoint(PgProtocol, LogUtils):
basebackup_request_tries=basebackup_request_tries,
timeout=timeout,
env=env,
grpc=grpc,
)
self._running.release(1)
self.log_config_value("shared_buffers")
@@ -4358,14 +4364,14 @@ class Endpoint(PgProtocol, LogUtils):
def is_running(self):
return self._running._value > 0
def reconfigure(self, pageserver_id: int | None = None, safekeepers: list[int] | None = None):
def reconfigure(self, pageserver_id: int | None = None, grpc: bool = False, safekeepers: list[int] | None = None):
assert self.endpoint_id is not None
# If `safekeepers` is not None, they are remember them as active and use
# in the following commands.
if safekeepers is not None:
self.active_safekeepers = safekeepers
self.env.neon_cli.endpoint_reconfigure(
self.endpoint_id, self.tenant_id, pageserver_id, self.active_safekeepers
self.endpoint_id, self.tenant_id, pageserver_id, grpc, self.active_safekeepers
)
def respec(self, **kwargs: Any) -> None:
@@ -4500,6 +4506,7 @@ class Endpoint(PgProtocol, LogUtils):
pageserver_id: int | None = None,
allow_multiple: bool = False,
basebackup_request_tries: int | None = None,
grpc: bool = False,
) -> Self:
"""
Create an endpoint, apply config, and start Postgres.
@@ -4514,11 +4521,13 @@ class Endpoint(PgProtocol, LogUtils):
lsn=lsn,
pageserver_id=pageserver_id,
allow_multiple=allow_multiple,
grpc=grpc,
).start(
remote_ext_base_url=remote_ext_base_url,
pageserver_id=pageserver_id,
allow_multiple=allow_multiple,
basebackup_request_tries=basebackup_request_tries,
grpc=grpc,
)
return self
@@ -4602,6 +4611,7 @@ class EndpointFactory:
remote_ext_base_url: str | None = None,
pageserver_id: int | None = None,
basebackup_request_tries: int | None = None,
grpc: bool = False,
) -> Endpoint:
ep = Endpoint(
self.env,
@@ -4622,6 +4632,7 @@ class EndpointFactory:
remote_ext_base_url=remote_ext_base_url,
pageserver_id=pageserver_id,
basebackup_request_tries=basebackup_request_tries,
grpc=grpc,
)
def create(
@@ -4634,6 +4645,7 @@ class EndpointFactory:
config_lines: list[str] | None = None,
pageserver_id: int | None = None,
update_catalog: bool = False,
grpc: bool = False,
) -> Endpoint:
ep = Endpoint(
self.env,
@@ -4656,6 +4668,7 @@ class EndpointFactory:
config_lines=config_lines,
pageserver_id=pageserver_id,
update_catalog=update_catalog,
grpc=grpc,
)
def stop_all(self, fail_on_error=True) -> Self: