# -*- coding: utf-8 -*-
"""
myCNN.py
自作CNN
maxpoolでサイズが半減するのでネットワークの深さに限度がある
"""

import numpy as np
from torch import nn

class CNN2(nn.Module):

    def __init__(self, in_size, channels, dropout, out_classes):
        super().__init__()

        if len(channels) == 1:
            channels = [channels[0]] * 2
        if len(dropout) == 1:
            dropout = [dropout[0]] * 2

        assert (len(in_size) == 3) and (len(channels) == 2) and (len(dropout) == 2)

        out_pixel = (int(np.ceil(in_size[1] / 2)), int(np.ceil(in_size[2] / 2)))

        self.conv1    = nn.Conv2d(in_size[0],  channels[0], kernel_size=3, stride=1, padding=1)
        self.conv2    = nn.Conv2d(channels[0], channels[1], kernel_size=3, stride=1, padding=1)
        self.bn1      = nn.BatchNorm2d(channels[0])
        self.bn2      = nn.BatchNorm2d(channels[1])
        self.dropout1 = nn.Dropout(dropout[0])
        self.dropout2 = nn.Dropout(dropout[1])
        self.relu     = nn.ReLU(inplace=True)
        self.maxpool  = nn.MaxPool2d(2, ceil_mode=True)
        self.flatten  = nn.Flatten()
        self.fc1      = nn.Linear(out_pixel[0] * out_pixel[1] * channels[-1], channels[-1])
        self.fc2      = nn.Linear(channels[-1], out_classes)

        self.features = nn.Sequential(
            self.conv1, self.bn1, self.relu,
            self.conv2, self.bn2, self.relu, self.maxpool, self.dropout1,
        )

        self.classifier = nn.Sequential(
            self.fc1,
            self.relu,
            self.dropout2,
            self.fc2,
        )

    def forward(self, x):
        x = self.features(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x

class CNN4(nn.Module):

    def __init__(self, in_size, channels, dropout, out_classes):
        super().__init__()

        if len(channels) == 1:
            channels = [channels[0]] * 4
        if len(dropout) == 1:
            dropout = [dropout[0]] * 3

        assert (len(in_size) == 3) and (len(channels) == 4) and (len(dropout) == 3)

        out_pixel = (int(np.ceil(in_size[1] / 4)), int(np.ceil(in_size[2] / 4)))

        self.conv1    = nn.Conv2d(in_size[0],  channels[0], kernel_size=3, stride=1, padding=1)
        self.conv2    = nn.Conv2d(channels[0], channels[1], kernel_size=3, stride=1, padding=1)
        self.conv3    = nn.Conv2d(channels[1], channels[2], kernel_size=3, stride=1, padding=1)
        self.conv4    = nn.Conv2d(channels[2], channels[3], kernel_size=3, stride=1, padding=1)
        self.bn1      = nn.BatchNorm2d(channels[0])
        self.bn2      = nn.BatchNorm2d(channels[1])
        self.bn3      = nn.BatchNorm2d(channels[2])
        self.bn4      = nn.BatchNorm2d(channels[3])
        self.dropout1 = nn.Dropout(dropout[0])
        self.dropout2 = nn.Dropout(dropout[1])
        self.dropout3 = nn.Dropout(dropout[2])
        self.relu     = nn.ReLU(inplace=True)
        self.maxpool  = nn.MaxPool2d(2, ceil_mode=True)
        self.flatten  = nn.Flatten()
        self.fc1      = nn.Linear(out_pixel[0] * out_pixel[1] * channels[-1], channels[-1])
        self.fc2      = nn.Linear(channels[-1], out_classes)

        self.features = nn.Sequential(
            self.conv1, self.bn1, self.relu,
            self.conv2, self.bn2, self.relu, self.maxpool, self.dropout1,
            self.conv3, self.bn3, self.relu,
            self.conv4, self.bn4, self.relu, self.maxpool, self.dropout2,
        )

        self.classifier = nn.Sequential(
            self.fc1,
            self.relu,
            self.dropout3,
            self.fc2,
        )

    def forward(self, x):
        x = self.features(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x

class CNN6(nn.Module):

    def __init__(self, in_size, channels, dropout, out_classes):
        super().__init__()

        if len(channels) == 1:
            channels = [channels[0]] * 6
        if len(dropout) == 1:
            dropout = [dropout[0]] * 4

        assert (len(in_size) == 3) and (len(channels) == 6) and (len(dropout) == 4)

        out_pixel = (int(np.ceil(in_size[1] / 8)), int(np.ceil(in_size[2] / 8)))

        self.conv1    = nn.Conv2d(in_size[0],  channels[0], kernel_size=3, stride=1, padding=1)
        self.conv2    = nn.Conv2d(channels[0], channels[1], kernel_size=3, stride=1, padding=1)
        self.conv3    = nn.Conv2d(channels[1], channels[2], kernel_size=3, stride=1, padding=1)
        self.conv4    = nn.Conv2d(channels[2], channels[3], kernel_size=3, stride=1, padding=1)
        self.conv5    = nn.Conv2d(channels[3], channels[4], kernel_size=3, stride=1, padding=1)
        self.conv6    = nn.Conv2d(channels[4], channels[5], kernel_size=3, stride=1, padding=1)
        self.bn1      = nn.BatchNorm2d(channels[0])
        self.bn2      = nn.BatchNorm2d(channels[1])
        self.bn3      = nn.BatchNorm2d(channels[2])
        self.bn4      = nn.BatchNorm2d(channels[3])
        self.bn5      = nn.BatchNorm2d(channels[4])
        self.bn6      = nn.BatchNorm2d(channels[5])
        self.dropout1 = nn.Dropout(dropout[0])
        self.dropout2 = nn.Dropout(dropout[1])
        self.dropout3 = nn.Dropout(dropout[2])
        self.dropout4 = nn.Dropout(dropout[3])
        self.relu     = nn.ReLU(inplace=True)
        self.maxpool  = nn.MaxPool2d(2, ceil_mode=True)
        self.flatten  = nn.Flatten()
        self.fc1      = nn.Linear(out_pixel[0] * out_pixel[1] * channels[-1], channels[-1])
        self.fc2      = nn.Linear(channels[-1], out_classes)

        self.features = nn.Sequential(
            self.conv1, self.bn1, self.relu,
            self.conv2, self.bn2, self.relu, self.maxpool, self.dropout1,
            self.conv3, self.bn3, self.relu,
            self.conv4, self.bn4, self.relu, self.maxpool, self.dropout2,
            self.conv5, self.bn5, self.relu,
            self.conv6, self.bn6, self.relu, self.maxpool, self.dropout3,
        )

        self.classifier = nn.Sequential(
            self.fc1,
            self.relu,
            self.dropout4,
            self.fc2,
        )

    def forward(self, x):
        x = self.features(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x

class CNN8(nn.Module):

    def __init__(self, in_size, channels, dropout, out_classes):
        super().__init__()

        if len(channels) == 1:
            channels = [channels[0]] * 8
        if len(dropout) == 1:
            dropout = [dropout[0]] * 5

        assert (len(in_size) == 3) and (len(channels) == 8) and (len(dropout) == 5)

        out_pixel = (int(np.ceil(in_size[1] / 16)), int(np.ceil(in_size[2] / 16)))

        self.conv1    = nn.Conv2d(in_size[0],  channels[0], kernel_size=3, stride=1, padding=1)
        self.conv2    = nn.Conv2d(channels[0], channels[1], kernel_size=3, stride=1, padding=1)
        self.conv3    = nn.Conv2d(channels[1], channels[2], kernel_size=3, stride=1, padding=1)
        self.conv4    = nn.Conv2d(channels[2], channels[3], kernel_size=3, stride=1, padding=1)
        self.conv5    = nn.Conv2d(channels[3], channels[4], kernel_size=3, stride=1, padding=1)
        self.conv6    = nn.Conv2d(channels[4], channels[5], kernel_size=3, stride=1, padding=1)
        self.conv7    = nn.Conv2d(channels[5], channels[6], kernel_size=3, stride=1, padding=1)
        self.conv8    = nn.Conv2d(channels[6], channels[7], kernel_size=3, stride=1, padding=1)
        self.bn1      = nn.BatchNorm2d(channels[0])
        self.bn2      = nn.BatchNorm2d(channels[1])
        self.bn3      = nn.BatchNorm2d(channels[2])
        self.bn4      = nn.BatchNorm2d(channels[3])
        self.bn5      = nn.BatchNorm2d(channels[4])
        self.bn6      = nn.BatchNorm2d(channels[5])
        self.bn7      = nn.BatchNorm2d(channels[6])
        self.bn8      = nn.BatchNorm2d(channels[7])
        self.dropout1 = nn.Dropout(dropout[0])
        self.dropout2 = nn.Dropout(dropout[1])
        self.dropout3 = nn.Dropout(dropout[2])
        self.dropout4 = nn.Dropout(dropout[3])
        self.dropout5 = nn.Dropout(dropout[4])
        self.relu     = nn.ReLU(inplace=True)
        self.maxpool  = nn.MaxPool2d(2, ceil_mode=True)
        self.flatten  = nn.Flatten()
        self.fc1      = nn.Linear(out_pixel[0] * out_pixel[1] * channels[-1], channels[-1])
        self.fc2      = nn.Linear(channels[-1], out_classes)

        self.features = nn.Sequential(
            self.conv1, self.bn1, self.relu,
            self.conv2, self.bn2, self.relu, self.maxpool, self.dropout1,
            self.conv3, self.bn3, self.relu,
            self.conv4, self.bn4, self.relu, self.maxpool, self.dropout2,
            self.conv5, self.bn5, self.relu,
            self.conv6, self.bn6, self.relu, self.maxpool, self.dropout3,
            self.conv7, self.bn7, self.relu,
            self.conv8, self.bn8, self.relu, self.maxpool, self.dropout4,
        )

        self.classifier = nn.Sequential(
            self.fc1,
            self.relu,
            self.dropout5,
            self.fc2,
        )

    def forward(self, x):
        x = self.features(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x

class CNN10(nn.Module):

    def __init__(self, in_size, channels, dropout, out_classes):
        super().__init__()

        if len(channels) == 1:
            channels = [channels[0]] * 10
        if len(dropout) == 1:
            dropout = [dropout[0]] * 6

        assert (len(in_size) == 3) and (len(channels) == 10) and (len(dropout) == 6)

        out_pixel = (int(np.ceil(in_size[1] / 32)), int(np.ceil(in_size[2] / 32)))

        self.conv1    = nn.Conv2d(in_size[0],  channels[0], kernel_size=3, stride=1, padding=1)
        self.conv2    = nn.Conv2d(channels[0], channels[1], kernel_size=3, stride=1, padding=1)
        self.conv3    = nn.Conv2d(channels[1], channels[2], kernel_size=3, stride=1, padding=1)
        self.conv4    = nn.Conv2d(channels[2], channels[3], kernel_size=3, stride=1, padding=1)
        self.conv5    = nn.Conv2d(channels[3], channels[4], kernel_size=3, stride=1, padding=1)
        self.conv6    = nn.Conv2d(channels[4], channels[5], kernel_size=3, stride=1, padding=1)
        self.conv7    = nn.Conv2d(channels[5], channels[6], kernel_size=3, stride=1, padding=1)
        self.conv8    = nn.Conv2d(channels[6], channels[7], kernel_size=3, stride=1, padding=1)
        self.conv9    = nn.Conv2d(channels[7], channels[8], kernel_size=3, stride=1, padding=1)
        self.conv10   = nn.Conv2d(channels[8], channels[9], kernel_size=3, stride=1, padding=1)
        self.bn1      = nn.BatchNorm2d(channels[0])
        self.bn2      = nn.BatchNorm2d(channels[1])
        self.bn3      = nn.BatchNorm2d(channels[2])
        self.bn4      = nn.BatchNorm2d(channels[3])
        self.bn5      = nn.BatchNorm2d(channels[4])
        self.bn6      = nn.BatchNorm2d(channels[5])
        self.bn7      = nn.BatchNorm2d(channels[6])
        self.bn8      = nn.BatchNorm2d(channels[7])
        self.bn9      = nn.BatchNorm2d(channels[8])
        self.bn10     = nn.BatchNorm2d(channels[9])
        self.dropout1 = nn.Dropout(dropout[0])
        self.dropout2 = nn.Dropout(dropout[1])
        self.dropout3 = nn.Dropout(dropout[2])
        self.dropout4 = nn.Dropout(dropout[3])
        self.dropout5 = nn.Dropout(dropout[4])
        self.dropout6 = nn.Dropout(dropout[5])
        self.relu     = nn.ReLU(inplace=True)
        self.maxpool  = nn.MaxPool2d(2, ceil_mode=True)
        self.flatten  = nn.Flatten()
        self.fc1      = nn.Linear(out_pixel[0] * out_pixel[1] * channels[-1], channels[-1])
        self.fc2      = nn.Linear(channels[-1], out_classes)

        self.features = nn.Sequential(
            self.conv1,  self.bn1,  self.relu,
            self.conv2,  self.bn2,  self.relu, self.maxpool, self.dropout1,
            self.conv3,  self.bn3,  self.relu,
            self.conv4,  self.bn4,  self.relu, self.maxpool, self.dropout2,
            self.conv5,  self.bn5,  self.relu,
            self.conv6,  self.bn6,  self.relu, self.maxpool, self.dropout3,
            self.conv7,  self.bn7,  self.relu,
            self.conv8,  self.bn8,  self.relu, self.maxpool, self.dropout4,
            self.conv9,  self.bn9,  self.relu,
            self.conv10, self.bn10, self.relu, self.maxpool, self.dropout5,
        )

        self.classifier = nn.Sequential(
            self.fc1,
            self.relu,
            self.dropout6,
            self.fc2,
        )

    def forward(self, x):
        x = self.features(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x
