PyTorchVAEがonnxへの変換に失敗する

jbm

PyTorch VAEをonnxに変換しようとしていますが、次のようになります。 torch.onnx.symbolic.normal does not exist

問題は次のreparametrize()機能に起因しているようです。

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if self.have_cuda:
             eps = torch.normal(torch.zeros(std.size()),torch.ones(std.size())).cuda()
        else:
           eps = torch.normal(torch.zeros(std.size()),torch.ones(std.size()))
        return eps.mul(std).add_(mu)

私も試しました:

eps = torch.cuda.FloatTensor(std.size()).normal_()

エラーが発生しました:

    Schema not found for node. File a bug report.
    Node: %173 : Float(1, 20) = aten::normal(%169, %170, %171, %172), scope: VAE 
    Input types:Float(1, 20), float, float, Generator

そして

eps = torch.randn(std.size()).cuda()

エラーが発生しました:

    builtins.TypeError: i_(): incompatible function arguments. The following argument types are supported:
    1. (self: torch._C.Node, arg0: str, arg1: int) -> torch._C.Node
    Invoked with: %137 : Tensor = onnx::RandomNormal(), scope: VAE, 'shape', 133 defined in (%133 : int[] = prim::ListConstruct(%128, %132), scope: VAE) (occurred when translating randn)

私はを使用していcudaます。

どんな考えでもありがたいです。おそらく私zはonnxに対して異なる方法で/ latentにアプローチする必要がありますか?

注:ステップスルーすると、が検出さRandomNormal()ていることがわかりtorch.randn()ます。これは正しいはずです。しかし、その時点では実際には引数にアクセスできないので、どうすれば修正できますか?

Yuki Hashimoto

要するに、以下のコードが機能する可能性があります。(少なくとも私の環境では、エラーなしで機能しました)。

と思われ.size()、それはonnxコンパイルエラーが発生しますので、オペレータは、一定ではなく、変数を返すことがあります。(.size()を使用するように変更したときに同じエラーが発生しました)

import torch
import torch.utils.data
from torch import nn
from torch.nn import functional as F



IN_DIMS = 28 * 28
BATCH_SIZE = 10
FEATURE_DIM = 20

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, FEATURE_DIM)
        self.fc22 = nn.Linear(400, FEATURE_DIM)
        self.fc3 = nn.Linear(FEATURE_DIM, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn(BATCH_SIZE, FEATURE_DIM, device='cuda')
        return eps.mul(std).add_(mu)

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)

        return recon_x

model = VAE().cuda()

dummy_input = torch.randn(BATCH_SIZE, IN_DIMS, device='cuda')
torch.onnx.export(model, dummy_input, "vae.onnx", verbose=True)

この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。

侵害の場合は、連絡してください[email protected]

編集
0

コメントを追加

0

関連記事

分類Dev

Generic MappingMethodが具象型への変換に失敗する

分類Dev

Newtonsoft.JsonがJSONからICollection <Object>への変換に失敗する

分類Dev

KotlinへのJava関数の変換が失敗する

分類Dev

DocXからPDFへのDriveApp変換が失敗する

分類Dev

線形xmlへのXSL変換が失敗する理由

分類Dev

QByteArrayの16進数への変換に失敗する

分類Dev

MSSQLからMySQLへの変換に失敗する

分類Dev

代わりに列を保持する代わりに、pivot_longerが行への変換に失敗する

分類Dev

Downloadhelperが変換に失敗する

分類Dev

Downloadhelperが変換に失敗する

分類Dev

DTDが存在する場合、xsltxmlからhtmlへの変換が失敗する

分類Dev

文字列へのUIColor変換が失敗するのはなぜですか?

分類Dev

配列から文字列への変換が原因でLaravelValidatorが失敗する

分類Dev

文字列のObjectIdへの変換がmongoose4.6.0で失敗する

分類Dev

AutoMapperが単純なDTOの変換に失敗する

分類Dev

接着剤が日付の変換に失敗する

分類Dev

number_format()が科学的記数法の変換に失敗する

分類Dev

C ++フロートから文字列への変換がstringstreamで失敗する

分類Dev

std :: stringstreamを使用した文字列への変換が空白で失敗する

分類Dev

UTF 16LEからUTF16BEへの.NETC#変換が失敗する

分類Dev

xts 0.9.7 から 0.10.0 への更新で Xts 変換が失敗する

分類Dev

ジェネリック関数のKotlinへの変換に失敗する

分類Dev

内部表現への変換に失敗するSpringデータJPA

分類Dev

posixctへの変換に失敗します

分類Dev

SQL Proc '変換に失敗しました' varcharからintへ。なぜ変換するのですか?

分類Dev

jpgフレームの変更が原因でjpgからmp4へのffmpeg変換が失敗する

分類Dev

/ etc / groupへの変更の書き込み中にgroupaddが失敗する

分類Dev

GoogleSQLへの移行時にマスターの変更が失敗する

分類Dev

Oracleカーソルを使用してSQLServerに挿入するときに、多数のPyodbc挿入がintからbigへの変換に失敗する

Related 関連記事

  1. 1

    Generic MappingMethodが具象型への変換に失敗する

  2. 2

    Newtonsoft.JsonがJSONからICollection <Object>への変換に失敗する

  3. 3

    KotlinへのJava関数の変換が失敗する

  4. 4

    DocXからPDFへのDriveApp変換が失敗する

  5. 5

    線形xmlへのXSL変換が失敗する理由

  6. 6

    QByteArrayの16進数への変換に失敗する

  7. 7

    MSSQLからMySQLへの変換に失敗する

  8. 8

    代わりに列を保持する代わりに、pivot_longerが行への変換に失敗する

  9. 9

    Downloadhelperが変換に失敗する

  10. 10

    Downloadhelperが変換に失敗する

  11. 11

    DTDが存在する場合、xsltxmlからhtmlへの変換が失敗する

  12. 12

    文字列へのUIColor変換が失敗するのはなぜですか?

  13. 13

    配列から文字列への変換が原因でLaravelValidatorが失敗する

  14. 14

    文字列のObjectIdへの変換がmongoose4.6.0で失敗する

  15. 15

    AutoMapperが単純なDTOの変換に失敗する

  16. 16

    接着剤が日付の変換に失敗する

  17. 17

    number_format()が科学的記数法の変換に失敗する

  18. 18

    C ++フロートから文字列への変換がstringstreamで失敗する

  19. 19

    std :: stringstreamを使用した文字列への変換が空白で失敗する

  20. 20

    UTF 16LEからUTF16BEへの.NETC#変換が失敗する

  21. 21

    xts 0.9.7 から 0.10.0 への更新で Xts 変換が失敗する

  22. 22

    ジェネリック関数のKotlinへの変換に失敗する

  23. 23

    内部表現への変換に失敗するSpringデータJPA

  24. 24

    posixctへの変換に失敗します

  25. 25

    SQL Proc '変換に失敗しました' varcharからintへ。なぜ変換するのですか?

  26. 26

    jpgフレームの変更が原因でjpgからmp4へのffmpeg変換が失敗する

  27. 27

    / etc / groupへの変更の書き込み中にgroupaddが失敗する

  28. 28

    GoogleSQLへの移行時にマスターの変更が失敗する

  29. 29

    Oracleカーソルを使用してSQLServerに挿入するときに、多数のPyodbc挿入がintからbigへの変換に失敗する

ホットタグ

アーカイブ