在TensorFlow中插值权重

Legotin

A有两个具有相同架构的模型,但训练有不同的损失。

这些模型给出不同的结果。我想在这些模型之间结合它们的权重线性进行“平衡”:

combined_weights = k * weights_a + (1 - k) * weights_b

我正在研究的问题允许这种方法。有人认为这种方法比对输出进行插值更有效。

我也有这种方法的PyTorch实现:

k = 0.3
model_a = torch.load(...)
model_b = torch.load(...)
combined_model = OrderedDict()

for i, v_a in model_a.items():
    v_b = model_b[i]
    combined_model[i] = k * v_a + (1 - k) * v_b

torch.save(combined_model, ...)

我如何在TensorFlow中做同样的事情?我试图结合我的模型的权重做到这一点:

k = 0.3
weights_a = model_a.get_weights()
weights_b = model_b.get_weights()

combined_weights = k * weights_a + (1 - k) * weights_b

combined_model.set_weights(combined_weights)

但是在出现错误k * weights_a

can't multiply sequence by non-int of type 'float'

那么,我该怎么办呢?

马图斯·杜布拉瓦(Matus Dubrava)

那是因为类型不匹配。双方weights_aweights_b都不能由花车乘只是普通的Python列表(这是不是你正在寻找反正操作)。

权重以ndarray格式存储为这些列表的元素。使其工作最简单的方法可能是遍历权重并针对每个权重执行操作ndarray

k = 0.3

weights_a = ...
weights_b = ...
combined = []

for i in range(len(weights_a)):
    c = k * weights_a[i] + (1 - k) * weights_b[i]
    combined.append(c)

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类常见问题

在TensorFlow训练的模型中获取某些权重的值

来自分类Dev

使用TensorFlow对图像中的点进行插值采样

来自分类Dev

Matlab:更快找到ND矩阵中每个元素的一维线性插值节点和权重

来自分类Dev

链接中的插值

来自分类Dev

angularjs 中的插值

来自分类Dev

Tensorflow - tf.nn.conv2D() 中的权重值是否发生了变化?

来自分类Dev

Python中的对数插值

来自分类Dev

R中的插值/查找

来自分类Dev

Freemarker 中的嵌套插值

来自分类Dev

如何在Tensorflow中实现逐元素一维插值?

来自分类Dev

在Ruby中为插值变量分配插值

来自分类Dev

评估步骤中的权重衰减 - Tensorflow

来自分类Dev

对列中的值进行awk插值

来自分类Dev

确定画布插值中的步骤数

来自分类Dev

不插值Highcharts中的缺失日期

来自分类Dev

R中连续数据的插值

来自分类Dev

Fragment Shader GLSL中的颜色插值?

来自分类Dev

在R中执行常数插值

来自分类Dev

iframe中的角度2触发插值

来自分类Dev

在bash脚本中插值包含“ $”的变量

来自分类Dev

在天花的表列中插值数组

来自分类Dev

PHP中的数组变量插值

来自分类Dev

图像中像素颜色的插值,MATLAB

来自分类Dev

XAML中的字符串插值

来自分类Dev

Pandas.Series中的加权插值

来自分类Dev

在React Component的属性中插值JSON?

来自分类Dev

线性插值中的布尔逻辑

来自分类Dev

实时图形中的骨骼动画插值

来自分类Dev

v-for中的闭合,属性插值