我想创建一个自动编码器子类,作为Keras Model类的子类,我不知道是否有AutoEncoder
必要分别创建编码器和解码器并将它们组合成一个新类,或者我需要在同一类中创建编码器和解码器。
这是一类超简单自动编码器的示例:
INPUT_SHAPE = 254
class AutoEncoder(tf.keras.Model):
def __init__(self):
super().__init__()
# Encoder
self.dense1 = tf.keras.layers.Dense(128, input_shape=(INPUT_SHAPE,), activation='relu')
self.dense2 = tf.keras.layers.Dense(INPUT_SHAPE, activation='relu')
#Decoder
self.dense3 = tf.keras.layers.Dense(128, activation='relu')
self.dense4 = tf.keras.layers.Dense(3, activation='sigmoid')
def __call__(self, inp, training=False):
x = self.dense1(inp)
x = self.dense2(x)
x = self.dense3(x)
x = self.dense4(x)
return x
这将是encoder
并且decoder
是单独的类,我的疑问是如何将两者结合起来?或在这种情况下创建自动编码器的最佳方法是什么。
class Encoder(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(128, input_shape=(INPUT_SHAPE,), activation='relu')
self.dense2 = tf.keras.layers.Dense(INPUT_SHAPE, activation='relu')
def __call__(self, inp, training=False):
x = self.dense1(x)
x = self.dense2(x)
return x
class Decoder(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense3 = tf.keras.layers.Dense(128, input_shape=(INPUT_SHAPE,), activation='relu')
self.dense4 = tf.keras.layers.Dense(3, activation='sigmoid')
def __call__(self, inp, training=False):
x = self.dense3(x)
x = self.dense4(x)
return x
您提供的代码中有2个小错误。好像INPUT_SHAPE
没有提供__init__
。另外,使用call
method代替__call__
class Encoder(tf.keras.Model):
def __init__(self, INPUT_SHAPE):
super(Encoder, self).__init__()
self.dense1 = tf.keras.layers.Dense(128, input_shape=(INPUT_SHAPE,), activation='relu')
self.dense2 = tf.keras.layers.Dense(INPUT_SHAPE, activation='relu')
def call(self, inp, training=False):
x = self.dense1(x)
x = self.dense2(x)
return x
class Decoder(tf.keras.Model):
def __init__(self, INPUT_SHAPE):
super(Decoder, self).__init__()
self.dense3 = tf.keras.layers.Dense(128, input_shape=(INPUT_SHAPE,), activation='relu')
self.dense4 = tf.keras.layers.Dense(3, activation='sigmoid')
def call(self, inp, training=False):
x = self.dense3(x)
x = self.dense4(x)
return x
一旦解决这些问题。您可以使用以下方法定义AE
class AE(tf.keras.Model):
def __init__(self, INPUT_SHAPE):
super(AE, self).__init__()
self.encoder = Encoder(INPUT_SHAPE)
self.decoder = Decoder(INPUT_SHAPE)
def call(self, inp):
out_encoder = self.encoder(inp)
out_decoder = self.decoder(out_encoder)
return out_encoder, out_decoder
是时候采取行动了。让我们实例化此类并检查对象。
INPUT_SHAPE = 10
model = AE(10)
model
>>>
<__main__.AE at 0x7f5bb4ef8dd8>
您还可以检查编码器和解码器
model.encoder
model.decoder
这将给<__main__.Encoder at 0x7f5bb4ed2710>
和<__main__.Decoder at 0x7f5bb4ec99e8>
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句