Tensorflow2.x可视化训练结果

  • 作者:rex
  • 分类: 深度学习
  • 时间:2020年8月22日
  • 76人已阅读
简介使用matplotlib来可视化训练结果

Tensorflow2.x可视化训练结果

综述

在本篇文章中我们将对上一篇文章基于LeNet-5的MNIST手写数字识别中的训练结果利用matplotlib进行可视化的操作。

在进行可视化操作之前你需要学会的技能有:

如果你不会使用这两个库的话,可以点击上面的超链接学习。

获取训练数据

在Tensorflow2.x中当我们使用keras来进行模型的训练时使用了一个叫fit的函数,如下所示:

# 网络定义
network = Sequential([
    # 卷积层1
    layers.Conv2D(filters=6,kernel_size=(5,5),activation="relu",input_shape=(28,28,1),padding="same"),
    layers.MaxPool2D(pool_size=(2,2),strides=2),

    # 卷积层2
    layers.Conv2D(filters=16,kernel_size=(5,5),activation="relu",padding="same"),
    layers.MaxPool2D(pool_size=2,strides=2),

    # 卷积层3
    layers.Conv2D(filters=32,kernel_size=(5,5),activation="relu",padding="same"),

    layers.Flatten(),

    # 全连接层1
    layers.Dense(200,activation="relu"),

    # 全连接层2
    layers.Dense(10,activation="softmax")    
])
network.summary()

# 模型训练 训练30个epoch
network.compile(optimizer='adam',loss="sparse_categorical_crossentropy",metrics=["accuracy"])
network.fit(trainImage,trainLabel,epochs=30,validation_split=0.1)

我们使用了network.fit进行了网络的训练,我们如果需要获得训练的结果的话需要添加如下的写法

# 模型训练 训练30个epoch
history = network.fit(trainImage,trainLabel,epochs=30,validation_split=0.1)
data = history.history

这里的history可以读取获得我们的训练数据,它返回的是一个python的字典格式。为了方便数据的处理,我们可以将其存储为json格式的文件。

# json文件写入
def save_json(data):
    with open('history.json', 'w') as json_file:
        json.dump(data, json_file)

# 保存训练结果
save_json(data)

在进行json文件的写入时记得导入json库,然后完成json文件的保存。

当我们保存完毕后可以获得一个叫history.json的文件,里面保存着我们的训练数据。

{
    "loss": [
        xxx,
    ],
    "accuracy": [
        xxx,
    ],
    "val_loss": [
        xxx,
    ],
    "val_accuracy": [
        xxx,
    ]
}

其中四个字段的解释如下:

loss accuracy val_loss val_accuracy
训练中的loss 训练中的准确率 测试中的loss 测试中的准确率

读取文件完成可视化操作

导入库

import matplotlib.pyplot as plt
import numpy as np
import json

读取history.json文件

# 载入数据
def load_data():

    try:
        with open('history.json','r') as json_file:
            data = json.load(json_file)
            print('data load success!')
    except:
        print('data load failed!')
    return data

解析数据

# 读取数据
def get_data(data):
    # 训练集损失数据
    loss = data['loss']
    np_loss = np.array(loss)
    y_loss = np_loss
    x_epoch = np.arange(30)

    # 训练集准确度数据
    acc = data['accuracy']
    np_acc = np.array(acc)
    y_acc = np_acc
    x_epoch = np.arange(30)

    # 测试集损失数据
    val_loss = data['val_loss']
    np_val_loss = np.array(val_loss)
    y_val_loss = np_val_loss
    x_epoch = np.arange(30)

    # 测试集准确度数据
    val_acc = data['val_accuracy']
    np_val_acc = np.array(val_acc)
    y_val_acc = np_val_acc
    x_epoch = np.arange(30)

    return y_loss,y_acc,y_val_loss,y_val_acc,x_epoch

注意! 为了方便matplotlib进行图形的绘制我们将数据转为numpy的array格式,具体代码如上所示。

绘制结果

这里我们绘制一个两行两列的图像,使用matplotlib来进行绘制

def main():
    data = load_data()
    y_loss,y_acc,y_val_loss,y_val_acc,x_epoch = get_data(data)

    plt.figure(figsize=(10,4))
    # loss图像
    plt.subplot(2,2,1)
    plt.plot(x_epoch,y_loss,label='loss')
    plt.legend()
    plt.xlabel('epoch')
    plt.ylabel('loss')

    # acc图像
    plt.subplot(2,2,2)
    plt.plot(x_epoch,y_acc,label='acc',color='red')
    plt.legend()
    plt.xlabel('epoch')
    plt.ylabel('acc')
    plt.ylim(0.9,1)

    # val_loss图像
    plt.subplot(2,2,3)
    plt.plot(x_epoch,y_val_loss,label='val_loss')
    plt.legend()
    plt.xlabel('epoch')
    plt.ylabel('val_loss')

    # val_acc图像
    plt.subplot(2,2,4)
    plt.plot(x_epoch,y_val_acc,label='val_acc',color='red')
    plt.legend()
    plt.xlabel('epoch')
    plt.ylabel('val_acc')
    plt.ylim(0.9,1)

    plt.show()

结果展示

绘制结果如下所示 结果

文章评论

Top