深入理解batchnorm的原理代码实现以及bn在cnn中的应用(代码片段)

白马金羁侠少年 白马金羁侠少年     2023-02-24     711

关键词:

深入理解BatchNorm的原理、代码实现以及BN在CNN中的应用

BatchNorm是算法岗面试中几乎必考题,本文将带你理解BatchNorm的原理和代码实现,以及详细介绍BatchNorm在CNN中的应用。

一、BatchNorm论文

论文题目:Batch Normalization: Accelerating Deep Network Training byReducing Internal Covariate Shift
论文地址:https://arxiv.org/pdf/1502.03167.pdf

BatchNorm伪代码如下:

二、BatchNorm代码

y = x − mean ⁡ ( x ) Var ⁡ ( x ) + e p s ∗ gamma ⁡ + beta ⁡ \\mathrmy=\\fracx-\\operatornamemean(x)\\sqrt\\operatornameVar(x)+e p s * \\operatornamegamma+\\operatornamebeta y=Var(x) +epsxmean(x)gamma+beta
根据数据维度的不同,PyTorch中的BatchNorm有不同的形式:

2.1 torch.nn.BatchNorm1d

官方文档:torch.nn.BatchNorm1d
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

  • 2D input: (mini_batch, num_feature),常见的结构化数据,如,房价预测问题中x的特征数有100个,torch.nn.BatchNorm1d(100)
  • 3D input: (mini_batch, num_feature, additional_channel),使用时 torch.nn.BatchNorm1d(num_feature),不过这种维度一般不常用

2.2 torch.nn.BatchNorm2d

官方文档:torch.nn.BatchNorm2d
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

  • 4D input: (mini_batch, num_feature_map, p, q),常用于CV的图像数据,如CIFAR10(3x32x32),torch.nn.BatchNorm2d(3)

2.3 BatchNorm层的参数γ,β和统计量

Batch Norm层有可学习的参数γ和β,以及统计量running mean和running var

  • (可学习参数)γ : weight of BatchNorm
  • (可学习参数)β : bias of BatchNorm
  • (统计量)running mean: 预测阶段会使用这个均值
  • (统计量)running var: 预测阶段会使用这个方差

pytorch中用state_dict()可以查看上面这些信息

print("--- 4D:(mini_batch, num_feature, p, q) ---")
m = nn.BatchNorm2d(3, momentum=0.1)  # 例如, CIFAR10数据集是三通道的,3x32x32
print(m.state_dict().keys())
# 输出:odict_keys(['weight', 'bias', 'running_mean', 'running_var', 'num_batches_tracked'])

2.3.1 train模式

在pytorch中可以使用model.train()将BatchNorm层切换到train模式。

在train模式下参数γ和β会随着网络的反向传播进行梯度更新,而统计量running mean和running var则会用一种特定的方式进行更新。在Pytorch中的更新方式如下:

x ^ new  = ( 1 − \\hatx_\\text new =(1- x^new =(1 momentum ) × x ^ + ) \\times \\hatx+ )×x^+ momentum × x t \\times x_t ×xt

  • x ^ \\hatx x^: running mean or running variance
  • x t x_t xt: input mean and variance(训练时的第t个batch的均值和方差)
  • 默认momentum为0.1

2.3.2 eval模式

在pytorch中可以使用model.eval()将BatchNorm层切换到eval模式。

在eval模式下,我们的模型不可能再等到预测的样本数量达到一个batch时,再进行归一化,而是直接使用train模式得到的统计量running mean和running var进行归一化

2.4 代码:Pytorch实战演练

import torch
import torch.nn as nn

bs = 64

print("Pytorch Batch Norm Layer详解")
print("--- 2D input:(mini_batch, num_feature) ---")
# With Learnable Parameters
m = nn.BatchNorm1d(400)  # 例如,房价预测:x的特征数是400,y是房价
# Without Learnable Parameters(无学习参数γ和β)
# m = nn.BatchNorm1d(100, affine=False)
inputs = torch.randn(bs, 400)
print(m(inputs).shape)


print("Batch Norm层的γ和β是要训练学习的参数")
print("γ:", m.state_dict()['weight'].shape)  # gammar
print("β:", m.state_dict()['bias'].shape)  # beta
print("")


print("--- 3D input:(mini_batch, num_feature, other_channel) ---")
m = nn.BatchNorm1d(32)
inputs = torch.randn(bs, 32, 32)  # 这种格式的数据不常用
print(m(inputs).shape)

