now is better than never

2. [지도학습] 나이브 베이즈 분류기 (Naive Bayes Classifier) 본문

머신러닝 & 딥러닝

2. [지도학습] 나이브 베이즈 분류기 (Naive Bayes Classifier)

김초송 2023. 3. 16. 16:24

데이터 분석 방법

  1. 통계 데이터 분석 : 가설 검정 
    • 귀무가설 : 코로나가 매출에 영향을 미치지 않았다
    • 대립가설 : 코로나(독립변수) 가 매출(종속변수) 에 영향을 미쳤다
  2.  머신러닝 데이터 분석
    • 빅데이터를 학습시켜서 기존 데이터로 새로운 데이터를 분류하거나 미래의 데이터를 예측하는 것

 

나이브 베이즈 알고리즘

  • 확률로 분류하는 머신러닝 방법 (50 % 이상이면 positive, 아니면 negative)
  • 예시
    • 스팸 메일과 정상 메일을 분류
    • 컴퓨터 네트워크에서 발견된 침입이나 비정상적인 행위 탐지
    • 일련의 관찰된 증상에 대한 의학적 질병 진단(분류)
  • 종류
    1.  BernoulliNB : 이산형 데이터 분류
    2. GaussianNB : 연속형 데이터 분류
    3. MultinomialNB : 이산형 데이터 분류

 

- 공식

출처 : https://www.youtube.com/watch?v=NyGYgD7vgCk&list=LL&index=7

  1. 사전확률:  이미 알고 있는 확률 
  2. 우도 : 이미 알고 있는 사건이 발생했다는 조건하에 다른 사건이 발생할 확률 
  3. 사후확률 :  사전확률과 우도확률을 통해서 알게되는 조건부 확률 

 

- R로 구현하기

  1. 데이터를 로드합니다.
  2. 데이터를 확인합니다.
  3. 결측치를 확인합니다.
  4. 훈련 데이터와 테스트 데이터를 분리합니다.
  5. 모델 훈련
  6. 모델 예측
  7. 모델 평가 
  8. 모델 성능개선
#1. 데이터를 로드합니다.
package_version(R.version)

mush <-  read.csv("/content/mushrooms.csv", stringsAsFactors=TRUE)
mush
  • UCI 머신러닝 저장소에서 제공하는 데이터
    23 종의 버섯과 8124개 샘플
  • 버섯의 22개의 특징은 갓모양, 갓색깔, 냄새, 주름크기, 주름색, 줄기모양, 서식지와 같은 특징
#2.  데이터를 확인합니다.
str(mush)  # 전부 factor 형
nrow(mush) # 8124의 행으로 구성 
ncol(mush)  # 23개의 컬럼으로 구성 

prop.table(table(mush$type))

#3.  결측치를 확인합니다.
colSums( is.na(mush) ) 

#4. 훈련 데이터와 테스트 데이터를 분리합니다. ( 훈련 테스트 80%, 테스트 데이터 20%)

.libPaths()
install.packages("caret", lib = .libPaths()[1])
library(caret)
set.seed(1) # 어느 자리에서든 똑같은 기준으로 분리하려고 지정 

k <-  createDataPartition( mush$type, p=0.8, list=F)  # 훈련 데이터 80%, 테스트 20%
k # 80% 에 해당하는 데이터 행의 인덱스 번호 

train_data <- mush[ k , ]
test_data <- mush[ -k, ] 

dim( train_data ) # ( 6500, 23 ) 
dim( test_data )   # ( 1624, 23 )
  • createDataPartition( mush$type, p=0.8, list=F) 
    : list = F -> 행렬로 저장
#5. 모델 훈련
install.packages("e1071", lib = .libPaths()[1])
library(e1071)
model <-  naiveBayes( type ~ . ,  data=train_data ) 

model  #  버섯 데이터에 대한 우도표 확인

#6. 모델 예측
result <-  predict(  model ,  test_data[   , -1 ]  ) # 정답컬럼 빼고 나머지 컬럼으로 예측
result

