评价此页

使用 PyTorch 和 TIAToolbox 进行全视野切片图像分类#

创建于:2023年12月19日 | 最后更新:2025年6月4日 | 最后验证:2024年11月5日

提示

为充分利用本教程,我们建议使用此Colab 版本。这将使您能够对下面介绍的信息进行实验。

简介#

在本教程中,我们将展示如何借助 TIAToolbox,使用 PyTorch 深度学习模型对全视野切片图像 (WSI) 进行分类。WSI 是通过手术或活检获取的人体组织样本的图像,并使用专用扫描仪进行扫描。病理学家和计算病理学研究人员使用它们来在微观层面研究癌症等疾病,以便了解例如肿瘤的生长情况,并帮助改善患者的治疗。

WSI 处理起来具有挑战性,原因在于其巨大的尺寸。例如,一张典型的切片图像大约有100,000x100,000 像素,其中每个像素可能对应切片上约 0.25x0.25 微米。这给加载和处理此类图像带来了挑战,更不用说在单个研究中处理数百甚至数千张 WSI(更大的研究能产生更好的结果)!

传统的图像处理流程不适用于 WSI 处理,因此我们需要更好的工具。这正是 TIAToolbox 发挥作用的地方,因为它提供了一套有用的工具,可以快速且计算高效地导入和处理组织切片。通常,WSI 以金字塔结构保存,其中包含同一图像在不同放大倍率下的多个副本,以便于可视化。金字塔的 0 级(或底层)包含最高放大倍率或缩放级别的图像,而金字塔的更高层则包含基础图像的较低分辨率副本。金字塔结构示意图如下。

