基于Tensorflow2.x和rps数据集的石头剪刀布识别

  • 作者:rex
  • 分类: 深度学习
  • 时间:2020年9月9日
  • 43人已阅读
简介基于Tensorflow2.x和rps数据集来对石头剪刀布进行识别

基于Tensorflow2.x和rps数据集的石头剪刀布识别

1.数据集下载和处理

前往下列地址下载训练数据和测试数据

训练集
https://storage.googleapis.com/laurencemoroney-blog.appspot.com/rps.zip
测试集
https://storage.googleapis.com/laurencemoroney-blog.appspot.com/rps-test-set.zip

数据集解压和数量显示

import os
import zipfile

local_zip = '/tmp/rps.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/tmp/')
zip_ref.close()

local_zip = '/tmp/rps-test-set.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/tmp/')
zip_ref.close()

rock_dir = os.path.join('/tmp/rps/rock')
paper_dir = os.path.join('/tmp/rps/paper')
scissors_dir = os.path.join('/tmp/rps/scissors')

print('total training rock images:', len(os.listdir(rock_dir)))
print('total training paper images:', len(os.listdir(paper_dir)))
print('total training scissors images:', len(os.listdir(scissors_dir)))

rock_files = os.listdir(rock_dir)
print(rock_files[:10])

paper_files = os.listdir(paper_dir)
print(paper_files[:10])

scissors_files = os.listdir(scissors_dir)
print(scissors_files[:10])

显示结果如下所示

('total training rock images:', 840)
('total training paper images:', 840)
('total training scissors images:', 840)
['rock06ck02-084.png', 'rock01-024.png', 'rock06ck02-069.png', 'rock03-086.png', 'rock06ck02-033.png', 'rock01-058.png', 'rock03-036.png', 'rock01-086.png', 'rock07-k03-010.png', 'rock01-110.png']
['paper01-079.png', 'paper03-059.png', 'paper04-108.png', 'paper02-048.png', 'paper02-007.png', 'paper04-022.png', 'paper01-103.png', 'paper07-043.png', 'paper03-017.png', 'paper06-119.png']
['scissors04-003.png', 'testscissors03-082.png', 'scissors03-102.png', 'scissors02-004.png', 'testscissors02-080.png', 'scissors01-081.png', 'scissors02-053.png', 'scissors02-045.png', 'testscissors01-113.png', 'testscissors03-068.png']

数据集数据部分展示

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

pic_index = 2

next_rock = [os.path.join(rock_dir, fname) 
                for fname in rock_files[pic_index-2:pic_index]]
next_paper = [os.path.join(paper_dir, fname) 
                for fname in paper_files[pic_index-2:pic_index]]
next_scissors = [os.path.join(scissors_dir, fname) 
                for fname in scissors_files[pic_index-2:pic_index]]

for i, img_path in enumerate(next_rock+next_paper+next_scissors):
  #print(img_path)
  img = mpimg.imread(img_path)
  plt.imshow(img)
  plt.axis('Off')
  plt.show()

w3ZrIU.png

w3ZWs1.png

2.模型训练和和保存

import tensorflow as tf
import keras_preprocessing
from keras_preprocessing import image
from keras_preprocessing.image import ImageDataGenerator

TRAINING_DIR = "/tmp/rps/"
training_datagen = ImageDataGenerator(
      rescale = 1./255,
      rotation_range=40,
      width_shift_range=0.2,
      height_shift_range=0.2,
      shear_range=0.2,
      zoom_range=0.2,
      horizontal_flip=True,
      fill_mode='nearest')

VALIDATION_DIR = "/tmp/rps-test-set/"
validation_datagen = ImageDataGenerator(rescale = 1./255)

train_generator = training_datagen.flow_from_directory(
    TRAINING_DIR,
    target_size=(150,150),
    class_mode='categorical'
)

validation_generator = validation_datagen.flow_from_directory(
    VALIDATION_DIR,
    target_size=(150,150),
    class_mode='categorical'
)

model = tf.keras.models.Sequential([
    # Note the input shape is the desired size of the image 150x150 with 3 bytes color
    # This is the first convolution
    tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(150, 150, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),
    # The second convolution
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    # The third convolution
    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    # The fourth convolution
    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    # Flatten the results to feed into a DNN
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.5),
    # 512 neuron hidden layer
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dense(3, activation='softmax')
])


model.summary()