print("Batch Norm层的γ和β是要训练学习的参数")
print("γ:", m.state_dict()['weight'].shape)  # gammar
print("β:", m.state_dict()['bias'].shape)  # beta
print("")


print("--- 4D input:(mini_batch, num_feature, H, W) ---")
m = nn.BatchNorm2d(3)  # 例如, CIFAR10数据集是三通道的,3x32x32

inputs = torch.randn(bs, 3, 32, 32)
print(m(inputs).shape)

print("Batch Norm层的γ和β是要训练学习的参数")
print("γ:", m.state_dict()['weight'].shape)  # gammar
print("β:", m.state_dict()['bias'].shape)  # beta
print("Batch Norm层的running_mean和running_var是统计量(主要用于预测阶段)")
print("running_mean:", m.state_dict()['running_mean'].shape)
print("running_var:", m.state_dict()['running_var'].shape)

输出:

Pytorch Batch Norm Layer详解
--- 2D input:(mini_batch, num_feature) ---
torch.Size([64, 400])
Batch Norm层的γ和β是要训练学习的参数
γ: torch.Size([400])
β: torch.Size([400])

--- 3D input:(mini_batch, num_feature, other_channel) ---
torch.Size([64, 32, 32])
Batch Norm层的γ和β是要训练学习的参数
γ: torch.Size([32])
β: torch.Size([32])

--- 4D input:(mini_batch, num_feature, H, W) ---
torch.Size([64, 3, 32, 32])
Batch Norm层的γ和β是要训练学习的参数
γ: torch.Size([3])
β: torch.Size([3])
Batch Norm层的running_mean和running_var是统计量(主要用于预测阶段)
running_mean: torch.Size([3])
running_var: torch.Size([3])

三、BatchNorm在CNN中的应用

我们在第二部分的代码中发现,BatchNorm2d的参数γ和β数量是跟特征图的数量是一致的,并不是我们直观认为的num_feature*H*W个参数,这是为什么呢?

《百面机器学习》P221是这样解释的:
BatchNorm批量归一化在卷积神经网络中应用时,需要注意卷积神经网络的参数共享机制。每一个卷积核的参数在不同位置的神经元当中是共享的,因此同一个特征图的所有神经元也应该被一起归一化!

  • 换句话说就是,你一个特征图用的是共享的卷积核参数,所以这个特征图中的每个神经元(共H*W个)也应该共享参数 γ , β \\gamma, \\beta γ,β。如果有 f f f个卷积核,就对应 f f f个特征图和 f f f组不同的 γ \\gamma γ β \\beta β参数

下面的解释来自hjimce

  • 假如某一层卷积层有6个特征图,每个特征图的大小是100*100,这样就相当于这一层网络有6*100*100个神经元,如果采用BN,就会有6*100*100个参数γ、β,这样岂不是太恐怖了。因此卷积层上的BN使用,其实也是使用了类似权值共享的策略,把一整张特征图当做一个神经元进行处理。
  • 卷积神经网络经过卷积后得到的是一系列的特征图,如果min-batch sizes为m,那么网络某一层输入数据可以表示为四维矩阵(m,f,p,q),m为min-batch sizes,f为特征图个数,p、q分别为特征图的宽高。在cnn中我们可以把每个特征图看成是一个特征处理,因此在使用Batch Normalization,mini-batch size 的大小相当于m*p*q,于是对于每个特征图都只有一对可学习参数:γ、β。

3.1 图解:卷积神经网络中的BatchNorm

这里我特意画了一个图来让大家看清楚CNN中Batchnorm到底是怎么做的

总结来说:

  1. 对于某个特征图而言,一个batch共有m个这样的特征图,并且每个特征图有p*q个神经元,把所有的m*p*q个神经元拉直,然后求得平均值和方差。
  2. 对m个这样特征图的p*q个神经元的每个神经元,利用求出的平均值和方差做下数据变换。

下面是来自于Keras卷积层的BN实现的一小段主要源码:

# Keras BatchNorm
input_shape = self.input_shape
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis]
if train:
    m = K.mean(X, axis=reduction_axes)
    brodcast_m = K.reshape(m, broadcast_shape)
    std = K.mean(K.square(X - brodcast_m) + self.epsilon, axis=reduction_axes)
    std = K.sqrt(std)
    brodcast_std = K.reshape(std, broadcast_shape)
    mean_update = self.momentum * self.running_mean + (1-self.momentum) * m
    std_update = self.momentum * self.running_std + (1-self.momentum) * std
    self.updates = [(self.running_mean, mean_update),
                    (self.running_std, std_update)]
    X_normed = (X - brodcast_m) / (brodcast_std + self.epsilon)
