From a93639af6b897caff0d2df7574900cb54da93a30 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 17 Dec 2024 18:46:38 +0000 Subject: [PATCH] feat(proxy): allow proxy to connect to separate compute compared to mock cplane --- proxy/src/control_plane/client/mock.rs | 120 +++++++++++-------------- test_runner/fixtures/neon_fixtures.py | 16 +++- test_runner/regress/test_proxy.py | 21 ++++- 3 files changed, 85 insertions(+), 72 deletions(-) diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 5f8bda0f35..0e5bf8fb54 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -74,44 +74,39 @@ impl MockControlPlane { tokio::spawn(connection); - let secret = if let Some(entry) = get_execute_postgres_query( - &client, - "select rolpassword from pg_catalog.pg_authid where rolname = $1", - &[&&*user_info.user], - "rolpassword", - ) - .await? - { - info!("got a secret: {entry}"); // safe since it's not a prod scenario - let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); - secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) - } else { - warn!("user '{}' does not exist", user_info.user); - None + let secret = { + let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1"; + if let Some(row) = client.query_opt(query, &[&&*user_info.user]).await? { + let entry: String = row.get("rolpassword"); + + info!("got a secret: {entry}"); // safe since it's not a prod scenario + let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); + secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) + } else { + warn!("user '{}' does not exist", user_info.user); + None + } }; - let allowed_ips = if self.ip_allowlist_check_enabled { - match get_execute_postgres_query( - &client, - "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1", - &[&user_info.endpoint.as_str()], - "allowed_ips", - ) - .await? - { - Some(s) => { - info!("got allowed_ips: {s}"); - s.split(',') - .map(|s| { - IpPattern::from_str(s).expect("mocked ip pattern should be correct") - }) - .collect() - } - None => vec![], + let mut allowed_ips = vec![]; + if self.ip_allowlist_check_enabled { + let query = + "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1"; + let row = client + .query_one(query, &[&user_info.endpoint.as_str()]) + .await?; + + let s: Option = row.get("allowed_ips"); + if let Some(s) = s { + info!("got allowed_ips: {s}"); + allowed_ips = s + .split(',') + .map(|s| { + IpPattern::from_str(s).expect("mocked ip pattern should be correct") + }) + .collect(); } - } else { - vec![] - }; + } Ok((secret, allowed_ips)) } @@ -134,11 +129,8 @@ impl MockControlPlane { let connection = tokio::spawn(connection); - let res = client.query( - "select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1", - &[&endpoint.as_str()], - ) - .await?; + let query = "select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1"; + let res = client.query(query, &[&endpoint.as_str()]).await?; let mut rows = vec![]; for row in res { @@ -161,11 +153,23 @@ impl MockControlPlane { Ok(rows) } - async fn do_wake_compute(&self) -> Result { - let mut config = compute::ConnCfg::new( - self.endpoint.host_str().unwrap_or("localhost").to_owned(), - self.endpoint.port().unwrap_or(5432), - ); + async fn do_wake_compute( + &self, + user_info: &ComputeUserInfo, + ) -> Result { + let (client, connection) = + tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; + tokio::spawn(connection); + + let query = "select host, port from neon_control_plane.endpoints where endpoint_id = $1"; + let row = client + .query_one(query, &[&user_info.endpoint.as_str()]) + .await?; + + let host: String = row.get("host"); + let port: i32 = row.get("port"); + + let mut config = compute::ConnCfg::new(host, port as u16); config.ssl_mode(postgres_client::config::SslMode::Disable); let node = NodeInfo { @@ -182,26 +186,6 @@ impl MockControlPlane { } } -async fn get_execute_postgres_query( - client: &Client, - query: &str, - params: &[&(dyn tokio_postgres::types::ToSql + Sync)], - idx: &str, -) -> Result, GetAuthInfoError> { - let rows = client.query(query, params).await?; - - // We can get at most one row, because `rolname` is unique. - let Some(row) = rows.first() else { - // This means that the user doesn't exist, so there can be no secret. - // However, this is still a *valid* outcome which is very similar - // to getting `404 Not found` from the Neon console. - return Ok(None); - }; - - let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?; - Ok(Some(entry)) -} - impl super::ControlPlaneApi for MockControlPlane { #[tracing::instrument(skip_all)] async fn get_role_secret( @@ -239,9 +223,11 @@ impl super::ControlPlaneApi for MockControlPlane { async fn wake_compute( &self, _ctx: &RequestContext, - _user_info: &ComputeUserInfo, + user_info: &ComputeUserInfo, ) -> Result { - self.do_wake_compute().map_ok(Cached::new_uncached).await + self.do_wake_compute(user_info) + .map_ok(Cached::new_uncached) + .await } } diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 13ada1361e..43820cd845 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3274,7 +3274,7 @@ class NeonProxy(PgProtocol): metric_collection_interval: str | None = None, ): host = "127.0.0.1" - domain = "proxy.localtest.me" # resolves to 127.0.0.1 + domain = "ep-test-endpoint.localtest.me" # resolves to 127.0.0.1 super().__init__(dsn=auth_backend.default_conn_url, host=domain, port=proxy_port) self.domain = domain @@ -3639,7 +3639,19 @@ def static_proxy( vanilla_pg.safe_psql("create user proxy with login superuser password 'password'") vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS neon_control_plane") vanilla_pg.safe_psql( - "CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))" + f""" + CREATE TABLE neon_control_plane.endpoints ( + endpoint_id text PRIMARY KEY, + allowed_ips text, + host text not null default '{host}', + port integer not null default {port} + ) + """ + ) + vanilla_pg.safe_psql( + """ + insert into neon_control_plane.endpoints (endpoint_id) VALUES ('ep-test-endpoint'); + """ ) proxy_port = port_distributor.get_port() diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index d8df2efc78..0f92cdc1ca 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -43,11 +43,16 @@ async def test_http_pool_begin_1(static_proxy: NeonProxy): print(results) -def test_proxy_select_1(static_proxy: NeonProxy): +def test_proxy_select_1(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres): """ A simplest smoke test: check proxy against a local postgres instance. """ + # Establish the default compute for this endpoint + vanilla_pg.safe_psql( + "INSERT INTO neon_control_plane.endpoints (endpoint_id) VALUES ('generic-project-name')" + ) + # no SNI, deprecated `options=project` syntax (before we had several endpoint in project) out = static_proxy.safe_psql("select 1", sslsni=0, options="project=generic-project-name") assert out[0][0] == 1 @@ -61,12 +66,17 @@ def test_proxy_select_1(static_proxy: NeonProxy): assert out[0][0] == 42 -def test_password_hack(static_proxy: NeonProxy): +def test_password_hack(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres): """ Check the PasswordHack auth flow: an alternative to SCRAM auth for clients which can't provide the project/endpoint name via SNI or `options`. """ + # Establish the default compute for this endpoint + vanilla_pg.safe_psql( + "INSERT INTO neon_control_plane.endpoints (endpoint_id) VALUES ('irrelevant')" + ) + user = "borat" password = "password" static_proxy.safe_psql(f"create role {user} with login password '{password}'") @@ -107,7 +117,7 @@ async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy): @pytest.mark.parametrize("option_name", ["project", "endpoint"]) -def test_proxy_options(static_proxy: NeonProxy, option_name: str): +def test_proxy_options(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres, option_name: str): """ Check that we pass extra `options` to the PostgreSQL server: * `project=...` and `endpoint=...` shouldn't be passed at all @@ -115,6 +125,11 @@ def test_proxy_options(static_proxy: NeonProxy, option_name: str): * everything else should be passed as-is. """ + # Establish the default compute for this endpoint + vanilla_pg.safe_psql( + "INSERT INTO neon_control_plane.endpoints (endpoint_id) VALUES ('irrelevant')" + ) + options = f"{option_name}=irrelevant -cproxytest.option=value" out = static_proxy.safe_psql("show proxytest.option", options=options, sslsni=0) assert out[0][0] == "value"