PyTorch 中是不是存在干净且可扩展的 LSTM 实现? [关闭]

     2023-03-12     34

关键词:

【中文标题】PyTorch 中是不是存在干净且可扩展的 LSTM 实现? [关闭]【英文标题】:Does a clean and extendable LSTM implementation exists in PyTorch? [closed]PyTorch 中是否存在干净且可扩展的 LSTM 实现? [关闭] 【发布时间】:2018-10-14 13:05:50 【问题描述】:

我想自己创建一个LSTM 类,但是我不想再次从头开始重写经典的LSTM 函数。

深入PyTorch的代码,我只发现一个肮脏的实现涉及至少3-4个具有继承性的类:

    https://github.com/pytorch/pytorch/blob/98c24fae6b6400a7d1e13610b20aa05f86f77070/torch/nn/modules/rnn.py#L323 https://github.com/pytorch/pytorch/blob/98c24fae6b6400a7d1e13610b20aa05f86f77070/torch/nn/modules/rnn.py#L12 https://github.com/pytorch/pytorch/blob/98c24fae6b6400a7d1e13610b20aa05f86f77070/torch/nn/_functions/rnn.py#L297

LSTMclean PyTorch 实现是否存在于某处?任何链接都会有所帮助。

例如,我知道 TensorFlow 中存在 LSTM 的干净实现,但我需要派生一个 PyTorch

举个明显的例子,我正在寻找的是一个像 this 一样干净的实现,但是在 PyTorch 中:

【问题讨论】:

【参考方案1】:

我找到的最佳实现在这里https://github.com/pytorch/benchmark/blob/master/rnns/benchmarks/lstm_variants/lstm.py

它甚至实现了四种不同的循环 dropout 变体,非常有用! 如果你把辍学部分拿走,你会得到

import math
import torch as th
import torch.nn as nn

class LSTM(nn.Module):

    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.i2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, x, hidden):
        h, c = hidden
        h = h.view(h.size(1), -1)
        c = c.view(c.size(1), -1)
        x = x.view(x.size(1), -1)

        # Linear mappings
        preact = self.i2h(x) + self.h2h(h)

        # activations
        gates = preact[:, :3 * self.hidden_size].sigmoid()
        g_t = preact[:, 3 * self.hidden_size:].tanh()
        i_t = gates[:, :self.hidden_size]
        f_t = gates[:, self.hidden_size:2 * self.hidden_size]
        o_t = gates[:, -self.hidden_size:]

        c_t = th.mul(c, f_t) + th.mul(i_t, g_t)

        h_t = th.mul(o_t, c_t.tanh())

        h_t = h_t.view(1, h_t.size(0), -1)
        c_t = c_t.view(1, c_t.size(0), -1)
        return h_t, (h_t, c_t)

PS:存储库包含更多 LSTM 和其他 RNN 的变体:https://github.com/pytorch/benchmark/tree/master/rnns/benchmarks. 看看吧,也许你心目中的扩展已经存在了!

编辑: 如 cmets 中所述,您可以包装上面的 LSTM 单元来处理顺序输出:

import math
import torch as th
import torch.nn as nn


class LSTMCell(nn.Module):

    def __init__(self, input_size, hidden_size, bias=True):
        # As before

    def reset_parameters(self):
        # As before

    def forward(self, x, hidden):

        if hidden is None:
            hidden = self._init_hidden(x)

        # Rest as before

    @staticmethod
    def _init_hidden(input_):
        h = th.zeros_like(input_.view(1, input_.size(1), -1))
        c = th.zeros_like(input_.view(1, input_.size(1), -1))
        return h, c


class LSTM(nn.Module):

    def __init__(self, input_size, hidden_size, bias=True):
        super().__init__()
        self.lstm_cell = LSTMCell(input_size, hidden_size, bias)

    def forward(self, input_, hidden=None):
        # input_ is of dimensionalty (1, time, input_size, ...)

        outputs = []
        for x in torch.unbind(input_, dim=1):
            hidden = self.lstm_cell(x, hidden)
            outputs.append(hidden[0].clone())

        return torch.stack(outputs, dim=1)

我没有测试代码,因为我正在使用 convLSTM 实现。如果有问题请告诉我。

更新:固定链接。

【讨论】:

