Module flowcon.transforms.normalization
Implementation of normalization-based transforms.
Classes
class ActNorm (features)
-
Base class for all transform objects.
Transform that performs activation normalization. Works for 2D and 4D inputs. For 4D inputs (images) normalization is performed per-channel, assuming BxCxHxW input shape.
Reference:
D. Kingma et. al., Glow: Generative flow with invertible 1x1 convolutions, NeurIPS 2018.
Expand source code
class ActNorm(Transform): def __init__(self, features): """ Transform that performs activation normalization. Works for 2D and 4D inputs. For 4D inputs (images) normalization is performed per-channel, assuming BxCxHxW input shape. Reference: > D. Kingma et. al., Glow: Generative flow with invertible 1x1 convolutions, NeurIPS 2018. """ if not check.is_positive_int(features): raise TypeError("Number of features must be a positive integer.") super().__init__() self.register_buffer("initialized", torch.tensor(False, dtype=torch.bool)) self.log_scale = nn.Parameter(torch.zeros(features)) self.shift = nn.Parameter(torch.zeros(features)) @property def scale(self): return torch.exp(self.log_scale) def _broadcastable_scale_shift(self, inputs): if inputs.dim() == 4: return self.scale.view(1, -1, 1, 1), self.shift.view(1, -1, 1, 1) else: return self.scale.view(1, -1), self.shift.view(1, -1) def forward(self, inputs, context=None): if inputs.dim() not in [2, 4]: raise ValueError("Expecting inputs to be a 2D or a 4D tensor.") if self.training and not self.initialized: self._initialize(inputs) scale, shift = self._broadcastable_scale_shift(inputs) outputs = scale * inputs + shift if inputs.dim() == 4: batch_size, _, h, w = inputs.shape logabsdet = h * w * torch.sum(self.log_scale) * outputs.new_ones(batch_size) else: batch_size, _ = inputs.shape logabsdet = torch.sum(self.log_scale) * outputs.new_ones(batch_size) return outputs, logabsdet def inverse(self, inputs, context=None): if inputs.dim() not in [2, 4]: raise ValueError("Expecting inputs to be a 2D or a 4D tensor.") scale, shift = self._broadcastable_scale_shift(inputs) outputs = (inputs - shift) / scale if inputs.dim() == 4: batch_size, _, h, w = inputs.shape logabsdet = -h * w * torch.sum(self.log_scale) * outputs.new_ones(batch_size) else: batch_size, _ = inputs.shape logabsdet = -torch.sum(self.log_scale) * outputs.new_ones(batch_size) return outputs, logabsdet def _initialize(self, inputs): """Data-dependent initialization, s.t. post-actnorm activations have zero mean and unit variance. """ if inputs.dim() == 4: num_channels = inputs.shape[1] inputs = inputs.permute(0, 2, 3, 1).reshape(-1, num_channels) with torch.no_grad(): std = inputs.std(dim=0) mu = (inputs / std).mean(dim=0) self.log_scale.data = -torch.log(std) self.shift.data = -mu self.initialized.data = torch.tensor(True, dtype=torch.bool)
Ancestors
- Transform
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Instance variables
prop scale
-
Expand source code
@property def scale(self): return torch.exp(self.log_scale)
Methods
def inverse(self, inputs, context=None)
Inherited members
class BatchNorm (features, eps=1e-05, momentum=0.1, affine=True)
-
Transform that performs batch normalization.
Limitations
- It works only for 1-dim inputs.
- Inverse is not available in training mode, only in eval mode.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class BatchNorm(Transform): """Transform that performs batch normalization. Limitations: * It works only for 1-dim inputs. * Inverse is not available in training mode, only in eval mode. """ def __init__(self, features, eps=1e-5, momentum=0.1, affine=True): if not check.is_positive_int(features): raise TypeError("Number of features must be a positive integer.") super().__init__() self.momentum = momentum self.eps = eps constant = np.log(np.exp(1 - eps) - 1) self.unconstrained_weight = nn.Parameter(constant * torch.ones(features)) self.bias = nn.Parameter(torch.zeros(features)) self.register_buffer("running_mean", torch.zeros(features)) self.register_buffer("running_var", torch.zeros(features)) @property def weight(self): return F.softplus(self.unconstrained_weight) + self.eps def forward(self, inputs, context=None): if inputs.dim() != 2: raise ValueError( "Expected 2-dim inputs, got inputs of shape: {}".format(inputs.shape) ) if self.training: mean, var = inputs.mean(0), inputs.var(0) self.running_mean.mul_(1 - self.momentum).add_(mean.detach() * self.momentum) self.running_var.mul_(1 - self.momentum).add_(var.detach() * self.momentum) else: mean, var = self.running_mean, self.running_var outputs = ( self.weight * ((inputs - mean) / torch.sqrt((var + self.eps))) + self.bias ) logabsdet_ = torch.log(self.weight) - 0.5 * torch.log(var + self.eps) logabsdet = torch.sum(logabsdet_) * inputs.new_ones(inputs.shape[0]) return outputs, logabsdet def inverse(self, inputs, context=None): if self.training: raise InverseNotAvailable( "Batch norm inverse is only available in eval mode, not in training mode." ) if inputs.dim() != 2: raise ValueError( "Expected 2-dim inputs, got inputs of shape: {}".format(inputs.shape) ) outputs = ( torch.sqrt(self.running_var + self.eps) * ((inputs - self.bias) / self.weight) + self.running_mean ) logabsdet_ = -torch.log(self.weight) + 0.5 * torch.log( self.running_var + self.eps ) logabsdet = torch.sum(logabsdet_) * inputs.new_ones(inputs.shape[0]) return outputs, logabsdet
Ancestors
- Transform
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Instance variables
prop weight
-
Expand source code
@property def weight(self): return F.softplus(self.unconstrained_weight) + self.eps
Methods
def inverse(self, inputs, context=None)
Inherited members