为什么tf.data比feed_dict好得多,及如何构建一个简单的数据管道

大多数初学者tensorflow教程向读者介绍了feed_dict将数据加载到机器学习模型中的方法,其中数据通过tf.Session.run()或tf.Tensor.eval()函数调用传递给tensorflow 。然而,使用tf.dataAPI,您只需几行代码即可创建高性能数据管道。

在一个简单的feed_dict管道中,每当GPU必须等待CPU为它提供下一批数据时,GPU总是闲置着。

为什么tf.data比feed_dict好得多,及如何构建一个简单的数据管道

tf.data管道可以异步地预取下一批次以使总空闲时间最小化。通过并行加载和预处理操作,可以进一步加快管道的速度。

为什么tf.data比feed_dict好得多,及如何构建一个简单的数据管道

实现最小的图像管道

要构建一个简单的数据管道,您需要两个对象。tf.data.Dataset存储您的机器学习数据集,tf.data.Iterator允许您逐个从机器学习数据集中提取项目。

tf.data.Dataset用于图像管道可以(示意图)是这样的:

[

[Tensor(image), Tensor(label)],

[Tensor(image), Tensor(label)],

...

]

然后,您可以使用tf.data.Iterator逐个检索图像标签对。在实践中,多个图像标签对将被一起批处理,以便迭代器一次抽出一整批。数据集可以从源(比如Python中的文件名列表)创建,也可以通过对现有数据集应用转换创建。以下是可能的转换的一些示例:

  • Dataset(list of image files) → Dataset(actual images)
  • Dataset(6400 images) → Dataset(64 batches with 100 images each)
  • Dataset(list of audio files) → Dataset(shuffled list of audio files)

定义计算图

图像的最小数据管道可能如下所示:

为什么tf.data比feed_dict好得多,及如何构建一个简单的数据管道

以下所有代码都与模型,损失,优化器一起放在您的计算图定义中......首先从文件列表中创建一个张量。

define list of files
files = ['a.png', 'b.png', 'c.png', 'd.png']
 
# create a dataset from filenames
dataset = tf.data.Dataset.from_tensor_slices(files)

为什么tf.data比feed_dict好得多,及如何构建一个简单的数据管道

现在定义一个函数从图像的路径加载图像(作为张量),并使用tf.data.Dataset.map()将该函数应用于机器学习数据集中的所有元素(文件路径)。您还可以向map()添加一个num_parallel_calls=n参数来并行化函数调用。

def load_image(path):
 image_string = tf.read_file(path)
 # Don't use tf.image.decode_image, or the output shape will be undefined
 image = tf.image.decode_jpeg(image_string, channels=3)
 # This will convert to float values in [0, 1]
 image = tf.image.convert_image_dtype(image, tf.float32)
 image = tf.image.resize_images(image, [image_size, image_size])
 return image
# Apply the function load_image to each filename in the dataset
dataset = dataset.map(load_image, num_parallel_calls=8)

为什么tf.data比feed_dict好得多,及如何构建一个简单的数据管道

接下来用于tf.data.Dataset.batch()创建批次:

# Create batches of 64 images each
dataset = dataset.batch(64)

您可能还希望将tf.data.Dataset.prefetch(buffer_size)添加到管道的末尾。这可以确保GPU总是可以立即获得下一批,并如前所述减少GPU的空闲。buffer_size是应该预取的批次数。buffer_size = 1通常是足够的,但在某些情况下,尤其当每批处理时间不同的时候,可以增加它的数量。

dataset = dataset.prefetch(buffer_size=1)

最后,创建一个迭代器以允许我们遍历机器学习数据集。有不同类型的迭代器可用。对于大多数用途,建议使用可initializable 迭代器。

iterator = dataset.make_initializable_iterator()

现在调用tf.data.Iterator.get_next()创建占位符张量,每次评估时,tensorflow 都会填充下一批图像。

batch_of_images = iterator.get_next()

如果您要切换feed_dict,请batch_of_images替换以前的占位符变量。

运行会话

现在像往常一样运行你的机器学习模型,但确保在每个epoch之前评估iterator.initializerop并在每个epoch之后捕获tf.errors.OutOfRangeError异常。

with tf.Session() as session:
 
 for i in range(epochs): 
 session.run(iterator.initializer)
 
 try:
 # Go through the entire dataset
 while True:
 image_batch = session.run(batch_of_images)
 
 except tf.errors.OutOfRangeError:
 print('End of Epoch.')

为什么tf.data比feed_dict好得多,及如何构建一个简单的数据管道

程序nvidia-smi允许您监视GPU利用率,并可以帮助您了解数据管道中的瓶颈。GPU的平均利用率通常应高于70-80%。

更完整的数据管道

Shuffle

使用tf.data.Dataset.shuffle() 进行Shuffle文件名。参数指定一次应该shuffled的元素数量。通常,建议立即对整个列表进行Shuffle。

dataset = tf.data.Dataset.from_tensor_slices(files) 
dataset = dataset.shuffle(len(files))

数据增强

您可以使用的函数tf.image.random_flip_left_right(),tf.image.random_brightness(),tf.image.random_saturation()对你的图片进行简单的数据增强。

# Source
def train_preprocess(image):
 image = tf.image.random_flip_left_right(image)
 image = tf.image.random_brightness(image, max_delta=32.0 / 255.0)
 image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
 # Make sure the image is still in [0, 1]
 image = tf.clip_by_value(image, 0.0, 1.0)
 return image

为什么tf.data比feed_dict好得多,及如何构建一个简单的数据管道

标签

要在图像上加载标签(或其他元数据),只需在创建初始数据集时包含它们:

# files is a python list of image filenames
# labels is a numpy array with label data for each image
dataset = tf.data.Dataset.from_tensor_slices((files, labels))

为什么tf.data比feed_dict好得多,及如何构建一个简单的数据管道

确保使用.map()应用于数据集的任何函数都允许标签数据通过:

def load_image(path, label): 
 # load image
 ...
 return image, label
dataset = dataset.map(load_image)

为什么tf.data比feed_dict好得多,及如何构建一个简单的数据管道

相关推荐