pytorch学习(代码片段)

采蘑菇的csz 采蘑菇的csz     2023-04-07     201

关键词:

Pytorch1_手写字识别

下面将构造一个三层全连接层的神经网络来对手写数字进行识别。

数据集的载入

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim

import torchvision
from matplotlib import pyplot as plt

batch_size=512
train_loader = torch.utils.data.DataLoader(
      torchvision.datasets.MNIST('mnist_data/',train=True, download=True,
                                transform=torchvision.transforms.Compose([
                                torchvision.transforms.ToTensor()])),
      batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor()])),
    batch_size=batch_size, shuffle=False)

数据集展示

def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0], cmap='gray', interpolation='none')
        plt.title(": ".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()
x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, 'image sample')

输出图片如下:

网络搭建

#定义网络
class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        
        #xw+b
        self.fc1 = nn.Linear(28*28,256)
        self.fc2 = nn.Linear(256,64)
        self.fc3 = nn.Linear(64,10)
        
    def forward(self,x):
        #x:[b,1,28,28]
        #h1 = relu(xw1+b1)
        x = F.relu(self.fc1(x))
        #h2 = relu(h1w2+b2)
        x = F.relu(self.fc2(x))
        #h3 = h2w3+b3
        x = self.fc3(x)
        
        return x

模型训练

net = Net()
# [w1, b1, w2, b2, w3, b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()

train_loss = []
for epoch in range(4):
    for batch_idx,(x,y) in enumerate(train_loader):
        
        #x:[b,1,28,28],y:[512]
        #[b,1,28,28]==>[b,784]
        x = x.view(x.size(0),28*28)
        #=>[b,10]
        out = net(x)
        loss = criterion(out,y)
        
        optimizer.zero_grad()
        loss.backward()
        
        #w'=w - lr*grad
        optimizer.step()
        
        train_loss.append(loss.item())
        
        if batch_idx %10 == 0:
            print(epoch,batch_idx,loss.item())

torch.nn.CrossEntropyLoss()使用注意
CrossEntropyLoss(将 nn.LogSoftmax() 和 nn.NLLLoss() 结合在一个类中)一般用于计算分类问题的损失值,可以计算出不同分布之间的差距。

训练损失函数图

def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()

plot_curve(train_loss)

模型预测

#test
total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0),28*28)
    out = net(x)
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
    
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("Teas ACC:",acc)

pytorch学习笔记:pytorch进阶训练技巧(代码片段)

PyTorch实战:PyTorch进阶训练技巧往期学习资料推荐:1.Pytorch实战笔记_GoAI的博客-CSDN博客2.Pytorch入门教程_GoAI的博客-CSDN博客本系列目录:PyTorch学习笔记(一):PyTorch环境安装PyTorch学习笔记(二)... 查看详情

深度学习——pytorch基础(代码片段)

深度学习(1)——Pytorch基础作者:夏风喃喃参考:《动手学深度学习第二版》李沐文章目录深度学习(1)——Pytorch基础一.数据操作二.数据预处理三.绘图四.自动求导五.概率论导入相关包:importtorch ... 查看详情

深度学习——pytorch基础(代码片段)

深度学习(1)——Pytorch基础作者:夏风喃喃参考:《动手学深度学习第二版》李沐文章目录深度学习(1)——Pytorch基础一.数据操作二.数据预处理三.绘图四.自动求导五.概率论导入相关包:importtorch ... 查看详情

pytorch学习笔记:模型定义修改保存(代码片段)

往期学习资料推荐:1.Pytorch实战笔记_GoAI的博客-CSDN博客2.Pytorch入门教程_GoAI的博客-CSDN博客本系列目录:PyTorch学习笔记(一):PyTorch环境安装PyTorch学习笔记(二):简介与基础知识PyTorch学习笔记&#... 查看详情

python学习pytorch第一阶段。(代码片段)

查看详情

pytorch学习笔记3.深度学习基础(代码片段)

...多分类22.全连接层23.激活函数与GPU加速24.测试根据龙良曲Pytorch学习视频整理,视频链接:【计算机-AI】PyTorch学这个就够了!(好课推荐)深度学习与PyTorch入门实战——主讲人龙 查看详情

pytorch迁移学习(transferlearning)代码详解(代码片段)

