Keras for R, MNIST 간단 예제

MNIST

데이터 준비

library(keras)
mnist<-dataset_mnist()
c(x_train,y_train) %<-% mnist$train
c(x_test,y_test) %<-% mnist$test
NUM_CATEGORY=10

Reshape

x_train<-x_train%>%array_reshape(c(nrow(x_train),784))

원래 데이터는 (images,width,height)의 3차원 데이터이다. images 인덱스만 남기고 width와 height를 flattern시켰다.

x_test<-x_test%>%array_reshape(c(nrow(x_test),784))

마찬가지다.

Categorical

y_train<-y_train%>%
  to_categorical(NUM_CATEGORY)
y_test<-y_test%>%
  to_categorical(NUM_CATEGORY)

Model

model<-keras_model_sequential()

레이어를 쌓자.

model%>%
  layer_dense(256,
              activation='relu',
              input_shape=c(784))%>%
  layer_dropout(0.4)%>%
  layer_dense(128,
              activation='relu')%>%
  layer_dropout(0.3)%>%
  layer_dense(10,
              activation='softmax')

모델요약

summary(model)
____________________________________________________________________________________________________
Layer (type)                                 Output Shape                            Param #        
====================================================================================================
dense_3 (Dense)                              (None, 256)                             200960         
____________________________________________________________________________________________________
dropout_2 (Dropout)                          (None, 256)                             0              
____________________________________________________________________________________________________
dense_4 (Dense)                              (None, 128)                             32896          
____________________________________________________________________________________________________
dropout_3 (Dropout)                          (None, 128)                             0              
____________________________________________________________________________________________________
dense_5 (Dense)                              (None, 10)                              1290           
====================================================================================================
Total params: 235,146
Trainable params: 235,146
Non-trainable params: 0
____________________________________________________________________________________________________

Compile

model%>%compile(
  loss='categorical_crossentropy',
  optimizer=optimizer_rmsprop(),
  metrics=c('accuracy')
)

Training

history<-model%>%
  fit(
    x_train,y_train,
    epochs=30,
    batch_size=128,
    validation_split=0.2
  )

더 학습시키자.

history<-model%>%
  fit(
    x_train,y_train,
    epochs=2,
    batch_size=128,
    validation_split=0.2
  )

Evaluation

model%>%evaluate(x_test,y_test,
                 verbose=FALSE)
$loss
[1] 2.377416

$acc
[1] 0.8518

Prediction

