文档

模型训练

更新时间:
一键部署

模型训练 API 与 tf.keras.Model 模块提供的 API 基本一致,关于 tf.keras.Model 模块的更多信息请参见 tf.keras

重要

本文涉及 API 中所有的占位符,例如"$df0",必须包含单引号或双引号。

模型训练 API 说明

模型训练 API 的使用方法如下:

  1. 继承 JupiterKerasModel 后,通过定义 build 方法自定义模型。build 函数接收一个 input_shape 参数作为模型输入层的 shape,以输出一个 keras.Model 实例。

  2. 提供以下信息,完成实例化自定义模型:

    • uid:模型 ID。

      build_method 支持以下三种输入:

      • from_server

        要求所有 Client 端从 Server 端获取并加载已编译好的模型,从而使所有参与方的初始状态保持一致。

      • from_local

        使用本地训练后的模型,要求 file_uri 参数指定的路径可以获取以 checkpoint 文件夹形式存在的模型。

      • None

        使用 build 方法创建模型,并且不进行同步操作。

    • file_uri:模型保存路径。

    • input_shape:输入形式。

    • build_method:构建方法。

  3. 编译模型。此处接受的参数类似 keras,但新增了一个 strategy 参数,该参数接受一个 fascia.strategy.BaseKerasStrategy 的子类,或者是预定义的字符串,目前支持以下两种策略:

    • fedavg-w

    • fedavg-grad

  4. 训练模型。关于此处接受的参数及类型,请参见本节 参数model.fit 参数说明。

  5. 使用 model.save() 保存模型,该方法接受 saved_path 参数作为保存路径。当不指定该参数时,保存路径由实例化模型时输入的 file_uri 参数决定。

模型训练代码示例

from typing import Union, Tuple, Dict
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras import Model
from fascia.biz.model import JupiterKerasModel
from fascia.biz.summary import summary
from fascia.biz.api.dataframe import read_fed_table

roche = read_fed_table("$df0").values()
train_x, train_y = roche[:, 1:], roche[:, 0]
summary.schema(train_x)
summary.schema(train_y)

class CustomKerasModel(JupiterKerasModel):
    def build(self, input_shape) -> Union[Model, Tuple[Model, Dict]]:
        model = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(64, activation="relu", input_shape=(72,)), 
                tf.keras.layers.Dense(32, activation="relu"),
                tf.keras.layers.Dense(1, activation='sigmoid'),
            ]
        )
        return model
model = CustomKerasModel(uid='roche_fedavg', file_uri="$model", input_shape=(72,))

# Compile model with FedAVG strategy (weight aggregation)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
              loss=tf.keras.losses.BinaryCrossentropy(),
              metrics=[tf.metrics.BinaryAccuracy(),
                       tf.keras.metrics.AUC(), 
                       tf.keras.metrics.Precision(), 
                       tf.keras.metrics.Recall()], 
                       strategy='fedavg-w') #tf.metrics.BinaryAccuracy(threshold=0.5),

# Fit model
model.fit(train_x, train_y, batch_size=256, epochs=10, validation_split=0.1, aggregate_freq=1)
model.save()

参数

  • train_x:特征,接受 FedNdArray 或者 "$df0" 形式的占位符。

  • train_y:标签,接受 FedNdArray 或者 "$df0" 形式的占位符。

  • batch_size:批处理大小,接受整数或类似 {'party_a': 32, 'party_b': 64} 格式的字段,注意该设置也会影响验证集。

  • epochs:训练回合数。

  • verbose:是否在过程中显示性能指标。

  • callbacks:类似 Keras.Callback

  • validation_split:验证集的切分比例。

  • validation_data:验证集数据,当提供该数据集时,validation_split 会被忽略。

  • validation_freq:验证频率。

  • shuffle:是否对数据进行 shuffle

  • sample_weightsample 权值,含义类似 keras

  • steps_per_epoch:每个 epoch 中执行的步数,当提供该参数时,batch_size参数会被忽略。

  • aggregate_freq:聚合频率。

  • trainable_parties:指定使用训练的 parties。指定为类似['party_a', 'party_b']时,表示使用指定的参与方参与聚合。默认为 None,即表示使用所有的 parties。

返回值定义

history:训练结果的历史记录,包括全局聚合的性能指标、本地训练集性能指标和验证集性能指标。

History 结果示例

{
  "alice": {"loss": [0.14, 0.12], "accuracy": [0.85, 0.87]},
  "bob": {"loss": [0.14, 0.12], "accuracy": [0.85, 0.87]},
  "__global__": {"loss": [0.14, 0.12], "accuracy": [0.85, 0.87]},
}

  • 本页导读 (0)
文档反馈