else:
    brodcast_m = K.reshape(self.running_mean, broadcast_shape)
    brodcast_std = K.reshape(self.running_std, broadcast_shape)
    X_normed = ((X - brodcast_m) /
                (brodcast_std + self.epsilon))
out = K.reshape(self.gamma, broadcast_shape) * X_normed + K.reshape(self.beta, broadcast_shape)

附:pytorch中取mean的操作

import torch

bs = 64
a = torch.randn(bs, 100, 32, 28)
# 将轴0,2,3的元素都放在一起取平均值
print(torch.mean(a, axis=(0, 2, 3)).shape)  # torch.Size([100])

附:CNN网络中的BatchNorm2d

四、BatchNorm的优缺点

BN的优点:

  • 解决内部协变量偏移,简单来说训练过程中,各层分布不同,增大了学习难度,BN缓解了这个问题。当然后来也有论文证明BN有作用和这个没关系,而是可以使损失平面更加的平滑,从而加快收敛速度。
  • 缓解了梯度饱和问题(如果使用sigmoid这种含有饱和区间的激活函数的话),加快收敛。

BN的缺点

  • Batch size比较小的时候,效果会比较差。因为他是用一个batch中的均值和方差来模拟全部数据的均值和方差。比如你一个batch只有2个样本,那你两个样本的均值和方差就不能很好地代表全班人的均值和方差,所以效果肯定就不好。
  • BN是计算机视觉CV的标配,但在自然语音处理NLP中效果一般较差,取而代之的是LN。关于LayerNorm的详解,可以参考我另一篇博客:深入理解NLP中LayerNorm的原理以及LN的代码详解

五、参考资料

[1] 李宏毅2021机器学习 第5节 Batch Normalization(学习笔记)
[2] 深度学习(二十九)Batch Normalization 学习笔记 (讲得挺全面的)
[3] BatchNormalization、LayerNormalization、InstanceNorm、GroupNorm简介(不同Norm的对比)
[4] BatchNorm behaves different in train() and eval() #5406
[5] BatchNorm2d原理、作用及其pytorch中BatchNorm2d函数的参数讲解
[6] 神经网络之BN层
[7] 5 分钟理解 BatchNorm
[8] Pytorch的BatchNorm层使用中容易出现的问题
[9] 【深度学习】深入理解Batch Normalization批标准化
[10] BN踩坑记–谈一下Batch Normalization的优缺点和适用场景

深度学习--深入理解batchnormalization

一、简介  BatchNormalization作为最近一年来DL的重要研究成果,已经广泛被证明其有效性和重要性。虽然有些细节处还解释不清其理论原因,但是实践证明好用才是真的好,别忘了DL从Hinton对深层网络做Pre_Train开始就是一个经验... 查看详情

bn(batchnormalization)原理与使用过程详解

论文名字:BatchNormalization:AcceleratingDeepNetworkTrainingby ReducingInternalCovariateShift论文地址:https://arxiv.org/abs/1502.03167   BN被广泛应用于深度学习的各个地方,由于在实习过程中需要修改网络,修改的网络在训练过程中 查看详情

batchnormalization在cnn中的实现细节

...2.BN在CNN中的实现细节2.1训练过程2.2前向推断过程整天说BatchNorm,CNN的论文里离不开BatchNorm。BN可以使每层输入数据分布相对稳定,加速模型训练时的收敛速度。但BN操作在CNN中具体是如何实现的呢?1.BN在MLP中的实现步... 查看详情

batchnormalization在cnn中的实现细节

...2.BN在CNN中的实现细节2.1训练过程2.2前向推断过程整天说BatchNorm,CNN的论文里离不开BatchNorm。BN可以使每层输入数据分布相对稳定,加速模型训练时的收敛速度。但BN操作在CNN中具体是如何实现的呢?1.BN在MLP中的实现步... 查看详情

batchnormalization在cnn中的实现细节

...2.BN在CNN中的实现细节2.1训练过程2.2前向推断过程整天说BatchNorm,CNN的论文里离不开BatchNorm。BN可以使每层输入数据分布相对稳定,加速模型训练时的收敛速度。但BN操作在CNN中具体是如何实现的呢?1.BN在MLP中的实现步... 查看详情

pytorch基础(12)--torch.nn.batchnorm2d()方法(代码片段)

