Source code for inferno.extensions.layers.device

import torch.nn as nn
from ...utils.python_utils import from_iterable, to_iterable
from ...utils.exceptions import assert_, DeviceError

__all__ = ['DeviceTransfer', 'OnDevice']


[docs]class DeviceTransfer(nn.Module): """Layer to transfer variables to a specified device.""" def __init__(self, target_device, device_ordinal=None, async=False): """ Parameters ---------- target_device : {'cpu', 'cuda'} Device to transfer to. device_ordinal : int Device ordinal if target_device == 'cuda'. async : bool Whether to use async transfers. """ super(DeviceTransfer, self).__init__() # Validate arguments assert_(target_device in ['cpu', 'cuda'], "Target device must either be 'cpu' or 'cuda'.", DeviceError) if target_device == 'cpu': assert_(device_ordinal is None, "'device_ordinal' must be None if target_device is 'cpu'.", DeviceError) self.target_device = target_device self.device_ordinal = device_ordinal self.async = async
[docs] def forward(self, *inputs): if self.target_device == 'cuda': transferred = tuple(input_.cuda(device_id=self.device_ordinal, async=self.async) for input_ in inputs) elif self.target_device == 'cpu': transferred = tuple(input_.cpu() for input_ in inputs) else: raise NotImplementedError return from_iterable(transferred)
[docs]class OnDevice(nn.Module): """ Moves a module to a device. The advantage of using this over `torch.nn.Module.cuda` is that the inputs are transferred to the same device as the module, enabling easy model parallelism. """ def __init__(self, module, target_device, device_ordinal=None, async=False): """ Parameters ---------- module : torch.nn.Module Module to transfer to device. target_device : {'cuda', 'cpu'} The device to move `module` to. Must be either 'cuda' or 'cpu'. device_ordinal : int Ordinal of the GPU device if `target_device = 'cuda'`. async : bool Whether to use async transfers. """ super(OnDevice, self).__init__() # Validate arguments assert_(target_device in ['cpu', 'cuda'], "Target device must either be 'cpu' or 'cuda'.", DeviceError) if target_device == 'cpu': assert_(device_ordinal is None, "'device_ordinal' must be None if target_device is 'cpu'.", DeviceError) self.target_device = target_device self.device_ordinal = device_ordinal self.async = async # This is a no-op if module is already in the right device self.device_transfer = DeviceTransfer(self.target_device, device_ordinal=self.device_ordinal, async=self.async) self.module = self.transfer_module(module)
[docs] def transfer_module(self, module): if self.target_device == 'cuda': return module.cuda(device_id=self.device_ordinal) elif self.target_device == 'cpu': return module.cpu() else: raise NotImplementedError
[docs] def forward(self, *inputs): # Transfer inputs (no-op if they're already on the right device) transferred = to_iterable(self.device_transfer(*inputs)) output = self.module(*transferred) return output