# -*- coding: utf-8 -*-
"""
fit.py
"""

import time
import numpy as np
import torch
from torch.amp import autocast, GradScaler

# 訓練
def _train(model, device, train_loader, optimizer, criterion):
    model.train()
    for data, target in train_loader:
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# 検証
def _validation(model, device, valid_loader, criterion):
    model.eval()
    valid_loss = 0
    preds = np.zeros(0, dtype=np.float32)
    with torch.no_grad():
        for data, target in valid_loader:
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            loss = criterion(output, target)
            valid_loss += loss.item() * len(data)
            output = output.to('cpu').numpy().reshape(-1)
            preds = np.append(preds, output)

    return preds, np.sqrt(valid_loss / len(valid_loader.dataset))

# 学習
def fit(model, device, num_epochs, train_loader, valid_loader, optimizer, criterion,
    mixed, modelfile, step=1):

    # 開始時刻
    t0 = time.time()
    t1 = t0

    # GPU転送
    model = model.to(device)

    preds_best = None
    min_loss = 1000
    history = np.zeros(num_epochs + 1, dtype=np.float32)

    # エポックに関するループ
    for epoch in range(num_epochs + 1):

        # 訓練
        if epoch > 0:  # 初回は行わない
            if not mixed:
                _train(model, device, train_loader, optimizer, criterion)
            else:
                _train_mixed(model, device, train_loader, optimizer, criterion)

        # 検証
        preds, valid_loss = _validation(model, device, valid_loader, criterion)

        # 損失と正解率を出力する
        history[epoch] = valid_loss
        if epoch % step == 0:
            t2 = time.time()
            print('%4d %.5f %8.1f(%5.1f)[sec]' % (epoch, valid_loss, t2 - t0, t2 - t1), flush=True)
            t1 = t2

        # 正解率を更新した
        if valid_loss < min_loss:
            min_loss = valid_loss
            preds_best = preds.copy()
            # パラメーター保存(restart用)
            if modelfile is not None:
                torch.save(model.state_dict(), modelfile)

    # メッセージ
    print('min  %.5f' % min_loss)
    if modelfile is not None:
        print('Save %s' % modelfile)

    return preds_best, history

# 訓練、テンソルコア用混合精度、結果は少し変わる
# 画像サイズが大きいとき(64以上)約2倍速くなる
# https://pytorch.org/docs/stable/notes/amp_examples.html
def _train_mixed(model, device, train_loader, optimizer, criterion):

    scaler = GradScaler('cuda')

    model.train()
    for data, target in train_loader:
        # これでもよい
        #data = data.to(device)
        #target = target.to(device)
        
        # cuda
        data = data.cuda()
        target = target.cuda()

        # 勾配初期化
        optimizer.zero_grad()

        # Runs the forward pass with autocasting.
        with autocast(device_type='cuda', dtype=torch.float16):   # モデルの演算精度を自動選択する
            output = model(data)
            loss = criterion(output, target)

        # 損失計算のスケール化、逆誤差伝搬
        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()  # スケール化されたlossのbackward関数を使用する

        # オプティマイザのstep処理
        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)  # optimizer.step()の代わりに、scaler.step()を使用する

        # Scalerの更新
        # Updates the scale for next iteration.
        scaler.update()

        #output = model(data)
        #loss = criterion(output, target)
        #loss.backward()
        #optimizer.step()
