在tensorflow 1,还有的层tf.compat.v1.keras.layers.CuDNNLSTM
,而在tensorflow 2这一层已经被赞成使用的不提倡使用的是使用cuDNN内置tf.keras.layers.LSTM
有
1. `activation` == `tanh`
2. `recurrent_activation` == `sigmoid`
3. `recurrent_dropout` == 0
4. `unroll` is `False`
5. `use_bias` is `True`
6. Inputs are not masked or strictly right padded.
用于cuDNN实现。我不知道是否存在未实现的bug或某些差异,但是CuDNNLSTM
使用输入偏差和循环偏差似乎有所不同,如LSTM
上述tf2 cuDNN规则所述,该方法仅使用循环偏差。
相关代码
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
print(tf.__version__)
model1 = Sequential()
model1.add(LSTM(1, activation='tanh', recurrent_dropout=0, unroll=False, use_bias=True, return_sequences=0, input_shape=(1, 1)))
print(model1.summary())
model2 = Sequential()
model2.add(CuDNNLSTM(1, return_sequences=0, input_shape=(1, 1)))
print(model2.summary())
2.2.0
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm (LSTM) (None, 1) 12
=================================================================
Total params: 12
Trainable params: 12
Non-trainable params: 0
_________________________________________________________________
None
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
cu_dnnlstm (CuDNNLSTM) (None, 1) 16
=================================================================
Total params: 16
Trainable params: 16
Non-trainable params: 0
_________________________________________________________________
请注意,总参数之间的差异为N_units * 4
,表示每个参数缺少一个额外的偏差向量。
请注意,LSTM的pytorch实现与tf1 CuDNNLSTM匹配,这是我偶然发现的方法。
我是否缺少一些修复程序,或者应该将其提升为github问题?
不,这不是错误。
的2x偏差CuDNNLSTM
是的单独偏差recurrent kernel
。
当在CuDNNLSTM
中提供时tf.keras.layers.LSTM
,您会看到代码以这样的方式编写:它不对a使用单独的偏向,recurrent kernel
而是调用LSTMCell
作为基类并且没有单独的偏向。
您可以model.layers[0].trainable_weights
用来查看两个实现之间的偏差形状差异。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句