第五讲 卷积神经网络 --baseline
import tensorflow as tf import os import numpy as np from matplotlib import pyplot as plt from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense from tensorflow.keras import Model np.set_printoptions(threshold=np.inf) cifar10 = tf.keras.datasets.cifar10 (x_train, y_train), (x_test, y_test) = cifar10.load_data() x_train, x_test = x_train/25.0, x_test/255.0 class BaseLine(Model): def __init__(self): super(BaseLine, self).__init__() self.c1 = Conv2D(filters=6, kernel_size=(5, 5), padding=‘same‘) #卷积层 self.b1 = BatchNormalization() #BN层 self.a1 = Activation(‘relu‘) #激活层 self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding=‘same‘) #池化层 self.d1 = Dropout(0.2) #dropou层 self.flatten = Flatten() self.f1 = Dense(128, activation=‘relu‘) self.d2 = Dropout(0.2) self.f2 = Dense(10, activation=‘softmax‘) def call(self, x): x = self.c1(x) x = self.b1(x) x = self.a1(x) x = self.p1(x) x = self.d1(x) x = self.flatten(x) x = self.f1(x) x = self.d2(x) y = self.f2(x) return y model = BaseLine() model.compile(optimizer=‘adam‘, loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics = [‘sparse_categorical_accuracy‘]) checkpoint_save_path = "./checkpoint/Baseline.ckpt" if os.path.exists(checkpoint_save_path + ".index"): print("--------------------load the model-----------------") model.load_weights(checkpoint_save_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True) history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) model.summary() with open(‘./weights.txt‘, ‘w‘) as file: for v in model.trainable_variables: file.write(str(v.name) + ‘\n‘) file.write(str(v.shape) + ‘\n‘) file.write(str(v.numpy()) + ‘\n‘) def plot_acc_loss_curve(history): # 显示训练集和验证集的acc和loss曲线 from matplotlib import pyplot as plt acc = history.history[‘sparse_categorical_accuracy‘] val_acc = history.history[‘val_sparse_categorical_accuracy‘] loss = history.history[‘loss‘] val_loss = history.history[‘val_loss‘] plt.figure(figsize=(15, 5)) plt.subplot(1, 2, 1) plt.plot(acc, label=‘Training Accuracy‘) plt.plot(val_acc, label=‘Validation Accuracy‘) plt.title(‘Training and Validation Accuracy‘) plt.legend() #plt.grid() plt.subplot(1, 2, 2) plt.plot(loss, label=‘Training Loss‘) plt.plot(val_loss, label=‘Validation Loss‘) plt.title(‘Training and Validation Loss‘) plt.legend() #plt.grid() plt.show() plot_acc_loss_curve(history)
相关推荐
demm 2020-09-18
sunxinyu 2020-09-17
walegahaha 2020-08-15
cherry0 2020-08-15
georgesale 2020-08-14
fengzhimohan 2020-07-23
wenxuegeng 2020-06-14
wenxuegeng 2020-06-08
cherry0 2020-06-06
hnyzyty 2020-06-05
hnyzyty 2020-06-03
wenxuegeng 2020-06-03
walegahaha 2020-06-03
cherry0 2020-06-03
zhaorui0 2020-06-01
kuankeTech 2020-06-01
hnyzyty 2020-05-12
georgesale 2020-05-10
hnyzyty 2020-05-05
walegahaha 2020-05-05
wenxuegeng 2020-05-04
玉来愈宏的随笔 2020-05-02
liqing 2020-04-19