私は現在、Flux forJuliaでバッチ更新を実装しようとしています。
計算中に、繰り返し実行してスカラーのバッチを取得します
δ = Gt - model(St)[1]
push!(deltas,δ)
ここで、モデルはニューラルネットワークです
global model= Chain(
Dense(statesize,10, leakyrelu),
Dense(10,10,leakyrelu),
Dense(10,1))
最終的に配列デルタになり、2番目のニューラルネットワークでバッチグラデーション更新(バッチサイズ= 19)を実行したいと思います。ここで、各グラデーションは適切なデルタによって重み付けされます。私が書いた更新関数は
function vupdate2!(S_batch,model,α,deltas)
function v_loss_total(x)
return sum(reshape(deltas,(1,19)) .* model(x))
end
local ps = Flux.params(model)
local gs = Flux.Tracker.gradient(() -> v_loss_total(S_batch), ps)
for p in ps
Flux.Tracker.update!( p, α.* gs[p])
end
end
問題は、勾配が計算されている行がエラーをスローすることです。 MethodError: no method matching Float32(::Tracker.TrackedReal{Float64})
問題は、デルタ配列が追跡されていることだと思います。ランダム入力のv_loss_total関数の出力を見ると、次のようになります。
julia> v_loss_total(S_batch)
-6752.433690476287 (tracked) (tracked)
興味深いことに、この数値は2回追跡されます(?)。これは、追跡された2つの数値(つまり、デルタとモデル(S_batch)のエントリ)を乗算した結果だと思います。最初にデルタ配列を追跡解除する方法はありますか?助けていただければ幸いです。
さて、結局のところ、機能があります
Flux.Tracker.data()
それはまさに私が必要としていたことをします。追跡された番号を受け取り、Float自体を返します。https://github.com/FluxML/Flux.jl/issues/640も参照してください
この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。
侵害の場合は、連絡してください[email protected]
コメントを追加