본문 바로가기
데이터&AI/LLM

RMSNorm과 Layer Normalization 비교하기

by 일등박사 2024. 6. 28.

오늘의 주제 요약!!

 

LN은 합창단의 각 파트가 아름다운 화음을 만들도록 돕는 합창 지휘자라면, 

RMSNorm은 오케스트라 전체의 연주를 조율하여 감동적인 교향곡을 완성하는 오케스트라 지휘자와 같다!!

 

1. Layer Normalization

 - 신경망의 각 층에서 입력 데이터의 정규화를 수행하는 방법

 -  합창단의 각 파트별 음량을 조절하는 것과 같음!! 소프라노, 알토, 테너, 베이스 각 파트의 음량이 너무 크거나 작으면 전체적인 화음이 깨지게 되기에 LN은 각 파트의 음량을 적절히 조절하여 전체적인 화음을 아름답게 만드는 역할을 함

 

  1. 입력 데이터:
    • 입력 데이터 𝑥=[𝑥1,𝑥2,...,𝑥𝑑] 의 𝑑차원 벡터. 
  2. 평균 계산:  각 샘플의 평균을 계산

3. 분산 계산:샘플의 분산을 계산

4. 정규화: 입력 데이터 x 정규화 진행

( 은 분모가 0이 되는 것을 방지하기 위해 사용하는 작은 값).

 

5. 스케일과 시프트: 정규화된 값에 학습 가능한 파라미터 𝛾𝛽를 사용하여 스케일과 시프트를 적용

 

 

최종적 Layer Normalization산출물!! 

 

𝑦=[𝑦1,𝑦2,...,𝑦𝑑]

 

 

이를 통해 신경망의 각 층에서 입력 데이터의 분포가 조절되어 학습이 안정적이고 효율적으로 진행 가능

 

2. RMSNorm (Root Mean Square Normalization)

 -  Layer Normalization과 유사하게 신경망의 입력을 정규화하는 방법

 -  다만, RMSNorm은 각 특성의 평균을 고려하지 않고, 입력의 제곱평균을 사용하여 정규화를 수행 >> 이에 주로 경량화된 계산을 필요로 하는 모델에서 사용

 -  각 악기군뿐만 아니라 개별 악기의 음량까지 세밀하게 조절하여 오케스트라 전체의 균형을 맞춤. 마치 지휘자가 각 악기의 소리를 듣고 전체적인 연주 흐름에 맞춰 적절한 강약을 조절하는 것

 

  1.  입력 데이터: (LN과 동일) :  𝑥=[𝑥1,𝑥2,...,𝑥𝑑] 의 𝑑차원 벡터.
  2. 평균 계산:  각 평균들의 제곱 평균을 계산

3. 분산 계산 (없음 LN과의 차이, 제곱평균이기에 가능)

4. 정규화: 입력 데이터 x 정규화 진행

( 은 분모가 0이 되는 것을 방지하기 위해 사용하는 작은 값).

 

5. 스케일과 시프트: RMSNorm에서는 일반적으로 시프트 파라미터 𝛽를 사용 X

 

결과물은 LN과 동일한 𝑦=[𝑦1,𝑦2,...,𝑦𝑑] 형식!!

RMSNorm의 요약 수식은!!

 

 

 

 

3. LN vs  RMSNorm

구분 Layer Normalization  RMSNorm
정규화 방식 평균과 분산을 사용하여 정규화 제곱평균을 사용하여 정규화
계산 비용 평균과 분산 계산이 필요하여 비용이 더 높음 제곱평균 계산만 필요하여 비용이 더 낮음
파라미터 𝛾 (스케일 파라미터), 𝛽 (시프트 파라미터) 𝛾 (스케일 파라미터)
배치 민감도 민감하지 않음 민감하지 않음
주 사용 사례 RNN과 같은 순환 신경망 대규모 신경망, 계산 비용이 낮은 모델 (transformer)
장점 각 샘플의 평균과 분산을 고려하여 정규화하여 더 안정적인 학습 가능 계산 비용이 낮고 간단한 계산으로 정규화 수행 가능
단점 평균과 분산 계산으로 인해 계산 비용이 높음 평균을 고려하지 않기 때문에 일부 경우에서 학습 안정성 떨어질 수 있음

 

4. pytorch로 실습해보기

 A. LN (Layer Normalization)

import torch
import torch.nn as nn

# Define a LayerNorm layer
layer_norm = nn.LayerNorm(normalized_shape)

# Apply LayerNorm to the output of a layer
normalized_output = layer_norm(output)

 

B. RMSNorm

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-8):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(normalized_shape))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True))
        x_norm = x / (rms + self.eps)
        return self.gamma * x_norm

# Define an RMSNorm layer
rms_norm = RMSNorm(normalized_shape)

# Apply RMSNorm to the output of a layer
normalized_output = rms_norm(output)

 

 

ㅁ 참고1 : LN 관련 논문 https://arxiv.org/abs/1607.06450

ㅁ 참고2 : RMSNorm 관련논문! https://arxiv.org/abs/1910.07467

 

Root Mean Square Layer Normalization

Layer normalization (LayerNorm) has been successfully applied to various deep neural networks to help stabilize training and boost model convergence because of its capability in handling re-centering and re-scaling of both inputs and weight matrix. However

arxiv.org

 

댓글