Tensorflow에서 Tensor 마스킹 및 인덱싱을 정확히 어떻게 수행해야합니까?

사자 별자리

저는 지금 2 년 동안 TF를 사용해 왔으며 각 프로젝트에서 마스킹에 대해 넌센스 오류가 많이 발생하는데, 일반적으로 도움이되지 않고 실제로 잘못된 것을 나타내지 않습니다. 또는 그보다 최악의 경우 결과는 잘못되었지만 오류는 없습니다. 나는 항상 더미 데이터로 훈련 루프 밖에서 코드를 테스트하고 괜찮습니다. 그러나 훈련 (적합이라고 부름)에서 TensorFlow가 정확히 무엇을 기대하는지 이해하지 못합니다. 예를 들어, 누군가이 코드가 바이너리 교차 엔트로피에서 작동하지 않는 이유를 알려주세요. 결과가 잘못되었으며 모델이 수렴되지 않지만이 경우 오류는 없습니다.

class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, y_true, y_pred):
        y_true = tf.squeeze(y_true)
        mask = tf.where(y_true!=2)
        y_true = tf.gather_nd(y_true, mask)
        y_pred = tf.gather_nd(y_pred, mask)
        loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        return tf.reduce_mean(loss)

이것이 올바르게 작동하는 동안 :

class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, y_true, y_pred):
        mask = tf.where(y_true!=2, True, False)
        y_true = y_true[mask]
        y_pred = y_pred[mask]
        loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        return tf.reduce_mean(loss)

그리고 범주 형 예의 경우 그 반대입니다. 마스크를 y_pred [mask], y_pred [mask [0]]와 같은 인덱스로 사용하거나 tf.squeeze () 등을 사용할 수 없습니다. 그러나 tf.gather_nd () 사용하면 작동합니다. 저는 항상 가능하다고 생각하는 모든 조합을 시도합니다. 왜 그렇게 단순한 것이 이렇게 힘들고 고통 스러워야 하는지를 이해하지 못합니다. Pytorch도 이와 같은가요? Pytorch에 유사한 성가신 세부 사항이 없다는 것을 알고 있으면 기꺼이 전환합니다.

편집 1 : 그들은 훈련 루프 밖에서 올바르게 작동하거나 더 정확하게 그래프 모드를 사용합니다.

y_pred = tf.random.uniform(shape=[10,], minval=0, maxval=1, dtype='float32')
y_true = tf.random.uniform(shape=[10,], minval=0, maxval=2, dtype='int32')

# first method
class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def call(self, y_true, y_pred):
        y_true = tf.squeeze(y_true)
        mask = tf.where(y_true!=2)
        y_true = tf.gather_nd(y_true, mask)
        y_pred = tf.gather_nd(y_pred, mask)
        loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        return tf.reduce_mean(loss)

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}

# instantiate
mbxe = MaskedBXE()
print(f'first snippet: {mbxe(y_true, y_pred).numpy()}')


# second method
class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, y_true, y_pred):
        mask = tf.where(y_true!=2, True, False)
        y_true = y_true[mask]
        y_pred = y_pred[mask]
        loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        return tf.reduce_mean(loss)
    
    def get_config(self):
        base_config = super().get_config()
        return {**base_config}
    
# instantiate
mbxe = MaskedBXE()
print(f'second snippet: {mbxe(y_true, y_pred).numpy()}')

첫 번째 스 니펫 : 1.2907861471176147

두 번째 스 니펫 : 1.2907861471176147

편집 2 : @jdehesa가 제안한 것처럼 그래프 모드에서 손실을 인쇄 한 후에는 다릅니다.

class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, y_true, y_pred):
        # first
        y_t = tf.squeeze(y_true)
        mask = tf.where(y_t!=2)
        y_t = tf.gather_nd(y_t, mask)
        y_p = tf.gather_nd(y_pred, mask)
        loss = tf.keras.losses.binary_crossentropy(y_t, y_p)
        first_loss =  tf.reduce_mean(loss)
        tf.print('first:')
        tf.print(first_loss, summarize=-1)
        # second
        mask = tf.where(y_true!=2, True, False)
        y_t = y_true[mask]
        y_p = y_pred[mask]
        loss = tf.keras.losses.binary_crossentropy(y_t, y_p)
        second_loss = tf.reduce_mean(loss)
        tf.print('second:')
        tf.print(second_loss, summarize=-1)
        return second_loss

먼저:

0.814215422

둘째:

0.787778914

먼저:

0.779697835

둘째:

0.802924752

. . .

이드 헤사

문제는 첫 번째 버전에서 실수로 브로드 캐스트 된 작업을 수행하여 잘못된 결과를 제공한다는 것입니다. 이것은 작업으로 (?, 1)인해 배치가 모양을 가진 경우 발생합니다 tf.squeeze. 이 예의 모양에 유의하십시오.