经过一番测试,我意识到它说的是For now, they only support a sequence size of 1。所以这段代码可能需要大量重构才能使用。 我上面给出的代码通常被称为 LSTM 单元。为了处理顺序输入,只需将其包装在一个设置初始隐藏状态的模块中,然后在输入的时间维度上进行迭代,在每个时间点调用 LSTM 单元(类似于这里的操作方式discuss.pytorch.org/t/implementation-of-multiplicative-lstm/…) 还可以看看它是如何在这里完成的 vor 卷积 LSTM github.com/automan000/Convolution_LSTM_PyTorch/blob/master/…。它与常规的 LSTM 非常相似。 您能解释一下为什么在静态方法中使用view 为h 和c 添加一个额外的维度吗? @AnIgnorantWanderer 这是因为此实现假定批量大小为 1,因此它适用于任何长度的输入序列。如果您想使用更大的批量大小,您需要一个固定的序列长度或sequence padding,并且您需要相应地修改代码。【参考方案2】:

我制作了一个简单通用的框架来自定义 LSTM: https://github.com/daehwannam/pytorch-rnn-util

您可以通过设计 LSTM 单元并将它们提供给LSTMFrame 来实现自定义 LSTM。 自定义 LSTM 的一个例子是包中的LayerNormLSTM

# snippet from rnn_util/seq.py
class LayerNormLSTM(LSTMFrame):
    def __init__(self, input_size, hidden_size, num_layers=1, dropout=0, r_dropout=0, bidirectional=False, layer_norm_enabled=True):
        r_dropout_layer = nn.Dropout(r_dropout)
        rnn_cells = tuple(
            tuple(
                LayerNormLSTMCell(
                    input_size if layer_idx == 0 else hidden_size * (2 if bidirectional else 1),
                    hidden_size,
                    dropout=r_dropout_layer,
                    layer_norm_enabled=layer_norm_enabled)
                for _ in range(2 if bidirectional else 1))
            for layer_idx in range(num_layers))

        super().__init__(rnn_cells, dropout, bidirectional)

LayerNormLSTM 拥有 PyTorch 标准 LSTM 的关键选项和附加选项 r_dropoutlayer_norm_enabled

# example.py
import torch
import rnn_util


bidirectional = True
num_directions = 2 if bidirectional else 1

rnn = rnn_util.LayerNormLSTM(10, 20, 2, dropout=0.3, r_dropout=0.25,
                             bidirectional=bidirectional, layer_norm_enabled=True)
# rnn = torch.nn.LSTM(10, 20, 2, bidirectional=bidirectional)

input = torch.randn(5, 3, 10)
h0 = torch.randn(2 * num_directions, 3, 20)
c0 = torch.randn(2 * num_directions, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))

print(output.size())

【讨论】:

JS 文件中的 TS 警告:“类型 X 上不存在属性 X”:是不是可以编写更干净的 JavaScript?

】JS文件中的TS警告:“类型X上不存在属性X”:是不是可以编写更干净的JavaScript?【英文标题】:TSwarninginJSfile:"PropertyXdoesnotexistontypeX":isitpossibletowritecleanerJavaScript?JS文件中的TS警告:“类型X上不存在属性X”:是否可以... 查看详情

如何以干净有效的方式在pytorch中获得小批量?

】如何以干净有效的方式在pytorch中获得小批量?【英文标题】:Howtogetmini-batchesinpytorchinacleanandefficientway?【发布时间】:2017-12-2003:52:37【问题描述】:我正在尝试做一件简单的事情,即使用Torch使用随机梯度下降(SGD)训练线性模... 查看详情

如何检查目录中是不是存在扩展名为 .txt 的文件?

】如何检查目录中是不是存在扩展名为.txt的文件?【英文标题】:Howtocheckifafilewitha.txtextensionexistsinadirectory?如何检查目录中是否存在扩展名为.txt的文件?【发布时间】:2012-08-2100:34:17【问题描述】:我想知道如何检查目录中是... 查看详情

如何检查php中是不是存在mcrypt扩展

】如何检查php中是不是存在mcrypt扩展【英文标题】:Howtocheckifmcryptextensionexistsinphp如何检查php中是否存在mcrypt扩展【发布时间】:2014-08-2422:26:18【问题描述】:我想知道最简单和最快的PHP代码行来检查mcrypt扩展是否可用/安装。有... 查看详情

pytorch为何如此高效好用?

...能知道可以借助C/C++扩展Python,并开发所谓的「扩展」。PyTorch的所有繁重工作由C/C++实现,而不是纯Python。为了定义C/C++中一个新的Python对象类型,你需要定义如下实例的一个类似结构://Pythonobjectthatbackstorch.autograd.VariablestructTHPV... 查看详情

如何仅使用 (*.pdf) 之类的扩展名检查 php 中是不是存在文件

】如何仅使用(*.pdf)之类的扩展名检查php中是不是存在文件【英文标题】:howtocheckfileexistinphpusingonlyextensionlike(*.pdf)如何仅使用(*.pdf)之类的扩展名检查php中是否存在文件【发布时间】:2013-05-2417:13:29【问题描述】:$filename=\'/www/tes... 查看详情

