keras深度学习实战(29)——长短时记忆网络详解与实现(代码片段)

盼小辉丶 盼小辉丶     2022-12-02     450

关键词:

Keras深度学习实战(29)——长短时记忆网络详解与实现

0. 前言

长短时记忆网络 (Long Short Term Memory, LSTM),顾名思义是具有记忆长短期信息能力的神经网络,解决了循环神经网络 (Recurrent neural networks, RNN) 梯度爆炸/消失的问题,是建立在循环神经网络上的一种新型深度学习的时间序列模型,它具有高度的学习能力与模拟能力,具有记忆可持续性的特点,且能预测未来的任意步长。本文首先介绍了 RNN 模型的局限性,从而引入介绍长短时记忆网络 (Long Short Term Memory, LSTM) 的基本原理,最后通过实现 LSTM 进行深入了解。

1. RNN 的局限性

我们首先可视化 RNN 在考虑多个时刻做出预测时的情况,如下所示,随着时间的增加,早期输入的影响会逐渐降低:

更具体的,我们也可以通过公式得到相同的结论,例如我们需要计算第 5 个时刻网络的中间状态:

h 5 = W X 5 + U h 4 = W X 5 + U W X 4 + U 2 W X 3 + U 3 W X 2 + U 4 W X 1 h_5 = WX_5 + Uh_4 = WX_5 + UWX_4 + U_2WX_3 + U_3WX_2 + U_4WX_1 h5=WX5+Uh4=WX5+UWX4+U2WX3+U3WX2+U4WX1

可以看到,随着时间的增加,如果 U > 1 U>1 U>1,则网络中间状态的值高度依赖于 X 1 X_1 X1;而如果 U < 1 U<1 U<1,则网络中间状态值对 X 1 X_1 X1 的依赖就少得多。对 U 矩阵的依赖性还可能在 U 值很小时导致梯度消失,而在 U 值很高时会导致梯度爆炸。
当在预测单词时存在长期依赖性时,RNN 的这种现象将导致无法学习长期依赖关系的问题。为了解决这个问题,我们将引入介绍长短期记忆 (Long Short Term Memory, LSTM) 体系结构。

2. LSTM 模型架构详解

在有关传统 RNN 的问题中,我们了解了 RNN 对于长期依赖问题无济于事。例如,假设输入句子如下:

I live in China. I speak ____.

可以通过关键字 China 来推测以上空中应填充的单词,但该关键字与我们要预测的单词距离 3 个时间戳。如果关键字远离要预测的单词,则需要解决消失/爆炸梯度问题。

2.1 LSTM 架构

在本节中,我们将学习 LSTM 如何帮助克服RNN体系结构的长期依赖缺点,并构建一个简单示例,以便了解 LSTM 的各个组成部分。LSTM 架构示意图如下所示:


可以看到,虽然每个时刻 (h) 的输入 X 和输出保持不变,但是在网络中使用不同的计算方式和激活函数。

2.2 LSTM 各组成部分与计算流程

接下来,我们详细介绍在一个时间戳内的计算过程:


在上图中, x x x h h h 表示输入层和 LSTM 的输出向量,内部状态向量 Memory 存储在单元状态 c c c 中也就是说,相较于基础 RNN 而言,LSTM 将内部状态向量 Memory 和输出分开为两个变量,利用输入门 (Input Gate)、遗忘门 (Forget Gate)和输出门 (Output Gate) 三个门控来控制内部信息的流动。门控机制是一种控制网络中数据流通量的手段,可以较好地控制数据流通的流量程度。

2.2.1 遗忘门

需要忘记的内容是通过“遗忘门”获得的,用于控制上一个时间戳的记忆 c t − 1 c_t-1 ct1 对当前时间戳的影响,遗忘门的控制变量 f t f_t ft 由:

f t = σ ( W x f x ( t ) + W h f h ( t − 1 ) + b f ) f_t=\\sigma(W_xfx^(t)+W_hfh^(t-1)+b_f) ft=σ(Wxfx(t)+Whfh(t1)+bf)

sigmoid 激活函数使网络能够选择性地识别需要忘记的内容。在确定需要忘记的内容后,更新后的单元状态如下:

c t = ( c ( t − 1 ) ⊗ f ) c_t=(c_(t-1)\\otimes f) ct=(c(t1)f)

其中, ⊗ \\otimes 表示逐元素乘法。例如,如果句子的输入序列是 I live in China. I speak ___,可以根据输入的单词 China 来填充空格,在之后,我们可能并不再需要有关国家名称的信息。我们根据当前时间戳需要忘记的内容来更新单元状态。

2.2.2 输入门

输入门用于控制 LSTM 对输入的接受程度,根据当前时间戳提供的输入将其他信息添加到单元状态中,通过 tanh 激活函数获得更新,因此也称为更新门。首先通过对当前时间戳的输入和上一时间戳的输出作非线性变换:

