# -*- coding: utf-8 -*-
"""
plot.py
図形出力
"""

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

import plotutils

# debug, 入力データを図形表示する
def plot2d_data(image, label):
    figsize = (7, 5)  # ウィンドウサイズ(inch)
    mx = 4
    my = 3
    mstep = 1
    #interpolation = 'auto'
    transpose = True
    origin = 'lower'
    cmap = 'rainbow'
    levels = 10

    # label: 誘電率分布
    #print(label.shape, np.min(label), np.max(label))
    assert label.ndim == 5
    if label.shape[3] == 1:
        # fdtd2d: 全領域
        f = label[:, :, :, 0, 0]  # 0:Nz=1, 0:誘電率
        #f = label[:, :, :, 0, 1]  # 0:Nz=1, 1:導電率
        figtitle = 'label: fdtd2d'
    else:
        # fdtd3d: 指定したZ面
        k = 6
        k = max(0, min(label.shape[3] - 1, k))
        f = label[:, :, :, k, 0]  # 0:誘電率
        #f = label[:, :, :, k, 1]  # 1:導電率
        figtitle = 'label: (fdtd3d Z=%d)' % k
    assert f.ndim == 3
    sta = ['%d  %.5f' % (i + 1, np.mean(d)) for i, d in enumerate(f)]
    vmin = np.min(f)
    vmax = np.max(f)
    suptitle = 'data=%d step=%d Nx=%d Ny=%d min=%.5f max=%.5f' % (f.shape[0], mstep, f.shape[1], f.shape[2], vmin, vmax)
    plotutils.contour(figtitle, figsize, mx, my, mstep, f, vmin, vmax, sta, suptitle, transpose=transpose, origin=origin, cmap=cmap, levels=levels)

    # image: S行列
    #print(image.shape, np.min(image), np.max(image))
    assert image.ndim == 5
    ifreq = 0; ifreq = max(0, min(image.shape[1] - 1, ifreq))
    s_r = image[:, ifreq, :, :, 0]
    s_i = image[:, ifreq, :, :, 1]
    s_a = np.sqrt(s_r**2 + s_i**2)
    #f = s_r; figtitle = 'image: Re(S)' 
    #f = s_i; figtitle = 'image: Im(S)'
    f = s_a; figtitle = 'image: S'
    sta = ['%d  %.5f' % (i + 1, np.mean(d)) for i, d in enumerate(f)]
    vmin = np.min(f)
    vmax = np.max(f)
    suptitle = 'data=%d step=%d Tx=%d Rx=%d min=%.5f max=%.5f' % (image.shape[0], mstep, image.shape[2], image.shape[3], vmin, vmax)
    plotutils.contour(figtitle, figsize, mx, my, mstep, f, vmin, vmax, sta, suptitle, transpose=transpose, origin=origin, cmap=cmap)

# debug, loss履歴を図形表示する
def plot2d_history(history, ymax=1):
    if len(history) < 2:
        return
    plt.figure('loss', figsize=(6, 4), layout='tight')
    plt.plot(history)
    plt.xlim(0, len(history) - 1)
    plt.ylim(0, ymax)
    plt.grid(True)
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()

