GAMMA比赛多模态眼底图像数据集下基于EfficientNet和ResNet构造fundus_img和oct_img的分类模型

作者信息:Tomoko-hjf

更新日期:2023 年 2 月

摘要:本示例教程演示如何使用EfficientNet和ResNet双分支网络完成多模态眼底图像分类

一、简要介绍

青光眼是一种不可逆的会导致视力下降的眼部疾病,及时地发现和诊断对青光眼的治疗至关重要。

在本实验中,我们有2D彩色眼底图像(如下面左图所示)3D OCT扫描体数据(如下面右图所示)两种模态的临床数据。下面右图展示的仅为3D图像中某一切片,实际每组3D图都由256张二维切片构成。

我们的任务是根据视觉特征将样本分级为无青光眼早期青光眼中或晚期青光眼三个类别。

二、环境设置

导入PaddlePaddle和一些其他数据处理会用到的包,本教程基于PaddlePaddle2.4.0编写。

# 导入包
import os

import numpy as np
import cv2
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import cohen_kappa_score
from tqdm import tqdm

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.vision.transforms as trans

import warnings

warnings.filterwarnings("ignore")

print(paddle.__version__)
2.4.0

三、数据集

3.1 数据介绍

本项目使用的数据集来自2021 MICCAI 竞赛的GAMMA挑战赛任务一:多模态青光眼分级,该数据将青光眼分为non(无),early(早期),mid_advanced(中晚期)三类。

数据集中包含200个两种临床模态影像的数据对,包括100对训练集,100对测试集。训练集中每对数据都包含一张常规的2D眼底彩照和一组3D光学相干断层扫描(OCT)两种模态的数据,每组3D图都由256张二维切片构成。

  • 可以报名飞桨学习赛:多模态青光眼分级获取数据

  • 也可以通过挂载AI Studio社区的官方数据集grad获取数据。

这里采用挂载飞桨官方数据集的方式。

3.2 解压数据集

执行如下解压命令,将数据集解压到本地。

!tar -xf /home/aistudio/data/data128738/gamma_grading.tar

3.3 划分训练集和测试集

因为官方只给了100个训练集和100个测试集,为了在训练过程中验证模型准确度,所以需要将官方训练集进一步划分为训练集和验证集。这里我们按8:2进行了划分。

# 划分训练集和测试集的比例
val_ratio = 0.2  # 80 / 20
# 训练数据根目录
trainset_root = "Glaucoma_grading/training/multi-modality_images"
# 标签文件名
gt_file = "Glaucoma_grading/training/glaucoma_grading_training_GT.xlsx"
# 测试数据根目录
testset_root = "Glaucoma_grading/testing/multi-modality_images"
# 读取所有训练数据文件名
filelists = os.listdir(trainset_root)
# 按照划分比例进行划分
train_filelists, val_filelists = train_test_split(
    filelists, test_size=val_ratio, random_state=42
)
print(
    "Total Nums: {}, train: {}, val: {}".format(
        len(filelists), len(train_filelists), len(val_filelists)
    )
)
Total Nums: 100, train: 80, val: 20

3.4 数据集类定义

定义自己的数据加载类GAMMA_sub1_dataset,该加载类继承了飞桨(PaddlePaddle)的paddle.io.Dataset父类,并实现了__getitem__()__len__()方法,__getitem__()方法可以根据下标获取输入网络的图像数据,__len__()方法可以返回数据集大小。

class GAMMA_sub1_dataset(paddle.io.Dataset):
    def __init__(
        self,
        img_transforms,
        oct_transforms,
        dataset_root,
        label_file="",
        filelists=None,
        num_classes=3,
        mode="train",
    ):
        self.dataset_root = dataset_root
        self.img_transforms = img_transforms
        self.oct_transforms = oct_transforms
        self.mode = mode.lower()
        self.num_classes = num_classes
        # 如果是加载训练集,则需要加载label
        if self.mode == "train":
            # 使用pandas读取label,label是one-hot形式
            label = {
                row["data"]: row[1:].values
                for _, row in pd.read_excel(label_file).iterrows()
            }

            self.file_list = [
                [f, label[int(f)]] for f in os.listdir(dataset_root)
            ]
        # 如果是加载测试集,则label为空
        elif self.mode == "test":
            self.file_list = [[f, None] for f in os.listdir(dataset_root)]

        # 如果指定了加载哪些数据,则只加载指定的数据
        if filelists is not None:
            self.file_list = [
                item for item in self.file_list if item[0] in filelists
            ]

    def __getitem__(self, idx):
        # 获取指定下标的训练集和标签,real_index是样本所在的文件夹名称
        real_index, label = self.file_list[idx]
        # 彩色眼底图像的路径
        fundus_img_path = os.path.join(
            self.dataset_root, real_index, real_index + ".jpg"
        )
        # 光学相干层析(OCT)图片的路径集,一个3D OCT图片包含256张二维切片
        oct_series_list = sorted(
            os.listdir(os.path.join(self.dataset_root, real_index, real_index)),
            key=lambda s: int(s.split("_")[0]),
        )
        # 使用opencv读取图片,并转换通道 BGR -> RGB
        fundus_img = cv2.imread(fundus_img_path)[:, :, ::-1]
        # 读取3D OCT图片的一个切片,注意是灰度图 cv2.IMREAD_GRAYSCALE
        oct_series_0 = cv2.imread(
            os.path.join(
                self.dataset_root, real_index, real_index, oct_series_list[0]
            ),
            cv2.IMREAD_GRAYSCALE,
        )
        oct_img = np.zeros(
            (
                oct_series_0.shape[0],
                oct_series_0.shape[1],
                len(oct_series_list),
            ),
            dtype="uint8",
        )
        # 依次读取每一个切片
        for k, p in enumerate(oct_series_list):
            oct_img[:, :, k] = cv2.imread(
                os.path.join(self.dataset_root, real_index, real_index, p),
                cv2.IMREAD_GRAYSCALE,
            )

        # 对彩色眼底图片进行数据增强
        if self.img_transforms is not None:
            fundus_img = self.img_transforms(fundus_img)

        # 对3D OCT图片进行数据增强
        if self.oct_transforms is not None:
            oct_img = self.oct_transforms(oct_img)

        # 交换维度,变为[通道数,高,宽],  H, W, C -> C, H, W
        fundus_img = fundus_img.transpose(2, 0, 1)
        oct_img = oct_img.transpose(2, 0, 1)

        if self.mode == "test":
            return fundus_img, oct_img, real_index
        if self.mode == "train":
            label = label.argmax()
            return fundus_img, oct_img, label

    # 获取数据集总的长度
    def __len__(self):
        return len(self.file_list)

3.5 数据增强操作

考虑到训练数据集较小,为了避免过拟合,可以使用paddle.vision.transform进行适当的水平翻转,垂直翻转,按比例裁剪等数据增强操作。

# 彩色眼底图片大小
image_size = [256, 256]
# 三维OCT图片每个切片的大小
oct_img_size = [512, 512]

# 数据增强操作
img_train_transforms = trans.Compose(
    [
        # 按比例随机裁剪原图后放缩到对应大小
        trans.RandomResizedCrop(
            image_size, scale=(0.90, 1.1), ratio=(0.90, 1.1)
        ),
        # 随机水平翻转
        trans.RandomHorizontalFlip(),
        # 随机垂直翻转
        trans.RandomVerticalFlip(),
        # 随机旋转
        trans.RandomRotation(30),
    ]
)

oct_train_transforms = trans.Compose(
    [
        # 中心裁剪到对应大小
        trans.CenterCrop(oct_img_size),
        # 随机水平翻转
        trans.RandomHorizontalFlip(),
        # 随机垂直翻转
        trans.RandomVerticalFlip(),
    ]
)

img_val_transforms = trans.Compose(
    [
        # 将图片放缩到固定大小
        trans.Resize(image_size)
    ]
)

oct_val_transforms = trans.Compose(
    [
        # 中心裁剪到固定大小
        trans.CenterCrop(oct_img_size)
    ]
)

3.6 数据集抽取展示

# 可视化训练数据集
_train = GAMMA_sub1_dataset(
    dataset_root=trainset_root,
    img_transforms=img_train_transforms,
    oct_transforms=oct_train_transforms,
    filelists=train_filelists,
    label_file=gt_file,
)

for i in range(5):
    fundus_img, oct_img, lab = _train.__getitem__(i)
    plt.subplot(2, 5, i + 1)
    plt.imshow(fundus_img.transpose(1, 2, 0))
    plt.axis("off")
    plt.subplot(2, 5, i + 6)
    # 展示灰度图
    plt.imshow(oct_img[100], cmap="gray")
    if i == 0:
        print("fundus_size", fundus_img.shape)
        print("oct_size", oct_img[100].shape)
    plt.axis("off")

四、模型组网

考虑到是多模态数据,所以使用了两个分支分别对数据进行训练,对于彩色眼底图像使用EfficientNetB3训练,对于3D OCT数据使用ResNet训练,然后将两个网络的输出特征共同输入分类头进行预测。

4.1 搭建EfficientNet网络分支

EfficientNet通过同时改变网络的分辨率(特征图的大小)、深度(网络的层数)、宽度(网络每层的通道数)来实现较高的准确率和较快的推理能力,网络的整体结构如下图所示。

class ConvBNLayer(nn.Layer):
    def __init__(
        self,
        inchannels,
        outchannels,
        stride,
        kernelsize=3,
        groups=1,
        padding="SAME",
    ):
        super(ConvBNLayer, self).__init__()
        padding = (kernelsize - 1) // 2
        self.conv = nn.Conv2D(
            in_channels=inchannels,
            out_channels=outchannels,
            kernel_size=kernelsize,
            stride=stride,
            padding=padding,
            groups=groups,
        )
        self.bn = nn.BatchNorm2D(outchannels)
        self.Swish = nn.Swish()

    def forward(self, inputs):
        x = self.conv(inputs)
        x = self.bn(x)
        x = self.Swish(x)
        return x


############ SE-Net ####################
class SE(nn.Layer):
    def __init__(self, inchannels):
        super(SE, self).__init__()
        self.pooling = nn.AdaptiveAvgPool2D(output_size=(1, 1))
        self.linear0 = nn.Conv2D(
            in_channels=inchannels,
            out_channels=int(inchannels * 0.25),
            kernel_size=1,
        )
        self.linear1 = nn.Conv2D(
            in_channels=int(inchannels * 0.25),
            out_channels=inchannels,
            kernel_size=1,
        )
        self.Swish = nn.Swish()
        self.Sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        x = self.pooling(inputs)
        x = self.linear0(x)
        x = self.Swish(x)
        x = self.linear1(x)
        x = self.Sigmoid(x)
        out = paddle.multiply(x, inputs)

        return out


############# MBConV ########################
class MBConv(nn.Layer):
    def __init__(
        self,
        inchannels,
        outchannels,
        channels_time,
        kernel_size,
        stride,
        connected_dropout,
    ):
        super(MBConv, self).__init__()
        self.stride = stride
        self.layer0 = ConvBNLayer(
            inchannels=inchannels,
            outchannels=inchannels * channels_time,
            kernelsize=1,
            stride=1,
        )
        self.layer1 = ConvBNLayer(
            inchannels=inchannels * channels_time,
            outchannels=inchannels * channels_time,
            kernelsize=kernel_size,
            stride=stride,
            groups=inchannels * channels_time,
        )
        self.SE = SE(inchannels=inchannels * channels_time)
        self.conv0 = nn.Conv2D(
            in_channels=inchannels * channels_time,
            out_channels=outchannels,
            kernel_size=1,
        )
        self.bn0 = nn.BatchNorm2D(outchannels)
        self.conv1 = nn.Conv2D(
            in_channels=inchannels, out_channels=outchannels, kernel_size=1
        )
        self.bn1 = nn.BatchNorm2D(outchannels)
        self.dropout = nn.Dropout(p=connected_dropout)

    def forward(self, inputs):
        y = inputs
        x = self.layer0(inputs)
        x = self.layer1(x)
        x = self.SE(x)
        x = self.conv0(x)
        x = self.bn0(x)
        x = self.dropout(x)
        if self.stride == 2:
            return x
        if self.stride == 1:
            y = self.conv1(inputs)
            y = self.bn1(y)
            return paddle.add(x, y)


#############  Classifier_Head ######################
class Classifier_Head(paddle.nn.Layer):
    def __init__(self, in_channels, num_channel, dropout_rate):
        super(Classifier_Head, self).__init__()
        self.pooling = nn.AdaptiveAvgPool2D(output_size=(1, 1))
        self.conv = ConvBNLayer(
            inchannels=in_channels, outchannels=1280, kernelsize=1, stride=1
        )
        self.dropout = nn.Dropout(p=dropout_rate)
        self.conv1 = nn.Conv2D(
            in_channels=1280,
            out_channels=num_channel,
            kernel_size=1,
            padding="SAME",
        )

    def forward(self, inputs):
        x = self.conv(inputs)
        x = self.pooling(x)
        x = self.dropout(x)
        x = self.conv1(x)
        x = paddle.squeeze(x, axis=[2, 3])
        x = F.softmax(x)

        return x


###############  EffictNet ##################
class EfficientNet(nn.Layer):
    def __init__(
        self,
        in_channels,
        num_class,
        width_coefficient,
        depth_coefficient,
        connected_dropout,
        dropout_rate,
    ):
        super(EfficientNet, self).__init__()
        block_setting = [
            [0, 1, 3, 1, 1, 16],
            [1, 2, 3, 2, 6, 24],
            [2, 2, 5, 2, 6, 40],
            [3, 3, 3, 2, 6, 80],
            [4, 3, 5, 1, 6, 112],
            [5, 4, 5, 2, 6, 192],
            [6, 1, 3, 1, 6, 320],
        ]
        self.block = []
        self.block.append(
            self.add_sublayer(
                "c" + str(-1),
                nn.Conv2D(
                    in_channels=in_channels,
                    out_channels=int(32 * width_coefficient),
                    kernel_size=3,
                    padding="SAME",
                    stride=2,
                ),
            )
        )
        self.block.append(
            self.add_sublayer(
                "bn" + str(-1), nn.BatchNorm2D(int(32 * width_coefficient))
            )
        )
        i = int(32 * width_coefficient)
        for j in range(int(depth_coefficient) - 1):
            if j == int(depth_coefficient) - 2:
                self.block.append(
                    self.add_sublayer(
                        "c" + str(j),
                        nn.Conv2D(
                            in_channels=i,
                            out_channels=int(32 * width_coefficient),
                            kernel_size=3,
                            padding="SAME",
                            stride=2,
                        ),
                    )
                )
                self.block.append(
                    self.add_sublayer(
                        "bn" + str(j),
                        nn.BatchNorm2D(int(32 * width_coefficient)),
                    )
                )
            else:
                self.block.append(
                    self.add_sublayer(
                        "c" + str(j),
                        nn.Conv2D(
                            in_channels=i,
                            out_channels=int(32 * width_coefficient),
                            kernel_size=3,
                            padding="SAME",
                        ),
                    )
                )
                self.block.append(
                    self.add_sublayer(
                        "bn" + str(j),
                        nn.BatchNorm2D(int(32 * width_coefficient)),
                    )
                )
        for n, r, k, s, e, o in block_setting:
            for j in range(int(r * depth_coefficient)):
                if j == int(r * depth_coefficient) - 1:
                    self.block.append(
                        self.add_sublayer(
                            "b" + str(n) + str(j),
                            MBConv(
                                inchannels=i,
                                outchannels=int(o * width_coefficient),
                                channels_time=e,
                                kernel_size=k,
                                stride=s,
                                connected_dropout=connected_dropout,
                            ),
                        )
                    )
                else:
                    self.block.append(
                        self.add_sublayer(
                            "b" + str(n) + str(j),
                            MBConv(
                                inchannels=i,
                                outchannels=int(o * width_coefficient),
                                channels_time=e,
                                kernel_size=k,
                                stride=1,
                                connected_dropout=connected_dropout,
                            ),
                        )
                    )
                i = int(o * width_coefficient)
        self.head = Classifier_Head(
            in_channels=i, num_channel=num_class, dropout_rate=dropout_rate
        )

    def forward(self, x):
        for layer in self.block:
            x = layer(x)
        x = self.head(x)
        return x