#7. 모델 평가 
sum( result == test_data[  , 1 ] )   /  length(result )   # 0.937807881773399
  • naiveBayes( 정답컬럼 ~  .   ,  data = 훈련 데이터 프레임명)
    data : 정답 컬럼 외의 다른 모든 컬럼들
# 이원교차표 확인
install.packages("gmodels", lib = .libPaths()[1])
library(gmodels) 

CrossTable(  x=test_data[   , 1 ],  y = result )   #  CrossTable(x=실제값, y=예측값)
  • FN ( 실제로는 독버섯인데 식용으로 예측 ) 이 많이 깨문에 0에 가깝게 만들도록 모델 개선
#8. 모델 성능개선
library( e1071 ) 
model2 <- naiveBayes( type ~  .,  data=train_data,  laplace=0.0001 ) 
result2 <-  predict(  model2,  test_data[   , -1 ] ) 
sum( result2 == test_data[  , 1] )  / length(result2)    # 0.995073891625616

library(gmodels) 
CrossTable(  x=test_data[   , 1 ],  y = result2 )   #  CrossTable(x=실제값, y=예측값) 

options(scipen=999)  # 0.00001 이런식으로 소숫점이 제대로 다 보이게 하려 셋팅
library(e1071)
y <- 0                           
jumpby <- 0.00001

for (i in 1:10) {
    y <- y + jumpby
    model2 <- naiveBayes( type~. , data=train_data, laplace=y)
    result2 <- predict( model2, test_data[  , -1] )
    a <- sum(result2 == test_data[  , 1]) / length(result2) * 100
    print ( paste(y, '일때 정확도 ', a))  
    }

 

- 라플라스 추정기

  • 나이브 베이즈 알고리즘의 성능을 높이는 방법
  • 나플라스 값(= 아주아주 작은 값) 을 추가해서 성능을 높임
  • 하이퍼 파라미터
    • KNN -> k 값
    • naivebayes -> laplace 값
model <- naiveBayes( type ~ ., data=train_data, laplace=0.0001 )

 

- 정상 메일과 스팸 메일을 분류하는 나이브 베이즈 모델 생성

  • 비아그라를 포함하는 메일이 스팸일 확률 -> 정확하게 분류가 안 될 수도 있음
    P( 스팸 | 비아그라 )
  • 비아그라, 돈, 식료품, 구독 취소를 포함하는 메일이 스팸일 확률
    P( 스팸 | 비아그라 ∩  돈 ∩ 식료품 ∩ 구독취소 )
  • 비아그라와 구독 취소를 포함하고 돈과 식료품을 포함하지 않는 메일이 스팸일 확률
    P( 스팸 | 비아그라 ∩ ¬돈  ∩ ¬ 식료품 ∩ 구독취소)
    =  P(비아그라|스팸) * P(¬ 돈|스팸) * P(¬ 식료품 | 스팸) * P(구독취소|스팸) * P(스팸) = 스팸의 전체 우도
  • 여기서 라플라스를 더할때 분자에는 1 씩, 분모에는 독립변수 수만큼 더함

 

나이브 베이즈 요약

  1. 나이브 베이즈 모델에서 하이퍼 파라미터에 해당하는 파라미터 이름은?
    -> 라플라스값 (laplace)

  2. 나이브 베이즈 수학 공식을 기술하시오
    -> P(A|B) = P(B|A) * P(A) / P(B)

  3. 나이브 베이즈 공식에서 사전확률과 사후확률은 무엇인가?
    • 사전확률 : 학습할 데이터로 이미 알고 있는 확률
    • 사후확률 : 사전확률과 우도확률을 통해서 알게 되는 조건부 확률

  4. 나이브 베이즈 모델로 예측할 수 있는 것은? 답 : 1
    1) 분류
    2) 수치예측
    3) 추천 시스템 구현
    4) 시계열 예측 
        
  5. 데이터 정규화의 종류 2가지
    1) 최대최소 정규화
    2) 표준화

 

나중에 참고해보기: https://mole-starseeker.tistory.com/78

본 내용은 아이티윌 '빅데이터&머신러닝 전문가 양성 과정' 을 수강하며 작성한 내용입니다.