深入三维视界:基于PAI-DSW的3D脑部MRI影像分类

更新时间:
复制为 MD 格式

本实验将是您从二维图像分析迈向三维空间理解的关键一步。您将在阿里云PAI-DSW环境中,学习如何处理并分析完整的3D医学影像数据(脑部MRI扫描),并训练一个三维深度学习模型(3D DenseNet),以完成一项有趣的分类任务:根据大脑的MRI扫描来判断其所属的性别。

实验简介

本实验将是您从二维图像分析迈向三维空间理解的关键一步。您将在阿里云PAI-DSW环境中,学习如何处理并分析完整的3D医学影像数据(脑部MRI扫描),并训练一个三维深度学习模型(3D DenseNet),以完成一项有趣的分类任务:根据大脑的MRI扫描来判断其所属的性别。

背景知识

  1. 3D医学影像分类: 与之前处理的2D图像(可视为“照片”)不同,3D医学影像(如CT、MRI)是由成百上千张连续的2D切片组成的“数据体”,它包含了完整的三维空间和解剖结构信息。3D分类任务直接将整个数据体作为输入,能够让模型从立体的角度学习特征,这对于需要全局空间信息的诊断任务至关重要。

  2. MRI (磁共振成像): MRI是一种先进的医学成像技术,它不使用X射线,而是利用强大的磁场和无线电波来生成身体内部结构的详细图像。尤其在脑部、关节等软组织成像方面,MRI具有极高的分辨率和清晰度,是神经科学研究和临床诊断的重要工具。

  3. 3D DenseNet模型: 这是经典DenseNet模型在三维空间的扩展。通过将2D卷积、池化等操作替换为对应的3D版本,模型获得了直接处理三维数据体(Volume)的能力。其密集的特征连接特性在处理信息量巨大的3D数据时依然表现出色。

  4. 模型可解释性 (Occlusion Sensitivity):“AI模型做出决策”的过程常常像一个“黑箱”。模型可解释性技术就是为了打开这个黑箱。遮挡敏感度(Occlusion Sensitivity)是一种直观的可视化方法,它通过依次遮挡输入图像的不同区域,并观察模型预测概率的变化,来判断哪些区域对最终决策的贡献最大。

  5. PAI-DSW: PAI-DSW是一个为开发者量身打造的云端深度学习开发环境。

实验室资源方式简介

进入实操前,请确保阿里云账号满足以下条件:

  • 个人账号资源

    • 使用您个人的云资源进行操作,资源归属于个人。

    • 平台仅提供手册参考,不会对资源做任何操作。

  • 确保已完成云工开物 300 元代金券领取。

  • 已通过实名认证且账户余额 ≥0 元。

  • 在实验页面,当您已阅读并同意上述创建资源的目的以及部分资源可能产生的计费规则。

资源消耗说明

本场景主要涉及以下云产品和服务:PAI、对象存储OSS。

本实验预计产生资源消耗:约10元(以使用ecs.gn6i-c8g1.2xlarge规格的PAI-DSW实例进行1小时的数据处理与模型训练为例估算)。

如果您调整了资源规格、延长了使用时长,或执行了本方案以外的操作,可能导致费用发生变化,请以控制台显示的实际价格和最终账单为准。

  • PAI-DSW: 费用主要由DSW实例的运行时长和其规格决定。本实验选用GPU实例进行模型训练,关机后即停止计费。

  • 对象存储 OSS: 费用由数据存储容量和少量外网下行流量(仅在下载结果时产生)决定。

领取专属权益及创建实验资源

  • 第一步:在开始实验之前,请先点击右侧屏幕的“进入实操”再进行后续操作

    image

  • 第二步:本次实验需要您通过领取阿里云云工开物学生专属300元抵扣券兑换本次实操的云资源,如未领取请先点击领取。(若已领取请跳过)

    image

    重要

    实验产生的费用优先使用优惠券,优惠券使用完毕后需您自行承担。

    学生认证

