Whale提供的GraphKeys工具,可以区分不同聚合操作的Collection。本文为您提供GraphKeys支持的接口格式、参数说明及调用示例。

接口 功能
add_to_collection 将Tensor对象加入到特定聚合方式对应的Collection。
get_all_collections 返回所有非空Collection的所有Tensor对象组成的列表。
get_collection 返回特定聚合方式对应Collection中的所有Tensor对象组成的列表。

背景信息

为提升模型训练的样本吞吐,经常混合使用不同的分布式并行方式,例如混合使用数据并行、模型并行及流水并行。然而Whale默认不对任何算子的输出Tensor Value进行多副本聚合,为了调试或算法收敛,需要定期查看局部或全局lossaccuracy等指标变化。Whale提供若干以graph_key区分的Collections,您可以使用whale.add_to_collectionwhale.get_all_collectionswhale.get_collection接口查看或修改对应Collection的Tensor列表。

add_to_collection

  • 格式
    add_to_collection(tensor, graph_key)
  • 功能

    将某tensor加入到graph_key对应的Collection,以便Whale框架按需自动聚合该tensor的数值。如果未调用该接口,则默认所有Tensor在sess.run时只返回单个副本的对应数值。

  • 参数
    • tensor:待聚合的某算子输出的Tensor对象,TensorFlow Tensor类型。
    • graph_key:Tensor对象的聚合方式,STRING类型,取值请参见GraphKeys取值及场景示例
  • 返回值

  • 示例
    import tensorflow as tf
    import whale as wh
    # 场景说明:查看所有模型副本的全局loss均值。
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=tf.cast(logits, tf.float32))
    wh.add_to_collection(loss, wh.GraphKeys.GLOBAL_MEAN_OBJECTS)

get_all_collections

  • 格式
    get_all_collections()
  • 功能

    返回所有非空Collection的所有Tensor对象组成的列表。

  • 参数

  • 返回值

    LIST类型。如果所有Collection为空,则返回空LIST。

  • 示例
    import whale as wh
    # 场景说明:输出所有Collection元素。
    print(wh.get_all_collections())

get_collection

  • 格式
    get_collection(graph_key)
  • 功能

    返回graph_key对应Collection中的所有Tensor对象组成的列表。

  • 参数

    graph_key:指定相应Collection中Tensor聚合方式,STRING类型,取值详情请参见GraphKeys取值及场景示例

  • 返回值

    LIST类型。如果相应Collection为空,则返回空LIST。

  • 示例
    import whale as wh
    # 场景说明:查看graph key为whale.GraphKeys.GLOBAL_MEAN_OBJECTS时,Collection中的所有Tensor对象。
    print(wh.get_collection(wh.GraphKeys.GLOBAL_MEAN_OBJECTS))

GraphKeys取值及场景示例

取值 描述 场景示例
GLOBAL_CONCAT_OBJECTS 该Collection内的Tensor将进行数值的全局(针对所有模型副本)拼接,拼接维度axis=0 查看一次迭代中所有模型副本消耗的所有数据。
LOCAL_CONCAT_OBJECTS 该Collection内的Tensor将进行数值的局部(针对当前模型副本)拼接,拼接维度axis=0 查看一次迭代中当前模型副本消耗的所有数据。例如,流水并行场景下多个micro-batch的聚合。
GLOBAL_MEAN_OBJECTS 该Collection内的Tensor将进行数值的全局(针对所有模型副本)求平均。 查看一次迭代中的所有模型副本全局loss均值。
LOCAL_MEAN_OBJECTS 该Collection内的Tensor将进行数值的局部(针对当前模型副本)求平均。 查看一次迭代中的当前模型副本局部loss均值,例如,流水并行场景下多个micro-batchloss的均值。
GLOBAL_SUM_OBJECTS 该Collection内的Tensor将进行数值的全局(针对所有模型副本)求和。 查看一次迭代中的所有模型副本全局loss之和。
LOCAL_SUM_OBJECTS 该collection内的Tensor将进行数值的局部(针对当前模型副本)求和。 查看一次迭代中的当前模型副本局部loss之和。例如,流水并行场景下多个micro-batchloss之和。