很多网络会有很多重复的层,例如ResNet,一个一个的写太麻烦,所以可以自定义一个模型来构建。
下面这个例子是一个简单的Residual Regressor
首先导入包
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers
from keras.datasets import mnist
加载mnist数据用于训练
train_X, train_y = mnist.load_data()[0]
train_X = train_X.reshape([-1, 28*28])
train_y = keras.utils.to_categorical(train_y, num_classes=10)
首先构建自定义层,可以定义n_layers
个层,以及每层有n_neurons
个神经元
class ResidualBlock(layers.Layer):
def __init__(self, n_layers, n_neurons, **kwargs):
super().__init__(**kwargs)
self.hidden = [layers.Dense(n_neurons, activation='relu',
kernel_initializer='he_normal')
for _ in range(n_layers)]
def call(self, inputs):
Z = inputs
for layer in self.hidden:
Z = layer(Z)
return inputs + Z
接着,构建自定义模型,
class ResidualRegressor(keras.Model):
def __init__(self, output_dim, **kwargs):
super().__init__(**kwargs)
self.hidden1 = layers.Dense(50, activation='relu',
kernel_initializer='he_normal')
self.block1 = ResidualBlock(2, 50)
self.block2 = ResidualBlock(2, 50)
self.out = layers.Dense(output_dim, activation='softmax')
def call(self, inputs):
Z = self.hidden1(inputs)
for _ in range(1 + 3):
Z = self.block1(Z)
Z = self.block2(Z)
return self.out(Z)
构建模型
- 为了把输入参数给model,在Sequential里面可以定义
layer.Input
model = ResidualRegressor(10)
model = keras.Sequential([layers.Input([28*28]),
ResidualRegressor(10)])
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['acc'])
model.fit(train_X, train_y, epochs=10)
Epoch 1/10
60000/60000 [==============================] - 4s 61us/sample - loss: 11.4857 - acc: 0.6385
Epoch 2/10
60000/60000 [==============================] - 4s 58us/sample - loss: 0.8238 - acc: 0.7670
Epoch 3/10
60000/60000 [==============================] - 3s 58us/sample - loss: 0.7337 - acc: 0.7904
Epoch 4/10
60000/60000 [==============================] - 3s 58us/sample - loss: 0.6798 - acc: 0.7996
Epoch 5/10
60000/60000 [==============================] - 3s 58us/sample - loss: 0.6667 - acc: 0.8010
Epoch 6/10
60000/60000 [==============================] - 3s 58us/sample - loss: 0.7053 - acc: 0.8156
Epoch 7/10
60000/60000 [==============================] - 3s 58us/sample - loss: 0.6738 - acc: 0.8362
Epoch 8/10
60000/60000 [==============================] - 3s 58us/sample - loss: 0.6651 - acc: 0.8384
Epoch 9/10
60000/60000 [==============================] - 3s 58us/sample - loss: 0.6663 - acc: 0.8272
Epoch 10/10
60000/60000 [==============================] - 4s 58us/sample - loss: 0.7446 - acc: 0.8041