如何使用功能API在Python中处理残差连接?
Keras存在于Tensorflow软件包中。使用以下代码可以访问它。
import tensorflow
from tensorflow import keras
使用功能API创建的模型比使用顺序API创建的模型更具灵活性。功能API可以处理具有非线性拓扑结构的模型,并且可以共享层并处理多个输入和输出。深度学习模型通常是包含多个层的有向无环图(DAG)。功能API有助于构建图层的图。
我们使用Google Colaboratory来运行以下代码。Google Colab或Colaboratory可在浏览器上运行Python代码,不需要任何配置,并免费使用GPU(图形处理器)。Colaboratory是基于Jupyter Notebook构建的。以下是代码片段。
更多Python相关文章,请阅读:Python 教程
示例
print("用于CIFAR10的玩具ResNet模型")
print("为模型生成的图层")
inputs = keras.Input(shape=(32, 32, 3), name="img")
x = layers.Conv2D(32, 3, activation="relu")(inputs)
x = layers.Conv2D(64, 3, activation="relu")(x)
block_1_output = layers.MaxPooling2D(3)(x)
x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_1_output)
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
block_2_output = layers.add([x, block_1_output])
x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_2_output)
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
block_3_output = layers.add([x, block_2_output])
x = layers.Conv2D(64, 3, activation="relu")(block_3_output)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(10)(x)
model = keras.Model(inputs, outputs, name="toy_resnet")
print("有关模型的更多信息")
model.summary()
代码来源 – https://www.tensorflow.org/guide/keras/functional
输出
CIFAR10的玩具ResNet模型
生成的模型层
有关模型的更多信息
模型:"toy_resnet"
________________________________________________________________________________
__________________
层(类型) 输出形状 参数 # 连接到
================================================================================
==================
img(InputLayer) [(None, 32, 32, 3)] 0
________________________________________________________________________________
__________________
conv2d_32(Conv2D) (None,30,30,32) 896 img [0] [0]
________________________________________________________________________________
__________________
conv2d_33(Conv2D) (None,28,28,64) 18496 conv2d_32 [0] [0]
________________________________________________________________________________
__________________
max_pooling2d_8(MaxPooling2D)(None,9,9,64)0 conv2d_33 [0] [0]
________________________________________________________________________________
__________________
conv2d_34(Conv2D) (None,9,9,64) 36928 max_pooling2d_8 [0,0 ]
________________________________________________________________________________
__________________
conv2d_35(Conv2D) (None,9,9,64) 36928 conv2d_34 [0,0]
________________________________________________________________________________
__________________
add_12(Add) (None,9,9,64) 0 conv2d_35 [0,0]
max_pooling2d_8 [0,0]
________________________________________________________________________________
__________________
conv2d_36(Conv2D) (None,9,9,64) 36928 add_12 [0,0]
________________________________________________________________________________
__________________
conv2d_37(Conv2D) (None,9,9,64) 36928 conv2d_36 [0,0]
________________________________________________________________________________
__________________
add_13(Add) (None,9,9,64) 0 conv2d_37 [0,0]
add_12 [0,0]
________________________________________________________________________________
__________________
conv2d_38(Conv2D) (None,7,7,64) 36928 add_13 [0,0]
________________________________________________________________________________
__________________
global_average_pooling2d_1 (Glo(None,64) 0 conv2d_38 [0,0]
________________________________________________________________________________
__________________
dense_40(Dense) (None,256) 16640 global_average_pooling2d_1 [0,0]
________________________________________________________________________________
__________________
dropout_2(Dropout) (None,256) 0 dense_40 [0,0]
________________________________________________________________________________
__________________
dense_41(Dense) (None,10) 2570 dropout_2 [0,0]
================================================================================
==================
总参数:223,242
可训练参数:223,242
不可训练的参数:0
________________________________________________________________________________
__________________
解释
-
该模型具有多个输入和输出。
-
函数式API使得在非线性连接拓扑上操作变得容易。
-
该模型具有非顺序连接,因此无法使用’Sequential’ API。
-
这就是残差连接发挥作用的地方。
-
构建了一个使用CIFAR10的示例ResNet模型来演示相同的操作。