“AI橡皮擦”:基于PAI-DSW实现医学影像去噪与复原
本文中含有需要您注意的重要提示信息,忽略该信息可能对您的业务造成影响,请务必仔细阅读。
本实验将带您探索深度学习在图像质量增强领域的强大能力。您将在阿里云PAI-DSW环境中,训练一个先进的、基于Transformer架构的模型(Restormer),使其能够像一块“AI橡皮擦”一样,去除医学影像中的噪声和模糊,并将其复原为清晰的高质量图像。
实验简介
本实验将带您探索深度学习在图像质量增强领域的强大能力。您将在阿里云PAI-DSW环境中,训练一个先进的、基于Transformer架构的模型(Restormer),使其能够像一块“AI橡皮擦”一样,去除医学影像中的噪声和模糊,并将其复原为清晰的高质量图像。
背景知识
医学影像配准 (Medical Image Registration): 在医学影像的采集和传输过程中,常常会因为设备抖动、电流干扰或压缩等原因,引入噪声、模糊等“瑕疵”,这些瑕疵会严重影响医生的观察和诊断。影像复原技术的目标就是利用算法自动地去除这些瑕疵,重建出高质量的清晰图像,对于提升诊断准确率至关重要。
Transformer模型: Transformer最初在自然语言处理领域取得了巨大成功,近年来也被广泛应用于计算机视觉任务。其核心优势在于强大的“全局注意力机制”,能够捕捉图像中长距离的依赖关系,这使得它在理解复杂的图像纹理和结构、并进行精细化重建方面,比传统卷积网络更具潜力。
Restormer模型: Restormer是专门为高分辨率图像复原任务设计的一种高效Transformer模型。它通过巧妙的网络结构设计,在保持强大复原能力的同时,显著降低了计算复杂度,使其成为图像去噪、去模糊、去雨等多种复原任务的领先解决方案。
PAI-DSW: PAI-DSW是一个为开发者量身打造的云端深度学习开发环境。它预置了主流的深度学习框架,并提供了高性能的计算资源(如GPU),用户可以通过其交互式的JupyterLab界面,便捷地完成AI项目的开发、训练和评估。
实验室资源方式简介
进入实操前,请确保阿里云账号满足以下条件:
个人账号资源
使用您个人的云资源进行操作,资源归属于个人。
平台仅提供手册参考,不会对资源做任何操作。
确保已完成云工开物 300 元代金券领取。
已通过实名认证且账户余额 ≥0 元。
在实验页面,当您已阅读并同意上述创建资源的目的以及部分资源可能产生的计费规则。
资源消耗说明
本场景主要涉及以下云产品和服务:PAI、对象存储OSS。
本实验预计产生资源消耗:约10元(以使用ecs.gn6i-c8g1.2xlarge规格的PAI-DSW实例进行1小时的数据处理与模型训练为例估算)。
如果您调整了资源规格、延长了使用时长,或执行了本方案以外的操作,可能导致费用发生变化,请以控制台显示的实际价格和最终账单为准。
PAI-DSW: 费用主要由DSW实例的运行时长和其规格决定。本实验选用GPU实例进行模型训练,关机后即停止计费。
对象存储 OSS: 费用由数据存储容量和少量外网下行流量(仅在下载结果时产生)决定。
领取专属权益及创建实验资源
第一步:在开始实验之前,请先点击右侧屏幕的“进入实操”再进行后续操作

第二步:本次实验需要您通过领取阿里云云工开物学生专属300元抵扣券兑换本次实操的云资源,如未领取请先点击领取。(若已领取请跳过)
重要实验产生的费用优先使用优惠券,优惠券使用完毕后需您自行承担。

实验步骤
进入DSW控制台
登录阿里云,进入机器学习PAI控制台,在左侧导航栏选择【工作空间列表】,点击进入您的工作空间

在工作空间内,选择左侧的【模型开发与训练】—【DSW(Data Science Workshop)】

创建DSW实例
点击【创建实例】

