diff --git a/setup.py b/setup.py index 2cc1d9d8..41d904f0 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ 'env = rocker.extensions:Environment', 'expose = rocker.extensions:Expose', 'git = rocker.git_extension:Git', + 'gpus = rocker.nvidia_extension:Gpus', 'group_add = rocker.extensions:GroupAdd', 'home = rocker.extensions:HomeDir', 'hostname = rocker.extensions:Hostname', @@ -62,6 +63,7 @@ 'privileged = rocker.extensions:Privileged', 'pulse = rocker.extensions:PulseAudio', 'rmw = rocker.rmw_extension:RMW', + 'shm_size = rocker.extensions:ShmSize', 'ssh = rocker.ssh_extension:Ssh', 'ulimit = rocker.ulimit_extension:Ulimit', 'user = rocker.extensions:User', diff --git a/src/rocker/extensions.py b/src/rocker/extensions.py index 3e22b71a..bee93639 100644 --- a/src/rocker/extensions.py +++ b/src/rocker/extensions.py @@ -468,3 +468,27 @@ def register_arguments(parser, defaults): default=defaults.get(GroupAdd.get_name(), None), action='append', help="Add additional groups to join.") + +class ShmSize(RockerExtension): + @staticmethod + def get_name(): + return 'shm_size' + + def __init__(self): + self.name = ShmSize.get_name() + + def get_preamble(self, cliargs): + return '' + + def get_docker_args(self, cliargs): + args = '' + shm_size = cliargs.get('shm_size', None) + if shm_size: + args += f' --shm-size {shm_size} ' + return args + + @staticmethod + def register_arguments(parser, defaults={}): + parser.add_argument('--shm-size', + default=defaults.get('shm_size', None), + help="Set the size of the shared memory for the container (e.g., 512m, 1g).") \ No newline at end of file diff --git a/src/rocker/nvidia_extension.py b/src/rocker/nvidia_extension.py index 3644c220..e4e47e32 100644 --- a/src/rocker/nvidia_extension.py +++ b/src/rocker/nvidia_extension.py @@ -152,12 +152,15 @@ def get_snippet(self, cliargs): def get_docker_args(self, cliargs): force_flag = cliargs.get('nvidia', None) + gpus_ids_flag = cliargs.get('gpus', None) + if gpus_ids_flag is None: + gpus_ids_flag = 'all' if force_flag == 'runtime': return " --runtime=nvidia" if force_flag == 'gpus': - return " --gpus all" + return f" --gpus {gpus_ids_flag}" if get_docker_version() >= Version("19.03"): - return " --gpus all" + return f" --gpus {gpus_ids_flag}" return " --runtime=nvidia" @staticmethod @@ -236,3 +239,30 @@ def register_arguments(parser, defaults): action='store_true', default=defaults.get('cuda', None), help="Install cuda and nvidia-cuda-dev into the container") + +class Gpus(RockerExtension): + @staticmethod + def get_name(): + return 'gpus' + + def __init__(self): + self.name = Gpus.get_name() + + def get_preamble(self, cliargs): + return '' + + def get_docker_args(self, cliargs): + # The gpu ids will be set in the nvidia extension, if the nvidia argument is passed. + if cliargs.get('nvidia', None): + return '' + args = '' + gpus = cliargs.get('gpus', None) + if gpus: + args += f' --gpus {gpus} ' + return args + + @staticmethod + def register_arguments(parser, defaults={}): + parser.add_argument('--gpus', + default=defaults.get('gpus', None), + help="Set the indices of GPUs to use") \ No newline at end of file diff --git a/test/test_extension.py b/test/test_extension.py index d264c35c..1ad88d85 100644 --- a/test/test_extension.py +++ b/test/test_extension.py @@ -617,3 +617,35 @@ def test_group_add_extension(self): args = p.get_docker_args(mock_cliargs) self.assertIn('--group-add sudo', args) self.assertIn('--group-add docker', args) + +class ShmSizeExtensionTest(unittest.TestCase): + + def setUp(self): + # Work around interference between empy Interpreter + # stdout proxy and test runner. empy installs a proxy on stdout + # to be able to capture the information. + # And the test runner creates a new stdout object for each test. + # This breaks empy as it assumes that the proxy has persistent + # between instances of the Interpreter class + # empy will error with the exception + # "em.Error: interpreter stdout proxy lost" + em.Interpreter._wasProxyInstalled = False + + @pytest.mark.docker + def test_shm_size_extension(self): + plugins = list_plugins() + shm_size_plugin = plugins['shm_size'] + self.assertEqual(shm_size_plugin.get_name(), 'shm_size') + + p = shm_size_plugin() + self.assertTrue(plugin_load_parser_correctly(shm_size_plugin)) + + mock_cliargs = {} + self.assertEqual(p.get_snippet(mock_cliargs), '') + self.assertEqual(p.get_preamble(mock_cliargs), '') + args = p.get_docker_args(mock_cliargs) + self.assertNotIn('--shm-size', args) + + mock_cliargs = {'shm_size': '12g'} + args = p.get_docker_args(mock_cliargs) + self.assertIn('--shm-size 12g', args) \ No newline at end of file diff --git a/test/test_nvidia.py b/test/test_nvidia.py index 6fb7e63b..9cfd2f51 100644 --- a/test/test_nvidia.py +++ b/test/test_nvidia.py @@ -328,3 +328,41 @@ def test_cuda_env_subs(self): with self.assertRaises(SystemExit) as cm: p.get_environment_subs(mock_cliargs) self.assertEqual(cm.exception.code, 1) + +class GpusExtensionTest(unittest.TestCase): + + def setUp(self): + # Work around interference between empy Interpreter + # stdout proxy and test runner. empy installs a proxy on stdout + # to be able to capture the information. + # And the test runner creates a new stdout object for each test. + # This breaks empy as it assumes that the proxy has persistent + # between instances of the Interpreter class + # empy will error with the exception + # "em.Error: interpreter stdout proxy lost" + em.Interpreter._wasProxyInstalled = False + + @pytest.mark.docker + def test_gpus_extension(self): + plugins = list_plugins() + gpus_plugin = plugins['gpus'] + self.assertEqual(gpus_plugin.get_name(), 'gpus') + + p = gpus_plugin() + self.assertTrue(plugin_load_parser_correctly(gpus_plugin)) + + # Test when no GPUs are specified + mock_cliargs = {} + self.assertEqual(p.get_snippet(mock_cliargs), '') + self.assertEqual(p.get_preamble(mock_cliargs), '') + args = p.get_docker_args(mock_cliargs) + self.assertNotIn('--gpus', args) + + # Test when GPUs are specified + mock_cliargs = {'gpus': 'all'} + args = p.get_docker_args(mock_cliargs) + self.assertIn('--gpus all', args) + + mock_cliargs = {'gpus': '0,1'} + args = p.get_docker_args(mock_cliargs) + self.assertIn('--gpus 0,1', args)