公式文書によると、train()
またはeval()
を使用すると、特定のモジュールに影響があります。ただし、今はカスタムモジュールでも同様のことを実現したいと考えています。つまりtrain()
、eval()
電源を入れたときに何かを行い、電源を入れたときに別のことを行います。これどうやってするの?
はい、できます。
あなたがソースコードで見ることができるように、eval()
そしてtrain()
基本的に呼ばれるフラグを変更していますself.training
(それは再帰的に呼ばれることに注意してください):
def train(self: T, mode: bool = True) -> T:
self.training = mode
for module in self.children():
module.train(mode)
return self
def eval(self: T) -> T:
return self.train(False)
このフラグはすべてnn.Module
ので使用できます。カスタムモジュールがこの基本クラスを継承する場合、必要なことを達成するのは非常に簡単です。
import torch.nn as nn
class MyCustomModule(nn.Module):
def __init__(self):
super().__init__()
# [...]
def forward(self, x):
if self.training:
# train() -> training logic
else:
# eval() -> inference logic
この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。
侵害の場合は、連絡してください[email protected]
コメントを追加