옵티마이저 아무거나 선택하면 안되는 이유, Adam vs AdamW

2023. 8. 21. 14:09광고차단 머신러닝

728x90

 

 

???:"휴먼, 당신의 말은 이해할 수 없습니다.(진짜모름)", 자연어처리를 위한 BERT 선택의 이유와 근

1. 상황 2023.07.24 - [산학협력프로젝트] - ??? : "하... 뭐 고르지?", 기술 선택의 이유와 근거 (1) ??? : "하... 뭐 고르지?", 기술 선택의 이유와 근거 (1) 뭐 먹을까요? 깔깔깔 산학협력 프로젝트, 광고

xpmxf4.tistory.com

위 글에서 저는 다음가 같은 코드를 보여드렸습니다.

model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=2)
model = model.to(device)

optimizer = AdamW(model.parameters(), lr=1e-5)

이 코드는 BERT 모델의 옵티마이저로 AdamW로 채택한 코드입니다.

 

최근에 이 글을 읽은 제 지인이 AdamW 가 뭐냐는 질문이 들어와 이렇게 답을 했습니다.

AdamW는 딥러닝에서 널리 사용되는 옵티마이저 중 하나예요!

뭔 개x리야?

그러더니 이런 표정을 지으시더라고요...ㅎ..

그래서 이번 기회에 AdamW에 대해서 좀 더 자세하게 글을 풀어보도록 하겠습니다!

 

1. 옵티마이저란?

머신러닝과 딥러닝에서, 모델은 데이터를 통해 학습을 합니다.

이때 학습의 중요한 목표는 ('오차') == (예측값과 실제값 사이의 차이)를 줄이는 것이 목표입니다.

이때 '오차'를 나타내는 함수를 '손실 함수'라고 부릅니다.

1. 추정/예측 모델을 정하고, 이 모델이 데이터를 얼마나 잘 추정/예측하는지,
     - 그 정도에 대해 수학적으로 표현하는 함수

2. 손실함수(Loss function)는 예측값과 실제값(레이블)의 차이를 구하는 기준을 의미하는 것으로
머신러닝 모델 학습에서 필수 구성요소라고 할 수 있다.

이때 옵티마이저는 이 손실 함수의 값을 최소화하기 위해

모델의 파라미터를 조절하는 알고리즘입니다.

 

즉, 옵티마이저는 모델이 학습하는 동안 손실 함수를 최소화할 수 있도록

파라미터를 업데이트하는 역할을 합니다.

 

근데 파라미터란?

(사진)

[정의]
파라미터는 모델의 예측 성능을 결정하는 변수들입니다.
이 변수들은 학습 데이터를 통해 조정되며, 학습이 진행됨에 따라 최적의 값을 찾아갑니다.

[딥러닝에서의 파라미터]
딥러닝, 특히 신경망에서, 파라미터는 주로 가중치(weights)와 편향(biases)으로 구성됩니다.
가중치(Weights): 각 연결선에 할당된 값으로, 입력 데이터에 얼마나 강한 영향을 미칠지를 결정합니다.
편향(Biases): 각 뉴런에 추가되는 값으로, 뉴런의 출력을 조정하는 역할을 합니다.

여기까지 되게 이론적인(a.k.a. 못 알아듣겠는) 말로만 설명을 했는 데

예시로 간단하게 설명을 해볼게요!

옵티마이저는 쉽게 비유하자면 고속도로에서의 운전에 적합한

고급 운전 기술을 가진 운전자와 같습니다.

 

운전자는 차량의 속도와 방향을 적절히 조절하면서

목적지에 빠르게 도달하는 방법을 알고 있는 사람과 같다고 보면 됩니다!

2. AdamW 란 무엇인가?

먼저 Adam에 대해서 알아보겠습니다!

Adam은 Adaptative Moment Estimation의 약자로,

딥러닝에서 널리 사용되는 옵티마이저 중 하나입니다.