def EfficientNetB3(in_channels, num_class):
    return EfficientNet(
        in_channels=in_channels,
        num_class=num_class,
        width_coefficient=1.2,
        depth_coefficient=1.4,
        connected_dropout=0.2,
        dropout_rate=0.3,
    )
# 展示EfficientNetB3分支网络结构
efficientnet = EfficientNetB3(in_channels=3, num_class=1000)
paddle.summary(efficientnet, (1, 3, 256, 256))
--------------------------------------------------------------------------------
    Layer (type)         Input Shape          Output Shape         Param #    
================================================================================
      Conv2D-1        [[1, 3, 256, 256]]   [1, 38, 128, 128]        1,064     
   BatchNorm2D-1     [[1, 38, 128, 128]]   [1, 38, 128, 128]         152      
      Conv2D-2       [[1, 38, 128, 128]]   [1, 38, 128, 128]        1,482     
   BatchNorm2D-2     [[1, 38, 128, 128]]   [1, 38, 128, 128]         152      
      Swish-1        [[1, 38, 128, 128]]   [1, 38, 128, 128]          0       
   ConvBNLayer-1     [[1, 38, 128, 128]]   [1, 38, 128, 128]          0       
      Conv2D-3       [[1, 38, 128, 128]]   [1, 38, 128, 128]         380      
   BatchNorm2D-3     [[1, 38, 128, 128]]   [1, 38, 128, 128]         152      
      Swish-2        [[1, 38, 128, 128]]   [1, 38, 128, 128]          0       
   ConvBNLayer-2     [[1, 38, 128, 128]]   [1, 38, 128, 128]          0       
AdaptiveAvgPool2D-1  [[1, 38, 128, 128]]     [1, 38, 1, 1]            0       
      Conv2D-4         [[1, 38, 1, 1]]        [1, 9, 1, 1]           351      
      Swish-3           [[1, 9, 1, 1]]        [1, 9, 1, 1]            0       
      Conv2D-5          [[1, 9, 1, 1]]       [1, 38, 1, 1]           380      
     Sigmoid-1         [[1, 38, 1, 1]]       [1, 38, 1, 1]            0       
        SE-1         [[1, 38, 128, 128]]   [1, 38, 128, 128]          0       
      Conv2D-6       [[1, 38, 128, 128]]   [1, 19, 128, 128]         741      
   BatchNorm2D-4     [[1, 19, 128, 128]]   [1, 19, 128, 128]         76       
     Dropout-1       [[1, 19, 128, 128]]   [1, 19, 128, 128]          0       
      Conv2D-7       [[1, 38, 128, 128]]   [1, 19, 128, 128]         741      
   BatchNorm2D-5     [[1, 19, 128, 128]]   [1, 19, 128, 128]         76       
      MBConv-1       [[1, 38, 128, 128]]   [1, 19, 128, 128]          0       
      Conv2D-8       [[1, 19, 128, 128]]   [1, 114, 128, 128]       2,280     
   BatchNorm2D-6     [[1, 114, 128, 128]]  [1, 114, 128, 128]        456      
      Swish-4        [[1, 114, 128, 128]]  [1, 114, 128, 128]         0       
   ConvBNLayer-3     [[1, 19, 128, 128]]   [1, 114, 128, 128]         0       
      Conv2D-9       [[1, 114, 128, 128]]  [1, 114, 128, 128]       1,140     
   BatchNorm2D-7     [[1, 114, 128, 128]]  [1, 114, 128, 128]        456      
      Swish-5        [[1, 114, 128, 128]]  [1, 114, 128, 128]         0       
   ConvBNLayer-4     [[1, 114, 128, 128]]  [1, 114, 128, 128]         0       
AdaptiveAvgPool2D-2  [[1, 114, 128, 128]]    [1, 114, 1, 1]           0       
     Conv2D-10         [[1, 114, 1, 1]]      [1, 28, 1, 1]          3,220     
      Swish-6          [[1, 28, 1, 1]]       [1, 28, 1, 1]            0       
     Conv2D-11         [[1, 28, 1, 1]]       [1, 114, 1, 1]         3,306     
     Sigmoid-2         [[1, 114, 1, 1]]      [1, 114, 1, 1]           0       
        SE-2         [[1, 114, 128, 128]]  [1, 114, 128, 128]         0       
     Conv2D-12       [[1, 114, 128, 128]]  [1, 28, 128, 128]        3,220     
   BatchNorm2D-8     [[1, 28, 128, 128]]   [1, 28, 128, 128]         112      
     Dropout-2       [[1, 28, 128, 128]]   [1, 28, 128, 128]          0       
     Conv2D-13       [[1, 19, 128, 128]]   [1, 28, 128, 128]         560      
   BatchNorm2D-9     [[1, 28, 128, 128]]   [1, 28, 128, 128]         112      
      MBConv-2       [[1, 19, 128, 128]]   [1, 28, 128, 128]          0       
     Conv2D-14       [[1, 28, 128, 128]]   [1, 168, 128, 128]       4,872     
   BatchNorm2D-10    [[1, 168, 128, 128]]  [1, 168, 128, 128]        672      
      Swish-7        [[1, 168, 128, 128]]  [1, 168, 128, 128]         0       
   ConvBNLayer-5     [[1, 28, 128, 128]]   [1, 168, 128, 128]         0       
     Conv2D-15       [[1, 168, 128, 128]]   [1, 168, 64, 64]        1,680     
   BatchNorm2D-11     [[1, 168, 64, 64]]    [1, 168, 64, 64]         672      
      Swish-8         [[1, 168, 64, 64]]    [1, 168, 64, 64]          0       
   ConvBNLayer-6     [[1, 168, 128, 128]]   [1, 168, 64, 64]          0       
AdaptiveAvgPool2D-3   [[1, 168, 64, 64]]     [1, 168, 1, 1]           0       
     Conv2D-16         [[1, 168, 1, 1]]      [1, 42, 1, 1]          7,098     
      Swish-9          [[1, 42, 1, 1]]       [1, 42, 1, 1]            0       
     Conv2D-17         [[1, 42, 1, 1]]       [1, 168, 1, 1]         7,224     
     Sigmoid-3         [[1, 168, 1, 1]]      [1, 168, 1, 1]           0       
        SE-3          [[1, 168, 64, 64]]    [1, 168, 64, 64]          0       
     Conv2D-18        [[1, 168, 64, 64]]    [1, 28, 64, 64]         4,732     
   BatchNorm2D-12     [[1, 28, 64, 64]]     [1, 28, 64, 64]          112      
     Dropout-3        [[1, 28, 64, 64]]     [1, 28, 64, 64]           0       
      MBConv-3       [[1, 28, 128, 128]]    [1, 28, 64, 64]           0       
     Conv2D-20        [[1, 28, 64, 64]]     [1, 168, 64, 64]        4,872     
   BatchNorm2D-14     [[1, 168, 64, 64]]    [1, 168, 64, 64]         672      
      Swish-10        [[1, 168, 64, 64]]    [1, 168, 64, 64]          0       
   ConvBNLayer-7      [[1, 28, 64, 64]]     [1, 168, 64, 64]          0       
     Conv2D-21        [[1, 168, 64, 64]]    [1, 168, 64, 64]        4,368     
   BatchNorm2D-15     [[1, 168, 64, 64]]    [1, 168, 64, 64]         672      
      Swish-11        [[1, 168, 64, 64]]    [1, 168, 64, 64]          0       
   ConvBNLayer-8      [[1, 168, 64, 64]]    [1, 168, 64, 64]          0       
AdaptiveAvgPool2D-4   [[1, 168, 64, 64]]     [1, 168, 1, 1]           0       
     Conv2D-22         [[1, 168, 1, 1]]      [1, 42, 1, 1]          7,098     
      Swish-12         [[1, 42, 1, 1]]       [1, 42, 1, 1]            0       
     Conv2D-23         [[1, 42, 1, 1]]       [1, 168, 1, 1]         7,224     
     Sigmoid-4         [[1, 168, 1, 1]]      [1, 168, 1, 1]           0       
        SE-4          [[1, 168, 64, 64]]    [1, 168, 64, 64]          0       
     Conv2D-24        [[1, 168, 64, 64]]    [1, 48, 64, 64]         8,112     
   BatchNorm2D-16     [[1, 48, 64, 64]]     [1, 48, 64, 64]          192      
     Dropout-4        [[1, 48, 64, 64]]     [1, 48, 64, 64]           0       
     Conv2D-25        [[1, 28, 64, 64]]     [1, 48, 64, 64]         1,392     
   BatchNorm2D-17     [[1, 48, 64, 64]]     [1, 48, 64, 64]          192      
      MBConv-4        [[1, 28, 64, 64]]     [1, 48, 64, 64]           0       
     Conv2D-26        [[1, 48, 64, 64]]     [1, 288, 64, 64]       14,112     
   BatchNorm2D-18     [[1, 288, 64, 64]]    [1, 288, 64, 64]        1,152     
      Swish-13        [[1, 288, 64, 64]]    [1, 288, 64, 64]          0       
   ConvBNLayer-9      [[1, 48, 64, 64]]     [1, 288, 64, 64]          0       
     Conv2D-27        [[1, 288, 64, 64]]    [1, 288, 32, 32]        7,488     
   BatchNorm2D-19     [[1, 288, 32, 32]]    [1, 288, 32, 32]        1,152     
      Swish-14        [[1, 288, 32, 32]]    [1, 288, 32, 32]          0       
   ConvBNLayer-10     [[1, 288, 64, 64]]    [1, 288, 32, 32]          0       
AdaptiveAvgPool2D-5   [[1, 288, 32, 32]]     [1, 288, 1, 1]           0       
     Conv2D-28         [[1, 288, 1, 1]]      [1, 72, 1, 1]         20,808     
      Swish-15         [[1, 72, 1, 1]]       [1, 72, 1, 1]            0       
     Conv2D-29         [[1, 72, 1, 1]]       [1, 288, 1, 1]        21,024     
     Sigmoid-5         [[1, 288, 1, 1]]      [1, 288, 1, 1]           0       
        SE-5          [[1, 288, 32, 32]]    [1, 288, 32, 32]          0       
     Conv2D-30        [[1, 288, 32, 32]]    [1, 48, 32, 32]        13,872     
   BatchNorm2D-20     [[1, 48, 32, 32]]     [1, 48, 32, 32]          192      
     Dropout-5        [[1, 48, 32, 32]]     [1, 48, 32, 32]           0       
      MBConv-5        [[1, 48, 64, 64]]     [1, 48, 32, 32]           0       
     Conv2D-32        [[1, 48, 32, 32]]     [1, 288, 32, 32]       14,112     
   BatchNorm2D-22     [[1, 288, 32, 32]]    [1, 288, 32, 32]        1,152     
      Swish-16        [[1, 288, 32, 32]]    [1, 288, 32, 32]          0       
   ConvBNLayer-11     [[1, 48, 32, 32]]     [1, 288, 32, 32]          0       
     Conv2D-33        [[1, 288, 32, 32]]    [1, 288, 32, 32]        2,880     
   BatchNorm2D-23     [[1, 288, 32, 32]]    [1, 288, 32, 32]        1,152     
      Swish-17        [[1, 288, 32, 32]]    [1, 288, 32, 32]          0       
   ConvBNLayer-12     [[1, 288, 32, 32]]    [1, 288, 32, 32]          0       
