如何使用功能API在Python中处理残差连接?

如何使用功能API在Python中处理残差连接?

Keras存在于Tensorflow软件包中。使用以下代码可以访问它。

import tensorflow
from tensorflow import keras

使用功能API创建的模型比使用顺序API创建的模型更具灵活性。功能API可以处理具有非线性拓扑结构的模型,并且可以共享层并处理多个输入和输出。深度学习模型通常是包含多个层的有向无环图(DAG)。功能API有助于构建图层的图。

我们使用Google Colaboratory来运行以下代码。Google Colab或Colaboratory可在浏览器上运行Python代码,不需要任何配置,并免费使用GPU(图形处理器)。Colaboratory是基于Jupyter Notebook构建的。以下是代码片段。

更多Python相关文章,请阅读:Python 教程

示例

print("用于CIFAR10的玩具ResNet模型")
print("为模型生成的图层")
inputs = keras.Input(shape=(32, 32, 3), name="img")
x = layers.Conv2D(32, 3, activation="relu")(inputs)
x = layers.Conv2D(64, 3, activation="relu")(x)
block_1_output = layers.MaxPooling2D(3)(x)

x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_1_output)
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
block_2_output = layers.add([x, block_1_output])

x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_2_output)
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
block_3_output = layers.add([x, block_2_output])

x = layers.Conv2D(64, 3, activation="relu")(block_3_output)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(10)(x)

model = keras.Model(inputs, outputs, name="toy_resnet")
print("有关模型的更多信息")
model.summary()

代码来源 – https://www.tensorflow.org/guide/keras/functional

输出

CIFAR10的玩具ResNet模型
生成的模型层
有关模型的更多信息
模型:"toy_resnet"
________________________________________________________________________________
__________________
层(类型)          输出形状          参数 #       连接到
================================================================================
==================
img(InputLayer)       [(None, 32, 32, 3)]    0
________________________________________________________________________________
__________________
conv2d_32(Conv2D)    (None,30,30,32)     896          img [0] [0]
________________________________________________________________________________
__________________
conv2d_33(Conv2D)    (None,28,28,64)    18496         conv2d_32 [0] [0]
________________________________________________________________________________
__________________
max_pooling2d_8(MaxPooling2D)(None,9,9,64)0          conv2d_33 [0] [0]
________________________________________________________________________________
__________________
conv2d_34(Conv2D)       (None,9,9,64)       36928       max_pooling2d_8 [0,0 ]
________________________________________________________________________________
__________________
conv2d_35(Conv2D)       (None,9,9,64)       36928       conv2d_34 [0,0]
________________________________________________________________________________
__________________
add_12(Add)             (None,9,9,64)          0       conv2d_35 [0,0]
                                          max_pooling2d_8 [0,0]
________________________________________________________________________________
__________________
conv2d_36(Conv2D)          (None,9,9,64)    36928       add_12 [0,0]
________________________________________________________________________________
__________________
conv2d_37(Conv2D)          (None,9,9,64)    36928       conv2d_36 [0,0]
________________________________________________________________________________
__________________
add_13(Add)               (None,9,9,64)       0       conv2d_37 [0,0]
                                      add_12 [0,0]
________________________________________________________________________________
__________________
conv2d_38(Conv2D)          (None,7,7,64)    36928       add_13 [0,0]
________________________________________________________________________________
__________________
global_average_pooling2d_1    (Glo(None,64)      0       conv2d_38 [0,0]
________________________________________________________________________________
__________________
dense_40(Dense)             (None,256)          16640    global_average_pooling2d_1 [0,0]
________________________________________________________________________________
__________________
dropout_2(Dropout)          (None,256)          0          dense_40 [0,0]
________________________________________________________________________________
__________________
dense_41(Dense)             (None,10)          2570       dropout_2 [0,0]
================================================================================
==================
总参数:223,242
可训练参数:223,242
不可训练的参数:0
________________________________________________________________________________
__________________

解释

  • 该模型具有多个输入和输出。

  • 函数式API使得在非线性连接拓扑上操作变得容易。

  • 该模型具有非顺序连接,因此无法使用’Sequential’ API。

  • 这就是残差连接发挥作用的地方。

  • 构建了一个使用CIFAR10的示例ResNet模型来演示相同的操作。

Python教程

Java教程

Web教程

数据库教程

图形图像教程

大数据教程

开发工具教程

计算机教程