model.compile(loss = 'categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

history = model.fit_generator(train_generator, epochs=25, validation_data = validation_generator, verbose = 1)

model.save("rps.h5")

训练结果和网络结构统计信息如下所示

Found 2520 images belonging to 3 classes.
Found 372 images belonging to 3 classes.
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_20 (Conv2D)           (None, 148, 148, 64)      1792      
_________________________________________________________________
max_pooling2d_20 (MaxPooling (None, 74, 74, 64)        0         
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 72, 72, 64)        36928     
_________________________________________________________________
max_pooling2d_21 (MaxPooling (None, 36, 36, 64)        0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 34, 34, 128)       73856     
_________________________________________________________________
max_pooling2d_22 (MaxPooling (None, 17, 17, 128)       0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 15, 15, 128)       147584    
_________________________________________________________________
max_pooling2d_23 (MaxPooling (None, 7, 7, 128)         0         
_________________________________________________________________
flatten_5 (Flatten)          (None, 6272)              0         
_________________________________________________________________
dropout_5 (Dropout)          (None, 6272)              0         
_________________________________________________________________
dense_10 (Dense)             (None, 512)               3211776   
_________________________________________________________________
dense_11 (Dense)             (None, 3)                 1539      
=================================================================
Total params: 3,473,475
Trainable params: 3,473,475
Non-trainable params: 0
_________________________________________________________________
Epoch 1/25
79/79==============================] - 20s 257ms/step - loss: 1.2037 - acc: 0.3786 - val_loss: 1.0084 - val_acc: 0.6344
Epoch 2/25
79/79==============================] - 19s 242ms/step - loss: 0.8781 - acc: 0.6000 - val_loss: 0.3174 - val_acc: 0.9946
Epoch 3/25
79/79==============================] - 20s 250ms/step - loss: 0.5636 - acc: 0.7595 - val_loss: 0.1306 - val_acc: 1.0000
Epoch 4/25
79/79==============================] - 19s 243ms/step - loss: 0.4033 - acc: 0.8397 - val_loss: 0.2548 - val_acc: 0.8414
·······················································································································
Epoch 24/25
79/79==============================] - 19s 246ms/step - loss: 0.0781 - acc: 0.9758 - val_loss: 0.0176 - val_acc: 0.9892
Epoch 25/25
79/79==============================] - 19s 237ms/step - loss: 0.0708 - acc: 0.9810 - val_loss: 0.1145 - val_acc: 0.9543

3. 训练结果可视化

可视化代码如下所示

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend(loc=0)
plt.figure()


plt.show()

训练结果可视化如下所示: w3e2tS.png

4.效果测试

在谷歌的colab上对训练结果进行测试,测试代码为

import numpy as np
from google.colab import files
from keras.preprocessing import image

uploaded = files.upload()

for fn in uploaded.keys():

  # predicting images
  path = fn
  img = image.load_img(path, target_size=(150, 150))
  x = image.img_to_array(img)
  x = np.expand_dims(x, axis=0)

  images = np.vstack([x])
  classes = model.predict(images, batch_size=10)
  print(fn)
  print(classes)

效果为

Saving paper2.png to paper2 (1).png
Saving rock-hires2.png to rock-hires2.png
Saving scissors-hires2.png to scissors-hires2.png
Saving paper-hires2.png to paper-hires2.png
Saving paper-hires1.png to paper-hires1.png
Saving rock-hires1.png to rock-hires1 (1).png
Saving scissors-hires1.png to scissors-hires1.png
Saving scissors9.png to scissors9.png
Saving scissors8.png to scissors8.png
Saving scissors7.png to scissors7 (1).png
Saving rock9.png to rock9.png
Saving rock8.png to rock8.png
Saving rock7.png to rock7.png
Saving paper9.png to paper9.png
Saving paper8.png to paper8.png
Saving paper7.png to paper7.png
Saving scissors6.png to scissors6.png
Saving scissors5.png to scissors5.png
Saving scissors4.png to scissors4.png
Saving paper6.png to paper6.png
Saving paper5.png to paper5.png
Saving paper4.png to paper4.png
Saving rock6.png to rock6.png
Saving rock5.png to rock5.png
Saving rock4.png to rock4.png
Saving scissors3.png to scissors3.png
Saving scissors2.png to scissors2.png
Saving scissors1.png to scissors1.png
Saving paper3.png to paper3.png
Saving paper1.png to paper1.png
Saving rock3.png to rock3.png
Saving rock2.png to rock2.png
Saving rock1.png to rock1.png
scissors-hires1.png
[[ 0.  0.  1.]]
paper4.png
[[  2.77571414e-38   0.00000000e+00   1.00000000e+00]]
paper6.png
[[ 1.  0.  0.]]
scissors1.png
[[ 0.  0.  1.]]
scissors3.png
[[ 0.  0.  1.]]
scissors2.png
[[ 0.  0.  1.]]
scissors9.png
[[ 0.  0.  1.]]
rock7.png
[[ 0.  1.  0.]]
rock9.png
[[ 0.  1.  0.]]
rock8.png
[[ 0.  1.  0.]]
paper-hires1.png
[[ 1.  0.  0.]]
rock4.png
[[ 0.  1.  0.]]
paper5.png
[[ 0.  0.  1.]]
rock2.png
[[ 0.  1.  0.]]
paper9.png
[[ 0.  0.  1.]]
paper1.png
[[ 1.  0.  0.]]
rock6.png
[[ 0.  1.  0.]]
rock5.png
[[ 0.  1.  0.]]
scissors4.png
[[ 0.  0.  1.]]
rock-hires2.png
[[ 0.  1.  0.]]
scissors-hires2.png
[[ 0.  0.  1.]]
paper8.png
[[ 1.  0.  0.]]
paper-hires2.png
[[ 1.  0.  0.]]
rock3.png
[[ 0.  1.  0.]]
rock1.png
[[ 0.  1.  0.]]
paper3.png
[[ 0.  0.  1.]]
paper2.png
[[ 1.  0.  0.]]
scissors5.png
[[ 0.  0.  1.]]
scissors7.png
[[ 0.  0.  1.]]
scissors6.png
[[ 0.  0.  1.]]
rock-hires1.png
[[ 0.  1.  0.]]
scissors8.png
[[ 0.  0.  1.]]
paper7.png
[[ 1.  0.  0.]]

文章评论

Top