如何将 Tensorflow 数据集 API 与训练和验证集一起使用

     2023-03-12     203

关键词:

【中文标题】如何将 Tensorflow 数据集 API 与训练和验证集一起使用【英文标题】:How to use Tensorflow dataset API with training and validation sets 【发布时间】:2018-05-01 14:11:44 【问题描述】:

手头的简单任务:运行 N 个 epoch 的训练,在每个 epoch 之后计算准确的验证准确度。时期大小可以等于完整的训练集或一些预定义的迭代次数。在验证期间,每个验证集输入都必须被评估一次。

将 one_shot_iterators、可初始化迭代器和/或用于该任务的句柄混合在一起的最佳方法是什么?

这是我认为它应该如何工作的脚手架:

def build_training_dataset():
    pass

def build_validation_dataset():
    pass

def construct_train_op(dataset):
    pass

def magic(iterator):
    pass

USE_CUSTOM_EPOCH_SIZE = True
CUSTOM_EPOCH_SIZE = 60
MAX_EPOCHS = 100


training_dataset = build_training_dataset()
validation_dataset = build_validation_dataset()


# Magic goes here to build a nice one-instance dataset
dataset = magic(training_dataset, validation_dataset)

train_op = construct_train_op(dataset)

# Run N epochs in which the training dataset is traversed, followed by the
# validation dataset.
with tf.Session() as sess:
    for epoch in MAX_EPOCHS:

        # train
        if USE_CUSTOM_EPOCH_SIZE:
            for _ in range(CUSTOM_EPOCH_SIZE):
                sess.run(train_op)
        else:
            while True:
                # I guess smth like this:
                try:
                    sess.run(train_op)
                except tf.errors.OutOfRangeError:
                    break # we are done with the epoch

        # validation
        validation_predictions = []
        while True:
            try:
                np.append(validation_predictions, sess.run(train_op)) # but for validation this time
            except tf.errors.OutOfRangeError:
                print('epoch %d finished with accuracy: %f' % (epoch validation_predictions.mean()))
                break 

【问题讨论】:

【参考方案1】:

由于解决方案比我预期的要复杂得多,因此它分为 2 个和平:

0) 两个示例共享的辅助代码:

USE_CUSTOM_EPOCH_SIZE = True
CUSTOM_EPOCH_SIZE = 60
MAX_EPOCHS = 100

TRAIN_SIZE = 500
VALIDATION_SIZE = 145
BATCH_SIZE = 64


def construct_train_op(batch):
    return batch


def build_train_dataset():
    return tf.data.Dataset.range(TRAIN_SIZE) \
        .map(lambda x: x + tf.random_uniform([], -10, 10, tf.int64)) \
        .batch(BATCH_SIZE)

def build_test_dataset():
    return tf.data.Dataset.range(VALIDATION_SIZE) \
        .batch(BATCH_SIZE)

1) 对于等于训练数据集大小的 epoch:

# datasets construction
training_dataset = build_train_dataset()
validation_dataset = build_test_dataset()

# handle constructions. Handle allows us to feed data from different dataset by providing a parameter in feed_dict
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

train_op = construct_train_op(next_element)

training_iterator = training_dataset.make_initializable_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

with tf.Session() as sess:
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())

    for epoch in range(MAX_EPOCHS):
        #train
        sess.run(training_iterator.initializer)
        total_in_train = 0
        while True:
            try:
                train_output = sess.run(train_op, feed_dict=handle: training_handle)
                total_in_train += len(train_output)
            except tf.errors.OutOfRangeError:
                assert total_in_train == TRAIN_SIZE
                break # we are done with the epoch

        # validation
        validation_predictions = []
        sess.run(validation_iterator.initializer)
        while True:
            try:
                pred = sess.run(train_op, feed_dict=handle: validation_handle)
                validation_predictions = np.append(validation_predictions, pred)
            except tf.errors.OutOfRangeError:
                assert len(validation_predictions) == VALIDATION_SIZE
                print('Epoch %d finished with accuracy: %f' % (epoch, np.mean(validation_predictions)))
                break