WSI 金字塔堆栈 WSI 金字塔堆栈(来源

TIAToolbox 允许我们自动化常见的下游分析任务,例如组织分类。在本教程中,我们将向您展示如何:1. 使用 TIAToolbox 加载 WSI 图像;2. 使用不同的 PyTorch 模型在补丁级别对切片进行分类。在本教程中,我们将提供一个使用 TorchVision ResNet18 模型和自定义 HistoEncoder <jopo666/HistoEncoder>`__ 模型的示例。

让我们开始吧!

设置环境#

要运行本教程中提供的示例,需要以下软件包作为先决条件。

  1. OpenJpeg

  2. OpenSlide

  3. Pixman

  4. TIAToolbox

  5. HistoEncoder (用于自定义模型示例)

请在您的终端中运行以下命令来安装这些软件包

apt-get -y -qq install libopenjp2-7-dev libopenjp2-tools openslide-tools libpixman-1-dev pip install -q ‘tiatoolbox<1.5’ histoencoder && echo “安装完成。”

或者,您可以在 MacOS 上运行 brew install openjpeg openslide 来安装先决条件软件包,而不是使用 apt-get。更多关于安装的信息可以在这里找到。

运行前清理#

为确保适当的清理(例如在异常终止时),本次运行中下载或创建的所有文件都保存在单个目录 global_save_dir 中,我们将其设置为“./tmp/”。为了简化维护,该目录的名称仅在此处出现一次,因此如果需要,可以轻松更改。

warnings.filterwarnings("ignore")
global_save_dir = Path("./tmp/")


def rmdir(dir_path: str | Path) -> None:
    """Helper function to delete directory."""
    if Path(dir_path).is_dir():
        shutil.rmtree(dir_path)
        logger.info("Removing directory %s", dir_path)


rmdir(global_save_dir)  # remove  directory if it exists from previous runs
global_save_dir.mkdir()
logger.info("Creating new directory %s", global_save_dir)

下载数据#

对于我们的示例数据,我们将使用一张全视野切片图像,以及来自 Kather 100k 数据集验证子集的补丁。

wsi_path = global_save_dir / "sample_wsi.svs"
patches_path = global_save_dir / "kather100k-validation-sample.zip"
weights_path = global_save_dir / "resnet18-kather100k.pth"

logger.info("Download has started. Please wait...")

# Downloading and unzip a sample whole-slide image
download_data(
    "https://tiatoolbox.dcs.warwick.ac.uk/sample_wsis/TCGA-3L-AA1B-01Z-00-DX1.8923A151-A690-40B7-9E5A-FCBEDFC2394F.svs",
    wsi_path,
)

# Download and unzip a sample of the validation set used to train the Kather 100K dataset
download_data(
    "https://tiatoolbox.dcs.warwick.ac.uk/datasets/kather100k-validation-sample.zip",
    patches_path,
)
with ZipFile(patches_path, "r") as zipfile:
    zipfile.extractall(path=global_save_dir)

# Download pretrained model weights for WSI classification using ResNet18 architecture
download_data(
    "https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet18-kather100k.pth",
    weights_path,
)

logger.info("Download is complete.")

读取数据#

我们创建一个补丁列表和相应标签的列表。例如,label_list 中的第一个标签将指示 patch_list 中第一个图像补丁的类别。

# Read the patch data and create a list of patches and a list of corresponding labels
dataset_path = global_save_dir / "kather100k-validation-sample"

# Set the path to the dataset
image_ext = ".tif"  # file extension of each image

# Obtain the mapping between the label ID and the class name
label_dict = {
    "BACK": 0, # Background (empty glass region)
    "NORM": 1, # Normal colon mucosa
    "DEB": 2,  # Debris
    "TUM": 3,  # Colorectal adenocarcinoma epithelium
    "ADI": 4,  # Adipose
    "MUC": 5,  # Mucus
    "MUS": 6,  # Smooth muscle
    "STR": 7,  # Cancer-associated stroma
    "LYM": 8,  # Lymphocytes
}

class_names = list(label_dict.keys())
class_labels = list(label_dict.values())

# Generate a list of patches and generate the label from the filename
patch_list = []
label_list = []
for class_name, label in label_dict.items():
    dataset_class_path = dataset_path / class_name
    patch_list_single_class = grab_files_from_dir(
        dataset_class_path,
        file_types="*" + image_ext,
    )
    patch_list.extend(patch_list_single_class)
    label_list.extend([label] * len(patch_list_single_class))

# Show some dataset statistics
plt.bar(class_names, [label_list.count(label) for label in class_labels])
plt.xlabel("Patch types")
plt.ylabel("Number of patches")

# Count the number of examples per class
for class_name, label in label_dict.items():
    logger.info(
        "Class ID: %d -- Class Name: %s -- Number of images: %d",
        label,
        class_name,
        label_list.count(label),
    )

# Overall dataset statistics
logger.info("Total number of patches: %d", (len(patch_list)))
tiatoolbox tutorial
|2023-11-14|13:15:59.299| [INFO] Class ID: 0 -- Class Name: BACK -- Number of images: 211
|2023-11-14|13:15:59.299| [INFO] Class ID: 1 -- Class Name: NORM -- Number of images: 176
|2023-11-14|13:15:59.299| [INFO] Class ID: 2 -- Class Name: DEB -- Number of images: 230
|2023-11-14|13:15:59.299| [INFO] Class ID: 3 -- Class Name: TUM -- Number of images: 286
|2023-11-14|13:15:59.299| [INFO] Class ID: 4 -- Class Name: ADI -- Number of images: 208
|2023-11-14|13:15:59.299| [INFO] Class ID: 5 -- Class Name: MUC -- Number of images: 178
|2023-11-14|13:15:59.299| [INFO] Class ID: 6 -- Class Name: MUS -- Number of images: 270
|2023-11-14|13:15:59.299| [INFO] Class ID: 7 -- Class Name: STR -- Number of images: 209
|2023-11-14|13:15:59.299| [INFO] Class ID: 8 -- Class Name: LYM -- Number of images: 232
|2023-11-14|13:15:59.299| [INFO] Total number of patches: 2000

如您所见,对于这个补丁数据集,我们有 9 个类别/标签,ID 为 0-8,并附有相关的类别名称,描述了补丁中的主要组织类型。

  • BACK ⟶ 背景(空的玻璃区域)

  • LYM ⟶ 淋巴细胞

  • NORM ⟶ 正常结肠粘膜

  • DEB ⟶ 碎屑

  • MUS ⟶ 平滑肌

  • STR ⟶ 癌症相关基质

  • ADI ⟶ 脂肪组织

  • MUC ⟶ 黏液

  • TUM ⟶ 结直肠腺癌上皮

对图像补丁进行分类#

我们演示如何首先使用 patch 模式为数字切片中的每个补丁获取预测,然后使用 wsi 模式处理大尺寸切片。

定义 PatchPredictor 模型#

PatchPredictor 类运行一个用 PyTorch 编写的基于 CNN 的分类器。

  • model 可以是任何训练好的 PyTorch 模型,但有一个约束条件,即它必须遵循 tiatoolbox.models.abc.ModelABC (docs) <https://tia-toolbox.readthedocs.io/en/latest/_autosummary/tiatoolbox.models.models_abc.ModelABC.html>`__ 的类结构。有关此问题的更多信息,请参阅我们关于高级建模技术的示例笔记本。为了加载自定义模型,您需要编写一个小的预处理函数,如 preproc_func(img),以确保输入张量对于加载的网络是正确的格式。

  • 或者,您可以将 pretrained_model 作为字符串参数传递。这指定了执行预测的 CNN 模型,它必须是此处列出的模型之一。命令将如下所示:predictor = PatchPredictor(pretrained_model='resnet18-kather100k', pretrained_weights=weights_path, batch_size=32)

  • pretrained_weights:当使用 pretrained_model 时,默认情况下也会下载相应的预训练权重。您可以通过 pretrained_weight 参数使用您自己的一组权重来覆盖默认值。

  • batch_size:每次馈送到模型的图像数量。此参数值越高,需要的(GPU)内存容量就越大。

# Importing a pretrained PyTorch model from TIAToolbox
predictor = PatchPredictor(pretrained_model='resnet18-kather100k', batch_size=32)

# Users can load any PyTorch model architecture instead using the following script
model = vanilla.CNNModel(backbone="resnet18", num_classes=9) # Importing model from torchvision.models.resnet18
model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True), strict=True)
def preproc_func(img):
    img = PIL.Image.fromarray(img)
    img = transforms.ToTensor()(img)
    return img.permute(1, 2, 0)
