CNN (Convolution Neural Network) の構成として、
自作CNN、自作ResNet、公開ResNetの3通りを考えます。
ソースコードCNN.py内の変数MODEL=0/1/2で選択します。
図3-1に自作CNNのネットワーク構成を示します[3]。
畳み込み部については下の(1)~(4)の4通りを考え、それぞれCNN2/CNN4/CNN6/CNN8と呼びます。
畳み込み部は畳み込み(conv)、バッチ正規化(bn)、ReLU活性化関数(relu)、
最大プーリング(max pool)、ドロップアウト(dropout)から成ります。
その後全結合(fc)部を通ります。
多値分類であるために損失にはCrossEntropyLoss関数を用います。
図3-1 自作CNNのネットワーク構成
畳み込みの出力サイズは次式で計算されます。
出力サイズ = [(入力サイズ + 2*パディング - カーネルサイズ) / ストライド] + 1
ここで、入力サイズ=N、パディング=1、カーネルサイズ=3、ストライド=1 とすると
出力サイズ = [(N + 2*1 - 3) / 1] + 1 = N
となり入力サイズと同じになります。
プーリングのサイズ=2とするとプーリングを行うごとに画素サイズは半分になります。
リスト3-1に自作CNNのネットワーククラスのソースコードを示します。
コンストラクタは以下の4個の引数を持ちます。
・in_size : 入力データの色数と縦横の画素数(MNISTのときは(1,28,28))
・channels : チャンネル数(配列の大きさ=畳み込み数)
・dropout : ドロップアウト確率(配列の大きさ=行数+1)
・out_classes : 出力分類数
リスト3-1 自作CNNのソースコード(CNN4, myCNN.py)
1 class CNN4(nn.Module): 2 3 def __init__(self, in_size, channels, dropout, out_classes): 4 super().__init__() 5 6 assert (len(in_size) == 3) and (len(channels) == 4) and (len(dropout) == 3) 7 out_pixel = (int(np.ceil(in_size[1] / 4)), int(np.ceil(in_size[2] / 4))) 8 9 self.conv1 = nn.Conv2d(in_size[0], channels[0], kernel_size=3, stride=1, padding=1) 10 self.conv2 = nn.Conv2d(channels[0], channels[1], kernel_size=3, stride=1, padding=1) 11 self.conv3 = nn.Conv2d(channels[1], channels[2], kernel_size=3, stride=1, padding=1) 12 self.conv4 = nn.Conv2d(channels[2], channels[3], kernel_size=3, stride=1, padding=1) 13 self.bn1 = nn.BatchNorm2d(channels[0]) 14 self.bn2 = nn.BatchNorm2d(channels[1]) 15 self.bn3 = nn.BatchNorm2d(channels[2]) 16 self.bn4 = nn.BatchNorm2d(channels[3]) 17 self.dropout1 = nn.Dropout(dropout[0]) 18 self.dropout2 = nn.Dropout(dropout[1]) 19 self.dropout3 = nn.Dropout(dropout[2]) 20 self.relu = nn.ReLU(inplace=True) 21 self.maxpool = nn.MaxPool2d(2, ceil_mode=True) 22 self.flatten = nn.Flatten() 23 self.fc1 = nn.Linear(out_pixel[0] * out_pixel[1] * channels[-1], channels[-1]) 24 self.fc2 = nn.Linear(channels[-1], out_classes) 25 26 self.features = nn.Sequential( 27 self.conv1, self.bn1, self.relu, 28 self.conv2, self.bn2, self.relu, self.maxpool, self.dropout1, 29 self.conv3, self.bn3, self.relu, 30 self.conv4, self.bn4, self.relu, self.maxpool, self.dropout2, 31 ) 32 33 self.classifier = nn.Sequential( 34 self.fc1, 35 self.relu, 36 self.dropout3, 37 self.fc2, 38 ) 39 40 def forward(self, x): 41 x = self.features(x) 42 x = self.flatten(x) 43 x = self.classifier(x) 44 return x
自作CNNのネットワークのprint文の出力(オプション)はリスト3-2の通りです。
ソースコードの記述そのままです。
リスト3-2 自作CNNのネットワークの出力(1)(CNN4, MNIST, 128チャンネル)
CNN4( (conv1): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (dropout1): Dropout(p=0.3, inplace=False) (dropout2): Dropout(p=0.3, inplace=False) (dropout3): Dropout(p=0.3, inplace=False) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True) (flatten): Flatten(start_dim=1, end_dim=-1) (fc1): Linear(in_features=6272, out_features=128, bias=True) (fc2): Linear(in_features=128, out_features=10, bias=True) (features): Sequential( (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True) (7): Dropout(p=0.3, inplace=False) (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (10): ReLU(inplace=True) (11): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (12): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (13): ReLU(inplace=True) (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True) (15): Dropout(p=0.3, inplace=False) ) (classifier): Sequential( (0): Linear(in_features=6272, out_features=128, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.3, inplace=False) (3): Linear(in_features=128, out_features=10, bias=True) ) )
自作CNNのsummary関数の出力(オプション)はリスト3-3の通りです。
"Output Shape"の最初の50はミニバッチサイズです。
PyTorchでsummary関数を使用するには、
Anaconda Prompt で下記を実行して torchinfo をインストールする必要があります。
$ conda install conda-forge::torchinfo
その後ソースコードCNN.pyに以下の2行を挿入します。
from torchinfo import summary (main関数定義の前)
summary(net, (50, 1, 28, 28)) (net作成の後、MNISTの場合)
summary関数の出力はSpyderのIPythonコンソールには行われないので、
Anaconda Prompt で下記のコマンドを実行する必要があります。
$ python CNN.py
リスト3-3 自作CNNのネットワークの出力(2)(CNN4, MNIST, 128チャンネル)
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== CNN4 [50, 10] -- ├─Sequential: 1-1 [50, 128, 7, 7] 443,520 │ └─Conv2d: 2-1 [50, 128, 28, 28] 1,280 │ └─BatchNorm2d: 2-2 [50, 128, 28, 28] 256 ├─Sequential: 1-8 -- (recursive) │ └─ReLU: 2-3 [50, 128, 28, 28] -- ├─Sequential: 1-9 -- (recursive) │ └─Conv2d: 2-4 [50, 128, 28, 28] 147,584 │ └─BatchNorm2d: 2-5 [50, 128, 28, 28] 256 ├─Sequential: 1-8 -- (recursive) │ └─ReLU: 2-6 [50, 128, 28, 28] -- ├─Sequential: 1-9 -- (recursive) │ └─MaxPool2d: 2-7 [50, 128, 14, 14] -- │ └─Dropout: 2-8 [50, 128, 14, 14] -- │ └─Conv2d: 2-9 [50, 128, 14, 14] 147,584 │ └─BatchNorm2d: 2-10 [50, 128, 14, 14] 256 ├─Sequential: 1-8 -- (recursive) │ └─ReLU: 2-11 [50, 128, 14, 14] -- ├─Sequential: 1-9 -- (recursive) │ └─Conv2d: 2-12 [50, 128, 14, 14] 147,584 │ └─BatchNorm2d: 2-13 [50, 128, 14, 14] 256 ├─Sequential: 1-8 -- (recursive) │ └─ReLU: 2-14 [50, 128, 14, 14] -- ├─Sequential: 1-9 -- (recursive) │ └─MaxPool2d: 2-15 [50, 128, 7, 7] -- │ └─Dropout: 2-16 [50, 128, 7, 7] -- ├─Flatten: 1-10 [50, 6272] -- ├─Sequential: 1-11 [50, 10] -- │ └─Linear: 2-17 [50, 128] 802,944 │ └─ReLU: 2-18 [50, 128] -- │ └─Dropout: 2-19 [50, 128] -- │ └─Linear: 2-20 [50, 10] 1,290 ========================================================================================== Total params: 1,692,810 Trainable params: 1,692,810 Non-trainable params: 0 Total mult-adds (Units.GIGABYTES): 8.77 ========================================================================================== Input size (MB): 0.16 Forward/backward pass size (MB): 200.76 Params size (MB): 5.00 Estimated Total Size (MB): 205.91 ==========================================================================================
ResNetは層数が多いときの勾配消失問題を解決したネットワークです[5]-[8]。
図3-2に自作ResNetのネットワーク構成を示します。
任意個数のResNet単位blockが並びます。
ResNet単位blockはショートカット加算が特徴です。
図では1個のResNet単位blockが2個の畳み込みを含みます。
リスト3-4に自作ResNetのネットワーククラスのソースコードを示します。
コンストラクタは以下の引数を持ちます。
リスト3-4 自作ResNetのソースコード(myResNet.py)
1 class BasicBlock(nn.Module): 2 3 def __init__(self, channels, dropout): 4 super().__init__() 5 6 self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False) 7 self.bn1 = nn.BatchNorm2d(channels) 8 self.relu = nn.ReLU(inplace=True) 9 self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False) 10 self.bn2 = nn.BatchNorm2d(channels) 11 self.dropout1 = nn.Dropout(dropout) 12 13 def forward(self, x: torch.Tensor) -> torch.Tensor: 14 out = self.conv1(x) 15 out = self.bn1(out) 16 out = self.relu(out) 17 out = self.conv2(out) 18 out = self.bn2(out) 19 out += x # ResNet 20 out = self.relu(out) 21 out = self.dropout1(out) 22 23 return out 24 25 class ResNet(nn.Module): 26 27 def __init__(self, in_filters, num_blocks, channels, dropout, out_classes): 28 super().__init__() 29 30 self.conv1 = nn.Conv2d(in_filters, channels, kernel_size=3, stride=1, padding=1)#, bias=False) 31 self.bn1 = nn.BatchNorm2d(channels) 32 self.relu = nn.ReLU(inplace=True) 33 self.maxpool = nn.MaxPool2d(2, ceil_mode=True) 34 self.dropout1 = nn.Dropout(dropout) 35 36 self.blocks = nn.Sequential(*[BasicBlock(channels, dropout) for _ in range(num_blocks)]) 37 38 self.avgpool = nn.AdaptiveAvgPool2d(1) 39 self.flatten = nn.Flatten() 40 self.fc = nn.Linear(channels, out_classes) 41 42 def forward(self, x): 43 x = self.conv1(x) 44 x = self.bn1(x) 45 x = self.relu(x) 46 x = self.maxpool(x) 47 x = self.dropout1(x) 48 49 x = self.blocks(x) 50 51 x = self.avgpool(x) 52 x = self.flatten(x) 53 x = self.fc(x) 54 55 return x
自作ResNetのネットワークのprint文の出力(オプション)はリスト3-5の通りです。
リスト3-5 自作ResNetのネットワークの出力(1)(MNIST、ブロック数=2、チャンネル数=128)
ResNet( (conv1): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True) (dropout1): Dropout(p=0.3, inplace=False) (blocks): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (dropout1): Dropout(p=0.3, inplace=False) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (dropout1): Dropout(p=0.3, inplace=False) ) ) (avgpool): AdaptiveAvgPool2d(output_size=1) (flatten): Flatten(start_dim=1, end_dim=-1) (fc): Linear(in_features=128, out_features=10, bias=True) )
自作ResNetの summary(net, (50, 1, 28, 28)) 関数の出力(オプション)はリスト3-6の通りです。
"Output Shape"の最初の50はミニバッチサイズです。
リスト3-6 自作ResNetのネットワークの出力(2)(MNIST、ブロック数=2、チャンネル数=128)
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ResNet [50, 10] -- ├─Conv2d: 1-1 [50, 128, 28, 28] 1,280 ├─BatchNorm2d: 1-2 [50, 128, 28, 28] 256 ├─ReLU: 1-3 [50, 128, 28, 28] -- ├─MaxPool2d: 1-4 [50, 128, 14, 14] -- ├─Dropout: 1-5 [50, 128, 14, 14] -- ├─Sequential: 1-6 [50, 128, 14, 14] -- │ └─BasicBlock: 2-1 [50, 128, 14, 14] -- │ │ └─Conv2d: 3-1 [50, 128, 14, 14] 147,456 │ │ └─BatchNorm2d: 3-2 [50, 128, 14, 14] 256 │ │ └─ReLU: 3-3 [50, 128, 14, 14] -- │ │ └─Conv2d: 3-4 [50, 128, 14, 14] 147,456 │ │ └─BatchNorm2d: 3-5 [50, 128, 14, 14] 256 │ │ └─ReLU: 3-6 [50, 128, 14, 14] -- │ │ └─Dropout: 3-7 [50, 128, 14, 14] -- │ └─BasicBlock: 2-2 [50, 128, 14, 14] -- │ │ └─Conv2d: 3-8 [50, 128, 14, 14] 147,456 │ │ └─BatchNorm2d: 3-9 [50, 128, 14, 14] 256 │ │ └─ReLU: 3-10 [50, 128, 14, 14] -- │ │ └─Conv2d: 3-11 [50, 128, 14, 14] 147,456 │ │ └─BatchNorm2d: 3-12 [50, 128, 14, 14] 256 │ │ └─ReLU: 3-13 [50, 128, 14, 14] -- │ │ └─Dropout: 3-14 [50, 128, 14, 14] -- ├─AdaptiveAvgPool2d: 1-7 [50, 128, 1, 1] -- ├─Flatten: 1-8 [50, 128] -- ├─Linear: 1-9 [50, 10] 1,290 ========================================================================================== Total params: 593,674 Trainable params: 593,674 Non-trainable params: 0 Total mult-adds (Units.GIGABYTES): 5.83 ========================================================================================== Input size (MB): 0.16 Forward/backward pass size (MB): 160.57 Params size (MB): 2.37 Estimated Total Size (MB): 163.10 ==========================================================================================
PyTorchには多数のネットワークが組み込まれており簡単に使用することができます[5]。
それらは高度にチューニングされたパラメーター(weights:重み)を含んでいます。
公開ResNet(PyTorchに組み込まれたResNet)を使用するには、
下記の1~3行のいずれかを有効にしてください。
それぞれResNet18/ResNet34/ResNet50が有効になります。
なお、引数"weights=None"とすると学習済みのパラメーターは読み込まれずに自分で学習する必要があります。
公開ResNetはImageNet[9](224x224ピクセル, カラー, 1000分類)
向けに学習されているので4行目で出力分類数を変更することが必要です。
net = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) net = models.resnet34(weights=models.ResNet34_Weights.DEFAULT) net = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) net.fc = nn.Linear(net.fc.in_features, len(strclasses))
ResNet18のprint文の出力はリスト3-7の通りです。
リスト3-7 ResNet18のネットワークの出力(1)
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=512, out_features=10, bias=True) )
ResNet18の summary(net, (50, 3, 224, 224)) 関数の出力(オプション)はリスト3-8の通りです。
"Output Shape"の最初の50はミニバッチサイズです。
リスト3-8 ResNet18のネットワークの出力(2)
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ResNet [50, 10] -- ├─Conv2d: 1-1 [50, 64, 112, 112] 9,408 ├─BatchNorm2d: 1-2 [50, 64, 112, 112] 128 ├─ReLU: 1-3 [50, 64, 112, 112] -- ├─MaxPool2d: 1-4 [50, 64, 56, 56] -- ├─Sequential: 1-5 [50, 64, 56, 56] -- │ └─BasicBlock: 2-1 [50, 64, 56, 56] -- │ │ └─Conv2d: 3-1 [50, 64, 56, 56] 36,864 │ │ └─BatchNorm2d: 3-2 [50, 64, 56, 56] 128 │ │ └─ReLU: 3-3 [50, 64, 56, 56] -- │ │ └─Conv2d: 3-4 [50, 64, 56, 56] 36,864 │ │ └─BatchNorm2d: 3-5 [50, 64, 56, 56] 128 │ │ └─ReLU: 3-6 [50, 64, 56, 56] -- │ └─BasicBlock: 2-2 [50, 64, 56, 56] -- │ │ └─Conv2d: 3-7 [50, 64, 56, 56] 36,864 │ │ └─BatchNorm2d: 3-8 [50, 64, 56, 56] 128 │ │ └─ReLU: 3-9 [50, 64, 56, 56] -- │ │ └─Conv2d: 3-10 [50, 64, 56, 56] 36,864 │ │ └─BatchNorm2d: 3-11 [50, 64, 56, 56] 128 │ │ └─ReLU: 3-12 [50, 64, 56, 56] -- ├─Sequential: 1-6 [50, 128, 28, 28] -- │ └─BasicBlock: 2-3 [50, 128, 28, 28] -- │ │ └─Conv2d: 3-13 [50, 128, 28, 28] 73,728 │ │ └─BatchNorm2d: 3-14 [50, 128, 28, 28] 256 │ │ └─ReLU: 3-15 [50, 128, 28, 28] -- │ │ └─Conv2d: 3-16 [50, 128, 28, 28] 147,456 │ │ └─BatchNorm2d: 3-17 [50, 128, 28, 28] 256 │ │ └─Sequential: 3-18 [50, 128, 28, 28] 8,448 │ │ └─ReLU: 3-19 [50, 128, 28, 28] -- │ └─BasicBlock: 2-4 [50, 128, 28, 28] -- │ │ └─Conv2d: 3-20 [50, 128, 28, 28] 147,456 │ │ └─BatchNorm2d: 3-21 [50, 128, 28, 28] 256 │ │ └─ReLU: 3-22 [50, 128, 28, 28] -- │ │ └─Conv2d: 3-23 [50, 128, 28, 28] 147,456 │ │ └─BatchNorm2d: 3-24 [50, 128, 28, 28] 256 │ │ └─ReLU: 3-25 [50, 128, 28, 28] -- ├─Sequential: 1-7 [50, 256, 14, 14] -- │ └─BasicBlock: 2-5 [50, 256, 14, 14] -- │ │ └─Conv2d: 3-26 [50, 256, 14, 14] 294,912 │ │ └─BatchNorm2d: 3-27 [50, 256, 14, 14] 512 │ │ └─ReLU: 3-28 [50, 256, 14, 14] -- │ │ └─Conv2d: 3-29 [50, 256, 14, 14] 589,824 │ │ └─BatchNorm2d: 3-30 [50, 256, 14, 14] 512 │ │ └─Sequential: 3-31 [50, 256, 14, 14] 33,280 │ │ └─ReLU: 3-32 [50, 256, 14, 14] -- │ └─BasicBlock: 2-6 [50, 256, 14, 14] -- │ │ └─Conv2d: 3-33 [50, 256, 14, 14] 589,824 │ │ └─BatchNorm2d: 3-34 [50, 256, 14, 14] 512 │ │ └─ReLU: 3-35 [50, 256, 14, 14] -- │ │ └─Conv2d: 3-36 [50, 256, 14, 14] 589,824 │ │ └─BatchNorm2d: 3-37 [50, 256, 14, 14] 512 │ │ └─ReLU: 3-38 [50, 256, 14, 14] -- ├─Sequential: 1-8 [50, 512, 7, 7] -- │ └─BasicBlock: 2-7 [50, 512, 7, 7] -- │ │ └─Conv2d: 3-39 [50, 512, 7, 7] 1,179,648 │ │ └─BatchNorm2d: 3-40 [50, 512, 7, 7] 1,024 │ │ └─ReLU: 3-41 [50, 512, 7, 7] -- │ │ └─Conv2d: 3-42 [50, 512, 7, 7] 2,359,296 │ │ └─BatchNorm2d: 3-43 [50, 512, 7, 7] 1,024 │ │ └─Sequential: 3-44 [50, 512, 7, 7] 132,096 │ │ └─ReLU: 3-45 [50, 512, 7, 7] -- │ └─BasicBlock: 2-8 [50, 512, 7, 7] -- │ │ └─Conv2d: 3-46 [50, 512, 7, 7] 2,359,296 │ │ └─BatchNorm2d: 3-47 [50, 512, 7, 7] 1,024 │ │ └─ReLU: 3-48 [50, 512, 7, 7] -- │ │ └─Conv2d: 3-49 [50, 512, 7, 7] 2,359,296 │ │ └─BatchNorm2d: 3-50 [50, 512, 7, 7] 1,024 │ │ └─ReLU: 3-51 [50, 512, 7, 7] -- ├─AdaptiveAvgPool2d: 1-9 [50, 512, 1, 1] -- ├─Linear: 1-10 [50, 10] 5,130 ========================================================================================== Total params: 11,181,642 Trainable params: 11,181,642 Non-trainable params: 0 Total mult-adds (Units.GIGABYTES): 90.68 ========================================================================================== Input size (MB): 30.11 Forward/backward pass size (MB): 1986.97 Params size (MB): 44.73 Estimated Total Size (MB): 2061.81 ==========================================================================================