睿智的目标检测61——tensorflow2focalloss详解与在yolov4当中的实现(代码片段)

Bubbliiiing Bubbliiiing     2023-03-15     238

关键词:

睿智的目标检测61——Tensorflow2 Focal loss详解与在YoloV4当中的实现

学习前言

TF2的也补上咯。其实和Keras的一摸一样0 0。

什么是Focal Loss

Focal Loss是一种Loss计算方案。其具有两个重要的特点。

1、控制正负样本的权重
2、控制容易分类和难分类样本的权重

正负样本的概念如下:
目标检测本质上是进行密集采样,在一张图像生成成千上万的先验框(或者特征点),将真实框与部分先验框匹配,匹配上的先验框就是正样本,没有匹配上的就是负样本

容易分类和难分类样本的概念如下:
假设存在一个二分类问题,样本1和样本2均为类别1。网络的预测结果中,样本1属于类别1的概率=0.9,样本2属于类别1的概率=0.6,前者预测的比较准确,是容易分类的样本;后者预测的不够准确,是难分类的样本。

如何实现权重控制呢,请往下看:

一、控制正负样本的权重

如下是常用的交叉熵loss,以二分类为例:

我们可以利用如下Pt简化交叉熵loss。

此时:

想要降低负样本的影响,可以在常规的损失函数前增加一个系数αt。与Pt类似:
当label=1的时候,αt=α;
当label=otherwise的时候,αt=1 - α。

a的范围是0到1。此时我们便可以通过设置α实现控制正负样本对loss的贡献
分解开就是:

二、控制容易分类和难分类样本的权重

样本属于某个类,且预测结果中该类的概率越大,其越容易分类 ,在二分类问题中,正样本的标签为1,负样本的标签为0,p代表样本为1类的概率。

对于正样本而言,1-p的值越大,样本越难分类。
对于负样本而言,p的值越大,样本越难分类。

Pt的定义如下:

所以利用1-Pt就可以计算出每个样本属于容易分类或者难分类。

具体实现方式如下。

其中:
( 1 − p t ) γ (1-p_t)^γ (1pt)γ
就是每个样本的容易区分程度, γ γ γ称为调制系数

1、当pt趋于0的时候,调制系数趋于1,对于总的loss的贡献很大。当pt趋于1的时候,调制系数趋于0,也就是对于总的loss的贡献很小。
2、当γ=0的时候,focal loss就是传统的交叉熵损失,可以通过调整γ实现调制系数的改变。

三、两种权重控制方法合并

通过如下公式就可以实现控制正负样本的权重控制容易分类和难分类样本的权重

实现方式

本文以Keras版本的YoloV4为例,给大家进行解析,YoloV4的坐标如下:
https://github.com/bubbliiiing/yolov4-tf2

首先定位YoloV4中,正负样本区分的损失部分,YoloV4的损失由三部分组成,分别为:
location_loss(回归损失)
confidence_loss(目标置信度损失)
class_loss(种类损失)
正负样本区分的损失部分是confidence_loss(目标置信度损失),因此我们在这一部分添加Focal Loss。

首先定位公式中的概率p。raw_pred代表每个特征点的预测结果,取出其中属于置信度的部分,取sigmoid,就是概率p

tf.sigmoid(raw_pred[...,4:5])

首先进行正负样本的平衡,设立参数alpha。

alpha 		# 正样本的平衡参数
1-alpha		# 负样本的平衡参数

然后进行难易分类样本的平衡,设立参数gamma。

(tf.ones_like(raw_pred[...,4:5]) - tf.sigmoid(raw_pred[...,4:5])) ** gamma 	# 正样本的平衡参数
tf.sigmoid(raw_pred[...,4:5]) ** gamma										# 负样本的平衡参数

乘上原来的交叉熵损失即可。

confidence_loss = object_mask * (tf.ones_like(raw_pred[...,4:5]) - tf.sigmoid(raw_pred[...,4:5])) ** gamma * alpha * K.binary_crossentropy(object_mask, raw_pred[...,4:5], from_logits=True) + \\
            (1 - object_mask) * ignore_mask * tf.sigmoid(raw_pred[...,4:5]) ** gamma * (1 - alpha) * K.binary_crossentropy(object_mask, raw_pred[...,4:5], from_logits=True)

睿智的目标检测60——pytorch搭建yolov7目标检测平台(代码片段)

睿智的目标检测60——Pytorch搭建YoloV7目标检测平台学习前言源码下载YoloV7改进的部分(不完全)YoloV7实现思路一、整体结构解析二、网络结构解析1、主干网络Backbone介绍2、构建FPN特征金字塔进行加强特征提取3、利用YoloHe... 查看详情

睿智的目标检测51——tensorflow2搭建yolo3目标检测平台(代码片段)

睿智的目标检测51——Tensorflow2搭建yolo3目标检测平台学习前言源码下载YoloV3实现思路一、整体结构解析二、网络结构解析1、主干网络Darknet53介绍2、构建FPN特征金字塔进行加强特征提取3、利用YoloHead获得预测结果三、预测结果的... 查看详情

睿智的目标检测61——tensorflow2focalloss详解与在yolov4当中的实现(代码片段)

睿智的目标检测61——Tensorflow2Focalloss详解与在YoloV4当中的实现学习前言什么是FocalLoss一、控制正负样本的权重二、控制容易分类和难分类样本的权重三、两种权重控制方法合并实现方式学习前言TF2的也补上咯。其实和Keras的一摸... 查看详情

睿智的目标检测65——pytorch搭建detr目标检测平台(代码片段)