2) 对于自定义 epoch 大小:

# datasets construction
training_dataset = build_train_dataset().repeat() # CHANGE 1
validation_dataset = build_test_dataset()

# handle constructions. Handle allows us to feed data from different dataset by providing a parameter in feed_dict
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()


train_op = construct_train_op(next_element)

training_iterator = training_dataset.make_one_shot_iterator() # CHANGE 2
validation_iterator = validation_dataset.make_initializable_iterator()

with tf.Session() as sess:
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())

    for epoch in range(MAX_EPOCHS):
        #train
        # CHANGE 3: no initiazation, not try/catch
        for _ in range(CUSTOM_EPOCH_SIZE): 
            train_output = sess.run(train_op, feed_dict=handle: training_handle)


        # validation
        validation_predictions = []
        sess.run(validation_iterator.initializer)
        while True:
            try:
                pred = sess.run(train_op, feed_dict=handle: validation_handle)
                validation_predictions = np.append(validation_predictions, pred)
            except tf.errors.OutOfRangeError:
                assert len(validation_predictions) == VALIDATION_SIZE
                print('Epoch %d finished with accuracy: %f' % (epoch, np.mean(validation_predictions)))
                break

【讨论】:

您确定需要在每个 epoch 重新初始化训练迭代器(方法 1)吗?

使用 TensorFlow 的数据集 API 的图像摘要

】使用TensorFlow的数据集API的图像摘要【英文标题】:ImagesummarieswithTensorFlow\'sDatasetAPI【发布时间】:2018-04-3022:05:17【问题描述】:我正在使用TensorFlow\'sDatasetAPI来加载和预处理图像。我想将预处理图像的摘要添加到Tensorboard。推... 查看详情

Tensorflow:如何查找 tf.data.Dataset API 对象的大小

】Tensorflow:如何查找tf.data.DatasetAPI对象的大小【英文标题】:Tensorflow:Howtofindthesizeofatf.data.DatasetAPIobject【发布时间】:2018-11-2721:14:57【问题描述】:我了解DatasetAPI是一种迭代器,它不会将整个数据集加载到内存中,因此无法找... 查看详情

如何在 TensorFlow 中使用我自己的数据将图像拆分为测试和训练集

】如何在TensorFlow中使用我自己的数据将图像拆分为测试和训练集【英文标题】:HowtosplitimagesintotestandtrainsetusingmyowndatainTensorFlow【发布时间】:2020-05-2414:54:50【问题描述】:我在这里有点困惑...我刚刚花了一个小时阅读有关如何... 查看详情

如何使用 TFRecord 数据集使 TensorFlow + Keras 快速运行?

】如何使用TFRecord数据集使TensorFlow+Keras快速运行?【英文标题】:HowdoyoumakeTensorFlow+KerasfastwithaTFRecorddataset?【发布时间】:2017-06-3007:20:09【问题描述】:什么是如何将TensorFlowTFRecord与Keras模型和tf.session.run()一起使用的示例,同时... 查看详情

tensorflow 的数据集 API 返回的大小不是恒定的

】tensorflow的数据集API返回的大小不是恒定的【英文标题】:returnedsizeoftensorflow\'sdatasetAPIisnotconstant【发布时间】:2017-10-1211:48:26【问题描述】:我正在使用tensorflow的datasetAPI。并用简单的案例测试我的代码。下面显示了我使用的... 查看详情

Tensorflow 数据集 API 中的过采样功能

】Tensorflow数据集API中的过采样功能【英文标题】:OversamplingfunctionalityinTensorflowdatasetAPI【发布时间】:2018-04-2411:46:19【问题描述】:我想问一下当前的数据集API是否允许实现过采样算法?我处理高度不平衡的班级问题。我在想,... 查看详情

如何从视频数据集制作数据集(首先是tensorflow)

