Keras for R, MNIST 간단 예제
MNIST
데이터 준비
library(keras)
mnist<-dataset_mnist()
c(x_train,y_train) %<-% mnist$train
c(x_test,y_test) %<-% mnist$test
NUM_CATEGORY=10
Reshape
x_train<-x_train%>%array_reshape(c(nrow(x_train),784))
원래 데이터는 (images,width,height)의 3차원 데이터이다. images 인덱스만 남기고 width와 height를 flattern시켰다.
x_test<-x_test%>%array_reshape(c(nrow(x_test),784))
마찬가지다.
Categorical
y_train<-y_train%>%
to_categorical(NUM_CATEGORY)
y_test<-y_test%>%
to_categorical(NUM_CATEGORY)
Model
model<-keras_model_sequential()
레이어를 쌓자.
model%>%
layer_dense(256,
activation='relu',
input_shape=c(784))%>%
layer_dropout(0.4)%>%
layer_dense(128,
activation='relu')%>%
layer_dropout(0.3)%>%
layer_dense(10,
activation='softmax')
모델요약
summary(model)
____________________________________________________________________________________________________
Layer (type) Output Shape Param #
====================================================================================================
dense_3 (Dense) (None, 256) 200960
____________________________________________________________________________________________________
dropout_2 (Dropout) (None, 256) 0
____________________________________________________________________________________________________
dense_4 (Dense) (None, 128) 32896
____________________________________________________________________________________________________
dropout_3 (Dropout) (None, 128) 0
____________________________________________________________________________________________________
dense_5 (Dense) (None, 10) 1290
====================================================================================================
Total params: 235,146
Trainable params: 235,146
Non-trainable params: 0
____________________________________________________________________________________________________
Compile
model%>%compile(
loss='categorical_crossentropy',
optimizer=optimizer_rmsprop(),
metrics=c('accuracy')
)
Training
history<-model%>%
fit(
x_train,y_train,
epochs=30,
batch_size=128,
validation_split=0.2
)
더 학습시키자.
history<-model%>%
fit(
x_train,y_train,
epochs=2,
batch_size=128,
validation_split=0.2
)
Evaluation
model%>%evaluate(x_test,y_test,
verbose=FALSE)
$loss
[1] 2.377416
$acc
[1] 0.8518
Prediction
model%>%
predict_classes(x_test)
[1] 7 2 1 0 4 1 4 4 6 7 0 6 7 0 1 5 7 7 3 4 7 6 6 5 4 0 7 4 0 1 3 1 3 6 7 2 7 1 2 1 1 7 4 2 3 5 1
[48] 2 4 4 6 3 5 5 6 0 4 1 7 5 7 8 4 3 7 4 6 4 3 0 7 0 2 7 1 7 3 7 8 7 7 6 2 7 8 4 7 3 6 1 3 6 4 3
[95] 1 4 1 7 6 7 6 0 5 4 5 4 2 1 4 4 8 7 3 4 7 4 4 4 4 8 5 4 7 6 7 4 0 5 8 5 6 6 5 7 8 1 0 1 6 4 6
[142] 7 3 1 7 1 8 2 0 4 4 8 5 5 1 5 6 0 3 4 4 6 5 4 6 5 4 5 1 4 4 7 2 3 2 7 1 8 1 8 1 8 5 0 8 4 2 5
[189] 0 1 1 1 0 4 0 3 1 6 4 2 3 6 1 1 1 3 4 5 2 7 4 5 7 3 4 0 3 6 5 5 7 2 2 7 1 2 8 4 1 7 3 3 8 7 7
[236] 7 2 2 4 1 5 8 8 7 2 3 0 6 4 2 4 1 4 5 7 7 2 8 2 0 8 5 7 7 8 1 8 1 8 0 3 0 1 7 7 4 1 8 2 1 2 4
[283] 7 5 4 2 6 4 1 5 8 2 7 2 0 4 0 0 2 8 4 7 1 2 4 0 2 7 4 3 3 0 0 5 1 7 6 5 2 5 7 7 7 3 0 4 2 0 7
[330] 1 1 2 1 5 3 3 4 7 8 6 3 4 1 3 8 1 0 5 1 7 1 5 0 6 1 8 5 1 7 4 4 6 7 2 5 0 6 5 6 3 7 2 0 8 8 5
[377] 4 1 1 4 0 7 3 7 6 1 6 2 1 7 2 8 6 1 7 5 2 5 4 4 2 8 3 8 2 4 5 0 3 1 7 7 3 7 4 7 1 7 2 1 4 2 4
[424] 2 0 4 4 1 4 8 1 8 4 5 7 7 8 3 7 6 0 0 3 0 8 0 6 4 8 5 3 3 2 3 7 1 2 6 8 0 5 6 6 6 7 8 8 2 7 5
[471] 8 7 6 1 8 4 1 2 5 8 1 4 7 5 4 0 8 4 7 1 0 5 2 3 7 0 4 4 0 6 3 7 3 2 1 3 1 5 6 5 7 8 2 2 6 8 2
[518] 6 5 4 8 4 7 1 3 0 3 8 3 1 4 6 4 4 6 4 2 1 8 2 5 4 8 8 4 0 0 2 3 2 7 7 0 6 7 4 4 7 4 6 4 0 4 8
[565] 0 4 6 0 6 5 5 4 8 3 3 4 3 3 8 7 8 0 2 2 1 7 0 6 5 4 3 3 0 7 6 3 8 0 4 4 6 8 6 8 5 7 3 6 0 2 4
[612] 0 2 8 3 1 7 7 5 1 0 8 4 6 2 6 7 4 3 2 4 8 2 2 7 2 7 3 5 4 1 8 0 2 0 5 2 1 3 7 6 7 1 2 5 8 0 3
[659] 7 7 4 0 4 1 8 6 7 7 4 3 4 7 1 4 5 1 7 3 4 7 6 7 1 3 2 8 3 3 6 7 2 4 5 8 5 1 1 4 4 3 1 0 7 7 0
[706] 7 7 4 4 8 5 5 4 0 8 2 1 6 8 4 8 0 4 0 6 1 7 3 8 6 7 2 6 8 3 1 4 6 2 5 4 8 0 6 2 1 7 3 4 1 0 5
[753] 4 3 1 1 7 4 4 8 4 8 4 0 2 4 5 1 1 6 4 7 1 4 4 2 4 1 5 5 3 8 3 1 4 5 6 8 4 4 1 5 3 8 0 3 2 5 1
[800] 2 8 3 4 4 0 8 8 3 3 1 7 3 5 7 6 3 2 6 1 3 6 0 7 2 1 7 1 4 2 4 2 1 7 7 6 1 1 2 4 3 1 7 7 4 7 0
[847] 7 3 1 3 1 0 7 7 0 3 5 5 2 7 6 6 4 2 8 3 5 2 2 5 6 0 8 2 4 2 8 6 8 8 7 4 7 3 0 6 6 3 2 1 5 2 2
[894] 7 3 0 0 5 7 8 3 4 4 6 0 2 7 1 4 7 4 7 3 4 8 8 4 7 1 2 1 2 2 3 7 3 2 3 4 1 7 4 0 3 5 5 8 6 5 0
[941] 6 7 6 6 8 2 7 4 1 1 2 4 6 4 4 5 2 3 3 8 7 8 7 1 1 0 7 1 4 4 5 4 0 6 2 2 3 1 5 1 2 0 8 8 1 2 6
[988] 7 1 6 2 3 4 0 1 2 2 0 8 7
[ reached getOption("max.print") -- omitted 9000 entries ]
Confusion matrix를 도출하자.
predicted=model%>%
predict_classes(x_test)
original=y_test%>%apply(1,which.max)-1
table(original,predicted)
predicted
original 0 1 2 3 4 5 6 7 8
0 960 0 2 1 0 5 4 4 4
1 0 1111 4 3 0 0 4 2 11
2 12 0 937 8 14 2 13 19 27
3 4 0 17 917 1 25 1 20 25
4 1 0 2 1 958 0 8 5 7
5 9 1 2 18 5 826 16 5 10
6 16 3 0 1 8 6 920 1 3
7 1 5 22 1 5 1 0 990 3
8 10 1 6 12 8 11 9 18 899
9 13 6 0 13 535 17 1 377 47
LS0tCnRpdGxlOiAiTU5JU1QiCm91dHB1dDogaHRtbF9ub3RlYm9vawotLS0KCiMg642w7J207YSwIOykgOu5hAoKYGBge3J9CmxpYnJhcnkoa2VyYXMpCm1uaXN0PC1kYXRhc2V0X21uaXN0KCkKYGBgCgpgYGB7cn0KYyh4X3RyYWluLHlfdHJhaW4pICU8LSUgbW5pc3QkdHJhaW4KYyh4X3Rlc3QseV90ZXN0KSAlPC0lIG1uaXN0JHRlc3QKYGBgCmBgYHtyfQpOVU1fQ0FURUdPUlk9MTAKYGBgCgoKIyMgUmVzaGFwZQoKYGBge3J9CnhfdHJhaW48LXhfdHJhaW4lPiVhcnJheV9yZXNoYXBlKGMobnJvdyh4X3RyYWluKSw3ODQpKQpgYGAK7JuQ656YIOuNsOydtO2EsOuKlCAoaW1hZ2VzLHdpZHRoLGhlaWdodCnsnZggM+ywqOybkCDrjbDsnbTthLDsnbTri6QuIGltYWdlcyDsnbjrjbHsiqTrp4wg64Ko6riw6rOgIHdpZHRo7JmAIGhlaWdodOulvCBmbGF0dGVybuyLnOy8sOuLpC4KYGBge3J9CnhfdGVzdDwteF90ZXN0JT4lYXJyYXlfcmVzaGFwZShjKG5yb3coeF90ZXN0KSw3ODQpKQpgYGAK66eI7LCs6rCA7KeA64ukLgoKIyMgQ2F0ZWdvcmljYWwKYGBge3J9CnlfdHJhaW48LXlfdHJhaW4lPiUKICB0b19jYXRlZ29yaWNhbChOVU1fQ0FURUdPUlkpCnlfdGVzdDwteV90ZXN0JT4lCiAgdG9fY2F0ZWdvcmljYWwoTlVNX0NBVEVHT1JZKQpgYGAKCiMgTW9kZWwKCmBgYHtyfQptb2RlbDwta2VyYXNfbW9kZWxfc2VxdWVudGlhbCgpCmBgYArroIjsnbTslrTrpbwg7IyT7J6QLgpgYGB7cn0KbW9kZWwlPiUKICBsYXllcl9kZW5zZSgyNTYsCiAgICAgICAgICAgICAgYWN0aXZhdGlvbj0ncmVsdScsCiAgICAgICAgICAgICAgaW5wdXRfc2hhcGU9Yyg3ODQpKSU+JQogIGxheWVyX2Ryb3BvdXQoMC40KSU+JQogIGxheWVyX2RlbnNlKDEyOCwKICAgICAgICAgICAgICBhY3RpdmF0aW9uPSdyZWx1JyklPiUKICBsYXllcl9kcm9wb3V0KDAuMyklPiUKICBsYXllcl9kZW5zZSgxMCwKICAgICAgICAgICAgICBhY3RpdmF0aW9uPSdzb2Z0bWF4JykKYGBgCgrrqqjrjbjsmpTslb0KYGBge3J9CnN1bW1hcnkobW9kZWwpCmBgYAojIyBDb21waWxlCgpgYGB7cn0KbW9kZWwlPiVjb21waWxlKAogIGxvc3M9J2NhdGVnb3JpY2FsX2Nyb3NzZW50cm9weScsCiAgb3B0aW1pemVyPW9wdGltaXplcl9ybXNwcm9wKCksCiAgbWV0cmljcz1jKCdhY2N1cmFjeScpCikKYGBgCgojIFRyYWluaW5nCgpgYGB7cn0KaGlzdG9yeTwtbW9kZWwlPiUKICBmaXQoCiAgICB4X3RyYWluLHlfdHJhaW4sCiAgICBlcG9jaHM9MzAsCiAgICBiYXRjaF9zaXplPTEyOCwKICAgIHZhbGlkYXRpb25fc3BsaXQ9MC4yCiAgKQpgYGAKCuuNlCDtlZnsirXsi5ztgqTsnpAuCgpgYGB7cn0KaGlzdG9yeTwtbW9kZWwlPiUKICBmaXQoCiAgICB4X3RyYWluLHlfdHJhaW4sCiAgICBlcG9jaHM9MiwKICAgIGJhdGNoX3NpemU9MTI4LAogICAgdmFsaWRhdGlvbl9zcGxpdD0wLjIKICApCmBgYAoKIyBFdmFsdWF0aW9uCgpgYGB7cn0KbW9kZWwlPiVldmFsdWF0ZSh4X3Rlc3QseV90ZXN0LAogICAgICAgICAgICAgICAgIHZlcmJvc2U9RkFMU0UpCmBgYAoKIyBQcmVkaWN0aW9uCmBgYHtyfQptb2RlbCU+JQogIHByZWRpY3RfY2xhc3Nlcyh4X3Rlc3QpCmBgYAoKQ29uZnVzaW9uIG1hdHJpeOulvCDrj4TstpztlZjsnpAuCgpgYGB7cn0KcHJlZGljdGVkPW1vZGVsJT4lCiAgcHJlZGljdF9jbGFzc2VzKHhfdGVzdCkKb3JpZ2luYWw9eV90ZXN0JT4lYXBwbHkoMSx3aGljaC5tYXgpLTEKdGFibGUob3JpZ2luYWwscHJlZGljdGVkKQpgYGAKCg==
댓글
댓글 쓰기