如何在 Pytorch 中测试自定义数据集?

     2023-03-29     143

关键词:

【中文标题】如何在 Pytorch 中测试自定义数据集?【英文标题】:How do you test a custom dataset in Pytorch? 【发布时间】:2021-07-21 07:09:54 【问题描述】:

我一直在关注 Pytorch 中使用来自 Pytorch 的数据集的教程,这些教程允许您启用是否要使用数据进行训练...但现在我使用的是 .csv 和自定义数据集。

class MyDataset(Dataset):
    def __init__(self, root, n_inp):
        self.df = pd.read_csv(root)
        self.data = self.df.to_numpy()
        self.x , self.y = (torch.from_numpy(self.data[:,:n_inp]),
                           torch.from_numpy(self.data[:,n_inp:]))
    def __getitem__(self, idx):
        return self.x[idx, :], self.y[idx,:]
    def __len__(self):
        return len(self.data)

我如何告诉 Pytorch 不要训练我的 test_dataset,以便我可以将其用作我的模型准确度的参考?

train_dataset = MyDataset("heart.csv", input_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle =True)
test_dataset = MyDataset("heart.csv", input_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle =True)

【问题讨论】:

【参考方案1】:

在 pytorch 中,自定义数据集继承类 Dataset。主要包含两个方法__len__()是指定要迭代的数据集对象的长度,__getitem__()是一次返回一批数据。

一旦数据加载器对象被初始化(train_loadertest_loader 在您的代码中指定),您需要编写一个训练循环和一个测试循环。

