本文为您介绍如何使用深度学习框架TensorFlow,快速搭架图像识别的预测模型。
前提条件
- 已创建OSS Bucket,并完成了OSS授权,详情请参见创建存储空间和PAI访问云产品授权:OSS。
重要 创建Bucket时,不要开通版本控制,否则可能导致训练失败。
- 已开启GPU,详情请参见MaxCompute资源管理。
背景信息
随着互联网发展,产生了大量图片及语音数据。如何有效利用这些非结构化数据,一直是困扰数据挖掘工程师的一道难题。主要原因包括:
- 通常需要使用深度学习算法,上手门槛高。
- 通常需要依赖GPU计算引擎,计算资源费用高。
PAI-Designer及原PAI-Studio已经预置了使用深度学习框架实现图片分类的模板,您可以直接从模板创建实验,并将其复用到图片鉴黄、物体检测等领域。
数据集
本实验使用CIFAR-10数据集,该数据集包含6万张像素为32*32的彩色图片,共10个类别,分别为飞机、汽车、鸟、毛、鹿、狗、青蛙、马、船及卡车,如下图所示。您可以下载该数据集及相关代码,详情请参见
CIFAR 10案例。

使用过程中将该数据集拆分为训练数据集(5万张图片)和预测数据集(1万张图片)。其中5万张图片的训练数据集又被拆分为5个
data_batch,1万张图片的预测数据集组成
test_batch,如下图所示。

数据准备
将本实验的数据集和相关代码上传至OSS的Bucket路径。例如,在OSS的Bucket下创建
aohai_test文件夹及四个子文件夹,如下图所示。
每个文件夹的作用如下:

- check_point:存储实验生成的模型。
说明 从原PAI-Studio模板创建实验后,必须手动将TensorFlow组件的 checkpoint输出目录/模型输入目录参数配置为已有的OSS文件夹路径,整个实验才能运行。本实验中,将 checkpoint输出目录/模型输入目录配置为 check_point文件夹路径。
- cifar-10-batches-py:存储训练数据集和预测数据集对应的数据源文件cifar-10-batcher-py和预测集文件bird_mount_bluebird.jpg。
- train_code:存储训练数据,即cifar_pai.py。
- predict_code:存储cifar_predict_pai.py。
使用TensorFlow实现图片分类
以原PAI-Studio的操作步骤为例,来说明如何使用TensorFlow实现图片分类,具体操作步骤如下。
- 进入PAI-Designer页面。
- 登录PAI控制台。
- 在左侧导航栏单击工作空间列表,在工作空间列表页面中单击待操作的工作空间名称,进入对应工作空间内。
- 在工作空间页面的左侧导航栏选择 ,进入Designer页面。
- 在可视化建模(Designer)页面右上方,单击前往旧版可视化建模(Studio)。
- 构建实验。
- 运行实验并查看输出结果。
- 单击画布上方的运行。
- 实验运行结束后,您可以在配置的OSS路径(checkpoint输出目录/模型输入目录)下查看预测结果。
训练代码解析
针对
cifar_pai.py文件中的关键代码进行解析:
- 构建CNN图片训练模型
network = input_data(shape=[None, 32, 32, 3], data_preprocessing=img_prep, data_augmentation=img_aug) network = conv_2d(network, 32, 3, activation='relu') network = max_pool_2d(network, 2) network = conv_2d(network, 64, 3, activation='relu') network = conv_2d(network, 64, 3, activation='relu') network = max_pool_2d(network, 2) network = fully_connected(network, 512, activation='relu') network = dropout(network, 0.5) network = fully_connected(network, 10, activation='softmax') network = regression(network, optimizer='adam', loss='categorical_crossentropy', learning_rate=0.001)
- 训练生成模型model.tfl
model = tflearn.DNN(network, tensorboard_verbose=0) model.fit(X, Y, n_epoch=100, shuffle=True, validation_set=(X_test, Y_test), show_metric=True, batch_size=96, run_id='cifar10_cnn') model_path = os.path.join(FLAGS.checkpointDir, "model.tfl") print(model_path) model.save(model_path)
预测代码解析
针对
cifar_predict_pai.py文件中的关键代码进行解析。首先读入图片
bird_bullocks_oriole.jpg,将其调整为32*32像素大小。然后传入
model.predict预测函数评分,返回这张图片对应的十种分类
[‘airplane’,’automobile’,’bird’,’cat’,’deer’,’dog’,’frog’,’horse’,’ship’,’truck’]的权重。最后将权重最高的一个分类作为预测结果返回。
predict_pic = os.path.join(FLAGS.buckets, "bird_bullocks_oriole.jpg")
img_obj = file_io.read_file_to_string(predict_pic)
file_io.write_string_to_file("bird_bullocks_oriole.jpg", img_obj)
img = scipy.ndimage.imread("bird_bullocks_oriole.jpg", mode="RGB")
# Scale it to 32x32
img = scipy.misc.imresize(img, (32, 32), interp="bicubic").astype(np.float32, casting='unsafe')
# Predict
prediction = model.predict([img])
print (prediction[0])
print (prediction[0])
#print (prediction[0].index(max(prediction[0])))
num=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
print ("This is a %s"%(num[prediction[0].index(max(prediction[0]))]))