Module flowcon.flows.realnvp
Implementations of Real NVP.
Classes
class SimpleRealNVP (features, hidden_features, num_layers, num_blocks_per_layer, use_volume_preserving=False, activation=<function relu>, dropout_probability=0.0, batch_norm_within_layers=False, batch_norm_between_layers=False)
-
An simplified version of Real NVP for 1-dim inputs.
This implementation uses 1-dim checkerboard masking but doesn't use multi-scaling.
Reference:
L. Dinh et al., Density estimation using Real NVP, ICLR 2017.
Constructor.
Args
transform
- A
Transform
object, it transforms data into noise. distribution
- A
AutoregressiveTransform
object, the base distribution of the flow that generates the noise. embedding_net
- A
nn.Module
which has trainable parameters to encode the context (condition). It is trained jointly with the flow.
Expand source code
class SimpleRealNVP(Flow): """An simplified version of Real NVP for 1-dim inputs. This implementation uses 1-dim checkerboard masking but doesn't use multi-scaling. Reference: > L. Dinh et al., Density estimation using Real NVP, ICLR 2017. """ def __init__( self, features, hidden_features, num_layers, num_blocks_per_layer, use_volume_preserving=False, activation=F.relu, dropout_probability=0.0, batch_norm_within_layers=False, batch_norm_between_layers=False, ): if use_volume_preserving: coupling_constructor = AdditiveCouplingTransform else: coupling_constructor = AffineCouplingTransform mask = torch.ones(features) mask[::2] = -1 def create_resnet(in_features, out_features): return nets.ResidualNet( in_features, out_features, hidden_features=hidden_features, num_blocks=num_blocks_per_layer, activation=activation, dropout_probability=dropout_probability, use_batch_norm=batch_norm_within_layers, ) layers = [] for _ in range(num_layers): transform = coupling_constructor( mask=mask, transform_net_create_fn=create_resnet ) layers.append(transform) mask *= -1 if batch_norm_between_layers: layers.append(BatchNorm(features=features)) super().__init__( transform=CompositeTransform(layers), distribution=StandardNormal([features]), )
Ancestors
- Flow
- Distribution
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Inherited members