Module flowcon.transforms.linear
Implementations of linear transforms.
Classes
class Linear (features, using_cache=False)
-
Abstract base class for linear transforms that parameterize a weight matrix.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class Linear(Transform): """Abstract base class for linear transforms that parameterize a weight matrix.""" def __init__(self, features, using_cache=False): if not check.is_positive_int(features): raise TypeError("Number of features must be a positive integer.") super().__init__() self.features = features self.bias = nn.Parameter(torch.zeros(features)) # Caching flag and values. self.using_cache = using_cache self.cache = LinearCache() def forward(self, inputs, context=None): if not self.training and self.using_cache: self._check_forward_cache() outputs = F.linear(inputs, self.cache.weight, self.bias) logabsdet = self.cache.logabsdet * outputs.new_ones(outputs.shape[0]) return outputs, logabsdet else: return self.forward_no_cache(inputs) def _check_forward_cache(self): if self.cache.weight is None and self.cache.logabsdet is None: self.cache.weight, self.cache.logabsdet = self.weight_and_logabsdet() elif self.cache.weight is None: self.cache.weight = self.weight() elif self.cache.logabsdet is None: self.cache.logabsdet = self.logabsdet() def inverse(self, inputs, context=None): if not self.training and self.using_cache: self._check_inverse_cache() outputs = F.linear(inputs - self.bias, self.cache.inverse) logabsdet = (-self.cache.logabsdet) * outputs.new_ones(outputs.shape[0]) return outputs, logabsdet else: return self.inverse_no_cache(inputs) def _check_inverse_cache(self): if self.cache.inverse is None and self.cache.logabsdet is None: ( self.cache.inverse, self.cache.logabsdet, ) = self.weight_inverse_and_logabsdet() elif self.cache.inverse is None: self.cache.inverse = self.weight_inverse() elif self.cache.logabsdet is None: self.cache.logabsdet = self.logabsdet() def train(self, mode=True): if mode: # If training again, invalidate cache. self.cache.invalidate() return super().train(mode) def use_cache(self, mode=True): if not check.is_bool(mode): raise TypeError("Mode must be boolean.") self.using_cache = mode def weight_and_logabsdet(self): # To be overridden by subclasses if it is more efficient to compute the weight matrix # and its logabsdet together. return self.weight(), self.logabsdet() def weight_inverse_and_logabsdet(self): # To be overridden by subclasses if it is more efficient to compute the weight matrix # inverse and weight matrix logabsdet together. return self.weight_inverse(), self.logabsdet() def forward_no_cache(self, inputs): """Applies `forward` method without using the cache.""" raise NotImplementedError() def inverse_no_cache(self, inputs): """Applies `inverse` method without using the cache.""" raise NotImplementedError() def weight(self): """Returns the weight matrix.""" raise NotImplementedError() def weight_inverse(self): """Returns the inverse weight matrix.""" raise NotImplementedError() def logabsdet(self): """Returns the log absolute determinant of the weight matrix.""" raise NotImplementedError()
Ancestors
- Transform
- torch.nn.modules.module.Module
Subclasses
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Methods
def forward_no_cache(self, inputs)
-
Applies
forward
method without using the cache. def inverse(self, inputs, context=None)
def inverse_no_cache(self, inputs)
-
Applies
inverse
method without using the cache. def logabsdet(self)
-
Returns the log absolute determinant of the weight matrix.
def train(self, mode=True)
-
Set the module in training mode.
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:
Dropout
, :class:BatchNorm
, etc.Args
mode
:bool
- whether to set training mode (
True
) or evaluation mode (False
). Default:True
.
Returns
Module
- self
def use_cache(self, mode=True)
def weight(self)
-
Returns the weight matrix.
def weight_and_logabsdet(self)
def weight_inverse(self)
-
Returns the inverse weight matrix.
def weight_inverse_and_logabsdet(self)
Inherited members
class LinearCache
-
Helper class to store the cache of a linear transform.
The cache consists of: the weight matrix, its inverse and its log absolute determinant.
Expand source code
class LinearCache: """Helper class to store the cache of a linear transform. The cache consists of: the weight matrix, its inverse and its log absolute determinant. """ def __init__(self): self.weight = None self.inverse = None self.logabsdet = None def invalidate(self): self.weight = None self.inverse = None self.logabsdet = None
Methods
def invalidate(self)
class NaiveLinear (features, orthogonal_initialization=True, using_cache=False)
-
A general linear transform that uses an unconstrained weight matrix.
This transform explicitly computes the log absolute determinant in the forward direction and uses a linear solver in the inverse direction.
Both forward and inverse directions have a cost of O(D^3), where D is the dimension of the input.
Constructor.
Args
features
- int, number of input features.
orthogonal_initialization
- bool, if True initialize weights to be a random orthogonal matrix.
Raises
TypeError
- if
features
is not a positive integer.
Expand source code
class NaiveLinear(Linear): """A general linear transform that uses an unconstrained weight matrix. This transform explicitly computes the log absolute determinant in the forward direction and uses a linear solver in the inverse direction. Both forward and inverse directions have a cost of O(D^3), where D is the dimension of the input. """ def __init__(self, features, orthogonal_initialization=True, using_cache=False): """Constructor. Args: features: int, number of input features. orthogonal_initialization: bool, if True initialize weights to be a random orthogonal matrix. Raises: TypeError: if `features` is not a positive integer. """ super().__init__(features, using_cache) if orthogonal_initialization: self._weight = nn.Parameter(torchutils.random_orthogonal(features)) else: self._weight = nn.Parameter(torch.empty(features, features)) stdv = 1.0 / np.sqrt(features) init.uniform_(self._weight, -stdv, stdv) def forward_no_cache(self, inputs): """Cost: output = O(D^2N) logabsdet = O(D^3) where: D = num of features N = num of inputs """ batch_size = inputs.shape[0] outputs = F.linear(inputs, self._weight, self.bias) logabsdet = torchutils.logabsdet(self._weight) logabsdet = logabsdet * outputs.new_ones(batch_size) return outputs, logabsdet def inverse_no_cache(self, inputs): """Cost: output = O(D^3 + D^2N) logabsdet = O(D^3) where: D = num of features N = num of inputs """ batch_size = inputs.shape[0] outputs = inputs - self.bias # LU-decompose the weights and solve for the outputs. lu, lu_pivots = torch.linalg.lu_factor(self._weight) outputs = torch.linalg.lu_solve(lu, lu_pivots, outputs.t()).t() # The linear-system solver returns the LU decomposition of the weights, which we # can use to obtain the log absolute determinant directly. logabsdet = -torch.sum(torch.log(torch.abs(torch.diag(lu)))) logabsdet = logabsdet * outputs.new_ones(batch_size) return outputs, logabsdet def weight(self): """Cost: weight = O(1) """ return self._weight def weight_inverse(self): """ Cost: inverse = O(D^3) where: D = num of features """ return torch.inverse(self._weight) def weight_inverse_and_logabsdet(self): """ Cost: inverse = O(D^3) logabsdet = O(D) where: D = num of features """ # If both weight inverse and logabsdet are needed, it's cheaper to compute both together. identity = torch.eye(self.features, self.features) # LU-decompose the weights and solve for the outputs. lu, lu_pivots = torch.lu(self._weight) weight_inv = torch.lu_solve(identity, lu, lu_pivots) logabsdet = torch.sum(torch.log(torch.abs(torch.diag(lu)))) return weight_inv, logabsdet def logabsdet(self): """Cost: logabsdet = O(D^3) where: D = num of features """ return torchutils.logabsdet(self._weight)
Ancestors
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Methods
def forward_no_cache(self, inputs)
-
Cost
output = O(D^2N) logabsdet = O(D^3) where: D = num of features N = num of inputs
def inverse_no_cache(self, inputs)
-
Cost
output = O(D^3 + D^2N) logabsdet = O(D^3) where: D = num of features N = num of inputs
def logabsdet(self)
-
Cost
logabsdet = O(D^3) where: D = num of features
def weight(self)
-
Cost: weight = O(1)
def weight_inverse(self)
-
Cost
inverse = O(D^3) where: D = num of features
def weight_inverse_and_logabsdet(self)
-
Cost
inverse = O(D^3) logabsdet = O(D) where: D = num of features
Inherited members
class ScalarScale (scale=1.0, trainable=True, eps=0.0001)
-
Base class for all transform objects.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class ScalarScale(Transform): def __init__(self, scale=1., trainable=True, eps=1e-4): super().__init__() assert np.all(scale > 1e-6), "Scale too small.." self._scale = nn.Parameter(torch.log(torch.tensor(scale, dtype=torch.get_default_dtype())), requires_grad=trainable) self.eps = eps @property def scale(self): return torch.exp(self._scale) + self.eps def forward(self, inputs, context=None): outputs = self.scale * inputs logabsdet = inputs.new_ones(inputs.shape[0]) * torch.log(self.scale).sum() * np.sum(inputs.shape[1:]) return outputs, logabsdet def inverse(self, inputs, context=None): outputs = inputs * (1. / self.scale) logabsdet = - inputs.new_ones(inputs.shape[0]) * torch.log(self.scale).sum() * np.sum(inputs.shape[1:]) return outputs, logabsdet
Ancestors
- Transform
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Instance variables
prop scale
-
Expand source code
@property def scale(self): return torch.exp(self._scale) + self.eps
Methods
def inverse(self, inputs, context=None)
Inherited members
class ScalarShift (shift=0.0, trainable=True)
-
Base class for all transform objects.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class ScalarShift(Transform): def __init__(self, shift=0., trainable=True): super().__init__() self.shift = nn.Parameter(torch.tensor(shift, dtype=torch.get_default_dtype()), requires_grad=trainable) def forward(self, inputs, context=None): outputs = inputs + self.shift return outputs, inputs.new_zeros(inputs.shape[0]) def inverse(self, inputs, context=None): outputs = inputs - self.shift return outputs, inputs.new_zeros(inputs.shape[0])
Ancestors
- Transform
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Methods
def inverse(self, inputs, context=None)
Inherited members