为什么tf.data比feed_dict好得多,及如何构建一个简单的数据管道
大多数初学者tensorflow教程向读者介绍了feed_dict将数据加载到机器学习模型中的方法,其中数据通过tf.Session.run()或tf.Tensor.eval()函数调用传递给tensorflow 。然而,使用tf.dataAPI,您只需几行代码即可创建高性能数据管道。
在一个简单的feed_dict管道中,每当GPU必须等待CPU为它提供下一批数据时,GPU总是闲置着。
tf.data管道可以异步地预取下一批次以使总空闲时间最小化。通过并行加载和预处理操作,可以进一步加快管道的速度。
实现最小的图像管道
要构建一个简单的数据管道,您需要两个对象。tf.data.Dataset存储您的机器学习数据集,tf.data.Iterator允许您逐个从机器学习数据集中提取项目。
tf.data.Dataset用于图像管道可以(示意图)是这样的:
[
[Tensor(image), Tensor(label)],
[Tensor(image), Tensor(label)],
...
]
然后,您可以使用tf.data.Iterator逐个检索图像标签对。在实践中,多个图像标签对将被一起批处理,以便迭代器一次抽出一整批。数据集可以从源(比如Python中的文件名列表)创建,也可以通过对现有数据集应用转换创建。以下是可能的转换的一些示例:
- Dataset(list of image files) → Dataset(actual images)
- Dataset(6400 images) → Dataset(64 batches with 100 images each)
- Dataset(list of audio files) → Dataset(shuffled list of audio files)
定义计算图
图像的最小数据管道可能如下所示:
以下所有代码都与模型,损失,优化器一起放在您的计算图定义中......首先从文件列表中创建一个张量。
define list of files files = ['a.png', 'b.png', 'c.png', 'd.png'] # create a dataset from filenames dataset = tf.data.Dataset.from_tensor_slices(files)
现在定义一个函数从图像的路径加载图像(作为张量),并使用tf.data.Dataset.map()将该函数应用于机器学习数据集中的所有元素(文件路径)。您还可以向map()添加一个num_parallel_calls=n参数来并行化函数调用。
def load_image(path): image_string = tf.read_file(path) # Don't use tf.image.decode_image, or the output shape will be undefined image = tf.image.decode_jpeg(image_string, channels=3) # This will convert to float values in [0, 1] image = tf.image.convert_image_dtype(image, tf.float32) image = tf.image.resize_images(image, [image_size, image_size]) return image # Apply the function load_image to each filename in the dataset dataset = dataset.map(load_image, num_parallel_calls=8)
接下来用于tf.data.Dataset.batch()创建批次:
# Create batches of 64 images each dataset = dataset.batch(64)
您可能还希望将tf.data.Dataset.prefetch(buffer_size)添加到管道的末尾。这可以确保GPU总是可以立即获得下一批,并如前所述减少GPU的空闲。buffer_size是应该预取的批次数。buffer_size = 1通常是足够的,但在某些情况下,尤其当每批处理时间不同的时候,可以增加它的数量。
dataset = dataset.prefetch(buffer_size=1)
最后,创建一个迭代器以允许我们遍历机器学习数据集。有不同类型的迭代器可用。对于大多数用途,建议使用可initializable 迭代器。
iterator = dataset.make_initializable_iterator()
现在调用tf.data.Iterator.get_next()创建占位符张量,每次评估时,tensorflow 都会填充下一批图像。
batch_of_images = iterator.get_next()
如果您要切换feed_dict,请batch_of_images替换以前的占位符变量。
运行会话
现在像往常一样运行你的机器学习模型,但确保在每个epoch之前评估iterator.initializerop并在每个epoch之后捕获tf.errors.OutOfRangeError异常。
with tf.Session() as session: for i in range(epochs): session.run(iterator.initializer) try: # Go through the entire dataset while True: image_batch = session.run(batch_of_images) except tf.errors.OutOfRangeError: print('End of Epoch.')
程序nvidia-smi允许您监视GPU利用率,并可以帮助您了解数据管道中的瓶颈。GPU的平均利用率通常应高于70-80%。
更完整的数据管道
Shuffle
使用tf.data.Dataset.shuffle() 进行Shuffle文件名。参数指定一次应该shuffled的元素数量。通常,建议立即对整个列表进行Shuffle。
dataset = tf.data.Dataset.from_tensor_slices(files) dataset = dataset.shuffle(len(files))
数据增强
您可以使用的函数tf.image.random_flip_left_right(),tf.image.random_brightness(),tf.image.random_saturation()对你的图片进行简单的数据增强。
# Source def train_preprocess(image): image = tf.image.random_flip_left_right(image) image = tf.image.random_brightness(image, max_delta=32.0 / 255.0) image = tf.image.random_saturation(image, lower=0.5, upper=1.5) # Make sure the image is still in [0, 1] image = tf.clip_by_value(image, 0.0, 1.0) return image
标签
要在图像上加载标签(或其他元数据),只需在创建初始数据集时包含它们:
# files is a python list of image filenames # labels is a numpy array with label data for each image dataset = tf.data.Dataset.from_tensor_slices((files, labels))
确保使用.map()应用于数据集的任何函数都允许标签数据通过:
def load_image(path, label): # load image ... return image, label dataset = dataset.map(load_image)