5.使用pytorch预先训练的模型执行目标检测(代码片段)

程序媛一枚~ 程序媛一枚~     2022-10-21     689

关键词:

5. 使用PyTorch预先训练的网络执行目标检测

这篇博客将介绍如何使用PyTorch预训练的网络执行目标检测,这些网络是开创性的、最先进的图像分类网络,包括使用ResNet的更快R-CNN、使用MobileNet的更快R-CNN和RetinaNet。

  • 具有ResNet50主干的更快R-CNN(Faster R-CNN with a ResNet50 backbone 更精确,但速度较慢)
  • 具有MobileNet主干的更快R-CNN(Faster R-CNN with a MobileNet v3 backbone 速度更快,但准确度稍低)
  • 具有ResNet50主干的RetinaNet(RetinaNet with a ResNet50 backbone 速度和准确性之间的良好平衡)
    在准确度和检测小物体方面,Faster R-CNN的表现都非常好。然而,这种准确性是有代价的——Faster R-CNN 模型往往比 Single Shot Detectors (SSD) 和 YOLO慢得多。
    为了帮助加快Faster R-CNN架构,可以将计算成本高昂的ResNet主干换成更轻、更高效(但不太准确)的 MobileNet主干。这样做会提高你的速度。
    否则,RetinaNet 是速度和准确性之间的一个很好的折衷方案。

1. 效果图

第一次运行会自动下载模型

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to C:\\Users\\xx/.cache\\torch\\hub\\checkpoints\\fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
Downloading: "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth" to C:\\Users\\xx/.cache\\torch\\hub\\checkpoints\\fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth
Downloading: "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth" to C:\\Users\\xx/.cache\\torch\\hub\\checkpoints\\retinanet_resnet50_fpn_coco-eeacb38b.pth

frcnn-resnet 效果图如下
使用的对象检测器是一个速度更快的R-CNN,带有ResNet50主干。由于网络的设计方式,速度更快的R-CNN往往非常擅长检测图像中的小物体——这一点可以从以下事实中得到证明:不仅在输入图像中检测到了所有的风筝,而且还检测到了其中的人,椅子(人眼几乎看不到),它真实地展示了R-CNN模型在检测小对象方面有多快。

更快的R-CNN和PyTorch可以一起用于检测复杂场景中的小对象。降低默认置信度将允许检测更多对象,但可能会误报。

frcnn-mobilenet 效果图如下"

retinanet 效果图如下"
可以看到蛋糕、酒杯、桌子、刀、胡萝卜、杯子都被成功检测到。

调低置信度会有更多的对象被检测出来,但也可能误报。

实时检测效果图如下
使用带有MobileNet v3的Faster R-CNN 模型(速度最佳),实现了≈5 FPS/秒。还没有达到大于20FPS的真正实时速度,但是有了更快的GPU和更多的优化可以轻松达到目标。

2. 原理

2.1 什么是经过预训练的对象检测网络,包括PyTorch库中构建的对象检测网络

就像ImageNet挑战往往是图像分类的事实标准一样,COCO数据集(上下文中的常见对象)往往是对象检测基准的标准。
该数据集包含90多类日常世界中常见的对象。计算机视觉和深度学习研究人员在COCO数据集上开发、训练和评估最先进的目标检测网络。
大多数研究人员还将预先训练好的权重发布到模型中,以便计算机视觉从业者可以轻松地将对象检测纳入自己的项目中。
本教程将演示如何使用PyTorch使用以下最先进的分类网络执行对象检测:

  • 具有ResNet50主干的更快R-CNN
  • 具有MobileNet主干的更快R-CNN
  • 具有ResNet50主干的RetinaNet

2.2 环境部署

pip install torch torchvision
pip install opencv-contrib-python

下载coco数据集可以通过fiftyone或者github

pip install fiftyone
pip install tensorflow torch torchvision umap-learn # 使用keras及torch
pip install ipywidgets>=7.5 # jupter notebook交互图

