Python/[코칭스터디 9기] 인공지능 AI 기초 다지기
[인공지능(AI) 기초 다지기] 5. 딥러닝 핵심 기초 (10)
김초송
2023. 3. 16. 18:22
7 - 2) MNIST Introduction
- MNIST
- 손으로 쓰여진 데이터셋
- 28 x 28 image = reshape input image into [batch_size by 784]
-> view 함수 사용 - 1 channel gray image
- 0 ~ 9 digits
- Torchvision
- 파이토치의 다양하고 유명한 데이터셋,
딥러닝 모델 아키텍처,
데이터에 적용하는 트랜스폼(전처리)를 쓸 수 있는 패키지
# Reading Data
import torchvision.datasets as dsets
from torchvision import transforms
mnist_train = dsets.MNIST(root="MNIST_data/", train=True,
transform=transforms.ToTensor(), download=True)
mnist_test = dsets.MNIST(root="MNIST_data/", train=False,
transform=transforms.ToTensor(), download=True)
data_loader = torch.utils.DataLoader(DataLoader=mnist_train, batch_size=100,
shuffle=True, drop_last=True)
- root : MNIST 데이터의 위치 경로
- train : True = train set, False = test set
- transform : MNIST 이미지를 불러올 때 어떤 transform들을 적용할 것인지
일반적으로 pytorch 에서는 0에서 1 사이 값, 순서는 채널 높이 너비 (channel, height, width)
일반적으로 이미지는 0에서 255 사이 값, 순서는 높이 너비 채널 (height, width, channel)
-> ToTensor()로 순서와 크기를 조정해줌 - download : root에 MNIST 데이터가 없다면 다운을 받겠다는 의미
- batch_size : MNIST train data 를 몇 개씩 불러올건지?
- shuffle = True : 무작위
- drop_last = True: batch_size 로 자를 때 맨 마지막에 남는 데이터들을 사용하지 않음
for epoch in range(training_epochs):
for X, Y in data_loader:
X = X.view(-1, 28 * 28).to(device)
- X : MNIST 이미지
- Y : Label = 0 ~ 9 숫자
- X.view(-1, 28 * 28 ) -> 784
batch_size, 1, 28, 28 -> batch_size, 784
- Epoch : 트레이닝 셋 전체가 학습에 사용하면 1 epoch
MNIST 이미지 60,000 장이 전체가 학습에 사용이 되면 1 epoch 이 돌았다고 표현함 - Batch Size : 대용량 데이터를 한 번에 학습시키지 못하기 때문에 나눠서 학습 = 자르는 크기
60,000 장 이미지의 100 batch size라면 600개의 batch 를 얻음 - Iterations : batch 를 몇 번 학습에 사용했는지
1,000 개 트레이닝 셋, 500 batch size = 2 개의 batch = 2 iterations to complete 1 epoch
- Softmax (Classifier)
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
# MNIST data image of shape 28 * 28 = 784
linear = torch.nn.Linear(784, 10, bias=True).to(device)
#parameters
training_epochs = 15
batch_size = 100
# define cost/Loss & optimizer
criterion = torch.nn.CrossEntropyLoss().to(device) # Softmax is iternally computed
optimizer = torch.optim.SGD(linear.parameters(), lr=0.1)
# 학습 코드
for epoch in range(training_epochs):
avg_cost = 0
total_batch = len(data_loader)
for X, Y in data_loader:
# reshape input image into [batch_size by 784]
# label is not one-hot encoded
X = X.view(-1, 28 * 28).to(device)
optimizer.zero_grad()
hypothesis = linear(X)
cost = criterion(hypothesis, Y)
cost.backward()
optimizer.step()
avg_cost += cost / total_batch
print("Epoch: ", "%04d" % (epoch+1), "cost =", "{:.9f}".format(avg_cost))
- Linear input = 784
Linear output = 10 : 0 - 9 label 을 가지기 때문 - PyTorch 는 cross entropy loss 가 자동으로 softmax 계산 -> softmax 별도 선언 X
- linear.parameters() = weight, bias
- 15 번 동안 MNIST 전체를 가지고 반복해서 학습
- data_loader : batch_size 만큼 MNIST 데이터와 label 불러옴 = X, Y
- criterion = cross entropy 를 계산해서 cost 를 구함
- backward 로 gradient 계산 -> 업데이트 (step)
- epoch 을 다 돌고나면 MNIST Classifier 를 직접 학습시킨 것
참고 : https://wikidocs.net/60324
05-05 소프트맥스 회귀로 MNIST 데이터 분류하기
이번 챕터에서는 MNIST 데이터에 대해서 이해하고, 파이토치(PyTorch)로 소프트맥스 회귀를 구현하여 MNIST 데이터를 분류하는 실습을 진행해봅시다. MNIST 데이…
wikidocs.net
- Test
# Test model using test sets
with torch.no_grad():
X_test = mnist_test.test_data.view(-1, 28 * 28).float().to(device)
Y_test = mnist_test.test_labels.to(device)
prediction = linear(X_test)
correct_prediction = torch.argmax(prediction, 1) == Y_test
accuracy = correct_prediction.float().mean()
print("Accuracy: ", accuracy.item())
- no_grad : 아래 코드 범위 안에서는 gradient 계산 안 함
-> 테스트할 때 사용하면 실수 방지 할 수 있음 - argmax : 예측된 결과의 label 구함 -> 실제 label 과 예측값 비교
- Visualization
- 실제 테스트 이미지를 눈으로 확인
import matplotlib.pyplot as plt
import random
r = random.randint(0, len(mnist_test) - 1)
X_single_data = mnist_test.test_data[r:r + 1].view(-1, 28 * 28).float().to(device)
Y_single_data = mnist_test.test_labels[r:r + 1].to(device)
print("Label: ", Y_single_data.item())
single_prediction = linear(X_single_data)
print("Prediction: ", torch.argmax(single_prediction, 1).item())
plt.imshow(mnist_test.test_data[r:r + 1].view(28, 28), cmap="Greys", interpolation="nearest")
plt.show()
