『tensorflow』ssd源码学习_其八:网络训练

hellcat hellcat     2022-12-16     731

关键词:

Fork版本项目地址:SSD

作者使用了分布式训练的写法,这使得训练部分代码异常臃肿,我给出了部分注释。我对于多机分布式并不很熟,而且不是重点,所以不过多介绍,简单的给出一点训练中作者的优化手段,包含优化器选择之类的。

一、滑动平均

        # =================================================================== #
        # Configure the moving averages.
        # =================================================================== #
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

二、学习率衰减

        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(FLAGS,
                                                             dataset.num_samples,
                                                             global_step)

细节实现函数,有三种形式,一种是常数学习率,两种不同的衰减方式(默认参数:exponential):

def configure_learning_rate(flags, num_samples_per_epoch, global_step):
    """Configures the learning rate.

    Args:
      num_samples_per_epoch: The number of samples in each epoch of training.
      global_step: The global_step tensor.
    Returns:
      A `Tensor` representing the learning rate.
    """
    decay_steps = int(num_samples_per_epoch / flags.batch_size *
                      flags.num_epochs_per_decay)

    if flags.learning_rate_decay_type == ‘exponential‘:
        return tf.train.exponential_decay(flags.learning_rate,
                                          global_step,
                                          decay_steps,
                                          flags.learning_rate_decay_factor,
                                          staircase=True,
                                          name=‘exponential_decay_learning_rate‘)
    elif flags.learning_rate_decay_type == ‘fixed‘:
        return tf.constant(flags.learning_rate, name=‘fixed_learning_rate‘)
    elif flags.learning_rate_decay_type == ‘polynomial‘:
        return tf.train.polynomial_decay(flags.learning_rate,
                                         global_step,
                                         decay_steps,
                                         flags.end_learning_rate,
                                         power=1.0,
                                         cycle=False,
                                         name=‘polynomial_decay_learning_rate‘)

三、优化器选择

optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)

选择很丰富(默认参数:rmsprop):

def configure_optimizer(flags, learning_rate):
    """Configures the optimizer used for training.

    Args:
      learning_rate: A scalar or `Tensor` learning rate.
    Returns:
      An instance of an optimizer.
    """
    if flags.optimizer == ‘adadelta‘:
        optimizer = tf.train.AdadeltaOptimizer(
            learning_rate,
            rho=flags.adadelta_rho,
            epsilon=flags.opt_epsilon)
    elif flags.optimizer == ‘adagrad‘:
        optimizer = tf.train.AdagradOptimizer(
            learning_rate,
            initial_accumulator_value=flags.adagrad_initial_accumulator_value)
    elif flags.optimizer == ‘adam‘:
        optimizer = tf.train.AdamOptimizer(
            learning_rate,
            beta1=flags.adam_beta1,
            beta2=flags.adam_beta2,
            epsilon=flags.opt_epsilon)
    elif flags.optimizer == ‘ftrl‘:
        optimizer = tf.train.FtrlOptimizer(
            learning_rate,
            learning_rate_power=flags.ftrl_learning_rate_power,
            initial_accumulator_value=flags.ftrl_initial_accumulator_value,
            l1_regularization_strength=flags.ftrl_l1,
            l2_regularization_strength=flags.ftrl_l2)
    elif flags.optimizer == ‘momentum‘:
        optimizer = tf.train.MomentumOptimizer(
            learning_rate,
            momentum=flags.momentum,
            name=‘Momentum‘)
    elif flags.optimizer == ‘rmsprop‘:
        optimizer = tf.train.RMSPropOptimizer(
            learning_rate,
            decay=flags.rmsprop_decay,
            momentum=flags.rmsprop_momentum,
            epsilon=flags.opt_epsilon)
    elif flags.optimizer == ‘sgd‘:
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    else:
        raise ValueError(‘Optimizer [%s] was not recognized‘, flags.optimizer)
    return optimizer

四、训练

实际上中间有好一段分布式梯度计算过程,这里不多介绍,大概就是在各个clone上计算出梯度,汇总梯度,再优化各个clone网络,将优化节点提出作为train_tensor等等。

gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
config = tf.ConfigProto(log_device_placement=False,
                        gpu_options=gpu_options)
saver = tf.train.Saver(max_to_keep=5,
                       keep_checkpoint_every_n_hours=1.0,
                       write_version=2,
                       pad_step_number=False)