3. 源码

3.1 照片目标检测

# USAGE
# python detect_image.py --model frcnn-resnet --image images/man.jpg --labels coco_classes_91.pickle
# python detect_image.py --model frcnn-resnet --image images/fruit.jpg --labels coco_classes_91.pickle --confidence 0.7

# coco_classes.pickle 包含PyTorch预训练对象检测网络所训练的类标签的名称。
# detect_image.py:在静态图像中使用PyTorch执行对象检测
# detect_realtime.py:将PyTorch对象检测应用于实时视频流
# image/: 示例测试图片

# 导入必要的包
import argparse
import pickle

import cv2
import numpy as np
import torch
from torchvision.models import detection  # torchvision.models包含目标检测的预训练模型

# 解析命令行参数
# --image 要执行目标检测的图像路径
# --model 要使用的PyTorch目标检测模型名称(Faster R-CNN + ResNet, Faster R-CNN + MobileNet, or RetinaNet + ResNet)
# --labels: COCO标签文件路径,包含可读性强的类标签containing human readable class labels
# --confidence: 过滤弱检测的置信度阈值
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", type=str, required=False, default='images/banner_eccv18.jpg',
                help="path to the input image")
ap.add_argument("-m", "--model", type=str, default="frcnn-resnet",
                choices=["frcnn-resnet", "frcnn-mobilenet", "retinanet"],
                help="name of the object detection model")
ap.add_argument("-l", "--labels", type=str, default="coco_classes_91.pickle",
                help="path to file containing list of categories in COCO dataset")
ap.add_argument("-c", "--confidence", type=float, default=0.5,
                help="minimum probability to filter weak detections")
args = vars(ap.parse_args())

# 设置使用cpu/gpu
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载COCO数据集标签,生成对应的边界框颜色列表(为每个标签生成随机颜色)
CLASSES = pickle.loads(open(args["labels"], "rb").read())
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))

# 初始化一个字典包括模型名及对应的PyTorch模型调用函数
# - 带有ResNet50主干网的快速R-CNN(Faster R-CNN with a ResNet50 backbone 更精确,但速度较慢)
# - 带有MobileNet v3主干网的快速R-CNN(Faster R-CNN with a MobileNet v3 backbone 速度更快,但准确度稍低)
# - 带有ResNet50主干网的RetinaNet(RetinaNet with a ResNet50 backbone 速度和准确性之间的良好平衡)
MODELS = 
    "frcnn-resnet": detection.fasterrcnn_resnet50_fpn,
    "frcnn-mobilenet": detection.fasterrcnn_mobilenet_v3_large_320_fpn,
    "retinanet": detection.retinanet_resnet50_fpn


# 加载模型,设置为评估模式
# pretrained=True:告诉PyTorch在COCO数据集上使用预先训练的权重加载模型架构
# progress=True:如果模型尚未下载和缓存,则显示下载进度条
# num_classes:唯一类的总数
# pretrained_backbone:为目标探测器提供主干网
# model = MODELS[args["model"]](pretrained=True, progress=True,
#                               num_classes=len(CLASSES), pretrained_backbone=True).to(DEVICE)
model = MODELS[args["model"]](pretrained=True, progress=True,
                              num_classes=91, pretrained_backbone=True).to(DEVICE)
model.eval()

# 从磁盘加载图像
image = cv2.imread(args["image"])
orig = image.copy()

# 将颜色通道顺序从BGR转换为RGB(因为PyTorch模型是在RGB顺序图像上训练的)
# 将颜色通道顺序从“通道最后”(OpenCV和Keras/TensorFlow默认值)切换到“通道第一”(PyTorch默认值)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.transpose((2, 0, 1))

# 添加维度,缩放像素值为[0,1]范围
# 将图像从NumPy数组转换为具有浮点数据类型的张量
image = np.expand_dims(image, axis=0)
image = image / 255.0
image = torch.FloatTensor(image)

