keras에서 훈련 할 때 기울기 계산을 어떻게 사용자 정의 할 수 있습니까?

지오 필

keras 레이어에 자연스러운 그라디언트를 구현하고 싶습니다. 이것은 이미 제자리에있는 사용자 정의 된 그라디언트 내에서 발생해야합니다. 옵티 마이저를 호출 할 때 계산해야 할 구현 (일반 또는 자연 그라디언트)을 선택할 수 있기를 원합니다.

내가 직면 한 문제는 내가 nat_grad=True(그래프 빌드 시간이 아닌) 훈련 시간에 부울 을 Op에 전달할 때 AutoGraph가 행복하지 않다는 것입니다.

현재 무슨 일이 일어나고 있는지 의사 코드는 다음과 같습니다.

@tf.custom_gradient
def MyOp(inputs, w, nat_grad=False):
    output = w*inputs
    def grad(dy):
        if nat_grad:
            return dy, 1.0
        else:
            return -dy, -1.0
    return output, grad


class MyKerasLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.nat_grad = False
        
    def build(self, input_shape):
        self.w = self.add_weight("w", dtype=tf.float32, trainable=True, initializer=tf.random_normal_initializer)
        super().build(input_shape)

    def call(self, inputs):
        return MyOp(inputs, self.w, self.nat_grad)


class MyModel(tf.keras.Sequential):
    def __init__(self, num_layers):
        super().__init__([tf.keras.Input(shape=[1], batch_size=None, dtype=tf.float32)]+[MyKerasLayer() for _ in range(num_layers)])


def optimize(model, X, Y, nat_grad:bool):
    for layer in model.layers:
        layer.nat_grad = nat_grad
    model.fit(x=X, y=Y)

    
model = MyModel(5)
model.compile(optimizer='SGD', loss=lambda x,y:x-y, metrics=[])
X = np.array([1.0, 2.0, 3.0])
Y = np.array([1.0, 2.0, 3.0])
optimize(model, X, Y, nat_grad=True)
>>> OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

이를 수행하는 올바른 방법은 무엇입니까?

David Vander Mijnsbrugge

Tensorflow 2.x를 사용하면 함수를 tf.graphs [1]로 실행할 수 있습니다. 장식 그래서 grad(dy)함께하는 @tf.function작동해야하지만 이후 새로운 오류로 실행하겠습니다 MyOp소요 nat_grad입력으로이 변수 [2]에 대한 기울기를 기대합니다.

@tf.custom_gradient
def MyOp(inputs, w, nat_grad=False):
    output = w*inputs
    @tf.function
    def grad(dy):
        if nat_grad:
            return dy, 1.0, 0.
        else:
            return -dy, -1.0, 0.
    return output, grad

이것은 이것을 수행하는 방법이 아니며 오히려 그라디언트 op를 2 부분으로 나누고 call.

@tf.custom_gradient
def NatOp(inputs, w):
    output = w*inputs
    def grad(dy):
        return dy, 1.0     
    return output, grad

@tf.custom_gradient
def RegOp(inputs, w):
    output = w*inputs
    def grad(dy):
        return -dy, -1.0
    return output, grad

class MyKerasLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.nat_grad = False
        
    def build(self, input_shape):
        self.w = self.add_weight("w", dtype=tf.float32, trainable=True, initializer=tf.random_normal_initializer)
        super().build(input_shape)

    def call(self, inputs):
        return NatOp(inputs, self.w) if self.nat_grad else RegOp(inputs, self.w)

[1] https://www.tensorflow.org/api_docs/python/tf/function

[2] https://www.tensorflow.org/api_docs/python/tf/custom_gradient

이 기사는 인터넷에서 수집됩니다. 재 인쇄 할 때 출처를 알려주십시오.

침해가 발생한 경우 연락 주시기 바랍니다[email protected] 삭제

에서 수정
0

몇 마디 만하겠습니다

0리뷰
로그인참여 후 검토

관련 기사

분류에서Dev

Torch에서 훈련 시간에 사전 훈련 된 임베딩을 어떻게 수정할 수 있습니까?

분류에서Dev

keras 라이브러리를 사용하여 RNN을 훈련 할 때 계속 발생하는 차원 오류를 어떻게 수정할 수 있습니까?

분류에서Dev

sklearn에서 학습 알고리즘의 훈련 임계 값을 어떻게 변경할 수 있습니까?

분류에서Dev

Microsoft VBA 편집기에서 사용자 정의 색상을 어떻게 사용할 수 있습니까?

분류에서Dev

다른 글꼴 유형을 사용할 때 글꼴 크기를 어떻게 계산할 수 있습니까?

분류에서Dev

손실 함수는 PyTorch에서 기울기를 계산할 모델을 어떻게 알 수 있습니까?

분류에서Dev

특정보기를 선택할 때 수명주기 방법을 어떻게 사용할 수 있습니까?

분류에서Dev

데이터 주석을 사용할 때 사용자 정의 편집기 템플릿을 어떻게 사용할 수 있습니까?

분류에서Dev

메뚜기에서 "for"를 사용할 때 어떻게 트래픽을 제어 할 수 있습니까?

분류에서Dev

React App에서 DataTable을 어떻게 사용자 정의 할 수 있습니까?

분류에서Dev

사전 훈련 된 caffe 모델의 하위 집합을 어떻게 저장할 수 있습니까?

분류에서Dev

자동 레이아웃 종횡비 제약 조건을 사용하여 xib에서로드 된 UICollectionViewCell의 크기를 어떻게 계산할 수 있습니까?

분류에서Dev

C #에서 사용자 지정 직렬화를 사용할 때 XML 요소의 이름을 어떻게 제어 할 수 있습니까?

분류에서Dev

