Module flowcon.transforms.nonlinearities

Implementations of invertible non-linearities.

Classes

class CauchyCDF (location=None, scale=None, features=None)

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class CauchyCDF(Transform):
    def __init__(self, location=None, scale=None, features=None):
        super().__init__()

    def forward(self, inputs, context=None):
        outputs = (1 / np.pi) * torch.atan(inputs) + 0.5
        logabsdet = torchutils.sum_except_batch(
            -np.log(np.pi) - torch.log(1 + inputs ** 2)
        )
        return outputs, logabsdet

    def inverse(self, inputs, context=None):
        if torch.min(inputs) < 0 or torch.max(inputs) > 1:
            raise InputOutsideDomain()

        outputs = torch.tan(np.pi * (inputs - 0.5))
        logabsdet = -torchutils.sum_except_batch(
            -np.log(np.pi) - torch.log(1 + outputs ** 2)
        )
        return outputs, logabsdet

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class CauchyCDFInverse (location=None, scale=None, features=None)

Creates a transform that is the inverse of a given transform.

Constructor.

Args

transform
An object of type Transform.
Expand source code
class CauchyCDFInverse(InverseTransform):
    def __init__(self, location=None, scale=None, features=None):
        super().__init__(CauchyCDF(location=location, scale=scale, features=features))

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Inherited members

class CompositeCDFTransform (squashing_transform, cdf_transform)

Composes several transforms into one, in the order they are given.

Constructor.

Args

transforms
an iterable of Transform objects.
Expand source code
class CompositeCDFTransform(CompositeTransform):
    def __init__(self, squashing_transform, cdf_transform):
        super().__init__(
            [squashing_transform, cdf_transform, InverseTransform(squashing_transform), ]
        )

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Inherited members

class Exp (*args, **kwargs)

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class Exp(Transform):
    def forward(self, inputs, context=None):
        outputs = torch.exp(inputs)
        logabsdet = torchutils.sum_except_batch(inputs, num_batch_dims=1)

        return outputs, logabsdet

    def inverse(self, inputs, context=None):
        if torch.min(inputs) <= 0.:
            raise InputOutsideDomain()

        outputs = torch.log(inputs)
        logabsdet = -torchutils.sum_except_batch(outputs, num_batch_dims=1)

        return outputs, logabsdet

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class ExtendedSoftplus (features, shift=None)

Combination of a (shifted and scaled) softplus and the same softplus flipped around the origin

Softplus(scale * (x-shift)) - Softplus(-scale * (x + shift))

Linear outside of origin, flat around origin.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class ExtendedSoftplus(torch.nn.Module):
    """
    Combination of a (shifted and scaled) softplus and the same softplus flipped around the origin

    Softplus(scale * (x-shift)) - Softplus(-scale * (x + shift))

    Linear outside of origin, flat around origin.
    """

    def __init__(self, features, shift=None):
        self.features = features
        super(ExtendedSoftplus, self).__init__()
        if shift is None:
            self.shift = torch.nn.Parameter(torch.ones(1, features) * 3, requires_grad=True)
            # self.log_scale = torch.nn.Parameter(torch.zeros(1, features), requires_grad=True)
        elif torch.is_tensor(shift):
            self.shift = shift.reshape(-1, features)
            # self.log_scale = log_scale.reshape(-1, features)
        else:
            self.shift = torch.nn.Parameter(torch.tensor(shift), requires_grad=True)
            # self.log_scale = torch.nn.Parameter(torch.tensor(log_scale), requires_grad=True)

        self._softplus = torch.nn.Softplus()

    # def get_shift_and_scale(self):
    #     # return self._softplus(self.shift), torch.exp(self.log_scale)
    #     return self.shift, torch.exp(self.log_scale) + 1e-3
    #     # return 5, torch.exp(self.log_scale)

    def get_shift(self):
        return self._softplus(self.shift) + 1e-1

    def softplus(self, x, shift):
        return self._softplus((x - shift))

    def softminus(self, x, shift):
        return - self._softplus(-(x + shift))

    def diag_jacobian_pos(self, x, shift):
        # (b e^(b x))/(e^(a b) + e^(b x))
        return torch.exp(x) / (torch.exp(shift) + torch.exp(x))

    def log_diag_jacobian_pos(self, x, shift):
        # -log(e^(a b) + e^(b x)) + b x + log(b)
        log_jac = -torch.logaddexp(shift, x) + x
        return log_jac

    def diag_jacobian_neg(self, x, shift):
        return torch.sigmoid(- (shift + x))

    def log_diag_jacobian_neg(self, x, shift):
        return - self._softplus((shift + x))

    def forward(self, inputs):
        # inputs = inputs.requires_grad_()
        shift = self.get_shift()
        outputs = self.softplus(inputs, shift) + self.softminus(inputs, shift)
        # ref_batch_jacobian = torchutils.batch_jacobian(outputs, inputs)
        # ref_logabsdet = torchutils.logabsdet(ref_batch_jacobian)
        # breakpoint()
        diag_jacobian = torch.logaddexp(self.log_diag_jacobian_pos(inputs, shift),
                                        self.log_diag_jacobian_neg(inputs, shift))
        return outputs, diag_jacobian  # torch.log(diag_jacobian).sum(-1)