PyTorch迁移学习代码详解概述为什么使用迁移学习更好的结果节省时间加载模型ResNet152冻层实现模型初始化获取需更新参数训练模型获取数据完整代码概述迁移学习(TransferLearning)是把已学训练好的模型参数用作新训练模型的起始... 查看详情

深度学习pytorch大坑处理(代码片段)

深度学习pytorch大坑处理importtoralprint(torch.cuda.is_available())//false解决方法.https://download.pytorch.org/whl/torch_stable.html用迅雷下载进入AnacondaPrompt,输入condaactivatepytorch下载好torchvision-0.10.1+cu1 查看详情

anaconda配置各版本pytorch深度学习环境(代码片段)

目录1.前言2.配置镜像源3.pytorch,torchvision,python版本对应4.创建并进入虚拟环境5.Pytorch0.4.16.Pytorch1.0.07.Pytorch1.0.18.Pytorch1.1.09.Pytorch1.2.010.Pytorch1.4.011.Pytorch1.5.012.Pytorch1.5.113.Pytorch1.6.0 查看详情

pytorch学习:stgcn(代码片段)

1main.ipynb1.1导入库importrandomimporttorchimportnumpyasnpimportpandasaspdfromsklearn.preprocessingimportStandardScalerfromload_dataimport*fromutilsimport*fromstgcnimport*1.2随机种子torch.manual_seed(2021)to 查看详情

pytorch学习笔记:pytorch生态简介(代码片段)

PyTorch生态简介往期学习资料推荐:1.Pytorch实战笔记_GoAI的博客-CSDN博客2.Pytorch入门教程_GoAI的博客-CSDN博客本系列目录:PyTorch学习笔记(一):PyTorch环境安装PyTorch学习笔记(二):简介与基础知识Py... 查看详情

《深度学习笔记》——pytorch调整学习率(代码片段)

1定义调整学习率函数defadjust_learning_rate(optimizer,epoch,lr):"""SetsthelearningratetotheinitialLRdecayedby10every2epochs"""#optimizer表示优化器对象lr*=(0.1**(epoch//2))forparam_groupinop 查看详情

pytorch入门与实战----pytorch入门(代码片段)

1.深度学习框架 pytorch与其他框架的比较pytorch的学习方法:课程安排:PyTorch是一个基于Python的科学计算库,它有以下特点:类似于NumPy,但是它可以使用GPU可以用它定义深度学习模型,可以灵活地进行深度学习模型的训练和使... 查看详情

pytorch学习系列——环境搭建(代码片段)

文章目录1.安装英伟达驱动2.安装Anaconda环境3.安装pytorch、CUDA和cuDNN环境(1)配置国内镜像加速(2)安装4.查看安装环境的版本4.1查看Anaconda版本4.2查看Nvidia驱动版本4.3查看pytorch版本4.4查看CUDA版本4.5查看cuDNN版本由... 查看详情

深度学习pytorch——数据加载和处理(代码片段)

深度学习Pytorch(五)——数据加载和处理文章目录深度学习Pytorch(五)——数据加载和处理一、下载安装包二、下载数据集三、读取数据集四、编写一个函数看看图像和landmark五、数据集类六、数据可视化七、数... 查看详情

《动手学深度学习》(pytorch版)(代码片段)

《动手学深度学习》PyTorch版前言简介面向人群食用方法方法一方法二方法三目录原书地址引用阅读指南前言读书啦!!!本项目将《动手学深度学习》原书中MXNet代码实现改为PyTorch实现。原书作者:阿斯顿·张、... 查看详情

[九]深度学习pytorch-transforms图像增强(剪裁翻转旋转)(代码片段)

0.往期内容[一]深度学习Pytorch-张量定义与张量创建[二]深度学习Pytorch-张量的操作:拼接、切分、索引和变换[三]深度学习Pytorch-张量数学运算[四]深度学习Pytorch-线性回归[五]深度学习Pytorch-计算图与动态图机制[六]深度学习Pyto... 查看详情

pytorch模型训练实用教程学习笔记:二模型的构建(代码片段)

前言最近在重温Pytorch基础,然而Pytorch官方文档的各种API是根据字母排列的,并不适合学习阅读。于是在gayhub上找到了这样一份教程《Pytorch模型训练实用教程》,写得不错,特此根据它来再学习一下Pytorch。仓库地... 查看详情