# -*- coding: utf-8 -*-
"""
myResNet.py
自作ResNet
"""

import torch
from torch import nn

# ResNetの単位block
class BasicBlock(nn.Module):

    def __init__(self, channels, dropout):
        super().__init__()

        self.conv1    = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1      = nn.BatchNorm2d(channels)
        self.relu     = nn.ReLU(inplace=True)
        self.conv2    = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2      = nn.BatchNorm2d(channels)
        self.dropout1 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += x  # ResNet
        out = self.relu(out)
        out = self.dropout1(out)

        return out

# 自作ResNet
class ResNet(nn.Module):

    def __init__(self, in_filters, num_blocks, channels, dropout, out_classes):
        super().__init__()

        self.conv1    = nn.Conv2d(in_filters, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1      = nn.BatchNorm2d(channels)
        self.relu     = nn.ReLU(inplace=True)
        self.maxpool  = nn.MaxPool2d(2, ceil_mode=True)
        self.dropout1 = nn.Dropout(dropout)

        self.blocks   = nn.Sequential(*[BasicBlock(channels, dropout) for _ in range(num_blocks)])

        self.avgpool  = nn.AdaptiveAvgPool2d(1)
        self.flatten  = nn.Flatten()
        self.fc       = nn.Linear(channels, out_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.dropout1(x)

        x = self.blocks(x)

        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc(x)

        return x