AdaptiveAvgPool2D-6   [[1, 288, 32, 32]]     [1, 288, 1, 1]           0       
     Conv2D-34         [[1, 288, 1, 1]]      [1, 72, 1, 1]         20,808     
      Swish-18         [[1, 72, 1, 1]]       [1, 72, 1, 1]            0       
     Conv2D-35         [[1, 72, 1, 1]]       [1, 288, 1, 1]        21,024     
     Sigmoid-6         [[1, 288, 1, 1]]      [1, 288, 1, 1]           0       
        SE-6          [[1, 288, 32, 32]]    [1, 288, 32, 32]          0       
     Conv2D-36        [[1, 288, 32, 32]]    [1, 96, 32, 32]        27,744     
   BatchNorm2D-24     [[1, 96, 32, 32]]     [1, 96, 32, 32]          384      
     Dropout-6        [[1, 96, 32, 32]]     [1, 96, 32, 32]           0       
     Conv2D-37        [[1, 48, 32, 32]]     [1, 96, 32, 32]         4,704     
   BatchNorm2D-25     [[1, 96, 32, 32]]     [1, 96, 32, 32]          384      
      MBConv-6        [[1, 48, 32, 32]]     [1, 96, 32, 32]           0       
     Conv2D-38        [[1, 96, 32, 32]]     [1, 576, 32, 32]       55,872     
   BatchNorm2D-26     [[1, 576, 32, 32]]    [1, 576, 32, 32]        2,304     
      Swish-19        [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
   ConvBNLayer-13     [[1, 96, 32, 32]]     [1, 576, 32, 32]          0       
     Conv2D-39        [[1, 576, 32, 32]]    [1, 576, 32, 32]        5,760     
   BatchNorm2D-27     [[1, 576, 32, 32]]    [1, 576, 32, 32]        2,304     
      Swish-20        [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
   ConvBNLayer-14     [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
AdaptiveAvgPool2D-7   [[1, 576, 32, 32]]     [1, 576, 1, 1]           0       
     Conv2D-40         [[1, 576, 1, 1]]      [1, 144, 1, 1]        83,088     
      Swish-21         [[1, 144, 1, 1]]      [1, 144, 1, 1]           0       
     Conv2D-41         [[1, 144, 1, 1]]      [1, 576, 1, 1]        83,520     
     Sigmoid-7         [[1, 576, 1, 1]]      [1, 576, 1, 1]           0       
        SE-7          [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
     Conv2D-42        [[1, 576, 32, 32]]    [1, 96, 32, 32]        55,392     
   BatchNorm2D-28     [[1, 96, 32, 32]]     [1, 96, 32, 32]          384      
     Dropout-7        [[1, 96, 32, 32]]     [1, 96, 32, 32]           0       
     Conv2D-43        [[1, 96, 32, 32]]     [1, 96, 32, 32]         9,312     
   BatchNorm2D-29     [[1, 96, 32, 32]]     [1, 96, 32, 32]          384      
      MBConv-7        [[1, 96, 32, 32]]     [1, 96, 32, 32]           0       
     Conv2D-44        [[1, 96, 32, 32]]     [1, 576, 32, 32]       55,872     
   BatchNorm2D-30     [[1, 576, 32, 32]]    [1, 576, 32, 32]        2,304     
      Swish-22        [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
   ConvBNLayer-15     [[1, 96, 32, 32]]     [1, 576, 32, 32]          0       
     Conv2D-45        [[1, 576, 32, 32]]    [1, 576, 32, 32]        5,760     
   BatchNorm2D-31     [[1, 576, 32, 32]]    [1, 576, 32, 32]        2,304     
      Swish-23        [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
   ConvBNLayer-16     [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
AdaptiveAvgPool2D-8   [[1, 576, 32, 32]]     [1, 576, 1, 1]           0       
     Conv2D-46         [[1, 576, 1, 1]]      [1, 144, 1, 1]        83,088     
      Swish-24         [[1, 144, 1, 1]]      [1, 144, 1, 1]           0       
     Conv2D-47         [[1, 144, 1, 1]]      [1, 576, 1, 1]        83,520     
     Sigmoid-8         [[1, 576, 1, 1]]      [1, 576, 1, 1]           0       
        SE-8          [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
     Conv2D-48        [[1, 576, 32, 32]]    [1, 96, 32, 32]        55,392     
   BatchNorm2D-32     [[1, 96, 32, 32]]     [1, 96, 32, 32]          384      
     Dropout-8        [[1, 96, 32, 32]]     [1, 96, 32, 32]           0       
     Conv2D-49        [[1, 96, 32, 32]]     [1, 96, 32, 32]         9,312     
   BatchNorm2D-33     [[1, 96, 32, 32]]     [1, 96, 32, 32]          384      
      MBConv-8        [[1, 96, 32, 32]]     [1, 96, 32, 32]           0       
     Conv2D-50        [[1, 96, 32, 32]]     [1, 576, 32, 32]       55,872     
   BatchNorm2D-34     [[1, 576, 32, 32]]    [1, 576, 32, 32]        2,304     
      Swish-25        [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
   ConvBNLayer-17     [[1, 96, 32, 32]]     [1, 576, 32, 32]          0       
     Conv2D-51        [[1, 576, 32, 32]]    [1, 576, 16, 16]        5,760     
   BatchNorm2D-35     [[1, 576, 16, 16]]    [1, 576, 16, 16]        2,304     
      Swish-26        [[1, 576, 16, 16]]    [1, 576, 16, 16]          0       
   ConvBNLayer-18     [[1, 576, 32, 32]]    [1, 576, 16, 16]          0       
AdaptiveAvgPool2D-9   [[1, 576, 16, 16]]     [1, 576, 1, 1]           0       
     Conv2D-52         [[1, 576, 1, 1]]      [1, 144, 1, 1]        83,088     
      Swish-27         [[1, 144, 1, 1]]      [1, 144, 1, 1]           0       
     Conv2D-53         [[1, 144, 1, 1]]      [1, 576, 1, 1]        83,520     
     Sigmoid-9         [[1, 576, 1, 1]]      [1, 576, 1, 1]           0       
        SE-9          [[1, 576, 16, 16]]    [1, 576, 16, 16]          0       
     Conv2D-54        [[1, 576, 16, 16]]    [1, 96, 16, 16]        55,392     
   BatchNorm2D-36     [[1, 96, 16, 16]]     [1, 96, 16, 16]          384      
     Dropout-9        [[1, 96, 16, 16]]     [1, 96, 16, 16]           0       
      MBConv-9        [[1, 96, 32, 32]]     [1, 96, 16, 16]           0       
     Conv2D-56        [[1, 96, 16, 16]]     [1, 576, 16, 16]       55,872     
   BatchNorm2D-38     [[1, 576, 16, 16]]    [1, 576, 16, 16]        2,304     
      Swish-28        [[1, 576, 16, 16]]    [1, 576, 16, 16]          0       
   ConvBNLayer-19     [[1, 96, 16, 16]]     [1, 576, 16, 16]          0       
     Conv2D-57        [[1, 576, 16, 16]]    [1, 576, 16, 16]       14,976     
   BatchNorm2D-39     [[1, 576, 16, 16]]    [1, 576, 16, 16]        2,304     
      Swish-29        [[1, 576, 16, 16]]    [1, 576, 16, 16]          0       
   ConvBNLayer-20     [[1, 576, 16, 16]]    [1, 576, 16, 16]          0       
AdaptiveAvgPool2D-10  [[1, 576, 16, 16]]     [1, 576, 1, 1]           0       
     Conv2D-58         [[1, 576, 1, 1]]      [1, 144, 1, 1]        83,088     
      Swish-30         [[1, 144, 1, 1]]      [1, 144, 1, 1]           0       
     Conv2D-59         [[1, 144, 1, 1]]      [1, 576, 1, 1]        83,520     
     Sigmoid-10        [[1, 576, 1, 1]]      [1, 576, 1, 1]           0       
       SE-10          [[1, 576, 16, 16]]    [1, 576, 16, 16]          0       
     Conv2D-60        [[1, 576, 16, 16]]    [1, 134, 16, 16]       77,318     
   BatchNorm2D-40     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     Dropout-10       [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-61        [[1, 96, 16, 16]]     [1, 134, 16, 16]       12,998     
   BatchNorm2D-41     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     MBConv-10        [[1, 96, 16, 16]]     [1, 134, 16, 16]          0       
     Conv2D-62        [[1, 134, 16, 16]]    [1, 804, 16, 16]       108,540    
   BatchNorm2D-42     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-31        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-21     [[1, 134, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-63        [[1, 804, 16, 16]]    [1, 804, 16, 16]       20,904     
   BatchNorm2D-43     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-32        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-22     [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
AdaptiveAvgPool2D-11  [[1, 804, 16, 16]]     [1, 804, 1, 1]           0       
     Conv2D-64         [[1, 804, 1, 1]]      [1, 201, 1, 1]        161,805    
      Swish-33         [[1, 201, 1, 1]]      [1, 201, 1, 1]           0       
     Conv2D-65         [[1, 201, 1, 1]]      [1, 804, 1, 1]        162,408    
     Sigmoid-11        [[1, 804, 1, 1]]      [1, 804, 1, 1]           0       
       SE-11          [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-66        [[1, 804, 16, 16]]    [1, 134, 16, 16]       107,870    
   BatchNorm2D-44     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     Dropout-11       [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-67        [[1, 134, 16, 16]]    [1, 134, 16, 16]       18,090     
   BatchNorm2D-45     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     MBConv-11        [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-68        [[1, 134, 16, 16]]    [1, 804, 16, 16]       108,540    
   BatchNorm2D-46     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-34        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-23     [[1, 134, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-69        [[1, 804, 16, 16]]    [1, 804, 16, 16]       20,904     
   BatchNorm2D-47     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-35        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-24     [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
AdaptiveAvgPool2D-12  [[1, 804, 16, 16]]     [1, 804, 1, 1]           0       
     Conv2D-70         [[1, 804, 1, 1]]      [1, 201, 1, 1]        161,805    
      Swish-36         [[1, 201, 1, 1]]      [1, 201, 1, 1]           0       
     Conv2D-71         [[1, 201, 1, 1]]      [1, 804, 1, 1]        162,408    
     Sigmoid-12        [[1, 804, 1, 1]]      [1, 804, 1, 1]           0       
       SE-12          [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-72        [[1, 804, 16, 16]]    [1, 134, 16, 16]       107,870    
   BatchNorm2D-48     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     Dropout-12       [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-73        [[1, 134, 16, 16]]    [1, 134, 16, 16]       18,090     
   BatchNorm2D-49     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     MBConv-12        [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-74        [[1, 134, 16, 16]]    [1, 804, 16, 16]       108,540    
   BatchNorm2D-50     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-37        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-25     [[1, 134, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-75        [[1, 804, 16, 16]]    [1, 804, 16, 16]       20,904     
   BatchNorm2D-51     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-38        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-26     [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
AdaptiveAvgPool2D-13  [[1, 804, 16, 16]]     [1, 804, 1, 1]           0       
     Conv2D-76         [[1, 804, 1, 1]]      [1, 201, 1, 1]        161,805    
      Swish-39         [[1, 201, 1, 1]]      [1, 201, 1, 1]           0       
     Conv2D-77         [[1, 201, 1, 1]]      [1, 804, 1, 1]        162,408    
     Sigmoid-13        [[1, 804, 1, 1]]      [1, 804, 1, 1]           0       
       SE-13          [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-78        [[1, 804, 16, 16]]    [1, 134, 16, 16]       107,870    
   BatchNorm2D-52     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     Dropout-13       [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-79        [[1, 134, 16, 16]]    [1, 134, 16, 16]       18,090     
   BatchNorm2D-53     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     MBConv-13        [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-80        [[1, 134, 16, 16]]    [1, 804, 16, 16]       108,540    
   BatchNorm2D-54     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-40        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-27     [[1, 134, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-81        [[1, 804, 16, 16]]    [1, 804, 16, 16]       20,904     
   BatchNorm2D-55     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-41        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-28     [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
AdaptiveAvgPool2D-14  [[1, 804, 16, 16]]     [1, 804, 1, 1]           0       
     Conv2D-82         [[1, 804, 1, 1]]      [1, 201, 1, 1]        161,805    
      Swish-42         [[1, 201, 1, 1]]      [1, 201, 1, 1]           0       
     Conv2D-83         [[1, 201, 1, 1]]      [1, 804, 1, 1]        162,408    
     Sigmoid-14        [[1, 804, 1, 1]]      [1, 804, 1, 1]           0       
       SE-14          [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-84        [[1, 804, 16, 16]]    [1, 230, 16, 16]       185,150    
   BatchNorm2D-56     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     Dropout-14       [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-85        [[1, 134, 16, 16]]    [1, 230, 16, 16]       31,050     
   BatchNorm2D-57     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     MBConv-14        [[1, 134, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-86        [[1, 230, 16, 16]]   [1, 1380, 16, 16]       318,780    
   BatchNorm2D-58    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
      Swish-43       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-29     [[1, 230, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-87       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]       35,880     
   BatchNorm2D-59    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
      Swish-44       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-30    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
AdaptiveAvgPool2D-15 [[1, 1380, 16, 16]]    [1, 1380, 1, 1]           0       
     Conv2D-88        [[1, 1380, 1, 1]]      [1, 345, 1, 1]        476,445    
      Swish-45         [[1, 345, 1, 1]]      [1, 345, 1, 1]           0       
     Conv2D-89         [[1, 345, 1, 1]]     [1, 1380, 1, 1]        477,480    
     Sigmoid-15       [[1, 1380, 1, 1]]     [1, 1380, 1, 1]           0       
       SE-15         [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-90       [[1, 1380, 16, 16]]    [1, 230, 16, 16]       317,630    
   BatchNorm2D-60     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     Dropout-15       [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-91        [[1, 230, 16, 16]]    [1, 230, 16, 16]       53,130     
   BatchNorm2D-61     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     MBConv-15        [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-92        [[1, 230, 16, 16]]   [1, 1380, 16, 16]       318,780    
   BatchNorm2D-62    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
      Swish-46       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-31     [[1, 230, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-93       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]       35,880     
   BatchNorm2D-63    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
      Swish-47       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-32    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
AdaptiveAvgPool2D-16 [[1, 1380, 16, 16]]    [1, 1380, 1, 1]           0       
     Conv2D-94        [[1, 1380, 1, 1]]      [1, 345, 1, 1]        476,445    
      Swish-48         [[1, 345, 1, 1]]      [1, 345, 1, 1]           0       
     Conv2D-95         [[1, 345, 1, 1]]     [1, 1380, 1, 1]        477,480    
     Sigmoid-16       [[1, 1380, 1, 1]]     [1, 1380, 1, 1]           0       
       SE-16         [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-96       [[1, 1380, 16, 16]]    [1, 230, 16, 16]       317,630    
   BatchNorm2D-64     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     Dropout-16       [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-97        [[1, 230, 16, 16]]    [1, 230, 16, 16]       53,130     
   BatchNorm2D-65     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     MBConv-16        [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-98        [[1, 230, 16, 16]]   [1, 1380, 16, 16]       318,780    
   BatchNorm2D-66    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
      Swish-49       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-33     [[1, 230, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-99       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]       35,880     
   BatchNorm2D-67    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
      Swish-50       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-34    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
AdaptiveAvgPool2D-17 [[1, 1380, 16, 16]]    [1, 1380, 1, 1]           0       
     Conv2D-100       [[1, 1380, 1, 1]]      [1, 345, 1, 1]        476,445    
      Swish-51         [[1, 345, 1, 1]]      [1, 345, 1, 1]           0       
     Conv2D-101        [[1, 345, 1, 1]]     [1, 1380, 1, 1]        477,480    
     Sigmoid-17       [[1, 1380, 1, 1]]     [1, 1380, 1, 1]           0       
       SE-17         [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-102      [[1, 1380, 16, 16]]    [1, 230, 16, 16]       317,630    
   BatchNorm2D-68     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     Dropout-17       [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-103       [[1, 230, 16, 16]]    [1, 230, 16, 16]       53,130     
   BatchNorm2D-69     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     MBConv-17        [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-104       [[1, 230, 16, 16]]   [1, 1380, 16, 16]       318,780    
   BatchNorm2D-70    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
      Swish-52       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-35     [[1, 230, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-105      [[1, 1380, 16, 16]]    [1, 1380, 8, 8]        35,880     
   BatchNorm2D-71     [[1, 1380, 8, 8]]     [1, 1380, 8, 8]         5,520     
      Swish-53        [[1, 1380, 8, 8]]     [1, 1380, 8, 8]           0       
   ConvBNLayer-36    [[1, 1380, 16, 16]]    [1, 1380, 8, 8]           0       
AdaptiveAvgPool2D-18  [[1, 1380, 8, 8]]     [1, 1380, 1, 1]           0       
     Conv2D-106       [[1, 1380, 1, 1]]      [1, 345, 1, 1]        476,445    
      Swish-54         [[1, 345, 1, 1]]      [1, 345, 1, 1]           0       
     Conv2D-107        [[1, 345, 1, 1]]     [1, 1380, 1, 1]        477,480    
     Sigmoid-18       [[1, 1380, 1, 1]]     [1, 1380, 1, 1]           0       
       SE-18          [[1, 1380, 8, 8]]     [1, 1380, 8, 8]           0       
     Conv2D-108       [[1, 1380, 8, 8]]      [1, 230, 8, 8]        317,630    
   BatchNorm2D-72      [[1, 230, 8, 8]]      [1, 230, 8, 8]          920      
     Dropout-18        [[1, 230, 8, 8]]      [1, 230, 8, 8]           0       
     MBConv-18        [[1, 230, 16, 16]]     [1, 230, 8, 8]           0       
     Conv2D-110        [[1, 230, 8, 8]]     [1, 1380, 8, 8]        318,780    
   BatchNorm2D-74     [[1, 1380, 8, 8]]     [1, 1380, 8, 8]         5,520     
      Swish-55        [[1, 1380, 8, 8]]     [1, 1380, 8, 8]           0       
   ConvBNLayer-37      [[1, 230, 8, 8]]     [1, 1380, 8, 8]           0       
     Conv2D-111       [[1, 1380, 8, 8]]     [1, 1380, 8, 8]        13,800     
   BatchNorm2D-75     [[1, 1380, 8, 8]]     [1, 1380, 8, 8]         5,520     
      Swish-56        [[1, 1380, 8, 8]]     [1, 1380, 8, 8]           0       
   ConvBNLayer-38     [[1, 1380, 8, 8]]     [1, 1380, 8, 8]           0       
AdaptiveAvgPool2D-19  [[1, 1380, 8, 8]]     [1, 1380, 1, 1]           0       
     Conv2D-112       [[1, 1380, 1, 1]]      [1, 345, 1, 1]        476,445    
      Swish-57         [[1, 345, 1, 1]]      [1, 345, 1, 1]           0       
     Conv2D-113        [[1, 345, 1, 1]]     [1, 1380, 1, 1]        477,480    
     Sigmoid-19       [[1, 1380, 1, 1]]     [1, 1380, 1, 1]           0       
       SE-19          [[1, 1380, 8, 8]]     [1, 1380, 8, 8]           0       
     Conv2D-114       [[1, 1380, 8, 8]]      [1, 384, 8, 8]        530,304    
   BatchNorm2D-76      [[1, 384, 8, 8]]      [1, 384, 8, 8]         1,536     
     Dropout-19        [[1, 384, 8, 8]]      [1, 384, 8, 8]           0       
     Conv2D-115        [[1, 230, 8, 8]]      [1, 384, 8, 8]        88,704     
   BatchNorm2D-77      [[1, 384, 8, 8]]      [1, 384, 8, 8]         1,536     
     MBConv-19         [[1, 230, 8, 8]]      [1, 384, 8, 8]           0       
     Conv2D-116        [[1, 384, 8, 8]]     [1, 1280, 8, 8]        492,800    
   BatchNorm2D-78     [[1, 1280, 8, 8]]     [1, 1280, 8, 8]         5,120     
      Swish-58        [[1, 1280, 8, 8]]     [1, 1280, 8, 8]           0       
   ConvBNLayer-39      [[1, 384, 8, 8]]     [1, 1280, 8, 8]           0       
AdaptiveAvgPool2D-20  [[1, 1280, 8, 8]]     [1, 1280, 1, 1]           0       
     Dropout-20       [[1, 1280, 1, 1]]     [1, 1280, 1, 1]           0       
     Conv2D-117       [[1, 1280, 1, 1]]     [1, 1000, 1, 1]       1,281,000   
 Classifier_Head-1     [[1, 384, 8, 8]]        [1, 1000]              0       
================================================================================
Total params: 14,328,212
Trainable params: 14,261,944
Non-trainable params: 66,268
--------------------------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 754.80
Params size (MB): 54.66
Estimated Total Size (MB): 810.21
--------------------------------------------------------------------------------






{'total_params': 14328212, 'trainable_params': 14261944}

4.2 搭建ResNet网络分支

ResNet(Residual Neural Network)由Kaiming He等人提出,该网络的核心思想是在卷积块中添加残差连接,解决了因网络层数过深所引起的网络退化问题,可以帮助我们加深网络的深度,提高模型的能力。

卷积块中的残差连接如下图所示:

网络的整体结构如下图所示:

###### BasicBlock ###############
class Basicblock(paddle.nn.Layer):
    def __init__(self, in_channel, out_channel, stride=1):
        super(Basicblock, self).__init__()
        self.stride = stride
        self.conv0 = nn.Conv2D(
            in_channel, out_channel, 3, stride=stride, padding=1
        )
        self.conv1 = nn.Conv2D(out_channel, out_channel, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2D(in_channel, out_channel, 1, stride=stride)
        self.bn0 = nn.BatchNorm2D(out_channel)
        self.bn1 = nn.BatchNorm2D(out_channel)
        self.bn2 = nn.BatchNorm2D(out_channel)

    def forward(self, inputs):
        y = inputs
        x = self.conv0(inputs)
        x = self.bn0(x)
        x = F.relu(x)
        x = self.conv1(x)
        x = self.bn1(x)
        if self.stride == 2:
            y = self.conv2(y)
            y = self.bn2(y)
        z = F.relu(x + y)
        return z


############ BottoleNeckBlock ##############
class Bottleneckblock(paddle.nn.Layer):
    def __init__(self, inplane, in_channel, out_channel, stride=1, start=False):
        super(Bottleneckblock, self).__init__()
        self.stride = stride
        self.start = start
        self.conv0 = nn.Conv2D(in_channel, inplane, 1, stride=stride)
        self.conv1 = nn.Conv2D(inplane, inplane, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2D(inplane, out_channel, 1, stride=1)
        self.conv3 = nn.Conv2D(in_channel, out_channel, 1, stride=stride)
        self.bn0 = nn.BatchNorm2D(inplane)
        self.bn1 = nn.BatchNorm2D(inplane)
        self.bn2 = nn.BatchNorm2D(out_channel)
        self.bn3 = nn.BatchNorm2D(out_channel)

    def forward(self, inputs):
        y = inputs
        x = self.conv0(inputs)
        x = self.bn0(x)
        x = F.relu(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        if self.start:
            y = self.conv3(y)
            y = self.bn3(y)
        z = F.relu(x + y)
        return z


###################  ResNet  #################
class ResNet(paddle.nn.Layer):
    def __init__(self, num, bottlenet, in_channels):
        super(ResNet, self).__init__()
        self.conv0 = nn.Conv2D(in_channels, 64, 7, stride=2)
        self.bn = nn.BatchNorm2D(64)
        self.pool1 = nn.MaxPool2D(3, stride=2)
        if bottlenet:
            self.layer0 = self.add_bottleneck_layer(num[0], 64, start=True)
            self.layer1 = self.add_bottleneck_layer(num[1], 128)
            self.layer2 = self.add_bottleneck_layer(num[2], 256)
            self.layer3 = self.add_bottleneck_layer(num[3], 512)
        else:
            self.layer0 = self.add_basic_layer(num[0], 64, start=True)
            self.layer1 = self.add_basic_layer(num[1], 128)
            self.layer2 = self.add_basic_layer(num[2], 256)
            self.layer3 = self.add_basic_layer(num[3], 512)
        self.pool2 = nn.AdaptiveAvgPool2D(output_size=(1, 1))

    def add_basic_layer(self, num, inplane, start=False):
        layer = []
        if start:
            layer.append(Basicblock(inplane, inplane))
        else:
            layer.append(Basicblock(inplane // 2, inplane, stride=2))
        for i in range(num - 1):
            layer.append(Basicblock(inplane, inplane))
        return nn.Sequential(*layer)

    def add_bottleneck_layer(self, num, inplane, start=False):
        layer = []
        if start:
            layer.append(
                Bottleneckblock(inplane, inplane, inplane * 4, start=True)
            )
        else:
            layer.append(
                Bottleneckblock(
                    inplane, inplane * 2, inplane * 4, stride=2, start=True
                )
            )
        for i in range(num - 1):
            layer.append(Bottleneckblock(inplane, inplane * 4, inplane * 4))
        return nn.Sequential(*layer)

    def forward(self, inputs):
        x = self.conv0(inputs)
        x = self.bn(x)
        x = self.pool1(x)
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool2(x)
        x = paddle.squeeze(x, axis=[2, 3])
        return x


def ResNet34(in_channels=3):
    return ResNet([3, 4, 6, 3], bottlenet=False, in_channels=in_channels)
# 展示ResNet分支网络结构
resnet = ResNet34(in_channels=256)
paddle.summary(resnet, (1, 256, 512, 512))
--------------------------------------------------------------------------------
    Layer (type)         Input Shape          Output Shape         Param #    
================================================================================
     Conv2D-118      [[1, 256, 512, 512]]  [1, 64, 253, 253]       802,880    
   BatchNorm2D-79    [[1, 64, 253, 253]]   [1, 64, 253, 253]         256      
    MaxPool2D-1      [[1, 64, 253, 253]]   [1, 64, 126, 126]          0       
     Conv2D-119      [[1, 64, 126, 126]]   [1, 64, 126, 126]       36,928     
   BatchNorm2D-80    [[1, 64, 126, 126]]   [1, 64, 126, 126]         256      
     Conv2D-120      [[1, 64, 126, 126]]   [1, 64, 126, 126]       36,928     
   BatchNorm2D-81    [[1, 64, 126, 126]]   [1, 64, 126, 126]         256      
    Basicblock-1     [[1, 64, 126, 126]]   [1, 64, 126, 126]          0       
     Conv2D-122      [[1, 64, 126, 126]]   [1, 64, 126, 126]       36,928     
   BatchNorm2D-83    [[1, 64, 126, 126]]   [1, 64, 126, 126]         256      
     Conv2D-123      [[1, 64, 126, 126]]   [1, 64, 126, 126]       36,928     
   BatchNorm2D-84    [[1, 64, 126, 126]]   [1, 64, 126, 126]         256      
    Basicblock-2     [[1, 64, 126, 126]]   [1, 64, 126, 126]          0       
     Conv2D-125      [[1, 64, 126, 126]]   [1, 64, 126, 126]       36,928     
   BatchNorm2D-86    [[1, 64, 126, 126]]   [1, 64, 126, 126]         256      
     Conv2D-126      [[1, 64, 126, 126]]   [1, 64, 126, 126]       36,928     
   BatchNorm2D-87    [[1, 64, 126, 126]]   [1, 64, 126, 126]         256      
    Basicblock-3     [[1, 64, 126, 126]]   [1, 64, 126, 126]          0       
     Conv2D-128      [[1, 64, 126, 126]]    [1, 128, 63, 63]       73,856     
   BatchNorm2D-89     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
     Conv2D-129       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
   BatchNorm2D-90     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
     Conv2D-130      [[1, 64, 126, 126]]    [1, 128, 63, 63]        8,320     
   BatchNorm2D-91     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
    Basicblock-4     [[1, 64, 126, 126]]    [1, 128, 63, 63]          0       
     Conv2D-131       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
   BatchNorm2D-92     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
     Conv2D-132       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
   BatchNorm2D-93     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
    Basicblock-5      [[1, 128, 63, 63]]    [1, 128, 63, 63]          0       
     Conv2D-134       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
   BatchNorm2D-95     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
     Conv2D-135       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
   BatchNorm2D-96     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
    Basicblock-6      [[1, 128, 63, 63]]    [1, 128, 63, 63]          0       
     Conv2D-137       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
   BatchNorm2D-98     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
     Conv2D-138       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
   BatchNorm2D-99     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
    Basicblock-7      [[1, 128, 63, 63]]    [1, 128, 63, 63]          0       
     Conv2D-140       [[1, 128, 63, 63]]    [1, 256, 32, 32]       295,168    
  BatchNorm2D-101     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-141       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-102     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-142       [[1, 128, 63, 63]]    [1, 256, 32, 32]       33,024     
  BatchNorm2D-103     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
    Basicblock-8      [[1, 128, 63, 63]]    [1, 256, 32, 32]          0       
     Conv2D-143       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-104     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-144       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-105     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
    Basicblock-9      [[1, 256, 32, 32]]    [1, 256, 32, 32]          0       
     Conv2D-146       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-107     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-147       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-108     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
   Basicblock-10      [[1, 256, 32, 32]]    [1, 256, 32, 32]          0       
     Conv2D-149       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-110     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-150       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-111     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
   Basicblock-11      [[1, 256, 32, 32]]    [1, 256, 32, 32]          0       
     Conv2D-152       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-113     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-153       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-114     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
   Basicblock-12      [[1, 256, 32, 32]]    [1, 256, 32, 32]          0       
     Conv2D-155       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-116     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-156       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-117     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
   Basicblock-13      [[1, 256, 32, 32]]    [1, 256, 32, 32]          0       
     Conv2D-158       [[1, 256, 32, 32]]    [1, 512, 16, 16]      1,180,160   
  BatchNorm2D-119     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
     Conv2D-159       [[1, 512, 16, 16]]    [1, 512, 16, 16]      2,359,808   
  BatchNorm2D-120     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
     Conv2D-160       [[1, 256, 32, 32]]    [1, 512, 16, 16]       131,584    
  BatchNorm2D-121     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
   Basicblock-14      [[1, 256, 32, 32]]    [1, 512, 16, 16]          0       
     Conv2D-161       [[1, 512, 16, 16]]    [1, 512, 16, 16]      2,359,808   
  BatchNorm2D-122     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
     Conv2D-162       [[1, 512, 16, 16]]    [1, 512, 16, 16]      2,359,808   
  BatchNorm2D-123     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
   Basicblock-15      [[1, 512, 16, 16]]    [1, 512, 16, 16]          0       
     Conv2D-164       [[1, 512, 16, 16]]    [1, 512, 16, 16]      2,359,808   
  BatchNorm2D-125     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
     Conv2D-165       [[1, 512, 16, 16]]    [1, 512, 16, 16]      2,359,808   
  BatchNorm2D-126     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
   Basicblock-16      [[1, 512, 16, 16]]    [1, 512, 16, 16]          0       
AdaptiveAvgPool2D-21  [[1, 512, 16, 16]]     [1, 512, 1, 1]           0       
================================================================================
Total params: 22,103,616
Trainable params: 22,086,592
Non-trainable params: 17,024
--------------------------------------------------------------------------------
Input size (MB): 256.00
Forward/backward pass size (MB): 352.82
Params size (MB): 84.32
Estimated Total Size (MB): 693.13
--------------------------------------------------------------------------------






{'total_params': 22103616, 'trainable_params': 22086592}

4.3 模型组装

构建双分支网络,使用EfficientNetB3训练彩色眼底图片,该分支输出结果的维度大小为[batchSize, 1000];使用ResNet34训练3D OCT图片,该分支输出结果的维度大小为[batchSize, 512],最后将两个分支的输出结果在通道维度进行Concat操作,再经过一个全连接层后得到最终的分类预测结果。

class Model(nn.Layer):
    def __init__(self):
        super(Model, self).__init__()
        self.fundus_branch = EfficientNetB3(in_channels=3, num_class=1000)
        self.oct_branch = ResNet34(in_channels=256)
        # 最终的分类数为3
        self.decision_branch = nn.Linear(1512, 3)

    def forward(self, fundus_img, oct_img):
        b1 = self.fundus_branch(fundus_img)
        b2 = self.oct_branch(oct_img)
        logit = self.decision_branch(paddle.concat([b1, b2], 1))

        return logit

4.4 模型可视化

model = Model()
paddle.summary(model, [(1, 3, 256, 256), (1, 256, 512, 512)])
--------------------------------------------------------------------------------
    Layer (type)         Input Shape          Output Shape         Param #    
================================================================================
     Conv2D-167       [[1, 3, 256, 256]]   [1, 38, 128, 128]        1,064     
  BatchNorm2D-128    [[1, 38, 128, 128]]   [1, 38, 128, 128]         152      
     Conv2D-168      [[1, 38, 128, 128]]   [1, 38, 128, 128]        1,482     
  BatchNorm2D-129    [[1, 38, 128, 128]]   [1, 38, 128, 128]         152      
      Swish-59       [[1, 38, 128, 128]]   [1, 38, 128, 128]          0       
   ConvBNLayer-40    [[1, 38, 128, 128]]   [1, 38, 128, 128]          0       
     Conv2D-169      [[1, 38, 128, 128]]   [1, 38, 128, 128]         380      
  BatchNorm2D-130    [[1, 38, 128, 128]]   [1, 38, 128, 128]         152      
      Swish-60       [[1, 38, 128, 128]]   [1, 38, 128, 128]          0       
   ConvBNLayer-41    [[1, 38, 128, 128]]   [1, 38, 128, 128]          0       
AdaptiveAvgPool2D-22 [[1, 38, 128, 128]]     [1, 38, 1, 1]            0       
     Conv2D-170        [[1, 38, 1, 1]]        [1, 9, 1, 1]           351      
      Swish-61          [[1, 9, 1, 1]]        [1, 9, 1, 1]            0       
     Conv2D-171         [[1, 9, 1, 1]]       [1, 38, 1, 1]           380      
     Sigmoid-20        [[1, 38, 1, 1]]       [1, 38, 1, 1]            0       
       SE-20         [[1, 38, 128, 128]]   [1, 38, 128, 128]          0       
     Conv2D-172      [[1, 38, 128, 128]]   [1, 19, 128, 128]         741      
  BatchNorm2D-131    [[1, 19, 128, 128]]   [1, 19, 128, 128]         76       
     Dropout-21      [[1, 19, 128, 128]]   [1, 19, 128, 128]          0       
     Conv2D-173      [[1, 38, 128, 128]]   [1, 19, 128, 128]         741      
  BatchNorm2D-132    [[1, 19, 128, 128]]   [1, 19, 128, 128]         76       
     MBConv-20       [[1, 38, 128, 128]]   [1, 19, 128, 128]          0       
     Conv2D-174      [[1, 19, 128, 128]]   [1, 114, 128, 128]       2,280     
  BatchNorm2D-133    [[1, 114, 128, 128]]  [1, 114, 128, 128]        456      
      Swish-62       [[1, 114, 128, 128]]  [1, 114, 128, 128]         0       
   ConvBNLayer-42    [[1, 19, 128, 128]]   [1, 114, 128, 128]         0       
     Conv2D-175      [[1, 114, 128, 128]]  [1, 114, 128, 128]       1,140     
  BatchNorm2D-134    [[1, 114, 128, 128]]  [1, 114, 128, 128]        456      
      Swish-63       [[1, 114, 128, 128]]  [1, 114, 128, 128]         0       
   ConvBNLayer-43    [[1, 114, 128, 128]]  [1, 114, 128, 128]         0       
AdaptiveAvgPool2D-23 [[1, 114, 128, 128]]    [1, 114, 1, 1]           0       
     Conv2D-176        [[1, 114, 1, 1]]      [1, 28, 1, 1]          3,220     
      Swish-64         [[1, 28, 1, 1]]       [1, 28, 1, 1]            0       
     Conv2D-177        [[1, 28, 1, 1]]       [1, 114, 1, 1]         3,306     
     Sigmoid-21        [[1, 114, 1, 1]]      [1, 114, 1, 1]           0       
       SE-21         [[1, 114, 128, 128]]  [1, 114, 128, 128]         0       
     Conv2D-178      [[1, 114, 128, 128]]  [1, 28, 128, 128]        3,220     
  BatchNorm2D-135    [[1, 28, 128, 128]]   [1, 28, 128, 128]         112      
     Dropout-22      [[1, 28, 128, 128]]   [1, 28, 128, 128]          0       
     Conv2D-179      [[1, 19, 128, 128]]   [1, 28, 128, 128]         560      
  BatchNorm2D-136    [[1, 28, 128, 128]]   [1, 28, 128, 128]         112      
     MBConv-21       [[1, 19, 128, 128]]   [1, 28, 128, 128]          0       
     Conv2D-180      [[1, 28, 128, 128]]   [1, 168, 128, 128]       4,872     
  BatchNorm2D-137    [[1, 168, 128, 128]]  [1, 168, 128, 128]        672      
      Swish-65       [[1, 168, 128, 128]]  [1, 168, 128, 128]         0       
   ConvBNLayer-44    [[1, 28, 128, 128]]   [1, 168, 128, 128]         0       
     Conv2D-181      [[1, 168, 128, 128]]   [1, 168, 64, 64]        1,680     
  BatchNorm2D-138     [[1, 168, 64, 64]]    [1, 168, 64, 64]         672      
      Swish-66        [[1, 168, 64, 64]]    [1, 168, 64, 64]          0       
   ConvBNLayer-45    [[1, 168, 128, 128]]   [1, 168, 64, 64]          0       
AdaptiveAvgPool2D-24  [[1, 168, 64, 64]]     [1, 168, 1, 1]           0       
     Conv2D-182        [[1, 168, 1, 1]]      [1, 42, 1, 1]          7,098     
      Swish-67         [[1, 42, 1, 1]]       [1, 42, 1, 1]            0       
     Conv2D-183        [[1, 42, 1, 1]]       [1, 168, 1, 1]         7,224     
     Sigmoid-22        [[1, 168, 1, 1]]      [1, 168, 1, 1]           0       
       SE-22          [[1, 168, 64, 64]]    [1, 168, 64, 64]          0       
     Conv2D-184       [[1, 168, 64, 64]]    [1, 28, 64, 64]         4,732     
  BatchNorm2D-139     [[1, 28, 64, 64]]     [1, 28, 64, 64]          112      
     Dropout-23       [[1, 28, 64, 64]]     [1, 28, 64, 64]           0       
     MBConv-22       [[1, 28, 128, 128]]    [1, 28, 64, 64]           0       
     Conv2D-186       [[1, 28, 64, 64]]     [1, 168, 64, 64]        4,872     
  BatchNorm2D-141     [[1, 168, 64, 64]]    [1, 168, 64, 64]         672      
      Swish-68        [[1, 168, 64, 64]]    [1, 168, 64, 64]          0       
   ConvBNLayer-46     [[1, 28, 64, 64]]     [1, 168, 64, 64]          0       
     Conv2D-187       [[1, 168, 64, 64]]    [1, 168, 64, 64]        4,368     
  BatchNorm2D-142     [[1, 168, 64, 64]]    [1, 168, 64, 64]         672      
      Swish-69        [[1, 168, 64, 64]]    [1, 168, 64, 64]          0       
   ConvBNLayer-47     [[1, 168, 64, 64]]    [1, 168, 64, 64]          0       
AdaptiveAvgPool2D-25  [[1, 168, 64, 64]]     [1, 168, 1, 1]           0       
     Conv2D-188        [[1, 168, 1, 1]]      [1, 42, 1, 1]          7,098     
      Swish-70         [[1, 42, 1, 1]]       [1, 42, 1, 1]            0       
     Conv2D-189        [[1, 42, 1, 1]]       [1, 168, 1, 1]         7,224     
     Sigmoid-23        [[1, 168, 1, 1]]      [1, 168, 1, 1]           0       
       SE-23          [[1, 168, 64, 64]]    [1, 168, 64, 64]          0       
     Conv2D-190       [[1, 168, 64, 64]]    [1, 48, 64, 64]         8,112     
  BatchNorm2D-143     [[1, 48, 64, 64]]     [1, 48, 64, 64]          192      
     Dropout-24       [[1, 48, 64, 64]]     [1, 48, 64, 64]           0       
     Conv2D-191       [[1, 28, 64, 64]]     [1, 48, 64, 64]         1,392     
  BatchNorm2D-144     [[1, 48, 64, 64]]     [1, 48, 64, 64]          192      
     MBConv-23        [[1, 28, 64, 64]]     [1, 48, 64, 64]           0       
     Conv2D-192       [[1, 48, 64, 64]]     [1, 288, 64, 64]       14,112     
  BatchNorm2D-145     [[1, 288, 64, 64]]    [1, 288, 64, 64]        1,152     
      Swish-71        [[1, 288, 64, 64]]    [1, 288, 64, 64]          0       
   ConvBNLayer-48     [[1, 48, 64, 64]]     [1, 288, 64, 64]          0       
     Conv2D-193       [[1, 288, 64, 64]]    [1, 288, 32, 32]        7,488     
  BatchNorm2D-146     [[1, 288, 32, 32]]    [1, 288, 32, 32]        1,152     
      Swish-72        [[1, 288, 32, 32]]    [1, 288, 32, 32]          0       
   ConvBNLayer-49     [[1, 288, 64, 64]]    [1, 288, 32, 32]          0       
AdaptiveAvgPool2D-26  [[1, 288, 32, 32]]     [1, 288, 1, 1]           0       
     Conv2D-194        [[1, 288, 1, 1]]      [1, 72, 1, 1]         20,808     
      Swish-73         [[1, 72, 1, 1]]       [1, 72, 1, 1]            0       
     Conv2D-195        [[1, 72, 1, 1]]       [1, 288, 1, 1]        21,024     
     Sigmoid-24        [[1, 288, 1, 1]]      [1, 288, 1, 1]           0       
       SE-24          [[1, 288, 32, 32]]    [1, 288, 32, 32]          0       
     Conv2D-196       [[1, 288, 32, 32]]    [1, 48, 32, 32]        13,872     
  BatchNorm2D-147     [[1, 48, 32, 32]]     [1, 48, 32, 32]          192      
     Dropout-25       [[1, 48, 32, 32]]     [1, 48, 32, 32]           0       
     MBConv-24        [[1, 48, 64, 64]]     [1, 48, 32, 32]           0       
     Conv2D-198       [[1, 48, 32, 32]]     [1, 288, 32, 32]       14,112     
  BatchNorm2D-149     [[1, 288, 32, 32]]    [1, 288, 32, 32]        1,152     
      Swish-74        [[1, 288, 32, 32]]    [1, 288, 32, 32]          0       
   ConvBNLayer-50     [[1, 48, 32, 32]]     [1, 288, 32, 32]          0       
     Conv2D-199       [[1, 288, 32, 32]]    [1, 288, 32, 32]        2,880     
  BatchNorm2D-150     [[1, 288, 32, 32]]    [1, 288, 32, 32]        1,152     
      Swish-75        [[1, 288, 32, 32]]    [1, 288, 32, 32]          0       
   ConvBNLayer-51     [[1, 288, 32, 32]]    [1, 288, 32, 32]          0       
AdaptiveAvgPool2D-27  [[1, 288, 32, 32]]     [1, 288, 1, 1]           0       
     Conv2D-200        [[1, 288, 1, 1]]      [1, 72, 1, 1]         20,808     
      Swish-76         [[1, 72, 1, 1]]       [1, 72, 1, 1]            0       
     Conv2D-201        [[1, 72, 1, 1]]       [1, 288, 1, 1]        21,024     
     Sigmoid-25        [[1, 288, 1, 1]]      [1, 288, 1, 1]           0       
       SE-25          [[1, 288, 32, 32]]    [1, 288, 32, 32]          0       
     Conv2D-202       [[1, 288, 32, 32]]    [1, 96, 32, 32]        27,744     
  BatchNorm2D-151     [[1, 96, 32, 32]]     [1, 96, 32, 32]          384      
     Dropout-26       [[1, 96, 32, 32]]     [1, 96, 32, 32]           0       
     Conv2D-203       [[1, 48, 32, 32]]     [1, 96, 32, 32]         4,704     
  BatchNorm2D-152     [[1, 96, 32, 32]]     [1, 96, 32, 32]          384      
     MBConv-25        [[1, 48, 32, 32]]     [1, 96, 32, 32]           0       
     Conv2D-204       [[1, 96, 32, 32]]     [1, 576, 32, 32]       55,872     
  BatchNorm2D-153     [[1, 576, 32, 32]]    [1, 576, 32, 32]        2,304     
      Swish-77        [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
   ConvBNLayer-52     [[1, 96, 32, 32]]     [1, 576, 32, 32]          0       
     Conv2D-205       [[1, 576, 32, 32]]    [1, 576, 32, 32]        5,760     
  BatchNorm2D-154     [[1, 576, 32, 32]]    [1, 576, 32, 32]        2,304     
      Swish-78        [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
   ConvBNLayer-53     [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
AdaptiveAvgPool2D-28  [[1, 576, 32, 32]]     [1, 576, 1, 1]           0       
     Conv2D-206        [[1, 576, 1, 1]]      [1, 144, 1, 1]        83,088     
      Swish-79         [[1, 144, 1, 1]]      [1, 144, 1, 1]           0       
     Conv2D-207        [[1, 144, 1, 1]]      [1, 576, 1, 1]        83,520     
     Sigmoid-26        [[1, 576, 1, 1]]      [1, 576, 1, 1]           0       
       SE-26          [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
     Conv2D-208       [[1, 576, 32, 32]]    [1, 96, 32, 32]        55,392     
  BatchNorm2D-155     [[1, 96, 32, 32]]     [1, 96, 32, 32]          384      
     Dropout-27       [[1, 96, 32, 32]]     [1, 96, 32, 32]           0       
     Conv2D-209       [[1, 96, 32, 32]]     [1, 96, 32, 32]         9,312     
  BatchNorm2D-156     [[1, 96, 32, 32]]     [1, 96, 32, 32]          384      
     MBConv-26        [[1, 96, 32, 32]]     [1, 96, 32, 32]           0       
     Conv2D-210       [[1, 96, 32, 32]]     [1, 576, 32, 32]       55,872     
  BatchNorm2D-157     [[1, 576, 32, 32]]    [1, 576, 32, 32]        2,304     
      Swish-80        [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
   ConvBNLayer-54     [[1, 96, 32, 32]]     [1, 576, 32, 32]          0       
     Conv2D-211       [[1, 576, 32, 32]]    [1, 576, 32, 32]        5,760     
  BatchNorm2D-158     [[1, 576, 32, 32]]    [1, 576, 32, 32]        2,304     
      Swish-81        [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
   ConvBNLayer-55     [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
AdaptiveAvgPool2D-29  [[1, 576, 32, 32]]     [1, 576, 1, 1]           0       
     Conv2D-212        [[1, 576, 1, 1]]      [1, 144, 1, 1]        83,088     
      Swish-82         [[1, 144, 1, 1]]      [1, 144, 1, 1]           0       
     Conv2D-213        [[1, 144, 1, 1]]      [1, 576, 1, 1]        83,520     
     Sigmoid-27        [[1, 576, 1, 1]]      [1, 576, 1, 1]           0       
       SE-27          [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
     Conv2D-214       [[1, 576, 32, 32]]    [1, 96, 32, 32]        55,392     
  BatchNorm2D-159     [[1, 96, 32, 32]]     [1, 96, 32, 32]          384      
     Dropout-28       [[1, 96, 32, 32]]     [1, 96, 32, 32]           0       
     Conv2D-215       [[1, 96, 32, 32]]     [1, 96, 32, 32]         9,312     
  BatchNorm2D-160     [[1, 96, 32, 32]]     [1, 96, 32, 32]          384      
     MBConv-27        [[1, 96, 32, 32]]     [1, 96, 32, 32]           0       
     Conv2D-216       [[1, 96, 32, 32]]     [1, 576, 32, 32]       55,872     
  BatchNorm2D-161     [[1, 576, 32, 32]]    [1, 576, 32, 32]        2,304     
      Swish-83        [[1, 576, 32, 32]]    [1, 576, 32, 32]          0       
   ConvBNLayer-56     [[1, 96, 32, 32]]     [1, 576, 32, 32]          0       
     Conv2D-217       [[1, 576, 32, 32]]    [1, 576, 16, 16]        5,760     
  BatchNorm2D-162     [[1, 576, 16, 16]]    [1, 576, 16, 16]        2,304     
      Swish-84        [[1, 576, 16, 16]]    [1, 576, 16, 16]          0       
   ConvBNLayer-57     [[1, 576, 32, 32]]    [1, 576, 16, 16]          0       
AdaptiveAvgPool2D-30  [[1, 576, 16, 16]]     [1, 576, 1, 1]           0       
     Conv2D-218        [[1, 576, 1, 1]]      [1, 144, 1, 1]        83,088     
      Swish-85         [[1, 144, 1, 1]]      [1, 144, 1, 1]           0       
     Conv2D-219        [[1, 144, 1, 1]]      [1, 576, 1, 1]        83,520     
     Sigmoid-28        [[1, 576, 1, 1]]      [1, 576, 1, 1]           0       
       SE-28          [[1, 576, 16, 16]]    [1, 576, 16, 16]          0       
     Conv2D-220       [[1, 576, 16, 16]]    [1, 96, 16, 16]        55,392     
  BatchNorm2D-163     [[1, 96, 16, 16]]     [1, 96, 16, 16]          384      
     Dropout-29       [[1, 96, 16, 16]]     [1, 96, 16, 16]           0       
     MBConv-28        [[1, 96, 32, 32]]     [1, 96, 16, 16]           0       
     Conv2D-222       [[1, 96, 16, 16]]     [1, 576, 16, 16]       55,872     
  BatchNorm2D-165     [[1, 576, 16, 16]]    [1, 576, 16, 16]        2,304     
      Swish-86        [[1, 576, 16, 16]]    [1, 576, 16, 16]          0       
   ConvBNLayer-58     [[1, 96, 16, 16]]     [1, 576, 16, 16]          0       
     Conv2D-223       [[1, 576, 16, 16]]    [1, 576, 16, 16]       14,976     
  BatchNorm2D-166     [[1, 576, 16, 16]]    [1, 576, 16, 16]        2,304     
      Swish-87        [[1, 576, 16, 16]]    [1, 576, 16, 16]          0       
   ConvBNLayer-59     [[1, 576, 16, 16]]    [1, 576, 16, 16]          0       
AdaptiveAvgPool2D-31  [[1, 576, 16, 16]]     [1, 576, 1, 1]           0       
     Conv2D-224        [[1, 576, 1, 1]]      [1, 144, 1, 1]        83,088     
      Swish-88         [[1, 144, 1, 1]]      [1, 144, 1, 1]           0       
     Conv2D-225        [[1, 144, 1, 1]]      [1, 576, 1, 1]        83,520     
     Sigmoid-29        [[1, 576, 1, 1]]      [1, 576, 1, 1]           0       
       SE-29          [[1, 576, 16, 16]]    [1, 576, 16, 16]          0       
     Conv2D-226       [[1, 576, 16, 16]]    [1, 134, 16, 16]       77,318     
  BatchNorm2D-167     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     Dropout-30       [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-227       [[1, 96, 16, 16]]     [1, 134, 16, 16]       12,998     
  BatchNorm2D-168     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     MBConv-29        [[1, 96, 16, 16]]     [1, 134, 16, 16]          0       
     Conv2D-228       [[1, 134, 16, 16]]    [1, 804, 16, 16]       108,540    
  BatchNorm2D-169     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-89        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-60     [[1, 134, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-229       [[1, 804, 16, 16]]    [1, 804, 16, 16]       20,904     
  BatchNorm2D-170     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-90        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-61     [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
AdaptiveAvgPool2D-32  [[1, 804, 16, 16]]     [1, 804, 1, 1]           0       
     Conv2D-230        [[1, 804, 1, 1]]      [1, 201, 1, 1]        161,805    
      Swish-91         [[1, 201, 1, 1]]      [1, 201, 1, 1]           0       
     Conv2D-231        [[1, 201, 1, 1]]      [1, 804, 1, 1]        162,408    
     Sigmoid-30        [[1, 804, 1, 1]]      [1, 804, 1, 1]           0       
       SE-30          [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-232       [[1, 804, 16, 16]]    [1, 134, 16, 16]       107,870    
  BatchNorm2D-171     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     Dropout-31       [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-233       [[1, 134, 16, 16]]    [1, 134, 16, 16]       18,090     
  BatchNorm2D-172     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     MBConv-30        [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-234       [[1, 134, 16, 16]]    [1, 804, 16, 16]       108,540    
  BatchNorm2D-173     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-92        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-62     [[1, 134, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-235       [[1, 804, 16, 16]]    [1, 804, 16, 16]       20,904     
  BatchNorm2D-174     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-93        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-63     [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
AdaptiveAvgPool2D-33  [[1, 804, 16, 16]]     [1, 804, 1, 1]           0       
     Conv2D-236        [[1, 804, 1, 1]]      [1, 201, 1, 1]        161,805    
      Swish-94         [[1, 201, 1, 1]]      [1, 201, 1, 1]           0       
     Conv2D-237        [[1, 201, 1, 1]]      [1, 804, 1, 1]        162,408    
     Sigmoid-31        [[1, 804, 1, 1]]      [1, 804, 1, 1]           0       
       SE-31          [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-238       [[1, 804, 16, 16]]    [1, 134, 16, 16]       107,870    
  BatchNorm2D-175     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     Dropout-32       [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-239       [[1, 134, 16, 16]]    [1, 134, 16, 16]       18,090     
  BatchNorm2D-176     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     MBConv-31        [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-240       [[1, 134, 16, 16]]    [1, 804, 16, 16]       108,540    
  BatchNorm2D-177     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-95        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-64     [[1, 134, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-241       [[1, 804, 16, 16]]    [1, 804, 16, 16]       20,904     
  BatchNorm2D-178     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-96        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-65     [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
AdaptiveAvgPool2D-34  [[1, 804, 16, 16]]     [1, 804, 1, 1]           0       
     Conv2D-242        [[1, 804, 1, 1]]      [1, 201, 1, 1]        161,805    
      Swish-97         [[1, 201, 1, 1]]      [1, 201, 1, 1]           0       
     Conv2D-243        [[1, 201, 1, 1]]      [1, 804, 1, 1]        162,408    
     Sigmoid-32        [[1, 804, 1, 1]]      [1, 804, 1, 1]           0       
       SE-32          [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-244       [[1, 804, 16, 16]]    [1, 134, 16, 16]       107,870    
  BatchNorm2D-179     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     Dropout-33       [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-245       [[1, 134, 16, 16]]    [1, 134, 16, 16]       18,090     
  BatchNorm2D-180     [[1, 134, 16, 16]]    [1, 134, 16, 16]         536      
     MBConv-32        [[1, 134, 16, 16]]    [1, 134, 16, 16]          0       
     Conv2D-246       [[1, 134, 16, 16]]    [1, 804, 16, 16]       108,540    
  BatchNorm2D-181     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-98        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-66     [[1, 134, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-247       [[1, 804, 16, 16]]    [1, 804, 16, 16]       20,904     
  BatchNorm2D-182     [[1, 804, 16, 16]]    [1, 804, 16, 16]        3,216     
      Swish-99        [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
   ConvBNLayer-67     [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
AdaptiveAvgPool2D-35  [[1, 804, 16, 16]]     [1, 804, 1, 1]           0       
     Conv2D-248        [[1, 804, 1, 1]]      [1, 201, 1, 1]        161,805    
     Swish-100         [[1, 201, 1, 1]]      [1, 201, 1, 1]           0       
     Conv2D-249        [[1, 201, 1, 1]]      [1, 804, 1, 1]        162,408    
     Sigmoid-33        [[1, 804, 1, 1]]      [1, 804, 1, 1]           0       
       SE-33          [[1, 804, 16, 16]]    [1, 804, 16, 16]          0       
     Conv2D-250       [[1, 804, 16, 16]]    [1, 230, 16, 16]       185,150    
  BatchNorm2D-183     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     Dropout-34       [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-251       [[1, 134, 16, 16]]    [1, 230, 16, 16]       31,050     
  BatchNorm2D-184     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     MBConv-33        [[1, 134, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-252       [[1, 230, 16, 16]]   [1, 1380, 16, 16]       318,780    
  BatchNorm2D-185    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
     Swish-101       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-68     [[1, 230, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-253      [[1, 1380, 16, 16]]   [1, 1380, 16, 16]       35,880     
  BatchNorm2D-186    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
     Swish-102       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-69    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
AdaptiveAvgPool2D-36 [[1, 1380, 16, 16]]    [1, 1380, 1, 1]           0       
     Conv2D-254       [[1, 1380, 1, 1]]      [1, 345, 1, 1]        476,445    
     Swish-103         [[1, 345, 1, 1]]      [1, 345, 1, 1]           0       
     Conv2D-255        [[1, 345, 1, 1]]     [1, 1380, 1, 1]        477,480    
     Sigmoid-34       [[1, 1380, 1, 1]]     [1, 1380, 1, 1]           0       
       SE-34         [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-256      [[1, 1380, 16, 16]]    [1, 230, 16, 16]       317,630    
  BatchNorm2D-187     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     Dropout-35       [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-257       [[1, 230, 16, 16]]    [1, 230, 16, 16]       53,130     
  BatchNorm2D-188     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     MBConv-34        [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-258       [[1, 230, 16, 16]]   [1, 1380, 16, 16]       318,780    
  BatchNorm2D-189    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
     Swish-104       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-70     [[1, 230, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-259      [[1, 1380, 16, 16]]   [1, 1380, 16, 16]       35,880     
  BatchNorm2D-190    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
     Swish-105       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-71    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
AdaptiveAvgPool2D-37 [[1, 1380, 16, 16]]    [1, 1380, 1, 1]           0       
     Conv2D-260       [[1, 1380, 1, 1]]      [1, 345, 1, 1]        476,445    
     Swish-106         [[1, 345, 1, 1]]      [1, 345, 1, 1]           0       
     Conv2D-261        [[1, 345, 1, 1]]     [1, 1380, 1, 1]        477,480    
     Sigmoid-35       [[1, 1380, 1, 1]]     [1, 1380, 1, 1]           0       
       SE-35         [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-262      [[1, 1380, 16, 16]]    [1, 230, 16, 16]       317,630    
  BatchNorm2D-191     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     Dropout-36       [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-263       [[1, 230, 16, 16]]    [1, 230, 16, 16]       53,130     
  BatchNorm2D-192     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     MBConv-35        [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-264       [[1, 230, 16, 16]]   [1, 1380, 16, 16]       318,780    
  BatchNorm2D-193    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
     Swish-107       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-72     [[1, 230, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-265      [[1, 1380, 16, 16]]   [1, 1380, 16, 16]       35,880     
  BatchNorm2D-194    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
     Swish-108       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-73    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
AdaptiveAvgPool2D-38 [[1, 1380, 16, 16]]    [1, 1380, 1, 1]           0       
     Conv2D-266       [[1, 1380, 1, 1]]      [1, 345, 1, 1]        476,445    
     Swish-109         [[1, 345, 1, 1]]      [1, 345, 1, 1]           0       
     Conv2D-267        [[1, 345, 1, 1]]     [1, 1380, 1, 1]        477,480    
     Sigmoid-36       [[1, 1380, 1, 1]]     [1, 1380, 1, 1]           0       
       SE-36         [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-268      [[1, 1380, 16, 16]]    [1, 230, 16, 16]       317,630    
  BatchNorm2D-195     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     Dropout-37       [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-269       [[1, 230, 16, 16]]    [1, 230, 16, 16]       53,130     
  BatchNorm2D-196     [[1, 230, 16, 16]]    [1, 230, 16, 16]         920      
     MBConv-36        [[1, 230, 16, 16]]    [1, 230, 16, 16]          0       
     Conv2D-270       [[1, 230, 16, 16]]   [1, 1380, 16, 16]       318,780    
  BatchNorm2D-197    [[1, 1380, 16, 16]]   [1, 1380, 16, 16]        5,520     
     Swish-110       [[1, 1380, 16, 16]]   [1, 1380, 16, 16]          0       
   ConvBNLayer-74     [[1, 230, 16, 16]]   [1, 1380, 16, 16]          0       
     Conv2D-271      [[1, 1380, 16, 16]]    [1, 1380, 8, 8]        35,880     
  BatchNorm2D-198     [[1, 1380, 8, 8]]     [1, 1380, 8, 8]         5,520     
     Swish-111        [[1, 1380, 8, 8]]     [1, 1380, 8, 8]           0       
   ConvBNLayer-75    [[1, 1380, 16, 16]]    [1, 1380, 8, 8]           0       
AdaptiveAvgPool2D-39  [[1, 1380, 8, 8]]     [1, 1380, 1, 1]           0       
     Conv2D-272       [[1, 1380, 1, 1]]      [1, 345, 1, 1]        476,445    
     Swish-112         [[1, 345, 1, 1]]      [1, 345, 1, 1]           0       
     Conv2D-273        [[1, 345, 1, 1]]     [1, 1380, 1, 1]        477,480    
     Sigmoid-37       [[1, 1380, 1, 1]]     [1, 1380, 1, 1]           0       
       SE-37          [[1, 1380, 8, 8]]     [1, 1380, 8, 8]           0       
     Conv2D-274       [[1, 1380, 8, 8]]      [1, 230, 8, 8]        317,630    
  BatchNorm2D-199      [[1, 230, 8, 8]]      [1, 230, 8, 8]          920      
     Dropout-38        [[1, 230, 8, 8]]      [1, 230, 8, 8]           0       
     MBConv-37        [[1, 230, 16, 16]]     [1, 230, 8, 8]           0       
     Conv2D-276        [[1, 230, 8, 8]]     [1, 1380, 8, 8]        318,780    
  BatchNorm2D-201     [[1, 1380, 8, 8]]     [1, 1380, 8, 8]         5,520     
     Swish-113        [[1, 1380, 8, 8]]     [1, 1380, 8, 8]           0       
   ConvBNLayer-76      [[1, 230, 8, 8]]     [1, 1380, 8, 8]           0       
     Conv2D-277       [[1, 1380, 8, 8]]     [1, 1380, 8, 8]        13,800     
  BatchNorm2D-202     [[1, 1380, 8, 8]]     [1, 1380, 8, 8]         5,520     
     Swish-114        [[1, 1380, 8, 8]]     [1, 1380, 8, 8]           0       
   ConvBNLayer-77     [[1, 1380, 8, 8]]     [1, 1380, 8, 8]           0       
AdaptiveAvgPool2D-40  [[1, 1380, 8, 8]]     [1, 1380, 1, 1]           0       
     Conv2D-278       [[1, 1380, 1, 1]]      [1, 345, 1, 1]        476,445    
     Swish-115         [[1, 345, 1, 1]]      [1, 345, 1, 1]           0       
     Conv2D-279        [[1, 345, 1, 1]]     [1, 1380, 1, 1]        477,480    
     Sigmoid-38       [[1, 1380, 1, 1]]     [1, 1380, 1, 1]           0       
       SE-38          [[1, 1380, 8, 8]]     [1, 1380, 8, 8]           0       
     Conv2D-280       [[1, 1380, 8, 8]]      [1, 384, 8, 8]        530,304    
  BatchNorm2D-203      [[1, 384, 8, 8]]      [1, 384, 8, 8]         1,536     
     Dropout-39        [[1, 384, 8, 8]]      [1, 384, 8, 8]           0       
     Conv2D-281        [[1, 230, 8, 8]]      [1, 384, 8, 8]        88,704     
  BatchNorm2D-204      [[1, 384, 8, 8]]      [1, 384, 8, 8]         1,536     
     MBConv-38         [[1, 230, 8, 8]]      [1, 384, 8, 8]           0       
     Conv2D-282        [[1, 384, 8, 8]]     [1, 1280, 8, 8]        492,800    
  BatchNorm2D-205     [[1, 1280, 8, 8]]     [1, 1280, 8, 8]         5,120     
     Swish-116        [[1, 1280, 8, 8]]     [1, 1280, 8, 8]           0       
   ConvBNLayer-78      [[1, 384, 8, 8]]     [1, 1280, 8, 8]           0       
AdaptiveAvgPool2D-41  [[1, 1280, 8, 8]]     [1, 1280, 1, 1]           0       
     Dropout-40       [[1, 1280, 1, 1]]     [1, 1280, 1, 1]           0       
     Conv2D-283       [[1, 1280, 1, 1]]     [1, 1000, 1, 1]       1,281,000   
 Classifier_Head-2     [[1, 384, 8, 8]]        [1, 1000]              0       
   EfficientNet-2     [[1, 3, 256, 256]]       [1, 1000]              0       
     Conv2D-284      [[1, 256, 512, 512]]  [1, 64, 253, 253]       802,880    
  BatchNorm2D-206    [[1, 64, 253, 253]]   [1, 64, 253, 253]         256      
    MaxPool2D-2      [[1, 64, 253, 253]]   [1, 64, 126, 126]          0       
     Conv2D-285      [[1, 64, 126, 126]]   [1, 64, 126, 126]       36,928     
  BatchNorm2D-207    [[1, 64, 126, 126]]   [1, 64, 126, 126]         256      
     Conv2D-286      [[1, 64, 126, 126]]   [1, 64, 126, 126]       36,928     
  BatchNorm2D-208    [[1, 64, 126, 126]]   [1, 64, 126, 126]         256      
   Basicblock-17     [[1, 64, 126, 126]]   [1, 64, 126, 126]          0       
     Conv2D-288      [[1, 64, 126, 126]]   [1, 64, 126, 126]       36,928     
  BatchNorm2D-210    [[1, 64, 126, 126]]   [1, 64, 126, 126]         256      
     Conv2D-289      [[1, 64, 126, 126]]   [1, 64, 126, 126]       36,928     
  BatchNorm2D-211    [[1, 64, 126, 126]]   [1, 64, 126, 126]         256      
   Basicblock-18     [[1, 64, 126, 126]]   [1, 64, 126, 126]          0       
     Conv2D-291      [[1, 64, 126, 126]]   [1, 64, 126, 126]       36,928     
  BatchNorm2D-213    [[1, 64, 126, 126]]   [1, 64, 126, 126]         256      
     Conv2D-292      [[1, 64, 126, 126]]   [1, 64, 126, 126]       36,928     
  BatchNorm2D-214    [[1, 64, 126, 126]]   [1, 64, 126, 126]         256      
   Basicblock-19     [[1, 64, 126, 126]]   [1, 64, 126, 126]          0       
     Conv2D-294      [[1, 64, 126, 126]]    [1, 128, 63, 63]       73,856     
  BatchNorm2D-216     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
     Conv2D-295       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
  BatchNorm2D-217     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
     Conv2D-296      [[1, 64, 126, 126]]    [1, 128, 63, 63]        8,320     
  BatchNorm2D-218     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
   Basicblock-20     [[1, 64, 126, 126]]    [1, 128, 63, 63]          0       
     Conv2D-297       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
  BatchNorm2D-219     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
     Conv2D-298       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
  BatchNorm2D-220     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
   Basicblock-21      [[1, 128, 63, 63]]    [1, 128, 63, 63]          0       
     Conv2D-300       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
  BatchNorm2D-222     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
     Conv2D-301       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
  BatchNorm2D-223     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
   Basicblock-22      [[1, 128, 63, 63]]    [1, 128, 63, 63]          0       
     Conv2D-303       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
  BatchNorm2D-225     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
     Conv2D-304       [[1, 128, 63, 63]]    [1, 128, 63, 63]       147,584    
  BatchNorm2D-226     [[1, 128, 63, 63]]    [1, 128, 63, 63]         512      
   Basicblock-23      [[1, 128, 63, 63]]    [1, 128, 63, 63]          0       
     Conv2D-306       [[1, 128, 63, 63]]    [1, 256, 32, 32]       295,168    
  BatchNorm2D-228     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-307       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-229     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-308       [[1, 128, 63, 63]]    [1, 256, 32, 32]       33,024     
  BatchNorm2D-230     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
   Basicblock-24      [[1, 128, 63, 63]]    [1, 256, 32, 32]          0       
     Conv2D-309       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-231     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-310       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-232     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
   Basicblock-25      [[1, 256, 32, 32]]    [1, 256, 32, 32]          0       
     Conv2D-312       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-234     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-313       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-235     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
   Basicblock-26      [[1, 256, 32, 32]]    [1, 256, 32, 32]          0       
     Conv2D-315       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-237     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-316       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-238     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
   Basicblock-27      [[1, 256, 32, 32]]    [1, 256, 32, 32]          0       
     Conv2D-318       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-240     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-319       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-241     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
   Basicblock-28      [[1, 256, 32, 32]]    [1, 256, 32, 32]          0       
     Conv2D-321       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-243     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
     Conv2D-322       [[1, 256, 32, 32]]    [1, 256, 32, 32]       590,080    
  BatchNorm2D-244     [[1, 256, 32, 32]]    [1, 256, 32, 32]        1,024     
   Basicblock-29      [[1, 256, 32, 32]]    [1, 256, 32, 32]          0       
     Conv2D-324       [[1, 256, 32, 32]]    [1, 512, 16, 16]      1,180,160   
  BatchNorm2D-246     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
     Conv2D-325       [[1, 512, 16, 16]]    [1, 512, 16, 16]      2,359,808   
  BatchNorm2D-247     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
     Conv2D-326       [[1, 256, 32, 32]]    [1, 512, 16, 16]       131,584    
  BatchNorm2D-248     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
   Basicblock-30      [[1, 256, 32, 32]]    [1, 512, 16, 16]          0       
     Conv2D-327       [[1, 512, 16, 16]]    [1, 512, 16, 16]      2,359,808   
  BatchNorm2D-249     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
     Conv2D-328       [[1, 512, 16, 16]]    [1, 512, 16, 16]      2,359,808   
  BatchNorm2D-250     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
   Basicblock-31      [[1, 512, 16, 16]]    [1, 512, 16, 16]          0       
     Conv2D-330       [[1, 512, 16, 16]]    [1, 512, 16, 16]      2,359,808   
  BatchNorm2D-252     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
     Conv2D-331       [[1, 512, 16, 16]]    [1, 512, 16, 16]      2,359,808   
  BatchNorm2D-253     [[1, 512, 16, 16]]    [1, 512, 16, 16]        2,048     
   Basicblock-32      [[1, 512, 16, 16]]    [1, 512, 16, 16]          0       
AdaptiveAvgPool2D-42  [[1, 512, 16, 16]]     [1, 512, 1, 1]           0       
      ResNet-2       [[1, 256, 512, 512]]       [1, 512]              0       
      Linear-1           [[1, 1512]]             [1, 3]             4,539     
================================================================================
Total params: 36,436,367
Trainable params: 36,353,075
Non-trainable params: 83,292
--------------------------------------------------------------------------------
Input size (MB): 256.75
Forward/backward pass size (MB): 1107.63
Params size (MB): 138.99
Estimated Total Size (MB): 1503.37
--------------------------------------------------------------------------------






{'total_params': 36436367, 'trainable_params': 36353075}

五、模型训练

2000次迭代,训练时间大概需要8小时。

# 训练逻辑
def train(
    model,
    iters,
    train_dataloader,
    val_dataloader,
    optimizer,
    criterion,
    log_interval,
    eval_interval,
):
    iter = 0
    model.train()
    avg_loss_list = []
    avg_kappa_list = []
    best_kappa = 0.5
    while iter < iters:
        for data in train_dataloader:
            iter += 1
            if iter > iters:
                break
            fundus_imgs = (data[0] / 255.0).astype("float32")
            oct_imgs = (data[1] / 255.0).astype("float32")
            labels = data[2].astype("int64")
            logits = model(fundus_imgs, oct_imgs)
            loss = criterion(logits, labels)
            for p, l in zip(logits.numpy().argmax(1), labels.numpy()):
                avg_kappa_list.append([p, l])

            loss.backward()
            optimizer.step()

            model.clear_gradients()
            avg_loss_list.append(loss.numpy()[0])

            if iter % log_interval == 0:
                avg_loss = np.array(avg_loss_list).mean()
                avg_kappa_list = np.array(avg_kappa_list)
                # 计算Cohen’s kappa分数
                avg_kappa = cohen_kappa_score(
                    avg_kappa_list[:, 0],
                    avg_kappa_list[:, 1],
                    weights="quadratic",
                )
                avg_loss_list = []
                avg_kappa_list = []
                print(
                    "[TRAIN] iter={}/{} avg_loss={:.4f} avg_kappa={:.4f}".format(
                        iter, iters, avg_loss, avg_kappa
                    )
                )

            if iter % eval_interval == 0:
                avg_loss, avg_kappa = val(model, val_dataloader, criterion)
                print(
                    "[EVAL] iter={}/{} avg_loss={:.4f} kappa={:.4f}".format(
                        iter, iters, avg_loss, avg_kappa
                    )
                )
                # 保存精度更好的模型
                if avg_kappa >= best_kappa:
                    best_kappa = avg_kappa
                    paddle.save(
                        model.state_dict(),
                        os.path.join(
                            "best_model_{:.4f}".format(best_kappa),
                            "model.pdparams",
                        ),
                    )
                model.train()


# 验证逻辑
def val(model, val_dataloader, criterion):
    model.eval()
    avg_loss_list = []
    cache = []
    with paddle.no_grad():
        for data in val_dataloader:
            fundus_imgs = (data[0] / 255.0).astype("float32")
            oct_imgs = (data[1] / 255.0).astype("float32")
            labels = data[2].astype("int64")

            logits = model(fundus_imgs, oct_imgs)
            for p, l in zip(logits.numpy().argmax(1), labels.numpy()):
                cache.append([p, l])

            loss = criterion(logits, labels)
            avg_loss_list.append(loss.numpy()[0])

    cache = np.array(cache)
    kappa = cohen_kappa_score(cache[:, 0], cache[:, 1], weights="quadratic")
    avg_loss = np.array(avg_loss_list).mean()

    return avg_loss, kappa


# 读取数据的线程数
num_workers = 4

# batch_size大小
batchsize = 2

# 总的迭代次数
iters = 2000

# 优化器类型
optimizer_type = "adam"

# 初始学习率
init_lr = 1e-3

# Dataset
train_dataset = GAMMA_sub1_dataset(
    dataset_root=trainset_root,
    img_transforms=img_train_transforms,
    oct_transforms=oct_train_transforms,
    filelists=train_filelists,
    label_file=gt_file,
)

val_dataset = GAMMA_sub1_dataset(
    dataset_root=trainset_root,
    img_transforms=img_val_transforms,
    oct_transforms=oct_val_transforms,
    filelists=val_filelists,
    label_file=gt_file,
)

# DataLoader
train_loader = paddle.io.DataLoader(
    train_dataset,
    num_workers=num_workers,
    batch_size=batchsize,
    shuffle=True,
    return_list=True,
    use_shared_memory=False,
)

val_loader = paddle.io.DataLoader(
    val_dataset,
    batch_size=batchsize,
    num_workers=num_workers,
    return_list=True,
    use_shared_memory=False,
)

if optimizer_type == "adam":
    optimizer = paddle.optimizer.Adam(init_lr, parameters=model.parameters())

# ----- 使用交叉熵作为损失函数 ------#
criterion = nn.CrossEntropyLoss()

# ----- 训练模型 -----------------#
train(
    model,
    iters,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    log_interval=50,
    eval_interval=50,
)
[TRAIN] iter=10/2000 avg_loss=2.1748 avg_kappa=-0.1957
[TRAIN] iter=20/2000 avg_loss=1.7035 avg_kappa=0.0774
[TRAIN] iter=30/2000 avg_loss=0.9511 avg_kappa=0.7015
[TRAIN] iter=40/2000 avg_loss=1.2922 avg_kappa=0.1606
[TRAIN] iter=50/2000 avg_loss=1.1163 avg_kappa=0.0000
[EVAL] iter=50/2000 avg_loss=10.3795 kappa=-0.2698
[TRAIN] iter=60/2000 avg_loss=1.1850 avg_kappa=-0.0606
[TRAIN] iter=70/2000 avg_loss=1.1729 avg_kappa=-0.2185
[TRAIN] iter=80/2000 avg_loss=1.0922 avg_kappa=0.1093
[TRAIN] iter=90/2000 avg_loss=0.9435 avg_kappa=0.2708
[TRAIN] iter=100/2000 avg_loss=1.2347 avg_kappa=0.0299
[EVAL] iter=100/2000 avg_loss=1.0698 kappa=0.3558
[TRAIN] iter=110/2000 avg_loss=1.0827 avg_kappa=0.2500
[TRAIN] iter=120/2000 avg_loss=1.0206 avg_kappa=-0.1354
[TRAIN] iter=130/2000 avg_loss=1.4113 avg_kappa=0.0968
[TRAIN] iter=140/2000 avg_loss=0.9458 avg_kappa=0.0625
[TRAIN] iter=150/2000 avg_loss=1.1128 avg_kappa=0.0000
[EVAL] iter=150/2000 avg_loss=2.1132 kappa=0.2978
[TRAIN] iter=160/2000 avg_loss=1.0617 avg_kappa=0.3791
[TRAIN] iter=170/2000 avg_loss=1.0624 avg_kappa=0.2808
[TRAIN] iter=180/2000 avg_loss=0.9705 avg_kappa=0.3561
[TRAIN] iter=190/2000 avg_loss=0.9573 avg_kappa=0.1078
[TRAIN] iter=200/2000 avg_loss=0.9949 avg_kappa=0.1176
[EVAL] iter=200/2000 avg_loss=2.2989 kappa=0.2982
[TRAIN] iter=210/2000 avg_loss=1.0670 avg_kappa=-0.0870
[TRAIN] iter=220/2000 avg_loss=0.8388 avg_kappa=-0.0596
[TRAIN] iter=230/2000 avg_loss=1.1535 avg_kappa=0.0625
[TRAIN] iter=240/2000 avg_loss=1.0679 avg_kappa=0.5335
[TRAIN] iter=250/2000 avg_loss=0.9604 avg_kappa=0.4318
[EVAL] iter=250/2000 avg_loss=1.2917 kappa=0.3895
[TRAIN] iter=260/2000 avg_loss=0.8026 avg_kappa=0.1964
[TRAIN] iter=270/2000 avg_loss=0.6450 avg_kappa=0.4536
[TRAIN] iter=280/2000 avg_loss=1.1303 avg_kappa=0.3511
[TRAIN] iter=290/2000 avg_loss=1.0567 avg_kappa=0.4462
[TRAIN] iter=300/2000 avg_loss=1.1012 avg_kappa=0.2647
[EVAL] iter=300/2000 avg_loss=1.0134 kappa=0.5455
[TRAIN] iter=310/2000 avg_loss=1.3930 avg_kappa=-0.3462
[TRAIN] iter=320/2000 avg_loss=0.9429 avg_kappa=0.0050
[TRAIN] iter=330/2000 avg_loss=0.8889 avg_kappa=0.0000
[TRAIN] iter=340/2000 avg_loss=1.1127 avg_kappa=0.1558
[TRAIN] iter=350/2000 avg_loss=1.0205 avg_kappa=0.1822
[EVAL] iter=350/2000 avg_loss=1.5745 kappa=0.0813
[TRAIN] iter=360/2000 avg_loss=1.2648 avg_kappa=0.0000
[TRAIN] iter=370/2000 avg_loss=1.0565 avg_kappa=-0.0714
[TRAIN] iter=380/2000 avg_loss=0.8164 avg_kappa=0.2857
[TRAIN] iter=390/2000 avg_loss=1.0671 avg_kappa=0.0222
[TRAIN] iter=400/2000 avg_loss=1.0092 avg_kappa=0.5305
[EVAL] iter=400/2000 avg_loss=1.0891 kappa=0.2982
[TRAIN] iter=410/2000 avg_loss=0.8200 avg_kappa=0.5745
[TRAIN] iter=420/2000 avg_loss=0.9284 avg_kappa=0.3827
[TRAIN] iter=430/2000 avg_loss=0.9759 avg_kappa=0.1497
[TRAIN] iter=440/2000 avg_loss=0.9261 avg_kappa=0.2188
[TRAIN] iter=450/2000 avg_loss=0.7655 avg_kappa=0.7500
[EVAL] iter=450/2000 avg_loss=1.0015 kappa=0.3229
[TRAIN] iter=460/2000 avg_loss=1.0300 avg_kappa=0.4271
[TRAIN] iter=470/2000 avg_loss=1.1922 avg_kappa=0.0189
[TRAIN] iter=480/2000 avg_loss=1.0318 avg_kappa=0.5122
[TRAIN] iter=490/2000 avg_loss=0.6496 avg_kappa=0.6269
[TRAIN] iter=500/2000 avg_loss=0.7465 avg_kappa=0.6154
[EVAL] iter=500/2000 avg_loss=0.9899 kappa=0.4842
[TRAIN] iter=510/2000 avg_loss=1.1083 avg_kappa=0.1290
[TRAIN] iter=520/2000 avg_loss=1.0746 avg_kappa=0.2053
[TRAIN] iter=530/2000 avg_loss=0.8995 avg_kappa=0.3421
[TRAIN] iter=540/2000 avg_loss=1.0162 avg_kappa=0.4419
[TRAIN] iter=550/2000 avg_loss=0.7716 avg_kappa=0.4583
[EVAL] iter=550/2000 avg_loss=0.8828 kappa=0.1722
[TRAIN] iter=560/2000 avg_loss=0.7626 avg_kappa=0.4754
[TRAIN] iter=570/2000 avg_loss=1.2483 avg_kappa=0.3625
[TRAIN] iter=580/2000 avg_loss=1.0986 avg_kappa=0.2565
[TRAIN] iter=590/2000 avg_loss=0.7910 avg_kappa=0.7143
[TRAIN] iter=600/2000 avg_loss=0.9362 avg_kappa=0.2917
[EVAL] iter=600/2000 avg_loss=0.8705 kappa=0.3438
[TRAIN] iter=610/2000 avg_loss=0.9716 avg_kappa=0.1176
[TRAIN] iter=620/2000 avg_loss=0.9979 avg_kappa=-0.0191
[TRAIN] iter=630/2000 avg_loss=0.6981 avg_kappa=0.6535
[TRAIN] iter=640/2000 avg_loss=0.7475 avg_kappa=0.4444
[TRAIN] iter=650/2000 avg_loss=1.0678 avg_kappa=0.6835
[EVAL] iter=650/2000 avg_loss=1.3692 kappa=0.2368
[TRAIN] iter=660/2000 avg_loss=0.8503 avg_kappa=0.5161
[TRAIN] iter=670/2000 avg_loss=1.0898 avg_kappa=0.2460
[TRAIN] iter=680/2000 avg_loss=0.8219 avg_kappa=0.1632
[TRAIN] iter=690/2000 avg_loss=1.0924 avg_kappa=0.6667
[TRAIN] iter=700/2000 avg_loss=0.9786 avg_kappa=0.3784
[EVAL] iter=700/2000 avg_loss=2.2221 kappa=0.3046
[TRAIN] iter=710/2000 avg_loss=1.1610 avg_kappa=0.3182
[TRAIN] iter=720/2000 avg_loss=1.0428 avg_kappa=0.1453
[TRAIN] iter=730/2000 avg_loss=0.8828 avg_kappa=0.4500
[TRAIN] iter=740/2000 avg_loss=1.0252 avg_kappa=0.2105
[TRAIN] iter=750/2000 avg_loss=0.8088 avg_kappa=0.3443
[EVAL] iter=750/2000 avg_loss=2.0567 kappa=0.0813
[TRAIN] iter=760/2000 avg_loss=0.9734 avg_kappa=0.1579
[TRAIN] iter=770/2000 avg_loss=0.9835 avg_kappa=0.3636
[TRAIN] iter=780/2000 avg_loss=0.7815 avg_kappa=0.2568
[TRAIN] iter=790/2000 avg_loss=0.8366 avg_kappa=0.7799
[TRAIN] iter=800/2000 avg_loss=0.9148 avg_kappa=0.4860
[EVAL] iter=800/2000 avg_loss=0.7692 kappa=0.5909
[TRAIN] iter=810/2000 avg_loss=1.0086 avg_kappa=0.2500
[TRAIN] iter=820/2000 avg_loss=0.6844 avg_kappa=0.5588
[TRAIN] iter=830/2000 avg_loss=0.7571 avg_kappa=0.5455
[TRAIN] iter=840/2000 avg_loss=0.9979 avg_kappa=0.0415
[TRAIN] iter=850/2000 avg_loss=0.9125 avg_kappa=0.2029
[EVAL] iter=850/2000 avg_loss=1.3723 kappa=0.0000
[TRAIN] iter=860/2000 avg_loss=0.9698 avg_kappa=0.2143
[TRAIN] iter=870/2000 avg_loss=0.7990 avg_kappa=0.3902
[TRAIN] iter=880/2000 avg_loss=0.7721 avg_kappa=0.5098
[TRAIN] iter=890/2000 avg_loss=0.9212 avg_kappa=-0.0606
[TRAIN] iter=900/2000 avg_loss=0.8273 avg_kappa=0.2574
[EVAL] iter=900/2000 avg_loss=0.9112 kappa=0.3657
[TRAIN] iter=910/2000 avg_loss=0.8701 avg_kappa=0.6154
[TRAIN] iter=920/2000 avg_loss=0.9726 avg_kappa=0.1818
[TRAIN] iter=930/2000 avg_loss=0.8154 avg_kappa=0.4834
[TRAIN] iter=940/2000 avg_loss=1.0587 avg_kappa=0.4000
[TRAIN] iter=950/2000 avg_loss=1.1157 avg_kappa=0.3092
[EVAL] iter=950/2000 avg_loss=1.0636 kappa=0.3537
[TRAIN] iter=960/2000 avg_loss=0.8583 avg_kappa=0.5882
[TRAIN] iter=970/2000 avg_loss=1.0570 avg_kappa=0.1364
[TRAIN] iter=980/2000 avg_loss=1.1619 avg_kappa=0.2674
[TRAIN] iter=990/2000 avg_loss=0.8051 avg_kappa=0.6685
[TRAIN] iter=1000/2000 avg_loss=0.8298 avg_kappa=0.2609
[EVAL] iter=1000/2000 avg_loss=1.8720 kappa=0.0000
[TRAIN] iter=1010/2000 avg_loss=0.8890 avg_kappa=0.1304
[TRAIN] iter=1020/2000 avg_loss=0.7587 avg_kappa=0.6939
[TRAIN] iter=1030/2000 avg_loss=0.9327 avg_kappa=0.3831
[TRAIN] iter=1040/2000 avg_loss=1.0000 avg_kappa=0.4643
[TRAIN] iter=1050/2000 avg_loss=0.8043 avg_kappa=0.5522
[EVAL] iter=1050/2000 avg_loss=0.9665 kappa=0.2405
[TRAIN] iter=1060/2000 avg_loss=0.7553 avg_kappa=0.5370
[TRAIN] iter=1070/2000 avg_loss=0.8057 avg_kappa=0.2017
[TRAIN] iter=1080/2000 avg_loss=0.8780 avg_kappa=0.3182
[TRAIN] iter=1090/2000 avg_loss=0.9831 avg_kappa=0.2279
[TRAIN] iter=1100/2000 avg_loss=0.9947 avg_kappa=0.2083
[EVAL] iter=1100/2000 avg_loss=0.8967 kappa=0.7635
[TRAIN] iter=1110/2000 avg_loss=0.9275 avg_kappa=0.0809
[TRAIN] iter=1120/2000 avg_loss=1.0070 avg_kappa=-0.0811
[TRAIN] iter=1130/2000 avg_loss=0.8866 avg_kappa=0.4444
[TRAIN] iter=1140/2000 avg_loss=0.7797 avg_kappa=0.4681
[TRAIN] iter=1150/2000 avg_loss=0.6757 avg_kappa=0.4811
[EVAL] iter=1150/2000 avg_loss=1.6824 kappa=0.2553
[TRAIN] iter=1160/2000 avg_loss=0.8656 avg_kappa=0.3966
[TRAIN] iter=1170/2000 avg_loss=0.5916 avg_kappa=0.6970
[TRAIN] iter=1180/2000 avg_loss=1.1541 avg_kappa=-0.1677
[TRAIN] iter=1190/2000 avg_loss=0.7411 avg_kappa=0.6552
[TRAIN] iter=1200/2000 avg_loss=1.1016 avg_kappa=0.4545
[EVAL] iter=1200/2000 avg_loss=1.0677 kappa=0.4091
[TRAIN] iter=1210/2000 avg_loss=0.9788 avg_kappa=0.4366
[TRAIN] iter=1220/2000 avg_loss=0.7062 avg_kappa=0.5516
[TRAIN] iter=1230/2000 avg_loss=0.8960 avg_kappa=0.4583
[TRAIN] iter=1240/2000 avg_loss=0.8596 avg_kappa=0.2273
[TRAIN] iter=1250/2000 avg_loss=0.9534 avg_kappa=0.3991
[EVAL] iter=1250/2000 avg_loss=1.0067 kappa=0.1822
[TRAIN] iter=1260/2000 avg_loss=1.0190 avg_kappa=0.3902
[TRAIN] iter=1270/2000 avg_loss=0.7897 avg_kappa=0.4615
[TRAIN] iter=1280/2000 avg_loss=0.6785 avg_kappa=0.6721
[TRAIN] iter=1290/2000 avg_loss=1.1654 avg_kappa=0.0800
[TRAIN] iter=1300/2000 avg_loss=1.1622 avg_kappa=0.4000
[EVAL] iter=1300/2000 avg_loss=4.2278 kappa=0.3046
[TRAIN] iter=1310/2000 avg_loss=0.6945 avg_kappa=-0.1644
[TRAIN] iter=1320/2000 avg_loss=1.1275 avg_kappa=0.2063
[TRAIN] iter=1330/2000 avg_loss=1.1955 avg_kappa=0.0782
[TRAIN] iter=1340/2000 avg_loss=0.8959 avg_kappa=0.2803
[TRAIN] iter=1350/2000 avg_loss=0.6391 avg_kappa=0.5909
[EVAL] iter=1350/2000 avg_loss=1.1239 kappa=0.1900
[TRAIN] iter=1360/2000 avg_loss=0.9569 avg_kappa=0.2713
[TRAIN] iter=1370/2000 avg_loss=1.1284 avg_kappa=0.2898
[TRAIN] iter=1380/2000 avg_loss=0.9595 avg_kappa=0.3030
[TRAIN] iter=1390/2000 avg_loss=0.7360 avg_kappa=0.5690
[TRAIN] iter=1400/2000 avg_loss=0.4813 avg_kappa=0.7500
[EVAL] iter=1400/2000 avg_loss=0.9864 kappa=0.4538
[TRAIN] iter=1410/2000 avg_loss=0.8431 avg_kappa=0.4595
[TRAIN] iter=1420/2000 avg_loss=1.0595 avg_kappa=0.2208
[TRAIN] iter=1430/2000 avg_loss=0.9496 avg_kappa=0.3333
[TRAIN] iter=1440/2000 avg_loss=0.7636 avg_kappa=0.5779
[TRAIN] iter=1450/2000 avg_loss=0.8095 avg_kappa=0.5312
[EVAL] iter=1450/2000 avg_loss=1.5890 kappa=-0.0601
[TRAIN] iter=1460/2000 avg_loss=0.9720 avg_kappa=0.1532
[TRAIN] iter=1470/2000 avg_loss=0.7457 avg_kappa=0.5714
[TRAIN] iter=1480/2000 avg_loss=0.7661 avg_kappa=0.7287
[TRAIN] iter=1490/2000 avg_loss=0.6770 avg_kappa=0.3617
[TRAIN] iter=1500/2000 avg_loss=0.8323 avg_kappa=0.6444
[EVAL] iter=1500/2000 avg_loss=1.3753 kappa=0.0816
[TRAIN] iter=1510/2000 avg_loss=0.6356 avg_kappa=0.5408
[TRAIN] iter=1520/2000 avg_loss=1.1022 avg_kappa=0.1093
[TRAIN] iter=1530/2000 avg_loss=0.9522 avg_kappa=0.3116
[TRAIN] iter=1540/2000 avg_loss=0.9507 avg_kappa=0.2063
[TRAIN] iter=1550/2000 avg_loss=0.6121 avg_kappa=0.7682
[EVAL] iter=1550/2000 avg_loss=0.8446 kappa=0.5434
[TRAIN] iter=1560/2000 avg_loss=0.7899 avg_kappa=0.4643
[TRAIN] iter=1570/2000 avg_loss=0.7781 avg_kappa=0.5769
[TRAIN] iter=1580/2000 avg_loss=0.6827 avg_kappa=0.4397
[TRAIN] iter=1590/2000 avg_loss=0.7548 avg_kappa=0.4545
[TRAIN] iter=1600/2000 avg_loss=0.6444 avg_kappa=0.6983
[EVAL] iter=1600/2000 avg_loss=1.2631 kappa=0.1905
[TRAIN] iter=1610/2000 avg_loss=0.6851 avg_kappa=0.6897
[TRAIN] iter=1620/2000 avg_loss=0.4364 avg_kappa=0.8500
[TRAIN] iter=1630/2000 avg_loss=0.7269 avg_kappa=0.5192
[TRAIN] iter=1640/2000 avg_loss=0.6269 avg_kappa=0.6284
[TRAIN] iter=1650/2000 avg_loss=0.9471 avg_kappa=0.3750
[EVAL] iter=1650/2000 avg_loss=1.1361 kappa=0.1864
[TRAIN] iter=1660/2000 avg_loss=0.9059 avg_kappa=0.4667
[TRAIN] iter=1670/2000 avg_loss=0.6955 avg_kappa=0.7170
[TRAIN] iter=1680/2000 avg_loss=0.7104 avg_kappa=0.5455
[TRAIN] iter=1690/2000 avg_loss=0.9699 avg_kappa=0.5556
[TRAIN] iter=1700/2000 avg_loss=0.5785 avg_kappa=0.6707
[EVAL] iter=1700/2000 avg_loss=1.9127 kappa=-0.0123
[TRAIN] iter=1710/2000 avg_loss=0.9138 avg_kappa=0.5455
[TRAIN] iter=1720/2000 avg_loss=0.6990 avg_kappa=0.3000
[TRAIN] iter=1730/2000 avg_loss=0.5126 avg_kappa=0.5322
[TRAIN] iter=1740/2000 avg_loss=1.3681 avg_kappa=0.0033
[TRAIN] iter=1750/2000 avg_loss=1.0812 avg_kappa=0.3359
[EVAL] iter=1750/2000 avg_loss=1.4875 kappa=-0.3081
[TRAIN] iter=1760/2000 avg_loss=0.7500 avg_kappa=0.5708
[TRAIN] iter=1770/2000 avg_loss=0.5901 avg_kappa=0.7967
[TRAIN] iter=1780/2000 avg_loss=0.6868 avg_kappa=0.8264
[TRAIN] iter=1790/2000 avg_loss=0.7448 avg_kappa=0.6818
[TRAIN] iter=1800/2000 avg_loss=1.1918 avg_kappa=0.2105
[EVAL] iter=1800/2000 avg_loss=2.2048 kappa=-0.0847
[TRAIN] iter=1810/2000 avg_loss=0.5920 avg_kappa=0.7863
[TRAIN] iter=1820/2000 avg_loss=0.7900 avg_kappa=0.7840
[TRAIN] iter=1830/2000 avg_loss=0.6585 avg_kappa=0.5570
[TRAIN] iter=1840/2000 avg_loss=0.6186 avg_kappa=0.7917
[TRAIN] iter=1850/2000 avg_loss=0.5335 avg_kappa=0.6721
[EVAL] iter=1850/2000 avg_loss=0.8442 kappa=0.3822
[TRAIN] iter=1860/2000 avg_loss=0.7680 avg_kappa=0.7143
[TRAIN] iter=1870/2000 avg_loss=0.8156 avg_kappa=0.5139
[TRAIN] iter=1880/2000 avg_loss=0.5862 avg_kappa=0.3590
[TRAIN] iter=1890/2000 avg_loss=1.2277 avg_kappa=0.3810
[TRAIN] iter=1900/2000 avg_loss=1.0219 avg_kappa=0.2965
[EVAL] iter=1900/2000 avg_loss=1.0719 kappa=0.4101
[TRAIN] iter=1910/2000 avg_loss=0.6899 avg_kappa=0.6721
[TRAIN] iter=1920/2000 avg_loss=0.5552 avg_kappa=0.7078
[TRAIN] iter=1930/2000 avg_loss=0.5540 avg_kappa=0.7818
[TRAIN] iter=1940/2000 avg_loss=0.5370 avg_kappa=0.4578
[TRAIN] iter=1950/2000 avg_loss=0.6274 avg_kappa=0.3689
[EVAL] iter=1950/2000 avg_loss=0.9100 kappa=0.2674
[TRAIN] iter=1960/2000 avg_loss=0.9596 avg_kappa=0.5455
[TRAIN] iter=1970/2000 avg_loss=0.7679 avg_kappa=0.6429
[TRAIN] iter=1980/2000 avg_loss=0.6650 avg_kappa=0.7009
[TRAIN] iter=1990/2000 avg_loss=0.6712 avg_kappa=0.6541
[TRAIN] iter=2000/2000 avg_loss=0.4856 avg_kappa=0.8214
[EVAL] iter=2000/2000 avg_loss=1.3298 kappa=0.0813

注意:如果在AI Studio训练过程出现数据加载错误,可能是因为训练集文件夹下多了一个.ipynb_checkpoints文件,使用rm -r .ipynb_checkpoints命令删去即可。

六、模型预测

6.1 模型预测

# 加载训练权重
best_model_path = "./best_model_0.9344/model.pdparams"
model = Model()
para_state_dict = paddle.load(best_model_path)
model.set_state_dict(para_state_dict)
model.eval()

img_test_transforms = trans.Compose([trans.Resize(image_size)])

oct_test_transforms = trans.Compose([trans.CenterCrop(oct_img_size)])

test_dataset = GAMMA_sub1_dataset(
    dataset_root=testset_root,
    img_transforms=img_test_transforms,
    oct_transforms=oct_test_transforms,
    mode="test",
)

cache = []
for fundus_img, oct_img, idx in tqdm(test_dataset):
    fundus_img = fundus_img[np.newaxis, ...]
    oct_img = oct_img[np.newaxis, ...]

    fundus_img = paddle.to_tensor((fundus_img / 255.0).astype("float32"))
    oct_img = paddle.to_tensor((oct_img / 255.0).astype("float32"))
    logits = model(fundus_img, oct_img)
    cache.append([idx, logits.numpy().argmax(1)])
 15%|█▍        | 9/62 [01:59<11:43, 13.27s/it]Premature end of JPEG file
100%|██████████| 62/62 [13:47<00:00, 13.34s/it]

6.2 展示预测结果

展示原图片以及模型预测类别,其中non代表正常,early代表早期,mid_advanced代表中晚期。

# 类别映射字典
class_map = {0: "non", 1: "early", 2: "mid_advanced"}

# 调整子图间距
plt.tight_layout()
plt.subplots_adjust(wspace=0.6, hspace=0.6)
for i in range(6):
    index, pred_class = cache[i]
    fundus_img = cv2.imread(
        "/home/aistudio/Glaucoma_grading/testing/multi-modality_images/{}/{}.jpg".format(
            index, index
        )
    )
    plt.subplot(2, 3, i + 1)
    plt.imshow(fundus_img[:, :, ::-1])
    plt.axis("off")
    # 输出当前数据以及预测类别
    plt.title("data:{}, pred:{}".format(index, class_map[pred_class]))