Ancestors

  • torch.nn.modules.module.Module

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def diag_jacobian_neg(self, x, shift)
def diag_jacobian_pos(self, x, shift)
def forward(self, inputs) ‑> Callable[..., Any]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def get_shift(self)
def log_diag_jacobian_neg(self, x, shift)
def log_diag_jacobian_pos(self, x, shift)
def softminus(self, x, shift)
def softplus(self, x, shift)
class GatedLinearUnit

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class GatedLinearUnit(Transform):
    def __init__(self):
        super().__init__()

    def forward(self, inputs, context=None):
        gate = torch.sigmoid(context)
        # return inputs * (1 + gate), torch.log(torch.ones_like(gate) + gate).reshape(-1)
        return inputs * gate, torch.log(gate).reshape(-1)

    def inverse(self, inputs, context=None):
        gate = torch.sigmoid(context)
        # return inputs / (1 + gate), - torch.log(torch.ones_like(gate) + gate).reshape(-1)
        return inputs / gate, -torch.log(gate).reshape(-1)

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class LeakyReLU (negative_slope=0.01)

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class LeakyReLU(Transform):
    def __init__(self, negative_slope=1e-2):
        if negative_slope <= 0:
            raise ValueError("Slope must be positive.")
        super().__init__()
        # self.device = device
        self.negative_slope = negative_slope
        self.log_negative_slope = torch.nn.Parameter(torch.log(torch.as_tensor(self.negative_slope)))  # .to(device)

    def forward(self, inputs, context=None):
        outputs = F.leaky_relu(inputs, negative_slope=self.negative_slope)
        mask = (inputs < 0).type(torch.Tensor).to(inputs.device)
        logabsdet = self.log_negative_slope * mask
        logabsdet = torchutils.sum_except_batch(logabsdet, num_batch_dims=1)
        return outputs, logabsdet

    def inverse(self, inputs, context=None):
        outputs = F.leaky_relu(inputs, negative_slope=(1 / self.negative_slope))
        mask = (inputs < 0).type(torch.Tensor).to(inputs.device)
        logabsdet = -self.log_negative_slope * mask
        logabsdet = torchutils.sum_except_batch(logabsdet, num_batch_dims=1)
        return outputs, logabsdet

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class LogTanh (cut_point=1)

Tanh with unbounded output.

Constructed by selecting a cut_point, and replacing values to the right of cut_point with alpha * log(beta * x), and to the left of -cut_point with -alpha * log(-beta * x). alpha and beta are set to match the value and the first derivative of tanh at cut_point.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class LogTanh(Transform):
    """Tanh with unbounded output. 

    Constructed by selecting a cut_point, and replacing values to the right of cut_point
    with alpha * log(beta * x), and to the left of -cut_point with -alpha * log(-beta *
    x). alpha and beta are set to match the value and the first derivative of tanh at
    cut_point."""

    def __init__(self, cut_point=1):
        if cut_point <= 0:
            raise ValueError("Cut point must be positive.")
        super().__init__()

        self.cut_point = cut_point
        self.inv_cut_point = np.tanh(cut_point)

        self.alpha = (1 - np.tanh(np.tanh(cut_point))) / cut_point
        self.beta = np.exp(
            (np.tanh(cut_point) - self.alpha * np.log(cut_point)) / self.alpha
        )

    def forward(self, inputs, context=None):
        mask_right = inputs > self.cut_point
        mask_left = inputs < -self.cut_point
        mask_middle = ~(mask_right | mask_left)

        outputs = torch.zeros_like(inputs)
        outputs[mask_middle] = torch.tanh(inputs[mask_middle])
        outputs[mask_right] = self.alpha * torch.log(self.beta * inputs[mask_right])
        outputs[mask_left] = self.alpha * -torch.log(-self.beta * inputs[mask_left])

        logabsdet = torch.zeros_like(inputs)
        logabsdet[mask_middle] = torch.log(1 - outputs[mask_middle] ** 2)
        logabsdet[mask_right] = torch.log(self.alpha / inputs[mask_right])
        logabsdet[mask_left] = torch.log(-self.alpha / inputs[mask_left])
        logabsdet = torchutils.sum_except_batch(logabsdet, num_batch_dims=1)

        return outputs, logabsdet

    def inverse(self, inputs, context=None):
        mask_right = inputs > self.inv_cut_point
        mask_left = inputs < -self.inv_cut_point
        mask_middle = ~(mask_right | mask_left)

        outputs = torch.zeros_like(inputs)
        outputs[mask_middle] = 0.5 * torch.log(
            (1 + inputs[mask_middle]) / (1 - inputs[mask_middle])
        )
        outputs[mask_right] = torch.exp(inputs[mask_right] / self.alpha) / self.beta
        outputs[mask_left] = -torch.exp(-inputs[mask_left] / self.alpha) / self.beta

        logabsdet = torch.zeros_like(inputs)
        logabsdet[mask_middle] = -torch.log(1 - inputs[mask_middle] ** 2)
        logabsdet[mask_right] = (
                -np.log(self.alpha * self.beta) + inputs[mask_right] / self.alpha
        )
        logabsdet[mask_left] = (
                -np.log(self.alpha * self.beta) - inputs[mask_left] / self.alpha
        )
        logabsdet = torchutils.sum_except_batch(logabsdet, num_batch_dims=1)

        return outputs, logabsdet

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class Logit (temperature=1, eps=1e-06)