# label, 2D図形出力, fdtd2d
def plot2d_labels_2d(trues, preds, dt):
    #print(trues.shape, preds.shape)
    assert trues.ndim == preds.ndim == 5
    assert trues.shape == preds.shape

    # パラメーター
    figsize = (4, 2)   # ウィンドウサイズ(inch)
    transpose = True   # 縦横変換する
    origin = 'lower'   # 原点は左下
    cmap = 'rainbow'   # colormap
    #cmap = 'gnuplot'   # colormap
    levels = 10        # 等高線の数
    pair = 1           # 正解と推定の並び, 0:横, 1:縦
    mx = 6            # 横方向図数
    dnum = [3, 18, 3]  # データ番号（最初,最後,間隔）(データ番号は1から始まる)

    # データ, 5D->3D, 0:Nz=1
    if trues.shape[1] == 2:
        # 実部+虚部->絶対値
        p1_r = trues[:, 0, :, :, 0]
        p1_i = trues[:, 1, :, :, 0]
        p2_r = preds[:, 0, :, :, 0]
        p2_i = preds[:, 1, :, :, 0]
        p1 = np.sqrt(p1_r**2 + p1_i**2)
        p2 = np.sqrt(p2_r**2 + p2_i**2)
    else:
        # 実部または虚部
        p1 = trues[:, 0, :, :, 0]
        p2 = preds[:, 0, :, :, 0]

    # 最小・最大（全データ）
    #vmin = min(np.min(p1), np.min(p2))
    #vmax = max(np.max(p1), np.max(p2))
    vmin = 5
    vmax = 55

    # 平均, 誤差（データ配列）
    ave = p1.mean(axis=(1,2))
    dif = np.abs(p1 - p2).mean(axis=(1,2))
    str_ave = ['%d %.4f' % (i + 1, d) for i, d in enumerate(ave)]
    str_dif = ['%d %.4f' % (i + 1, d) for i, d in enumerate(dif)]

    # 図形出力
    figtitle = 'fdtd2d label: (true, pred) %s' % dt
    suptitle = 'total=%d %dx%d min=%.4g max=%.4g dif=%.4f' % \
        (p1.shape[0], p1.shape[1], p1.shape[2], vmin, vmax, np.mean(np.abs(p1 - p2)))
    print(suptitle)
    plotutils.contour2(figtitle, figsize, pair, mx, dnum, p1, p2, vmin, vmax, str_ave, str_dif, suptitle, transpose=transpose, origin=origin, cmap=cmap, levels=levels)

# label, 2D図形出力, fdtd3d, 複数データ
def plot2d_labels_3d(trues, preds, dt):
    #print(trues.shape, preds.shape)
    assert trues.ndim == preds.ndim == 5
    assert trues.shape == preds.shape

    # パラメーター
    figsize = (7, 2)   # ウィンドウサイズ(inch)
    transpose = True   # 縦横変換する
    origin = 'lower'   # 原点は左下
    cmap = 'rainbow'   # colormap
    levels = 10        # 等高線の数
    pair = 1           # 正解と推定の並び, 0:横, 1:縦
    mx = 10             # 横方向図数
    dnum = [1, 10, 1]   # 面番号（最初,最後,間隔）(面番号は1から始まる)
    plane = 'Z'        # 面

    # データ番号, 適当に代入する(0<=idata0<idata1<=Ndata)
    idata0, idata1 = 1, 2
    idata0, idata1 = min(idata0, idata1), max(idata0, idata1)
    idata0 = max(0, min(idata0, trues.shape[0] - 1))
    idata1 = max(0, min(idata1, trues.shape[0]))

    # 1データ=1ページ
    for idata in range(idata0, idata1):
        # データ
        if trues.shape[1] == 2:
            # 実部+虚部->絶対値
            p1_r = trues[idata, 0]
            p1_i = trues[idata, 1]
            p2_r = preds[idata, 0]
            p2_i = preds[idata, 1]
            p1 = np.sqrt(p1_r**2 + p1_i**2)
            p2 = np.sqrt(p2_r**2 + p2_i**2)
        else:
            # 実部または虚部
            p1 = trues[idata, 0]
            p2 = preds[idata, 0]
        # 面の向き
        if   plane == 'X': # XYZ
            p1 = p1.transpose(0,1,2)
            p2 = p2.transpose(0,1,2)
        elif plane == 'Y': # YXZ
            p1 = p1.transpose(1,0,2)
            p2 = p2.transpose(1,0,2)
        elif plane == 'Z': # ZXY
            p1 = p1.transpose(2,0,1)
            p2 = p2.transpose(2,0,1)
        #print(p1.shape, p2.shape)
        assert p1.shape == p2.shape

        # 最小・最大（個別データ）
        #vmin = min(np.min(p1), np.min(p2))
        #vmax = max(np.max(p1), np.max(p2))
        vmin = 5
        vmax = 55

        # 平均, 誤差
        ave = p1.mean(axis=(1,2))
        dif = np.abs(p1 - p2).mean(axis=(1,2))
        str_ave = ['%d %.4f' % (i + 1, d) for i, d in enumerate(ave)]
        str_dif = ['%d %.4f' % (i + 1, d) for i, d in enumerate(dif)]

        # 図形出力
        figtitle = 'fdtd3d label data=%d %s' % (idata + 1, dt)
        suptitle = 'data=%d/%d %s %dx%dx%d min=%.4g max=%.4g dif=%.4f' % \
            (idata + 1, trues.shape[0], plane, trues.shape[2], trues.shape[3], trues.shape[4], vmin, vmax, np.mean(np.abs(p1 - p2)))
        print(suptitle)
        plotutils.contour2(figtitle, figsize, pair, mx, dnum, p1, p2, vmin, vmax, str_ave, str_dif, suptitle, transpose=transpose, origin=origin, cmap=cmap, levels=levels)