import tensorflow as tf

# Make random y_true and y_pred with shape (10, 1)
tf.random.set_seed(10)
y_true = tf.dtypes.cast(tf.random.uniform((10, 1), 0, 3, dtype=tf.int32), tf.float32)
y_pred = tf.random.uniform((10, 1), 0, 1, dtype=tf.float32)

# first
y_t = tf.squeeze(y_true)
mask = tf.where(y_t != 2)
y_t = tf.gather_nd(y_t, mask)
tf.print(tf.shape(y_t))
# [7]
y_p = tf.gather_nd(y_pred, mask)
tf.print(tf.shape(y_p))
# [7 1]
loss = tf.keras.losses.binary_crossentropy(y_t, y_p)
first_loss =  tf.reduce_mean(loss)
tf.print(tf.shape(loss), summarize=-1)
# [7]
tf.print(first_loss, summarize=-1)
# 0.884061277

# second
mask = tf.where(y_true!=2, True, False)
y_t = y_true[mask]
tf.print(tf.shape(y_t))
# [7]
y_p = y_pred[mask]
tf.print(tf.shape(y_p))
# [7]
loss = tf.keras.losses.binary_crossentropy(y_t, y_p)
tf.print(tf.shape(loss), summarize=-1)
# []
second_loss = tf.reduce_mean(loss)
tf.print(second_loss, summarize=-1)
# 1.15896356

첫 번째 버전에서 모두 y_ty_p교차 엔트로피가 기본적으로 계산되도록 7 × 7 텐서에 방송 될 한 다음 평균 "모든 대 모두". 두 번째 경우에는 해당 값의 각 쌍에 대해서만 교차 엔트로피가 계산되며 이는 올바른 작업입니다.

tf.squeeze위의 예 에서 작업 을 간단히 제거 하면 결과가 수정됩니다.

이 기사는 인터넷에서 수집됩니다. 재 인쇄 할 때 출처를 알려주십시오.

침해가 발생한 경우 연락 주시기 바랍니다[email protected] 삭제

에서 수정
0

몇 마디 만하겠습니다

0리뷰
로그인참여 후 검토

관련 기사

분류에서Dev

"방송 및 부울 마스킹을 사용한 팬시 인덱싱"은 어떻게 작동합니까?

분류에서Dev

MEAN 스택에서 일회성 MongoDB 인덱싱을 수행하려면 어떻게해야합니까?

분류에서Dev

Django (drf 및 simplejwt)에서 JWT 기반 인증을 정확히 어떻게 구현해야합니까?

분류에서Dev

Pandas에서 고급 인덱싱을 사용하여 값 마스킹 / 수정

분류에서Dev

프로세스에서 지침을 수정하려면 어떻게해야합니까? Linux 및 ARMv7

분류에서Dev

Amazon ec2 인스턴스에서 자동 확장을 수행하려면 어떻게해야합니까?

분류에서Dev

행 및 열 인덱스로 Numpy 행렬에 요소를 삽입하려면 어떻게해야합니까?

분류에서Dev

새 인스턴스를 마지막 인스턴스보다 정확히 1 더 높이려면 어떻게해야합니까? (UML에서 Java 코드로)

분류에서Dev

Pandas의 특정 인덱스 이전에 모든 행을 가져 오려면 어떻게해야합니까?

분류에서Dev

Tensorflow는 Tensorflow Tensor에서 고유 한 값의 인덱스를 어떻게 얻습니까?

분류에서Dev

R에서 인덱싱 작업을 수행하려면 어떻게해야합니까?

분류에서Dev

비 Ubuntu Linux 커널 4.18에서 스냅 실행을 수정하려면 어떻게해야합니까?

분류에서Dev

C 또는 C ++에서 멀티 태스킹을 수행하려면 어떻게해야합니까?

분류에서Dev

JavaScript에서 이름, 중간 이름, 성을 마스킹하려면 어떻게해야합니까?

분류에서Dev

정수 setter와 바인딩 setter를 XAML 및 C #에서도 사용할 수있는 클래스로 결합하려면 어떻게해야합니까?

분류에서Dev

Tensorflow 2 LSTM 훈련에서 다중 출력을 어떻게 마스킹합니까?

분류에서Dev

람다 식에 정확히 한 번 변수 값을 제공하려면 어떻게해야합니까?

분류에서Dev

sails.js 및 워터 라인에서 중첩 조인을 수행하려면 어떻게해야합니까?

분류에서Dev

다른 차원으로 PyTorch / Numpy에서 마스킹을 어떻게 수행합니까?

분류에서Dev

pentaho PDI (spoon)로 데이터 마스킹을 어떻게 수행해야합니까?

분류에서Dev

Bash에서 편집 한 히스토리 라인을 재설정하려면 어떻게해야합니까?

분류에서Dev

