使用机器学习

开通RDS MySQL机器学习服务后,您需要基于目标数据建立机器学习模型,以便对结果进行预测。本文创建一个包含三种类型鸢尾花特征的数据集,通过学习每种鸢尾花的花瓣、花萼的长度以及宽度,最终根据花瓣以及花萼的尺寸判断鸢尾花的类型。

前提条件

使用机器学习

  1. 准备模型训练用的数据集iris表,包含花瓣、花萼的长度和宽度,以及对应的鸢尾花类型。

    DROP TABLE IF EXISTS iris;
    CREATE TABLE iris (sepal_length float,sepal_width float,petal_length float,petal_width float,species varchar(64));
    INSERT INTO iris VALUES (5.1,3.5,1.4,0.2,"setosa");
    INSERT INTO iris VALUES (4.9,3,1.4,0.2,"setosa");
    INSERT INTO iris VALUES (4.7,3.2,1.3,0.2,"setosa");
    INSERT INTO iris VALUES (4.6,3.1,1.5,0.2,"setosa");
    INSERT INTO iris VALUES (5,3.6,1.4,0.2,"setosa");
    INSERT INTO iris VALUES (5.4,3.9,1.7,0.4,"setosa");
    INSERT INTO iris VALUES (4.6,3.4,1.4,0.3,"setosa");
    INSERT INTO iris VALUES (5,3.4,1.5,0.2,"setosa");
    INSERT INTO iris VALUES (4.4,2.9,1.4,0.2,"setosa");
    INSERT INTO iris VALUES (4.9,3.1,1.5,0.1,"setosa");
    INSERT INTO iris VALUES (5.4,3.7,1.5,0.2,"setosa");
    INSERT INTO iris VALUES (4.8,3.4,1.6,0.2,"setosa");
    INSERT INTO iris VALUES (4.8,3,1.4,0.1,"setosa");
    INSERT INTO iris VALUES (4.3,3,1.1,0.1,"setosa");
    INSERT INTO iris VALUES (5.8,4,1.2,0.2,"setosa");
    INSERT INTO iris VALUES (5.7,4.4,1.5,0.4,"setosa");
    INSERT INTO iris VALUES (5.4,3.9,1.3,0.4,"setosa");
    INSERT INTO iris VALUES (5.1,3.5,1.4,0.3,"setosa");
    INSERT INTO iris VALUES (5.7,3.8,1.7,0.3,"setosa");
    INSERT INTO iris VALUES (5.1,3.8,1.5,0.3,"setosa");
    INSERT INTO iris VALUES (5.4,3.4,1.7,0.2,"setosa");
    INSERT INTO iris VALUES (5.1,3.7,1.5,0.4,"setosa");
    INSERT INTO iris VALUES (4.6,3.6,1,0.2,"setosa");
    INSERT INTO iris VALUES (5.1,3.3,1.7,0.5,"setosa");
    INSERT INTO iris VALUES (4.8,3.4,1.9,0.2,"setosa");
    INSERT INTO iris VALUES (5,3,1.6,0.2,"setosa");
    INSERT INTO iris VALUES (5,3.4,1.6,0.4,"setosa");
    INSERT INTO iris VALUES (5.2,3.5,1.5,0.2,"setosa");
    INSERT INTO iris VALUES (5.2,3.4,1.4,0.2,"setosa");
    INSERT INTO iris VALUES (4.7,3.2,1.6,0.2,"setosa");
    INSERT INTO iris VALUES (4.8,3.1,1.6,0.2,"setosa");
    INSERT INTO iris VALUES (5.4,3.4,1.5,0.4,"setosa");
    INSERT INTO iris VALUES (5.2,4.1,1.5,0.1,"setosa");
    INSERT INTO iris VALUES (5.5,4.2,1.4,0.2,"setosa");
    INSERT INTO iris VALUES (4.9,3.1,1.5,0.2,"setosa");
    INSERT INTO iris VALUES (5,3.2,1.2,0.2,"setosa");
    INSERT INTO iris VALUES (5.5,3.5,1.3,0.2,"setosa");
    INSERT INTO iris VALUES (4.9,3.6,1.4,0.1,"setosa");
    INSERT INTO iris VALUES (4.4,3,1.3,0.2,"setosa");
    INSERT INTO iris VALUES (5.1,3.4,1.5,0.2,"setosa");
    INSERT INTO iris VALUES (5,3.5,1.3,0.3,"setosa");
    INSERT INTO iris VALUES (4.5,2.3,1.3,0.3,"setosa");
    INSERT INTO iris VALUES (4.4,3.2,1.3,0.2,"setosa");
    INSERT INTO iris VALUES (5,3.5,1.6,0.6,"setosa");
    INSERT INTO iris VALUES (5.1,3.8,1.9,0.4,"setosa");
    INSERT INTO iris VALUES (4.8,3,1.4,0.3,"setosa");
    INSERT INTO iris VALUES (5.1,3.8,1.6,0.2,"setosa");
    INSERT INTO iris VALUES (4.6,3.2,1.4,0.2,"setosa");
    INSERT INTO iris VALUES (5.3,3.7,1.5,0.2,"setosa");
    INSERT INTO iris VALUES (5,3.3,1.4,0.2,"setosa");
    INSERT INTO iris VALUES (7,3.2,4.7,1.4,"versicolor");
    INSERT INTO iris VALUES (6.4,3.2,4.5,1.5,"versicolor");
    INSERT INTO iris VALUES (6.9,3.1,4.9,1.5,"versicolor");
    INSERT INTO iris VALUES (5.5,2.3,4,1.3,"versicolor");
    INSERT INTO iris VALUES (6.5,2.8,4.6,1.5,"versicolor");
    INSERT INTO iris VALUES (5.7,2.8,4.5,1.3,"versicolor");
    INSERT INTO iris VALUES (6.3,3.3,4.7,1.6,"versicolor");
    INSERT INTO iris VALUES (4.9,2.4,3.3,1,"versicolor");
    INSERT INTO iris VALUES (6.6,2.9,4.6,1.3,"versicolor");
    INSERT INTO iris VALUES (5.2,2.7,3.9,1.4,"versicolor");
    INSERT INTO iris VALUES (5,2,3.5,1,"versicolor");
    INSERT INTO iris VALUES (5.9,3,4.2,1.5,"versicolor");
    INSERT INTO iris VALUES (6,2.2,4,1,"versicolor");
    INSERT INTO iris VALUES (6.1,2.9,4.7,1.4,"versicolor");
    INSERT INTO iris VALUES (5.6,2.9,3.6,1.3,"versicolor");
    INSERT INTO iris VALUES (6.7,3.1,4.4,1.4,"versicolor");
    INSERT INTO iris VALUES (5.6,3,4.5,1.5,"versicolor");
    INSERT INTO iris VALUES (5.8,2.7,4.1,1,"versicolor");
    INSERT INTO iris VALUES (6.2,2.2,4.5,1.5,"versicolor");
    INSERT INTO iris VALUES (5.6,2.5,3.9,1.1,"versicolor");
    INSERT INTO iris VALUES (5.9,3.2,4.8,1.8,"versicolor");
    INSERT INTO iris VALUES (6.1,2.8,4,1.3,"versicolor");
    INSERT INTO iris VALUES (6.3,2.5,4.9,1.5,"versicolor");
    INSERT INTO iris VALUES (6.1,2.8,4.7,1.2,"versicolor");
    INSERT INTO iris VALUES (6.4,2.9,4.3,1.3,"versicolor");
    INSERT INTO iris VALUES (6.6,3,4.4,1.4,"versicolor");
    INSERT INTO iris VALUES (6.8,2.8,4.8,1.4,"versicolor");
    INSERT INTO iris VALUES (6.7,3,5,1.7,"versicolor");
    INSERT INTO iris VALUES (6,2.9,4.5,1.5,"versicolor");
    INSERT INTO iris VALUES (5.7,2.6,3.5,1,"versicolor");
    INSERT INTO iris VALUES (5.5,2.4,3.8,1.1,"versicolor");
    INSERT INTO iris VALUES (5.5,2.4,3.7,1,"versicolor");
    INSERT INTO iris VALUES (5.8,2.7,3.9,1.2,"versicolor");
    INSERT INTO iris VALUES (6,2.7,5.1,1.6,"versicolor");
    INSERT INTO iris VALUES (5.4,3,4.5,1.5,"versicolor");
    INSERT INTO iris VALUES (6,3.4,4.5,1.6,"versicolor");
    INSERT INTO iris VALUES (6.7,3.1,4.7,1.5,"versicolor");
    INSERT INTO iris VALUES (6.3,2.3,4.4,1.3,"versicolor");
    INSERT INTO iris VALUES (5.6,3,4.1,1.3,"versicolor");
    INSERT INTO iris VALUES (5.5,2.5,4,1.3,"versicolor");
    INSERT INTO iris VALUES (5.5,2.6,4.4,1.2,"versicolor");
    INSERT INTO iris VALUES (6.1,3,4.6,1.4,"versicolor");
    INSERT INTO iris VALUES (5.8,2.6,4,1.2,"versicolor");
    INSERT INTO iris VALUES (5,2.3,3.3,1,"versicolor");
    INSERT INTO iris VALUES (5.6,2.7,4.2,1.3,"versicolor");
    INSERT INTO iris VALUES (5.7,3,4.2,1.2,"versicolor");
    INSERT INTO iris VALUES (5.7,2.9,4.2,1.3,"versicolor");
    INSERT INTO iris VALUES (6.2,2.9,4.3,1.3,"versicolor");
    INSERT INTO iris VALUES (5.1,2.5,3,1.1,"versicolor");
    INSERT INTO iris VALUES (5.7,2.8,4.1,1.3,"versicolor");
    INSERT INTO iris VALUES (6.3,3.3,6,2.5,"virginica");
    INSERT INTO iris VALUES (5.8,2.7,5.1,1.9,"virginica");
    INSERT INTO iris VALUES (7.1,3,5.9,2.1,"virginica");
    INSERT INTO iris VALUES (6.3,2.9,5.6,1.8,"virginica");
    INSERT INTO iris VALUES (6.5,3,5.8,2.2,"virginica");
    INSERT INTO iris VALUES (7.6,3,6.6,2.1,"virginica");
    INSERT INTO iris VALUES (4.9,2.5,4.5,1.7,"virginica");
    INSERT INTO iris VALUES (7.3,2.9,6.3,1.8,"virginica");
    INSERT INTO iris VALUES (6.7,2.5,5.8,1.8,"virginica");
    INSERT INTO iris VALUES (7.2,3.6,6.1,2.5,"virginica");
    INSERT INTO iris VALUES (6.5,3.2,5.1,2,"virginica");
    INSERT INTO iris VALUES (6.4,2.7,5.3,1.9,"virginica");
    INSERT INTO iris VALUES (6.8,3,5.5,2.1,"virginica");
    INSERT INTO iris VALUES (5.7,2.5,5,2,"virginica");
    INSERT INTO iris VALUES (5.8,2.8,5.1,2.4,"virginica");
    INSERT INTO iris VALUES (6.4,3.2,5.3,2.3,"virginica");
    INSERT INTO iris VALUES (6.5,3,5.5,1.8,"virginica");
    INSERT INTO iris VALUES (7.7,3.8,6.7,2.2,"virginica");
    INSERT INTO iris VALUES (7.7,2.6,6.9,2.3,"virginica");
    INSERT INTO iris VALUES (6,2.2,5,1.5,"virginica");
    INSERT INTO iris VALUES (6.9,3.2,5.7,2.3,"virginica");
    INSERT INTO iris VALUES (5.6,2.8,4.9,2,"virginica");
    INSERT INTO iris VALUES (7.7,2.8,6.7,2,"virginica");
    INSERT INTO iris VALUES (6.3,2.7,4.9,1.8,"virginica");
    INSERT INTO iris VALUES (6.7,3.3,5.7,2.1,"virginica");
    INSERT INTO iris VALUES (7.2,3.2,6,1.8,"virginica");
    INSERT INTO iris VALUES (6.2,2.8,4.8,1.8,"virginica");
    INSERT INTO iris VALUES (6.1,3,4.9,1.8,"virginica");
    INSERT INTO iris VALUES (6.4,2.8,5.6,2.1,"virginica");
    INSERT INTO iris VALUES (7.2,3,5.8,1.6,"virginica");
    INSERT INTO iris VALUES (7.4,2.8,6.1,1.9,"virginica");
    INSERT INTO iris VALUES (7.9,3.8,6.4,2,"virginica");
    INSERT INTO iris VALUES (6.4,2.8,5.6,2.2,"virginica");
    INSERT INTO iris VALUES (6.3,2.8,5.1,1.5,"virginica");
    INSERT INTO iris VALUES (6.1,2.6,5.6,1.4,"virginica");
    INSERT INTO iris VALUES (7.7,3,6.1,2.3,"virginica");
    INSERT INTO iris VALUES (6.3,3.4,5.6,2.4,"virginica");
    INSERT INTO iris VALUES (6.4,3.1,5.5,1.8,"virginica");
    INSERT INTO iris VALUES (6,3,4.8,1.8,"virginica");
    INSERT INTO iris VALUES (6.9,3.1,5.4,2.1,"virginica");
    INSERT INTO iris VALUES (6.7,3.1,5.6,2.4,"virginica");
    INSERT INTO iris VALUES (6.9,3.1,5.1,2.3,"virginica");
    INSERT INTO iris VALUES (5.8,2.7,5.1,1.9,"virginica");
    INSERT INTO iris VALUES (6.8,3.2,5.9,2.3,"virginica");
    INSERT INTO iris VALUES (6.7,3.3,5.7,2.5,"virginica");
    INSERT INTO iris VALUES (6.7,3,5.2,2.3,"virginica");
    INSERT INTO iris VALUES (6.3,2.5,5,1.9,"virginica");
    INSERT INTO iris VALUES (6.5,3,5.2,2,"virginica");
    INSERT INTO iris VALUES (6.2,3.4,5.4,2.3,"virginica");
    INSERT INTO iris VALUES (5.9,3,5.1,1.8,"virginica");                       
  2. 创建分类表species,包含鸢尾花的分类以及每种分类对应的标签(class),用数字0、1、2表示。

    DROP TABLE iF EXISTS species;
    CREATE TABLE species (species varchar(64), class int);
    INSERT INTO species VALUES ('setosa',0),('versicolor',1),('virginica',2); #0代表setosa,1代表versicolor,2代表virginica。
  3. 通过模型训练数据集(iris表)进行模型训练,生成预测模型。

    select i.*, s.class from iris i, species s where i.species = s.species order by rand() limit 120 #通过SQL语句指定需要训练的内容。
    TO TRAIN DNNClassifier #表示该语句为模型训练语句,并指定DNNClassifier作为模型训练的算法。
        with model.n_classes = 3, #指定预测结果的分类数量,由于鸢尾花的种类一共有3种,所以这里设置为3。
        model.hidden_units = [100, 100], #指定DNNClassifier中神经网络隐藏单元的数量,[100, 100]表示两层神经网络,每层有100个隐藏单元。
        optimizer.learning_rate=0.1, #指定DNNClassifier的学习率,默认值为0.0001。
        train.epoch = 0.1, #指定模型训练次数。
        validation.select="select i.*, s.class from iris i, species s where i.species = s.species order by rand() limit 30" #通过验证SQL语句校准模型准确度。
    COLUMN sepal_length, sepal_width, petal_width #指定用作模型训练的特征列。
    LABEL class INTO iris_model; #指定预测目标(此处指定为class),生成预测模型iris_model(模型名称可自定义)。
    说明
    • 上述代码为同步任务,需要等待模型训练完成才可进行下一步操作。模型训练时间根据数据量大小、训练参数的不同会有很大变化。如果您的数据量庞大,可执行异步任务,即在代码后方加入async;将任务置于后台进行。例如:

      select i.*, s.class from iris i, species s where i.species = s.species order by rand() limit 120
      TO TRAIN DNNClassifier
        with model.n_classes = 3,
        model.hidden_units = [100, 100],
        optimizer.learning_rate=0.1,
        train.epoch = 0.1,
        validation.select="select i.*, s.class from iris i, species s where i.species = s.species order by rand() limit 30"
      COLUMN sepal_length, sepal_width, petal_width
      LABEL class INTO iris_model async;

      异步任务执行时,您可以执行SHOW TRAIN;查看任务执行的情况。

    • 模型训练完成后,您可以执行SHOW MODELS;查看已完成的模型。

    • 除了DNNClassifier之外,您还可以选择其他算法,RDS MySQL机器学习支持如下算法:

    • 更多关于机器学习语句的信息,请参见官方说明

  4. 创建测试数据表,提供不同鸢尾花的花瓣、花萼尺寸,用于测试训练完成的模型是否可以根据这些尺寸预测出花的种类。

    CREATE TABLE iris_test (sepal_length float,sepal_width float,petal_length float,petal_width float,species varchar(64));
    INSERT INTO iris_test VALUES (4.8,3.1,1.6,0.2,"setosa");
    INSERT INTO iris_test VALUES (5.4,3.4,1.5,0.4,"setosa");
    INSERT INTO iris_test VALUES (5.2,4.1,1.5,0.1,"setosa");
    INSERT INTO iris_test VALUES (5.5,4.2,1.4,0.2,"setosa");
    INSERT INTO iris_test VALUES (6.7,3,5.2,2.3,"virginica");
    INSERT INTO iris_test VALUES (6.1,2.6,5.6,1.4,"virginica");
    INSERT INTO iris_test VALUES (7.7,3,6.1,2.3,"virginica");
    INSERT INTO iris_test VALUES (6.3,3.4,5.6,2.4,"virginica");
    INSERT INTO iris_test VALUES (6.4,3.1,5.5,1.8,"virginica");
    INSERT INTO iris_test VALUES (5.5,2.5,4,1.3,"versicolor");
    INSERT INTO iris_test VALUES (5.5,2.6,4.4,1.2,"versicolor");
    INSERT INTO iris_test VALUES (6.3,2.3,4.4,1.3,"versicolor");
    INSERT INTO iris_test VALUES (6.3,2.3,4.4,1.3,"versicolor");
  5. 创建预测结果表iris_pred。

    CREATE TABLE iris_pred (sepal_length float,sepal_width float,petal_length float,petal_width float,species varchar(64), class int);
  6. 使用训练完成的模型(iris_modle)对测试数据表(iris_test)中的数据进行预测,并将预测结果存入预测结果表(iris_pred)。

    select * from iris_test #通过SQL语句指定需要预测的内容。
    TO PREDICT #表示该语句为预测语句。
        iris_pred.class #将预测结果存入iris_pred的class列。
        using iris_model;  #使用名为iris_model的模型。
  7. 查询预测结果表。

    select * from iris_pred;
    结果
    说明

    class列为预测的结果,species列为答案,您可以根据这两列查看模型预测是否精确。如果结果不够精确,则可以尝试优化步骤3模型训练的代码。如何优化模型训练的代码,请参见官方说明