인공지능/파이토치

파이토치 - 선형 회귀 모델

해피밀세트 2020. 7. 20. 02:25
반응형

 

 

 

 

1. 파이토치로 선형 회귀 모델 만들기 (직접 만들기)

 

y = 1 + 2x1 + 3x2

 

1) 테스트 데이터 생성 및 파라미터 학습을 위한 변수 정의

 

# 라이브러리 불러오기

import torch

 

# 참(True)의 계수

w_true = torch.Tensor([123])

w_true

# X 데이터 준비. 절편을 회귀 계수에 포함시키기 위해 X의 최초 차원에 1을 추가해 둔다.

X = torch.cat([torch.ones(1001), torch.randn(1002)], 1)

X

# 참의 계수와 각 X의 내적을 행렬과 벡터의 곱으로 모아서 계산

y = torch.mv(X, w_true) + torch.randn(100) * 0.5

y

# 기울기 하강으로 최적화하기 위해 파라미터 Tensor를 난수로 초기화해서 생성

w = torch.randn(3, requires_grad=True)

w

# 학습률

gamma = 0.1

 

 

 

 

2) 경사 하강법으로 파라미터 최적화

 

# 손실 함수의 로그

losses = []

 

# 100회 반복

for epoc in range(100):

  # 전회의 backward 메서드로 계산된 경사값을 초기화

  w.grad = None

 

  #선형 모델로 y 예측값을 계산

  y_pred = torch.mv(X, w)

 

  # MSE loss와 w에 의한 미분을 계산

  loss = torch.mean((y - y_pred)**2)

  loss.backward()

 

  # 경사를 갱신한다.

  # w를 그대로 대입해서 갱신하면 다른 텐서가 돼서 계산 
    그래프가 망가진다. 따라서 data만 갱신한다.

  w.data = w.data - gamma * w.grad.data

 

  # 수렴 확인을 위한 loss를 기록해 둔다.

  losses.append(loss.item())

  print(loss.item())

 

 

 

3) matpotlib 그래프 그리기

 

%matplotlib inline
from matplotlib import pyplot as plt
plt.plot(losses)

 

 

 

4) 회귀 계수의 확인

 

w

 


 

2. 파이토치로 선형 회귀 모델 만들기 (nn, optim 모듈 사용)

 

1) 선형 회귀 모델의 구축과 최적화 준비

 

# 라이브러리 불러오기

from torch import nn, optim

# Linear층을 작성. 이번에는 절편은 회귀 계수에 포함하므로 입력 차원을 3으로 하고 bias(절편)을 False로 한다.

net = nn.Linear(in_features=3, out_features=1, bias=False)

# SGD의 최적화기상에서 정의한 네트워크의 파라미터를 전달해서 초기화

optimizer = optim.SGD(net.parameters(), lr=0.1)

# MSE loss 클래스

loss_fn = nn.MSELoss()

 

 

 

2) 최적화 루프 (반복 루프) 돌리기

 

# 손실 함수 로그

losses = []

 

# 100회 반복

for epoc in range(100):

  # 전회의 backward 메서드로 계산된 경사값을 초기화

  optimizer.zero_grad()

 

  # 선형 모델으로 y 예측값을 계산

  y_pred = net(X)

 

  # MSE loss 계산

  # y_pred는 (n,1)과 같은 shape를 지니고 있으므로 (n,)으로 변경할 필요가 있다.

  loss = loss_fn(y_pred.view_as(y), y)

 

  # loss의 w를 사용한 미분계산

  loss.backward()

 

  # 경사를 갱신한다.

  optimizer.step()

 

  # 수렴 확인을 위한 loss를 기록해 둔다.

  losses.append(loss.item())

  print(loss.item())

 

 

 

3) 수렴한 모델의 파라미터 확인

 

list(net.parameters())

 

반응형