Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion tests/modules/layers/test_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
# LICENSE file in the root directory of this source tree.

import torch
from torchmultimodal.modules.layers.normalizations import Fp32GroupNorm, Fp32LayerNorm
import torch.nn.functional as F

from tests.test_utils import gpu_test

from torchmultimodal.modules.layers.normalizations import (
Fp32GroupNorm,
Fp32LayerNorm,
RMSNorm,
)


def test_fp32layernorm():
Expand All @@ -20,3 +28,45 @@ def test_fp32groupnorm():
norm = Fp32GroupNorm(2, 4)
output = norm(x)
assert output.dtype == torch.float16


def test_rms_norm_fp32return():
"""verify type is returned as fp32"""
dims = 512
x = torch.empty(dims, dtype=torch.float16)
norm = RMSNorm(
dims,
)
output = norm(x)
assert output.dtype == torch.float32


@gpu_test(1)
def test_rms_norm_core_algo():
"""compare RMSNorm with RMSNorm using F.norm version"""

dims = 1024
x = torch.empty(dims, dtype=torch.float16, device="cuda")
x_clone = x.clone().detach()

class RMSNormFunctional(torch.nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.scale = dim**0.5
self.weights = torch.nn.Parameter(torch.ones(dim))
self.eps = eps

def forward(self, x):
return F.normalize(x, p=2, dim=-1, eps=self.eps) * self.scale * self.weights

base_norm = RMSNorm(
dims,
).to("cuda")
backup_norm = RMSNormFunctional(
dims,
).to("cuda")

output_base_rms = base_norm(x)
output_backup_rms = backup_norm(x_clone)

assert torch.allclose(output_base_rms, output_backup_rms)
23 changes: 23 additions & 0 deletions torchmultimodal/modules/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import Any

import torch
from torch import nn, Tensor


Expand Down Expand Up @@ -45,3 +46,25 @@ def forward(self, x: Tensor) -> Tensor:
self.eps,
)
return output.type_as(x)


class RMSNorm(nn.Module):
"""Root Mean Square layer normalization
as proposed in: https://arxiv.org/abs/1910.07467

params:
dim = model size
eps = epsilon
"""

def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
x_normed = self._norm(x.float()).type_as(x)
return x_normed * self.scale