我是机器学习和keras的新手。我试图为我的分类问题创建LSTM模型,但收到此错误:(我从互联网上获得了一些样本并试图对其进行修改)
ValueError:输入0与连续图层_1不兼容:预期形状=(无,无,30),发现形状= [无,3,1]这就是我所需要的,我有一个像这样的1,2,3,4序列,其中1,2,3是我的X_train,4是label(Y),所以我的意思是时间步长是3,每个步长只有一个功能
我的标签有30节课。因此,我希望输出是这30个类之一。64是存储单元数。
这是我的代码
def get_lstm():
model = Sequential()
model.add(LSTM(64, input_shape=(3, 30), return_sequences=True))
model.add(LSTM(64))
model.add(Dropout(0.2))
model.add(Dense(30, activation='softmax'))
X_train = user_data[:, 0:3]
X_train = np.asarray(X_train).astype(np.float32)
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
Y_train = user_data[:, 3]
Y_train = np.asarray(Y_train).astype(np.float32)
local_model = Mymodel.get_lstm()
local_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=['accuracy'])
local_model.set_weights(global_weights)
local_model.fit(X_train, Y_train, batch_size=32,
epochs=1)
如果您需要更多信息或不清楚,请告诉我。我真的需要你们的帮助,谢谢
不知道为什么要像(3,30)
第一个LSTM一样设置输入形状。如您所言-
这就是我所需要的,我有一个像这样的1,2,3,4序列,其中1,2,3是我的X_train和4是label(Y)。所以我的意思是时间步长是3,每个步长只有一个功能
如果您有3个步骤,并且只有一个功能,则应该这样定义每个序列。
另外,由于模型将始终输出30个长度的概率分布,但是y_train是单个值(唯一的30个类),因此您需要使用losssparse_categorical_crossentropy
而不是categorical_crossentropy
。在这里阅读更多。
from tensorflow.keras import layers, Model, utils
#Dummy data and its shapes
X = np.random.random((100,3,1)) #(100,3,1)
y = np.random.randint(0,29,(100,)) #(100,)
#Design model
inp = layers.Input((3,1))
x = layers.LSTM(64, return_sequences=True)(inp)
x = layers.LSTM(64)(x)
x = layers.Dropout(0.2)(x)
out = layers.Dense(30, activation='softmax')(x)
model = Model(inp, out)
#Compile and fit
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=['accuracy'])
model.fit(X, y, batch_size=32,epochs=3)
Epoch 1/3
4/4 [==============================] - 0s 4ms/step - loss: 3.4005 - accuracy: 0.0400
Epoch 2/3
4/4 [==============================] - 0s 5ms/step - loss: 3.3953 - accuracy: 0.0700
Epoch 3/3
4/4 [==============================] - 0s 8ms/step - loss: 3.3902 - accuracy: 0.0900
utils.plot_model(model, show_layer_names=False, show_shapes=True)
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句