我正在用训练模型tensorflow 2.0
。我的训练集中的图像具有不同的分辨率。我建立的模型可以处理可变分辨率(转换层,然后进行全局平均)。我的训练集很小,我想在一个批次中使用完整的训练集。
由于我的图片分辨率不同,因此无法使用model.fit()
。因此,我计划将每个样本分别通过网络传递,累积误差/梯度,然后应用一个优化程序步骤。我可以计算损失值,但是我不知道如何累计损失/梯度。如何累积损失/梯度,然后应用单个优化程序步骤?
代码:
for i in range(num_epochs):
print(f'Epoch: {i + 1}')
total_loss = 0
for j in tqdm(range(num_samples)):
sample = samples[j]
with tf.GradientTape as tape:
prediction = self.model(sample)
loss_value = self.loss_function(y_true=labels[j], y_pred=prediction)
gradients = tape.gradient(loss_value, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
total_loss += loss_value
epoch_loss = total_loss / num_samples
print(f'Epoch loss: {epoch_loss}')
如果我从以下陈述中正确理解:
如何累积损失/梯度,然后应用单个优化程序步骤?
@Nagabhushan尝试累积梯度,然后将优化应用于(平均)累积梯度。@TensorflowSupport提供的答案无法回答。为了仅执行一次优化并从多个磁带上累积梯度,可以执行以下操作:
for i in range(num_epochs):
print(f'Epoch: {i + 1}')
total_loss = 0
# get trainable variables
train_vars = self.model.trainable_variables
# Create empty gradient list (not a tf.Variable list)
accum_gradient = [tf.zeros_like(this_var) for this_var in train_vars]
for j in tqdm(range(num_samples)):
sample = samples[j]
with tf.GradientTape as tape:
prediction = self.model(sample)
loss_value = self.loss_function(y_true=labels[j], y_pred=prediction)
total_loss += loss_value
# get gradients of this tape
gradients = tape.gradient(loss_value, train_vars)
# Accumulate the gradients
accum_gradient = [(acum_grad+grad) for acum_grad, grad in zip(accum_gradient, gradients)]
# Now, after executing all the tapes you needed, we apply the optimization step
# (but first we take the average of the gradients)
accum_gradient = [this_grad/num_samples for this_grad in accum_gradient]
# apply optimization step
self.optimizer.apply_gradients(zip(accum_gradient,train_vars))
epoch_loss = total_loss / num_samples
print(f'Epoch loss: {epoch_loss}')
在训练循环中应避免使用tf.Variable(),因为在尝试将代码作为图形执行时会产生错误。如果您在训练函数中使用tf.Variable(),然后用“ @ tf.function”装饰它或应用“ tf.function(my_train_fcn)”以获得图形函数(即为了提高性能),则执行将增加错误。发生这种情况是因为对tf.Variable函数的跟踪导致的行为与渴望执行(分别为重新利用或创建)时所观察到的行为不同。您可以在tensorflow帮助页面中找到更多信息。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句