机器学习:用tfrecord数据集加速Keras

机器学习:用tfrecord数据集加速Keras

机器学习中tfrecord数据集基本上是您的数据集,保存为硬盘驱动器上的协议缓冲区。使用此格式的好处是:

  • 您无需将完整数据集加载到内存中。您可以通过数据集类逐个获取数据。当您的GPU处理数字时,Tensorflow会负责加载更多数据。
  • 由于您不需要先将数据加载到numpy数组中然后将其放入keras / tensorflow会话,因此速度极快。你只是保持C ++的端到端。

要构建自己的输入管道,您需要执行以下步骤。

  • 将数据集转换为TFRecord数据集并将其保存到磁盘。
  • 使用TFRecordDataset类加载此数据集
  • 将其放入入您的Kerasmodel。

转换为TFRecord数据集

创建机器学习数据集非常简单。您需要做的就是使用以下内容定义数据集,Python代码如下:

import tensorflow as tf

# Helperfunctions to make your feature definition more readable

def _int64_feature(value):

return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):

return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# load data in numpy

image,label = ReadFunctionToCreateYourDataAsNumpyArrays()

# create filewriter

writer = tf.python_io.TFRecordWriter(FILEPATH)

# Define the features of your tfrecord

feature = {'image': _bytes_feature(tf.compat.as_bytes(image.tostring())),

'label': _int64_feature(int(label))}

# Serialize to string and write to file

example = tf.train.Example(features=tf.train.Features(feature=feature))

writer.write(example.SerializeToString())

尝试创建小的数据集,这些数据集不大于你的RAM,但足够大,这样tfrecords的序列化会给你带来优势。“相对”小的文件大小可以让您在读取数据时改组数据,并在读取时执行其他很酷的技巧。

我尝试将文件大小设置为4-5 GB(在32 GB RAM Maschine上)。一个文件通常包括一个类别。这使我可以灵活地快速添加和删除类别。

加载数据集以进行训练

这里有一个小GIST,向您展示如何将新创建的tfrecord加载为tf数据集。Python代码如下:

import tensorflow as tf

def _parse_function(proto):

# define your tfrecord again. Remember that you saved your image as a string.

keys_to_features = {'image': tf.FixedLenFeature([], tf.string),

"label": tf.FixedLenFeature([], tf.int64)}

# Load one example

parsed_features = tf.parse_single_example(example_proto, keys_to_features)

# Turn your saved image string into an array

parsed_features['image'] = tf.decode_raw(

parsed_features['image'], tf.uint8)

return parsed_features['image'], parsed_features["label"]

def create_dataset(filepath):

# This works with arrays as well

dataset = tf.data.TFRecordDataset(filepath)

# Maps the parser on every filepath in the array. You can set the number of parallel loaders here

dataset = dataset.map(_parse_function, num_parallel_calls=8)

# This dataset will go on forever

dataset = dataset.repeat()

# Set the number of datapoints you want to load and shuffle

dataset = dataset.shuffle(SHUFFLE_BUFFER)

# Set the batchsize

dataset = dataset.batch(BATCH_SIZE)

# Create an iterator

iterator = dataset.make_one_shot_iterator()

# Create your tf representation of the iterator

image, label = iterator.get_next()

# Bring your picture back in shape

image = tf.reshape(image, [-1, 256, 256, 1])

# Create a one hot array for your labels

label = tf.one_hot(label, NUM_CLASSES)

return image, label

注意,您没有直接加载数据。你只是在建造一条管道。label和image变量只是张量,在以后的tensorflow会话中填充。

Ingest到Kerasmodel

最后一步非常简单。加载创建label和image张量,然后为keras模型创建一个输入层。这个层是通过传递图像创建的。在模型的编译过程中,您还可以以类似的模式交付target_tensors。

最棘手的部分是,Keras不知道一个Epoch需要多少steps。在加载期间,我们告诉tensorflow永远重复数据集。

我通常计算每个文件的数据点数量,并将它们写入.txt文件中。将数据样本的总数除以批处理的大小,可以得到每个epoch的步骤。Python代码如下:

import tensorflow as tf

from tensorflow.python import keras as keras

STEPS_PER_EPOCH= SUM_OF_ALL_DATASAMPLES / BATCHSIZE

#Get your datatensors

image, label = create_dataset(filenames_train)

#Combine it with keras

model_input = keras.layers.Input(tensor=image)

#Build your network

model_output = keras.layers.Flatten(input_shape=(-1, 255, 255, 1))(model_input)

model_output = keras.layers.Dense(1000, activation='relu')(model_output)

#Create your model

train_model = keras.models.Model(inputs=model_input, outputs=model_output)

#Compile your model

train_model.compile(optimizer=keras.optimizers.RMSprop(lr=0.0001),

loss='mean_squared_error',

metrics=[soft_acc],

target_tensors=[label])

#Train the model

train_model.fit(epochs=EPOCHS,

steps_per_epoch=STEPS_PER_EPOC)

#More Kerasstuff here

这就是要让你的神经网络在Keras上使用tfrecord数据集的所有工作。

最后

Tfrecords是改善和清理数据加载的好方法。它提高了数据读取的速度,使您可以将类保存在分类文件中。

甚至可以在读取时写入具有tensorflow的数据filters 或数据transformers (例如,翻转或旋转图像,添加噪声,blocking bad data)。这样,您可以获得原始数据在驱动器上不受影响的优势,并且您不必担心原始数据集的10个不同版本会占用您宝贵的数据空间。这给了您一个优势,您的原始数据在驱动器上保持不变,您不必担心10个不同版本的原始数据会侵蚀您宝贵的数据空间。

相关推荐