module documentation

Undocumented

Function lecun_normal_ Undocumented
Function trunc_normal_tf_ Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution N(mean, std2) with values outside [a, b]...
Function variance_scaling_ Undocumented
Function _no_grad_trunc_normal_ Undocumented
def lecun_normal_(tensor: torch.Tensor): (source)

Undocumented

def trunc_normal_tf_(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0) -> torch.Tensor: (source)

Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution N(mean, std2) with values outside [a, b] redrawn until they are within the bounds. The method used for generating the random values works best when a ≤ mean ≤ b.

NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsquently scaled and shifted by the mean and std args.

Args:
tensor: an n-dimensional torch.Tensor mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
def variance_scaling_(tensor: torch.Tensor, scale: float = 1.0, mode: str = 'fan_in', distribution: str = 'normal'): (source)

Undocumented

def _no_grad_trunc_normal_(tensor: torch.Tensor, mean: float, std: float, a: float, b: float): (source)

Undocumented