睿智的目标检测65——Pytorch搭建DETR目标检测平台学习前言源码下载DETR实现思路一、整体结构解析二、网络结构解析1、主干网络Backbone介绍a、什么是残差网络b、什么是ResNet50模型c、位置编码2、编码网络Encoder网络介绍a、Transformer... 查看详情

睿智的目标检测——pyqt5搭建目标检测界面(代码片段)

睿智的目标检测——PyQt5搭建目标检测界面学习前言基于B导开源的YoloV4-Pytorch源码开发了戴口罩人脸检测系统(21年完成的本科毕设,较为老旧,可自行替换为最新的目标检测算法)。源码下载https://github.com/Egrt/YOL... 查看详情

睿智的目标检测——pytorch搭建yolov7-obb旋转目标检测平台(代码片段)

睿智的目标检测——Pytorch搭建[YoloV7-OBB]旋转目标检测平台学习前言源码下载YoloV7-OBB改进的部分(不完全)YoloV7-OBB实现思路一、整体结构解析二、网络结构解析1、主干网络Backbone介绍2、构建FPN特征金字塔进行加强特征提... 查看详情

睿智的目标检测56——pytorch搭建yolov5目标检测平台(代码片段)

睿智的目标检测56——Pytorch搭建YoloV5目标检测平台学习前言源码下载YoloV5改进的部分(不完全)YoloV5实现思路一、整体结构解析二、网络结构解析1、主干网络Backbone介绍2、构建FPN特征金字塔进行加强特征提取3、利用YoloHe... 查看详情

睿智的目标检测55——keras搭建yolov5目标检测平台(代码片段)

睿智的目标检测55——Keras搭建YoloV5目标检测平台学习前言源码下载YoloV5改进的部分(不完全)YoloV5实现思路一、整体结构解析二、网络结构解析1、主干网络Backbone介绍2、构建FPN特征金字塔进行加强特征提取3、利用YoloHead... 查看详情

保研笔记八——yolov5项目复习(代码片段)

学习转载自:睿智的目标检测56——Pytorch搭建YoloV5目标检测平台_Bubbliiiing的博客-CSDN博客_睿智yolo Pytorch搭建自己的YoloV5目标检测平台(Bubbliiiing源码详解训练预测)-主干网络介绍_哔哩哔哩_bilibili还有一些视频的学习... 查看详情

目标检测61dynamicheadunifyingobjectdetectionheadswithattentions(代码片段)

...tps://github.com/microsoft/DynamicHead出处:CVPR2021一、背景将目标检测中的定位和分类头统一起来一直是一个引人关注的问题,之前的工作尝试了很多方法,但是一直没有得来一个一致的形式。所以,如何提升目标检测头... 查看详情

mmrotate旋转目标检测实战指南(代码片段)

一、环境安装1、创建虚拟环境condacreate-nopenmmlabpython=3.7-ycondaactivateopenmmlab2、安装pytorch(默认cuda10.1)condainstallpytorch==1.7.0torchvision==0.8.0cudatoolkit=10.1-cpytorch 查看详情

mmrotate旋转目标检测实战指南(代码片段)

一、环境安装1、创建虚拟环境condacreate-nopenmmlabpython=3.7-ycondaactivateopenmmlab2、安装pytorch(默认cuda10.1)condainstallpytorch==1.7.0torchvision==0.8.0cudatoolkit=10.1-cpytorch 查看详情

目标检测中的评价指标map以及coco评价标准(代码片段)

pycocotools安装Linux:pipinstallpycocotools;Windows:pipinstallpycocotools-windows训练过程中输出值AveragePrecision(AP)@[IoU=0.50:0.95|area=all|maxDets=100]=0.534AveragePrecision(AP)@[IoU&# 查看详情

本项目基于paddlex实现目标检测(代码片段)

一、项目背景本项目基于paddlex实现目标检测。效果如图下载版本小于2.0.0的paddlex!pipinstall"paddlex<2.0.0"-ihttps://mirror.baidu.com/pypi/simple定义参数#configval_radio=0.2train_list='data/train_list.txt'val_list='data/val_list.txt'... 查看详情

目标检测小脚本:根据xml文件统计类别数(代码片段)

问题场景搜到一个目标检测数据集,但是别人没有详细说明到底有多少类别以及各类别名称,这时候就需要查询xml的标注文件来获取这两个信息。脚本代码importxml.etree.ElementTreeasETimportnumpyasnpimportosif__name__=='__main_... 查看详情

基于mmrotate训练自定义数据集做旋转目标检测2022-3-30(代码片段)

...的成员之一。里面包含了rcnn、fasterrcnn、r3det等各种旋转目标的检测模型,适合于遥感图像领域的目标检测。1.MMrotate下载MMrotate包下载:下载链接目录结构如下:2.环境安装所需要的依赖环境:Linux&WindowsPython3.7&#... 查看详情

目标检测中各种iou说明(代码片段)

一、说明IoU损失是目标检测中最常见的损失函数,表示的就是真实框和预测框的交并比。二、代码实现importmathdefbbox_iou(box1,box2,xywh=True,GIoU=False,DIoU=False,CIoU=False,eps=1e-7):#ReturnsIntersectionoverUnionIoU(n,1)ofbox1(n,4)tobox2(... 查看详情

解读通过风格迁移的浓雾天气条件下无人机图像目标检测方法

...件比较差,解决办法就是通过风格迁移的无人机图像目标检测方法。用这种方法检测目标改变变换为了使目标检测模型从学习目标物纹理和表面信息转变为学习目标物轮廓这一更高难度的任务。实验数据集为Visdrone2019(... 查看详情