module documentation
Undocumented
| Function | lecun |
Undocumented |
| Function | trunc |
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution N(mean, std2) with values outside [a, b]... |
| Function | variance |
Undocumented |
| Function | _no |
Undocumented |
def trunc_normal_tf_(tensor:
torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0) -> torch.Tensor:
(source)
¶
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution N(mean, std2) with values outside [a, b] redrawn until they are within the bounds. The method used for generating the random values works best when a ≤ mean ≤ b.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsquently scaled and shifted by the mean and std args.
- Args:
- tensor: an n-dimensional
torch.Tensormean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value - Examples:
>>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w)