实验步骤

  1. 进入DSW控制台

    • 登录阿里云,进入机器学习PAI控制台,在左侧导航栏选择【工作空间列表】,点击进入您的工作空间

      image

    • 在工作空间内,选择左侧的【模型开发与训练】—【DSW(Data Science Workshop)

      image

  2. 创建DSW实例

    • 点击【创建实例】

      image

      • 实例名称:自定义一个名称,如 medical-2d-regist

      • 资源组(机型):为了进行深度学习模型训练,我们需要选择GPU实例。点击【筛选】,勾选【GPU】,然后选择一款有库存的GPU机型,例如 ecs.gn6i-c8g1.2xlarge(vCPU: 8, 内存: 32GiB, GPU: NVIDIA T4 16GB)

        说明

        这是成本消耗的主要来源,请务必注意实验后及时停止或删除实例!

        image

      • 镜像:选择一个预置了PyTorch框架的镜像,例如 pytorch:1.12-gpu-py39-cu113-ubuntu20.04

        image

    • 其他保持默认,选择完成后点击【确定】

      image

    • 等待约2-3分钟,直到实例状态变为“运行中”

      image

  3. 进入JupyterLab环境

    • DSW实例列表中,找到刚刚创建的实例,点击右侧的【打开】

    • 返回JupyterLab的启动器(Launcher)页面,点击【Python 3 (PyTorch 1.12)】

    • 创建一个新的Notebook文件

    image

  4. 安装MONAI并导入环境

    • Notebook的第一个代码单元格中,运行以下代码来安装必要的AI开发库。重点: nibabel是专门用于读取.nii.gz3D医学影像格式文件的库。

      !pip install -q "monai-weekly[nibabel, tqdm]"
    • 在下一个单元格中,导入本实验所需的所有Python库。

      import os
      import sys
      import tempfile
      import matplotlib.pyplot as plt
      import torch
      import numpy as np
      import monai
      from monai.apps import download_and_extract
      from monai.data import DataLoader, ImageDataset
      from monai.transforms import Compose, EnsureChannelFirst, RandRotate90, Resize, ScaleIntensity
      
      # 确定运行设备
      device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
      print(f"Using device: {device}")
      
      
  5. 下载并准备IXI脑部MRI数据集

    • 说明

      本实验使用IXI脑部MRI公开数据集的一个子集。我们将通过代码自动下载并解压

      root_dir = tempfile.mkdtemp()
      resource = "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/IXI-T1.tar"
      md5 = "34901a0593b41dd19c1a1f746eac2d58"
      dataset_dir = os.path.join(root_dir, "ixi")
      tarfile_name = f"{dataset_dir}.tar"
      
      if not os.path.exists(dataset_dir):
          download_and_extract(resource, tarfile_name, root_dir, md5)
      
      # 为方便演示,教程已预先定义好少量图像的路径和标签
      images = [os.path.join(dataset_dir, f) for f in os.listdir(dataset_dir) if f.endswith('.nii.gz')][:20]
      # 0代表男性, 1代表女性 (标签为示例,非真实)
      labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0])
      
      print(f"数据集已准备完成,共 {len(images)} 个3D MRI扫描。")
      	
      
      
  6. 定义数据预处理流程和加载器

    • 重要

      重点说明:处理3D数据时,一个关键步骤是Resize。由于每个病人的扫描尺寸可能不同,我们需要将所有输入的3D数据体重采样到统一的尺寸(例如96x96x96),才能送入模型进行训练。

      # 定义训练集的变换(包含随机旋转的数据增强)
      train_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(), Resize((96, 96, 96)), RandRotate90()])
      # 定义验证集的变换(无数据增强)
      val_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(), Resize((96, 96, 96))])
      
      # 创建训练和验证数据集
      train_ds = ImageDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms)
      val_ds = ImageDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms)
      
      # 创建数据加载器
      train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2)
      val_loader = DataLoader(val_ds, batch_size=2, num_workers=2)
      
      
  7. 定义模型并开始训练

    • 我们初始化一个DenseNet121模型,并通过spatial_dims=3参数告知模型它将要处理的是3D数据

      model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
      loss_function = torch.nn.CrossEntropyLoss()
      optimizer = torch.optim.Adam(model.parameters(), 1e-4)
      
      max_epochs = 5 # 为快速演示,只训练5个周期
      best_metric = -1
      best_metric_epoch = -1
      
      for epoch in range(max_epochs):
          print(f"\n--- Epoch {epoch + 1}/{max_epochs} ---")
          model.train()
          for batch_data in train_loader:
              inputs, labels = batch_data[0].to(device), torch.nn.functional.one_hot(torch.as_tensor(batch_data[1]), num_classes=2).float().to(device)
              optimizer.zero_grad()
              outputs = model(inputs)
              loss = loss_function(outputs, labels)
              loss.backward()
              optimizer.step()
          
          # 在验证集上评估
          model.eval()
          num_correct = 0.0
          metric_count = 0
          with torch.no_grad():
              for val_data in val_loader:
                  val_images, val_labels = val_data[0].to(device), torch.as_tensor(val_data[1]).to(device)
                  val_outputs = model(val_images)
                  value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                  metric_count += len(value)
                  num_correct += value.sum().item()
              
              metric = num_correct / metric_count
              if metric > best_metric:
                  best_metric = metric
                  best_metric_epoch = epoch + 1
                  torch.save(model.state_dict(), "best_metric_model_3d.pth")
      
              print(f"Validation Accuracy: {metric:.4f}, Best Accuracy: {best_metric:.4f}")
      
      print(f"\nTraining completed. Best accuracy: {best_metric:.4f} at epoch {best_metric_epoch}")
      
      
  8. 可视化模型决策依据:遮挡敏感度分析

    • 训练完成后,我们不仅关心模型“对不对”,还想知道它“为什么对”。我们加载最佳模型,并使用遮挡敏感度技术来可视化模型在进行分类时,最关注大脑的哪些区域。

      model.load_state_dict(torch.load("best_metric_model_3d.pth"))
      # 从验证集中取一个样本
      img, label = val_ds[0]
      img = img.unsqueeze(0).to(device)
      
      # 初始化遮挡敏感度分析器
      occ_sens = monai.visualize.OcclusionSensitivity(nn_module=model, mask_size=12, n_batch=10)
      
      # 为节省时间,我们只分析中间一个切片的热力图
      depth_slice = img.shape[2] // 2
      occ_sens_b_box = [depth_slice - 1, depth_slice, -1, -1, -1, -1]
      
      # 运行分析
      occ_result, _ = occ_sens(x=img, b_box=occ_sens_b_box)
      
      # 获取对应正确类别的热力图
      true_label_index = label.item()
      occ_result_slice = occ_result[0, true_label_index]
      
      # 绘图
      fig, axes = plt.subplots(1, 2, figsize=(12, 6))
      axes[0].set_title("Original MRI Slice")
      axes[0].imshow(img.cpu().numpy()[0, 0, depth_slice, :, :], cmap="gray")
      axes[0].axis("off")
      
      axes[1].set_title("Occlusion Sensitivity Heatmap")
      im = axes[1].imshow(occ_result_slice.detach().cpu(), cmap="jet")
      axes[1].axis("off")
      fig.colorbar(im, ax=axes[1])
      plt.show()
      
      
    • image

清理资源

警告

为避免产生不必要的个人扣费,实验完成后请务必按照以下步骤清理所有资源!

  1. 释放PAI-DSW实例

    • 返回 机器学习PAI控制台 的DSW实例列表页面

    • 找到本次实验创建的实例,点击右侧的【停止】

      image

    • 等待实例状态变为“已停止”后,为确保完全释放,再次点击右侧的【...】更多操作,选择【删除】

      image

    • 在弹出的确认框中点击【停止实例】/【删除实例】

      image

    • 等待一段时间检查是否删除成功

      image

  2. 删除OSS数据和Bucket

    • 进入 对象存储OSS控制台,找到为本次实验创建的Bucket,点击进入

    • 选中所有上传的数据文件和文件夹,点击【删除】,返回Bucket列表,选中该Bucket,点击【删除】,根据提示完成删除操作(可能需要清空碎片)

      image

关闭实验

  • 在完成实验后,点击 结束实操

    image

  • 点击 取消 回到实验页面,点击 确定 跳转实验评分

    image

  • 请为本次实验评分,并给出您的建议,点击 确认,结束本次实验

    image