Python画像解析入門-3 2018/05/30

Pythonライブラリscikit-learnとMNISTデータセットで機械学習を行う.


Question

Exercise 3.1: import MINIST dataset into Python using scikit-learn library.

Exercise 3.2: see the data number and pixel size of imported data.

Exercise 3.3: save one or two imported image in Python into your PC.

Exercise 3.4: conduct machine learning with scikit-learn. (You can decrease the number of training data, if the running time is too long.)

Exercise 3.5: calculate overall accuracy.


Answer

#! /usr/bin/python3
import numpy as np
from sklearn.datasets import fetch_mldata
from sklearn.ensemble import RandomForestClassifier
from sklearn import metrics

# Exercise 3.1
mnist = fetch_mldata('MNIST original', data_home=".")
data = np.array(mnist["data"])
label = np.array(mnist["target"])

# Exercise 3.2
print(data.shape)
print(label.shape)

# Exercise 3.3
import PIL
from PIL import Image
test_1 = Image.fromarray(np.uint8(data[0].reshape(28, 28)))
test_rgb_1 = Image.merge("RGB", (test_1,test_1,test_1))
test_rgb_1.save("test_1.png", 'PNG')

test_30000 = Image.fromarray(np.uint8(data[29999].reshape(28, 28)))
test_rgb_30000 = Image.merge("RGB", (test_30000,test_30000,test_30000)) 
test_rgb_30000.save("test_30000.png", 'PNG')

# Exercise 3.4
## make training and validation data
training_num = 60000
validation_num = 10000

training_num_ori = []
for i in range(1, 60001):
  training_num_ori.append(i)
training_num_ori_np = np.array(training_num_ori)
training_data_random = np.random.choice(training_num_ori_np, training_num, replace=False)

validation_num_ori = []
for i in range(60001, 70001):
  validation_num_ori.append(i)
validation_num_ori_np = np.array(validation_num_ori) 
validation_data_random = np.random.choice(validation_num_ori_np, validation_num, replace=False)

training_data = data[training_data_random - 1]
training_label = label[training_data_random - 1]
validation_data = data[validation_data_random - 1]
validation_label = label[validation_data_random - 1]

## call Random Forest classifier.
## you can check available packages with help(sklearn) command.
clf = RandomForestClassifier(n_estimators=10,n_jobs=2)
clf.fit(training_data, training_label)
predict = clf.predict(validation_data)

# exercise 3.5
ac_score = metrics.accuracy_score(validation_label, predict)
print("score is " + str(ac_score))