i t = σ ( W x i x ( t ) + W h i h ( t − 1 ) + b i ) i_t=\\sigma(W_xix^(t)+W_hih^(t-1)+b_i) it=σ(Wxix(t)+Whih(t1)+bi)

输入门中,输入更新计算方法如下:

g t = t a n h ( W x g x ( t ) + W h g h ( t − 1 ) + b g ) g_t=tanh(W_xgx^(t)+W_hgh^(t-1)+b_g) gt=tanh(Wxgx(t)+Whgh(t1)+bg)

在当前时间戳中需要忘记某些信息,并在其中添加一些其他信息,此时单元状态将按以下方式更新:

c ( t ) = ( c ( t 1 − ) ⊙ f t ) ⊕ ( i t ⊙ g t ) c^(t)=(c^(t1-)\\odot f_t)\\oplus(i_t\\odot g_t) c(t)=(c(t1)ft)(itgt)

得到的新的状态向量 c ( t ) c^(t) c(t) 即为当前时间戳的状态向量。

2.2.3 输入门

最后一个门称为输出门,我们需要指定输入组合和单元状态的哪一部分需要传递到下一个时刻,输入组合包括当前时间戳的输入和前一时间戳的输出值:

o t = σ ( W x o x ( t ) + W h o h ( t − 1 ) + b o ) o_t=\\sigma(W_xox^(t)+W_hoh^(t-1)+b_o) ot=σ(Wxox(t)+Whoh(t1)+bo)

最终的网络状态值表示如下:

h ( t ) = o t ⊙ t a n h ( c ( t ) ) h^(t)=o_t\\odot tanh(c^(t)) h(t)=ottanh(c(t))

这样,我们就可以利用 LSTM 中的各个门来有选择地识别需要存储在存储器中的信息,从而克服了 RNN 的局限性。

3. 从零开始实现 LSTM

在本小节中,我们通过使用一个简单示例来了解 LSTM 的工作原理。

3.1 LSTM 模型实现

(1) 对输入数据进行预处理,该示例所用输入数据与预处理过程与在 RNN 模型中使用的完全相同:

# 定义输入与输出数据
docs = ['this is','is an']
# define class labels
labels = ['an','example']

from collections import Counter
counts = Counter()
for i,review in enumerate(docs+labels):
    counts.update(review.split())
words = sorted(counts, key=counts.get, reverse=True)
vocab_size=len(words)
word_to_int = word: i for i, word in enumerate(words, 1)

encoded_docs = []
for doc in docs:
    encoded_docs.append([word_to_int[word] for word in doc.split()])
