import torch
import torch.nn as nn
import torch.nn.functional as F
from ...utils.exceptions import assert_, ShapeError
from ...utils import python_utils as pyu
__all__ = ['View', 'AsMatrix', 'Flatten',
'As3D', 'As2D',
'Concatenate', 'Cat',
'ResizeAndConcatenate', 'PoolCat',
'Sum', 'SplitChannels']
[docs]class View(nn.Module):
def __init__(self, as_shape):
super(View, self).__init__()
self.as_shape = self.validate_as_shape(as_shape)
[docs] def validate_as_shape(self, as_shape):
assert all([isinstance(_s, int) or _s == 'x' for _s in as_shape])
all_int_indices = [_n for _n, _s in enumerate(as_shape) if isinstance(_s, int)]
if all_int_indices:
first_int_at_index = all_int_indices[0]
assert all([isinstance(_s, int) for _s in as_shape[first_int_at_index:]])
return as_shape
[docs] def forward(self, input):
input_shape = list(input.size())
reshaped_shape = [_s if isinstance(_s, int) else input_shape[_n]
for _n, _s in enumerate(self.as_shape)]
output = input.view(*reshaped_shape)
return output
[docs]class AsMatrix(View):
def __init__(self):
super(AsMatrix, self).__init__(as_shape=['x', 'x'])
[docs]class Flatten(View):
def __init__(self):
super(Flatten, self).__init__(as_shape=['x', -1])
[docs]class As3D(nn.Module):
def __init__(self, channel_as_z=False, num_channels_or_num_z_slices=1):
super(As3D, self).__init__()
self.channel_as_z = channel_as_z
self.num_channels_or_num_z_slices = num_channels_or_num_z_slices
[docs] def forward(self, input):
if input.dim() == 5:
# If input is a batch of 3D volumes - return as is
return input
elif input.dim() == 4:
# If input is a batch of 2D images, reshape
b, c, _0, _1 = list(input.size())
assert_(c % self.num_channels_or_num_z_slices == 0,
"Number of channels of the 4D image tensor (= {}) must be "
"divisible by the set number of channels or number of z slices "
"of the 5D volume tensor (= {})."
.format(c, self.num_channels_or_num_z_slices),
ShapeError)
c //= self.num_channels_or_num_z_slices
if self.channel_as_z:
# Move channel axis to z
return input.view(b, self.num_channels_or_num_z_slices, c, _0, _1)
else:
# Keep channel axis where it is, but add a singleton dimension for z
return input.view(b, c, self.num_channels_or_num_z_slices, _0, _1)
elif input.dim() == 2:
# We have a matrix which we wish to turn to a 3D batch
b, c = list(input.size())
return input.view(b, c, 1, 1, 1)
else:
raise NotImplementedError
[docs]class As2D(nn.Module):
def __init__(self, z_as_channel=True):
super(As2D, self).__init__()
self.z_as_channel = z_as_channel
[docs] def forward(self, input):
if input.dim() == 5:
b, c, _0, _1, _2 = list(input.size())
if not self.z_as_channel:
assert _0 == 1
# Reshape
return input.view(b, c * _0, _1, _2)
elif input.dim() == 4:
# Nothing to do here - input is already 2D
return input
elif input.dim() == 2:
# We make singleton dimensions
b, c = list(input.size())
return input.view(b, c, 1, 1)
[docs]class Concatenate(nn.Module):
"""Concatenate input tensors along a specified dimension."""
def __init__(self, dim=1):
super(Concatenate, self).__init__()
self.dim = dim
[docs] def forward(self, *inputs):
return torch.cat(inputs, dim=self.dim)
[docs]class ResizeAndConcatenate(nn.Module):
"""
Resize input tensors spatially (to a specified target size) before concatenating
them along the channel dimension. The downsampling mode can be specified
('average' or 'max'), but the upsampling is always 'nearest'.
"""
POOL_MODE_MAPPING = {'avg': 'avg',
'average': 'avg',
'mean': 'avg',
'max': 'max'}
def __init__(self, target_size, pool_mode='average'):
super(ResizeAndConcatenate, self).__init__()
self.target_size = target_size
assert_(pool_mode in self.POOL_MODE_MAPPING.keys(),
"`pool_mode` must be one of {}, got {} instead."
.format(self.POOL_MODE_MAPPING.keys(), pool_mode),
ValueError)
self.pool_mode = pool_mode
[docs] def forward(self, *inputs):
dim = inputs[0].dim()
assert_(dim in [4, 5],
'Input tensors must either be 4 or 5 '
'dimensional, but inputs[0] is {}D.'.format(dim),
ShapeError)
# Get resize function
spatial_dim = {4: 2, 5: 3}[dim]
resize_function = getattr(F, 'adaptive_{}_pool{}d'.format(self.pool_mode,
spatial_dim))
target_size = pyu.as_tuple_of_len(self.target_size, spatial_dim)
# Do the resizing
resized_inputs = []
for input_num, input in enumerate(inputs):
# Make sure the dim checks out
assert_(input.dim() == dim,
"Expected inputs[{}] to be a {}D tensor, got a {}D "
"tensor instead.".format(input_num, dim, input.dim()),
ShapeError)
resized_inputs.append(resize_function(input, target_size))
# Concatenate along the channel axis
concatenated = torch.cat(tuple(resized_inputs), 1)
# Done
return concatenated
[docs]class Cat(Concatenate):
"""An alias for `Concatenate`. Hey, everyone knows who Cat is."""
pass
[docs]class PoolCat(ResizeAndConcatenate):
"""Alias for `ResizeAndConcatenate`, just to annoy snarky web developers."""
pass
[docs]class Sum(nn.Module):
"""Sum all inputs."""
[docs] def forward(self, *inputs):
return torch.stack(inputs, dim=0).sum(0).squeeze(0)
[docs]class SplitChannels(nn.Module):
"""Split input at a given index along the channel axis."""
def __init__(self, channel_index):
super(SplitChannels, self).__init__()
self.channel_index = channel_index
[docs] def forward(self, input):
if isinstance(self.channel_index, int):
split_location = self.channel_index
elif self.channel_index == 'half':
split_location = input.size(1) // 2
else:
raise NotImplementedError
assert split_location < input.size(1)
split_0 = input[:, 0:split_location, ...]
split_1 = input[:, split_location:, ...]
return split_0, split_1