私は2つのMLPを一緒にトレーニングしようとしています。それぞれが、異なる実数値変数を予測します。これらの2つの出力での損失を最小限に抑えたいのですが、いくつかの「ウォームアップ」反復のためにそのうちの1つを修正したいと思います。
私はテンソルフローに不慣れですが、基本的にはPytorchで次のようなものを探しています。
def loss(self, *args, **kwargs) -> torch.Tensor:
# Extract data
data, target, probability = args
# Iterate through each model and sum nll
nll = []
for index in range(self.num_models):
# Extract mean and variance from prediction
if self._current_it < self.warm_start_it:
predictive_mean = self.mean[index](data)
with torch.no_grad():
predictive_variance = softplus(self.variance[index](data))
else:
with torch.no_grad():
predictive_mean = self.mean[index](data)
predictive_variance = softplus(self.variance[index](data))
# Calculate the loss
nll.append(self.calculate_nll(target, predictive_mean, predictive_variance))
mean_nll = torch.stack(nll).mean()
# Update current iteration
if self.training:
self._current_it += 1
return mean_nll
モデルのcall()
関数内で同様のことができると思います。つまり、次のようになります。
def call(self, step, inputs, training=None, mask=None):
if step < self.warmup:
with tf.GradientTape() as t:
mean_predictions = self.mean(inputs)
var_predictions = self.variance(inputs)
else:
mean_predictions = self.mean(inputs)
with tf.GradientTape() as t:
var_predictions = self.variance(inputs)
return mean_predictions, var_predictions
これは、上記のPytorchと同等のものを取得する正しい方法ですか?
私は次のことをすることになりました:
メインループでは、
mlp = UncertaintyMLP(805, 1)
loss_fn = GaussianNLL()
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
epochs = 1000
for epoch in range(epochs):
for step, (x_batch, y_batch) in enumerate(train_dataset):
if epoch > mlp.warmup:
for layer in mlp.mean.layers:
layer.trainable = False
for layer in mlp.variance.layers:
layer.trainable = True
with tf.GradientTape() as tape:
output = mlp(step, x_batch)
loss = loss_fn(y_batch, output)
grads = tape.gradient(loss, mlp.trainable_weights)
optimizer.apply_gradients(zip(grads, mlp.trainable_weights))
およびモデルクラス:
def call(self, step, inputs, training=None, mask=None):
mean_predictions = self.mean(inputs)
var_predictions = tf.math.softplus(self.variance(inputs)
return mean_predictions, var_predictions
ただし、TensorflowがPytorchに相当するものがある場合は、それが何であるかを知りたいと思っていtorch.no_grad()
ます。
この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。
侵害の場合は、連絡してください[email protected]
コメントを追加