GAMMA比赛多模态眼底图像数据集下基于EfficientNet和ResNet构造fundus_img和oct_img的分类模型¶
作者信息:Tomoko-hjf
更新日期:2023 年 2 月
摘要:本示例教程演示如何使用EfficientNet和ResNet双分支网络完成多模态眼底图像分类
一、简要介绍¶
青光眼是一种不可逆的会导致视力下降的眼部疾病,及时地发现和诊断对青光眼的治疗至关重要。
在本实验中,我们有2D彩色眼底图像(如下面左图所示)
和3D OCT扫描体数据(如下面右图所示)
两种模态的临床数据。下面右图展示的仅为3D图像中某一切片,实际每组3D图都由256张二维切片构成。
我们的任务是根据视觉特征将样本分级为无青光眼
、早期青光眼
、中或晚期青光眼
三个类别。
二、环境设置¶
导入PaddlePaddle
和一些其他数据处理会用到的包,本教程基于PaddlePaddle2.4.0
编写。
# 导入包
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import cohen_kappa_score
from tqdm import tqdm
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.vision.transforms as trans
import warnings
warnings.filterwarnings('ignore')
print(paddle.__version__)
2.4.0
三、数据集¶
3.1 数据介绍¶
本项目使用的数据集来自2021 MICCAI
竞赛的GAMMA
挑战赛任务一:多模态青光眼分级,该数据将青光眼分为non
(无),early
(早期),mid_advanced
(中晚期)三类。
数据集中包含200个两种临床模态影像的数据对,包括100对训练集,100对测试集。训练集中每对数据都包含一张常规的2D眼底彩照和一组3D光学相干断层扫描(OCT)两种模态的数据,每组3D图都由256张二维切片构成。
这里采用挂载飞桨官方数据集的方式。
3.3 划分训练集和测试集¶
因为官方只给了100个训练集和100个测试集,为了在训练过程中验证模型准确度,所以需要将官方训练集进一步划分为训练集和验证集。这里我们按8:2
进行了划分。
# 划分训练集和测试集的比例
val_ratio = 0.2 # 80 / 20
# 训练数据根目录
trainset_root = "Glaucoma_grading/training/multi-modality_images"
# 标签文件名
gt_file = 'Glaucoma_grading/training/glaucoma_grading_training_GT.xlsx'
# 测试数据根目录
testset_root = "Glaucoma_grading/testing/multi-modality_images"
# 读取所有训练数据文件名
filelists = os.listdir(trainset_root)
# 按照划分比例进行划分
train_filelists, val_filelists = train_test_split(filelists, test_size=val_ratio, random_state=42)
print("Total Nums: {}, train: {}, val: {}".format(len(filelists), len(train_filelists), len(val_filelists)))
Total Nums: 100, train: 80, val: 20
3.4 数据集类定义¶
定义自己的数据加载类GAMMA_sub1_dataset
,该加载类继承了飞桨(PaddlePaddle)的paddle.io.Dataset
父类,并实现了__getitem__()
和__len__()
方法,__getitem__()
方法可以根据下标获取输入网络的图像数据,__len__()
方法可以返回数据集大小。
class GAMMA_sub1_dataset(paddle.io.Dataset):
def __init__(self,
img_transforms,
oct_transforms,
dataset_root,
label_file='',
filelists=None,
num_classes=3,
mode='train'):
self.dataset_root = dataset_root
self.img_transforms = img_transforms
self.oct_transforms = oct_transforms
self.mode = mode.lower()
self.num_classes = num_classes
# 如果是加载训练集,则需要加载label
if self.mode == 'train':
# 使用pandas读取label,label是one-hot形式
label = {row['data']: row[1:].values for _, row in pd.read_excel(label_file).iterrows()}
self.file_list = [[f, label[int(f)]] for f in os.listdir(dataset_root) ]
# 如果是加载测试集,则label为空
elif self.mode == "test":
self.file_list = [[f, None] for f in os.listdir(dataset_root)]
# 如果指定了加载哪些数据,则只加载指定的数据
if filelists is not None:
self.file_list = [item for item in self.file_list if item[0] in filelists]
def __getitem__(self, idx):
# 获取指定下标的训练集和标签,real_index是样本所在的文件夹名称
real_index, label = self.file_list[idx]
# 彩色眼底图像的路径
fundus_img_path = os.path.join(self.dataset_root, real_index, real_index + ".jpg")
# 光学相干层析(OCT)图片的路径集,一个3D OCT图片包含256张二维切片
oct_series_list = sorted(os.listdir(os.path.join(self.dataset_root, real_index, real_index)), key=lambda s: int(s.split('_')[0]))
# 使用opencv读取图片,并转换通道 BGR -> RGB
fundus_img = cv2.imread(fundus_img_path)[:, :, ::-1]
# 读取3D OCT图片的一个切片,注意是灰度图 cv2.IMREAD_GRAYSCALE
oct_series_0 = cv2.imread(os.path.join(self.dataset_root, real_index, real_index, oct_series_list[0]), cv2.IMREAD_GRAYSCALE)
oct_img = np.zeros((oct_series_0.shape[0], oct_series_0.shape[1], len(oct_series_list)), dtype="uint8")
# 依次读取每一个切片
for k, p in enumerate(oct_series_list):
oct_img[:, :, k] = cv2.imread(os.path.join(self.dataset_root, real_index, real_index, p), cv2.IMREAD_GRAYSCALE)
# 对彩色眼底图片进行数据增强
if self.img_transforms is not None:
fundus_img = self.img_transforms(fundus_img)
# 对3D OCT图片进行数据增强
if self.oct_transforms is not None:
oct_img = self.oct_transforms(oct_img)
# 交换维度,变为[通道数,高,宽], H, W, C -> C, H, W
fundus_img = fundus_img.transpose(2, 0, 1)
oct_img = oct_img.transpose(2, 0, 1)
if self.mode == 'test':
return fundus_img, oct_img, real_index
if self.mode == "train":
label = label.argmax()
return fundus_img, oct_img, label
# 获取数据集总的长度
def __len__(self):
return len(self.file_list)
3.5 数据增强操作¶
考虑到训练数据集较小,为了避免过拟合,可以使用paddle.vision.transform
进行适当的水平翻转,垂直翻转,按比例裁剪等数据增强操作。
# 彩色眼底图片大小
image_size = [256, 256]
# 三维OCT图片每个切片的大小
oct_img_size = [512, 512]
# 数据增强操作
img_train_transforms = trans.Compose([
# 按比例随机裁剪原图后放缩到对应大小
trans.RandomResizedCrop(image_size, scale=(0.90, 1.1), ratio=(0.90, 1.1)),
# 随机水平翻转
trans.RandomHorizontalFlip(),
# 随机垂直翻转
trans.RandomVerticalFlip(),
# 随机旋转
trans.RandomRotation(30)
])
oct_train_transforms = trans.Compose([
# 中心裁剪到对应大小
trans.CenterCrop(oct_img_size),
# 随机水平翻转
trans.RandomHorizontalFlip(),
# 随机垂直翻转
trans.RandomVerticalFlip()
])
img_val_transforms = trans.Compose([
# 将图片放缩到固定大小
trans.Resize(image_size)
])
oct_val_transforms = trans.Compose([
# 中心裁剪到固定大小
trans.CenterCrop(oct_img_size)
])
3.6 数据集抽取展示¶
# 可视化训练数据集
_train = GAMMA_sub1_dataset(dataset_root=trainset_root,
img_transforms=img_train_transforms,
oct_transforms=oct_train_transforms,
filelists=train_filelists,
label_file=gt_file)
for i in range(5):
fundus_img, oct_img, lab = _train.__getitem__(i)
plt.subplot(2, 5, i+1)
plt.imshow(fundus_img.transpose(1, 2, 0))
plt.axis("off")
plt.subplot(2, 5, i+6)
# 展示灰度图
plt.imshow(oct_img[100], cmap='gray')
if i == 0:
print('fundus_size', fundus_img.shape)
print('oct_size', oct_img[100].shape)
plt.axis("off")
四、模型组网¶
考虑到是多模态数据,所以使用了两个分支分别对数据进行训练,对于彩色眼底图像
使用EfficientNetB3训练,对于3D OCT
数据使用ResNet训练,然后将两个网络的输出特征共同输入分类头进行预测。
4.1 搭建EfficientNet网络分支¶
EfficientNet
通过同时改变网络的分辨率
(特征图的大小)、深度
(网络的层数)、宽度
(网络每层的通道数)来实现较高的准确率和较快的推理能力,网络的整体结构如下图所示。
class ConvBNLayer(nn.Layer):
def __init__(self, inchannels, outchannels, stride, kernelsize = 3, groups = 1, padding="SAME"):
super(ConvBNLayer, self).__init__()
padding = (kernelsize - 1) // 2
self.conv = nn.Conv2D(
in_channels=inchannels,
out_channels=outchannels,
kernel_size=kernelsize,
stride=stride,
padding=padding,
groups=groups
)
self.bn = nn.BatchNorm2D(outchannels)
self.Swish = nn.Swish()
def forward(self, inputs):
x = self.conv(inputs)
x = self.bn(x)
x = self.Swish(x)
return x
############ SE-Net ####################
class SE(nn.Layer):
def __init__(self, inchannels):
super(SE, self).__init__()
self.pooling = nn.AdaptiveAvgPool2D(output_size=(1, 1))
self.linear0 = nn.Conv2D(in_channels=inchannels,
out_channels=int(inchannels*0.25),
kernel_size=1)
self.linear1 = nn.Conv2D(in_channels=int(inchannels*0.25),
out_channels=inchannels,
kernel_size=1)
self.Swish = nn.Swish()
self.Sigmoid = nn.Sigmoid()
def forward(self, inputs):
x = self.pooling(inputs)
x = self.linear0(x)
x = self.Swish(x)
x = self.linear1(x)
x = self.Sigmoid(x)
out = paddle.multiply(x, inputs)
return out
############# MBConV ########################
class MBConv(nn.Layer):
def __init__(self, inchannels, outchannels, channels_time, kernel_size, stride, connected_dropout):
super(MBConv, self).__init__()
self.stride = stride
self.layer0 = ConvBNLayer(
inchannels=inchannels,
outchannels=inchannels*channels_time,
kernelsize=1,
stride=1,
)
self.layer1 = ConvBNLayer(
inchannels=inchannels*channels_time,
outchannels=inchannels*channels_time,
kernelsize=kernel_size,
stride=stride,
groups=inchannels*channels_time
)
self.SE = SE(inchannels = inchannels*channels_time)
self.conv0 = nn.Conv2D(
in_channels=inchannels*channels_time,
out_channels=outchannels,
kernel_size=1
)
self.bn0 = nn.BatchNorm2D(outchannels)
self.conv1 = nn.Conv2D(
in_channels=inchannels,
out_channels=outchannels,
kernel_size=1
)
self.bn1 = nn.BatchNorm2D(outchannels)
self.dropout = nn.Dropout(p=connected_dropout)
def forward(self, inputs):
y = inputs
x = self.layer0(inputs)
x = self.layer1(x)
x = self.SE(x)
x = self.conv0(x)
x = self.bn0(x)
x = self.dropout(x)
if self.stride == 2:
return x
if self.stride == 1:
y = self.conv1(inputs)
y = self.bn1(y)
return paddle.add(x, y)
############# Classifier_Head ######################
class Classifier_Head(paddle.nn.Layer):
def __init__(self, in_channels, num_channel, dropout_rate):
super(Classifier_Head, self).__init__()
self.pooling = nn.AdaptiveAvgPool2D(output_size=(1, 1))
self.conv = ConvBNLayer(inchannels=in_channels,
outchannels=1280,
kernelsize=1,
stride=1
)
self.dropout = nn.Dropout(p=dropout_rate)
self.conv1 = nn.Conv2D(
in_channels=1280,
out_channels=num_channel,
kernel_size=1,
padding="SAME"
)
def forward(self, inputs):
x = self.conv(inputs)
x = self.pooling(x)
x = self.dropout(x)
x = self.conv1(x)
x = paddle.squeeze(x, axis=[2, 3])
x = F.softmax(x)
return x
############### EffictNet ##################
class EfficientNet(nn.Layer):
def __init__(self, in_channels, num_class, width_coefficient, depth_coefficient, connected_dropout, dropout_rate):
super(EfficientNet, self).__init__()
block_setting=[[0, 1, 3, 1, 1, 16],
[1, 2, 3, 2, 6, 24],
[2, 2, 5, 2, 6, 40],
[3, 3, 3, 2, 6, 80],
[4, 3, 5, 1, 6, 112],
[5, 4, 5, 2, 6, 192],
[6, 1, 3, 1, 6, 320]]
self.block = []
self.block.append(self.add_sublayer('c'+str(-1), nn.Conv2D(in_channels=in_channels,
out_channels=int(32*width_coefficient),
kernel_size=3,
padding="SAME",
stride=2)))
self.block.append(self.add_sublayer('bn'+str(-1),nn.BatchNorm2D(int(32*width_coefficient))))
i = int(32*width_coefficient)
for j in range(int(depth_coefficient)-1):
if j==int(depth_coefficient)-2:
self.block.append(self.add_sublayer('c'+str(j), nn.Conv2D(in_channels=i,
out_channels=int(32*width_coefficient),
kernel_size=3,
padding="SAME",
stride=2)))
self.block.append(self.add_sublayer('bn'+str(j),nn.BatchNorm2D(int(32*width_coefficient))))
else:
self.block.append(self.add_sublayer('c'+str(j), nn.Conv2D(in_channels=i,
out_channels=int(32*width_coefficient),
kernel_size=3,
padding="SAME")))
self.block.append(self.add_sublayer('bn'+str(j),nn.BatchNorm2D(int(32*width_coefficient))))
for n, r, k, s, e, o in block_setting:
for j in range(int(r*depth_coefficient)):
if j==int(r*depth_coefficient)-1:
self.block.append(self.add_sublayer('b'+str(n)+str(j),
MBConv(inchannels=i,
outchannels=int(o*width_coefficient),
channels_time=e,
kernel_size=k,
stride=s,
connected_dropout=connected_dropout
)))
else:
self.block.append(self.add_sublayer('b'+str(n)+str(j),
MBConv(inchannels=i,
outchannels=int(o*width_coefficient),
channels_time=e,
kernel_size=k,
stride=1,
connected_dropout=connected_dropout
)))
i = int(o*width_coefficient)
self.head = Classifier_Head(in_channels=i,
num_channel=num_class,
dropout_rate=dropout_rate)
def forward(self, x):
for layer in self.block:
x = layer(x)
x = self.head(x)
return x
def EfficientNetB3(in_channels, num_class):
return EfficientNet(in_channels=in_channels,
num_class=num_class,
width_coefficient=1.2,
depth_coefficient=1.4,
connected_dropout=0.2,
dropout_rate=0.3)
# 展示EfficientNetB3分支网络结构
efficientnet = EfficientNetB3(in_channels=3, num_class=1000)
paddle.summary(efficientnet, (1, 3, 256, 256))
--------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
================================================================================
Conv2D-1 [[1, 3, 256, 256]] [1, 38, 128, 128] 1,064
BatchNorm2D-1 [[1, 38, 128, 128]] [1, 38, 128, 128] 152
Conv2D-2 [[1, 38, 128, 128]] [1, 38, 128, 128] 1,482
BatchNorm2D-2 [[1, 38, 128, 128]] [1, 38, 128, 128] 152
Swish-1 [[1, 38, 128, 128]] [1, 38, 128, 128] 0
ConvBNLayer-1 [[1, 38, 128, 128]] [1, 38, 128, 128] 0
Conv2D-3 [[1, 38, 128, 128]] [1, 38, 128, 128] 380
BatchNorm2D-3 [[1, 38, 128, 128]] [1, 38, 128, 128] 152
Swish-2 [[1, 38, 128, 128]] [1, 38, 128, 128] 0
ConvBNLayer-2 [[1, 38, 128, 128]] [1, 38, 128, 128] 0
AdaptiveAvgPool2D-1 [[1, 38, 128, 128]] [1, 38, 1, 1] 0
Conv2D-4 [[1, 38, 1, 1]] [1, 9, 1, 1] 351
Swish-3 [[1, 9, 1, 1]] [1, 9, 1, 1] 0
Conv2D-5 [[1, 9, 1, 1]] [1, 38, 1, 1] 380
Sigmoid-1 [[1, 38, 1, 1]] [1, 38, 1, 1] 0
SE-1 [[1, 38, 128, 128]] [1, 38, 128, 128] 0
Conv2D-6 [[1, 38, 128, 128]] [1, 19, 128, 128] 741
BatchNorm2D-4 [[1, 19, 128, 128]] [1, 19, 128, 128] 76
Dropout-1 [[1, 19, 128, 128]] [1, 19, 128, 128] 0
Conv2D-7 [[1, 38, 128, 128]] [1, 19, 128, 128] 741
BatchNorm2D-5 [[1, 19, 128, 128]] [1, 19, 128, 128] 76
MBConv-1 [[1, 38, 128, 128]] [1, 19, 128, 128] 0
Conv2D-8 [[1, 19, 128, 128]] [1, 114, 128, 128] 2,280
BatchNorm2D-6 [[1, 114, 128, 128]] [1, 114, 128, 128] 456
Swish-4 [[1, 114, 128, 128]] [1, 114, 128, 128] 0
ConvBNLayer-3 [[1, 19, 128, 128]] [1, 114, 128, 128] 0
Conv2D-9 [[1, 114, 128, 128]] [1, 114, 128, 128] 1,140
BatchNorm2D-7 [[1, 114, 128, 128]] [1, 114, 128, 128] 456
Swish-5 [[1, 114, 128, 128]] [1, 114, 128, 128] 0
ConvBNLayer-4 [[1, 114, 128, 128]] [1, 114, 128, 128] 0
AdaptiveAvgPool2D-2 [[1, 114, 128, 128]] [1, 114, 1, 1] 0
Conv2D-10 [[1, 114, 1, 1]] [1, 28, 1, 1] 3,220
Swish-6 [[1, 28, 1, 1]] [1, 28, 1, 1] 0
Conv2D-11 [[1, 28, 1, 1]] [1, 114, 1, 1] 3,306
Sigmoid-2 [[1, 114, 1, 1]] [1, 114, 1, 1] 0
SE-2 [[1, 114, 128, 128]] [1, 114, 128, 128] 0
Conv2D-12 [[1, 114, 128, 128]] [1, 28, 128, 128] 3,220
BatchNorm2D-8 [[1, 28, 128, 128]] [1, 28, 128, 128] 112
Dropout-2 [[1, 28, 128, 128]] [1, 28, 128, 128] 0
Conv2D-13 [[1, 19, 128, 128]] [1, 28, 128, 128] 560
BatchNorm2D-9 [[1, 28, 128, 128]] [1, 28, 128, 128] 112
MBConv-2 [[1, 19, 128, 128]] [1, 28, 128, 128] 0
Conv2D-14 [[1, 28, 128, 128]] [1, 168, 128, 128] 4,872
BatchNorm2D-10 [[1, 168, 128, 128]] [1, 168, 128, 128] 672
Swish-7 [[1, 168, 128, 128]] [1, 168, 128, 128] 0
ConvBNLayer-5 [[1, 28, 128, 128]] [1, 168, 128, 128] 0
Conv2D-15 [[1, 168, 128, 128]] [1, 168, 64, 64] 1,680
BatchNorm2D-11 [[1, 168, 64, 64]] [1, 168, 64, 64] 672
Swish-8 [[1, 168, 64, 64]] [1, 168, 64, 64] 0
ConvBNLayer-6 [[1, 168, 128, 128]] [1, 168, 64, 64] 0
AdaptiveAvgPool2D-3 [[1, 168, 64, 64]] [1, 168, 1, 1] 0
Conv2D-16 [[1, 168, 1, 1]] [1, 42, 1, 1] 7,098
Swish-9 [[1, 42, 1, 1]] [1, 42, 1, 1] 0
Conv2D-17 [[1, 42, 1, 1]] [1, 168, 1, 1] 7,224
Sigmoid-3 [[1, 168, 1, 1]] [1, 168, 1, 1] 0
SE-3 [[1, 168, 64, 64]] [1, 168, 64, 64] 0
Conv2D-18 [[1, 168, 64, 64]] [1, 28, 64, 64] 4,732
BatchNorm2D-12 [[1, 28, 64, 64]] [1, 28, 64, 64] 112
Dropout-3 [[1, 28, 64, 64]] [1, 28, 64, 64] 0
MBConv-3 [[1, 28, 128, 128]] [1, 28, 64, 64] 0
Conv2D-20 [[1, 28, 64, 64]] [1, 168, 64, 64] 4,872
BatchNorm2D-14 [[1, 168, 64, 64]] [1, 168, 64, 64] 672
Swish-10 [[1, 168, 64, 64]] [1, 168, 64, 64] 0
ConvBNLayer-7 [[1, 28, 64, 64]] [1, 168, 64, 64] 0
Conv2D-21 [[1, 168, 64, 64]] [1, 168, 64, 64] 4,368
BatchNorm2D-15 [[1, 168, 64, 64]] [1, 168, 64, 64] 672
Swish-11 [[1, 168, 64, 64]] [1, 168, 64, 64] 0
ConvBNLayer-8 [[1, 168, 64, 64]] [1, 168, 64, 64] 0
AdaptiveAvgPool2D-4 [[1, 168, 64, 64]] [1, 168, 1, 1] 0
Conv2D-22 [[1, 168, 1, 1]] [1, 42, 1, 1] 7,098
Swish-12 [[1, 42, 1, 1]] [1, 42, 1, 1] 0
Conv2D-23 [[1, 42, 1, 1]] [1, 168, 1, 1] 7,224
Sigmoid-4 [[1, 168, 1, 1]] [1, 168, 1, 1] 0
SE-4 [[1, 168, 64, 64]] [1, 168, 64, 64] 0
Conv2D-24 [[1, 168, 64, 64]] [1, 48, 64, 64] 8,112
BatchNorm2D-16 [[1, 48, 64, 64]] [1, 48, 64, 64] 192
Dropout-4 [[1, 48, 64, 64]] [1, 48, 64, 64] 0
Conv2D-25 [[1, 28, 64, 64]] [1, 48, 64, 64] 1,392
BatchNorm2D-17 [[1, 48, 64, 64]] [1, 48, 64, 64] 192
MBConv-4 [[1, 28, 64, 64]] [1, 48, 64, 64] 0
Conv2D-26 [[1, 48, 64, 64]] [1, 288, 64, 64] 14,112
BatchNorm2D-18 [[1, 288, 64, 64]] [1, 288, 64, 64] 1,152
Swish-13 [[1, 288, 64, 64]] [1, 288, 64, 64] 0
ConvBNLayer-9 [[1, 48, 64, 64]] [1, 288, 64, 64] 0
Conv2D-27 [[1, 288, 64, 64]] [1, 288, 32, 32] 7,488
BatchNorm2D-19 [[1, 288, 32, 32]] [1, 288, 32, 32] 1,152
Swish-14 [[1, 288, 32, 32]] [1, 288, 32, 32] 0
ConvBNLayer-10 [[1, 288, 64, 64]] [1, 288, 32, 32] 0
AdaptiveAvgPool2D-5 [[1, 288, 32, 32]] [1, 288, 1, 1] 0
Conv2D-28 [[1, 288, 1, 1]] [1, 72, 1, 1] 20,808
Swish-15 [[1, 72, 1, 1]] [1, 72, 1, 1] 0
Conv2D-29 [[1, 72, 1, 1]] [1, 288, 1, 1] 21,024
Sigmoid-5 [[1, 288, 1, 1]] [1, 288, 1, 1] 0
SE-5 [[1, 288, 32, 32]] [1, 288, 32, 32] 0
Conv2D-30 [[1, 288, 32, 32]] [1, 48, 32, 32] 13,872
BatchNorm2D-20 [[1, 48, 32, 32]] [1, 48, 32, 32] 192
Dropout-5 [[1, 48, 32, 32]] [1, 48, 32, 32] 0
MBConv-5 [[1, 48, 64, 64]] [1, 48, 32, 32] 0
Conv2D-32 [[1, 48, 32, 32]] [1, 288, 32, 32] 14,112
BatchNorm2D-22 [[1, 288, 32, 32]] [1, 288, 32, 32] 1,152
Swish-16 [[1, 288, 32, 32]] [1, 288, 32, 32] 0
ConvBNLayer-11 [[1, 48, 32, 32]] [1, 288, 32, 32] 0
Conv2D-33 [[1, 288, 32, 32]] [1, 288, 32, 32] 2,880
BatchNorm2D-23 [[1, 288, 32, 32]] [1, 288, 32, 32] 1,152
Swish-17 [[1, 288, 32, 32]] [1, 288, 32, 32] 0
ConvBNLayer-12 [[1, 288, 32, 32]] [1, 288, 32, 32] 0
AdaptiveAvgPool2D-6 [[1, 288, 32, 32]] [1, 288, 1, 1] 0
Conv2D-34 [[1, 288, 1, 1]] [1, 72, 1, 1] 20,808
Swish-18 [[1, 72, 1, 1]] [1, 72, 1, 1] 0
Conv2D-35 [[1, 72, 1, 1]] [1, 288, 1, 1] 21,024
Sigmoid-6 [[1, 288, 1, 1]] [1, 288, 1, 1] 0
SE-6 [[1, 288, 32, 32]] [1, 288, 32, 32] 0
Conv2D-36 [[1, 288, 32, 32]] [1, 96, 32, 32] 27,744
BatchNorm2D-24 [[1, 96, 32, 32]] [1, 96, 32, 32] 384
Dropout-6 [[1, 96, 32, 32]] [1, 96, 32, 32] 0
Conv2D-37 [[1, 48, 32, 32]] [1, 96, 32, 32] 4,704
BatchNorm2D-25 [[1, 96, 32, 32]] [1, 96, 32, 32] 384
MBConv-6 [[1, 48, 32, 32]] [1, 96, 32, 32] 0
Conv2D-38 [[1, 96, 32, 32]] [1, 576, 32, 32] 55,872
BatchNorm2D-26 [[1, 576, 32, 32]] [1, 576, 32, 32] 2,304
Swish-19 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
ConvBNLayer-13 [[1, 96, 32, 32]] [1, 576, 32, 32] 0
Conv2D-39 [[1, 576, 32, 32]] [1, 576, 32, 32] 5,760
BatchNorm2D-27 [[1, 576, 32, 32]] [1, 576, 32, 32] 2,304
Swish-20 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
ConvBNLayer-14 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
AdaptiveAvgPool2D-7 [[1, 576, 32, 32]] [1, 576, 1, 1] 0
Conv2D-40 [[1, 576, 1, 1]] [1, 144, 1, 1] 83,088
Swish-21 [[1, 144, 1, 1]] [1, 144, 1, 1] 0
Conv2D-41 [[1, 144, 1, 1]] [1, 576, 1, 1] 83,520
Sigmoid-7 [[1, 576, 1, 1]] [1, 576, 1, 1] 0
SE-7 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
Conv2D-42 [[1, 576, 32, 32]] [1, 96, 32, 32] 55,392
BatchNorm2D-28 [[1, 96, 32, 32]] [1, 96, 32, 32] 384
Dropout-7 [[1, 96, 32, 32]] [1, 96, 32, 32] 0
Conv2D-43 [[1, 96, 32, 32]] [1, 96, 32, 32] 9,312
BatchNorm2D-29 [[1, 96, 32, 32]] [1, 96, 32, 32] 384
MBConv-7 [[1, 96, 32, 32]] [1, 96, 32, 32] 0
Conv2D-44 [[1, 96, 32, 32]] [1, 576, 32, 32] 55,872
BatchNorm2D-30 [[1, 576, 32, 32]] [1, 576, 32, 32] 2,304
Swish-22 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
ConvBNLayer-15 [[1, 96, 32, 32]] [1, 576, 32, 32] 0
Conv2D-45 [[1, 576, 32, 32]] [1, 576, 32, 32] 5,760
BatchNorm2D-31 [[1, 576, 32, 32]] [1, 576, 32, 32] 2,304
Swish-23 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
ConvBNLayer-16 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
AdaptiveAvgPool2D-8 [[1, 576, 32, 32]] [1, 576, 1, 1] 0
Conv2D-46 [[1, 576, 1, 1]] [1, 144, 1, 1] 83,088
Swish-24 [[1, 144, 1, 1]] [1, 144, 1, 1] 0
Conv2D-47 [[1, 144, 1, 1]] [1, 576, 1, 1] 83,520
Sigmoid-8 [[1, 576, 1, 1]] [1, 576, 1, 1] 0
SE-8 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
Conv2D-48 [[1, 576, 32, 32]] [1, 96, 32, 32] 55,392
BatchNorm2D-32 [[1, 96, 32, 32]] [1, 96, 32, 32] 384
Dropout-8 [[1, 96, 32, 32]] [1, 96, 32, 32] 0
Conv2D-49 [[1, 96, 32, 32]] [1, 96, 32, 32] 9,312
BatchNorm2D-33 [[1, 96, 32, 32]] [1, 96, 32, 32] 384
MBConv-8 [[1, 96, 32, 32]] [1, 96, 32, 32] 0
Conv2D-50 [[1, 96, 32, 32]] [1, 576, 32, 32] 55,872
BatchNorm2D-34 [[1, 576, 32, 32]] [1, 576, 32, 32] 2,304
Swish-25 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
ConvBNLayer-17 [[1, 96, 32, 32]] [1, 576, 32, 32] 0
Conv2D-51 [[1, 576, 32, 32]] [1, 576, 16, 16] 5,760
BatchNorm2D-35 [[1, 576, 16, 16]] [1, 576, 16, 16] 2,304
Swish-26 [[1, 576, 16, 16]] [1, 576, 16, 16] 0
ConvBNLayer-18 [[1, 576, 32, 32]] [1, 576, 16, 16] 0
AdaptiveAvgPool2D-9 [[1, 576, 16, 16]] [1, 576, 1, 1] 0
Conv2D-52 [[1, 576, 1, 1]] [1, 144, 1, 1] 83,088
Swish-27 [[1, 144, 1, 1]] [1, 144, 1, 1] 0
Conv2D-53 [[1, 144, 1, 1]] [1, 576, 1, 1] 83,520
Sigmoid-9 [[1, 576, 1, 1]] [1, 576, 1, 1] 0
SE-9 [[1, 576, 16, 16]] [1, 576, 16, 16] 0
Conv2D-54 [[1, 576, 16, 16]] [1, 96, 16, 16] 55,392
BatchNorm2D-36 [[1, 96, 16, 16]] [1, 96, 16, 16] 384
Dropout-9 [[1, 96, 16, 16]] [1, 96, 16, 16] 0
MBConv-9 [[1, 96, 32, 32]] [1, 96, 16, 16] 0
Conv2D-56 [[1, 96, 16, 16]] [1, 576, 16, 16] 55,872
BatchNorm2D-38 [[1, 576, 16, 16]] [1, 576, 16, 16] 2,304
Swish-28 [[1, 576, 16, 16]] [1, 576, 16, 16] 0
ConvBNLayer-19 [[1, 96, 16, 16]] [1, 576, 16, 16] 0
Conv2D-57 [[1, 576, 16, 16]] [1, 576, 16, 16] 14,976
BatchNorm2D-39 [[1, 576, 16, 16]] [1, 576, 16, 16] 2,304
Swish-29 [[1, 576, 16, 16]] [1, 576, 16, 16] 0
ConvBNLayer-20 [[1, 576, 16, 16]] [1, 576, 16, 16] 0
AdaptiveAvgPool2D-10 [[1, 576, 16, 16]] [1, 576, 1, 1] 0
Conv2D-58 [[1, 576, 1, 1]] [1, 144, 1, 1] 83,088
Swish-30 [[1, 144, 1, 1]] [1, 144, 1, 1] 0
Conv2D-59 [[1, 144, 1, 1]] [1, 576, 1, 1] 83,520
Sigmoid-10 [[1, 576, 1, 1]] [1, 576, 1, 1] 0
SE-10 [[1, 576, 16, 16]] [1, 576, 16, 16] 0
Conv2D-60 [[1, 576, 16, 16]] [1, 134, 16, 16] 77,318
BatchNorm2D-40 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
Dropout-10 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-61 [[1, 96, 16, 16]] [1, 134, 16, 16] 12,998
BatchNorm2D-41 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
MBConv-10 [[1, 96, 16, 16]] [1, 134, 16, 16] 0
Conv2D-62 [[1, 134, 16, 16]] [1, 804, 16, 16] 108,540
BatchNorm2D-42 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-31 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-21 [[1, 134, 16, 16]] [1, 804, 16, 16] 0
Conv2D-63 [[1, 804, 16, 16]] [1, 804, 16, 16] 20,904
BatchNorm2D-43 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-32 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-22 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
AdaptiveAvgPool2D-11 [[1, 804, 16, 16]] [1, 804, 1, 1] 0
Conv2D-64 [[1, 804, 1, 1]] [1, 201, 1, 1] 161,805
Swish-33 [[1, 201, 1, 1]] [1, 201, 1, 1] 0
Conv2D-65 [[1, 201, 1, 1]] [1, 804, 1, 1] 162,408
Sigmoid-11 [[1, 804, 1, 1]] [1, 804, 1, 1] 0
SE-11 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
Conv2D-66 [[1, 804, 16, 16]] [1, 134, 16, 16] 107,870
BatchNorm2D-44 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
Dropout-11 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-67 [[1, 134, 16, 16]] [1, 134, 16, 16] 18,090
BatchNorm2D-45 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
MBConv-11 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-68 [[1, 134, 16, 16]] [1, 804, 16, 16] 108,540
BatchNorm2D-46 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-34 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-23 [[1, 134, 16, 16]] [1, 804, 16, 16] 0
Conv2D-69 [[1, 804, 16, 16]] [1, 804, 16, 16] 20,904
BatchNorm2D-47 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-35 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-24 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
AdaptiveAvgPool2D-12 [[1, 804, 16, 16]] [1, 804, 1, 1] 0
Conv2D-70 [[1, 804, 1, 1]] [1, 201, 1, 1] 161,805
Swish-36 [[1, 201, 1, 1]] [1, 201, 1, 1] 0
Conv2D-71 [[1, 201, 1, 1]] [1, 804, 1, 1] 162,408
Sigmoid-12 [[1, 804, 1, 1]] [1, 804, 1, 1] 0
SE-12 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
Conv2D-72 [[1, 804, 16, 16]] [1, 134, 16, 16] 107,870
BatchNorm2D-48 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
Dropout-12 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-73 [[1, 134, 16, 16]] [1, 134, 16, 16] 18,090
BatchNorm2D-49 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
MBConv-12 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-74 [[1, 134, 16, 16]] [1, 804, 16, 16] 108,540
BatchNorm2D-50 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-37 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-25 [[1, 134, 16, 16]] [1, 804, 16, 16] 0
Conv2D-75 [[1, 804, 16, 16]] [1, 804, 16, 16] 20,904
BatchNorm2D-51 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-38 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-26 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
AdaptiveAvgPool2D-13 [[1, 804, 16, 16]] [1, 804, 1, 1] 0
Conv2D-76 [[1, 804, 1, 1]] [1, 201, 1, 1] 161,805
Swish-39 [[1, 201, 1, 1]] [1, 201, 1, 1] 0
Conv2D-77 [[1, 201, 1, 1]] [1, 804, 1, 1] 162,408
Sigmoid-13 [[1, 804, 1, 1]] [1, 804, 1, 1] 0
SE-13 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
Conv2D-78 [[1, 804, 16, 16]] [1, 134, 16, 16] 107,870
BatchNorm2D-52 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
Dropout-13 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-79 [[1, 134, 16, 16]] [1, 134, 16, 16] 18,090
BatchNorm2D-53 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
MBConv-13 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-80 [[1, 134, 16, 16]] [1, 804, 16, 16] 108,540
BatchNorm2D-54 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-40 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-27 [[1, 134, 16, 16]] [1, 804, 16, 16] 0
Conv2D-81 [[1, 804, 16, 16]] [1, 804, 16, 16] 20,904
BatchNorm2D-55 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-41 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-28 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
AdaptiveAvgPool2D-14 [[1, 804, 16, 16]] [1, 804, 1, 1] 0
Conv2D-82 [[1, 804, 1, 1]] [1, 201, 1, 1] 161,805
Swish-42 [[1, 201, 1, 1]] [1, 201, 1, 1] 0
Conv2D-83 [[1, 201, 1, 1]] [1, 804, 1, 1] 162,408
Sigmoid-14 [[1, 804, 1, 1]] [1, 804, 1, 1] 0
SE-14 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
Conv2D-84 [[1, 804, 16, 16]] [1, 230, 16, 16] 185,150
BatchNorm2D-56 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
Dropout-14 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-85 [[1, 134, 16, 16]] [1, 230, 16, 16] 31,050
BatchNorm2D-57 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
MBConv-14 [[1, 134, 16, 16]] [1, 230, 16, 16] 0
Conv2D-86 [[1, 230, 16, 16]] [1, 1380, 16, 16] 318,780
BatchNorm2D-58 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-43 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-29 [[1, 230, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-87 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 35,880
BatchNorm2D-59 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-44 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-30 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
AdaptiveAvgPool2D-15 [[1, 1380, 16, 16]] [1, 1380, 1, 1] 0
Conv2D-88 [[1, 1380, 1, 1]] [1, 345, 1, 1] 476,445
Swish-45 [[1, 345, 1, 1]] [1, 345, 1, 1] 0
Conv2D-89 [[1, 345, 1, 1]] [1, 1380, 1, 1] 477,480
Sigmoid-15 [[1, 1380, 1, 1]] [1, 1380, 1, 1] 0
SE-15 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-90 [[1, 1380, 16, 16]] [1, 230, 16, 16] 317,630
BatchNorm2D-60 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
Dropout-15 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-91 [[1, 230, 16, 16]] [1, 230, 16, 16] 53,130
BatchNorm2D-61 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
MBConv-15 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-92 [[1, 230, 16, 16]] [1, 1380, 16, 16] 318,780
BatchNorm2D-62 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-46 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-31 [[1, 230, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-93 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 35,880
BatchNorm2D-63 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-47 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-32 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
AdaptiveAvgPool2D-16 [[1, 1380, 16, 16]] [1, 1380, 1, 1] 0
Conv2D-94 [[1, 1380, 1, 1]] [1, 345, 1, 1] 476,445
Swish-48 [[1, 345, 1, 1]] [1, 345, 1, 1] 0
Conv2D-95 [[1, 345, 1, 1]] [1, 1380, 1, 1] 477,480
Sigmoid-16 [[1, 1380, 1, 1]] [1, 1380, 1, 1] 0
SE-16 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-96 [[1, 1380, 16, 16]] [1, 230, 16, 16] 317,630
BatchNorm2D-64 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
Dropout-16 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-97 [[1, 230, 16, 16]] [1, 230, 16, 16] 53,130
BatchNorm2D-65 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
MBConv-16 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-98 [[1, 230, 16, 16]] [1, 1380, 16, 16] 318,780
BatchNorm2D-66 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-49 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-33 [[1, 230, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-99 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 35,880
BatchNorm2D-67 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-50 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-34 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
AdaptiveAvgPool2D-17 [[1, 1380, 16, 16]] [1, 1380, 1, 1] 0
Conv2D-100 [[1, 1380, 1, 1]] [1, 345, 1, 1] 476,445
Swish-51 [[1, 345, 1, 1]] [1, 345, 1, 1] 0
Conv2D-101 [[1, 345, 1, 1]] [1, 1380, 1, 1] 477,480
Sigmoid-17 [[1, 1380, 1, 1]] [1, 1380, 1, 1] 0
SE-17 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-102 [[1, 1380, 16, 16]] [1, 230, 16, 16] 317,630
BatchNorm2D-68 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
Dropout-17 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-103 [[1, 230, 16, 16]] [1, 230, 16, 16] 53,130
BatchNorm2D-69 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
MBConv-17 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-104 [[1, 230, 16, 16]] [1, 1380, 16, 16] 318,780
BatchNorm2D-70 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-52 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-35 [[1, 230, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-105 [[1, 1380, 16, 16]] [1, 1380, 8, 8] 35,880
BatchNorm2D-71 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 5,520
Swish-53 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 0
ConvBNLayer-36 [[1, 1380, 16, 16]] [1, 1380, 8, 8] 0
AdaptiveAvgPool2D-18 [[1, 1380, 8, 8]] [1, 1380, 1, 1] 0
Conv2D-106 [[1, 1380, 1, 1]] [1, 345, 1, 1] 476,445
Swish-54 [[1, 345, 1, 1]] [1, 345, 1, 1] 0
Conv2D-107 [[1, 345, 1, 1]] [1, 1380, 1, 1] 477,480
Sigmoid-18 [[1, 1380, 1, 1]] [1, 1380, 1, 1] 0
SE-18 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 0
Conv2D-108 [[1, 1380, 8, 8]] [1, 230, 8, 8] 317,630
BatchNorm2D-72 [[1, 230, 8, 8]] [1, 230, 8, 8] 920
Dropout-18 [[1, 230, 8, 8]] [1, 230, 8, 8] 0
MBConv-18 [[1, 230, 16, 16]] [1, 230, 8, 8] 0
Conv2D-110 [[1, 230, 8, 8]] [1, 1380, 8, 8] 318,780
BatchNorm2D-74 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 5,520
Swish-55 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 0
ConvBNLayer-37 [[1, 230, 8, 8]] [1, 1380, 8, 8] 0
Conv2D-111 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 13,800
BatchNorm2D-75 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 5,520
Swish-56 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 0
ConvBNLayer-38 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 0
AdaptiveAvgPool2D-19 [[1, 1380, 8, 8]] [1, 1380, 1, 1] 0
Conv2D-112 [[1, 1380, 1, 1]] [1, 345, 1, 1] 476,445
Swish-57 [[1, 345, 1, 1]] [1, 345, 1, 1] 0
Conv2D-113 [[1, 345, 1, 1]] [1, 1380, 1, 1] 477,480
Sigmoid-19 [[1, 1380, 1, 1]] [1, 1380, 1, 1] 0
SE-19 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 0
Conv2D-114 [[1, 1380, 8, 8]] [1, 384, 8, 8] 530,304
BatchNorm2D-76 [[1, 384, 8, 8]] [1, 384, 8, 8] 1,536
Dropout-19 [[1, 384, 8, 8]] [1, 384, 8, 8] 0
Conv2D-115 [[1, 230, 8, 8]] [1, 384, 8, 8] 88,704
BatchNorm2D-77 [[1, 384, 8, 8]] [1, 384, 8, 8] 1,536
MBConv-19 [[1, 230, 8, 8]] [1, 384, 8, 8] 0
Conv2D-116 [[1, 384, 8, 8]] [1, 1280, 8, 8] 492,800
BatchNorm2D-78 [[1, 1280, 8, 8]] [1, 1280, 8, 8] 5,120
Swish-58 [[1, 1280, 8, 8]] [1, 1280, 8, 8] 0
ConvBNLayer-39 [[1, 384, 8, 8]] [1, 1280, 8, 8] 0
AdaptiveAvgPool2D-20 [[1, 1280, 8, 8]] [1, 1280, 1, 1] 0
Dropout-20 [[1, 1280, 1, 1]] [1, 1280, 1, 1] 0
Conv2D-117 [[1, 1280, 1, 1]] [1, 1000, 1, 1] 1,281,000
Classifier_Head-1 [[1, 384, 8, 8]] [1, 1000] 0
================================================================================
Total params: 14,328,212
Trainable params: 14,261,944
Non-trainable params: 66,268
--------------------------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 754.80
Params size (MB): 54.66
Estimated Total Size (MB): 810.21
--------------------------------------------------------------------------------
{'total_params': 14328212, 'trainable_params': 14261944}
4.2 搭建ResNet网络分支¶
ResNet
(Residual Neural Network)由Kaiming He
等人提出,该网络的核心思想是在卷积块中添加残差连接
,解决了因网络层数过深所引起的网络退化问题,可以帮助我们加深网络的深度,提高模型的能力。
卷积块中的残差连接如下图所示:
网络的整体结构如下图所示:
###### BasicBlock ###############
class Basicblock(paddle.nn.Layer):
def __init__(self, in_channel, out_channel, stride = 1):
super(Basicblock, self).__init__()
self.stride = stride
self.conv0 = nn.Conv2D(in_channel, out_channel, 3, stride = stride, padding = 1)
self.conv1 = nn.Conv2D(out_channel, out_channel, 3, stride=1, padding = 1)
self.conv2 = nn.Conv2D(in_channel, out_channel, 1, stride = stride)
self.bn0 = nn.BatchNorm2D(out_channel)
self.bn1 = nn.BatchNorm2D(out_channel)
self.bn2 = nn.BatchNorm2D(out_channel)
def forward(self, inputs):
y = inputs
x = self.conv0(inputs)
x = self.bn0(x)
x = F.relu(x)
x = self.conv1(x)
x = self.bn1(x)
if self.stride == 2:
y = self.conv2(y)
y = self.bn2(y)
z = F.relu(x+y)
return z
############ BottoleNeckBlock ##############
class Bottleneckblock(paddle.nn.Layer):
def __init__(self, inplane, in_channel, out_channel, stride = 1, start = False):
super(Bottleneckblock, self).__init__()
self.stride = stride
self.start = start
self.conv0 = nn.Conv2D(in_channel, inplane, 1, stride = stride)
self.conv1 = nn.Conv2D(inplane, inplane, 3, stride=1, padding=1)
self.conv2 = nn.Conv2D(inplane, out_channel, 1, stride=1)
self.conv3 = nn.Conv2D(in_channel, out_channel, 1, stride = stride)
self.bn0 = nn.BatchNorm2D(inplane)
self.bn1 = nn.BatchNorm2D(inplane)
self.bn2 = nn.BatchNorm2D(out_channel)
self.bn3 = nn.BatchNorm2D(out_channel)
def forward(self, inputs):
y = inputs
x = self.conv0(inputs)
x = self.bn0(x)
x = F.relu(x)
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
if self.start:
y = self.conv3(y)
y = self.bn3(y)
z = F.relu(x+y)
return z
################### ResNet #################
class ResNet(paddle.nn.Layer):
def __init__(self, num, bottlenet, in_channels):
super(ResNet, self).__init__()
self.conv0 = nn.Conv2D(in_channels, 64, 7, stride=2)
self.bn = nn.BatchNorm2D(64)
self.pool1 = nn.MaxPool2D(3, stride=2)
if bottlenet:
self.layer0 = self.add_bottleneck_layer(num[0], 64, start = True)
self.layer1 = self.add_bottleneck_layer(num[1], 128)
self.layer2 = self.add_bottleneck_layer(num[2], 256)
self.layer3 = self.add_bottleneck_layer(num[3], 512)
else:
self.layer0 = self.add_basic_layer(num[0], 64, start = True)
self.layer1 = self.add_basic_layer(num[1], 128)
self.layer2 = self.add_basic_layer(num[2], 256)
self.layer3 = self.add_basic_layer(num[3], 512)
self.pool2 = nn.AdaptiveAvgPool2D(output_size = (1, 1))
def add_basic_layer(self, num, inplane, start = False):
layer = []
if start:
layer.append(Basicblock(inplane, inplane))
else:
layer.append(Basicblock(inplane//2, inplane, stride = 2))
for i in range(num-1):
layer.append(Basicblock(inplane, inplane))
return nn.Sequential(*layer)
def add_bottleneck_layer(self, num, inplane, start = False):
layer = []
if start:
layer.append(Bottleneckblock(inplane, inplane, inplane*4, start = True))
else:
layer.append(Bottleneckblock(inplane, inplane*2, inplane*4, stride = 2, start = True))
for i in range(num-1):
layer.append(Bottleneckblock(inplane, inplane*4, inplane*4))
return nn.Sequential(*layer)
def forward(self, inputs):
x = self.conv0(inputs)
x = self.bn(x)
x = self.pool1(x)
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.pool2(x)
x = paddle.squeeze(x, axis=[2, 3])
return x
def ResNet34(in_channels=3):
return ResNet([3, 4, 6, 3], bottlenet = False, in_channels=in_channels)
# 展示ResNet分支网络结构
resnet = ResNet34(in_channels=256)
paddle.summary(resnet, (1, 256, 512, 512))
--------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
================================================================================
Conv2D-118 [[1, 256, 512, 512]] [1, 64, 253, 253] 802,880
BatchNorm2D-79 [[1, 64, 253, 253]] [1, 64, 253, 253] 256
MaxPool2D-1 [[1, 64, 253, 253]] [1, 64, 126, 126] 0
Conv2D-119 [[1, 64, 126, 126]] [1, 64, 126, 126] 36,928
BatchNorm2D-80 [[1, 64, 126, 126]] [1, 64, 126, 126] 256
Conv2D-120 [[1, 64, 126, 126]] [1, 64, 126, 126] 36,928
BatchNorm2D-81 [[1, 64, 126, 126]] [1, 64, 126, 126] 256
Basicblock-1 [[1, 64, 126, 126]] [1, 64, 126, 126] 0
Conv2D-122 [[1, 64, 126, 126]] [1, 64, 126, 126] 36,928
BatchNorm2D-83 [[1, 64, 126, 126]] [1, 64, 126, 126] 256
Conv2D-123 [[1, 64, 126, 126]] [1, 64, 126, 126] 36,928
BatchNorm2D-84 [[1, 64, 126, 126]] [1, 64, 126, 126] 256
Basicblock-2 [[1, 64, 126, 126]] [1, 64, 126, 126] 0
Conv2D-125 [[1, 64, 126, 126]] [1, 64, 126, 126] 36,928
BatchNorm2D-86 [[1, 64, 126, 126]] [1, 64, 126, 126] 256
Conv2D-126 [[1, 64, 126, 126]] [1, 64, 126, 126] 36,928
BatchNorm2D-87 [[1, 64, 126, 126]] [1, 64, 126, 126] 256
Basicblock-3 [[1, 64, 126, 126]] [1, 64, 126, 126] 0
Conv2D-128 [[1, 64, 126, 126]] [1, 128, 63, 63] 73,856
BatchNorm2D-89 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Conv2D-129 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-90 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Conv2D-130 [[1, 64, 126, 126]] [1, 128, 63, 63] 8,320
BatchNorm2D-91 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Basicblock-4 [[1, 64, 126, 126]] [1, 128, 63, 63] 0
Conv2D-131 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-92 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Conv2D-132 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-93 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Basicblock-5 [[1, 128, 63, 63]] [1, 128, 63, 63] 0
Conv2D-134 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-95 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Conv2D-135 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-96 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Basicblock-6 [[1, 128, 63, 63]] [1, 128, 63, 63] 0
Conv2D-137 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-98 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Conv2D-138 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-99 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Basicblock-7 [[1, 128, 63, 63]] [1, 128, 63, 63] 0
Conv2D-140 [[1, 128, 63, 63]] [1, 256, 32, 32] 295,168
BatchNorm2D-101 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-141 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-102 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-142 [[1, 128, 63, 63]] [1, 256, 32, 32] 33,024
BatchNorm2D-103 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Basicblock-8 [[1, 128, 63, 63]] [1, 256, 32, 32] 0
Conv2D-143 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-104 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-144 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-105 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Basicblock-9 [[1, 256, 32, 32]] [1, 256, 32, 32] 0
Conv2D-146 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-107 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-147 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-108 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Basicblock-10 [[1, 256, 32, 32]] [1, 256, 32, 32] 0
Conv2D-149 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-110 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-150 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-111 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Basicblock-11 [[1, 256, 32, 32]] [1, 256, 32, 32] 0
Conv2D-152 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-113 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-153 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-114 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Basicblock-12 [[1, 256, 32, 32]] [1, 256, 32, 32] 0
Conv2D-155 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-116 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-156 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-117 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Basicblock-13 [[1, 256, 32, 32]] [1, 256, 32, 32] 0
Conv2D-158 [[1, 256, 32, 32]] [1, 512, 16, 16] 1,180,160
BatchNorm2D-119 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Conv2D-159 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,359,808
BatchNorm2D-120 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Conv2D-160 [[1, 256, 32, 32]] [1, 512, 16, 16] 131,584
BatchNorm2D-121 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Basicblock-14 [[1, 256, 32, 32]] [1, 512, 16, 16] 0
Conv2D-161 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,359,808
BatchNorm2D-122 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Conv2D-162 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,359,808
BatchNorm2D-123 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Basicblock-15 [[1, 512, 16, 16]] [1, 512, 16, 16] 0
Conv2D-164 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,359,808
BatchNorm2D-125 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Conv2D-165 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,359,808
BatchNorm2D-126 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Basicblock-16 [[1, 512, 16, 16]] [1, 512, 16, 16] 0
AdaptiveAvgPool2D-21 [[1, 512, 16, 16]] [1, 512, 1, 1] 0
================================================================================
Total params: 22,103,616
Trainable params: 22,086,592
Non-trainable params: 17,024
--------------------------------------------------------------------------------
Input size (MB): 256.00
Forward/backward pass size (MB): 352.82
Params size (MB): 84.32
Estimated Total Size (MB): 693.13
--------------------------------------------------------------------------------
{'total_params': 22103616, 'trainable_params': 22086592}
4.3 模型组装¶
构建双分支网络,使用EfficientNetB3
训练彩色眼底图片
,该分支输出结果的维度大小为[batchSize, 1000]
;使用ResNet34
训练3D OCT
图片,该分支输出结果的维度大小为[batchSize, 512]
,最后将两个分支的输出结果在通道维度进行Concat
操作,再经过一个全连接层后得到最终的分类预测结果。
class Model(nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.fundus_branch = EfficientNetB3(in_channels=3, num_class=1000)
self.oct_branch = ResNet34(in_channels=256)
# 最终的分类数为3
self.decision_branch = nn.Linear(1512 , 3)
def forward(self, fundus_img, oct_img):
b1 = self.fundus_branch(fundus_img)
b2 = self.oct_branch(oct_img)
logit = self.decision_branch(paddle.concat([b1, b2], 1))
return logit
4.4 模型可视化¶
model = Model()
paddle.summary(model, [(1, 3, 256, 256), (1, 256, 512, 512)])
--------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
================================================================================
Conv2D-167 [[1, 3, 256, 256]] [1, 38, 128, 128] 1,064
BatchNorm2D-128 [[1, 38, 128, 128]] [1, 38, 128, 128] 152
Conv2D-168 [[1, 38, 128, 128]] [1, 38, 128, 128] 1,482
BatchNorm2D-129 [[1, 38, 128, 128]] [1, 38, 128, 128] 152
Swish-59 [[1, 38, 128, 128]] [1, 38, 128, 128] 0
ConvBNLayer-40 [[1, 38, 128, 128]] [1, 38, 128, 128] 0
Conv2D-169 [[1, 38, 128, 128]] [1, 38, 128, 128] 380
BatchNorm2D-130 [[1, 38, 128, 128]] [1, 38, 128, 128] 152
Swish-60 [[1, 38, 128, 128]] [1, 38, 128, 128] 0
ConvBNLayer-41 [[1, 38, 128, 128]] [1, 38, 128, 128] 0
AdaptiveAvgPool2D-22 [[1, 38, 128, 128]] [1, 38, 1, 1] 0
Conv2D-170 [[1, 38, 1, 1]] [1, 9, 1, 1] 351
Swish-61 [[1, 9, 1, 1]] [1, 9, 1, 1] 0
Conv2D-171 [[1, 9, 1, 1]] [1, 38, 1, 1] 380
Sigmoid-20 [[1, 38, 1, 1]] [1, 38, 1, 1] 0
SE-20 [[1, 38, 128, 128]] [1, 38, 128, 128] 0
Conv2D-172 [[1, 38, 128, 128]] [1, 19, 128, 128] 741
BatchNorm2D-131 [[1, 19, 128, 128]] [1, 19, 128, 128] 76
Dropout-21 [[1, 19, 128, 128]] [1, 19, 128, 128] 0
Conv2D-173 [[1, 38, 128, 128]] [1, 19, 128, 128] 741
BatchNorm2D-132 [[1, 19, 128, 128]] [1, 19, 128, 128] 76
MBConv-20 [[1, 38, 128, 128]] [1, 19, 128, 128] 0
Conv2D-174 [[1, 19, 128, 128]] [1, 114, 128, 128] 2,280
BatchNorm2D-133 [[1, 114, 128, 128]] [1, 114, 128, 128] 456
Swish-62 [[1, 114, 128, 128]] [1, 114, 128, 128] 0
ConvBNLayer-42 [[1, 19, 128, 128]] [1, 114, 128, 128] 0
Conv2D-175 [[1, 114, 128, 128]] [1, 114, 128, 128] 1,140
BatchNorm2D-134 [[1, 114, 128, 128]] [1, 114, 128, 128] 456
Swish-63 [[1, 114, 128, 128]] [1, 114, 128, 128] 0
ConvBNLayer-43 [[1, 114, 128, 128]] [1, 114, 128, 128] 0
AdaptiveAvgPool2D-23 [[1, 114, 128, 128]] [1, 114, 1, 1] 0
Conv2D-176 [[1, 114, 1, 1]] [1, 28, 1, 1] 3,220
Swish-64 [[1, 28, 1, 1]] [1, 28, 1, 1] 0
Conv2D-177 [[1, 28, 1, 1]] [1, 114, 1, 1] 3,306
Sigmoid-21 [[1, 114, 1, 1]] [1, 114, 1, 1] 0
SE-21 [[1, 114, 128, 128]] [1, 114, 128, 128] 0
Conv2D-178 [[1, 114, 128, 128]] [1, 28, 128, 128] 3,220
BatchNorm2D-135 [[1, 28, 128, 128]] [1, 28, 128, 128] 112
Dropout-22 [[1, 28, 128, 128]] [1, 28, 128, 128] 0
Conv2D-179 [[1, 19, 128, 128]] [1, 28, 128, 128] 560
BatchNorm2D-136 [[1, 28, 128, 128]] [1, 28, 128, 128] 112
MBConv-21 [[1, 19, 128, 128]] [1, 28, 128, 128] 0
Conv2D-180 [[1, 28, 128, 128]] [1, 168, 128, 128] 4,872
BatchNorm2D-137 [[1, 168, 128, 128]] [1, 168, 128, 128] 672
Swish-65 [[1, 168, 128, 128]] [1, 168, 128, 128] 0
ConvBNLayer-44 [[1, 28, 128, 128]] [1, 168, 128, 128] 0
Conv2D-181 [[1, 168, 128, 128]] [1, 168, 64, 64] 1,680
BatchNorm2D-138 [[1, 168, 64, 64]] [1, 168, 64, 64] 672
Swish-66 [[1, 168, 64, 64]] [1, 168, 64, 64] 0
ConvBNLayer-45 [[1, 168, 128, 128]] [1, 168, 64, 64] 0
AdaptiveAvgPool2D-24 [[1, 168, 64, 64]] [1, 168, 1, 1] 0
Conv2D-182 [[1, 168, 1, 1]] [1, 42, 1, 1] 7,098
Swish-67 [[1, 42, 1, 1]] [1, 42, 1, 1] 0
Conv2D-183 [[1, 42, 1, 1]] [1, 168, 1, 1] 7,224
Sigmoid-22 [[1, 168, 1, 1]] [1, 168, 1, 1] 0
SE-22 [[1, 168, 64, 64]] [1, 168, 64, 64] 0
Conv2D-184 [[1, 168, 64, 64]] [1, 28, 64, 64] 4,732
BatchNorm2D-139 [[1, 28, 64, 64]] [1, 28, 64, 64] 112
Dropout-23 [[1, 28, 64, 64]] [1, 28, 64, 64] 0
MBConv-22 [[1, 28, 128, 128]] [1, 28, 64, 64] 0
Conv2D-186 [[1, 28, 64, 64]] [1, 168, 64, 64] 4,872
BatchNorm2D-141 [[1, 168, 64, 64]] [1, 168, 64, 64] 672
Swish-68 [[1, 168, 64, 64]] [1, 168, 64, 64] 0
ConvBNLayer-46 [[1, 28, 64, 64]] [1, 168, 64, 64] 0
Conv2D-187 [[1, 168, 64, 64]] [1, 168, 64, 64] 4,368
BatchNorm2D-142 [[1, 168, 64, 64]] [1, 168, 64, 64] 672
Swish-69 [[1, 168, 64, 64]] [1, 168, 64, 64] 0
ConvBNLayer-47 [[1, 168, 64, 64]] [1, 168, 64, 64] 0
AdaptiveAvgPool2D-25 [[1, 168, 64, 64]] [1, 168, 1, 1] 0
Conv2D-188 [[1, 168, 1, 1]] [1, 42, 1, 1] 7,098
Swish-70 [[1, 42, 1, 1]] [1, 42, 1, 1] 0
Conv2D-189 [[1, 42, 1, 1]] [1, 168, 1, 1] 7,224
Sigmoid-23 [[1, 168, 1, 1]] [1, 168, 1, 1] 0
SE-23 [[1, 168, 64, 64]] [1, 168, 64, 64] 0
Conv2D-190 [[1, 168, 64, 64]] [1, 48, 64, 64] 8,112
BatchNorm2D-143 [[1, 48, 64, 64]] [1, 48, 64, 64] 192
Dropout-24 [[1, 48, 64, 64]] [1, 48, 64, 64] 0
Conv2D-191 [[1, 28, 64, 64]] [1, 48, 64, 64] 1,392
BatchNorm2D-144 [[1, 48, 64, 64]] [1, 48, 64, 64] 192
MBConv-23 [[1, 28, 64, 64]] [1, 48, 64, 64] 0
Conv2D-192 [[1, 48, 64, 64]] [1, 288, 64, 64] 14,112
BatchNorm2D-145 [[1, 288, 64, 64]] [1, 288, 64, 64] 1,152
Swish-71 [[1, 288, 64, 64]] [1, 288, 64, 64] 0
ConvBNLayer-48 [[1, 48, 64, 64]] [1, 288, 64, 64] 0
Conv2D-193 [[1, 288, 64, 64]] [1, 288, 32, 32] 7,488
BatchNorm2D-146 [[1, 288, 32, 32]] [1, 288, 32, 32] 1,152
Swish-72 [[1, 288, 32, 32]] [1, 288, 32, 32] 0
ConvBNLayer-49 [[1, 288, 64, 64]] [1, 288, 32, 32] 0
AdaptiveAvgPool2D-26 [[1, 288, 32, 32]] [1, 288, 1, 1] 0
Conv2D-194 [[1, 288, 1, 1]] [1, 72, 1, 1] 20,808
Swish-73 [[1, 72, 1, 1]] [1, 72, 1, 1] 0
Conv2D-195 [[1, 72, 1, 1]] [1, 288, 1, 1] 21,024
Sigmoid-24 [[1, 288, 1, 1]] [1, 288, 1, 1] 0
SE-24 [[1, 288, 32, 32]] [1, 288, 32, 32] 0
Conv2D-196 [[1, 288, 32, 32]] [1, 48, 32, 32] 13,872
BatchNorm2D-147 [[1, 48, 32, 32]] [1, 48, 32, 32] 192
Dropout-25 [[1, 48, 32, 32]] [1, 48, 32, 32] 0
MBConv-24 [[1, 48, 64, 64]] [1, 48, 32, 32] 0
Conv2D-198 [[1, 48, 32, 32]] [1, 288, 32, 32] 14,112
BatchNorm2D-149 [[1, 288, 32, 32]] [1, 288, 32, 32] 1,152
Swish-74 [[1, 288, 32, 32]] [1, 288, 32, 32] 0
ConvBNLayer-50 [[1, 48, 32, 32]] [1, 288, 32, 32] 0
Conv2D-199 [[1, 288, 32, 32]] [1, 288, 32, 32] 2,880
BatchNorm2D-150 [[1, 288, 32, 32]] [1, 288, 32, 32] 1,152
Swish-75 [[1, 288, 32, 32]] [1, 288, 32, 32] 0
ConvBNLayer-51 [[1, 288, 32, 32]] [1, 288, 32, 32] 0
AdaptiveAvgPool2D-27 [[1, 288, 32, 32]] [1, 288, 1, 1] 0
Conv2D-200 [[1, 288, 1, 1]] [1, 72, 1, 1] 20,808
Swish-76 [[1, 72, 1, 1]] [1, 72, 1, 1] 0
Conv2D-201 [[1, 72, 1, 1]] [1, 288, 1, 1] 21,024
Sigmoid-25 [[1, 288, 1, 1]] [1, 288, 1, 1] 0
SE-25 [[1, 288, 32, 32]] [1, 288, 32, 32] 0
Conv2D-202 [[1, 288, 32, 32]] [1, 96, 32, 32] 27,744
BatchNorm2D-151 [[1, 96, 32, 32]] [1, 96, 32, 32] 384
Dropout-26 [[1, 96, 32, 32]] [1, 96, 32, 32] 0
Conv2D-203 [[1, 48, 32, 32]] [1, 96, 32, 32] 4,704
BatchNorm2D-152 [[1, 96, 32, 32]] [1, 96, 32, 32] 384
MBConv-25 [[1, 48, 32, 32]] [1, 96, 32, 32] 0
Conv2D-204 [[1, 96, 32, 32]] [1, 576, 32, 32] 55,872
BatchNorm2D-153 [[1, 576, 32, 32]] [1, 576, 32, 32] 2,304
Swish-77 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
ConvBNLayer-52 [[1, 96, 32, 32]] [1, 576, 32, 32] 0
Conv2D-205 [[1, 576, 32, 32]] [1, 576, 32, 32] 5,760
BatchNorm2D-154 [[1, 576, 32, 32]] [1, 576, 32, 32] 2,304
Swish-78 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
ConvBNLayer-53 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
AdaptiveAvgPool2D-28 [[1, 576, 32, 32]] [1, 576, 1, 1] 0
Conv2D-206 [[1, 576, 1, 1]] [1, 144, 1, 1] 83,088
Swish-79 [[1, 144, 1, 1]] [1, 144, 1, 1] 0
Conv2D-207 [[1, 144, 1, 1]] [1, 576, 1, 1] 83,520
Sigmoid-26 [[1, 576, 1, 1]] [1, 576, 1, 1] 0
SE-26 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
Conv2D-208 [[1, 576, 32, 32]] [1, 96, 32, 32] 55,392
BatchNorm2D-155 [[1, 96, 32, 32]] [1, 96, 32, 32] 384
Dropout-27 [[1, 96, 32, 32]] [1, 96, 32, 32] 0
Conv2D-209 [[1, 96, 32, 32]] [1, 96, 32, 32] 9,312
BatchNorm2D-156 [[1, 96, 32, 32]] [1, 96, 32, 32] 384
MBConv-26 [[1, 96, 32, 32]] [1, 96, 32, 32] 0
Conv2D-210 [[1, 96, 32, 32]] [1, 576, 32, 32] 55,872
BatchNorm2D-157 [[1, 576, 32, 32]] [1, 576, 32, 32] 2,304
Swish-80 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
ConvBNLayer-54 [[1, 96, 32, 32]] [1, 576, 32, 32] 0
Conv2D-211 [[1, 576, 32, 32]] [1, 576, 32, 32] 5,760
BatchNorm2D-158 [[1, 576, 32, 32]] [1, 576, 32, 32] 2,304
Swish-81 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
ConvBNLayer-55 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
AdaptiveAvgPool2D-29 [[1, 576, 32, 32]] [1, 576, 1, 1] 0
Conv2D-212 [[1, 576, 1, 1]] [1, 144, 1, 1] 83,088
Swish-82 [[1, 144, 1, 1]] [1, 144, 1, 1] 0
Conv2D-213 [[1, 144, 1, 1]] [1, 576, 1, 1] 83,520
Sigmoid-27 [[1, 576, 1, 1]] [1, 576, 1, 1] 0
SE-27 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
Conv2D-214 [[1, 576, 32, 32]] [1, 96, 32, 32] 55,392
BatchNorm2D-159 [[1, 96, 32, 32]] [1, 96, 32, 32] 384
Dropout-28 [[1, 96, 32, 32]] [1, 96, 32, 32] 0
Conv2D-215 [[1, 96, 32, 32]] [1, 96, 32, 32] 9,312
BatchNorm2D-160 [[1, 96, 32, 32]] [1, 96, 32, 32] 384
MBConv-27 [[1, 96, 32, 32]] [1, 96, 32, 32] 0
Conv2D-216 [[1, 96, 32, 32]] [1, 576, 32, 32] 55,872
BatchNorm2D-161 [[1, 576, 32, 32]] [1, 576, 32, 32] 2,304
Swish-83 [[1, 576, 32, 32]] [1, 576, 32, 32] 0
ConvBNLayer-56 [[1, 96, 32, 32]] [1, 576, 32, 32] 0
Conv2D-217 [[1, 576, 32, 32]] [1, 576, 16, 16] 5,760
BatchNorm2D-162 [[1, 576, 16, 16]] [1, 576, 16, 16] 2,304
Swish-84 [[1, 576, 16, 16]] [1, 576, 16, 16] 0
ConvBNLayer-57 [[1, 576, 32, 32]] [1, 576, 16, 16] 0
AdaptiveAvgPool2D-30 [[1, 576, 16, 16]] [1, 576, 1, 1] 0
Conv2D-218 [[1, 576, 1, 1]] [1, 144, 1, 1] 83,088
Swish-85 [[1, 144, 1, 1]] [1, 144, 1, 1] 0
Conv2D-219 [[1, 144, 1, 1]] [1, 576, 1, 1] 83,520
Sigmoid-28 [[1, 576, 1, 1]] [1, 576, 1, 1] 0
SE-28 [[1, 576, 16, 16]] [1, 576, 16, 16] 0
Conv2D-220 [[1, 576, 16, 16]] [1, 96, 16, 16] 55,392
BatchNorm2D-163 [[1, 96, 16, 16]] [1, 96, 16, 16] 384
Dropout-29 [[1, 96, 16, 16]] [1, 96, 16, 16] 0
MBConv-28 [[1, 96, 32, 32]] [1, 96, 16, 16] 0
Conv2D-222 [[1, 96, 16, 16]] [1, 576, 16, 16] 55,872
BatchNorm2D-165 [[1, 576, 16, 16]] [1, 576, 16, 16] 2,304
Swish-86 [[1, 576, 16, 16]] [1, 576, 16, 16] 0
ConvBNLayer-58 [[1, 96, 16, 16]] [1, 576, 16, 16] 0
Conv2D-223 [[1, 576, 16, 16]] [1, 576, 16, 16] 14,976
BatchNorm2D-166 [[1, 576, 16, 16]] [1, 576, 16, 16] 2,304
Swish-87 [[1, 576, 16, 16]] [1, 576, 16, 16] 0
ConvBNLayer-59 [[1, 576, 16, 16]] [1, 576, 16, 16] 0
AdaptiveAvgPool2D-31 [[1, 576, 16, 16]] [1, 576, 1, 1] 0
Conv2D-224 [[1, 576, 1, 1]] [1, 144, 1, 1] 83,088
Swish-88 [[1, 144, 1, 1]] [1, 144, 1, 1] 0
Conv2D-225 [[1, 144, 1, 1]] [1, 576, 1, 1] 83,520
Sigmoid-29 [[1, 576, 1, 1]] [1, 576, 1, 1] 0
SE-29 [[1, 576, 16, 16]] [1, 576, 16, 16] 0
Conv2D-226 [[1, 576, 16, 16]] [1, 134, 16, 16] 77,318
BatchNorm2D-167 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
Dropout-30 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-227 [[1, 96, 16, 16]] [1, 134, 16, 16] 12,998
BatchNorm2D-168 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
MBConv-29 [[1, 96, 16, 16]] [1, 134, 16, 16] 0
Conv2D-228 [[1, 134, 16, 16]] [1, 804, 16, 16] 108,540
BatchNorm2D-169 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-89 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-60 [[1, 134, 16, 16]] [1, 804, 16, 16] 0
Conv2D-229 [[1, 804, 16, 16]] [1, 804, 16, 16] 20,904
BatchNorm2D-170 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-90 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-61 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
AdaptiveAvgPool2D-32 [[1, 804, 16, 16]] [1, 804, 1, 1] 0
Conv2D-230 [[1, 804, 1, 1]] [1, 201, 1, 1] 161,805
Swish-91 [[1, 201, 1, 1]] [1, 201, 1, 1] 0
Conv2D-231 [[1, 201, 1, 1]] [1, 804, 1, 1] 162,408
Sigmoid-30 [[1, 804, 1, 1]] [1, 804, 1, 1] 0
SE-30 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
Conv2D-232 [[1, 804, 16, 16]] [1, 134, 16, 16] 107,870
BatchNorm2D-171 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
Dropout-31 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-233 [[1, 134, 16, 16]] [1, 134, 16, 16] 18,090
BatchNorm2D-172 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
MBConv-30 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-234 [[1, 134, 16, 16]] [1, 804, 16, 16] 108,540
BatchNorm2D-173 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-92 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-62 [[1, 134, 16, 16]] [1, 804, 16, 16] 0
Conv2D-235 [[1, 804, 16, 16]] [1, 804, 16, 16] 20,904
BatchNorm2D-174 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-93 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-63 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
AdaptiveAvgPool2D-33 [[1, 804, 16, 16]] [1, 804, 1, 1] 0
Conv2D-236 [[1, 804, 1, 1]] [1, 201, 1, 1] 161,805
Swish-94 [[1, 201, 1, 1]] [1, 201, 1, 1] 0
Conv2D-237 [[1, 201, 1, 1]] [1, 804, 1, 1] 162,408
Sigmoid-31 [[1, 804, 1, 1]] [1, 804, 1, 1] 0
SE-31 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
Conv2D-238 [[1, 804, 16, 16]] [1, 134, 16, 16] 107,870
BatchNorm2D-175 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
Dropout-32 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-239 [[1, 134, 16, 16]] [1, 134, 16, 16] 18,090
BatchNorm2D-176 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
MBConv-31 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-240 [[1, 134, 16, 16]] [1, 804, 16, 16] 108,540
BatchNorm2D-177 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-95 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-64 [[1, 134, 16, 16]] [1, 804, 16, 16] 0
Conv2D-241 [[1, 804, 16, 16]] [1, 804, 16, 16] 20,904
BatchNorm2D-178 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-96 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-65 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
AdaptiveAvgPool2D-34 [[1, 804, 16, 16]] [1, 804, 1, 1] 0
Conv2D-242 [[1, 804, 1, 1]] [1, 201, 1, 1] 161,805
Swish-97 [[1, 201, 1, 1]] [1, 201, 1, 1] 0
Conv2D-243 [[1, 201, 1, 1]] [1, 804, 1, 1] 162,408
Sigmoid-32 [[1, 804, 1, 1]] [1, 804, 1, 1] 0
SE-32 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
Conv2D-244 [[1, 804, 16, 16]] [1, 134, 16, 16] 107,870
BatchNorm2D-179 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
Dropout-33 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-245 [[1, 134, 16, 16]] [1, 134, 16, 16] 18,090
BatchNorm2D-180 [[1, 134, 16, 16]] [1, 134, 16, 16] 536
MBConv-32 [[1, 134, 16, 16]] [1, 134, 16, 16] 0
Conv2D-246 [[1, 134, 16, 16]] [1, 804, 16, 16] 108,540
BatchNorm2D-181 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-98 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-66 [[1, 134, 16, 16]] [1, 804, 16, 16] 0
Conv2D-247 [[1, 804, 16, 16]] [1, 804, 16, 16] 20,904
BatchNorm2D-182 [[1, 804, 16, 16]] [1, 804, 16, 16] 3,216
Swish-99 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
ConvBNLayer-67 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
AdaptiveAvgPool2D-35 [[1, 804, 16, 16]] [1, 804, 1, 1] 0
Conv2D-248 [[1, 804, 1, 1]] [1, 201, 1, 1] 161,805
Swish-100 [[1, 201, 1, 1]] [1, 201, 1, 1] 0
Conv2D-249 [[1, 201, 1, 1]] [1, 804, 1, 1] 162,408
Sigmoid-33 [[1, 804, 1, 1]] [1, 804, 1, 1] 0
SE-33 [[1, 804, 16, 16]] [1, 804, 16, 16] 0
Conv2D-250 [[1, 804, 16, 16]] [1, 230, 16, 16] 185,150
BatchNorm2D-183 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
Dropout-34 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-251 [[1, 134, 16, 16]] [1, 230, 16, 16] 31,050
BatchNorm2D-184 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
MBConv-33 [[1, 134, 16, 16]] [1, 230, 16, 16] 0
Conv2D-252 [[1, 230, 16, 16]] [1, 1380, 16, 16] 318,780
BatchNorm2D-185 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-101 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-68 [[1, 230, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-253 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 35,880
BatchNorm2D-186 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-102 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-69 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
AdaptiveAvgPool2D-36 [[1, 1380, 16, 16]] [1, 1380, 1, 1] 0
Conv2D-254 [[1, 1380, 1, 1]] [1, 345, 1, 1] 476,445
Swish-103 [[1, 345, 1, 1]] [1, 345, 1, 1] 0
Conv2D-255 [[1, 345, 1, 1]] [1, 1380, 1, 1] 477,480
Sigmoid-34 [[1, 1380, 1, 1]] [1, 1380, 1, 1] 0
SE-34 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-256 [[1, 1380, 16, 16]] [1, 230, 16, 16] 317,630
BatchNorm2D-187 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
Dropout-35 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-257 [[1, 230, 16, 16]] [1, 230, 16, 16] 53,130
BatchNorm2D-188 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
MBConv-34 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-258 [[1, 230, 16, 16]] [1, 1380, 16, 16] 318,780
BatchNorm2D-189 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-104 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-70 [[1, 230, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-259 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 35,880
BatchNorm2D-190 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-105 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-71 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
AdaptiveAvgPool2D-37 [[1, 1380, 16, 16]] [1, 1380, 1, 1] 0
Conv2D-260 [[1, 1380, 1, 1]] [1, 345, 1, 1] 476,445
Swish-106 [[1, 345, 1, 1]] [1, 345, 1, 1] 0
Conv2D-261 [[1, 345, 1, 1]] [1, 1380, 1, 1] 477,480
Sigmoid-35 [[1, 1380, 1, 1]] [1, 1380, 1, 1] 0
SE-35 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-262 [[1, 1380, 16, 16]] [1, 230, 16, 16] 317,630
BatchNorm2D-191 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
Dropout-36 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-263 [[1, 230, 16, 16]] [1, 230, 16, 16] 53,130
BatchNorm2D-192 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
MBConv-35 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-264 [[1, 230, 16, 16]] [1, 1380, 16, 16] 318,780
BatchNorm2D-193 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-107 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-72 [[1, 230, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-265 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 35,880
BatchNorm2D-194 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-108 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-73 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
AdaptiveAvgPool2D-38 [[1, 1380, 16, 16]] [1, 1380, 1, 1] 0
Conv2D-266 [[1, 1380, 1, 1]] [1, 345, 1, 1] 476,445
Swish-109 [[1, 345, 1, 1]] [1, 345, 1, 1] 0
Conv2D-267 [[1, 345, 1, 1]] [1, 1380, 1, 1] 477,480
Sigmoid-36 [[1, 1380, 1, 1]] [1, 1380, 1, 1] 0
SE-36 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-268 [[1, 1380, 16, 16]] [1, 230, 16, 16] 317,630
BatchNorm2D-195 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
Dropout-37 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-269 [[1, 230, 16, 16]] [1, 230, 16, 16] 53,130
BatchNorm2D-196 [[1, 230, 16, 16]] [1, 230, 16, 16] 920
MBConv-36 [[1, 230, 16, 16]] [1, 230, 16, 16] 0
Conv2D-270 [[1, 230, 16, 16]] [1, 1380, 16, 16] 318,780
BatchNorm2D-197 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 5,520
Swish-110 [[1, 1380, 16, 16]] [1, 1380, 16, 16] 0
ConvBNLayer-74 [[1, 230, 16, 16]] [1, 1380, 16, 16] 0
Conv2D-271 [[1, 1380, 16, 16]] [1, 1380, 8, 8] 35,880
BatchNorm2D-198 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 5,520
Swish-111 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 0
ConvBNLayer-75 [[1, 1380, 16, 16]] [1, 1380, 8, 8] 0
AdaptiveAvgPool2D-39 [[1, 1380, 8, 8]] [1, 1380, 1, 1] 0
Conv2D-272 [[1, 1380, 1, 1]] [1, 345, 1, 1] 476,445
Swish-112 [[1, 345, 1, 1]] [1, 345, 1, 1] 0
Conv2D-273 [[1, 345, 1, 1]] [1, 1380, 1, 1] 477,480
Sigmoid-37 [[1, 1380, 1, 1]] [1, 1380, 1, 1] 0
SE-37 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 0
Conv2D-274 [[1, 1380, 8, 8]] [1, 230, 8, 8] 317,630
BatchNorm2D-199 [[1, 230, 8, 8]] [1, 230, 8, 8] 920
Dropout-38 [[1, 230, 8, 8]] [1, 230, 8, 8] 0
MBConv-37 [[1, 230, 16, 16]] [1, 230, 8, 8] 0
Conv2D-276 [[1, 230, 8, 8]] [1, 1380, 8, 8] 318,780
BatchNorm2D-201 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 5,520
Swish-113 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 0
ConvBNLayer-76 [[1, 230, 8, 8]] [1, 1380, 8, 8] 0
Conv2D-277 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 13,800
BatchNorm2D-202 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 5,520
Swish-114 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 0
ConvBNLayer-77 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 0
AdaptiveAvgPool2D-40 [[1, 1380, 8, 8]] [1, 1380, 1, 1] 0
Conv2D-278 [[1, 1380, 1, 1]] [1, 345, 1, 1] 476,445
Swish-115 [[1, 345, 1, 1]] [1, 345, 1, 1] 0
Conv2D-279 [[1, 345, 1, 1]] [1, 1380, 1, 1] 477,480
Sigmoid-38 [[1, 1380, 1, 1]] [1, 1380, 1, 1] 0
SE-38 [[1, 1380, 8, 8]] [1, 1380, 8, 8] 0
Conv2D-280 [[1, 1380, 8, 8]] [1, 384, 8, 8] 530,304
BatchNorm2D-203 [[1, 384, 8, 8]] [1, 384, 8, 8] 1,536
Dropout-39 [[1, 384, 8, 8]] [1, 384, 8, 8] 0
Conv2D-281 [[1, 230, 8, 8]] [1, 384, 8, 8] 88,704
BatchNorm2D-204 [[1, 384, 8, 8]] [1, 384, 8, 8] 1,536
MBConv-38 [[1, 230, 8, 8]] [1, 384, 8, 8] 0
Conv2D-282 [[1, 384, 8, 8]] [1, 1280, 8, 8] 492,800
BatchNorm2D-205 [[1, 1280, 8, 8]] [1, 1280, 8, 8] 5,120
Swish-116 [[1, 1280, 8, 8]] [1, 1280, 8, 8] 0
ConvBNLayer-78 [[1, 384, 8, 8]] [1, 1280, 8, 8] 0
AdaptiveAvgPool2D-41 [[1, 1280, 8, 8]] [1, 1280, 1, 1] 0
Dropout-40 [[1, 1280, 1, 1]] [1, 1280, 1, 1] 0
Conv2D-283 [[1, 1280, 1, 1]] [1, 1000, 1, 1] 1,281,000
Classifier_Head-2 [[1, 384, 8, 8]] [1, 1000] 0
EfficientNet-2 [[1, 3, 256, 256]] [1, 1000] 0
Conv2D-284 [[1, 256, 512, 512]] [1, 64, 253, 253] 802,880
BatchNorm2D-206 [[1, 64, 253, 253]] [1, 64, 253, 253] 256
MaxPool2D-2 [[1, 64, 253, 253]] [1, 64, 126, 126] 0
Conv2D-285 [[1, 64, 126, 126]] [1, 64, 126, 126] 36,928
BatchNorm2D-207 [[1, 64, 126, 126]] [1, 64, 126, 126] 256
Conv2D-286 [[1, 64, 126, 126]] [1, 64, 126, 126] 36,928
BatchNorm2D-208 [[1, 64, 126, 126]] [1, 64, 126, 126] 256
Basicblock-17 [[1, 64, 126, 126]] [1, 64, 126, 126] 0
Conv2D-288 [[1, 64, 126, 126]] [1, 64, 126, 126] 36,928
BatchNorm2D-210 [[1, 64, 126, 126]] [1, 64, 126, 126] 256
Conv2D-289 [[1, 64, 126, 126]] [1, 64, 126, 126] 36,928
BatchNorm2D-211 [[1, 64, 126, 126]] [1, 64, 126, 126] 256
Basicblock-18 [[1, 64, 126, 126]] [1, 64, 126, 126] 0
Conv2D-291 [[1, 64, 126, 126]] [1, 64, 126, 126] 36,928
BatchNorm2D-213 [[1, 64, 126, 126]] [1, 64, 126, 126] 256
Conv2D-292 [[1, 64, 126, 126]] [1, 64, 126, 126] 36,928
BatchNorm2D-214 [[1, 64, 126, 126]] [1, 64, 126, 126] 256
Basicblock-19 [[1, 64, 126, 126]] [1, 64, 126, 126] 0
Conv2D-294 [[1, 64, 126, 126]] [1, 128, 63, 63] 73,856
BatchNorm2D-216 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Conv2D-295 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-217 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Conv2D-296 [[1, 64, 126, 126]] [1, 128, 63, 63] 8,320
BatchNorm2D-218 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Basicblock-20 [[1, 64, 126, 126]] [1, 128, 63, 63] 0
Conv2D-297 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-219 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Conv2D-298 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-220 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Basicblock-21 [[1, 128, 63, 63]] [1, 128, 63, 63] 0
Conv2D-300 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-222 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Conv2D-301 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-223 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Basicblock-22 [[1, 128, 63, 63]] [1, 128, 63, 63] 0
Conv2D-303 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-225 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Conv2D-304 [[1, 128, 63, 63]] [1, 128, 63, 63] 147,584
BatchNorm2D-226 [[1, 128, 63, 63]] [1, 128, 63, 63] 512
Basicblock-23 [[1, 128, 63, 63]] [1, 128, 63, 63] 0
Conv2D-306 [[1, 128, 63, 63]] [1, 256, 32, 32] 295,168
BatchNorm2D-228 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-307 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-229 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-308 [[1, 128, 63, 63]] [1, 256, 32, 32] 33,024
BatchNorm2D-230 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Basicblock-24 [[1, 128, 63, 63]] [1, 256, 32, 32] 0
Conv2D-309 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-231 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-310 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-232 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Basicblock-25 [[1, 256, 32, 32]] [1, 256, 32, 32] 0
Conv2D-312 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-234 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-313 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-235 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Basicblock-26 [[1, 256, 32, 32]] [1, 256, 32, 32] 0
Conv2D-315 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-237 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-316 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-238 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Basicblock-27 [[1, 256, 32, 32]] [1, 256, 32, 32] 0
Conv2D-318 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-240 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-319 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-241 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Basicblock-28 [[1, 256, 32, 32]] [1, 256, 32, 32] 0
Conv2D-321 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-243 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Conv2D-322 [[1, 256, 32, 32]] [1, 256, 32, 32] 590,080
BatchNorm2D-244 [[1, 256, 32, 32]] [1, 256, 32, 32] 1,024
Basicblock-29 [[1, 256, 32, 32]] [1, 256, 32, 32] 0
Conv2D-324 [[1, 256, 32, 32]] [1, 512, 16, 16] 1,180,160
BatchNorm2D-246 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Conv2D-325 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,359,808
BatchNorm2D-247 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Conv2D-326 [[1, 256, 32, 32]] [1, 512, 16, 16] 131,584
BatchNorm2D-248 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Basicblock-30 [[1, 256, 32, 32]] [1, 512, 16, 16] 0
Conv2D-327 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,359,808
BatchNorm2D-249 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Conv2D-328 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,359,808
BatchNorm2D-250 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Basicblock-31 [[1, 512, 16, 16]] [1, 512, 16, 16] 0
Conv2D-330 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,359,808
BatchNorm2D-252 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Conv2D-331 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,359,808
BatchNorm2D-253 [[1, 512, 16, 16]] [1, 512, 16, 16] 2,048
Basicblock-32 [[1, 512, 16, 16]] [1, 512, 16, 16] 0
AdaptiveAvgPool2D-42 [[1, 512, 16, 16]] [1, 512, 1, 1] 0
ResNet-2 [[1, 256, 512, 512]] [1, 512] 0
Linear-1 [[1, 1512]] [1, 3] 4,539
================================================================================
Total params: 36,436,367
Trainable params: 36,353,075
Non-trainable params: 83,292
--------------------------------------------------------------------------------
Input size (MB): 256.75
Forward/backward pass size (MB): 1107.63
Params size (MB): 138.99
Estimated Total Size (MB): 1503.37
--------------------------------------------------------------------------------
{'total_params': 36436367, 'trainable_params': 36353075}
五、模型训练¶
2000次迭代,训练时间大概需要8小时。
# 训练逻辑
def train(model, iters, train_dataloader, val_dataloader, optimizer, criterion, log_interval, eval_interval):
iter = 0
model.train()
avg_loss_list = []
avg_kappa_list = []
best_kappa = 0.5
while iter < iters:
for data in train_dataloader:
iter += 1
if iter > iters:
break
fundus_imgs = (data[0] / 255.).astype("float32")
oct_imgs = (data[1] / 255.).astype("float32")
labels = data[2].astype('int64')
logits = model(fundus_imgs, oct_imgs)
loss = criterion(logits, labels)
for p, l in zip(logits.numpy().argmax(1), labels.numpy()):
avg_kappa_list.append([p, l])
loss.backward()
optimizer.step()
model.clear_gradients()
avg_loss_list.append(loss.numpy()[0])
if iter % log_interval == 0:
avg_loss = np.array(avg_loss_list).mean()
avg_kappa_list = np.array(avg_kappa_list)
# 计算Cohen’s kappa分数
avg_kappa = cohen_kappa_score(avg_kappa_list[:, 0], avg_kappa_list[:, 1], weights='quadratic')
avg_loss_list = []
avg_kappa_list = []
print("[TRAIN] iter={}/{} avg_loss={:.4f} avg_kappa={:.4f}".format(iter, iters, avg_loss, avg_kappa))
if iter % eval_interval == 0:
avg_loss, avg_kappa = val(model, val_dataloader, criterion)
print("[EVAL] iter={}/{} avg_loss={:.4f} kappa={:.4f}".format(iter, iters, avg_loss, avg_kappa))
# 保存精度更好的模型
if avg_kappa >= best_kappa:
best_kappa = avg_kappa
paddle.save(model.state_dict(), os.path.join("best_model_{:.4f}".format(best_kappa), 'model.pdparams'))
model.train()
# 验证逻辑
def val(model, val_dataloader, criterion):
model.eval()
avg_loss_list = []
cache = []
with paddle.no_grad():
for data in val_dataloader:
fundus_imgs = (data[0] / 255.).astype("float32")
oct_imgs = (data[1] / 255.).astype("float32")
labels = data[2].astype('int64')
logits = model(fundus_imgs, oct_imgs)
for p, l in zip(logits.numpy().argmax(1), labels.numpy()):
cache.append([p, l])
loss = criterion(logits, labels)
avg_loss_list.append(loss.numpy()[0])
cache = np.array(cache)
kappa = cohen_kappa_score(cache[:, 0], cache[:, 1], weights='quadratic')
avg_loss = np.array(avg_loss_list).mean()
return avg_loss, kappa
# 读取数据的线程数
num_workers = 4
# batch_size大小
batchsize = 2
# 总的迭代次数
iters = 2000
# 优化器类型
optimizer_type = "adam"
# 初始学习率
init_lr = 1e-3
# Dataset
train_dataset = GAMMA_sub1_dataset(dataset_root=trainset_root,
img_transforms=img_train_transforms,
oct_transforms=oct_train_transforms,
filelists=train_filelists,
label_file=gt_file)
val_dataset = GAMMA_sub1_dataset(dataset_root=trainset_root,
img_transforms=img_val_transforms,
oct_transforms=oct_val_transforms,
filelists=val_filelists,
label_file=gt_file)
# DataLoader
train_loader = paddle.io.DataLoader(
train_dataset,
num_workers=num_workers,
batch_size=batchsize,
shuffle=True,
return_list=True,
use_shared_memory=False
)
val_loader = paddle.io.DataLoader(
val_dataset,
batch_size=batchsize,
num_workers=num_workers,
return_list=True,
use_shared_memory=False
)
if optimizer_type == "adam":
optimizer = paddle.optimizer.Adam(init_lr, parameters=model.parameters())
#----- 使用交叉熵作为损失函数 ------#
criterion = nn.CrossEntropyLoss()
#----- 训练模型 -----------------#
train(model, iters, train_loader, val_loader, optimizer, criterion, log_interval=50, eval_interval=50)
[TRAIN] iter=10/2000 avg_loss=2.1748 avg_kappa=-0.1957
[TRAIN] iter=20/2000 avg_loss=1.7035 avg_kappa=0.0774
[TRAIN] iter=30/2000 avg_loss=0.9511 avg_kappa=0.7015
[TRAIN] iter=40/2000 avg_loss=1.2922 avg_kappa=0.1606
[TRAIN] iter=50/2000 avg_loss=1.1163 avg_kappa=0.0000
[EVAL] iter=50/2000 avg_loss=10.3795 kappa=-0.2698
[TRAIN] iter=60/2000 avg_loss=1.1850 avg_kappa=-0.0606
[TRAIN] iter=70/2000 avg_loss=1.1729 avg_kappa=-0.2185
[TRAIN] iter=80/2000 avg_loss=1.0922 avg_kappa=0.1093
[TRAIN] iter=90/2000 avg_loss=0.9435 avg_kappa=0.2708
[TRAIN] iter=100/2000 avg_loss=1.2347 avg_kappa=0.0299
[EVAL] iter=100/2000 avg_loss=1.0698 kappa=0.3558
[TRAIN] iter=110/2000 avg_loss=1.0827 avg_kappa=0.2500
[TRAIN] iter=120/2000 avg_loss=1.0206 avg_kappa=-0.1354
[TRAIN] iter=130/2000 avg_loss=1.4113 avg_kappa=0.0968
[TRAIN] iter=140/2000 avg_loss=0.9458 avg_kappa=0.0625
[TRAIN] iter=150/2000 avg_loss=1.1128 avg_kappa=0.0000
[EVAL] iter=150/2000 avg_loss=2.1132 kappa=0.2978
[TRAIN] iter=160/2000 avg_loss=1.0617 avg_kappa=0.3791
[TRAIN] iter=170/2000 avg_loss=1.0624 avg_kappa=0.2808
[TRAIN] iter=180/2000 avg_loss=0.9705 avg_kappa=0.3561
[TRAIN] iter=190/2000 avg_loss=0.9573 avg_kappa=0.1078
[TRAIN] iter=200/2000 avg_loss=0.9949 avg_kappa=0.1176
[EVAL] iter=200/2000 avg_loss=2.2989 kappa=0.2982
[TRAIN] iter=210/2000 avg_loss=1.0670 avg_kappa=-0.0870
[TRAIN] iter=220/2000 avg_loss=0.8388 avg_kappa=-0.0596
[TRAIN] iter=230/2000 avg_loss=1.1535 avg_kappa=0.0625
[TRAIN] iter=240/2000 avg_loss=1.0679 avg_kappa=0.5335
[TRAIN] iter=250/2000 avg_loss=0.9604 avg_kappa=0.4318
[EVAL] iter=250/2000 avg_loss=1.2917 kappa=0.3895
[TRAIN] iter=260/2000 avg_loss=0.8026 avg_kappa=0.1964
[TRAIN] iter=270/2000 avg_loss=0.6450 avg_kappa=0.4536
[TRAIN] iter=280/2000 avg_loss=1.1303 avg_kappa=0.3511
[TRAIN] iter=290/2000 avg_loss=1.0567 avg_kappa=0.4462
[TRAIN] iter=300/2000 avg_loss=1.1012 avg_kappa=0.2647
[EVAL] iter=300/2000 avg_loss=1.0134 kappa=0.5455
[TRAIN] iter=310/2000 avg_loss=1.3930 avg_kappa=-0.3462
[TRAIN] iter=320/2000 avg_loss=0.9429 avg_kappa=0.0050
[TRAIN] iter=330/2000 avg_loss=0.8889 avg_kappa=0.0000
[TRAIN] iter=340/2000 avg_loss=1.1127 avg_kappa=0.1558
[TRAIN] iter=350/2000 avg_loss=1.0205 avg_kappa=0.1822
[EVAL] iter=350/2000 avg_loss=1.5745 kappa=0.0813
[TRAIN] iter=360/2000 avg_loss=1.2648 avg_kappa=0.0000
[TRAIN] iter=370/2000 avg_loss=1.0565 avg_kappa=-0.0714
[TRAIN] iter=380/2000 avg_loss=0.8164 avg_kappa=0.2857
[TRAIN] iter=390/2000 avg_loss=1.0671 avg_kappa=0.0222
[TRAIN] iter=400/2000 avg_loss=1.0092 avg_kappa=0.5305
[EVAL] iter=400/2000 avg_loss=1.0891 kappa=0.2982
[TRAIN] iter=410/2000 avg_loss=0.8200 avg_kappa=0.5745
[TRAIN] iter=420/2000 avg_loss=0.9284 avg_kappa=0.3827
[TRAIN] iter=430/2000 avg_loss=0.9759 avg_kappa=0.1497
[TRAIN] iter=440/2000 avg_loss=0.9261 avg_kappa=0.2188
[TRAIN] iter=450/2000 avg_loss=0.7655 avg_kappa=0.7500
[EVAL] iter=450/2000 avg_loss=1.0015 kappa=0.3229
[TRAIN] iter=460/2000 avg_loss=1.0300 avg_kappa=0.4271
[TRAIN] iter=470/2000 avg_loss=1.1922 avg_kappa=0.0189
[TRAIN] iter=480/2000 avg_loss=1.0318 avg_kappa=0.5122
[TRAIN] iter=490/2000 avg_loss=0.6496 avg_kappa=0.6269
[TRAIN] iter=500/2000 avg_loss=0.7465 avg_kappa=0.6154
[EVAL] iter=500/2000 avg_loss=0.9899 kappa=0.4842
[TRAIN] iter=510/2000 avg_loss=1.1083 avg_kappa=0.1290
[TRAIN] iter=520/2000 avg_loss=1.0746 avg_kappa=0.2053
[TRAIN] iter=530/2000 avg_loss=0.8995 avg_kappa=0.3421
[TRAIN] iter=540/2000 avg_loss=1.0162 avg_kappa=0.4419
[TRAIN] iter=550/2000 avg_loss=0.7716 avg_kappa=0.4583
[EVAL] iter=550/2000 avg_loss=0.8828 kappa=0.1722
[TRAIN] iter=560/2000 avg_loss=0.7626 avg_kappa=0.4754
[TRAIN] iter=570/2000 avg_loss=1.2483 avg_kappa=0.3625
[TRAIN] iter=580/2000 avg_loss=1.0986 avg_kappa=0.2565
[TRAIN] iter=590/2000 avg_loss=0.7910 avg_kappa=0.7143
[TRAIN] iter=600/2000 avg_loss=0.9362 avg_kappa=0.2917
[EVAL] iter=600/2000 avg_loss=0.8705 kappa=0.3438
[TRAIN] iter=610/2000 avg_loss=0.9716 avg_kappa=0.1176
[TRAIN] iter=620/2000 avg_loss=0.9979 avg_kappa=-0.0191
[TRAIN] iter=630/2000 avg_loss=0.6981 avg_kappa=0.6535
[TRAIN] iter=640/2000 avg_loss=0.7475 avg_kappa=0.4444
[TRAIN] iter=650/2000 avg_loss=1.0678 avg_kappa=0.6835
[EVAL] iter=650/2000 avg_loss=1.3692 kappa=0.2368
[TRAIN] iter=660/2000 avg_loss=0.8503 avg_kappa=0.5161
[TRAIN] iter=670/2000 avg_loss=1.0898 avg_kappa=0.2460
[TRAIN] iter=680/2000 avg_loss=0.8219 avg_kappa=0.1632
[TRAIN] iter=690/2000 avg_loss=1.0924 avg_kappa=0.6667
[TRAIN] iter=700/2000 avg_loss=0.9786 avg_kappa=0.3784
[EVAL] iter=700/2000 avg_loss=2.2221 kappa=0.3046
[TRAIN] iter=710/2000 avg_loss=1.1610 avg_kappa=0.3182
[TRAIN] iter=720/2000 avg_loss=1.0428 avg_kappa=0.1453
[TRAIN] iter=730/2000 avg_loss=0.8828 avg_kappa=0.4500
[TRAIN] iter=740/2000 avg_loss=1.0252 avg_kappa=0.2105
[TRAIN] iter=750/2000 avg_loss=0.8088 avg_kappa=0.3443
[EVAL] iter=750/2000 avg_loss=2.0567 kappa=0.0813
[TRAIN] iter=760/2000 avg_loss=0.9734 avg_kappa=0.1579
[TRAIN] iter=770/2000 avg_loss=0.9835 avg_kappa=0.3636
[TRAIN] iter=780/2000 avg_loss=0.7815 avg_kappa=0.2568
[TRAIN] iter=790/2000 avg_loss=0.8366 avg_kappa=0.7799
[TRAIN] iter=800/2000 avg_loss=0.9148 avg_kappa=0.4860
[EVAL] iter=800/2000 avg_loss=0.7692 kappa=0.5909
[TRAIN] iter=810/2000 avg_loss=1.0086 avg_kappa=0.2500
[TRAIN] iter=820/2000 avg_loss=0.6844 avg_kappa=0.5588
[TRAIN] iter=830/2000 avg_loss=0.7571 avg_kappa=0.5455
[TRAIN] iter=840/2000 avg_loss=0.9979 avg_kappa=0.0415
[TRAIN] iter=850/2000 avg_loss=0.9125 avg_kappa=0.2029
[EVAL] iter=850/2000 avg_loss=1.3723 kappa=0.0000
[TRAIN] iter=860/2000 avg_loss=0.9698 avg_kappa=0.2143
[TRAIN] iter=870/2000 avg_loss=0.7990 avg_kappa=0.3902
[TRAIN] iter=880/2000 avg_loss=0.7721 avg_kappa=0.5098
[TRAIN] iter=890/2000 avg_loss=0.9212 avg_kappa=-0.0606
[TRAIN] iter=900/2000 avg_loss=0.8273 avg_kappa=0.2574
[EVAL] iter=900/2000 avg_loss=0.9112 kappa=0.3657
[TRAIN] iter=910/2000 avg_loss=0.8701 avg_kappa=0.6154
[TRAIN] iter=920/2000 avg_loss=0.9726 avg_kappa=0.1818
[TRAIN] iter=930/2000 avg_loss=0.8154 avg_kappa=0.4834
[TRAIN] iter=940/2000 avg_loss=1.0587 avg_kappa=0.4000
[TRAIN] iter=950/2000 avg_loss=1.1157 avg_kappa=0.3092
[EVAL] iter=950/2000 avg_loss=1.0636 kappa=0.3537
[TRAIN] iter=960/2000 avg_loss=0.8583 avg_kappa=0.5882
[TRAIN] iter=970/2000 avg_loss=1.0570 avg_kappa=0.1364
[TRAIN] iter=980/2000 avg_loss=1.1619 avg_kappa=0.2674
[TRAIN] iter=990/2000 avg_loss=0.8051 avg_kappa=0.6685
[TRAIN] iter=1000/2000 avg_loss=0.8298 avg_kappa=0.2609
[EVAL] iter=1000/2000 avg_loss=1.8720 kappa=0.0000
[TRAIN] iter=1010/2000 avg_loss=0.8890 avg_kappa=0.1304
[TRAIN] iter=1020/2000 avg_loss=0.7587 avg_kappa=0.6939
[TRAIN] iter=1030/2000 avg_loss=0.9327 avg_kappa=0.3831
[TRAIN] iter=1040/2000 avg_loss=1.0000 avg_kappa=0.4643
[TRAIN] iter=1050/2000 avg_loss=0.8043 avg_kappa=0.5522
[EVAL] iter=1050/2000 avg_loss=0.9665 kappa=0.2405
[TRAIN] iter=1060/2000 avg_loss=0.7553 avg_kappa=0.5370
[TRAIN] iter=1070/2000 avg_loss=0.8057 avg_kappa=0.2017
[TRAIN] iter=1080/2000 avg_loss=0.8780 avg_kappa=0.3182
[TRAIN] iter=1090/2000 avg_loss=0.9831 avg_kappa=0.2279
[TRAIN] iter=1100/2000 avg_loss=0.9947 avg_kappa=0.2083
[EVAL] iter=1100/2000 avg_loss=0.8967 kappa=0.7635
[TRAIN] iter=1110/2000 avg_loss=0.9275 avg_kappa=0.0809
[TRAIN] iter=1120/2000 avg_loss=1.0070 avg_kappa=-0.0811
[TRAIN] iter=1130/2000 avg_loss=0.8866 avg_kappa=0.4444
[TRAIN] iter=1140/2000 avg_loss=0.7797 avg_kappa=0.4681
[TRAIN] iter=1150/2000 avg_loss=0.6757 avg_kappa=0.4811
[EVAL] iter=1150/2000 avg_loss=1.6824 kappa=0.2553
[TRAIN] iter=1160/2000 avg_loss=0.8656 avg_kappa=0.3966
[TRAIN] iter=1170/2000 avg_loss=0.5916 avg_kappa=0.6970
[TRAIN] iter=1180/2000 avg_loss=1.1541 avg_kappa=-0.1677
[TRAIN] iter=1190/2000 avg_loss=0.7411 avg_kappa=0.6552
[TRAIN] iter=1200/2000 avg_loss=1.1016 avg_kappa=0.4545
[EVAL] iter=1200/2000 avg_loss=1.0677 kappa=0.4091
[TRAIN] iter=1210/2000 avg_loss=0.9788 avg_kappa=0.4366
[TRAIN] iter=1220/2000 avg_loss=0.7062 avg_kappa=0.5516
[TRAIN] iter=1230/2000 avg_loss=0.8960 avg_kappa=0.4583
[TRAIN] iter=1240/2000 avg_loss=0.8596 avg_kappa=0.2273
[TRAIN] iter=1250/2000 avg_loss=0.9534 avg_kappa=0.3991
[EVAL] iter=1250/2000 avg_loss=1.0067 kappa=0.1822
[TRAIN] iter=1260/2000 avg_loss=1.0190 avg_kappa=0.3902
[TRAIN] iter=1270/2000 avg_loss=0.7897 avg_kappa=0.4615
[TRAIN] iter=1280/2000 avg_loss=0.6785 avg_kappa=0.6721
[TRAIN] iter=1290/2000 avg_loss=1.1654 avg_kappa=0.0800
[TRAIN] iter=1300/2000 avg_loss=1.1622 avg_kappa=0.4000
[EVAL] iter=1300/2000 avg_loss=4.2278 kappa=0.3046
[TRAIN] iter=1310/2000 avg_loss=0.6945 avg_kappa=-0.1644
[TRAIN] iter=1320/2000 avg_loss=1.1275 avg_kappa=0.2063
[TRAIN] iter=1330/2000 avg_loss=1.1955 avg_kappa=0.0782
[TRAIN] iter=1340/2000 avg_loss=0.8959 avg_kappa=0.2803
[TRAIN] iter=1350/2000 avg_loss=0.6391 avg_kappa=0.5909
[EVAL] iter=1350/2000 avg_loss=1.1239 kappa=0.1900
[TRAIN] iter=1360/2000 avg_loss=0.9569 avg_kappa=0.2713
[TRAIN] iter=1370/2000 avg_loss=1.1284 avg_kappa=0.2898
[TRAIN] iter=1380/2000 avg_loss=0.9595 avg_kappa=0.3030
[TRAIN] iter=1390/2000 avg_loss=0.7360 avg_kappa=0.5690
[TRAIN] iter=1400/2000 avg_loss=0.4813 avg_kappa=0.7500
[EVAL] iter=1400/2000 avg_loss=0.9864 kappa=0.4538
[TRAIN] iter=1410/2000 avg_loss=0.8431 avg_kappa=0.4595
[TRAIN] iter=1420/2000 avg_loss=1.0595 avg_kappa=0.2208
[TRAIN] iter=1430/2000 avg_loss=0.9496 avg_kappa=0.3333
[TRAIN] iter=1440/2000 avg_loss=0.7636 avg_kappa=0.5779
[TRAIN] iter=1450/2000 avg_loss=0.8095 avg_kappa=0.5312
[EVAL] iter=1450/2000 avg_loss=1.5890 kappa=-0.0601
[TRAIN] iter=1460/2000 avg_loss=0.9720 avg_kappa=0.1532
[TRAIN] iter=1470/2000 avg_loss=0.7457 avg_kappa=0.5714
[TRAIN] iter=1480/2000 avg_loss=0.7661 avg_kappa=0.7287
[TRAIN] iter=1490/2000 avg_loss=0.6770 avg_kappa=0.3617
[TRAIN] iter=1500/2000 avg_loss=0.8323 avg_kappa=0.6444
[EVAL] iter=1500/2000 avg_loss=1.3753 kappa=0.0816
[TRAIN] iter=1510/2000 avg_loss=0.6356 avg_kappa=0.5408
[TRAIN] iter=1520/2000 avg_loss=1.1022 avg_kappa=0.1093
[TRAIN] iter=1530/2000 avg_loss=0.9522 avg_kappa=0.3116
[TRAIN] iter=1540/2000 avg_loss=0.9507 avg_kappa=0.2063
[TRAIN] iter=1550/2000 avg_loss=0.6121 avg_kappa=0.7682
[EVAL] iter=1550/2000 avg_loss=0.8446 kappa=0.5434
[TRAIN] iter=1560/2000 avg_loss=0.7899 avg_kappa=0.4643
[TRAIN] iter=1570/2000 avg_loss=0.7781 avg_kappa=0.5769
[TRAIN] iter=1580/2000 avg_loss=0.6827 avg_kappa=0.4397
[TRAIN] iter=1590/2000 avg_loss=0.7548 avg_kappa=0.4545
[TRAIN] iter=1600/2000 avg_loss=0.6444 avg_kappa=0.6983
[EVAL] iter=1600/2000 avg_loss=1.2631 kappa=0.1905
[TRAIN] iter=1610/2000 avg_loss=0.6851 avg_kappa=0.6897
[TRAIN] iter=1620/2000 avg_loss=0.4364 avg_kappa=0.8500
[TRAIN] iter=1630/2000 avg_loss=0.7269 avg_kappa=0.5192
[TRAIN] iter=1640/2000 avg_loss=0.6269 avg_kappa=0.6284
[TRAIN] iter=1650/2000 avg_loss=0.9471 avg_kappa=0.3750
[EVAL] iter=1650/2000 avg_loss=1.1361 kappa=0.1864
[TRAIN] iter=1660/2000 avg_loss=0.9059 avg_kappa=0.4667
[TRAIN] iter=1670/2000 avg_loss=0.6955 avg_kappa=0.7170
[TRAIN] iter=1680/2000 avg_loss=0.7104 avg_kappa=0.5455
[TRAIN] iter=1690/2000 avg_loss=0.9699 avg_kappa=0.5556
[TRAIN] iter=1700/2000 avg_loss=0.5785 avg_kappa=0.6707
[EVAL] iter=1700/2000 avg_loss=1.9127 kappa=-0.0123
[TRAIN] iter=1710/2000 avg_loss=0.9138 avg_kappa=0.5455
[TRAIN] iter=1720/2000 avg_loss=0.6990 avg_kappa=0.3000
[TRAIN] iter=1730/2000 avg_loss=0.5126 avg_kappa=0.5322
[TRAIN] iter=1740/2000 avg_loss=1.3681 avg_kappa=0.0033
[TRAIN] iter=1750/2000 avg_loss=1.0812 avg_kappa=0.3359
[EVAL] iter=1750/2000 avg_loss=1.4875 kappa=-0.3081
[TRAIN] iter=1760/2000 avg_loss=0.7500 avg_kappa=0.5708
[TRAIN] iter=1770/2000 avg_loss=0.5901 avg_kappa=0.7967
[TRAIN] iter=1780/2000 avg_loss=0.6868 avg_kappa=0.8264
[TRAIN] iter=1790/2000 avg_loss=0.7448 avg_kappa=0.6818
[TRAIN] iter=1800/2000 avg_loss=1.1918 avg_kappa=0.2105
[EVAL] iter=1800/2000 avg_loss=2.2048 kappa=-0.0847
[TRAIN] iter=1810/2000 avg_loss=0.5920 avg_kappa=0.7863
[TRAIN] iter=1820/2000 avg_loss=0.7900 avg_kappa=0.7840
[TRAIN] iter=1830/2000 avg_loss=0.6585 avg_kappa=0.5570
[TRAIN] iter=1840/2000 avg_loss=0.6186 avg_kappa=0.7917
[TRAIN] iter=1850/2000 avg_loss=0.5335 avg_kappa=0.6721
[EVAL] iter=1850/2000 avg_loss=0.8442 kappa=0.3822
[TRAIN] iter=1860/2000 avg_loss=0.7680 avg_kappa=0.7143
[TRAIN] iter=1870/2000 avg_loss=0.8156 avg_kappa=0.5139
[TRAIN] iter=1880/2000 avg_loss=0.5862 avg_kappa=0.3590
[TRAIN] iter=1890/2000 avg_loss=1.2277 avg_kappa=0.3810
[TRAIN] iter=1900/2000 avg_loss=1.0219 avg_kappa=0.2965
[EVAL] iter=1900/2000 avg_loss=1.0719 kappa=0.4101
[TRAIN] iter=1910/2000 avg_loss=0.6899 avg_kappa=0.6721
[TRAIN] iter=1920/2000 avg_loss=0.5552 avg_kappa=0.7078
[TRAIN] iter=1930/2000 avg_loss=0.5540 avg_kappa=0.7818
[TRAIN] iter=1940/2000 avg_loss=0.5370 avg_kappa=0.4578
[TRAIN] iter=1950/2000 avg_loss=0.6274 avg_kappa=0.3689
[EVAL] iter=1950/2000 avg_loss=0.9100 kappa=0.2674
[TRAIN] iter=1960/2000 avg_loss=0.9596 avg_kappa=0.5455
[TRAIN] iter=1970/2000 avg_loss=0.7679 avg_kappa=0.6429
[TRAIN] iter=1980/2000 avg_loss=0.6650 avg_kappa=0.7009
[TRAIN] iter=1990/2000 avg_loss=0.6712 avg_kappa=0.6541
[TRAIN] iter=2000/2000 avg_loss=0.4856 avg_kappa=0.8214
[EVAL] iter=2000/2000 avg_loss=1.3298 kappa=0.0813
注意:如果在
AI Studio
训练过程出现数据加载错误,可能是因为训练集文件夹下多了一个.ipynb_checkpoints
文件,使用rm -r .ipynb_checkpoints
命令删去即可。
六、模型预测¶
6.1 模型预测¶
# 加载训练权重
best_model_path = "./best_model_0.9344/model.pdparams"
model = Model()
para_state_dict = paddle.load(best_model_path)
model.set_state_dict(para_state_dict)
model.eval()
img_test_transforms = trans.Compose([
trans.Resize(image_size)
])
oct_test_transforms = trans.Compose([
trans.CenterCrop(oct_img_size)
])
test_dataset = GAMMA_sub1_dataset(dataset_root=testset_root,
img_transforms=img_test_transforms,
oct_transforms=oct_test_transforms,
mode='test')
cache = []
for fundus_img, oct_img, idx in tqdm(test_dataset):
fundus_img = fundus_img[np.newaxis, ...]
oct_img = oct_img[np.newaxis, ...]
fundus_img = paddle.to_tensor((fundus_img / 255.).astype("float32"))
oct_img = paddle.to_tensor((oct_img / 255.).astype("float32"))
logits = model(fundus_img, oct_img)
cache.append([idx, logits.numpy().argmax(1)])
15%|█▍ | 9/62 [01:59<11:43, 13.27s/it]Premature end of JPEG file
100%|██████████| 62/62 [13:47<00:00, 13.34s/it]
6.2 展示预测结果¶
展示原图片以及模型预测类别,其中non
代表正常,early
代表早期,mid_advanced
代表中晚期。
# 类别映射字典
class_map = {0: 'non', 1: 'early', 2: 'mid_advanced'}
# 调整子图间距
plt.tight_layout()
plt.subplots_adjust(wspace=0.6, hspace=0.6)
for i in range(6):
index, pred_class = cache[i]
fundus_img = cv2.imread('/home/aistudio/Glaucoma_grading/testing/multi-modality_images/{}/{}.jpg'.format(index, index))
plt.subplot(2, 3, i + 1)
plt.imshow(fundus_img[:, :, ::-1])
plt.axis("off")
# 输出当前数据以及预测类别
plt.title('data:{}, pred:{}'.format(index, class_map[pred_class]))