train()またはeval()が呼び出されたときにカスタムpytorchモジュールの動作を変えることができますか?

ihdv:

公式文書によると、train()またはeval()使用すると、特定のモジュールに影響があります。ただし、今はカスタムモジュールでも同様のことを実現したいと考えています。つまりtrain()eval()電源を入れたときに何かを行い、電源を入れたときに別のことを行います。これどうやってするの?

ベリエル:

はい、できます。

あなたがソースコード見ることができるようにeval()そしてtrain()基本的に呼ばれるフラグを変更していますself.training(それは再帰的に呼ばれることに注意してください):

def train(self: T, mode: bool = True) -> T:
    self.training = mode
    for module in self.children():
        module.train(mode)
    return self

def eval(self: T) -> T:
    return self.train(False)

このフラグはすべてnn.Moduleので使用できますカスタムモジュールがこの基本クラスを継承する場合、必要なことを達成するのは非常に簡単です。

import torch.nn as nn


class MyCustomModule(nn.Module):
    def __init__(self):
        super().__init__()
        # [...]

    def forward(self, x):
        if self.training:
            # train() -> training logic
        else:
            # eval()  -> inference logic

この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。

侵害の場合は、連絡してください[email protected]

編集
0

コメントを追加

0

関連記事

Related 関連記事

ホットタグ

アーカイブ