Module flowcon.transforms.unitvector
Classes
class UnitVector (features)
-
Bijective map from R^d to the surface of a sphere in R^(d+1) with an inverse stereographic projection. Inspired by https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors/UnitVector
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class UnitVector(Transform): """ Bijective map from R^d to the surface of a sphere in R^(d+1) with an inverse stereographic projection. Inspired by https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors/UnitVector """ def __init__(self, features): super().__init__() self.dim_Rd = features self.dim_sphere = torch.nn.Parameter(torch.tensor(features + 1, dtype=torch.float32)) def forward(self, inputs, context=None): """ R^d to surface of sphere """ assert inputs.shape[-1] == self.dim_Rd norm_sq = self.norm_sq(inputs) output = torch.concatenate([2 * inputs, norm_sq - 1], dim=-1) / (norm_sq + 1) return output, self.forward_logabsdet(inputs).view(-1) def inverse(self, inputs, context=None): """ Surface of sphere to R^d """ assert inputs.shape[-1] == self.dim_sphere self.check_inverse_range(inputs) outputs = (inputs / (1 - inputs[..., -1:]))[..., :-1] return outputs, -self.forward_logabsdet(outputs).view(-1) def check_inverse_range(self, inverse_inputs): if torch.abs(torch.max(self.norm_sq(inverse_inputs)) - 1.) > 1e-4: raise InputOutsideDomain() def forward_logabsdet(self, inputs): """ sqrt{det{J^T J}} = (2 / (sum^2 + 1))^n. :param inputs: :param forward_norm_sq: :return: """ n = self.dim_Rd log_2 = torch.log(torch.tensor(2).to(inputs)) return n * (log_2 - torch.log1p(self.norm_sq(inputs).squeeze())) @staticmethod def norm_sq(x): return torch.sum(x ** 2, dim=-1, keepdim=True)
Ancestors
- Transform
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Static methods
def norm_sq(x)
Methods
def check_inverse_range(self, inverse_inputs)
def forward(self, inputs, context=None) ‑> Callable[..., Any]
-
R^d to surface of sphere
def forward_logabsdet(self, inputs)
-
sqrt{det{J^T J}} = (2 / (sum^2 + 1))^n. :param inputs: :param forward_norm_sq: :return:
def inverse(self, inputs, context=None)
-
Surface of sphere to R^d