model.preproc_func = preproc_func
predictor = PatchPredictor(model=model, batch_size=32)

预测补丁标签#

我们创建一个预测器对象,然后使用 patch 模式调用 predict 方法。然后我们计算分类准确率和混淆矩阵。

with suppress_console_output():
    output = predictor.predict(imgs=patch_list, mode="patch", on_gpu=ON_GPU)

acc = accuracy_score(label_list, output["predictions"])
logger.info("Classification accuracy: %f", acc)

# Creating and visualizing the confusion matrix for patch classification results
conf = confusion_matrix(label_list, output["predictions"], normalize="true")
df_cm = pd.DataFrame(conf, index=class_names, columns=class_names)
df_cm
|2023-11-14|13:16:03.215| [INFO] Classification accuracy: 0.993000
BACK NORM DEB TUM ADI MUC MUS STR LYM
BACK 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.00000
NORM 0.000000 0.988636 0.000000 0.011364 0.000000 0.000000 0.000000 0.000000 0.00000
DEB 0.000000 0.000000 0.991304 0.000000 0.000000 0.000000 0.000000 0.008696 0.00000
TUM 0.000000 0.000000 0.000000 0.996503 0.000000 0.003497 0.000000 0.000000 0.00000
ADI 0.004808 0.000000 0.000000 0.000000 0.990385 0.000000 0.004808 0.000000 0.00000
MUC 0.000000 0.000000 0.000000 0.000000 0.000000 0.988764 0.000000 0.011236 0.00000
MUS 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.996296 0.003704 0.00000
STR 0.000000 0.000000 0.004785 0.000000 0.000000 0.004785 0.004785 0.985646 0.00000
LYM 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.004310 0.99569


为整个切片预测补丁标签#

现在我们介绍 IOPatchPredictorConfig,这是一个指定模型预测引擎的图像读取和预测写入配置的类。这是必需的,用于告知分类器应读取 WSI 金字塔的哪个级别、处理数据并生成输出。

