Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
'expose = rocker.extensions:Expose',
'git = rocker.git_extension:Git',
'group_add = rocker.extensions:GroupAdd',
'shm_size = rocker.extensions:ShmSize',
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in alphabetical order.

'gpus = rocker.extensions:Gpus',
'home = rocker.extensions:HomeDir',
'hostname = rocker.extensions:Hostname',
'ipc = rocker.extensions:Ipc',
Expand Down
49 changes: 49 additions & 0 deletions src/rocker/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,52 @@ def register_arguments(parser, defaults):
default=defaults.get(GroupAdd.get_name(), None),
action='append',
help="Add additional groups to join.")

class ShmSize(RockerExtension):
@staticmethod
def get_name():
return 'shm_size'

def __init__(self):
self.name = ShmSize.get_name()

def get_preamble(self, cliargs):
return ''

def get_docker_args(self, cliargs):
args = ''
shm_size = cliargs.get('shm_size', None)
if shm_size:
args += f' --shm-size={shm_size} '
return args

@staticmethod
def register_arguments(parser, defaults={}):
parser.add_argument('--shm-size',
default=defaults.get('shm_size', None),
help="Set the size of the shared memory for the container (e.g., 512m, 1g).")


class Gpus(RockerExtension):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the cross coupling can you move this to the nvidia_extensions file instead? It's not explicitly nvidia but I'd like to keep them colocated. And that file could potentially be renamed in the future with it's expanded scope. The X11 class is already not nvidia tied.

@staticmethod
def get_name():
return 'gpus'

def __init__(self):
self.name = Gpus.get_name()

def get_preamble(self, cliargs):
return ''

def get_docker_args(self, cliargs):
args = ''
shm_size = cliargs.get('gpus', None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like some copy and paste here

if shm_size:
args += f' --gpus {shm_size} '
return args

@staticmethod
def register_arguments(parser, defaults={}):
parser.add_argument('--gpus',
default=defaults.get('gpus', None),
help="Set the indices of GPUs to use")