# -*- coding: utf-8 -*-
"""
utils.py
各種関数
"""

import os, sys
import numpy as np
import torch

# 学習データを読み込む
def readdata(fn):

    if not os.path.isfile(fn):
        print('file %s not found' % fn)
        sys.exit()

    print('Load %s' % fn)

    offset = np.int64(0)  # 0は不可, Winddows版Pythonの整数は既定値では32ビット

    # 整数(i4)
    count = 9
    idata = np.fromfile(fn, dtype='i4', count=count, offset=offset)
    offset += count * 4  # 4:i4
    #print(idata)

    # 実数(f4)
    count = 8
    fdata = np.fromfile(fn, dtype='f4', count=count, offset=offset)
    offset += count * 4  # 4:f4
    #print(fdata)

    # image: S行列（実部、虚部）
    image = np.zeros((idata[0], idata[1], idata[2], idata[3], 2), dtype=np.float32)  # 5D
    count = 2 * idata[1] * idata[2] * idata[3]
    size1 = idata[7]
    if   size1 == 4:  # f4
        for n in range(idata[0]):
            image_f4 = np.fromfile(fn, dtype='f4', count=count, offset=offset)
            image[n] = image_f4.reshape(idata[1], idata[2], idata[3], 2)
            offset += count * size1
    elif size1 == 2:  # u2
        for n in range(idata[0]):
            image_u2 = np.fromfile(fn, dtype='u2', count=count, offset=offset)
            # 正規化を復元する(u2->f4)
            image_f4 = image_u2.astype(np.float32) / 65535
            image[n] = image_f4.reshape(idata[1], idata[2], idata[3], 2)
            image[n,:,:,:,0] = fdata[0] + (fdata[1] - fdata[0]) * image[n,:,:,:,0]
            image[n,:,:,:,1] = fdata[2] + (fdata[3] - fdata[2]) * image[n,:,:,:,1]
            offset += count * size1
    elif size1 == 1:  # u1
        for n in range(idata[0]):
            image_u1 = np.fromfile(fn, dtype='u1', count=count, offset=offset)
            # 正規化を復元する(u1->f4)
            image_f4 = image_u1.astype(np.float32) / 255
            image[n] = image_f4.reshape(idata[1], idata[2], idata[3], 2)
            image[n,:,:,:,0] = fdata[0] + (fdata[1] - fdata[0]) * image[n,:,:,:,0]
            image[n,:,:,:,1] = fdata[2] + (fdata[3] - fdata[2]) * image[n,:,:,:,1]
            offset += count * size1

    # label: 誘電率,実部,虚部
    label = np.zeros((idata[0], idata[4], idata[5], idata[6], 2), dtype=np.float32)  # 5D
    count = 2 * idata[4] * idata[5] * idata[6]
    size2 = idata[8]
    if   size2 == 4:  # f4
        for n in range(idata[0]):
            label_f4 = np.fromfile(fn, dtype='f4', count=count, offset=offset)
            label[n] = label_f4.reshape(idata[4], idata[5], idata[6], 2)
            offset += count * size2
    elif size2 == 2:  # u2
        for n in range(idata[0]):
            label_u2 = np.fromfile(fn, dtype='u2', count=count, offset=offset)
            # 正規化を復元する(u2->f4)
            label_f4 = label_u2.astype(np.float32) / 65535
            label[n] = label_f4.reshape(idata[4], idata[5], idata[6], 2)
            label[n,:,:,:,0] = fdata[4] + (fdata[5] - fdata[4]) * label[n,:,:,:,0]
            label[n,:,:,:,1] = fdata[6] + (fdata[7] - fdata[6]) * label[n,:,:,:,1]
            offset += count * size2
    elif size2 == 1:  # u1
        for n in range(idata[0]):
            label_u1 = np.fromfile(fn, dtype='u1', count=count, offset=offset)
            # 正規化を復元する(u1->f4)
            label_f4 = label_u1.astype(np.float32) / 255
            label[n] = label_f4.reshape(idata[4], idata[5], idata[6], 2)
            label[n,:,:,:,0] = fdata[4] + (fdata[5] - fdata[4]) * label[n,:,:,:,0]
            label[n,:,:,:,1] = fdata[6] + (fdata[7] - fdata[6]) * label[n,:,:,:,1]
            offset += count * size2

    return idata, image, label

# 乱数初期化
def torch_seed(seed=123):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

# モデルファイルを読み込む
def load_model(model, modelfile, device):
    if not os.path.isfile(modelfile):
        print('file %s not found' % modelfile)
        sys.exit()
    model.load_state_dict(torch.load(modelfile, map_location=device, weights_only=True))
    print("Load %s" % modelfile)

# 検証データのラベルの正解と予測値をファイルに保存する
def save_label(fn, idata, valid_set, preds, ecomp):
    trues = []
    for _, label in valid_set:
        trues.append(label.numpy())
    trues = np.array(trues).reshape(-1)

    np.savez(fn, idata, trues, preds, ecomp)
    print('Save %s' % fn)

# 検証データのラベルの正解と予測値をファイルから読み込む
def load_label(fn):
    d = np.load(fn)
    idata   = d[d.files[0]]
    trues   = d[d.files[1]]
    preds   = d[d.files[2]]
    ecomp   = d[d.files[3]]
    print('Load %s' % fn)

    return idata, trues, preds, ecomp

# debug, loss履歴をファイルに出力する
def output_history(history, fn):
    with open(fn, 'wt', encoding='utf-8') as fp:
        for i, loss in enumerate(history):
            fp.write('%d %.5f\n' % (i, loss))