# 传递图像到设备,并进行预测
image = image.to(DEVICE)
detections = model(image)[0]

# 遍历检测结果
for i in range(0, len(detections["boxes"])):
    # 获取与检测相关的置信度(即概率)
    confidence = detections["scores"][i]

    # 过滤弱检测
    if confidence > args["confidence"]:
        # 提取类标签的下标,计算对象的边界框坐标
        idx = int(detections["labels"][i])
        box = detections["boxes"][i].detach().cpu().numpy()
        # 获取边界框坐标并将其转换为整数
        (startX, startY, endX, endY) = box.astype("int")

        # 展示类标签到终端
        label = " : :.2f%".format(str(idx), CLASSES[idx], confidence * 100)
        print("[INFO] ".format(label))
        label = ": :.2f%".format(CLASSES[idx], confidence * 100)

        # 绘制边界框和label在图像上
        cv2.rectangle(orig, (startX, startY), (endX, endY),
                      COLORS[idx], 2)
        y = startY - 15 if startY - 15 > 15 else startY + 15
        cv2.putText(orig, label, (startX, y),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLORS[idx], 2)

# 展示输出图像
cv2.imshow("Output " + str(args["model"]), orig)
cv2.waitKey(0)

3.2 实时视频流(文件/摄像头)目标检测

# coco_classes.pickle 包含PyTorch预训练对象检测网络所训练的类标签的名称。
# detect_frame.py:在静态图像中使用PyTorch执行对象检测
# detect_realtime.py:将PyTorch对象检测应用于实时视频流
# frame/: 示例测试图片

# USAGE
# python detect_realtime.py --model frcnn-mobilenet --labels coco_classes_91.pickle
# python detect_realtime.py --model frcnn-mobilenet --input images/jurassic_park_trailer.mp4 --labels coco_classes_91.pickle --confidence 0.6

# 导入必要的包
import argparse
import pickle
import time

import cv2
import imutils
import numpy as np
import torch
from imutils.video import FPS  # FPS:测量对象检测管道的近似每秒帧数吞吐率
from imutils.video import VideoStream  # 访问摄像头流
from torchvision.models import detection

# 构建命令行参数及解析
# --model 要使用的PyTorch目标检测模型名称(Faster R-CNN + ResNet, Faster R-CNN + MobileNet, or RetinaNet + ResNet)
# --labels: COCO标签文件路径,包含可读性强的类标签containing human readable class labels
# -i 可选的输入视频文件路径,不输入则使用网络摄像头
# -o 可选的输出视频文件路径
# --confidence: 置信度阈值,过滤弱的假阳性检测
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", type=str, default="frcnn-resnet",
                choices=["frcnn-resnet", "frcnn-mobilenet", "retinanet"],
                help="name of the object detection model")
ap.add_argument("-l", "--labels", type=str, default="coco_classes_90.pickle",
                help="path to file containing list of categories in COCO dataset")
ap.add_argument("-i", "--input", type=str,
                help="path to optional input video file")
ap.add_argument("-o", "--output", type=str,
                help="path to optional output video file")
ap.add_argument("-c", "--confidence", type=float, default=0.5,
                help="minimum probability to filter weak detections")
args = vars(ap.parse_args())

# 设置使用cpu/gpu
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载COCO数据集标签,生成对应的边界框颜色列表(为每个标签生成随机颜色)
CLASSES = pickle.loads(open(args["labels"], "rb").read())
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))

# 初始化一个字典包括模型名及对应的PyTorch模型调用函数
# - 带有ResNet50主干网的快速R-CNN(Faster R-CNN with a ResNet50 backbone 更精确,但速度较慢)
# - 带有MobileNet v3主干网的快速R-CNN(Faster R-CNN with a MobileNet v3 backbone 速度更快,但准确度稍低)
# - 带有ResNet50主干网的RetinaNet(RetinaNet with a ResNet50 backbone 速度和准确性之间的良好平衡)
MODELS = 
    "frcnn-resnet": detection.fasterrcnn_resnet50_fpn,
    "frcnn-mobilenet": detection.fasterrcnn_mobilenet_v3_large_320_fpn,
    "retinanet": detection.retinanet_resnet50_fpn