IOPatchPredictorConfig 的参数定义如下

  • input_resolutions: 一个列表,形式为字典,指定每个输入的分辨率。列表元素必须与目标 model.forward() 中的顺序相同。如果您的模型只接受一个输入,您只需放入一个指定 'units''resolution' 的字典。请注意,TIAToolbox 支持具有多个输入的模型。有关单位和分辨率的更多信息,请参阅 TIAToolbox 文档

  • patch_input_shape:最大输入的形状,格式为(高,宽)。

  • stride_shape:在补丁提取过程中,两个连续补丁之间的步幅(步长)大小。如果用户将 stride_shape 设置为等于 patch_input_shape,则补丁将被提取和处理,没有任何重叠。

wsi_ioconfig = IOPatchPredictorConfig(
    input_resolutions=[{"units": "mpp", "resolution": 0.5}],
    patch_input_shape=[224, 224],
    stride_shape=[224, 224],
)

predict 方法在输入补丁上应用 CNN 并获取结果。以下是参数及其描述

  • mode:要处理的输入类型。根据您的应用选择 patch, tilewsi

  • imgs:输入列表,应为输入瓦片或 WSI 的路径列表。

  • return_probabilities:设置为 True 以获取每个类别的概率以及输入补丁的预测标签。如果您希望合并预测以生成 tilewsi 模式的预测图,您可以设置 return_probabilities=True

  • ioconfig:使用 IOPatchPredictorConfig 类设置 IO 配置信息。

  • resolutionunit(下面未显示):这些参数指定了我们计划提取补丁的 WSI 级别的级别或微米/像素分辨率,可以用来代替 ioconfig。这里我们指定 WSI 级别为 'baseline',这相当于 0 级。通常,这是分辨率最高的级别。在这种特殊情况下,图像只有一个级别。更多信息可以在文档中找到。

  • masks:对应于 imgs 列表中 WSI 掩码的路径列表。这些掩码指定了我们想要从原始 WSI 中提取补丁的区域。如果某个特定 WSI 的掩码被指定为 None,那么该 WSI 的所有补丁(即使是背景区域)的标签都将被预测。这可能导致不必要的计算。

  • merge_predictions:如果需要生成补丁分类结果的 2D 地图,您可以将此参数设置为 True。然而,对于大型 WSI,这将需要较大的可用内存。另一种(默认)解决方案是设置 merge_predictions=False,然后使用 merge_predictions 函数生成 2D 预测地图,您将在后面看到。

由于我们使用的是大型 WSI,补丁提取和预测过程可能需要一些时间(如果您有支持 Cuda 的 GPU 和 PyTorch+Cuda,请确保设置 ON_GPU=True)。

with suppress_console_output():
    wsi_output = predictor.predict(
        imgs=[wsi_path],
        masks=None,
        mode="wsi",
        merge_predictions=False,
        ioconfig=wsi_ioconfig,
        return_probabilities=True,
        save_dir=global_save_dir / "wsi_predictions",
        on_gpu=ON_GPU,
    )

我们通过可视化 wsi_output 来观察预测模型在我们全视野切片图像上的工作情况。我们首先需要合并补丁预测输出,然后将它们作为原始图像上的叠加层进行可视化。和之前一样,merge_predictions 方法用于合并补丁预测。这里我们设置参数 resolution=1.25, units='power' 来生成 1.25 倍放大率下的预测图。如果您想获得更高/更低分辨率(更大/更小)的预测图,您需要相应地更改这些参数。当预测合并后,使用 overlay_patch_prediction 函数将预测图叠加在 WSI 缩略图上,该缩略图应在用于预测合并的分辨率下提取。

overview_resolution = (
    4  # the resolution in which we desire to merge and visualize the patch predictions
)
# the unit of the `resolution` parameter. Can be "power", "level", "mpp", or "baseline"
overview_unit = "mpp"
wsi = WSIReader.open(wsi_path)
wsi_overview = wsi.slide_thumbnail(resolution=overview_resolution, units=overview_unit)
plt.figure(), plt.imshow(wsi_overview)
plt.axis("off")
tiatoolbox tutorial