이는 과거의 gradient(기울기)의 제곱들의 평균 + gradient  들의 평균을 가지고

계산하는 2 가지의 방법을 사용합니다.

즉, 과거의 기울기 정보들을 사용, 현재의 파라미터를 업데이트합니다.

 

하지만 옵티마이저에게 들어오는 기울기 값 중 하나가 값이 과하다면,

모델의 예측은 과한 기울기에 편향된 결과를 가지고 오게 됩니다.

이를 과적합(오버피팅)이라고 부릅니다.

 

Adam 옵티마이저는 일부 연구에서 일정한 오버피팅 문제점을 가지고 있다는 것이 지적되었습니다.

이러한 문제를 해결하기 위해 제안된 것이 AdamW입니다.

AdamW Adam 원리를 기반으로 하되, 가중치 감소(Weight Decay) 보다 효과적으로 적용하여

모델의 일반화 성능을 향상시킵니다.

 

3. 오버피팅을 간단하게 알아보자면?

게임에서 최고의 점수를 얻기 위해 특정 전략을 사용하고 있다고 해볼까요?

 

전략은 처음에는 작동하며, 우리는 높은 점수를 얻습니다.

 

그러나, 게임의 다른 레벨에서는 전략이 효과적이지 않습니다.

여기서 게임의 레벨은 데이터의 다양성, 전략은 모델의 파라미터를 의미합니다.

 

Adam 특정 레벨에서만 작동하는 전략을 고수하는 경향이 있을 있습니다.

반면, AdamW 다양한 레벨에서도 작동하는 전략을 찾아가려고 합니다.

4. AdamW 실제 사용 코드

이는 제 BERT 모델학습 시에 AdamW를 사용한 코드입니다.

 

5. 실제로 Adam, AdamW 성능 비교해보기

성능을 비교하기 위해 코드를 짜고 테스트를 한번 돌려봤습니다!

DataSet 이라든지, 엑셀 데이터(1,000개) 를 가져오는 작업은 

이전 글과 동일하기에 넣지 않고, 순수 비교하는 코드만 넣었습니다!

# Adam 옵티마이저와 AdamW 옵티마이저 설정
optimizers = {
    'Adam': Adam(model.parameters(), lr=1e-5),
    'AdamW': AdamW(model.parameters(), lr=1e-5)
}

losses = {
    'Adam': [],
    'AdamW': []
}

# 옵티마이저별 학습
for opt_name, optimizer in optimizers.items():
    # 모델 초기 상태 저장
    original_state = model.state_dict().copy()
    model.train()

    for epoch in range(5):
        epoch_loss = 0.0

        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            epoch_loss += loss.item()

        # 에포크의 평균 손실 저장
        losses[opt_name].append(epoch_loss / len(dataloader))
        print(f'Optimizer: {opt_name}, Epoch {epoch + 1}, Loss {epoch_loss / len(dataloader)}')

    # 모델을 초기 상태로 복원
    model.load_state_dict(original_state)

# 손실 그래프를 통한 비교
plt.plot(losses['Adam'], label='Adam')
plt.plot(losses['AdamW'], label='AdamW')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Optimizer Comparison')
plt.show()

코드에 대한 간략한 설명을 하자면 BERT 모델의 옵티마이저로

Adam, AdamW 각각을 채용했을 때의 오차율을 matplotlib 라이브러리로

나타내는 코드입니다!

Adam vs AdamW 오차율 그래프

그래프를 보시면 AdamW 의 성능이 훨씬 좋죠?

확실히 성능을 개선한 모델이 맞다고 느껴지네요!

 

마무리

오늘만 포스팅을 마치겠습니다!

질문이 있거나 틀린 부분이 있다면 언제든 지적해 주세요!

감사합니다 :)

출처

http://www.ktword.co.kr/test/view/view.php?nav=2&no=6261&sh=%EC%86%90%EC%8B%A4+%ED%95%A8%EC%88%98

http://dmqm.korea.ac.kr/activity/seminar/326

728x90