tensorflow Dataset及TFRecord一些要点【持续更新】

关于tensorflow结合Dataset与TFRecord这方面看到挺好一篇文章:

https://cloud.tencent.com/developer/article/1088751

github:

https://github.com/YJango/TFRecord-Dataset-Estimator-API/blob/master/TensorFlow%20Dataset%20%2B%20TFRecords.ipynb

dataset要点:

一般先shuffle,然后batch,再repeat

当然可以先repeat再batch,这样做与前面一个的区别就是最后一个batch是会有一部分之前batch里出现过的数据,在测试集上这么做要谨慎。

 

dataset的one_shot_iterator和make_initializable_iterator

看到stackoverflow上有个问题:

https://stackoverflow.com/questions/48091693/tensorflow-dataset-api-diff-between-make-initializable-iterator-and-make-one-sho

其实我觉得题主想问的不是答主说的意思,实际上one-shot迭代器只能迭代一轮,initializable迭代器可以迭代多轮(通过sess.run(iterator.initializer)来实现。那么这两种迭代器的区别就非常明显了,换句话说,题主的第二个代码片段,除了第0个epoch之外,剩下的所有epoch都没有数据,会报错。

可以看下面这段代码来测试:

import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.random.uniform(size=(5, 2))).shuffle(100).batch(2)
iterator = dataset.make_initializable_iterator()
# iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:

    for i in range(5):
        sess.run(iterator.initializer)
        while True:
            try:
                print(sess.run(one_element))
            except tf.errors.OutOfRangeError:
                print("Epoch %s is done." % i)
                break

相关推荐