Module flowcon.transforms.base

Basic definitions for the transforms module.

Classes

class CompositeTransform (transforms)

Composes several transforms into one, in the order they are given.

Constructor.

Args

transforms
an iterable of Transform objects.
Expand source code
class CompositeTransform(Transform):
    """Composes several transforms into one, in the order they are given."""

    def __init__(self, transforms):
        """Constructor.

        Args:
            transforms: an iterable of `Transform` objects.
        """
        super().__init__()
        self._transforms = nn.ModuleList(transforms)

    @staticmethod
    def _cascade(inputs, funcs, context):
        batch_size = inputs.shape[0]
        outputs = inputs
        total_logabsdet = inputs.new_zeros(batch_size)
        for func in funcs:
            outputs, logabsdet = func(outputs, context)
            total_logabsdet += logabsdet
        return outputs, total_logabsdet

    def forward(self, inputs, context=None):
        funcs = self._transforms
        return self._cascade(inputs, funcs, context)

    def inverse(self, inputs, context=None):
        funcs = (transform.inverse for transform in self._transforms[::-1])
        return self._cascade(inputs, funcs, context)

Ancestors

Subclasses

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class InputOutsideDomain (*args, **kwargs)

Exception to be thrown when the input to a transform is not within its domain.

Expand source code
class InputOutsideDomain(Exception):
    """Exception to be thrown when the input to a transform is not within its domain."""

    pass

Ancestors

  • builtins.Exception
  • builtins.BaseException
class InverseNotAvailable (*args, **kwargs)

Exception to be thrown when a transform does not have an inverse.

Expand source code
class InverseNotAvailable(Exception):
    """Exception to be thrown when a transform does not have an inverse."""

    pass

Ancestors

  • builtins.Exception
  • builtins.BaseException
class InverseTransform (transform)

Creates a transform that is the inverse of a given transform.

Constructor.

Args

transform
An object of type Transform.
Expand source code
class InverseTransform(Transform):
    """Creates a transform that is the inverse of a given transform."""

    def __init__(self, transform):
        """Constructor.

        Args:
            transform: An object of type `Transform`.
        """
        super().__init__()
        self._transform = transform

    def forward(self, inputs, context=None):
        return self._transform.inverse(inputs, context)

    def inverse(self, inputs, context=None):
        return self._transform(inputs, context)

Ancestors

Subclasses

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def inverse(self, inputs, context=None)

Inherited members

class MultiscaleCompositeTransform (num_transforms, split_dim=1)

A multiscale composite transform as described in the RealNVP paper.

Splits the outputs along the given dimension after every transform, outputs one half, and passes the other half to further transforms. No splitting is done before the last transform.

Note: Inputs could be of arbitrary shape, but outputs will always be flattened.

Reference:

L. Dinh et al., Density estimation using Real NVP, ICLR 2017.

Constructor.

Args