slim.learning.train(
    train_tensor,
    logdir=FLAGS.train_dir,
    master=‘‘,
    is_chief=True,
    init_fn=tf_utils.get_init_fn(FLAGS),            # 看函数实现就明白了,assign变量用
    summary_op=summary_op,                          # tf.summary.merge节点
    number_of_steps=FLAGS.max_number_of_steps,      # 训练step
    log_every_n_steps=FLAGS.log_every_n_steps,      # 每次model保存step间隔
    save_summaries_secs=FLAGS.save_summaries_secs,  # 每次summary时间间隔
    saver=saver,                                    # tf.train.Saver节点
    save_interval_secs=FLAGS.save_interval_secs,
    session_config=config,                          # sess参数
    sync_optimizer=None)

其中调用的初始化函数如下:

def get_init_fn(flags):
    """Returns a function run by the chief worker to warm-start the training.
    Note that the init_fn is only run when initializing the model during the very
    first global step.

    Returns:
      An init function run by the supervisor.
    """
    if flags.checkpoint_path is None:
        return None
    # Warn the user if a checkpoint exists in the train_dir. Then ignore.
    if tf.train.latest_checkpoint(flags.train_dir):
        tf.logging.info(
            ‘Ignoring --checkpoint_path because a checkpoint already exists in %s‘
            % flags.train_dir)
        return None

    exclusions = []
    if flags.checkpoint_exclude_scopes:
        exclusions = [scope.strip()
                      for scope in flags.checkpoint_exclude_scopes.split(‘,‘)]

    # TODO(sguada) variables.filter_variables()
    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)
    # Change model scope if necessary.
    if flags.checkpoint_model_scope is not None:
        variables_to_restore =             var.op.name.replace(flags.model_name,
                                 flags.checkpoint_model_scope): var
             for var in variables_to_restore


    if tf.gfile.IsDirectory(flags.checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path)
    else:
        checkpoint_path = flags.checkpoint_path
    tf.logging.info(‘Fine-tuning from %s. Ignoring missing vars: %s‘ % (checkpoint_path, flags.ignore_missing_vars))

    return slim.assign_from_checkpoint_fn(
        checkpoint_path,
        variables_to_restore,
        ignore_missing_vars=flags.ignore_missing_vars)

至此,SSD项目介绍完毕。

如何使用训练好模型见集智专栏的文章最后一部分。

《tensorflow实战google深度学习框架》pdf一套四本+源代码_高清_完整

 TensorFlow实战热门Tensorflow实战书籍PDF高清版一套共四本+源代码,包含《Tensorflow实战》、《Tensorflow:实战Google深度学习框架(完整版)》、《TensorFlow:实战Google深度学习框架(第2版)》与《TensorFlow技术解析与实战》,不能错过的... 查看详情

如何使用tensorflow提供的models训练人脸识别的网络?(代码片段)

如何使用TensorFlow提供的models训练人脸识别的网络?前提是我们已经有train和test的record数据格式文件修改使用的网络,使其为我所用文件路径:models/research/object_detection/samples/configs/ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_face.co 查看详情

深度学习(08)_rnn-lstm循环神经网络-03-tensorflow进阶实现(代码片段)

...;点击这里查看本文个人博客地址:点击这里查看关于Tensorflow实现一个简单的二元序列的例子可以点击这里查看关于RNN和LSTM的基础可以查看这里这篇博客主要包含以下内容训练一个RNN模型逐字符生成文本数据(最后的部分)使... 查看详情

『tensorflow』第四弹_classification分类学习_拨云见日

...见mnist......),不过这个分类器实现很底层,能学到不少tensorflow的用法习惯,语言不好描述,值得注意的地方使用大箭头标注了(原来tensorflow的向前传播的调用是这么实现的...之前自己好蠢):1importtensorflowastf2fromtensorflow.example... 查看详情

(转)干货|这篇tensorflow实例教程文章告诉你gans为何引爆机器学习?(附源码)

干货|这篇TensorFlow实例教程文章告诉你GANs为何引爆机器学习?(附源码) 该博客来源自:https://mp.weixin.qq.com/s?__biz=MzA4NzE1NzYyMw==&mid=2247492203&idx=5&sn=3020c3a43bd4dd678782d8aa24996745&chksm=903f1c73a7489565 查看详情

深度学习(07)_rnn-循环神经网络-02-tensorflow中的实现(代码片段)

关于基本的RNN和LSTM的概念和BPTT算法可以查看这里本文个人博客地址:点击这里参考文章:https://r2rt.com/recurrent-neural-networks-in-tensorflow-i.htmlhttps://r2rt.com/styles-of-truncated-backpropagation.html一、源代码实现一个binary例子1、 查看详情

深度学习tensorflow2.0:如何用keras构建自己的网络层?(代码片段)