model%>%
  predict_classes(x_test)
   [1] 7 2 1 0 4 1 4 4 6 7 0 6 7 0 1 5 7 7 3 4 7 6 6 5 4 0 7 4 0 1 3 1 3 6 7 2 7 1 2 1 1 7 4 2 3 5 1
  [48] 2 4 4 6 3 5 5 6 0 4 1 7 5 7 8 4 3 7 4 6 4 3 0 7 0 2 7 1 7 3 7 8 7 7 6 2 7 8 4 7 3 6 1 3 6 4 3
  [95] 1 4 1 7 6 7 6 0 5 4 5 4 2 1 4 4 8 7 3 4 7 4 4 4 4 8 5 4 7 6 7 4 0 5 8 5 6 6 5 7 8 1 0 1 6 4 6
 [142] 7 3 1 7 1 8 2 0 4 4 8 5 5 1 5 6 0 3 4 4 6 5 4 6 5 4 5 1 4 4 7 2 3 2 7 1 8 1 8 1 8 5 0 8 4 2 5
 [189] 0 1 1 1 0 4 0 3 1 6 4 2 3 6 1 1 1 3 4 5 2 7 4 5 7 3 4 0 3 6 5 5 7 2 2 7 1 2 8 4 1 7 3 3 8 7 7
 [236] 7 2 2 4 1 5 8 8 7 2 3 0 6 4 2 4 1 4 5 7 7 2 8 2 0 8 5 7 7 8 1 8 1 8 0 3 0 1 7 7 4 1 8 2 1 2 4
 [283] 7 5 4 2 6 4 1 5 8 2 7 2 0 4 0 0 2 8 4 7 1 2 4 0 2 7 4 3 3 0 0 5 1 7 6 5 2 5 7 7 7 3 0 4 2 0 7
 [330] 1 1 2 1 5 3 3 4 7 8 6 3 4 1 3 8 1 0 5 1 7 1 5 0 6 1 8 5 1 7 4 4 6 7 2 5 0 6 5 6 3 7 2 0 8 8 5
 [377] 4 1 1 4 0 7 3 7 6 1 6 2 1 7 2 8 6 1 7 5 2 5 4 4 2 8 3 8 2 4 5 0 3 1 7 7 3 7 4 7 1 7 2 1 4 2 4
 [424] 2 0 4 4 1 4 8 1 8 4 5 7 7 8 3 7 6 0 0 3 0 8 0 6 4 8 5 3 3 2 3 7 1 2 6 8 0 5 6 6 6 7 8 8 2 7 5
 [471] 8 7 6 1 8 4 1 2 5 8 1 4 7 5 4 0 8 4 7 1 0 5 2 3 7 0 4 4 0 6 3 7 3 2 1 3 1 5 6 5 7 8 2 2 6 8 2
 [518] 6 5 4 8 4 7 1 3 0 3 8 3 1 4 6 4 4 6 4 2 1 8 2 5 4 8 8 4 0 0 2 3 2 7 7 0 6 7 4 4 7 4 6 4 0 4 8
 [565] 0 4 6 0 6 5 5 4 8 3 3 4 3 3 8 7 8 0 2 2 1 7 0 6 5 4 3 3 0 7 6 3 8 0 4 4 6 8 6 8 5 7 3 6 0 2 4
 [612] 0 2 8 3 1 7 7 5 1 0 8 4 6 2 6 7 4 3 2 4 8 2 2 7 2 7 3 5 4 1 8 0 2 0 5 2 1 3 7 6 7 1 2 5 8 0 3
 [659] 7 7 4 0 4 1 8 6 7 7 4 3 4 7 1 4 5 1 7 3 4 7 6 7 1 3 2 8 3 3 6 7 2 4 5 8 5 1 1 4 4 3 1 0 7 7 0
 [706] 7 7 4 4 8 5 5 4 0 8 2 1 6 8 4 8 0 4 0 6 1 7 3 8 6 7 2 6 8 3 1 4 6 2 5 4 8 0 6 2 1 7 3 4 1 0 5
 [753] 4 3 1 1 7 4 4 8 4 8 4 0 2 4 5 1 1 6 4 7 1 4 4 2 4 1 5 5 3 8 3 1 4 5 6 8 4 4 1 5 3 8 0 3 2 5 1
 [800] 2 8 3 4 4 0 8 8 3 3 1 7 3 5 7 6 3 2 6 1 3 6 0 7 2 1 7 1 4 2 4 2 1 7 7 6 1 1 2 4 3 1 7 7 4 7 0
 [847] 7 3 1 3 1 0 7 7 0 3 5 5 2 7 6 6 4 2 8 3 5 2 2 5 6 0 8 2 4 2 8 6 8 8 7 4 7 3 0 6 6 3 2 1 5 2 2
 [894] 7 3 0 0 5 7 8 3 4 4 6 0 2 7 1 4 7 4 7 3 4 8 8 4 7 1 2 1 2 2 3 7 3 2 3 4 1 7 4 0 3 5 5 8 6 5 0
 [941] 6 7 6 6 8 2 7 4 1 1 2 4 6 4 4 5 2 3 3 8 7 8 7 1 1 0 7 1 4 4 5 4 0 6 2 2 3 1 5 1 2 0 8 8 1 2 6
 [988] 7 1 6 2 3 4 0 1 2 2 0 8 7
 [ reached getOption("max.print") -- omitted 9000 entries ]

Confusion matrix를 도출하자.

predicted=model%>%
  predict_classes(x_test)
original=y_test%>%apply(1,which.max)-1
table(original,predicted)
        predicted
original    0    1    2    3    4    5    6    7    8
       0  960    0    2    1    0    5    4    4    4
       1    0 1111    4    3    0    0    4    2   11
       2   12    0  937    8   14    2   13   19   27
       3    4    0   17  917    1   25    1   20   25
       4    1    0    2    1  958    0    8    5    7
       5    9    1    2   18    5  826   16    5   10
       6   16    3    0    1    8    6  920    1    3
       7    1    5   22    1    5    1    0  990    3
       8   10    1    6   12    8   11    9   18  899
       9   13    6    0   13  535   17    1  377   47