将预测图叠加在此图像上,如下所示

# Visualization of whole-slide image patch-level prediction
# first set up a label to color mapping
label_color_dict = {}
label_color_dict[0] = ("empty", (0, 0, 0))
colors = cm.get_cmap("Set1").colors
for class_name, label in label_dict.items():
    label_color_dict[label + 1] = (class_name, 255 * np.array(colors[label]))

pred_map = predictor.merge_predictions(
    wsi_path,
    wsi_output[0],
    resolution=overview_resolution,
    units=overview_unit,
)
overlay = overlay_prediction_mask(
    wsi_overview,
    pred_map,
    alpha=0.5,
    label_info=label_color_dict,
    return_ax=True,
)
plt.show()
tiatoolbox tutorial

使用病理学特定模型进行特征提取#

在本节中,我们将展示如何使用 TIAToolbox 提供的 WSI 推理引擎,从一个存在于 TIAToolbox 之外的预训练 PyTorch 模型中提取特征。为了说明这一点,我们将使用 HistoEncoder,这是一个计算病理学特定的模型,通过自监督方式训练,用于从组织学图像中提取特征。该模型可在此处获取

“HistoEncoder:数字病理学的基础模型”(jopo666/HistoEncoder) 作者:赫尔辛基大学的 Pohjonen, Joona 及其团队。

我们将绘制一个将特征图降维到 3D(RGB)的 UMAP 图,以可视化特征如何捕捉上述一些组织类型之间的差异。

# Import some extra modules
import histoencoder.functional as F
import torch.nn as nn

from tiatoolbox.models.engine.semantic_segmentor import DeepFeatureExtractor, IOSegmentorConfig
from tiatoolbox.models.models_abc import ModelABC
import umap

TIAToolbox 定义了一个 ModelABC,这是一个继承自 PyTorch nn.Module 的类,并指定了模型在 TIAToolbox 推理引擎中使用的外观。histoencoder 模型不遵循此结构,因此我们需要将其包装在一个类中,该类的输出和方法是 TIAToolbox 引擎所期望的。

class HistoEncWrapper(ModelABC):
    """Wrapper for HistoEnc model that conforms to tiatoolbox ModelABC interface."""

    def __init__(self: HistoEncWrapper, encoder) -> None:
        super().__init__()
        self.feat_extract = encoder

    def forward(self: HistoEncWrapper, imgs: torch.Tensor) -> torch.Tensor:
        """Pass input data through the model.

        Args:
            imgs (torch.Tensor):
                Model input.

        """
        out = F.extract_features(self.feat_extract, imgs, num_blocks=2, avg_pool=True)
        return out

    @staticmethod
    def infer_batch(
        model: nn.Module,
        batch_data: torch.Tensor,
        *,
        on_gpu: bool,
    ) -> list[np.ndarray]:
        """Run inference on an input batch.

        Contains logic for forward operation as well as i/o aggregation.

        Args:
            model (nn.Module):
                PyTorch defined model.
            batch_data (torch.Tensor):
                A batch of data generated by
                `torch.utils.data.DataLoader`.
            on_gpu (bool):
                Whether to run inference on a GPU.

        """
        img_patches_device = batch_data.to('cuda') if on_gpu else batch_data
        model.eval()
        # Do not compute the gradient (not training)
        with torch.inference_mode():
            output = model(img_patches_device)
        return [output.cpu().numpy()]

现在我们有了包装器,我们将创建我们的特征提取模型并实例化一个DeepFeatureExtractor,以允许我们在 WSI 上使用这个模型。我们将使用与上面相同的 WSI,但这次我们将使用 HistoEncoder 模型从 WSI 的补丁中提取特征,而不是为每个补丁预测某个标签。

# create the model
encoder = F.create_encoder("prostate_medium")
model = HistoEncWrapper(encoder)

# set the pre-processing function
norm=transforms.Normalize(mean=[0.662, 0.446, 0.605],std=[0.169, 0.190, 0.155])
trans = [
    transforms.ToTensor(),
    norm,
]
model.preproc_func = transforms.Compose(trans)

