인공지능/파이토치

파이토치 - MLP 구축과 학습

해피밀세트 2020. 7. 20. 03:06
반응형

 

 

 

 

1) 손글씨 문자를 판별하는 MLP 작성

 

# 라이브러리 불러오기

import torch

from torch import nn

# 네트워크 구성

net = nn.Sequential(

    nn.Linear(6432),

    nn.ReLU(),

    nn.Linear(3216),

    nn.ReLU(),

    nn.Linear(1610)

)

 

 

 

2) 손글씨 문자 데이터의 학습 코드의 나머지 부분

 

# 라이브러리 및 데이터 불러오기

import torch

from torch import nn, optim

from sklearn.datasets import load_digits

digits = load_digits()

# 독립변수, 종속변수 분리

X = digits.data

Y = digits.target

# Numpy의 ndarray를 파이토치의 텐서로 변환

X = torch.tensor(X, dtype=torch.float32)

Y = torch.tensor(Y, dtype=torch.int64)

# 소프트맥스 크로스 엔트로피

loss_fn = nn.CrossEntropyLoss()

# Adam

optimizer = optim.Adam(net.parameters())

# 손실 함수의 로그

losses = []

# 100회 반복

for epoc in range(100):

  # backward 메서드로 계산된 이전 값을 삭제

  optimizer.zero_grad()

 

  # 선형 모델로 y의 예측 값 게산

  y_pred = net(X)

 

  # MSE loss와 w를 사용한 미분 계산

  loss = loss_fn(y_pred, Y)

  loss.backward()

 

  # 경사를 갱신

  optimizer.step()

 

  # 수령 확인을 위해 loss를 기록해 둔다.

  losses.append(loss.item())

  print(loss.item())

 

 

 

3) to 메서드를 이용해 GPU로 전송

 

X = X.to("cuda:0")
Y = Y.to("cuda:0")
net.to("cuda:0")

# 이후 처리는 동일하게 optimizer를 설정해서 학습 루프를 돌린다.

반응형