实例名称:自定义一个名称,如 medical-2d-regist
资源组(机型):为了进行深度学习模型训练,我们需要选择GPU实例。点击【筛选】,勾选【GPU】,然后选择一款有库存的GPU机型,例如 ecs.gn6i-c8g1.2xlarge(vCPU: 8, 内存: 32GiB, GPU: NVIDIA T4 16GB)
说明这是成本消耗的主要来源,请务必注意实验后及时停止或删除实例!

镜像:选择一个预置了PyTorch框架的镜像,例如 pytorch:1.12-gpu-py39-cu113-ubuntu20.04

其他保持默认,选择完成后点击【确定】

等待约2-3分钟,直到实例状态变为“运行中”

进入JupyterLab环境
在DSW实例列表中,找到刚刚创建的实例,点击右侧的【打开】
返回JupyterLab的启动器(Launcher)页面,点击【Python 3 (PyTorch 1.12)】
创建一个新的Notebook文件

安装MONAI并导入环境
from monai.utils import set_determinism, first from monai.transforms import Compose, LoadImageD, EnsureChannelFirstD, ScaleIntensityd, RandGaussianNoiseD, RandGaussianSmoothD from monai.data import DataLoader, Dataset, CacheDataset from monai.networks.nets.restormer import Restormer from monai.apps import MedNISTDataset from monai.losses import SSIMLoss import os import torch import matplotlib.pyplot as plt import tempfile set_determinism(42) # 设置随机种子以保证实验结果可复现准备训练数据:创建“清晰-退化”图像对
我们将自动下载MedNIST数据集,并筛选出手部X光片。
- 说明
核心思想: 对于每一张清晰的原始图像,我们都需要一个与之对应的“退化”版本作为模型的输入。因此,我们会对原始图像进行人工“加戏”——随机地为其添加高斯噪声和高斯模糊,来模拟真实的低质量图像。
root_dir = tempfile.mkdtemp() # 创建一个临时目录来存放数据 train_data = MedNISTDataset(root_dir=root_dir, section="training", download=True, transform=None) # 筛选出手部X光片,并创建字典 training_datadict = [ {"original_hand": item["image"], "noisy_hand": item["image"]} for item in train_data.data if item["label"] == 4 # label 4 对应手部X光片 ] print(f"数据集已下载,共加载 {len(training_datadict)} 张手部X光片。")接下来定义数据预处理流程。RandGaussianNoiseD和RandGaussianSmoothD将负责对noisy_hand进行随机退化操作。
img_keys = ["original_hand", "noisy_hand"] degradation_key = "noisy_hand" train_transforms = Compose([ LoadImageD(keys=img_keys), EnsureChannelFirstD(keys=img_keys), ScaleIntensityd(keys=img_keys), # 归一化到[0, 1] # --- 只对 noisy_hand 进行退化处理 --- RandGaussianNoiseD(keys=[degradation_key], prob=0.5, std=0.1), RandGaussianSmoothD(keys=[degradation_key], prob=0.5, sigma_x=(0.5, 1.5), sigma_y=(0.5, 1.5)), ])
可视化数据样本
让我们随机抽取一组处理后的图像,直观地看一下退化效果。
check_ds = Dataset(data=training_datadict, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=1, shuffle=True) check_data = first(check_loader) original_image = check_data["original_hand"][0][0] noisy_image = check_data["noisy_hand"][0][0] plt.figure("check", (12, 6)) plt.subplot(1, 2, 1) plt.title("Degraded Image (Input)") plt.imshow(noisy_image, cmap="gray") plt.subplot(1, 2, 2) plt.title("Original Image (Target)") plt.imshow(original_image, cmap="gray") plt.show()
- 重要
重点说明: 为了让模型学习得更好(提升泛化能力),我们会对训练数据进行“数据增强”,即在每次训练时对图片进行随机的旋转、翻转和缩放,模拟真实世界中可能存在的各种变化。
# 为训练集定义带数据增强的变换 train_transforms = Compose([ LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity(), RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True), RandFlip(spatial_axis=0, prob=0.5), RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), ]) # 验证和测试集不需要数据增强 val_transforms = Compose([LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])
定义模型并开始训练
我们初始化一个轻量级的Restormer模型。
- 重要
重点说明: 在图像复原任务中,除了常用的MSE损失,SSIMLoss(结构相似性损失)是另一个非常有效的选择。它更关注图像的结构、亮度和对比度信息,而不是单纯的像素值差异,这有助于生成视觉效果更好的复原图像。
# 创建数据集加载器 train_ds = CacheDataset(data=training_datadict[:1000], transform=train_transforms, cache_rate=1.0) train_loader = DataLoader(train_ds, batch_size=16, shuffle=True) # 初始化模型、损失函数和优化器 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Restormer( spatial_dims=2, in_channels=1, out_channels=1, dim=32, num_blocks=[2, 2], heads=[2, 2], num_refinement_blocks=1, ).to(device) image_loss = SSIMLoss(spatial_dims=2, data_range=1.0) optimizer = torch.optim.Adam(model.parameters(), 1e-4) # 稍微提高学习率 # 开始训练 max_epochs = 50 epoch_loss_values = [] for epoch in range(max_epochs): model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 optimizer.zero_grad() noisy = batch_data["noisy_hand"].to(device) original = batch_data["original_hand"].to(device) pred_image = model(noisy) pred_image = torch.sigmoid(pred_image) # 确保输出在[0, 1]范围 loss = 1 - image_loss(input=pred_image, target=original) # SSIM值越大越好,所以用1-SSIM作为损失 loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss /= step epoch_loss_values.append(epoch_loss) if (epoch + 1) % 5 == 0: # 每5个epoch打印一次日志 print(f"Epoch {epoch + 1}/{max_epochs}, Average Loss: {epoch_loss:.4f}") print("训练完成!") 
在验证数据上评估模型效果
训练完成后,我们用模型去复原一些它从未见过的、同样经过随机退化的图像,并与原始清晰图像进行对比,直观地检验模型的复原能力。
# 准备验证数据 val_ds = CacheDataset(data=training_datadict[2000:2500], transform=train_transforms, cache_rate=1.0) val_loader = DataLoader(val_ds, batch_size=16) # 获取一个批次的预测结果 model.eval() with torch.no_grad(): for batch_data in val_loader: noisy = batch_data["noisy_hand"].to(device) original = batch_data["original_hand"].to(device) pred_image = model(noisy) pred_image = torch.sigmoid(pred_image) break # 只取第一个batch进行可视化 # 将数据转为numpy用于绘图 original_np = original.detach().cpu().numpy()[:, 0] noisy_np = noisy.detach().cpu().numpy()[:, 0] pred_np = pred_image.detach().cpu().numpy()[:, 0] # 绘图对比 num_to_show = 5 plt.figure(figsize=(9, 3 * num_to_show)) for i in range(num_to_show): plt.subplot(num_to_show, 3, i * 3 + 1) if i == 0: plt.title("Degraded Image") plt.imshow(noisy_np[i], cmap="gray") plt.axis("off") plt.subplot(num_to_show, 3, i * 3 + 2) if i == 0: plt.title("Restored Image") plt.imshow(pred_np[i], cmap="gray") plt.axis("off") plt.subplot(num_to_show, 3, i * 3 + 3) if i == 0: plt.title("Original Image") plt.imshow(original_np[i], cmap="gray") plt.axis("off") plt.tight_layout() plt.show()
清理资源
为避免产生不必要的个人扣费,实验完成后请务必按照以下步骤清理所有资源!
释放PAI-DSW实例
返回 机器学习PAI控制台 的DSW实例列表页面
找到本次实验创建的实例,点击右侧的【停止】

等待实例状态变为“已停止”后,为确保完全释放,再次点击右侧的【...】更多操作,选择【删除】

在弹出的确认框中点击【停止实例】/【删除实例】

等待一段时间检查是否删除成功

删除OSS数据和Bucket
进入 对象存储OSS控制台,找到为本次实验创建的Bucket,点击进入
选中所有上传的数据文件和文件夹,点击【删除】,返回Bucket列表,选中该Bucket,点击【删除】,根据提示完成删除操作(可能需要清空碎片)

关闭实验
在完成实验后,点击 结束实操

点击 取消 回到实验页面,点击 确定 跳转实验评分

请为本次实验评分,并给出您的建议,点击 确认,结束本次实验
