例:
import torch
pred = torch.tensor([1,2,1,0,0], device='cuda:0')
correct = torch.tensor([1,0,1,1,0], device='cuda:0')
assigned = torch.tensor([1,2,2,1,0], device='cuda:0')
欲しいですresult = tensor([1,2,1,1,0], device='cuda:0')
。
基本的に、when pred
はcorrect
then と同じcorrect
ですassigned
。
さらに、この計算を勾配計算から除外したいと思います。
テンソルを繰り返さずにこれを行う方法はありますか?
torch.where
あなたが探しているものを正確に実行します:
import torch
pred = torch.tensor([1,2,1,0,0], device='cuda:0')
correct = torch.tensor([1,0,1,1,0], device='cuda:0')
assigned = torch.tensor([1,2,2,1,0], device='cuda:0')
result = torch.where(pred == correct, correct, assigned)
print(result)
# >>> tensor([1, 2, 1, 1, 0], device='cuda:0')
これらのテンソルにはがrequires_grad=True
ないため、勾配計算を回避するために何もする必要はありません。それ以外の場合は、次のようなことができます。
import torch
pred = torch.tensor([1.,2.,1.,0.,0.], device='cuda:0')
correct = torch.tensor([1.,0.,1.,1.,0.], device='cuda:0', requires_grad=True)
assigned = torch.tensor([1.,2.,2.,1.,0.], device='cuda:0', requires_grad=True)
with torch.no_grad():
result = torch.where(pred == correct, correct, assigned)
print(result)
# >>> tensor([1, 2, 1, 1, 0], device='cuda:0')
を使用しない場合はtorch.no_grad()
、次のようになります。
result = torch.where(pred == correct, correct, assigned)
print(result)
# >>> tensor([1., 2., 1., 1., 0.], device='cuda:0', grad_fn=<SWhereBackward>)
これは、次のようにして計算グラフから切り離すことができます。
result = result.detach()
print(result)
# >>> tensor([1., 2., 1., 1., 0.], device='cuda:0')
この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。
侵害の場合は、連絡してください[email protected]
コメントを追加