본문 바로가기
교재 리뷰/파이토치 첫걸음

파이토치 첫걸음 - 4. 인공 신경망

by 펄서까투리 2019. 11. 18.

# 세줄요약 #

  1. 인공뉴런(Perceptron)은 들어온 입력값에 가중치(Weight)를 곱하고 편차(Bias)를 더해준 뒤 모두 다 더한 값을 활성화함수(Activation function)를 통해 변형하여 전달하는 단위를 의미하고, 이러한 뉴런들이 모인 네트워크를 인공신경망(Artificial Neural Network; ANN)이라 부른다(hidden layer가 2개 이상이면, Deep Neural Network; DNN).
  2. 인공신경망(ANN)에 들어온 입력값이 여러개의 은닉층(hidden layer)을 거치며 최종결과값인 예측값 y*를 구하는 과정을 순전파(forward propagation)라 한다(y* = w3 × σ(w2 × σ(w1 × χ + b1) + b2) + b3 ; w = weight, b = bias, σ = activation function).
  3. 예측값 y*과 정답 y의 차이로 계산된 손실을 연쇄법칙(경사하강법)을 이용하여 입력층까지 다시 전달하는 과정을 역전파(backward propagation)라 한다.

 

인공신경망. 2개의 은닉층(hidden layer)을 가진 구조를 가지고 있다. [출처: https://www.digitaltrends.com/cool-tech/what-is-an-artificial-neural-network/]

 

 

다양한 종류의 활성화 함수들. [출처: https://towardsdatascience.com/complete-guide-of-activation-functions-34076e95d044]

 

#code block#
######################
# Training ANN model #
#######E##############

#Library Import
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import matplotlib.pyplot as plt

# Data Create
num_data = 1000
num_epoch = 10000

noise = init.normal_(torch.FloatTensor(num_data,1),std=1)
x = init.uniform_(torch.Tensor(num_data,1),-15,15)
y = (x**2) + 3 
y_noise = y + noise

# Model Build
model = nn.Sequential(
          nn.Linear(1,6),
          nn.ReLU(),
          nn.Linear(6,10),
          nn.ReLU(),
          nn.Linear(10,6),
          nn.ReLU(),
          nn.Linear(6,1),
      )

loss_func = nn.L1Loss()
optimizer = optim.SGD(model.parameters(),lr=0.0002)

# Model Training
loss_array = []
for i in range(num_epoch):
    optimizer.zero_grad()
    output = model(x)
    
    loss = loss_func(output,y_noise)
    loss.backward()
    optimizer.step()
    
    loss_array.append(loss)
    
# Loss Graph
plt.plot(loss_array)
plt.show()

# Result Visualization
plt.figure(figsize=(10,10))
plt.scatter(x.detach().numpy(),y_noise,label="Original Data")
plt.scatter(x.detach().numpy(),output.detach().numpy(),label="Model Output")
plt.legend()
plt.show()

"""
출처: 파이토치 첫걸음 / 최건호 / 한빛미디어
"""

 

* 출처: 파이토치 첫걸음 / 최건호 / 한빛미디어

728x90
728x90

댓글