크크루쿠쿠

DL03. Back-propagation 역전파 본문

DeepLearning/공부

DL03. Back-propagation 역전파

JH_KIM 2021. 1. 9. 01:08

역전파 알고리즘 Backpropagation

 

전에 본 예시들처럼 간단하다면 상관X

But 인공지능 신경망이 이렇게 복잡하다면? → loss에 대한 gradient 값을 계산 불가능

 

농구의 자유투 연습을 생각해보자

자유투를 던지는 과정 → 순전파 과정 (forward propagation)

공이 도착한 위치를 보고 던지는 위치 수정 → Backpropagation

 

즉 loss를 구한 다음 그 loss를 뒤로 전파해가면서 변수들을 갱신해주는 것

 

그렇담 어떻게?

Chain rule

역전파 방식을 사용하기 위해선 이 chain rule을 이용해야한다.

 

 

 

이런 방식으로 뒤로 미분값을 계속 곱해줌으로써 모든 parameter의 loss에 대한 미분값을 알 수 있음.

 

 

예시

 

y=w*x의 경우를 예시로 들어보자

 

x=1,y=2 dataset과 w=1이라 가정

 

 

그뒤 나오는 loss 1과 그것을 이용하여 back propagation 해줌으로써

chain rule을 사용해 loss의 미분값을 알아냄

 

 

예시 코드

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import pdb
 
x_data = [1.02.03.0]
y_data = [2.04.06.0]
= torch.tensor([1.0], requires_grad=True)
 
 
# our model forward pass
def forward(x):
    return x * w
 
# Loss function
def loss(y_pred, y_val):
    return (y_pred - y_val) ** 2
 
# Before training
print("Prediction (before training)",  4, forward(4).item())
 
# Training loop
for epoch in range(10):
    for x_val, y_val in zip(x_data, y_data):
        y_pred = forward(x_val) # 1) Forward pass
        l = loss(y_pred, y_val) # 2) Compute loss
        l.backward() # 3) 역전파 + w 업데이트 자동으로 해줌
        print("\tgrad: ", x_val, y_val, w.grad.item()) # x,y,미분값
        w.data = w.data - 0.01 * w.grad.item()
 
        # Manually zero the gradients after updating weights
        w.grad.data.zero_()
                # 다 끝난뒤 0으로 초기화 시켜주기
 
    print(f"Epoch: {epoch} | Loss: {l.item()}"#1epoch 후 출력
 
# After training
print("Prediction (after training)",  4, forward(4).item())
cs

'DeepLearning > 공부' 카테고리의 다른 글

Prompt Engineering Guide 1. Introduction  (0) 2023.03.27
NLP 01.Text Classification  (0) 2021.01.21
DL04. Linear Regression in the PyTorch way  (0) 2021.01.13
DL02. Gradient Descent  (0) 2021.01.06
DL01.Linear Model  (0) 2021.01.06
Comments