# label, 3D図形出力(fdtd3d)
def plot3d_labels(figtitle, a3d, vmin, vmax):
    nx, ny, nz = a3d.shape
    #print(vmin, vmax)
    #print(nx, ny, nz)

    verts = []
    colors = []
    for k in range(nz):
        for i in range(nx):
            for j in range(ny):
                rect = [[i, j, k], [i + 1, j, k], [i + 1, j + 1, k], [i, j + 1, k], [i, j, k]]
                verts.append(rect)
                v = (a3d[i, j, k] - vmin) / (vmax - vmin)
                color = (1, 1 - v, 1 - v)
                colors.append(color)
    verts = np.array(verts)
    colors = np.array(colors)

    fig = plt.figure(figtitle, figsize=(6, 5), layout='tight')
    ax = fig.add_subplot(projection='3d')
    poly = Poly3DCollection(verts, facecolor=colors, alpha=0.3)
    ax.add_collection3d(poly)

    ax.set_xlim(0, nx)
    ax.set_ylim(0, ny)
    ax.set_zlim(0, nz)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_aspect('equal')
    ax.view_init(elev=20, azim=225, roll=0)
    ax.grid(visible=False)

    suptitle = 'Nx=%d Ny=%d Nz=%d min=%.5f max=%.5f' % (nx, ny, nz, vmin, vmax)
    plt.suptitle(suptitle)

    plt.show()

# ヒストグラム
def histogram2(trues, preds, bins=10, range=(0, 1)):
    plt.figure('histogram', figsize=(5, 4), layout='tight')
    #print(preds.shape, trues.shape)
    #print(bins, range)
    assert trues.ndim == preds.ndim == 5
    assert trues.shape == preds.shape

    # true
    d = trues.reshape(-1)
    ax1 = plt.subplot(2, 1, 1)
    ax1.hist(d, bins=bins, range=range)
    ax1.set_title('min=%.3f max=%.3f' % (np.min(d), np.max(d)))
    ax1.set_xlim(range[0], range[1])
    ax1.set_xlabel('true')

    # pred
    d = preds.reshape(-1)
    ax2 = plt.subplot(2, 1, 2)
    ax2.hist(d, bins=bins, range=range)
    ax2.set_title('min=%.3f max=%.3f' % (np.min(d), np.max(d)))
    ax2.set_xlim(range[0], range[1])
    ax2.set_xlabel('pred')

    suptitle = '%dx%dx%dx%dx%d=%d' % (trues.shape[0], trues.shape[1], trues.shape[2], trues.shape[3], trues.shape[4], trues.size)
    plt.suptitle(suptitle)

    plt.show()