...层 ?from__future__importabsolute_import,division,print_functionimporttensorflowastftf.keras.backend.clear_session()importtensorflow.kerasaskerasimporttensorflow.keras.layersaslayers#定义网络层就是:设置网络权重和输出到输入的计算过程classMyLayer(layers.Layer):def__init_... 查看详情

基于深度学习的天气识别算法对比研究-tensorflow实现-卷积神经网络(cnn)|第1例(内附源码+数据)(代码片段)

...7;我的环境:语言环境:Python3深度学习环境:TensorFlow2🥂相关教程:编译器教程:新手入门深度学习|1-2:编译器JupyterNotebook深度学习环境配置教程:新手入门深度学习|1-1:配置深度学习环境一... 查看详情

tensorflow学习教程------利用卷积神经网络对mnist数据集进行分类_训练模型(代码片段)

原理就不多讲了,直接上代码,有详细注释。#coding:utf-8importtensorflowastffromtensorflow.examples.tutorials.mnistimportinput_datamnist=input_data.read_data_sets(‘MNIST_data‘,one_hot=True)#每个批次的大小batch_size=100n_batch=mnis 查看详情

tensorflow+keras深度学习人工智能实践应用_林大贵

 编辑推荐浅入深地讲解Keras与TensorFlow深度学习类神经网络使用实际的数据集配合范例程序代码介绍各种深度学习算法,并示范如何进行数据预处理、训练数据、建立模型和预测结果 内容简介本书提供安装、上机操作指南... 查看详情

如何修改 Tensorflow 2.0 中的 epoch 数?

】如何修改Tensorflow2.0中的epoch数?【英文标题】:HowtomodifynumberofepochsinTensorflow2.0?【发布时间】:2021-10-2803:23:36【问题描述】:我正在构建一个模型,仅通过以下方式检测汽车:方法:Tensorflow2.0和迁移学习预训练模型:ssd_mobilene... 查看详情

ssd训练数据集流程(学习记录)(代码片段)

...6.13+pytorch1.10.2+cuda11.62.训练步骤(1)下载SSD源码可到github进行下载GitHub-amdegroot/ssd.pytorch:APyTorchImplementationofSingleShotMultiBoxDetector(2)下载模型文件VGG16_reducedfc.pth预训练模型下载地址:https://s3.amazonaws.com/am... 查看详情

tensorflow源码学习之一--tf架构

目前总共4篇:tensorflow源码学习之一--tf架构tensorflow源码学习之二--clienttensorflow源码学习之三--mastertensorflow源码学习之四--worker  tensorflow的架构图如下:我们主要使用的是Pythonclient+CAPI+Distributedmaster+Dataflowexecutor(worker&am 查看详情

tensorflow学习笔记:神经网络优化(代码片段)

本文是个人的学习笔记,是跟随北大曹健老师的视频课学习的附:bilibili课程链接和MOOC课程链接以及源码下载链接(提取码:mocm)文章目录一、预备知识1.条件判断和选择(where)2.[0,1)随机数(rando... 查看详情

Tensorflow 图节点是交换的

】Tensorflow图节点是交换的【英文标题】:Tensorflowgraphnodesareexchange【发布时间】:2020-07-1712:24:56【问题描述】:我已经使用微调预训练模型ssd_mobilenet_v2_coco_2018训练了一个模型。在这里,我使用了完全相同的pipeline.config文件进行... 查看详情

21-神经网络模型_超参数搜索(tensorflow系列)(深度学习)(代码片段)

...earning_rate)sklearn_model=KerasRegressor(build_fn=build_model)fromtensorflow.keras.wrappers.scikit_learnimportKerasRegressor#回归神经网络#搜索最佳学习率defbuild_model(hidden_layers=1,layer_size=30,learning_rate=3e-3):model=keras.models.Sequ 查看详情

入门级_tensorflow网络搭建

Tensorflow如何搭建神经网络1.基本概念基于Tensorflow的神经网络:用张量表示数据,用计算图搭建神经网络,用会话执行计算图,优化线上的权重(参数),得到模型张量:张量就是多维数据list,用阶来表示张量的维度0阶张量称作标量... 查看详情

在树莓派 (tensorflow) 上进行对象检测的项目

】在树莓派(tensorflow)上进行对象检测的项目【英文标题】:Projectwithobjectdetectiononraspberrypi(tensorflow)【发布时间】:2020-12-1217:06:34【问题描述】:我正在做一个包含对象检测AI的研究项目,能够通过网络摄像头检测7类对象。使用goo... 查看详情