module documentation
Undocumented
Function | lecun |
Undocumented |
Function | trunc |
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution N(mean, std2) with values outside [a, b]... |
Function | variance |
Undocumented |
Function | _no |
Undocumented |
def trunc_normal_tf_(tensor:
torch.Tensor
, mean: float
= 0.0, std: float
= 1.0, a: float
= -2.0, b: float
= 2.0) -> torch.Tensor
:
(source)
¶
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution N(mean, std2) with values outside [a, b] redrawn until they are within the bounds. The method used for generating the random values works best when a ≤ mean ≤ b.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsquently scaled and shifted by the mean and std args.
- Args:
- tensor: an n-dimensional
torch.Tensor
mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value - Examples:
>>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w)