用tensorflow创建tfrecords格式的数据集

下面的代码是生成一个每个图片大小是227*227*1的tfrecord文件,label是这个类别的英文名。

原图片是256*256*3RGB型的.jpg文件,在制作数据集的时候由于对图片的颜色没有要求,所以为了节省空间,进行了灰度化处理。

import tensorflow as tf
import os
import sys
from PIL import Image
import numpy as np

# 数据集路径
TRAIN_DATASET_DIR = "E:/python文件/tensorflow_learn/MyNet/images/train/"
TEST_DATASET_DIR = "E:/python文件/tensorflow_learn/MyNet/images/test/"
# tfrecord文件存放路径
TFRECORD_DIR = "E:/python文件/tensorflow_learn/MyNet/images/"
# 类型名
classes = {"apple_scab", "black_rot", "cedar_apple_rust", "healthy"}


# 判断tfrecord文件是否存在
def _dataset_exists(tfrecord_dir):
    for split_name in [‘train‘, ‘test‘]:
        # 产生test.tfrecords和 train.tfrecords文件路径
        output_filename = os.path.join(tfrecord_dir, split_name+‘.tfrecords‘)
        if not tf.gfile.Exists(output_filename):
            return False
    return True


def int64_feature(values):
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


# 获取该类别的所有文件
def _get_filenames_and_classes(dataset_dir):
    photo_filename = []
    for filename in os.listdir(dataset_dir):
        # 获取文件路径
        path = os.path.join(dataset_dir, filename)
        photo_filename.append(path)
    return photo_filename


# 把数据转换为TFRecord格式
def _convert_dataset(split_name, dataset_dir):
    assert split_name in [‘train‘, ‘test‘]
    with tf.Session() as sess:
        output_filename = os.path.join(TFRECORD_DIR, split_name+‘.tfrecords‘)
        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
            for index, name in enumerate(classes):
                if split_name == ‘train‘:
                    class_path = TRAIN_DATASET_DIR + name + ‘/‘
                else:
                    class_path = TEST_DATASET_DIR + name + ‘/‘
                filenames = _get_filenames_and_classes(class_path)
                for i, img_name in enumerate(filenames):
                    sys.stdout.write(‘\r>>%s %s  Convering image: %d/%d‘ % (split_name, name, i+1, len(filenames)))
                    print(str(img_name))
                    sys.stdout.flush()
                    image_data = Image.open(img_name)
                    image_data = image_data.resize((227, 227))
                    image_data = np.array(image_data.convert(‘L‘))  # 图片灰度化处理
                    img_raw = image_data.tobytes()
                    example = tf.train.Example(
                        features=tf.train.Features(
                            feature={
                                ‘img_raw‘: tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
                                ‘label‘: tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                            }
                        )
                    )
                    tfrecord_writer.write(example.SerializeToString())
            # tfrecord_writer.close()


# 判断tfrecord文件是否存在
if _dataset_exists(TFRECORD_DIR):
    print("文件已存在")
else:
    # 数据转换
    _convert_dataset(‘test‘, TEST_DATASET_DIR)
    _convert_dataset(‘train‘, TRAIN_DATASET_DIR)
    print(‘生成tfrecord文件!‘)

相关推荐