】如何从视频数据集制作数据集(首先是tensorflow)【英文标题】:Howtomakeadatasetfromvideodatasets(tensorflowfirst)【发布时间】:2019-08-1614:27:42【问题描述】:每个人。现在我有一个对象分类任务,我有一个包含大量视频的数据集。在... 查看详情

从图像本地目录创建 tensorflow 数据集

】从图像本地目录创建tensorflow数据集【英文标题】:Createtensorflowdatasetfromimagelocaldirectory【发布时间】:2019-07-0213:47:39【问题描述】:我在本地有一个非常庞大的图像数据库,数据分布就像每个文件夹都包含一个类别的图像。我... 查看详情

Tensorflow 对象检测 API 中的过拟合

】Tensorflow对象检测API中的过拟合【英文标题】:OverfittinginTensorflowObjectdetectionAPI【发布时间】:2020-06-1623:11:03【问题描述】:我正在自定义数据集(即车牌数据集)上训练tensorflow对象检测API模型。我的目标是使用tensorflowlite将此... 查看详情

python使用tensorflow数据集api读取磁盘上的图像(代码片段)

查看详情

python数据集输入fn使用tensorflow1.4api。(代码片段)

查看详情

用 Python 生成的 Tensorflow 数据集在 Tensorflow Java API(标签图像)中有不同的读数

】用Python生成的Tensorflow数据集在TensorflowJavaAPI(标签图像)中有不同的读数【英文标题】:TensorflowdatasetproducedinPythonhasdifferentreadingsinTensorflowJavaAPI(LabelImage)【发布时间】:2019-04-0901:05:42【问题描述】:背景:我是Tensorflow和AI的新... 查看详情

使用 tensorflow 将数据集拆分为训练和测试

】使用tensorflow将数据集拆分为训练和测试【英文标题】:splitdatasetintotrainandtestusingtensorflow【发布时间】:2021-02-2701:39:28【问题描述】:我想将我的完整数据集(每个原始数据都有多个特征)拆分为训练集和测试集。除了使用scik... 查看详情

如何使用 tensorflow 数据集训练 sklearn 模型?

】如何使用tensorflow数据集训练sklearn模型?【英文标题】:Howtotrainsklearnmodelsusingtensorflowdataset?【发布时间】:2021-05-1301:31:02【问题描述】:我想知道是否可以使用Tensorflow数据集来训练scikit-learn和其他ML框架。那么,例如,我可以... 查看详情

如何正确地重新标记 TensorFlow 数据集?

】如何正确地重新标记TensorFlow数据集?【英文标题】:HowtoproperlyrelabelaTensorFlowdataset?【发布时间】:2022-01-0421:30:09【问题描述】:我目前正在使用TensorFlow处理CIFAR10数据集。由于各种原因,我需要通过预定义的规则更改标签,例... 查看详情

如何使用 tensorflow 数据集访问图像

】如何使用tensorflow数据集访问图像【英文标题】:Howtoaccessimagesusingtensorflowdataset【发布时间】:2021-12-1409:06:43【问题描述】:最近我从thispage下载了CelebA数据集。现在我想使用tensforflow_dataset包中的tfds.load函数访问它。我的zip文... 查看详情

如何在 TensorFlow 中对数据集进行切片?

】如何在TensorFlow中对数据集进行切片?【英文标题】:HowcanslicingdatasetinTensorFlow?【发布时间】:2020-12-0508:15:12【问题描述】:我想在tf.data中切片数据集。我的数据是这样的:dataset=tf.data.Dataset.from_tensor_slices([[0,1,2,3,4],[1,2,3,4,5],[2... 查看详情

在 Tensorflow 的 Dataset API 中,如何将一个元素映射到多个元素?

】在Tensorflow的DatasetAPI中,如何将一个元素映射到多个元素?【英文标题】:InTensorflow\'sDatasetAPIhowdoyoumaponeelementintomultipleelements?【发布时间】:2018-07-0610:00:30【问题描述】:在tensorflowDataset管道中,我想定义一个自定义映射函数... 查看详情