# -*- coding: utf-8 -*-
"""
dataset.py
"""

import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import DataLoader

# dataset
# image : S行列(実部,虚部), 5D配列
# label : 誘電率,導電率, 5D配列
def dataset(ndata, image, label, scomp, ecomp, train_ratio):
    assert image.ndim == 5
    assert label.ndim == 5
    assert image.shape[0] == label.shape[0]
    assert (scomp >= 1) and (scomp <= 3)
    assert (ecomp >= 1) and (ecomp <= 3)

    # データ数, 周波数数
    ndata = image.shape[0] if (ndata <= 0 or ndata > image.shape[0]) else ndata
    nfreq = image.shape[1]
    #print(ndata, scomp, ecomp, nfreq, image.shape, label.shape)

    # 実用時は学習と推論で正規化が必要
    # scomp=2のときのimage正規化（効果なし?）
    #mean = image.mean()
    #std = image.std()
    #image = (image - mean) / std

    # 5D->4D
    image_r = image[:, :, :, :, 0]  # Re
    image_i = image[:, :, :, :, 1]  # Im
    image_a = np.sqrt(image_r**2 + image_i**2)

    # scomp=1のときのimage正規化（効果なし?）
    #mean = image_a.mean()
    #std = image_a.std()
    #image_a = (image_a - mean) / std

    # 乱数発生
    rand = torch.rand(ndata)

    # 前処理関数
    resize = 72
    transform_resize = transforms.Resize(size=(resize, resize))

    # dataset
    image0 = np.zeros((scomp * nfreq, image.shape[2], image.shape[3]), dtype=np.float32)
    train_set, valid_set = [], []
    for i in range(ndata):
        # image:S行列
        for ifreq in range(nfreq):
            if   scomp == 1:
                image0[ifreq] = image_a[i, ifreq]
            elif scomp == 2:
                image0[ifreq * scomp + 0] = image_r[i, ifreq]
                image0[ifreq * scomp + 1] = image_i[i, ifreq]
            elif scomp == 3:
                image0[ifreq * scomp + 0] = image_r[i, ifreq]
                image0[ifreq * scomp + 1] = image_i[i, ifreq]
                image0[ifreq * scomp + 2] = image_a[i, ifreq]

        # tensor化
        image_tensor = torch.from_numpy(image0).clone()

        # resize（ResNet用）
        image_tensor = transform_resize(image_tensor)

        # label:誘電率, 導電率,　1D化, tensor化
        label0 = label[i].copy()
        label0 = np.transpose(label0, (3, 0, 1, 2))  # [0/1, Mx, My, Mz]
        assert len(label0) == 2
        if   ecomp == 1:
            label1 = label0[0]
        elif ecomp == 2:
            label1 = label0[1]
        elif ecomp == 3:
            label1 = label0
        label1 = label1.reshape(-1)  # 1D化
        label_tensor = torch.from_numpy(label1).clone()
        #label_tensor = torch.tensor(label1).view(-1)

        # 乱数で分配
        if rand[i] < train_ratio:
            # 訓練データ
            train_set.append([image_tensor, label_tensor])
        else:
            # 検証データ
            #r = 0.1  # S/N比
            #image_tensor += r * torch.randn(image_tensor.shape)
            valid_set.append([image_tensor, label_tensor])

    return train_set, valid_set

# dataloader
def dataloader(train_set, valid_set, batch_size=50):
    train_loader, valid_loader = None, None
    if train_set is not None and len(train_set) > 0:
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    if valid_set is not None and len(valid_set) > 0:
        valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False)

    return train_loader, valid_loader
