在 PyTorch 中使用 module.to() 移动成员张量

     2023-03-12     88

关键词:

【中文标题】在 PyTorch 中使用 module.to() 移动成员张量【英文标题】:Moving member tensors with module.to() in PyTorch 【发布时间】:2019-07-09 09:49:12 【问题描述】:

我正在 PyTorch 中构建变分自动编码器 (VAE),但在编写与设备无关的代码时遇到问题。 Autoencoder 是nn.Module 的子代,具有编码器和解码器网络,它们也是。通过调用net.to(device),可以将网络的所有权重从一台设备转移到另一台设备。

我遇到的问题是重新参数化技巧:

encoding = mu + noise * sigma

噪声是与musigma 大小相同的张量,并保存为自动编码器模块的成员变量。它在构造函数中初始化,并在每个训练步骤就地重新采样。我这样做是为了避免每一步都构建一个新的噪声张量并将其推送到所需的设备。此外,我想修复评估中的噪音。代码如下:

class VariationalGenerator(nn.Module):
    def __init__(self, input_nc, output_nc):
        super(VariationalGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        embedding_size = 128

        self._train_noise = torch.randn(batch_size, embedding_size)
        self._eval_noise = torch.randn(1, embedding_size)
        self.noise = self._train_noise

        # Create encoder
        self.encoder = Encoder(input_nc, embedding_size)
        # Create decoder
        self.decoder = Decoder(output_nc, embedding_size)

    def train(self, mode=True):
        super(VariationalGenerator, self).train(mode)
        self.noise = self._train_noise

    def eval(self):
        super(VariationalGenerator, self).eval()
        self.noise = self._eval_noise

    def forward(self, inputs):
        # Calculate parameters of embedding space
        mu, log_sigma = self.encoder.forward(inputs)
        # Resample noise if training
        if self.training:
            self.noise.normal_()
        # Reparametrize noise to embedding space
        inputs = mu + self.noise * torch.exp(0.5 * log_sigma)
        # Decode to image
        inputs = self.decoder(inputs)

        return inputs, mu, log_sigma

当我现在使用 net.to('cuda:0') 将自动编码器移动到 GPU 时,由于噪声张量未移动,我在转发时遇到错误。

我不想在构造函数中添加设备参数,因为以后仍然无法将其移动到另一个设备。我还尝试将噪声包装到nn.Parameter 中,使其受net.to() 的影响,但这会导致优化器出错,因为噪声被标记为requires_grad=False

任何人都可以使用net.to() 移动所有模块吗?

【问题讨论】:

【参考方案1】:

使用这个:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

现在对于你使用的模型和每个张量

net.to(device)
input = input.to(device)

【讨论】:

这不是问题。问题是当使用net.to(device) 时,噪声张量没有与权重一起移动。【参考方案2】:

经过反复试验,我找到了两种方法:

    使用缓冲区:通过将self._train_noise = torch.randn(batch_size, embedding_size) 替换为self.register_buffer('_train_noise', torch.randn(batch_size, embedding_size),噪声张量作为缓冲区添加到模块中。这也让net.to(device) 影响它。此外,张量现在是 state_dict 的一部分。

    覆盖net.to(device):使用它,噪音不会出现在 state_dict 之外。

    def to(device):
        new_self = super(VariationalGenerator, self).to(device)
        new_self._train_noise = new_self._train_noise.to(device)
        new_self._eval_noise = new_self._eval_noise.to(device)
    
        return new_self
    

【讨论】:

方法 2 通过覆盖 _apply 而不仅仅是 to 可能会更好;然后.cuda() 等也都可以工作。 (my answer)【参考方案3】:

tilman151's second approach 的更好版本可能是覆盖_apply,而不是to。这样net.cuda()net.float() 等都可以正常工作,因为它们都调用_apply 而不是to(可以在the source 中看到,这比你想象的要简单):

def _apply(self, fn):
    super(VariationalGenerator, self)._apply(fn)
    self._train_noise = fn(self._train_noise)
    self._eval_noise = fn(self._eval_noise)
    return self

【讨论】:

_apply 方法似乎被开发人员故意保留为不覆盖。相反,我建议将其注册为带有 persistent=False arg 的缓冲区,因为这将使其远离 state_dict。这不是更稳定的方式吗? @lamo_738 是的,我认为在给出这个答案时不存在非持久缓冲区,但这似乎是现在最好的选择。 我有一个要移动到 GPU 的顺序对象列表。这样做时,_apply 函数中出现以下错误:AttributeError: 'Sequential' object has no attribute 'is_floating_point'【参考方案4】:

通过使用它,您可以将相同的参数应用于您的张量和模块

def to(self, **kwargs):
    module = super(VariationalGenerator, self).to(**kwargs)
    module._train_noise = self._train_noise.to(**kwargs)
    module._eval_noise = self._eval_noise.to(**kwargs)

    return module

【讨论】:

【参考方案5】:

您可以使用nn.Module 缓冲区和参数 - 调用.to(device) 时会考虑这两者并移至device。 优化器正在更新参数(因此它们需要requires_grad=True),缓冲区不是。

所以在你的情况下,我会把构造函数写成:

    def __init__(self, input_nc, output_nc):
        super(VariationalGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        embedding_size = 128

        # --- CHANGED LINES ---
        self.register_buffer('_train_noise', torch.randn(batch_size, embedding_size))
        self.register_buffer('_eval_noise', torch.randn(1, embedding_size))
        # --- CHANGED LINES ---

        self.noise = self._train_noise

        # Create encoder
        self.encoder = Encoder(input_nc, embedding_size)
        # Create decoder
        self.decoder = Decoder(output_nc, embedding_size)

【讨论】:

如何在 pytorch 中使用 tensorboard 调试器?

】如何在pytorch中使用tensorboard调试器?【英文标题】:Howtousetensorboarddebuggerwithpytorch?【发布时间】:2020-01-2300:57:48【问题描述】:我知道pytorch从1.11版本开始支持tensorboard。但我想知道我们是否可以使用pytorch的tensorboard调试器插... 查看详情

在 CrossEntropyLoss 和 BCELoss (PyTorch) 中使用权重

】在CrossEntropyLoss和BCELoss(PyTorch)中使用权重【英文标题】:UsingweightsinCrossEntropyLossandBCELoss(PyTorch)【发布时间】:2021-08-1604:44:25【问题描述】:我正在训练一个PyTorch模型来执行二进制分类。我的少数类约占数据的10%,所以我想使... 查看详情

如何在 PyTorch 中使用 autograd.gradcheck?

】如何在PyTorch中使用autograd.gradcheck?【英文标题】:Howtouseautograd.gradcheckinPyTorch?【发布时间】:2019-12-2820:16:18【问题描述】:文档不包含任何gradcheck的示例用例,它在哪里有用?【问题讨论】:【参考方案1】:这里的文档中提... 查看详情

在 MESS 中使用 pytorch 的 Tensorboard

】在MESS中使用pytorch的Tensorboard【英文标题】:TensorboardwithpytorchinaMESS【发布时间】:2021-04-2513:34:08【问题描述】:我正在使用PyTorch并可视化训练标量,例如tensorboard中的损失。我有3个模型的事件文件位于以下目录中:├─data│... 查看详情

为啥我需要在 Python/Pytorch 中使用两次导入? [复制]

】为啥我需要在Python/Pytorch中使用两次导入?[复制]【英文标题】:WhydoIneedtouseimporttwiceinPython/Pytorch?[duplicate]为什么我需要在Python/Pytorch中使用两次导入?[复制]【发布时间】:2021-11-1509:20:25【问题描述】:我在Youtube上关注Pytorch... 查看详情

如何在代码中使用 PyTorch PackedSequence?

】如何在代码中使用PyTorchPackedSequence?【英文标题】:HowdoyouusePyTorchPackedSequenceincode?【发布时间】:2017-11-2209:53:27【问题描述】:有人可以提供完整的工作代码(不是sn-p,而是在可变长度循环神经网络上运行的代码),说明您... 查看详情

在 pytorch 中使用多个损失函数

】在pytorch中使用多个损失函数【英文标题】:Usingmultiplelossfunctionsinpytorch【发布时间】:2021-02-2014:57:12【问题描述】:我正在处理图像恢复任务,我考虑了多个损失函数。我的计划是考虑3条路线:1:使用多个损失进行监控,但... 查看详情

在 pytorch 中使用 torchvision.transforms 进行数据增强

】在pytorch中使用torchvision.transforms进行数据增强【英文标题】:DataAugmentationwithtorchvision.transformsinpytorch【发布时间】:2019-07-2503:12:11【问题描述】:我发现数据增强可以在PyTorch中使用torchvision.transforms完成。我还读到在每个时期... 查看详情

我可以在 Pytorch 中使用啥来代替 caffe 重量填充物

】我可以在Pytorch中使用啥来代替caffe重量填充物【英文标题】:WhatcanIuseinPytorchtoremplacecaffe\'sweightfiller我可以在Pytorch中使用什么来代替caffe重量填充物【发布时间】:2021-01-2003:35:58【问题描述】:我正在基于下面的caffe模型编写... 查看详情

是否可以在 transform.compose 中使用非 pytorch 增强

】是否可以在transform.compose中使用非pytorch增强【英文标题】:Isitpossibletousenon-pytochaugmentationintransform.compose【发布时间】:2020-11-0509:24:20【问题描述】:我正在研究一个将图像作为Pytorch输入的数据分类问题。我想使用imgaug库,但... 查看详情

在 pytorch Torchtext 中使用 Vocab 获取单词的频率

】在pytorchTorchtext中使用Vocab获取单词的频率【英文标题】:GetfrequencyofwordsusingVocabinpytorchTorchtext【发布时间】:2022-01-2403:05:42【问题描述】:如何在使用build_vocab_from_iterator创建的torchtextvocab中获取标记的频率?文档链接:https://py... 查看详情

如何在 Pytorch 中测试自定义数据集?

】如何在Pytorch中测试自定义数据集?【英文标题】:HowdoyoutestacustomdatasetinPytorch?【发布时间】:2021-07-2107:09:54【问题描述】:我一直在关注Pytorch中使用来自Pytorch的数据集的教程,这些教程允许您启用是否要使用数据进行训练...... 查看详情

如何使用空间转换器在pytorch中裁剪图像?

】如何使用空间转换器在pytorch中裁剪图像?【英文标题】:Howtousespatialtransformertocroptheimageinpytorch?【发布时间】:2019-08-1718:32:59【问题描述】:空间变换网络的论文声称它可以用来裁剪图像。给定裁剪区域(top_left,bottom_right)=(x1,y1... 查看详情

在 PyTorch 中使用张量索引多维张量

】在PyTorch中使用张量索引多维张量【英文标题】:Indexingamulti-dimensionaltensorwithatensorinPyTorch【发布时间】:2019-02-0502:46:21【问题描述】:我有以下代码:a=torch.randint(0,10,[3,3,3,3])b=torch.LongTensor([1,1,1,1])我有一个多维索引b并想用它来... 查看详情

在 Pytorch 中使用 Dropout:nn.Dropout 与 F.dropout

】在Pytorch中使用Dropout:nn.Dropout与F.dropout【英文标题】:Pytorch:nn.Dropoutvs.F.dropout【发布时间】:2019-04-2411:02:05【问题描述】:使用pyTorch有两种方法可以退出torch.nn.Dropout和torch.nn.functional.Dropout。我很难看出它们在使用上的区别:... 查看详情

无法在pytorch python中使用多目标损失函数

】无法在pytorchpython中使用多目标损失函数【英文标题】:Unabletousemultitargetlossfunctioninpytorchpython【发布时间】:2021-11-1719:57:43【问题描述】:我无法在pytorch中使用损失函数进行多标签分类这是我的损失函数:defloss(self,pred,y_true):p... 查看详情

在 Pytorch 中使用逻辑回归的预测返回无穷大

】在Pytorch中使用逻辑回归的预测返回无穷大【英文标题】:PredictionsusingLogisticRegressioninPytorchreturninfinity【发布时间】:2021-04-0308:01:31【问题描述】:我开始观看有关PyTorch的教程,并且正在学习逻辑回归的概念。我使用我拥有的... 查看详情

如何在忽略类中使用 pytorch 闪电精度?

】如何在忽略类中使用pytorch闪电精度?【英文标题】:HowtousepytorchlightningAccuracywithignoreclass?【发布时间】:2021-07-0405:05:13【问题描述】:我有一些使用CrossEntropyLoss和忽略类的训练管道。模型输出log_probs形状(150,3)-表示3个可能的... 查看详情