# 加载模型,设置为评估模式
# pretrained=True:告诉PyTorch在COCO数据集上使用预先训练的权重加载模型架构
# progress=True:如果模型尚未下载和缓存,则显示下载进度条
# num_classes:唯一类的总数
# pretrained_backbone:为目标探测器提供主干网
model = MODELS[args["model"]](pretrained=True, progress=True,
                              num_classes=len(CLASSES), pretrained_backbone=True).to(DEVICE)
model.eval()

# 如果没有输入的视频文件路径提供,则获取网络摄像头的指针
# 初始化视频流,允许摄像头预热2s,初始化fps吞吐量
if not args.get("input", False):
    print("[INFO] starting video stream...")
    vs = VideoStream(src=0).start()
    time.sleep(2.0)
# 否则,获取视频文件指针
else:
    print("[INFO] opening video file...")
    vs = cv2.VideoCapture(args["input"])
fps = FPS().start()

# 初始化视频文件writer
writer = None

# 初始化帧的宽度和高度
W = None
H = None

# 遍历视频流里的帧
while True:
    # 从线程化的视频流获取帧,缩放为宽度400px
    # 从视频流中读取一帧,调整其大小(输入帧越小,推断速度越快),然后克隆它,以便以后可以对其进行绘制
    # 获取下一帧,并判断是从摄像头或者文件捕获到的帧
    frame = vs.read()
    frame = frame[1] if args.get("input", False) else frame

    # 如果在文件流未获取到视频帧,则表明到了文件末尾,终止循环
    if args["input"] is not None and frame is None:
        break
    frame = imutils.resize(frame, width=400)
    orig = frame.copy()

    # 将颜色通道顺序从BGR转换为RGB(因为PyTorch模型是在RGB顺序图像上训练的)
    # 将颜色通道顺序从“通道最后”(OpenCV和Keras/TensorFlow默认值)切换到“通道第一”(PyTorch默认值)
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = frame.transpose((2, 0, 1))

    # 添加维度,缩放像素值为[0,1]范围
    # 将图像从NumPy数组转换为具有浮点数据类型的张量
    frame = np.expand_dims(frame, axis=0)
    frame = frame / 255.0
    frame = torch.FloatTensor(frame)

    # 传递图像到设备,并进行预测
    frame = frame.to(DEVICE)
    detections = model(frame)[0]

    # 遍历检测结果
    for i in range(0, len(detections["boxes"])):
        # 获取与检测相关的置信度(即概率)
        confidence = detections["scores"][i]

        # 过滤弱检测
        if confidence > args["confidence"]:
            # 提取类标签的下标,计算对象的边界框坐标
            idx = int(detections["labels"][i])
            box = detections["boxes"][i].detach().cpu().numpy()
            # 获取边界框坐标并将其转换为整数
            (startX, startY, endX, endY) = box.astype("int")

            # 展示类标签到终端
            label = ": :.2f%".format(CLASSES[idx], confidence * 100)
            print("[INFO] ".format(label))

            # 绘制边界框和label在图像上
            cv2.rectangle(orig, (startX, startY), (endX, endY),
                          COLORS[idx], 2)
            y = startY - 15 if startY - 15 > 15 else startY + 15
            cv2.putText(orig, label, (startX, y),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLORS[idx], 2)

    # 如果帧的宽度和高度为None,则定义WH
    if W is None or H is None:
        (H, W) = orig.shape[:2]

    # 如果需要写入结果视频流到磁盘,则初始化writer
    if args["output"] is not None and writer is None:
        fourcc = cv2.VideoWriter_fourcc(*"MJPG")
        writer = cv2.VideoWriter(args["output"], fourcc, 30,
                                 (W, H), True)

        # 检查是否绘制结果到文件
    if writer is not None:
        writer.write(orig)

    # 展示输出帧
    cv2.imshow("Frame", orig)
    key = cv2.waitKey(1) & 0xFF

    # 按下‘q’键,退出循环
    if key == ord("q"):
        break

    # 更新fps计数器
    fps.update()