BatchNormanlization简称BN,也就是数据归一化,对深度学习模型性能的提升有很大的帮助。BN的原理可以查阅我之前的一篇博客。白话详细解读(七)-----BatchNormalization。但为了该篇博客的完整性,在这里简单介绍... 查看详情

深入理解requestanimationframe

前言本文主要参考w3c资料,从底层实现原理的角度介绍了requestAnimationFrame、cancelAnimationFrame,给出了相关的示例代码以及我对实现原理的理解和讨论。本文介绍浏览器中动画有两种实现形式:通过申明元素实现(如SVG中的元素)... 查看详情

深入理解多线程——moniter的实现原理(代码片段)

深入理解多线程(四)——Moniter的实现原理收录于话题#和并发编程有关的那点事儿13个点击上方“Hollis”关注我,精彩内容第一时间呈现。全文字数:1200阅读时间:3分钟本文是《深入理解多线程系列文章》的第四篇。点击查看... 查看详情

《深入理解mybatis原理1》mybatis的架构设计以及实例分析

 《深入理解mybatis原理》MyBatis的架构设计以及实例分析MyBatis是目前非常流行的ORM框架,它的功能很强大,然而其实现却比较简单、优雅。本文主要讲述MyBatis的架构设计思路,并且讨论MyBatis的几个核心部件,然后结合一个sele... 查看详情

深度神经网络中的batchnormalization介绍及实现(代码片段)

   之前在https://blog.csdn.net/fengbingchun/article/details/114493591中介绍DenseNet时,网络中会有BN层,即BatchNormalization,在每个DenseBlock中都会有BN参与运算,下面对BN进行介绍并给出C++和PyTorch实现。   Bat 查看详情

深入理解多线程——synchronized的实现原理(代码片段)

synchronized,是Java中用于解决并发情况下数据同步访问的一个很重要的关键字。当我们想要保证一个共享资源在同一时间只会被一个线程访问到时,我们可以在代码中使用synchronized关键字对类或者对象加锁。那么,本文来介绍一... 查看详情

深入理解aqs(代码片段)

文章目录深入理解AQSAQS概念特点AOS自定义实现锁ReentrantLock原理非公平锁实现原理加锁解锁原理竞争失败原理RenntrantLock可重入的原理可打断原理可打断模式公平锁实现原理非公平锁实现公平锁实现读写锁ReentrantReadWriteLock注意事项... 查看详情

bn(batchnormalization)(代码片段)

...?  1.是什么?  2.有什么用?  3.怎么用?paper:《BatchNormalization:AcceleratingDeepNetworkTrainingbyReducingInternalCovariateShift》  先来思考一个问题:我们知道在神经网络训练开始前,都要对输入数据做一个归一化处理,那么具体... 查看详情

深入理解python虚拟机:调试器实现原理与源码分析(代码片段)

...理,通过了解一个语言的调试器的实现原理我们可以更加深入的理解整个语言的运行机制,可以帮助我们更好的理解程序的执行。深入理解python虚拟机:调试器实现原理与源码分析调试器是一个编程语言非常重要的部分,调试器... 查看详情

深入理解composenavigation实现原理(代码片段)

前言一个纯Compose项目少不了页面导航的支持,而navigation-compose几乎是这方面的唯一选择,这也使得它成为Compose工程的标配二方库。介绍navigation-compose如何使用的文章很多了,比如这篇。其实在代码设计上Navigation也非... 查看详情

深入理解composenavigation实现原理(代码片段)

前言一个纯Compose项目少不了页面导航的支持,而navigation-compose几乎是这方面的唯一选择,这也使得它成为Compose工程的标配二方库。介绍navigation-compose如何使用的文章很多了,比如这篇。其实在代码设计上Navigation也非... 查看详情

华为云技术分享batchnormalization(bn)介绍(代码片段)

BatchNormalization(BN)解决的是InternalCovariateShift(ICS)的问题。1InternalCovariateShift在文中定义为2Thechangeinthedistributionofnetworkactivationsduetothechangeinnetworkparametersduringtraining.也就是在训练的过程中因为网络参数改变引起网络各层输出的... 查看详情

深入理解plasmaplasmacash

这一系列文章将围绕以太坊的二层扩容框架Plasma,介绍其基本运行原理,具体操作细节,安全性讨论以及未来研究方向等。本篇文章主要介绍在Plasma框架下的项目PlasmaCash。在上一篇文章中我们已经理解了Plasma的最小实现PlasmaMVP... 查看详情