高斯混合模型训练

更新时间: 2023-11-01 11:47:05

高斯混合模型(Gaussian Mixture Model)表示在总体分布中包含K个高斯子分布的概率模型。您可以使用高斯混合模型训练组件实现模型分类。本文为您介绍高斯混合模型训练组件的配置方法。

使用限制

支持的计算引擎为MaxCompute和Flink。

可视化配置组件参数

Designer支持通过可视化的方式,配置组件参数。

页签

参数

描述

字段设置

向量列名

向量列对应的列名。

参数设置

收敛阈值

当两轮迭代的中心点距离小于收敛阈值时,算法收敛。默认为1.0E~4。

聚类中心点数量

聚类中心点的数量,默认为2。

最大迭代步数

最大迭代步数,默认为100。

随机数种子

正整数,默认为0。

执行调优

节点个数

单个节点内存大小参数配对使用。取值为[1, 9999]的正整数。具体配置方法,详情请参见附录:如何预估资源的使用量

单个节点内存大小,单位M

取值范围为1024 MB~64*1024 MB,具体配置方法,详情请参见附录:如何预估资源的使用量

通过代码方式配置组件

您也可以通过配置代码的方式,来实现高斯混合模型训练组件的相关功能,具体配置方法如下。

  • Python代码

    df_data = pd.DataFrame([
        ["-0.6264538 0.1836433"],
        ["-0.8356286 1.5952808"],
        ["0.3295078 -0.8204684"],
        ["0.4874291 0.7383247"],
        ["0.5757814 -0.3053884"],
        ["1.5117812 0.3898432"],
        ["-0.6212406 -2.2146999"],
        ["11.1249309 9.9550664"],
        ["9.9838097 10.9438362"],
        ["10.8212212 10.5939013"],
        ["10.9189774 10.7821363"],
        ["10.0745650 8.0106483"],
        ["10.6198257 9.9438713"],
        ["9.8442045 8.5292476"],
        ["9.5218499 10.4179416"],
    ])
    
    data = BatchOperator.fromDataframe(df_data, schemaStr='features string')
    dataStream = StreamOperator.fromDataframe(df_data, schemaStr='features string')
    
    gmm = GmmTrainBatchOp() \
        .setVectorCol("features") \
        .setEpsilon(0.)
    
    model = gmm.linkFrom(data)
    
    predictor = GmmPredictBatchOp() \
        .setPredictionCol("cluster_id") \
        .setVectorCol("features") \
        .setPredictionDetailCol("cluster_detail")
    
    predictor.linkFrom(model, data).print()
    
    predictorStream = GmmPredictStreamOp(model) \
        .setPredictionCol("cluster_id") \
        .setVectorCol("features") \
        .setPredictionDetailCol("cluster_detail")
    
    predictorStream.linkFrom(dataStream).print()
    
    StreamOperator.execute()
  • Java代码

    import org.apache.flink.types.Row;
    
    import com.alibaba.alink.operator.batch.BatchOperator;
    import com.alibaba.alink.operator.batch.clustering.GmmPredictBatchOp;
    import com.alibaba.alink.operator.batch.clustering.GmmTrainBatchOp;
    import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
    import com.alibaba.alink.operator.stream.StreamOperator;
    import com.alibaba.alink.operator.stream.clustering.GmmPredictStreamOp;
    import com.alibaba.alink.operator.stream.source.MemSourceStreamOp;
    import org.junit.Test;
    
    import java.util.Arrays;
    import java.util.List;
    
    public class GmmTrainBatchOpTest {
        @Test
        public void testGmmTrainBatchOp() throws Exception {
            List <Row> df_data = Arrays.asList(
                Row.of("-0.6264538 0.1836433"),
                Row.of("-0.8356286 1.5952808"),
                Row.of("0.3295078 -0.8204684"),
                Row.of("0.4874291 0.7383247"),
                Row.of("0.5757814 -0.3053884"),
                Row.of("1.5117812 0.3898432"),
                Row.of("-0.6212406 -2.2146999"),
                Row.of("11.1249309 9.9550664"),
                Row.of("9.9838097 10.9438362"),
                Row.of("10.8212212 10.5939013"),
                Row.of("10.9189774 10.7821363"),
                Row.of("10.0745650 8.0106483"),
                Row.of("10.6198257 9.9438713"),
                Row.of("9.8442045 8.5292476"),
                Row.of("9.5218499 10.4179416")
            );
            BatchOperator <?> data = new MemSourceBatchOp(df_data, "features string");
            StreamOperator <?> dataStream = new MemSourceStreamOp(df_data, "features string");
            BatchOperator <?> gmm = new GmmTrainBatchOp()
                .setVectorCol("features")
                .setEpsilon(0.);
            BatchOperator <?> model = gmm.linkFrom(data);
            BatchOperator <?> predictor = new GmmPredictBatchOp()
                .setPredictionCol("cluster_id")
                .setVectorCol("features")
                .setPredictionDetailCol("cluster_detail");
            predictor.linkFrom(model, data).print();
            StreamOperator <?> predictorStream = new GmmPredictStreamOp(model)
                .setPredictionCol("cluster_id")
                .setVectorCol("features")
                .setPredictionDetailCol("cluster_detail");
            predictorStream.linkFrom(dataStream).print();
            StreamOperator.execute();
        }
    }

附录:如何预估资源的使用量

您可以参考以下示例,来预估资源的使用量。

  • 如何预估每个节点的内存大小?

    假设聚类中心点数量为K,输入数据的向量维度为M,则每个节点需要配置的内存大小为:M × M × K × 8 × 2 ×12,即M × M × K × 8 × 2 × 12 ÷ 1024 ÷ 1024 MB。通常每个节点的内存配置为8 GB。

  • 如何预估节点的个数?

    建议按照输入数据的大小配置。例如:输入数据大小为X GB,则建议使用5X个节点。如果资源不足,可以适当降低节点数量。由于存在通信开销,随着节点数量的增加,分布式训练任务速度会先变快,后变慢。如果您观测到训练任务随着节点数量增加之后,速度变慢,则应该停止增加节点数量。

  • 该算法组件支持的数据量大小?

    建议向量维度小于200。

阿里云首页 人工智能平台 PAI 相关技术圈