num_transforms
int, total number of transforms to be added.
split_dim
dimension along which to split.
Expand source code
class MultiscaleCompositeTransform(Transform):
    """A multiscale composite transform as described in the RealNVP paper.

    Splits the outputs along the given dimension after every transform, outputs one half, and
    passes the other half to further transforms. No splitting is done before the last transform.

    Note: Inputs could be of arbitrary shape, but outputs will always be flattened.

    Reference:
    > L. Dinh et al., Density estimation using Real NVP, ICLR 2017.
    """

    def __init__(self, num_transforms, split_dim=1):
        """Constructor.

        Args:
            num_transforms: int, total number of transforms to be added.
            split_dim: dimension along which to split.
        """
        if not check.is_positive_int(split_dim):
            raise TypeError("Split dimension must be a positive integer.")

        super().__init__()
        self._transforms = nn.ModuleList()
        self._output_shapes = []
        self._num_transforms = num_transforms
        self._split_dim = split_dim

    def add_transform(self, transform, transform_output_shape):
        """Add a transform. Must be called exactly `num_transforms` times.

        Parameters:
            transform: the `Transform` object to be added.
            transform_output_shape: tuple, shape of transform's outputs, excl. the first batch
                dimension.

        Returns:
            Input shape for the next transform, or None if adding the last transform.
        """
        assert len(self._transforms) <= self._num_transforms

        if len(self._transforms) == self._num_transforms:
            raise RuntimeError(
                "Adding more than {} transforms is not allowed.".format(
                    self._num_transforms
                )
            )

        if (self._split_dim - 1) >= len(transform_output_shape):
            raise ValueError("No split_dim in output shape")

        if transform_output_shape[self._split_dim - 1] < 2:
            raise ValueError(
                "Size of dimension {} must be at least 2.".format(self._split_dim)
            )

        self._transforms.append(transform)

        if len(self._transforms) != self._num_transforms:  # Unless last transform.
            output_shape = list(transform_output_shape)
            output_shape[self._split_dim - 1] = (
                output_shape[self._split_dim - 1] + 1
            ) // 2
            output_shape = tuple(output_shape)

            hidden_shape = list(transform_output_shape)
            hidden_shape[self._split_dim - 1] = hidden_shape[self._split_dim - 1] // 2
            hidden_shape = tuple(hidden_shape)
        else:
            # No splitting for last transform.
            output_shape = transform_output_shape
            hidden_shape = None

        self._output_shapes.append(output_shape)
        return hidden_shape

    def forward(self, inputs, context=None):
        if self._split_dim >= inputs.dim():
            raise ValueError("No split_dim in inputs.")
        if self._num_transforms != len(self._transforms):
            raise RuntimeError(
                "Expecting exactly {} transform(s) "
                "to be added.".format(self._num_transforms)
            )

        batch_size = inputs.shape[0]

        def cascade():
            hiddens = inputs

            for i, transform in enumerate(self._transforms[:-1]):
                transform_outputs, logabsdet = transform(hiddens, context)
                outputs, hiddens = torch.chunk(
                    transform_outputs, chunks=2, dim=self._split_dim
                )
                assert outputs.shape[1:] == self._output_shapes[i]
                yield outputs, logabsdet

            # Don't do the splitting for the last transform.
            outputs, logabsdet = self._transforms[-1](hiddens, context)
            yield outputs, logabsdet

        all_outputs = []
        total_logabsdet = inputs.new_zeros(batch_size)

        for outputs, logabsdet in cascade():
            all_outputs.append(outputs.reshape(batch_size, -1))
            total_logabsdet += logabsdet

        all_outputs = torch.cat(all_outputs, dim=-1)
        return all_outputs, total_logabsdet

    def inverse(self, inputs, context=None):
        if inputs.dim() != 2:
            raise ValueError("Expecting NxD inputs")
        if self._num_transforms != len(self._transforms):
            raise RuntimeError(
                "Expecting exactly {} transform(s) "
                "to be added.".format(self._num_transforms)
            )

        batch_size = inputs.shape[0]

        rev_inv_transforms = [transform.inverse for transform in self._transforms[::-1]]

        split_indices = np.cumsum([np.prod(shape) for shape in self._output_shapes])
        split_indices = np.insert(split_indices, 0, 0)

        split_inputs = []
        for i in range(len(self._output_shapes)):
            flat_input = inputs[:, split_indices[i] : split_indices[i + 1]]
            split_inputs.append(flat_input.view(-1, *self._output_shapes[i]))
        rev_split_inputs = split_inputs[::-1]

        total_logabsdet = inputs.new_zeros(batch_size)

        # We don't do the splitting for the last (here first) transform.
        hiddens, logabsdet = rev_inv_transforms[0](rev_split_inputs[0], context)
        total_logabsdet += logabsdet

        for inv_transform, input_chunk in zip(
            rev_inv_transforms[1:], rev_split_inputs[1:]
        ):
            tmp_concat_inputs = torch.cat([input_chunk, hiddens], dim=self._split_dim)
            hiddens, logabsdet = inv_transform(tmp_concat_inputs, context)
            total_logabsdet += logabsdet

        outputs = hiddens

        return outputs, total_logabsdet

Ancestors

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def add_transform(self, transform, transform_output_shape)

Add a transform. Must be called exactly num_transforms times.

Parameters

transform: the Transform object to be added. transform_output_shape: tuple, shape of transform's outputs, excl. the first batch dimension.

Returns

Input shape for the next transform, or None if adding the last transform.

def inverse(self, inputs, context=None)

Inherited members

class Transform (*args, **kwargs)

Base class for all transform objects.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class Transform(nn.Module):
    """Base class for all transform objects."""

    def forward(self, inputs, context=None):
        raise NotImplementedError()

    def inverse(self, inputs, context=None):
        raise InverseNotAvailable()

Ancestors

  • torch.nn.modules.module.Module

Subclasses

Class variables

var call_super_init : bool
var dump_patches : bool
var training : bool

Methods

def forward(self, inputs, context=None) ‑> Callable[..., Any]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def inverse(self, inputs, context=None)