如何检查Golang目录中是不是存在文件`name`(带有任何扩展名)?

】如何检查Golang目录中是不是存在文件`name`(带有任何扩展名)?【英文标题】:Howtocheckfile`name`(withanyextension)doesexistinthedirectoryinGolang?如何检查Golang目录中是否存在文件`name`(带有任何扩展名)?【发布时间】:2017-12-2021:17:52... 查看详情

检查文件扩展名为“.ini”的文件是不是存在,shell

】检查文件扩展名为“.ini”的文件是不是存在,shell【英文标题】:Checkingiffilewithfile-extension\'.ini\'exists,shell检查文件扩展名为“.ini”的文件是否存在,shell【发布时间】:2012-09-1213:31:20【问题描述】:如何检查(使用shell)/dir... 查看详情

PHP - 在不知道扩展名的情况下检查文件是不是存在

】PHP-在不知道扩展名的情况下检查文件是不是存在【英文标题】:PHP-CheckiffileexistswithoutknowingtheextensionPHP-在不知道扩展名的情况下检查文件是否存在【发布时间】:2013-08-3115:43:36【问题描述】:我已经看到thisquestion,但它对我... 查看详情

使用 NoSQL 数据库对 JSON 数据进行高效且可扩展的存储

】使用NoSQL数据库对JSON数据进行高效且可扩展的存储【英文标题】:EfficientandscalablestorageforJSONdatawithNoSQLdatabases【发布时间】:2011-10-2921:25:06【问题描述】:我们正在开展一个项目,该项目应收集日志和审计数据并将其存储在数... 查看详情

可靠且可扩展的实时音频流解决方案?

】可靠且可扩展的实时音频流解决方案?【英文标题】:Reliable&scalablesolutionsforliveaudiostreaming?【发布时间】:2009-05-1214:52:53【问题描述】:我想知道您是否可以分享您关于实时音频流(运行在线广播电台)的想法/资源。我不... 查看详情

PyTorch 中的 nn.functional() 与 nn.sequential() 之间是不是存在计算效率差异

】PyTorch中的nn.functional()与nn.sequential()之间是不是存在计算效率差异【英文标题】:Arethereanycomputationalefficiencydifferencesbetweennn.functional()Vsnn.sequential()inPyTorchPyTorch中的nn.functional()与nn.sequential()之间是否存在计算效率差异【发布时间... 查看详情

PyTorch 中是不是有提取图像补丁的功能?

】PyTorch中是不是有提取图像补丁的功能?【英文标题】:IsthereafunctiontoextractimagepatchesinPyTorch?PyTorch中是否有提取图像补丁的功能?【发布时间】:2018-01-3108:49:19【问题描述】:给定一批图像,我想提取所有可能的图像块,类似... 查看详情

Android - 检查元素是不是存在于 WebView 中(在 DOM 中)

】Android-检查元素是不是存在于WebView中(在DOM中)【英文标题】:Android-CheckifelementexistsinsideWebView(intheDOM)Android-检查元素是否存在于WebView中(在DOM中)【发布时间】:2015-10-0214:58:30【问题描述】:是否有一种简单干净的方法来检... 查看详情

CloudKit 不检索应该存在且可检索的记录

】CloudKit不检索应该存在且可检索的记录【英文标题】:CloudKitdoesnotretrieverecordthatshouldbethereandretrievable【发布时间】:2018-07-1701:12:40【问题描述】:我有一个在CloudKit中创建新记录的函数。成功创建该记录后,它会调用一个名为re... 查看详情

C# 和 CPython 之间的快速且可扩展的 RPC

】C#和CPython之间的快速且可扩展的RPC【英文标题】:FastandscalableRPCbetweenC#andCPython【发布时间】:2012-10-3015:11:20【问题描述】:我需要一些关于C#-用于科学应用的cPython集成的反馈。场景如下:用于数据采集和可视化的C#-用于数据... 查看详情

PyTorch c++ 扩展中的可选张量

】PyTorchc++扩展中的可选张量【英文标题】:OptionaltensorsinPyTorchc++extension【发布时间】:2019-07-0804:29:32【问题描述】:我正在为pytorch编写C++扩展,并使用c++api来执行此操作。对于我的forward函数,我需要传递一个可选的张量。在函... 查看详情

是否存在“SendKeys”的快速且可预测的替代方案?

】是否存在“SendKeys”的快速且可预测的替代方案?【英文标题】:Doesafastandpredictablealternativeto`SendKeys`exist?【发布时间】:2015-03-2501:29:19【问题描述】:我正在尝试为在线游戏编写一个AI,它的工作原理是抓取游戏区域的屏幕区... 查看详情