如何在联邦的Tensorflow中保存模型

向量

如何在打击代码中保存模型

如果要运行代码,请访问https://github.com/tensorflow/federated并下载federated_learning_for_image_classification.ipynb。

如果您在教程federated_learning_for_image_classification.ipynb中告诉我如何保存联邦学习模型,我将不胜感激。



from __future__ import absolute_import, division, print_function
import tensorflow_federated as tff
from matplotlib import pyplot as plt
import tensorflow as tf
import six
import numpy as np
from six.moves import range
import warnings
import collections
import nest_asyncio
import h5py_character
from tensorflow.keras import layers
nest_asyncio.apply()
warnings.simplefilter('ignore')
tf.compat.v1.enable_v2_behavior()
np.random.seed(0)


NUM_CLIENTS = 1
NUM_EPOCHS = 1
BATCH_SIZE = 20
SHUFFLE_BUFFER = 500
num_classes = 3755

if six.PY3:
    tff.framework.set_default_executor(
        tff.framework.create_local_executor(NUM_CLIENTS))  


data_train = h5py_character.load_characters_data()

print(len(data_train.client_ids))

example_dataset = data_train.create_tf_dataset_for_client(
    data_train.client_ids[0])


def preprocess(dataset):
    def element_fn(element):
        # element['data'] = tf.expand_dims(element['data'], axis=-1)
        return collections.OrderedDict([
            # ('x', tf.reshape(element['data'], [-1])),
            ('x', tf.reshape(element['data'], [64, 64, 1])),
            ('y', tf.reshape(element['label'], [1])),
        ])

    return dataset.repeat(NUM_EPOCHS).map(element_fn).shuffle(
        SHUFFLE_BUFFER).batch(BATCH_SIZE)


preprocessed_example_dataset = preprocess(example_dataset)  
print(iter(preprocessed_example_dataset).next())


sample_batch = tf.nest.map_structure(
    lambda x: x.numpy(), iter(preprocessed_example_dataset).next())



def make_federated_data(client_data, client_ids):
    return [preprocess(client_data.create_tf_dataset_for_client(x))
            for x in client_ids]


sample_clients = data_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(data_train, sample_clients)




def create_compiled_keras_model():

    model = tf.keras.Sequential([
        layers.Conv2D(input_shape=(64, 64, 1), filters=64, kernel_size=(3, 3), strides=(1, 1),
                      padding='same', activation='relu'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),
        layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),
        layers.Conv2D(filters=256, kernel_size=(3, 3), padding='same'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),

        layers.Flatten(),
        layers.Dense(1024, activation='relu'),
        layers.Dense(3755, activation='softmax')
    ])

    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        # metrics=['accuracy'])
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])



    return model


def model_fn():
    keras_model = create_compiled_keras_model()
    global model_to_save
    model_to_save = keras_model
    print(keras_model.summary())
    return tff.learning.from_compiled_keras_model(keras_model, sample_batch)


iterative_process = tff.learning.build_federated_averaging_process(model_fn)


state = iterative_process.initialize()

state, metrics = iterative_process.next(state, federated_train_data)

print('round  1, metrics={}'.format(metrics))

for round_num in range(2, 110):
    state, metrics = iterative_process.next(state, federated_train_data)
    print('round {:2d}, metrics={}'.format(round_num, metrics))
基思·拉什(Keith Rush)

粗略地讲,我们将在这里使用对象及其save_checkpoint/load_checkpoint方法。特别是,您可以实例化一个FileCheckpointManager,并要求其state直接保存(几乎)。

state在您的示例中是tff.python.common_libs.anonymous_tuple.AnonymousTuple(IIRC)的实例,该实例docstringtf.convert_to_tensor所需save_checkpoint并在docstring中声明的兼容TFF研究代码中经常使用的通用解决方案是引入Pythonattr的类,以便在返回状态后立即将其从匿名元组转换为其他示例(请参见此处的示例)。

假设以上所述,以下草图应适用:

# state assumed an anonymous tuple, previously created
# N some integer 

ckpt_manager = FileCheckpointManager(...)
ckpt_manager.save_checkpoint(ServerState.from_anon_tuple(state), round_num=N)

要从此检查点还原,可以随时调用:

state = iterative_process.initialize()
ckpt_manager = FileCheckpointManager(...)
restored_state = ckpt_manager.load_latest_checkpoint(
    ServerState.from_anon_tuple(state))

