Module flowcon.transforms.matrix.cholesky
Classes
class CholeskyOuterProduct (N, checkargs=True, eps=1e-06)
-
Base class for all transform objects.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class CholeskyOuterProduct(Transform): def __init__(self, N, checkargs=True, eps=1e-6): super().__init__() self.eps = eps self.N = N self.eye = nn.Parameter(torch.diag_embed(torch.ones(self.N)).unsqueeze(0), requires_grad=False) self.powers = nn.Parameter(torch.arange(self.N, 0, -1).unsqueeze(0), requires_grad=False) self.checkargs = checkargs def forward(self, inputs, context=None): if self.checkargs: self.check_pos_low_triang(inputs) outputs = torch.bmm(inputs, inputs.mT) outputs = 0.5*(outputs + outputs.mT) diagonal = torch.diagonal(inputs, dim1=-2, dim2=-1) logabsdet = self.N * np.log(2.) + (self.powers * diagonal.log()).sum(-1) return outputs, logabsdet def inverse(self, inputs, context=None): inputs_jitter = inputs + self.eye * self.eps if self.checkargs: self.check_pos_def(inputs_jitter) outputs = torch.linalg.cholesky(inputs_jitter, upper=False) diagonal = torch.diagonal(outputs, dim1=-2, dim2=-1) logabsdet = self.N * np.log(2.) + (self.powers * diagonal.log()).sum(1) return outputs, -logabsdet def check_pos_low_triang(self, inputs): assert inputs.shape[-2] == inputs.shape[-1], "input tensor must be mini batch of square matrices" upper_indices = np.triu_indices(self.N, k=1) assert torch.all(inputs[:, upper_indices[0], upper_indices[1]] == 0.), ( "input tensor must be mini batch of lower triangular matrices") assert torch.all(torch.diagonal(inputs, dim1=-2, dim2=-1) > 0), ( 'input tensor must be mini batch of lower triangular matrices with positive diagonal elements') def check_pos_def(self, inputs): assert torch.all(inputs == inputs.mT), "Input matrix is not symmetric." assert torch.all(torch.linalg.eig(inputs)[0].real >= 0), ( "Input matrix is not positive semi-definite in order to perform Cholesky decomposition" )
Ancestors
- Transform
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Methods
def check_pos_def(self, inputs)
def check_pos_low_triang(self, inputs)
def inverse(self, inputs, context=None)
Inherited members