如何为Tensorflow配置花卉数据集以提高性能?
当创建一个模型时,花卉数据集会给出一个特定的准确率。如果需要为性能配置模型,则使用缓冲预取和Rescaling层。使用Keras模型将此层应用于数据集,将Rescaling层作为Keras模型的一部分。
更多Python相关文章,请阅读:Python 教程
我们将使用包含数千张花卉图片的花卉数据集。它包含5个子目录,每个类别都有一个子目录。
我们将使用Google Colaboratory来运行以下代码。Google Colab或Colaboratory可以在浏览器上运行Python代码,无需配置,可以免费访问GPU(图形处理器)。Colaboratory建立在Jupyter Notebook之上。
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
num_classes = 5
print("A sequential model is built")
model = tf.keras.Sequential([
layers.experimental.preprocessing.Rescaling(1./255),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes)
])
代码来源:https://www.tensorflow.org/tutorials/load_data/images
输出
A sequential model is built
说明
- 使用缓冲预取,以便从磁盘中产生数据而不阻塞I/O。
- 这是加载数据的重要步骤。
- ‘.cache()’方法可在第一轮执行后将图像保留在内存中。
- 这确保了数据集在训练模型时不会成为障碍。
- 如果数据集太大而无法适应内存,则可以使用相同的方法创建高性能的磁盘缓存。
- ‘.prefetch()’方法在数据训练时重叠数据预处理和模型执行操作。