encoded_labels = []
for label in labels:
    encoded_labels.append([word_to_int[word] for word keras深度学习实战——使用长短时记忆网络构建情感分析模型(代码片段)

Keras深度学习实战——使用长短时记忆网络构建情感分析模型0.前言1.构建LSTM模型进行情感分类1.1数据集分析1.2模型构建2.构建多层LSTM进行情感分类相关链接0.前言我们已经学习了如何使用循环神经网络(Recurrentneuralnetworks,RNN)构建... 查看详情

keras深度学习实战(33)——基于lstm的序列预测模型(代码片段)

Keras深度学习实战(33)——基于LSTM的序列预测模型0.前言1.序列学习任务1.1命名实体提取1.2文本摘要1.3机器翻译2.从输出网络返回输出序列2.1传统模型体系结构2.2返回每个时间戳的网络中间状态序列2.3使用双向LSTM网络小... 查看详情

keras深度学习实战(32)——基于lstm预测股价(代码片段)

Keras深度学习实战(32)——基于LSTM预测股价0.前言1.模型与数据集分析1.1数据集分析1.2模型分析2.基于长短时记忆网络LSTM预测股价2.1根据最近五天的股价预测股价2.2考虑最近五天的股价和新闻数据小结系列链接0.前言股价... 查看详情

(转)零基础入门深度学习-长短时记忆网络(lstm)

...理大数据的时代,作为一个有理想有追求的程序员,不懂深度学习(DeepLearning)这个超热的技术,会不会感觉马上就out了?现在救命稻草来了,《零基础入门深度学习》系列文章旨在讲帮助爱编程的你从零基础达到入门级水平。... 查看详情

keras深度学习实战(40)——音频生成(代码片段)

Keras深度学习实战(40)——音频生成0.前言1.模型与数据集分析1.1数据集分析1.2模型分析2.音频生成模型2.1数据集加载与预处理2.2模型构建与训练小结系列链接0.前言我们已经在《文本生成模型》一节中学习了如何利用深... 查看详情

深度学习100例——利用pytorch长短期记忆网络lstm实现股票预测分析|第5例

前言大家好,我是阿光。本专栏整理了《深度学习100例》,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集。正在更新中~✨ 查看详情

keras深度学习实战——卷积神经网络的局限性(代码片段)

Keras深度学习实战(9)——卷积神经网络的局限性0.前言1.卷积神经网络的局限性2.情景1——训练数据集图像尺寸较大3.情景2——训练数据集图像尺寸较小4.情景3——在训练尺寸较大的图像时使用更大池化小结系列链接0.... 查看详情

keras深度学习实战——卷积神经网络详解与实现(代码片段)

Keras深度学习实战(7)——卷积神经网络详解与实现0.前言1.传统神经网络的缺陷1.1构建传统神经网络1.2传统神经网络的缺陷2.使用Python从零开始构建CNN2.1卷积神经网络的基本概念2.2卷积和池化相比全连接网络的优势3.使用... 查看详情

keras深度学习实战(22)——生成对抗网络详解与实现(代码片段)

Keras深度学习实战(22)——生成对抗网络详解与实现0.前言1.生成对抗网络原理2.模型分析3.利用生成对抗网络生成手写数字图像小结系列链接0.前言生成对抗网络(GenerativeAdversarialNetworks,GAN)使用神经网络生成与原始图像集... 查看详情

keras深度学习实战——使用循环神经网络构建情感分析模型(代码片段)

Keras深度学习实战——使用循环神经网络构建情感分析模型0.前言1.使用循环神经网络构建情感分析模型1.1数据集分析1.2构建RNN模型进行情感分析相关链接0.前言在《循环神经详解与实现》一节中,我们已经了解循环神经网络(Recurr... 查看详情

keras深度学习实战(10)——迁移学习(代码片段)

Keras深度学习实战(10)——迁移学习0.前言1.迁移学习1.1迁移学习原理1.2ImageNet数据集介绍2.利用预训练VGG16模型进行性别分类2.1VGG16架构2.2微调模型2.3错误分类的图片示例小结系列链接0.前言在《卷积神经网络的局限性》... 查看详情

keras深度学习实战——使用keras构建神经网络(代码片段)

Keras深度学习实战(2)——使用Keras构建神经网络0前言1.Keras简介与安装2.Keras构建神经网络初体验3.训练香草神经网络3.1香草神经网络与MNIST数据集介绍3.2训练神经网络步骤回顾3.3使用Keras构建神经网络模型3.4关键步骤总结... 查看详情

备战数学建模49-深度学习之长短期记忆网络lstm(rnn)(攻坚战14)(代码片段)

...络,就是LSTM,它是循环神经网络的一种。在使用深度学习处理时序问题时,RNN是最常使用的模型之一。RNN之所以在时序数据上有着优异的表现是因为RNN在 t 时间片时会将 t−1 时间片的隐节点作为当前时间片的输入... 查看详情

备战数学建模49-深度学习之长短期记忆网络lstm(rnn)(攻坚战14)(代码片段)

...络,就是LSTM,它是循环神经网络的一种。在使用深度学习处理时序问题时,RNN是最常使用的模型之一。RNN之所以在时序数据上有着优异的表现是因为RNN在 t 时间片时会将 t−1 时间片的隐节点作为当前时间片的输入... 查看详情

keras深度学习实战——使用循环神经网络构建情感分析模型(代码片段)

Keras深度学习实战——使用循环神经网络构建情感分析模型0.前言1.使用循环神经网络构建情感分析模型1.1数据集分析1.2构建RNN模型进行情感分析相关链接0.前言在《循环神经详解与实现》一节中,我们已经了解循环神经网络(R... 查看详情

keras深度学习实战(19)——使用对抗攻击生成可欺骗神经网络的图像(代码片段)

Keras深度学习实战(19)——使用对抗攻击生成可欺骗神经网络的图像0.前言1.对抗攻击简介2.对抗攻击模型分析2.1模型识别图像流程2.2对抗攻击流程3.使用Keras实现对抗攻击小结系列链接0.前言近年来,深度学习在图像... 查看详情

[人工智能-深度学习-53]:lstm长短记忆时序模型的简化gru模型

...反向传播第1章前序知识1.1RNN循环神经网络模型[人工智能-深度学习-51]:循环神经网络RNN基本原理详解_文火冰糖(王文兵)的博客-CSDN博客https://blog.csdn.net/HiWangWenBing/article/details/1213872851.2 LSTM长短记忆时序模型https://bl... 查看详情

keras深度学习实战——新闻文本分类(代码片段)

Keras深度学习实战(9)——新闻文本分类0.前言1.新闻文本分类任务与神经网络模型分析1.1数据集1.2神经网络模型2.使用神经网络进行新闻文本分类小结系列链接0.前言在先前的应用实战中,我们分析了结构化的数据集&... 查看详情