Add arg to override config in zenith_cli

This commit is contained in:
Arthur Petukhovsky
2022-01-25 14:28:55 +00:00
parent be6d1cc360
commit 17419b8a62
4 changed files with 59 additions and 14 deletions

View File

@@ -19,6 +19,7 @@ import subprocess
import time
import filecmp
import tempfile
import toml
from contextlib import closing
from pathlib import Path
@@ -372,7 +373,6 @@ class MockS3Server:
def kill(self):
self.subprocess.kill()
class ZenithEnvBuilder:
"""
Builder object to create a Zenith runtime environment
@@ -502,6 +502,9 @@ class ZenithEnv:
self.port_distributor = config.port_distributor
self.s3_mock_server = config.s3_mock_server
# If specified, this config will be passed for all zenith_cli calls
self.override_config: Optional[Dict] = None
self.postgres = PostgresFactory(self)
self.safekeepers: List[Safekeeper] = []
@@ -599,9 +602,22 @@ sync = false # Disable fsyncs to make the tests go faster
assert type(arguments) == list
if self.override_config is not None:
tmp_config = tempfile.NamedTemporaryFile(mode='w+')
tmp_config.write(toml.dumps(self.override_config))
tmp_config.flush()
log.info('Using overriden config to run next command!')
log.info(f'Config: {toml.dumps(self.override_config)}')
else:
tmp_config = None
bin_zenith = os.path.join(str(zenith_binpath), 'zenith')
args = [bin_zenith] + arguments
args = [bin_zenith]
if tmp_config is not None:
args += ['--override-config', tmp_config.name]
args += arguments
log.info('Running command "{}"'.format(' '.join(args)))
log.info(f'Running in "{self.repo_dir}"')
@@ -637,9 +653,21 @@ sync = false # Disable fsyncs to make the tests go faster
log.info(msg)
raise Exception(msg) from exc
finally:
if tmp_config is not None:
tmp_config.close()
return res
def read_toml_config(self) -> Dict:
"""
Read the config file from the repo directory.
Returns a dictionary of the config file.
"""
with open(os.path.join(str(self.repo_dir), 'config')) as f:
return toml.load(f)
@cached_property
def auth_keys(self) -> AuthKeys:
pub = (Path(self.repo_dir) / 'auth_public_key.pem').read_bytes()