Module flowcon.transforms.no_analytic_inv.base

Classes

class MonotonicTransform (num_iterations=20, num_newton_iterations=1, lim=10, ratio_multiplier=1.5)

Elementwise Inverse for monotonic (elementwise) transformations using newton root finding method. For the initial guess, use bisection method.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class MonotonicTransform(Transform, ABC):
    """
    Elementwise Inverse for monotonic (elementwise) transformations using newton root finding method.
    For the initial guess, use bisection method.
    """

    def __init__(self, num_iterations=20, num_newton_iterations=1, lim=10, ratio_multiplier=1.5):
        self.num_iterations = num_iterations
        self.num_newton_iterations = num_newton_iterations
        self.lim = lim
        self.atol = 1e-7
        self.ratio_multiplier = ratio_multiplier
        super(MonotonicTransform, self).__init__()

    def newton_inverse(self, z, context=None, forward_function=None):
        if forward_function is None:
            forward_function = self.forward

        with torch.enable_grad():
            x_guess = self.bisection_inverse(z, context=context,
                                             forward_function=forward_function)[0].requires_grad_(True)
            for i in range(2):
                f = forward_function(x_guess, context=context)[0] - z
                df_dx = torchutils.gradient(f, x_guess).view(f.shape)
                x_guess = x_guess - f / (df_dx + 1e-7)
        return x_guess, -self.forward_logabsdet(x_guess, context=context, forward_function=forward_function).reshape(-1)

    def bisection_inverse(self, z, context=None, forward_function=None):
        if forward_function is None:
            forward_function = self.forward

        x_max = torch.ones_like(z) * self.lim
        x_min = -torch.ones_like(z) * self.lim

        z_max, _ = forward_function(x_max, context)
        z_min, _ = forward_function(x_min, context)

        idx_maxdiff, idx_mindiff, maxdiff, mindiff = self.calc_diffs(z, z_max, z_min)

        while maxdiff > 0:
            ratio = (maxdiff + z_max.flatten()[idx_maxdiff]) / z_max.flatten()[idx_maxdiff]
            x_max = x_max * self.ratio_multiplier * ratio
            z_max, _ = forward_function(x_max, context)
            idx_maxdiff, idx_mindiff, maxdiff, mindiff = self.calc_diffs(z, z_max, z_min)

        x_max += 1
        while mindiff < 0:
            ratio = (mindiff + z_min.flatten()[idx_mindiff]) / z_min.flatten()[idx_mindiff]
            x_min = x_min * self.ratio_multiplier * ratio
            z_min, _ = forward_function(x_min, context)
            idx_maxdiff, idx_mindiff, maxdiff, mindiff = self.calc_diffs(z, z_max, z_min)
        x_min -= 1

        z_max, _ = forward_function(x_max, context)
        z_min, _ = forward_function(x_min, context)
        # Old inversion by binary search
        i = 0
        x_middle = (x_max + x_min) / 2
        while i < self.num_iterations and (x_middle - z).abs().max() > self.atol:
            # for i in range(self.num_iterations):
            x_middle = (x_max + x_min) / 2
            z_middle, _ = forward_function(x_middle, context)
            left = (z_middle > z).float()
            right = (z_middle < z).float()
            equal = 1 - (left + right)

            x_max = left * x_middle + right * x_max + equal * x_middle
            x_min = right * x_middle + left * x_min + equal * x_middle
            z_max = left * z_middle + right * z_max + equal * z_middle
            z_min = right * z_middle + left * z_min + equal * z_middle
            i += 1

        x = (x_max + x_min) / 2
        # z_pred, _ = self.forward(x_middle, context)
        return x, -self.forward_logabsdet(x, context=context, forward_function=forward_function).squeeze()

    def calc_diffs(self, z, z_max, z_min):
        diff = z - z_max
        idx_maxdiff = torch.argmax(diff)
        maxdiff = diff.flatten()[idx_maxdiff]
        diff = z - z_min
        idx_mindiff = torch.argmin(diff)
        mindiff = diff.flatten()[idx_mindiff]
        return idx_maxdiff, idx_mindiff, maxdiff, mindiff

    def forward_logabsdet(self, inputs, context=None, forward_function=None):
        if forward_function is None:
            forward_function = self.forward
        _, logabsdet = forward_function(inputs=inputs, context=context)
        return logabsdet

    def inverse(self, inputs, context=None, forward_function=None):
        if forward_function is None:
            forward_function = self.forward
        return self.newton_inverse(inputs, context=context, forward_function=forward_function)

Ancestors

  • Transform
  • torch.nn.modules.module.Module
  • abc.ABC

Subclasses

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def bisection_inverse(self, z, context=None, forward_function=None)
def calc_diffs(self, z, z_max, z_min)
def forward_logabsdet(self, inputs, context=None, forward_function=None)
def inverse(self, inputs, context=None, forward_function=None)
def newton_inverse(self, z, context=None, forward_function=None)

Inherited members