머신러닝 & 딥러닝

[딥러닝] 10. CNN

김초송 2023. 5. 10. 15:30

CNN (Convolutional Neural Network)

  • = 합성곱 신경망
  • convolutional 층과 pooling 층을 포함하는 신경망

 

기존 신경망

  • = 완전 연결 계층 (fully connected)
  • 입력층 -> Affine-> Relu -> Affine -> Relu -> softmax
  • 입력층에서 이미지 데이터를 flatten 함
  • (1, 784)
  • 문제점 : 이미지 형상이 무시됨
    글자가 이동하거나 회전, 확대/축소 등 조금이라도 변형되면 새로운 학습 데이터로 처리해야 함

출처: 모두의 연구소

CNN

  • 입력층 -> convolution -> Relu -> pooling -> 완전 연결 계층
  • convolutional layer (합성곱층) 에서 이미지 특징 (feature) 추출
    -> 이를 기반으로 neural network 로 분류
  • convolutional layer = 특징을 추출하는 기능을 하는 필터(Filter)
                                      +
    이 필터의 값을 비선형 값으로 바꾸어 주는 Activiation 함수
  • 한 장의 사진에서 다양한 비슷한 사진을 만들어내는 작업이 일어남 (이미지 증식)
  • (1, 28, 28, 3) = (사진수, 28, 28, 색조(RGB)) -> 4차원

출처: 모두의 연구소
출처 : https://halfundecided.medium.com/딥러닝-머신러닝-cnn-convolutional-neural-networks-쉽게-이해하기-836869f88375

 

# 신경망 모델 구현
model = Sequential()
model.add(Conv2D(100, kernel_size=(5, 5), input_shape=(32, 32, 3), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
# model.add(Conv2D(10, kernel_size=(5, 5),  activation='relu'))
# model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
model.add(Flatten())
model.add(BatchNormalization())
model.add(Dropout(0.2))
model.add(Dense(500, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.2))
model.add(Dense(100, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.2))
model.add(Dense(2, activation='softmax'))

# 5. 모델을 설정
model.compile(optimizer='Adam', 
              loss='categorical_crossentropy', 
              metrics=['acc'])  # 학습과정에서 정확도를 보려고 

#6. 모델을 훈련시킵니다. 

callbacks = [EarlyStopping(monitor='val_acc', patience=20, verbose=1, restore_best_weights=True)] # restore_best_weights : 가장 높은 성능일 때 가중치
history = model.fit(x_train, y_train, 
                    epochs = 200,  
                    batch_size = 100,
                    validation_data=(x_test, y_test),
                    callbacks=callbacks)

왜 CNN 2 + ANN 2 보다 CNN 1 + ANN 3 이 더 성능이 좋은지,,,?ㅠ