LS0tCnRpdGxlOiAiTU5JU1QiCm91dHB1dDogaHRtbF9ub3RlYm9vawotLS0KCiMg642w7J207YSwIOykgOu5hAoKYGBge3J9CmxpYnJhcnkoa2VyYXMpCm1uaXN0PC1kYXRhc2V0X21uaXN0KCkKYGBgCgpgYGB7cn0KYyh4X3RyYWluLHlfdHJhaW4pICU8LSUgbW5pc3QkdHJhaW4KYyh4X3Rlc3QseV90ZXN0KSAlPC0lIG1uaXN0JHRlc3QKYGBgCmBgYHtyfQpOVU1fQ0FURUdPUlk9MTAKYGBgCgoKIyMgUmVzaGFwZQoKYGBge3J9CnhfdHJhaW48LXhfdHJhaW4lPiVhcnJheV9yZXNoYXBlKGMobnJvdyh4X3RyYWluKSw3ODQpKQpgYGAK7JuQ656YIOuNsOydtO2EsOuKlCAoaW1hZ2VzLHdpZHRoLGhlaWdodCnsnZggM+ywqOybkCDrjbDsnbTthLDsnbTri6QuIGltYWdlcyDsnbjrjbHsiqTrp4wg64Ko6riw6rOgIHdpZHRo7JmAIGhlaWdodOulvCBmbGF0dGVybuyLnOy8sOuLpC4KYGBge3J9CnhfdGVzdDwteF90ZXN0JT4lYXJyYXlfcmVzaGFwZShjKG5yb3coeF90ZXN0KSw3ODQpKQpgYGAK66eI7LCs6rCA7KeA64ukLgoKIyMgQ2F0ZWdvcmljYWwKYGBge3J9CnlfdHJhaW48LXlfdHJhaW4lPiUKICB0b19jYXRlZ29yaWNhbChOVU1fQ0FURUdPUlkpCnlfdGVzdDwteV90ZXN0JT4lCiAgdG9fY2F0ZWdvcmljYWwoTlVNX0NBVEVHT1JZKQpgYGAKCiMgTW9kZWwKCmBgYHtyfQptb2RlbDwta2VyYXNfbW9kZWxfc2VxdWVudGlhbCgpCmBgYArroIjsnbTslrTrpbwg7IyT7J6QLgpgYGB7cn0KbW9kZWwlPiUKICBsYXllcl9kZW5zZSgyNTYsCiAgICAgICAgICAgICAgYWN0aXZhdGlvbj0ncmVsdScsCiAgICAgICAgICAgICAgaW5wdXRfc2hhcGU9Yyg3ODQpKSU+JQogIGxheWVyX2Ryb3BvdXQoMC40KSU+JQogIGxheWVyX2RlbnNlKDEyOCwKICAgICAgICAgICAgICBhY3RpdmF0aW9uPSdyZWx1JyklPiUKICBsYXllcl9kcm9wb3V0KDAuMyklPiUKICBsYXllcl9kZW5zZSgxMCwKICAgICAgICAgICAgICBhY3RpdmF0aW9uPSdzb2Z0bWF4JykKYGBgCgrrqqjrjbjsmpTslb0KYGBge3J9CnN1bW1hcnkobW9kZWwpCmBgYAojIyBDb21waWxlCgpgYGB7cn0KbW9kZWwlPiVjb21waWxlKAogIGxvc3M9J2NhdGVnb3JpY2FsX2Nyb3NzZW50cm9weScsCiAgb3B0aW1pemVyPW9wdGltaXplcl9ybXNwcm9wKCksCiAgbWV0cmljcz1jKCdhY2N1cmFjeScpCikKYGBgCgojIFRyYWluaW5nCgpgYGB7cn0KaGlzdG9yeTwtbW9kZWwlPiUKICBmaXQoCiAgICB4X3RyYWluLHlfdHJhaW4sCiAgICBlcG9jaHM9MzAsCiAgICBiYXRjaF9zaXplPTEyOCwKICAgIHZhbGlkYXRpb25fc3BsaXQ9MC4yCiAgKQpgYGAKCuuNlCDtlZnsirXsi5ztgqTsnpAuCgpgYGB7cn0KaGlzdG9yeTwtbW9kZWwlPiUKICBmaXQoCiAgICB4X3RyYWluLHlfdHJhaW4sCiAgICBlcG9jaHM9MiwKICAgIGJhdGNoX3NpemU9MTI4LAogICAgdmFsaWRhdGlvbl9zcGxpdD0wLjIKICApCmBgYAoKIyBFdmFsdWF0aW9uCgpgYGB7cn0KbW9kZWwlPiVldmFsdWF0ZSh4X3Rlc3QseV90ZXN0LAogICAgICAgICAgICAgICAgIHZlcmJvc2U9RkFMU0UpCmBgYAoKIyBQcmVkaWN0aW9uCmBgYHtyfQptb2RlbCU+JQogIHByZWRpY3RfY2xhc3Nlcyh4X3Rlc3QpCmBgYAoKQ29uZnVzaW9uIG1hdHJpeOulvCDrj4TstpztlZjsnpAuCgpgYGB7cn0KcHJlZGljdGVkPW1vZGVsJT4lCiAgcHJlZGljdF9jbGFzc2VzKHhfdGVzdCkKb3JpZ2luYWw9eV90ZXN0JT4lYXBwbHkoMSx3aGljaC5tYXgpLTEKdGFibGUob3JpZ2luYWwscHJlZGljdGVkKQpgYGAKCg==

댓글

이 블로그의 인기 게시물

Bradley-Terry Model: paired comparison models

xlwings tutorial - 데이터 계산하여 붙여 넣기

R에서 csv 파일 읽는 법