def train(model, optimizer, loss_fn, dataloader):
    model.train()
    for i, (input, gt) in enumerate(dataloader):
        if params.use_gpu: #(If training using GPU)
            input, gt = input.cuda(non_blocking = True), gt.cuda(non_blocking = True)
        predicted = model(input)
        loss = loss_fn(predicted, gt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

你的测试循环应该是:

def test(model,loss_fn, dataloader):
    model.eval()
    for i, (input, gt) in enumerate(dataloader):
        if params.use_gpu: #(If training using GPU)
            input, gt = input.cuda(non_blocking = True), gt.cuda(non_blocking = True)
        predicted = model(input)
        loss     = loss_fn(predicted, gt)

此外,您可以使用指标字典来记录您的预测、损失、时期等。训练循环和测试循环的主要区别在于我们在推理阶段排除了反向传播(zero_grad(), backward(), step())。

最后,

for epoch in range(1, epochs + 1):
    train(model, optimizer, loss_fn, train_loader)
    test(model, loss_fn, test_loader)

【讨论】:

【参考方案2】:

在 pytorch 中进行测试时需要注意以下几点:

    将您的模型置于评估模式,这样dropout 和batch normalization 之类的东西就不会进入训练模式:model.eval() 在您的测试代码周围放置一个包装器以避免计算梯度(节省内存和时间):with torch.no_grad(): 仅根据您的训练集规范化或标准化您的数据。这对于最小/最大归一化或 z 分数标准化很重要,以便模型准确反映测试性能。

除此之外,您编写的内容在我看来还不错,因为您没有对数据应用任何转换(例如,图像翻转或高斯噪声注入)。要显示测试模式下的代码应该是什么样子,请参见下文:

for e in range(num_epochs):
    for B, (dat, label) in enumerate(train_loader):
         #transforms here
         opt.zero_grad()
         out = model(dat.to(device))
         loss = criterion(out)
         loss.backward()
         opt.step()
    with torch.no_grad():
         model.eval()
         global_corr = 0
         for B, (dat,label) in enumerate(test_loader):
             out = model(dat.to(device))
             # get batch eval metrics here!

     

【讨论】:

在pytorch中构建高效的自定义数据集(代码片段)

...。神经网络训练在数据管理上可能很难做到“大规模”。PyTorch最近已经出现在我的圈子里,尽管对Keras和TensorFlow感到满意,但我还是不得不尝试一下。令人惊讶的是,我发现它非常令人耳目一新,非常讨人喜欢,尤其是PyTorch提... 查看详情

如何将基于自定义图像的数据集加载到 Pytorch 中以与 CNN 一起使用?

】如何将基于自定义图像的数据集加载到Pytorch中以与CNN一起使用?【英文标题】:HowdoIloadcustomimagebaseddatasetsintoPytorchforusewithaCNN?【发布时间】:2019-01-0517:15:46【问题描述】:我已经在互联网上搜索了几个小时,以找到解决我问... 查看详情

pytorch自定义数据集模型训练流程(代码片段)

文章目录Pytorch模型自定义数据集训练流程1、任务描述2、导入各种需要用到的包3、分割数据集4、将数据转成pytorch标准的DataLoader输入格式5、导入预训练模型,并修改分类层6、开始模型训练7、利用训好的模型做预测Pytorch模... 查看详情

我用pytorch复现了lenet-5神经网络(自定义数据集篇)!

...:这可能是神经网络LeNet-5最详细的解释了!我用PyTorch复现了LeNet-5神经网络(MNIST手写数据集篇)!我用PyTorch复现了LeNet-5神经网络(CIFAR10数据集篇)!详细介绍了卷积神经网络LeNet-5的理论部分和... 查看详情

如何在 Pytorch Lightning 中使用 numpy 数据集

】如何在PytorchLightning中使用numpy数据集【英文标题】:HowtousenumpydatasetinPytorchLightning【发布时间】:2021-07-2921:49:53【问题描述】:我想使用NumPy制作一个数据集,然后想训练和测试一个简单的模型,例如“线性或逻辑”。我正在... 查看详情

在自定义数据集上使用 roboflow 对象检测 Yolov4 pytorch 模型时出现值错误

】在自定义数据集上使用roboflow对象检测Yolov4pytorch模型时出现值错误【英文标题】:ValueerrorwhileusingroboflowobjectdetectionYolov4pytorchmodeloncustomdataset【发布时间】:2021-12-1319:43:48【问题描述】:我们使用Roboflow进行对象检测,使用Yolov4... 查看详情

使用 PyTorch 加载自定义图像数据集

】使用PyTorch加载自定义图像数据集【英文标题】:LoadingcustomdatasetofimagesusingPyTorch【发布时间】:2019-12-3002:47:07【问题描述】:我正在使用线圈100数据集,该数据集包含100个对象的图像,每个对象的72个图像是从固定相机拍摄的... 查看详情

如何将自定义数据集拆分为训练和测试数据集?

】如何将自定义数据集拆分为训练和测试数据集?【英文标题】:HowdoIsplitacustomdatasetintotrainingandtestdatasets?【发布时间】:2018-11-0518:28:51【问题描述】:importpandasaspdimportnumpyasnpimportcv2fromtorch.utils.data.datasetimportDatasetclassCustomDatasetF... 查看详情

Python,类数据集,如何在pytorch中将图像与其各自的标签连接起来

】Python,类数据集,如何在pytorch中将图像与其各自的标签连接起来【英文标题】:Python,classdataset,howtoconcatenateimageswiththeirrespectivelabelsinpytorch【发布时间】:2020-10-0713:39:57【问题描述】:我是PyTorch的新手,在过去的几天里,我一... 查看详情

如何在 PyTorch 中添加自定义定位损失函数?

】如何在PyTorch中添加自定义定位损失函数?【英文标题】:HowtoaddacustomlocalizationlossfunctioninPyTorch?【发布时间】:2021-10-0323:00:27【问题描述】:我有一个PyTorch网络,它使用Wi-FiRSS数据预测设备的位置。所以输出层包含两个对应于x... 查看详情

pytorch自定义数据集处理/dataset/dataloader等(代码片段)

问题处理自定义数据集是应用PyTorch走向工程实际的重要前提,本文将持续更新介绍自定义数据集处理一些常见方法。方法加载自定义数据集并获取分类数量fromtorchvision.datasetsimportImageFoldertrain_dataset=ImageFolder('D:\\\\data\\\\... 查看详情

如何在 pytorch 中处理大型数据集

】如何在pytorch中处理大型数据集【英文标题】:Howtoworkwithlargedatasetinpytorch【发布时间】:2019-07-1205:13:00【问题描述】:我有一个不适合内存(150G)的庞大数据集,我正在寻找在pytorch中使用它的最佳方法。数据集由几个.npz文件组... 查看详情

我用pytorch复现了lenet-5神经网络(自定义数据集篇)!

...:这可能是神经网络LeNet-5最详细的解释了!我用PyTorch复现了LeNet-5神经网络(MNIST手写数据集篇)!我用PyTorch复现了LeNet-5神经网络(CIFAR10数据集篇)!详细介绍了卷积神经网络LeNe 查看详情

我用pytorch复现了lenet-5神经网络(自定义数据集篇)!

...:这可能是神经网络LeNet-5最详细的解释了!我用PyTorch复现了LeNet-5神经网络(MNIST手写数据集篇)!我用PyTorch复现了LeNet-5神经网络(CIFAR10数据集篇)!详细介绍了卷积神经网络LeNe 查看详情

使用pytorch进行数据处理(代码片段)

...到模型的效果。本文以处理图像数据为例,记录一些使用PyTorch进行图像预处理和数据加载的方法。一、数据的加载??在PyTorch中,数据加载需要自定义数据集类,并用此类来实例化数据对象,实现自定义的数据集需要继承torch.utils... 查看详情

如何在pytorch中获取自定义损失函数的权重?

】如何在pytorch中获取自定义损失函数的权重?【英文标题】:Howtogetweightsforcustomlossfunctioninpytorch?【发布时间】:2018-09-2318:49:33【问题描述】:我在pytorch中有一个模型,想在loss_function中添加L1正则化。但我不想将权重传递给loss_f... 查看详情

使用pytorch框架自己制作做数据集进行图像分类(代码片段)

第一章:Pytorch制作自己的数据集实现图像分类第一章:Pytorch框架制作自己的数据集实现图像分类第二章:Pytorch框架构建残差神经网络(ResNet)第三章:Pytorch框架构建DenseNet神经网络提示:本文代码,含有部... 查看详情

如何在pytorch中批处理对话数据集?

】如何在pytorch中批处理对话数据集?【英文标题】:howtobatchdialogdatasetinpytorch?【发布时间】:2020-03-1416:49:17【问题描述】:我想做一个用于预订餐厅的面向任务的对话聊天机器人。因为每个对话都有不同的序列(例如,有些对... 查看详情