WIX에서 기존 사용자 인터페이스를 어떻게 사용자 정의 할 수 있습니까?

분류에서Dev

내 앱은 사용자의 iOS 기기에있는 소셜 앱을 어떻게 식별 할 수 있습니까?

분류에서Dev

희소 행렬에서 값을 정의 할 때 병렬 처리를 어떻게 활용할 수 있습니까?

분류에서Dev

어떻게 & 여기에 비트 연산자를 사용할 수 있습니까?

분류에서Dev

통계 모델 OLS에서 절편과 기울기를 어떻게 계산할 수 있습니까?

분류에서Dev

MATLAB에서 픽셀을 "교차하는"선의 기울기를 어떻게 감지 할 수 있습니까?

분류에서Dev

Python / Pybrain : 훈련 중에 신경망의 가중치를 어떻게 수정할 수 있습니까?

분류에서Dev

모든 행에서 다른 특정 문자의 발생을 어떻게 계산할 수 있습니까?

분류에서Dev

BASH에서 xmlstarlet을 사용하여 XML 문서의 요소 수를 어떻게 계산할 수 있습니까?

분류에서Dev

Linux에서 사용자의 기본 그룹을 어떻게 변경할 수 있습니까?

분류에서Dev

Linux에서 사용자의 기본 그룹을 어떻게 변경할 수 있습니까?

분류에서Dev

OpenGL을 사용하여 Java에서 내 서클의 크기를 어떻게 조정할 수 있습니까?

분류에서Dev

어떻게 Keras 만들 예측에 HDF5 파일에 저장 모델을 훈련 사용할 수 있습니까?

분류에서Dev

presentTextInputControllerWithSuggestions에서 받아쓰기보기를 어떻게 사용자 정의 할 수 있습니까?

분류에서Dev

SignalR을 사용할 때 파이프에서 특정 HttpModule을 어떻게 제거 할 수 있습니까?

분류에서Dev

정수 기반 입력 C ++의 사용자 입력에 쉼표를 어떻게 사용할 수 있습니까?

Related 관련 기사

  1. 1

    Torch에서 훈련 시간에 사전 훈련 된 임베딩을 어떻게 수정할 수 있습니까?

  2. 2

    keras 라이브러리를 사용하여 RNN을 훈련 할 때 계속 발생하는 차원 오류를 어떻게 수정할 수 있습니까?

  3. 3

    sklearn에서 학습 알고리즘의 훈련 임계 값을 어떻게 변경할 수 있습니까?

  4. 4

    Microsoft VBA 편집기에서 사용자 정의 색상을 어떻게 사용할 수 있습니까?

  5. 5

    다른 글꼴 유형을 사용할 때 글꼴 크기를 어떻게 계산할 수 있습니까?

  6. 6

    손실 함수는 PyTorch에서 기울기를 계산할 모델을 어떻게 알 수 있습니까?

  7. 7

    특정보기를 선택할 때 수명주기 방법을 어떻게 사용할 수 있습니까?

  8. 8

    데이터 주석을 사용할 때 사용자 정의 편집기 템플릿을 어떻게 사용할 수 있습니까?

  9. 9

    메뚜기에서 "for"를 사용할 때 어떻게 트래픽을 제어 할 수 있습니까?

  10. 10

    React App에서 DataTable을 어떻게 사용자 정의 할 수 있습니까?

  11. 11

    사전 훈련 된 caffe 모델의 하위 집합을 어떻게 저장할 수 있습니까?

  12. 12

    자동 레이아웃 종횡비 제약 조건을 사용하여 xib에서로드 된 UICollectionViewCell의 크기를 어떻게 계산할 수 있습니까?

  13. 13

    C #에서 사용자 지정 직렬화를 사용할 때 XML 요소의 이름을 어떻게 제어 할 수 있습니까?

  14. 14

    WIX에서 기존 사용자 인터페이스를 어떻게 사용자 정의 할 수 있습니까?

  15. 15

    내 앱은 사용자의 iOS 기기에있는 소셜 앱을 어떻게 식별 할 수 있습니까?

  16. 16

    희소 행렬에서 값을 정의 할 때 병렬 처리를 어떻게 활용할 수 있습니까?

  17. 17

    어떻게 & 여기에 비트 연산자를 사용할 수 있습니까?

  18. 18

    통계 모델 OLS에서 절편과 기울기를 어떻게 계산할 수 있습니까?

  19. 19

    MATLAB에서 픽셀을 "교차하는"선의 기울기를 어떻게 감지 할 수 있습니까?

  20. 20

    Python / Pybrain : 훈련 중에 신경망의 가중치를 어떻게 수정할 수 있습니까?

  21. 21

    모든 행에서 다른 특정 문자의 발생을 어떻게 계산할 수 있습니까?

  22. 22

    BASH에서 xmlstarlet을 사용하여 XML 문서의 요소 수를 어떻게 계산할 수 있습니까?

  23. 23

    Linux에서 사용자의 기본 그룹을 어떻게 변경할 수 있습니까?

  24. 24

    Linux에서 사용자의 기본 그룹을 어떻게 변경할 수 있습니까?

  25. 25

    OpenGL을 사용하여 Java에서 내 서클의 크기를 어떻게 조정할 수 있습니까?

  26. 26

    어떻게 Keras 만들 예측에 HDF5 파일에 저장 모델을 훈련 사용할 수 있습니까?

  27. 27

    presentTextInputControllerWithSuggestions에서 받아쓰기보기를 어떻게 사용자 정의 할 수 있습니까?

  28. 28

    SignalR을 사용할 때 파이프에서 특정 HttpModule을 어떻게 제거 할 수 있습니까?

  29. 29

    정수 기반 입력 C ++의 사용자 입력에 쉼표를 어떻게 사용할 수 있습니까?

뜨겁다태그

보관