CIFAR10与VGG13实战

  • 作者:rex
  • 分类: 深度学习
  • 时间:2020年8月8日
  • 57人已阅读
简介CIFAR10数据集基于VGG13模型训练

CIFAR10与VGG13实战

CIFAR10数据集

U9p1hD.png

代码片

import  tensorflow as tf
from    tensorflow.keras import layers, optimizers, datasets, Sequential
import  os

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
tf.random.set_seed(2345)

device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

conv_layers = [ # 5 units of conv + max pooling
    # unit 1
    layers.Conv2D(64, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(64, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),

    # unit 2
    layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),

    # unit 3
    layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),

    # unit 4
    layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),

    # unit 5
    layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same')

]



def preprocess(x, y):
    # [0~1]
    x = 2*tf.cast(x, dtype=tf.float32) / 255.-1
    y = tf.cast(y, dtype=tf.int32)
    return x,y


(x,y), (x_test, y_test) = datasets.cifar10.load_data()
y = tf.squeeze(y, axis=1)
y_test = tf.squeeze(y_test, axis=1)
print(x.shape, y.shape, x_test.shape, y_test.shape)


train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.shuffle(1000).map(preprocess).batch(128)

test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = test_db.map(preprocess).batch(64)

sample = next(iter(train_db))
print('sample:', sample[0].shape, sample[1].shape,
      tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))


def main():

    # [b, 32, 32, 3] => [b, 1, 1, 512]
    conv_net = Sequential(conv_layers)

    fc_net = Sequential([
        layers.Dense(256, activation=tf.nn.relu),
        layers.Dense(128, activation=tf.nn.relu),
        layers.Dense(10, activation=None),
    ])

    conv_net.build(input_shape=[None, 32, 32, 3])
    fc_net.build(input_shape=[None, 512])
    conv_net.summary()
    fc_net.summary()
    optimizer = optimizers.Adam(lr=1e-4)

    # [1, 2] + [3, 4] => [1, 2, 3, 4]
    variables = conv_net.trainable_variables + fc_net.trainable_variables

    for epoch in range(50):

        for step, (x,y) in enumerate(train_db):

            with tf.GradientTape() as tape:
                # [b, 32, 32, 3] => [b, 1, 1, 512]
                out = conv_net(x)
                # flatten, => [b, 512]
                out = tf.reshape(out, [-1, 512])
                # [b, 512] => [b, 10]
                logits = fc_net(out)
                # [b] => [b, 10]
                y_onehot = tf.one_hot(y, depth=10)
                # compute loss
                loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
                loss = tf.reduce_mean(loss)

            grads = tape.gradient(loss, variables)
            optimizer.apply_gradients(zip(grads, variables))

            if step %100 == 0:
                print(epoch, step, 'loss:', float(loss))



        total_num = 0
        total_correct = 0
        for x,y in test_db:

            out = conv_net(x)
            out = tf.reshape(out, [-1, 512])
            logits = fc_net(out)
            prob = tf.nn.softmax(logits, axis=1)
            pred = tf.argmax(prob, axis=1)
            pred = tf.cast(pred, dtype=tf.int32)

            correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)
            correct = tf.reduce_sum(correct)

            total_num += x.shape[0]
            total_correct += int(correct)

        acc = total_correct / total_num
        print(epoch, 'acc:', acc)



# if __name__ == '__main__':
#     main()
main()

结果

