# -*- coding: utf-8 -*-
"""
dlspa.py
fdtd2d/fdtd3d対応
"""

import sys
import torch
from torch import nn, optim
from torchvision import models
#from torchinfo import summary

import utils, dataset, fit, plot
#import myCNN, myResNet

def main(argv):
    # 学習or推論
    Imode = 1  # 1=学習, 2=推論
    # 計算結果ファイル名（image=S行列, label=誘電率,導電率）
    datafile = '../fdtd2d/fdtd.bin'
    #datafile = '../fdtd3d/fdtd.bin'
    # モデルファイル
    modelfile = 'dlspa.pth'

    # 引数で指定するとき, python dlspa.py [1/2 [datafile [modelfile]]]
    if len(argv) > 1:
        Imode = int(argv[1])
        assert Imode == 1 or Imode == 2
    if len(argv) > 2:
        datafile = argv[2]
    if len(argv) > 3:
        modelfile = argv[3]

    # 計算パラメーター
    if   Imode == 1:
        # 学習
        load_model = 0     # 通常0, 前回保存したmodelfileからrestartするときは1
        save_model = 1     # 通常1, 計算終了時にmodelfileを保存しないときは0
        ndata = -1         # データ数, -1のときはすべてのデータ
        batch_size = 60    # バッチサイズ(通常50-100程度)
        num_epochs = 30   # エポック数
        train_ratio = 0.8  # 訓練データの割合(通常0.8程度)
    elif Imode == 2:
        # 推論
        load_model = 1     # 通常1:modelfileを読み込む
        save_model = 0     # 通常0:modelfileは保存しない
        ndata = -1         # データ数, -1のときはすべてのデータ
        batch_size = 50    # バッチサイズ(=50:ダミー)
        num_epochs = 0     # エポック数(=0)
        train_ratio = 0    # 訓練データの割合(=0)

    # 学習・推論共通
    ecomp = 3  # label成分: 1=Re(Er), 2=Im(Er), 3=両方 (通常3)
    scomp = 2  # image成分: 1=S, 2=Re(S),Im(S), 3=Re(S),Im(S),S (通常2)

    # 混合精度(T/F)
    mixed = torch.cuda.is_available()

    # device
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # 乱数初期化
    utils.torch_seed()

    # 学習データを読み込む
    idata, image, label = utils.readdata(datafile)
    #print(idata); return

    # debug, 図形確認
    #plot.plot2d_data(image, label); return

    # dataset
    train_set, valid_set = dataset.dataset(ndata, image, label, scomp, ecomp, train_ratio)#, sfactor=sfactor)
    #image, label = next(iter(train_set)); print(image.shape, label.shape); return

    # dataloader
    train_loader, valid_loader = dataset.dataloader(train_set, valid_set, batch_size=batch_size)
    #image, label = next(iter(train_loader)); print(image.shape, label.shape); return

    # 入出力サイズ
    mlabel = (2 if ecomp == 3 else 1)  # label成分数
    in_channels = idata[1] * scomp
    num_classes = idata[4] * idata[5] * idata[6] * mlabel
    #print(in_channels, num_classes); return

    # myCNN/myResNet
    #model = myCNN.CNN10((in_channels, idata[1], idata[2]), [192], [0], num_classes)
    #model = myResNet.ResNet(in_channels, 4, 64, 0, num_classes)

    # 公開ResNet
    # 重みなし
    model = models.resnet18(weights=None)
    #model = models.resnet34(weights=None)
    #model = models.resnet50(weights=None)
    # 重みあり
    #model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    #model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
    #model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    # model付け替え（ResNet18/34/50共通）
    model.conv1 = nn.Conv2d(in_channels, 64, 7, stride=2, padding=3, bias=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)

    # VGG（32px以上, momentum?, 性能悪い?, 大px向き?）
    #model = models.vgg11_bn(weights=None)
    #model = models.vgg11_bn(weights=models.VGG11_BN_Weights.DEFAULT)
    #model.features[0] = nn.Conv2d(in_channels, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #model.classifier[6] = nn.Linear(model.classifier[6].in_features, out_features=num_classes, bias=True)

    # debug
    #summary(model, (60, 2, 42, 42)); return
    #print(model); return

    # restart（前回と同じmodelのとき有効）
    if load_model == 1:
        utils.load_model(model, modelfile, device)

    # 損失関数, 回帰
    criterion = nn.MSELoss()

    # 最適化関数
    #optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    optimizer = optim.Adam(model.parameters())

    # 計算条件表示
    print('data=%d+%d=%d %dx%dx%dx%d %dx%dx%dx%d=%d batch=%d epochs=%d'
        % (len(train_set), len(valid_set), len(train_set) + len(valid_set),
           idata[2], idata[3], scomp, idata[1], 
           idata[4], idata[5], idata[6], mlabel, idata[4] * idata[5] * idata[6] * mlabel, batch_size, num_epochs))

    # 学習実行(計算時間の主要部)
    preds, history = fit.fit(model, device, num_epochs, train_loader, valid_loader, optimizer, criterion, mixed,
        modelfile if (save_model == 1) else None, step=2)

    # 計算結果labelをファイルに保存する
    outfile = 'dlspa.npz'
    utils.save_label(outfile, idata, valid_set, preds, ecomp)

    # debug, loss履歴を図形表示する
    plot.plot2d_history(history, ymax=5)

    # debug, loss履歴をloss.logに出力する
    utils.output_history(history, 'loss.log')

if __name__ == '__main__':
    main(sys.argv)