for 루프의 이미지 마스킹을 파이썬의 논리적 인덱싱으로 대체하려면 어떻게해야합니까?

분류에서Dev

SSH에서 내 서버에 원격으로 액세스하려면 정확히 어떻게해야합니까?

분류에서Dev

Pyspark Dataframe의 특정 인덱스에서 행을 추가하거나 대체하려면 어떻게해야합니까?

분류에서Dev

Elasticsearch에서 배열 필드로 정확히 구문을 찾으려면 어떻게해야합니까?

분류에서Dev

Notepad ++에서 선택한 텍스트의 단어 수를 확인하려면 어떻게해야합니까?

분류에서Dev

새로운 단일 연결 목록에서 단일 연결 목록의 홀수 인덱싱 된 노드를 반환하려면 어떻게해야합니까? 첫 번째 노드의 인덱스를 1로 가정합니다.

분류에서Dev

Firefox 브라우저 서명을 어떻게 마스킹합니까?

분류에서Dev

한 행에서 정확히 3 번 발생하는 기준을 충족하는 ID에 대한 행을 반환하려면 어떻게해야합니까?

Related 관련 기사

  1. 1

    "방송 및 부울 마스킹을 사용한 팬시 인덱싱"은 어떻게 작동합니까?

  2. 2

    MEAN 스택에서 일회성 MongoDB 인덱싱을 수행하려면 어떻게해야합니까?

  3. 3

    Django (drf 및 simplejwt)에서 JWT 기반 인증을 정확히 어떻게 구현해야합니까?

  4. 4

    Pandas에서 고급 인덱싱을 사용하여 값 마스킹 / 수정

  5. 5

    프로세스에서 지침을 수정하려면 어떻게해야합니까? Linux 및 ARMv7

  6. 6

    Amazon ec2 인스턴스에서 자동 확장을 수행하려면 어떻게해야합니까?

  7. 7

    행 및 열 인덱스로 Numpy 행렬에 요소를 삽입하려면 어떻게해야합니까?

  8. 8

    새 인스턴스를 마지막 인스턴스보다 정확히 1 더 높이려면 어떻게해야합니까? (UML에서 Java 코드로)

  9. 9

    Pandas의 특정 인덱스 이전에 모든 행을 가져 오려면 어떻게해야합니까?

  10. 10

    Tensorflow는 Tensorflow Tensor에서 고유 한 값의 인덱스를 어떻게 얻습니까?

  11. 11

    R에서 인덱싱 작업을 수행하려면 어떻게해야합니까?

  12. 12

    비 Ubuntu Linux 커널 4.18에서 스냅 실행을 수정하려면 어떻게해야합니까?

  13. 13

    C 또는 C ++에서 멀티 태스킹을 수행하려면 어떻게해야합니까?

  14. 14

    JavaScript에서 이름, 중간 이름, 성을 마스킹하려면 어떻게해야합니까?

  15. 15

    정수 setter와 바인딩 setter를 XAML 및 C #에서도 사용할 수있는 클래스로 결합하려면 어떻게해야합니까?

  16. 16

    Tensorflow 2 LSTM 훈련에서 다중 출력을 어떻게 마스킹합니까?

  17. 17

    람다 식에 정확히 한 번 변수 값을 제공하려면 어떻게해야합니까?

  18. 18

    sails.js 및 워터 라인에서 중첩 조인을 수행하려면 어떻게해야합니까?

  19. 19

    다른 차원으로 PyTorch / Numpy에서 마스킹을 어떻게 수행합니까?

  20. 20

    pentaho PDI (spoon)로 데이터 마스킹을 어떻게 수행해야합니까?

  21. 21

    Bash에서 편집 한 히스토리 라인을 재설정하려면 어떻게해야합니까?

  22. 22

    for 루프의 이미지 마스킹을 파이썬의 논리적 인덱싱으로 대체하려면 어떻게해야합니까?

  23. 23

    SSH에서 내 서버에 원격으로 액세스하려면 정확히 어떻게해야합니까?

  24. 24

    Pyspark Dataframe의 특정 인덱스에서 행을 추가하거나 대체하려면 어떻게해야합니까?

  25. 25

    Elasticsearch에서 배열 필드로 정확히 구문을 찾으려면 어떻게해야합니까?

  26. 26

    Notepad ++에서 선택한 텍스트의 단어 수를 확인하려면 어떻게해야합니까?

  27. 27

    새로운 단일 연결 목록에서 단일 연결 목록의 홀수 인덱싱 된 노드를 반환하려면 어떻게해야합니까? 첫 번째 노드의 인덱스를 1로 가정합니다.

  28. 28

    Firefox 브라우저 서명을 어떻게 마스킹합니까?

  29. 29

    한 행에서 정확히 3 번 발생하는 기준을 충족하는 ID에 대한 행을 반환하려면 어떻게해야합니까?

뜨겁다태그

보관