Compare commits

...

1 Commits

Author SHA1 Message Date
Conrad Ludgate
a93639af6b feat(proxy): allow proxy to connect to separate compute compared to mock cplane 2024-12-17 19:04:57 +00:00
3 changed files with 85 additions and 72 deletions

View File

@@ -74,44 +74,39 @@ impl MockControlPlane {
tokio::spawn(connection);
let secret = if let Some(entry) = get_execute_postgres_query(
&client,
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
&[&&*user_info.user],
"rolpassword",
)
.await?
{
info!("got a secret: {entry}"); // safe since it's not a prod scenario
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
} else {
warn!("user '{}' does not exist", user_info.user);
None
let secret = {
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
if let Some(row) = client.query_opt(query, &[&&*user_info.user]).await? {
let entry: String = row.get("rolpassword");
info!("got a secret: {entry}"); // safe since it's not a prod scenario
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
} else {
warn!("user '{}' does not exist", user_info.user);
None
}
};
let allowed_ips = if self.ip_allowlist_check_enabled {
match get_execute_postgres_query(
&client,
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
&[&user_info.endpoint.as_str()],
"allowed_ips",
)
.await?
{
Some(s) => {
info!("got allowed_ips: {s}");
s.split(',')
.map(|s| {
IpPattern::from_str(s).expect("mocked ip pattern should be correct")
})
.collect()
}
None => vec![],
let mut allowed_ips = vec![];
if self.ip_allowlist_check_enabled {
let query =
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1";
let row = client
.query_one(query, &[&user_info.endpoint.as_str()])
.await?;
let s: Option<String> = row.get("allowed_ips");
if let Some(s) = s {
info!("got allowed_ips: {s}");
allowed_ips = s
.split(',')
.map(|s| {
IpPattern::from_str(s).expect("mocked ip pattern should be correct")
})
.collect();
}
} else {
vec![]
};
}
Ok((secret, allowed_ips))
}
@@ -134,11 +129,8 @@ impl MockControlPlane {
let connection = tokio::spawn(connection);
let res = client.query(
"select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
&[&endpoint.as_str()],
)
.await?;
let query = "select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1";
let res = client.query(query, &[&endpoint.as_str()]).await?;
let mut rows = vec![];
for row in res {
@@ -161,11 +153,23 @@ impl MockControlPlane {
Ok(rows)
}
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new(
self.endpoint.host_str().unwrap_or("localhost").to_owned(),
self.endpoint.port().unwrap_or(5432),
);
async fn do_wake_compute(
&self,
user_info: &ComputeUserInfo,
) -> Result<NodeInfo, WakeComputeError> {
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
tokio::spawn(connection);
let query = "select host, port from neon_control_plane.endpoints where endpoint_id = $1";
let row = client
.query_one(query, &[&user_info.endpoint.as_str()])
.await?;
let host: String = row.get("host");
let port: i32 = row.get("port");
let mut config = compute::ConnCfg::new(host, port as u16);
config.ssl_mode(postgres_client::config::SslMode::Disable);
let node = NodeInfo {
@@ -182,26 +186,6 @@ impl MockControlPlane {
}
}
async fn get_execute_postgres_query(
client: &Client,
query: &str,
params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
idx: &str,
) -> Result<Option<String>, GetAuthInfoError> {
let rows = client.query(query, params).await?;
// We can get at most one row, because `rolname` is unique.
let Some(row) = rows.first() else {
// This means that the user doesn't exist, so there can be no secret.
// However, this is still a *valid* outcome which is very similar
// to getting `404 Not found` from the Neon console.
return Ok(None);
};
let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
Ok(Some(entry))
}
impl super::ControlPlaneApi for MockControlPlane {
#[tracing::instrument(skip_all)]
async fn get_role_secret(
@@ -239,9 +223,11 @@ impl super::ControlPlaneApi for MockControlPlane {
async fn wake_compute(
&self,
_ctx: &RequestContext,
_user_info: &ComputeUserInfo,
user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
self.do_wake_compute().map_ok(Cached::new_uncached).await
self.do_wake_compute(user_info)
.map_ok(Cached::new_uncached)
.await
}
}

View File

@@ -3274,7 +3274,7 @@ class NeonProxy(PgProtocol):
metric_collection_interval: str | None = None,
):
host = "127.0.0.1"
domain = "proxy.localtest.me" # resolves to 127.0.0.1
domain = "ep-test-endpoint.localtest.me" # resolves to 127.0.0.1
super().__init__(dsn=auth_backend.default_conn_url, host=domain, port=proxy_port)
self.domain = domain
@@ -3639,7 +3639,19 @@ def static_proxy(
vanilla_pg.safe_psql("create user proxy with login superuser password 'password'")
vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS neon_control_plane")
vanilla_pg.safe_psql(
"CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))"
f"""
CREATE TABLE neon_control_plane.endpoints (
endpoint_id text PRIMARY KEY,
allowed_ips text,
host text not null default '{host}',
port integer not null default {port}
)
"""
)
vanilla_pg.safe_psql(
"""
insert into neon_control_plane.endpoints (endpoint_id) VALUES ('ep-test-endpoint');
"""
)
proxy_port = port_distributor.get_port()

View File

@@ -43,11 +43,16 @@ async def test_http_pool_begin_1(static_proxy: NeonProxy):
print(results)
def test_proxy_select_1(static_proxy: NeonProxy):
def test_proxy_select_1(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres):
"""
A simplest smoke test: check proxy against a local postgres instance.
"""
# Establish the default compute for this endpoint
vanilla_pg.safe_psql(
"INSERT INTO neon_control_plane.endpoints (endpoint_id) VALUES ('generic-project-name')"
)
# no SNI, deprecated `options=project` syntax (before we had several endpoint in project)
out = static_proxy.safe_psql("select 1", sslsni=0, options="project=generic-project-name")
assert out[0][0] == 1
@@ -61,12 +66,17 @@ def test_proxy_select_1(static_proxy: NeonProxy):
assert out[0][0] == 42
def test_password_hack(static_proxy: NeonProxy):
def test_password_hack(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres):
"""
Check the PasswordHack auth flow: an alternative to SCRAM auth for
clients which can't provide the project/endpoint name via SNI or `options`.
"""
# Establish the default compute for this endpoint
vanilla_pg.safe_psql(
"INSERT INTO neon_control_plane.endpoints (endpoint_id) VALUES ('irrelevant')"
)
user = "borat"
password = "password"
static_proxy.safe_psql(f"create role {user} with login password '{password}'")
@@ -107,7 +117,7 @@ async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy):
@pytest.mark.parametrize("option_name", ["project", "endpoint"])
def test_proxy_options(static_proxy: NeonProxy, option_name: str):
def test_proxy_options(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres, option_name: str):
"""
Check that we pass extra `options` to the PostgreSQL server:
* `project=...` and `endpoint=...` shouldn't be passed at all
@@ -115,6 +125,11 @@ def test_proxy_options(static_proxy: NeonProxy, option_name: str):
* everything else should be passed as-is.
"""
# Establish the default compute for this endpoint
vanilla_pg.safe_psql(
"INSERT INTO neon_control_plane.endpoints (endpoint_id) VALUES ('irrelevant')"
)
options = f"{option_name}=irrelevant -cproxytest.option=value"
out = static_proxy.safe_psql("show proxytest.option", options=options, sslsni=0)
assert out[0][0] == "value"