Tensorflow 2에서 tf.GradientTape를 사용한 훈련이 fit API를 사용한 훈련과 다른 동작을하는 이유는 무엇입니까?

Wakeme UpNow

Tensorflow 2를 처음 사용합니다.

나는 kerastensorflow 1에서 사용 하는 것에 익숙합니다 . 그리고 저는 일반적으로 fit모델을 훈련시키기 위해 메소드 API를 사용 합니다. 그러나 최근 tensorflow 2에서 그들은 열성적인 실행 을 도입했습니다 . 내가 구현 모두에서 CiFAR-10 데이터 세트에 대한 간단한 이미지 분류를 비교 그래서 fittf.GradientTape20 개 시대 각각 훈련

여러 번 실행 한 후 결과는 다음과 같습니다.

  • fitAPI로 학습 된 모델
    • 훈련 데이터 세트, 손실은 약 0.61-0.65이며 정확도는 76 %-80 %입니다.
    • 검증 데이터 세트, 손실은 약 0.8이며 정확도는 72 %-75 %입니다.
  • 훈련 된 모델 tf.GradientTape
    • 훈련 데이터 세트, 손실은 약 0.15-0.2이며 정확도는 91 %-94 %입니다.
    • 검증 데이터 세트, 손실은 약 1.8-2이며 정확도는 64 %-67 %입니다.

모델이 다른 동작을 보이는 이유를 잘 모르겠습니다. 뭔가 잘못 구현할 수 있다고 생각합니다. tf.GradientTape모델 에서 훈련 데이터 세트를 더 빨리 과적 합하기 시작 하는 것이 이상하다고 생각합니다.

다음은 몇 가지 스 니펫입니다.

  1. fitAPI 사용
model = SimpleClassifier(10)
model.compile(
    optimizer=Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=[tf.keras.metrics.CategoricalAccuracy()]
)
model.fit(X[:split_idx, :, :, :], y[:split_idx, :], batch_size=256, epochs=20, validation_data=(X[split_idx:, :, :, :], y[split_idx:, :]))
  1. 사용 tf.GradientTape
with tf.GradientTape() as tape:
    y_pred = model(tf.stop_gradient(train_X))
    loss = loss_fn(train_y, y_pred)
    gradients = tape.gradient(loss, model.trainable_weights)
model.optimizer.apply_gradients(zip(gradients, model.trainable_weights))

전체 코드는 여기 Colab에서 찾을 수 있습니다.

참고 문헌

세바스찬 아니

tf.GradientTape수정 될 수 있는 코드 에는 몇 가지 가 있습니다.
1) trainable_variablesnot trainable_weights. 모델 가중치뿐만 아니라 모든 학습 가능한 변수에 기울기를 적용하려고합니다.

# gradients = tape.gradient(loss, model.trainable_weights)
gradients = tape.gradient(loss, model.trainable_variables)

# and

# model.optimizer.apply_gradients(zip(gradients, model.trainable_weights))
model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))

2) tf.stop_gradient입력 텐서에서 제거 합니다.

with tf.GradientTape() as tape:
#    y_pred = model(tf.stop_gradient(train_X))
    y_pred = model(train_X, training=True)

훈련 매개 변수도 추가했습니다. 또한 모델 정의에 포함되어야합니다 phase(예 : BatchNormalization 및 Dropout) :

    def call(self, X, training=None):
        X = self.cnn_1(X)
        X = self.bn_1(X, training=training)
        X = self.cnn_2(X)
        X = self.max_pool_2d(X)
        X = self.dropout_1(X)

        X = self.cnn_3(X)
        X = self.bn_2(X, training=training)
        X = self.cnn_4(X)
        X = self.bn_3(X, training=training)
        X = self.cnn_5(X)
        X = self.max_pool_2d(X)
        X = self.dropout_2(X)

        X = self.flatten(X)
        X = self.dense_1(X)
        X = self.dropout_3(X, training=training)
        X = self.dense_2(X)
        return self.out(X)

이러한 몇 가지 변경 사항으로 keras.fit결과에 더 유사한 약간 더 나은 점수를 얻을 수있었습니다 .

[19/20] loss: 0.64020, acc: 0.76965, val_loss: 0.71291, val_acc: 0.75318: 100%|██████████| 137/137 [00:12<00:00, 11.25it/s]
[20/20] loss: 0.62999, acc: 0.77649, val_loss: 0.77925, val_acc: 0.73219: 100%|██████████| 137/137 [00:12<00:00, 11.30it/s]

대답 : 차이점은 아마도 Keras.fit이러한 작업의 대부분을 내부적으로 수행 한 사실이었습니다 .

마지막으로, 명확성과 재현성을 위해 내가 사용한 부분 훈련 / 평가 코드 :

for bIdx, (train_X, train_y) in enumerate(train_batch):
            if bIdx < epoch_max_iter:
                with tf.GradientTape() as tape:
                    y_pred = model(train_X, training=True)
                    loss = loss_fn(train_y, y_pred)
                    total_loss += (np.sum(loss.numpy()) * train_X.shape[0])
                    total_num += train_X.shape[0]
                    # gradients = tape.gradient(loss, model.trainable_weights)
                    gradients = tape.gradient(loss, model.trainable_variables)
                total_acc += (metrics(train_y, y_pred) * train_X.shape[0])

                running_loss = (total_loss/total_num)
                running_acc = (total_acc/total_num)
                # model.optimizer.apply_gradients(zip(gradients, model.trainable_weights))
                model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))

                pbar.set_description("[{}/{}] loss: {:.5f}, acc: {:.5f}".format(e, epochs, running_loss, running_acc))
                pbar.refresh()
                pbar.update()

그리고 평가 하나 :

# Eval loop
        # Calculate something wrong here
        val_total_loss = 0
        val_total_acc = 0
        total_val_num = 0
        for bIdx, (val_X, val_y) in enumerate(val_batch):
            if bIdx >= max_val_iterations:
                break
            y_pred = model(val_X, training=False)

이 기사는 인터넷에서 수집됩니다. 재 인쇄 할 때 출처를 알려주십시오.

침해가 발생한 경우 연락 주시기 바랍니다[email protected] 삭제

에서 수정
0

몇 마디 만하겠습니다

0리뷰
로그인참여 후 검토

관련 기사

Related 관련 기사

뜨겁다태그

보관