# 停止FPS计时器并显示(1)脚本运行时间和(2)大约每秒帧数吞吐量信息。
fps.stop()
print("[INFO] elapsed time: :.2f".format(fps.elapsed()))
print("[INFO] approx. FPS: :.2f".format(fps.fps()))

# 检查是否需要释放视频writer指针
if writer is not None:
    writer.release()

# 如果不使用视频文件,停止线程化的视频流对象
if not args.get("input", False):
    vs.stop()
# 否则释放视频流指针
else:
    vs.release()

# 关闭所有打开的窗口
cv2.destroyAllWindows()

参考

4.使用预训练的pytorch网络进行图像分类(代码片段)

4.使用预训练的PyTorch网络进行图像分类这篇博客将介绍如何使用PyTorch预先训练的网络执行图像分类。利用这些网络只需几行代码就可以准确地对1000个常见对象类别进行分类。这些图像分类网络是开创性的、最先进的图像分类网... 查看详情

4.使用预训练的pytorch网络进行图像分类(代码片段)

4.使用预训练的PyTorch网络进行图像分类这篇博客将介绍如何使用PyTorch预先训练的网络执行图像分类。利用这些网络只需几行代码就可以准确地对1000个常见对象类别进行分类。这些图像分类网络是开创性的、最先进的图像分类网... 查看详情

基于pytorch预训练模型使用fasterrcnn调用摄像头进行目标检测无敌详细!简单!超少代码!(代码片段)

基于pytorch预训练模型使用FasterRCNN调用摄像头进行目标检测【无敌详细!简单!超少代码!】详细完整项目链接:https://download.csdn.net/download/weixin_46570668/86954697?spm=1001.2014.3001.5503使用Pytorch自带的预训练模型fasterr 查看详情

基于pytorch预训练模型使用fasterrcnn调用摄像头进行目标检测无敌详细!简单!超少代码!(代码片段)

基于pytorch预训练模型使用FasterRCNN调用摄像头进行目标检测【无敌详细!简单!超少代码!】详细完整项目链接:https://download.csdn.net/download/weixin_46570668/86954697?spm=1001.2014.3001.5503使用Pytorch自带的预训练模型fasterr... 查看详情

《南溪的目标检测学习笔记》——训练pytorch模型遇到显存不足的情况怎么办

1前言在目标检测中,可能会遇到显存不足的情况,我们在这里记录一下解决方案;2如何判断真正是出现显存(不是“软件误报”)当前需要分配的显存在200MiB以下,例如:RuntimeError:CUDAoutofmemory.Triedtoal... 查看详情

21使用预训练的目标检测与语义分割网络(代码片段)

今天简单测试一下pytorch提供的模型文章目录1.使用训练好的目标检测网络1.1完整代码2.使用训练好的语义分割网络2.1完整代码1.使用训练好的目标检测网络importnumpyasnpimporttorchvisionimporttorchimporttorchvision.transformsastransformsfromPILimportIm... 查看详情

睿智的目标检测56——pytorch搭建yolov5目标检测平台(代码片段)

睿智的目标检测56——Pytorch搭建YoloV5目标检测平台学习前言源码下载YoloV5改进的部分(不完全)YoloV5实现思路一、整体结构解析二、网络结构解析1、主干网络Backbone介绍2、构建FPN特征金字塔进行加强特征提取3、利用YoloHe... 查看详情

如何对经过训练的目标检测模型进行剪枝?

