a b/lit_gpt/rmsnorm.py
1
import torch
2
3
4
class RMSNorm(torch.nn.Module):
5
    """Root Mean Square Layer Normalization.
6
7
    Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
8
    https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
9
    """
10
11
    def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
12
        super().__init__()
13
        self.weight = torch.nn.Parameter(torch.ones(size))
14
        self.eps = eps
15
        self.dim = dim
16
17
    def forward(self, x: torch.Tensor) -> torch.Tensor:
18
        dtype = x.dtype
19
        x = x.float()
20
        # NOTE: the original RMSNorm paper implementation is not equivalent
21
        norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
22
        x_normed = x * torch.rsqrt(norm_x + self.eps)
23
        return (self.weight * x_normed).to(dtype=dtype)
24
25
    def reset_parameters(self) -> None:
26
        torch.nn.init.ones_(self.weight)