Module flowcon.transforms.matrix.diagonal

Classes

class TransformDiagonal (N, diag_transformation: Transform = Exp())

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class TransformDiagonal(Transform):
    def __init__(self, N, diag_transformation: Transform = Exp()):
        super().__init__()
        self.N = N
        self.diag_indices = np.diag_indices(self.N)
        self.diag_mask = nn.Parameter(torch.diag_embed(torch.ones(1, self.N)), requires_grad=False)
        self.diag_transform = diag_transformation

        # self.transform = CompositeTransform([Sigmoid(), ScalarScale(scale=self.MAX_EXP, trainable=False)])

    def forward(self, inputs, context=None):
        transformed_diag, logabsdet_diag = self.diag_transform(torch.diagonal(inputs, dim1=-2, dim2=-1))
        outputs = torch.diagonal_scatter(inputs, transformed_diag, dim1=-2, dim2=-1)
        return outputs, logabsdet_diag

    def inverse(self, inputs, context=None):
        transformed_diag, logabsdet_diag = self.diag_transform.inverse(torch.diagonal(inputs, dim1=-2, dim2=-1))
        outputs = torch.diagonal_scatter(inputs, transformed_diag, dim1=-2, dim2=-1)
        return outputs, logabsdet_diag

Ancestors

Subclasses

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class TransformDiagonalExponential (N, eps=1e-05)

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class TransformDiagonalExponential(TransformDiagonal):
    def __init__(self, N, eps=1e-5):
        super().__init__(N=N, diag_transformation=CompositeTransform([Exp(),
                                                                      ScalarShift(eps, trainable=False)]))

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Inherited members

class TransformDiagonalSoftplus (N, eps=1e-05)

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class TransformDiagonalSoftplus(TransformDiagonal):
    def __init__(self, N, eps=1e-5):
        super().__init__(N=N, diag_transformation=CompositeTransform([Softplus(),
                                                                      ScalarShift(eps, trainable=False)]))

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Inherited members