Module flowcon.transforms.matrix.diagonal
Classes
class TransformDiagonal (N, diag_transformation: Transform = Exp())
-
Base class for all transform objects.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class TransformDiagonal(Transform): def __init__(self, N, diag_transformation: Transform = Exp()): super().__init__() self.N = N self.diag_indices = np.diag_indices(self.N) self.diag_mask = nn.Parameter(torch.diag_embed(torch.ones(1, self.N)), requires_grad=False) self.diag_transform = diag_transformation # self.transform = CompositeTransform([Sigmoid(), ScalarScale(scale=self.MAX_EXP, trainable=False)]) def forward(self, inputs, context=None): transformed_diag, logabsdet_diag = self.diag_transform(torch.diagonal(inputs, dim1=-2, dim2=-1)) outputs = torch.diagonal_scatter(inputs, transformed_diag, dim1=-2, dim2=-1) return outputs, logabsdet_diag def inverse(self, inputs, context=None): transformed_diag, logabsdet_diag = self.diag_transform.inverse(torch.diagonal(inputs, dim1=-2, dim2=-1)) outputs = torch.diagonal_scatter(inputs, transformed_diag, dim1=-2, dim2=-1) return outputs, logabsdet_diag
Ancestors
- Transform
- torch.nn.modules.module.Module
Subclasses
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Methods
def inverse(self, inputs, context=None)
Inherited members
class TransformDiagonalExponential (N, eps=1e-05)
-
Base class for all transform objects.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class TransformDiagonalExponential(TransformDiagonal): def __init__(self, N, eps=1e-5): super().__init__(N=N, diag_transformation=CompositeTransform([Exp(), ScalarShift(eps, trainable=False)]))
Ancestors
- TransformDiagonal
- Transform
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Inherited members
class TransformDiagonalSoftplus (N, eps=1e-05)
-
Base class for all transform objects.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class TransformDiagonalSoftplus(TransformDiagonal): def __init__(self, N, eps=1e-5): super().__init__(N=N, diag_transformation=CompositeTransform([Softplus(), ScalarShift(eps, trainable=False)]))
Ancestors
- TransformDiagonal
- Transform
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Inherited members