Creates a transform that is the inverse of a given transform.

Constructor.

Args

transform
An object of type Transform.
Expand source code
class Logit(InverseTransform):
    def __init__(self, temperature=1, eps=1e-6):
        super().__init__(Sigmoid(temperature=temperature, eps=eps))

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Inherited members

class PiecewiseCubicCDF (shape, num_bins=10, tails=None, tail_bound=1.0, min_bin_width=0.001, min_bin_height=0.001)

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class PiecewiseCubicCDF(Transform):
    def __init__(
            self,
            shape,
            num_bins=10,
            tails=None,
            tail_bound=1.0,
            min_bin_width=splines.cubic.DEFAULT_MIN_BIN_WIDTH,
            min_bin_height=splines.cubic.DEFAULT_MIN_BIN_HEIGHT,
    ):
        super().__init__()

        self.min_bin_width = min_bin_width
        self.min_bin_height = min_bin_height
        self.tail_bound = tail_bound
        self.tails = tails

        self.unnormalized_widths = nn.Parameter(torch.randn(*shape, num_bins))
        self.unnormalized_heights = nn.Parameter(torch.randn(*shape, num_bins))
        self.unnorm_derivatives_left = nn.Parameter(torch.randn(*shape, 1))
        self.unnorm_derivatives_right = nn.Parameter(torch.randn(*shape, 1))

    def _spline(self, inputs, inverse=False):
        batch_size = inputs.shape[0]

        unnormalized_widths = _share_across_batch(self.unnormalized_widths, batch_size)
        unnormalized_heights = _share_across_batch(
            self.unnormalized_heights, batch_size
        )
        unnorm_derivatives_left = _share_across_batch(
            self.unnorm_derivatives_left, batch_size
        )
        unnorm_derivatives_right = _share_across_batch(
            self.unnorm_derivatives_right, batch_size
        )

        if self.tails is None:
            spline_fn = splines.cubic_spline
            spline_kwargs = {}
        else:
            spline_fn = splines.unconstrained_cubic_spline
            spline_kwargs = {"tails": self.tails, "tail_bound": self.tail_bound}

        outputs, logabsdet = spline_fn(
            inputs=inputs,
            unnormalized_widths=unnormalized_widths,
            unnormalized_heights=unnormalized_heights,
            unnorm_derivatives_left=unnorm_derivatives_left,
            unnorm_derivatives_right=unnorm_derivatives_right,
            inverse=inverse,
            min_bin_width=self.min_bin_width,
            min_bin_height=self.min_bin_height,
            **spline_kwargs
        )

        return outputs, torchutils.sum_except_batch(logabsdet)

    def forward(self, inputs, context=None):
        return self._spline(inputs, inverse=False)

    def inverse(self, inputs, context=None):
        return self._spline(inputs, inverse=True)

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class PiecewiseLinearCDF (shape, num_bins=10, tails=None, tail_bound=1.0)

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class PiecewiseLinearCDF(Transform):
    def __init__(self, shape, num_bins=10, tails=None, tail_bound=1.0):
        super().__init__()

        self.tail_bound = tail_bound
        self.tails = tails

        self.unnormalized_pdf = nn.Parameter(torch.randn(*shape, num_bins))

    def _spline(self, inputs, inverse=False):
        batch_size = inputs.shape[0]

        unnormalized_pdf = _share_across_batch(self.unnormalized_pdf, batch_size)

        if self.tails is None:
            outputs, logabsdet = splines.linear_spline(
                inputs=inputs, unnormalized_pdf=unnormalized_pdf, inverse=inverse
            )
        else:
            outputs, logabsdet = splines.unconstrained_linear_spline(
                inputs=inputs,
                unnormalized_pdf=unnormalized_pdf,
                inverse=inverse,
                tails=self.tails,
                tail_bound=self.tail_bound,
            )

        return outputs, torchutils.sum_except_batch(logabsdet)

    def forward(self, inputs, context=None):
        return self._spline(inputs, inverse=False)

    def inverse(self, inputs, context=None):
        return self._spline(inputs, inverse=True)

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class PiecewiseQuadraticCDF (shape, num_bins=10, tails=None, tail_bound=1.0, min_bin_width=0.001, min_bin_height=0.001)

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class PiecewiseQuadraticCDF(Transform):
    def __init__(
            self,
            shape,
            num_bins=10,
            tails=None,
            tail_bound=1.0,
            min_bin_width=splines.quadratic.DEFAULT_MIN_BIN_WIDTH,
            min_bin_height=splines.quadratic.DEFAULT_MIN_BIN_HEIGHT,
    ):
        super().__init__()
        self.min_bin_width = min_bin_width
        self.min_bin_height = min_bin_height
        self.tail_bound = tail_bound
        self.tails = tails

        self.unnormalized_widths = nn.Parameter(torch.randn(*shape, num_bins))
        if tails is None:
            self.unnormalized_heights = nn.Parameter(torch.randn(*shape, num_bins + 1))
        else:
            self.unnormalized_heights = nn.Parameter(torch.randn(*shape, num_bins - 1))

    def _spline(self, inputs, inverse=False):
        batch_size = inputs.shape[0]

        unnormalized_widths = _share_across_batch(self.unnormalized_widths, batch_size)
        unnormalized_heights = _share_across_batch(
            self.unnormalized_heights, batch_size
        )

        if self.tails is None:
            spline_fn = splines.quadratic_spline
            spline_kwargs = {}
        else:
            spline_fn = splines.unconstrained_quadratic_spline
            spline_kwargs = {"tails": self.tails, "tail_bound": self.tail_bound}

        outputs, logabsdet = spline_fn(
            inputs=inputs,
            unnormalized_widths=unnormalized_widths,
            unnormalized_heights=unnormalized_heights,
            inverse=inverse,
            min_bin_width=self.min_bin_width,
            min_bin_height=self.min_bin_height,
            **spline_kwargs
        )

        return outputs, torchutils.sum_except_batch(logabsdet)

    def forward(self, inputs, context=None):
        return self._spline(inputs, inverse=False)

    def inverse(self, inputs, context=None):
        return self._spline(inputs, inverse=True)

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class PiecewiseRationalQuadraticCDF (shape, num_bins=10, tails=None, tail_bound=1.0, identity_init=False, min_bin_width=0.001, min_bin_height=0.001, min_derivative=0.001)

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class PiecewiseRationalQuadraticCDF(Transform):
    def __init__(
            self,
            shape,
            num_bins=10,
            tails=None,
            tail_bound=1.0,
            identity_init=False,
            min_bin_width=splines.rational_quadratic.DEFAULT_MIN_BIN_WIDTH,
            min_bin_height=splines.rational_quadratic.DEFAULT_MIN_BIN_HEIGHT,
            min_derivative=splines.rational_quadratic.DEFAULT_MIN_DERIVATIVE,
    ):
        super().__init__()

        self.min_bin_width = min_bin_width
        self.min_bin_height = min_bin_height
        self.min_derivative = min_derivative

        self.tail_bound = tail_bound
        self.tails = tails

        if isinstance(shape, int):
            shape = (shape,)
        if identity_init:
            self.unnormalized_widths = nn.Parameter(torch.zeros(*shape, num_bins))
            self.unnormalized_heights = nn.Parameter(torch.zeros(*shape, num_bins))

            constant = np.log(np.exp(1 - min_derivative) - 1)
            num_derivatives = (
                (num_bins - 1) if self.tails == "linear" else (num_bins + 1)
            )
            self.unnormalized_derivatives = nn.Parameter(
                constant * torch.ones(*shape, num_derivatives)
            )
        else:
            self.unnormalized_widths = nn.Parameter(torch.rand(*shape, num_bins))
            self.unnormalized_heights = nn.Parameter(torch.rand(*shape, num_bins))

            num_derivatives = (
                (num_bins - 1) if self.tails == "linear" else (num_bins + 1)
            )
            self.unnormalized_derivatives = nn.Parameter(
                torch.rand(*shape, num_derivatives)
            )

    def _spline(self, inputs, inverse=False):
        batch_size = inputs.shape[0]

        unnormalized_widths = _share_across_batch(self.unnormalized_widths, batch_size)
        unnormalized_heights = _share_across_batch(
            self.unnormalized_heights, batch_size
        )
        unnormalized_derivatives = _share_across_batch(
            self.unnormalized_derivatives, batch_size
        )

        if self.tails is None:
            spline_fn = splines.rational_quadratic_spline
            spline_kwargs = {}
        else:
            spline_fn = splines.unconstrained_rational_quadratic_spline
            spline_kwargs = {"tails": self.tails, "tail_bound": self.tail_bound}

        outputs, logabsdet = spline_fn(
            inputs=inputs,
            unnormalized_widths=unnormalized_widths,
            unnormalized_heights=unnormalized_heights,
            unnormalized_derivatives=unnormalized_derivatives,
            inverse=inverse,
            min_bin_width=self.min_bin_width,
            min_bin_height=self.min_bin_height,
            min_derivative=self.min_derivative,
            **spline_kwargs
        )

        return outputs, torchutils.sum_except_batch(logabsdet)

    def forward(self, inputs, context=None):
        return self._spline(inputs, inverse=False)

    def inverse(self, inputs, context=None):
        return self._spline(inputs, inverse=True)

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class Sigmoid (temperature=1, eps=1e-06, learn_temperature=False)

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class Sigmoid(Transform):
    def __init__(self, temperature=1, eps=1e-6, learn_temperature=False):
        super().__init__()
        self.eps = eps
        if learn_temperature:
            self.temperature = nn.Parameter(torch.Tensor([temperature]))
        else:
            temperature = torch.Tensor([temperature])
            self.register_buffer('temperature', temperature)

    def forward(self, inputs, context=None):
        inputs = self.temperature * inputs
        outputs = torch.sigmoid(inputs)
        logabsdet = torchutils.sum_except_batch(
            torch.log(self.temperature) - F.softplus(-inputs) - F.softplus(inputs)
        )
        return outputs, logabsdet

    def inverse(self, inputs, context=None):
        if torch.min(inputs) < 0 or torch.max(inputs) > 1:
            raise InputOutsideDomain()

        inputs = torch.clamp(inputs, self.eps, 1 - self.eps)

        outputs = (1 / self.temperature) * (torch.log(inputs) - torch.log1p(-inputs))
        logabsdet = -torchutils.sum_except_batch(
            torch.log(self.temperature)
            - F.softplus(-self.temperature * outputs)
            - F.softplus(self.temperature * outputs)
        )
        return outputs, logabsdet

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class Softplus (threshold=20, eps=0.0)

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class Softplus(Transform):
    def __init__(self, threshold=20, eps=0.):
        super().__init__()

        self.eps = eps
        self.softplus = torch.nn.Softplus(beta=1, threshold=threshold)
        self.log_sigmoid = torch.nn.LogSigmoid()

    def forward(self, inputs, context=None):
        outputs = self.softplus(inputs) + self.eps
        logabsdet = self.log_sigmoid(inputs).sum(-1)
        return outputs, logabsdet

    def inverse(self, inputs, context=None):
        inputs = inputs - self.eps
        outputs = torch.where(inputs > self.softplus.threshold, inputs, inputs.expm1().log())
        logabsdet = -torch.log(-torch.expm1(-inputs)).sum(-1)
        return outputs, logabsdet

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class Tanh (*args, **kwargs)

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class Tanh(Transform):
    def forward(self, inputs, context=None):
        outputs = torch.tanh(inputs)
        logabsdet = torch.log(1 - outputs ** 2)
        logabsdet = torchutils.sum_except_batch(logabsdet, num_batch_dims=1)
        return outputs, logabsdet

    def inverse(self, inputs, context=None):
        if torch.min(inputs) <= -1 or torch.max(inputs) >= 1:
            raise InputOutsideDomain()
        outputs = 0.5 * torch.log((1 + inputs) / (1 - inputs))
        logabsdet = -torch.log(1 - inputs ** 2)
        logabsdet = torchutils.sum_except_batch(logabsdet, num_batch_dims=1)
        return outputs, logabsdet

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members