...发布时间】:2021-12-0901:01:47【问题描述】:您好,我已经使用tensorflow1.14对象检测API训练了对象检测模型,我的模型表现良好。但是,我想减少/优化模型的参数以使其更轻。如何在经过训练的模型上使用剪枝?【问题讨论】:你... 查看详情

睿智的目标检测——pytorch搭建yolov7-obb旋转目标检测平台(代码片段)

睿智的目标检测——Pytorch搭建[YoloV7-OBB]旋转目标检测平台学习前言源码下载YoloV7-OBB改进的部分(不完全)YoloV7-OBB实现思路一、整体结构解析二、网络结构解析1、主干网络Backbone介绍2、构建FPN特征金字塔进行加强特征提... 查看详情

openmmlab实战营打卡-第5课

...用训练策略训练自己的检测模型通常基于微调训练:使用基于COCO预训练的检测模型作为梯度下降的“起点”使用自己的数 查看详情

睿智的目标检测65——pytorch搭建detr目标检测平台(代码片段)

睿智的目标检测65——Pytorch搭建DETR目标检测平台学习前言源码下载DETR实现思路一、整体结构解析二、网络结构解析1、主干网络Backbone介绍a、什么是残差网络b、什么是ResNet50模型c、位置编码2、编码网络Encoder网络介绍a、Transformer... 查看详情

深度学习和目标检测系列教程10-300:通过torch训练第一个faster-rcnn模型(代码片段)

...图像数据集上使用Faster-RCNN模型。代码的灵感来自此处的Pytorch文档教程和Kagglehttps://pytorch.org/tutorials/intermediate/torchvision_tutorial 查看详情

tensorflow利用预训练模型进行目标检测:预训练模型的使用(代码片段)

..._tutorial.ipynb 但是一直有问题,没有运行起来,所以先使用一个别人写好的代码上一个在ubuntu下可用的代码链接:https://gitee.com/bubbleit/Ji 查看详情

戴眼镜检测和识别2:pytorch实现戴眼镜检测和识别(含戴眼镜数据集和训练代码)(代码片段)

Pytorch实现戴眼镜检测和识别(含戴眼镜数据集和训练代码)目录Pytorch实现戴眼镜检测和识别(含戴眼镜数据集和训练代码)1.戴眼镜检测和识别方法2.戴眼镜数据集3.人脸检测模型4.戴眼镜分类模型训练(1)项目安装(2ÿ... 查看详情

pytorch从头搭建并训练一个神经网络模型(图像分类cnn)(代码片段)

...多稍微修改下源码的接口满足自己的需求。还从来没有用PyTorch从头搭建并训练一个模型出来。正好最近在较为系统地学PyTorch,就总结一下如何从头搭建并训练一个神经网络模型。1.使用torchvision加载数据集并做预处理我们使... 查看详情

六pytorch进阶训练技巧(代码片段)

六、PyTorch进阶训练技巧文章目录六、PyTorch进阶训练技巧1.自定义损失函数1.1.函数定义1.2.类定义1.2.1.DiceLoss1.2.2.DiceBCELoss1.2.3.IoULoss1.2.4.FocalLoss2.动态调整学习率2.1.使用官方提供的scheduler2.2.自定义scheduler3.模型微调-torchvision3.1使用... 查看详情

手把手教你使用yolov5训练自己的目标检测模型-口罩检测-视频教程(代码片段)

手把手教你使用YOLOV5训练自己的目标检测模型大家好,这里是肆十二(dejahu),好几个月没有更新了,这两天看了一下关注量,突然多了1k多个朋友关注,想必都是大作业系列教程来的小伙伴。既然有... 查看详情

动手学cv-目标检测入门教程6:训练与测试(代码片段)

...小组创作的目标检测入门教程。对应开源项目《动手学CV-Pytorch》的第3章的内容,教程中涉及的代码也可以在项目中找到,后续会持续更新更多的优质内容,欢迎⭐️。如果使用我们教程的内容或图片,请在文章... 查看详情