diff --git a/test_runner/batch_others/test_auth.py b/test_runner/batch_others/test_auth.py index 062e54c3be..6f3dc167fe 100644 --- a/test_runner/batch_others/test_auth.py +++ b/test_runner/batch_others/test_auth.py @@ -22,6 +22,25 @@ class AuthKeys: pub: bytes priv: bytes + def generate_management_token(self): + token = jwt.encode({"scope": "pageserverapi"}, self.priv, algorithm="RS256") + + # jwt.encode can return 'bytes' or 'str', depending on Python version or type + # hinting or something (not sure what). If it returned 'bytes', convert it to 'str' + # explicitly. + if isinstance(token, bytes): + token = token.decode() + + return token + + def generate_tenant_token(self, tenant_id): + token = jwt.encode({"scope": "tenant", "tenant_id": tenant_id}, self.priv, algorithm="RS256") + + if isinstance(token, bytes): + token = token.decode() + + return token + @pytest.fixture def auth_keys(repo_dir: str) -> AuthKeys: @@ -34,9 +53,9 @@ def auth_keys(repo_dir: str) -> AuthKeys: def test_pageserver_auth(pageserver_auth_enabled: ZenithPageserver, auth_keys: AuthKeys): ps = pageserver_auth_enabled - tenant_token = jwt.encode({"scope": "tenant", "tenant_id": ps.initial_tenant}, auth_keys.priv, algorithm="RS256") - invalid_tenant_token = jwt.encode({"scope": "tenant", "tenant_id": uuid4().hex}, auth_keys.priv, algorithm="RS256") - management_token = jwt.encode({"scope": "pageserverapi"}, auth_keys.priv, algorithm="RS256") + tenant_token = auth_keys.generate_tenant_token(ps.initial_tenant) + invalid_tenant_token = auth_keys.generate_tenant_token(uuid4().hex) + management_token = auth_keys.generate_management_token() # this does not invoke auth check and only decodes jwt and checks it for validity # check both tokens @@ -72,7 +91,7 @@ def test_compute_auth_to_pageserver( ps = pageserver_auth_enabled # since we are in progress of refactoring protocols between compute safekeeper and page server # use hardcoded management token in safekeeper - management_token = jwt.encode({"scope": "pageserverapi"}, auth_keys.priv, algorithm="RS256") + management_token = auth_keys.generate_management_token() branch = f"test_compute_auth_to_pageserver{with_wal_acceptors}" zenith_cli.run(["branch", branch, "empty"])