wsi_ioconfig = IOSegmentorConfig(
    input_resolutions=[{"units": "mpp", "resolution": 0.5}],
    patch_input_shape=[224, 224],
    output_resolutions=[{"units": "mpp", "resolution": 0.5}],
    patch_output_shape=[224, 224],
    stride_shape=[224, 224],
)

当我们创建 DeepFeatureExtractor 时,我们将传递 auto_generate_mask=True 参数。这将使用 Otsu 阈值法自动创建组织区域的掩码,以便提取器只处理那些包含组织的补丁。

# create the feature extractor and run it on the WSI
extractor = DeepFeatureExtractor(model=model, auto_generate_mask=True, batch_size=32, num_loader_workers=4, num_postproc_workers=4)
with suppress_console_output():
    out = extractor.predict(imgs=[wsi_path], mode="wsi", ioconfig=wsi_ioconfig, save_dir=global_save_dir / "wsi_features",)

这些特征可以用于训练下游模型,但在这里,为了对特征所代表的内容有一些直观的了解,我们将使用 UMAP 降维将特征可视化在 RGB 空间中。标记为相似颜色的点应该具有相似的特征,因此当我们将 UMAP 降维图叠加在 WSI 缩略图上时,我们可以检查特征是否自然地分离成不同的组织区域。我们将在接下来的单元格中将其与上面的补丁级预测图一起绘制,以比较特征与补丁级预测的情况。

# First we define a function to calculate the umap reduction
def umap_reducer(x, dims=3, nns=10):
    """UMAP reduction of the input data."""
    reducer = umap.UMAP(n_neighbors=nns, n_components=dims, metric="manhattan", spread=0.5, random_state=2)
    reduced = reducer.fit_transform(x)
    reduced -= reduced.min(axis=0)
    reduced /= reduced.max(axis=0)
    return reduced

# load the features output by our feature extractor
pos = np.load(global_save_dir / "wsi_features" / "0.position.npy")
feats = np.load(global_save_dir / "wsi_features" / "0.features.0.npy")
pos = pos / 8 # as we extracted at 0.5mpp, and we are overlaying on a thumbnail at 4mpp

# reduce the features into 3 dimensional (rgb) space
reduced = umap_reducer(feats)

# plot the prediction map the classifier again
overlay = overlay_prediction_mask(
    wsi_overview,
    pred_map,
    alpha=0.5,
    label_info=label_color_dict,
    return_ax=True,
)

# plot the feature map reduction
plt.figure()
plt.imshow(wsi_overview)
plt.scatter(pos[:,0], pos[:,1], c=reduced, s=1, alpha=0.5)
plt.axis("off")
plt.title("UMAP reduction of HistoEnc features")
plt.show()
  • tiatoolbox tutorial
  • UMAP reduction of HistoEnc features

我们看到,我们的补丁级预测器生成的预测图,以及我们的自监督特征编码器生成的特征图,都捕捉到了关于 WSI 中组织类型的相似信息。这是一个很好的健全性检查,表明我们的模型按预期工作。它还表明 HistoEncoder 模型提取的特征正在捕捉不同组织类型之间的差异,因此它们正在编码组织学上相关的信息。

下一步#

在本笔记本中,我们展示了如何使用 PatchPredictorDeepFeatureExtractor 类及其 predict 方法来预测大瓦片和 WSI 补丁的标签或提取特征。我们介绍了 merge_predictionsoverlay_prediction_mask 辅助函数,它们可以合并补丁预测输出,并将生成的预测图作为覆盖层在输入图像/WSI 上进行可视化。

所有过程都在 TIAToolbox 内进行,我们可以轻松地将各个部分组合起来,遵循我们的示例代码。请确保正确设置输入和选项。我们鼓励您进一步研究更改 predict 函数参数对预测输出的影响。我们已经演示了如何在 TIAToolbox 框架中使用您自己的预训练模型或研究社区为特定任务提供的模型,对大型 WSI 进行推理,即使模型结构未在 TIAToolbox 模型类中定义。

您可以通过以下资源了解更多信息