Found GPU at: /device:GPU:0
(50000, 32, 32, 3) (50000,) (10000, 32, 32, 3) (10000,)
sample: (128, 32, 32, 3) (128,) tf.Tensor(-1.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32)
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_10 (Conv2D)           multiple                  1792      
_________________________________________________________________
conv2d_11 (Conv2D)           multiple                  36928     
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 multiple                  0         
_________________________________________________________________
conv2d_12 (Conv2D)           multiple                  73856     
_________________________________________________________________
conv2d_13 (Conv2D)           multiple                  147584    
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 multiple                  0         
_________________________________________________________________
conv2d_14 (Conv2D)           multiple                  295168    
_________________________________________________________________
conv2d_15 (Conv2D)           multiple                  590080    
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 multiple                  0         
_________________________________________________________________
conv2d_16 (Conv2D)           multiple                  1180160   
_________________________________________________________________
conv2d_17 (Conv2D)           multiple                  2359808   
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 multiple                  0         
_________________________________________________________________
conv2d_18 (Conv2D)           multiple                  2359808   
_________________________________________________________________
conv2d_19 (Conv2D)           multiple                  2359808   
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 multiple                  0         
=================================================================
Total params: 9,404,992
Trainable params: 9,404,992
Non-trainable params: 0
_________________________________________________________________
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              multiple                  131328    
_________________________________________________________________
dense_4 (Dense)              multiple                  32896     
_________________________________________________________________
dense_5 (Dense)              multiple                  1290      
=================================================================
Total params: 165,514
Trainable params: 165,514
Non-trainable params: 0
_________________________________________________________________
0 0 loss: 2.302992820739746
0 100 loss: 1.8901972770690918
0 200 loss: 1.7901721000671387
0 300 loss: 1.801344394683838
0 acc: 0.4268
1 0 loss: 1.5000039339065552
1 100 loss: 1.4324939250946045
1 200 loss: 1.3264274597167969
1 300 loss: 1.2084956169128418
1 acc: 0.5523
2 0 loss: 1.1944959163665771
2 100 loss: 1.0641543865203857
2 200 loss: 1.319048285484314
2 300 loss: 1.0663871765136719
2 acc: 0.5734
3 0 loss: 0.9752454161643982
3 100 loss: 1.1654609441757202
3 200 loss: 1.0090887546539307
3 300 loss: 0.8643631935119629
3 acc: 0.6454
4 0 loss: 0.9792678356170654
4 100 loss: 0.8592975735664368
4 200 loss: 0.8445404767990112
4 300 loss: 0.8941479921340942
4 acc: 0.6778
5 0 loss: 0.7622193098068237
5 100 loss: 0.915672779083252
5 200 loss: 0.8174030184745789
5 300 loss: 0.7292886972427368
5 acc: 0.7127
6 0 loss: 0.7300940752029419
6 100 loss: 0.7106158137321472
6 200 loss: 0.7417889833450317
6 300 loss: 0.5376895666122437
6 acc: 0.7307
7 0 loss: 0.44040730595588684
7 100 loss: 0.5800349712371826
7 200 loss: 0.49223780632019043
7 300 loss: 0.5557278990745544
7 acc: 0.74
8 0 loss: 0.4799507260322571
8 100 loss: 0.45104551315307617
8 200 loss: 0.4204782545566559
8 300 loss: 0.44960978627204895
8 acc: 0.7416
9 0 loss: 0.49316203594207764
9 100 loss: 0.4252029061317444
9 200 loss: 0.40155109763145447
9 300 loss: 0.36971157789230347
9 acc: 0.7616
10 0 loss: 0.24532964825630188
10 100 loss: 0.26975929737091064
10 200 loss: 0.2851220369338989
10 300 loss: 0.2935635447502136
10 acc: 0.7717
11 0 loss: 0.20445677638053894
11 100 loss: 0.30031895637512207
11 200 loss: 0.3098926246166229
11 300 loss: 0.25346529483795166
11 acc: 0.7468
12 0 loss: 0.2552114427089691
12 100 loss: 0.3162142038345337
12 200 loss: 0.16345956921577454
12 300 loss: 0.13134022057056427
12 acc: 0.7649
13 0 loss: 0.20834167301654816
13 100 loss: 0.12202440947294235
13 200 loss: 0.1823371797800064
13 300 loss: 0.08800317347049713
13 acc: 0.7698
14 0 loss: 0.10439999401569366
14 100 loss: 0.08311246335506439
14 200 loss: 0.15722894668579102
14 300 loss: 0.13591712713241577
14 acc: 0.7683
15 0 loss: 0.0538744255900383
15 100 loss: 0.08527786284685135
15 200 loss: 0.0652163028717041
15 300 loss: 0.049679212272167206
15 acc: 0.7655
16 0 loss: 0.07588814198970795
16 100 loss: 0.06458132714033127
16 200 loss: 0.03780544176697731
16 300 loss: 0.019340790808200836
16 acc: 0.7666
17 0 loss: 0.09726909548044205
17 100 loss: 0.03869732841849327
17 200 loss: 0.03340745344758034
17 300 loss: 0.12972211837768555
17 acc: 0.7674
18 0 loss: 0.06307273358106613
18 100 loss: 0.05896104499697685
18 200 loss: 0.0825689360499382
18 300 loss: 0.044900115579366684
18 acc: 0.7749
19 0 loss: 0.022179005667567253
19 100 loss: 0.01824028044939041
19 200 loss: 0.01285834051668644
19 300 loss: 0.011381611227989197
19 acc: 0.762
20 0 loss: 0.1490422487258911
20 100 loss: 0.06503628939390182
20 200 loss: 0.029477272182703018
20 300 loss: 0.04929957911372185
20 acc: 0.7681
21 0 loss: 0.09931177645921707
21 100 loss: 0.04310132935643196
21 200 loss: 0.06460382044315338
21 300 loss: 0.0317661389708519
21 acc: 0.7719
22 0 loss: 0.03811441734433174
22 100 loss: 0.07437655329704285
22 200 loss: 0.006810937076807022
22 300 loss: 0.017275065183639526
22 acc: 0.7729
23 0 loss: 0.08091800659894943
23 100 loss: 0.0076401690021157265
23 200 loss: 0.08543173968791962
23 300 loss: 0.012159734033048153
23 acc: 0.7843
24 0 loss: 0.07182680815458298
24 100 loss: 0.012264705263078213
24 200 loss: 0.022202862426638603
24 300 loss: 0.04546601325273514
24 acc: 0.7597
25 0 loss: 0.0698116198182106
25 100 loss: 0.05376395955681801
25 200 loss: 0.04635988920927048
25 300 loss: 0.010879624634981155
25 acc: 0.7772
26 0 loss: 0.04097108915448189
26 100 loss: 0.018222816288471222
26 200 loss: 0.06394268572330475
26 300 loss: 0.06391524523496628
26 acc: 0.7602
27 0 loss: 0.020985107868909836
27 100 loss: 0.014314044266939163
27 200 loss: 0.0071915192529559135
27 300 loss: 0.046115756034851074
27 acc: 0.7714
28 0 loss: 0.014514286071062088
28 100 loss: 0.017840707674622536
28 200 loss: 0.024033505469560623
28 300 loss: 0.02790391445159912
28 acc: 0.7753
29 0 loss: 0.05537789314985275
29 100 loss: 0.03018992766737938
29 200 loss: 0.05451907217502594
29 300 loss: 0.06299859285354614
29 acc: 0.7723
30 0 loss: 0.01716894842684269
30 100 loss: 0.019309524446725845
30 200 loss: 0.04247695580124855
30 300 loss: 0.014780644327402115
30 acc: 0.7724
31 0 loss: 0.05317273736000061
31 100 loss: 0.020048007369041443
31 200 loss: 0.0023788458202034235
31 300 loss: 0.0068132816813886166
31 acc: 0.7785
32 0 loss: 0.020766515284776688
32 100 loss: 0.006842954084277153
32 200 loss: 0.0213004257529974
32 300 loss: 0.011463353410363197
32 acc: 0.7753
33 0 loss: 0.008880363777279854
33 100 loss: 0.02077978104352951
33 200 loss: 0.018707891926169395
33 300 loss: 0.011059397831559181
33 acc: 0.7823
34 0 loss: 0.015089326538145542
34 100 loss: 0.04146172106266022
34 200 loss: 0.04508623853325844
34 300 loss: 0.06830544769763947
34 acc: 0.7722
35 0 loss: 0.05722248926758766
35 100 loss: 0.06661555916070938
35 200 loss: 0.0076147038489580154
35 300 loss: 0.02936476096510887
35 acc: 0.7771
36 0 loss: 0.030617613345384598
36 100 loss: 0.07923397421836853
36 200 loss: 0.0016426900401711464
36 300 loss: 0.03232787549495697
36 acc: 0.7709
37 0 loss: 0.03940636292099953
37 100 loss: 0.0376625694334507
37 200 loss: 0.0548693872988224
37 300 loss: 0.010061042383313179
37 acc: 0.7708
38 0 loss: 0.004660519305616617
38 100 loss: 0.017333801835775375
38 200 loss: 0.0026862535160034895
38 300 loss: 0.03924272581934929
38 acc: 0.781
39 0 loss: 0.006731842644512653
39 100 loss: 0.006652030162513256
39 200 loss: 0.059130921959877014
39 300 loss: 0.024248989298939705
39 acc: 0.7801
40 0 loss: 0.02024053782224655
40 100 loss: 0.004644289612770081
40 200 loss: 0.0051679471507668495
40 300 loss: 0.020759165287017822
40 acc: 0.7875
41 0 loss: 0.023266050964593887
41 100 loss: 0.006671542301774025
41 200 loss: 0.0009782916167750955
41 300 loss: 0.006725775543600321
41 acc: 0.784
42 0 loss: 0.02520041912794113
42 100 loss: 0.047426074743270874
42 200 loss: 0.04175373166799545
42 300 loss: 0.04228482022881508
42 acc: 0.7745
43 0 loss: 0.028878284618258476
43 100 loss: 0.0031117205508053303
43 200 loss: 0.026115767657756805
43 300 loss: 0.04506634175777435
43 acc: 0.7833
44 0 loss: 0.00647880882024765
44 100 loss: 0.03147329390048981
44 200 loss: 0.03646465763449669
44 300 loss: 0.008249848149716854
44 acc: 0.7892
45 0 loss: 0.004242300521582365
45 100 loss: 0.029023541137576103
45 200 loss: 0.02344965562224388
45 300 loss: 0.007129283156245947
45 acc: 0.7913
46 0 loss: 0.0033837100490927696
46 100 loss: 0.05386582016944885
46 200 loss: 0.00831370148807764
46 300 loss: 0.0713522732257843
46 acc: 0.784
47 0 loss: 0.003499329322949052
47 100 loss: 0.0029714852571487427
47 200 loss: 0.023521700873970985
47 300 loss: 0.020276004448533058
47 acc: 0.7833
48 0 loss: 0.010389955714344978
48 100 loss: 0.010127793997526169
48 200 loss: 0.002463790588080883
48 300 loss: 0.001762881875038147
48 acc: 0.7859
49 0 loss: 0.017623843625187874
49 100 loss: 0.004276433493942022
49 200 loss: 0.016966236755251884
49 300 loss: 0.0044584558345377445
49 acc: 0.7668

文章评论

Top