需要注意的一件事:上面链接的代码指针通常在中tff.python.research...,不包含在pip包中。因此,获取它们的首选方法是将代码放入您自己的项目中,或者拉下存储库并从源代码进行构建。

感谢您对TFF的关注!

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

如何在PyTorch中保存模型架构?

来自分类Dev

如何在Django中保存模型记录

来自分类Dev

Tensorflow:如何在训练中我想要的步骤中保存模型

来自分类Dev

如何不在Tensorflow Keras中保存模型优化器?

来自分类Dev

如何在一次交易中保存相关模型?

来自分类Dev

如何在Ember Data中保存关联模型上的更改

来自分类Dev

如何在Django模型中保存对象列表?

来自分类Dev

如何在R中保存JAGS模型对象?

来自分类Dev

如何在Scala中保存RandomForestClassifier Spark模型?

来自分类Dev

如何在Apache Spark中保存和加载MLLib模型?

来自分类Dev

如何在一次交易中保存相关模型?

来自分类Dev

如何在libsvm中保存经过matlab训练的模型

来自分类Dev

如何在R中保存JAGS模型对象?

来自分类Dev

如何在视图中的模型中保存表单字段

来自分类Dev

如何在Scala中保存RandomForestClassifier Spark模型?

来自分类Dev

如何在ignite中保存决策树训练模型?

来自分类Dev

如何在联邦Tensorflow中绘制增量重量的直方图摘要?

来自分类Dev

无法在 Tensorflow 中保存或恢复模型

来自分类Dev

如何在rails中保存包含另一个模型属性的模型?

来自分类Dev

如何仅在Tensorflow2中保存张量而不是模型

来自分类Dev

如何在Ember.js中保存属性类型为“ date”的模型?

来自分类Dev

如何在Ember.js中保存具有多对多关系的模型?

来自分类Dev

如何在odoo模型中保存只读/可编辑假字段上的值?

来自分类Dev

如何在Swift的CoreData模型中保存动态创建的UISwitch的更改状态?

来自分类Dev

如何在Python中保存所有深度学习模型参数?

来自分类Dev

如何在Vapor 3中保存具有特定ID的模型

来自分类Dev

如何在Ember.js中保存具有多对多关系的模型?

来自分类Dev

如何在一次交易中保存多个Django模型?

来自分类Dev

如何在Django模型中保存两个不同的用户?

Related 相关文章

  1. 1

    如何在PyTorch中保存模型架构?

  2. 2

    如何在Django中保存模型记录

  3. 3

    Tensorflow:如何在训练中我想要的步骤中保存模型

  4. 4

    如何不在Tensorflow Keras中保存模型优化器?

  5. 5

    如何在一次交易中保存相关模型?

  6. 6

    如何在Ember Data中保存关联模型上的更改

  7. 7

    如何在Django模型中保存对象列表?

  8. 8

    如何在R中保存JAGS模型对象?

  9. 9

    如何在Scala中保存RandomForestClassifier Spark模型?

  10. 10

    如何在Apache Spark中保存和加载MLLib模型?

  11. 11

    如何在一次交易中保存相关模型?

  12. 12

    如何在libsvm中保存经过matlab训练的模型

  13. 13

    如何在R中保存JAGS模型对象?

  14. 14

    如何在视图中的模型中保存表单字段

  15. 15

    如何在Scala中保存RandomForestClassifier Spark模型?

  16. 16

    如何在ignite中保存决策树训练模型?

  17. 17

    如何在联邦Tensorflow中绘制增量重量的直方图摘要?

  18. 18

    无法在 Tensorflow 中保存或恢复模型

  19. 19

    如何在rails中保存包含另一个模型属性的模型?

  20. 20

    如何仅在Tensorflow2中保存张量而不是模型

  21. 21

    如何在Ember.js中保存属性类型为“ date”的模型?

  22. 22

    如何在Ember.js中保存具有多对多关系的模型?

  23. 23

    如何在odoo模型中保存只读/可编辑假字段上的值?

  24. 24

    如何在Swift的CoreData模型中保存动态创建的UISwitch的更改状态?

  25. 25

    如何在Python中保存所有深度学习模型参数?

  26. 26

    如何在Vapor 3中保存具有特定ID的模型

  27. 27

    如何在Ember.js中保存具有多对多关系的模型?

  28. 28

    如何在一次交易中保存多个Django模型?

  29. 29

    如何在Django模型中保存两个不同的用户?

热门标签

归档