Module flowcon.distributions.discrete
Implementations of discrete distributions.
Classes
class ConditionalIndependentBernoulli (shape, context_encoder=None)
-
An independent Bernoulli whose parameters are functions of a context.
Constructor.
Args
shape
- list, tuple or torch.Size, the shape of the input variables.
context_encoder
- callable or None, encodes the context to the distribution parameters. If None, defaults to the identity function.
Expand source code
class ConditionalIndependentBernoulli(Distribution): """An independent Bernoulli whose parameters are functions of a context.""" def __init__(self, shape, context_encoder=None): """Constructor. Args: shape: list, tuple or torch.Size, the shape of the input variables. context_encoder: callable or None, encodes the context to the distribution parameters. If None, defaults to the identity function. """ super().__init__() self._shape = torch.Size(shape) if context_encoder is None: self._context_encoder = lambda x: x else: self._context_encoder = context_encoder def _compute_params(self, context): """Compute the logits from context.""" if context is None: raise ValueError("Context can't be None.") logits = self._context_encoder(context) if logits.shape[0] != context.shape[0]: raise RuntimeError( "The batch dimension of the parameters is inconsistent with the input." ) return logits.reshape(logits.shape[0], *self._shape) def _log_prob(self, inputs, context): if inputs.shape[1:] != self._shape: raise ValueError( "Expected input of shape {}, got {}".format( self._shape, inputs.shape[1:] ) ) # Compute parameters. logits = self._compute_params(context) assert logits.shape == inputs.shape # Compute log prob. log_prob = -inputs * F.softplus(-logits) - (1.0 - inputs) * F.softplus(logits) log_prob = torchutils.sum_except_batch(log_prob, num_batch_dims=1) return log_prob def _sample(self, num_samples, context): # Compute parameters. logits = self._compute_params(context) probs = torch.sigmoid(logits) probs = torchutils.repeat_rows(probs, num_samples) # Generate samples. context_size = context.shape[0] noise = torch.rand(context_size * num_samples, *self._shape) samples = (noise < probs).float() return torchutils.split_leading_dim(samples, [context_size, num_samples]) def _mean(self, context): logits = self._compute_params(context) return torch.sigmoid(logits)
Ancestors
- Distribution
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Inherited members