|
a |
|
b/lit_gpt/rmsnorm.py |
|
|
1 |
import torch |
|
|
2 |
|
|
|
3 |
|
|
|
4 |
class RMSNorm(torch.nn.Module): |
|
|
5 |
"""Root Mean Square Layer Normalization. |
|
|
6 |
|
|
|
7 |
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: |
|
|
8 |
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. |
|
|
9 |
""" |
|
|
10 |
|
|
|
11 |
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: |
|
|
12 |
super().__init__() |
|
|
13 |
self.weight = torch.nn.Parameter(torch.ones(size)) |
|
|
14 |
self.eps = eps |
|
|
15 |
self.dim = dim |
|
|
16 |
|
|
|
17 |
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
18 |
dtype = x.dtype |
|
|
19 |
x = x.float() |
|
|
20 |
# NOTE: the original RMSNorm paper implementation is not equivalent |
|
|
21 |
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) |
|
|
22 |
x_normed = x * torch.rsqrt(norm_x + self.eps) |
|
|
23 |
return (self.weight * x_normed).to(dtype=dtype) |
|
|
24 |
|
|
|
25 |
def reset_parameters(self) -> None: |
|
|
26 |
torch.nn.init.ones_(self.weight) |