# -*- coding: utf-8 -*-
"""
plot_current.py
"""

import numpy as np
import matplotlib.pyplot as plt

# plot current distribution (2D/3D)
def plot(Post, Ne, Nfreq, Freq, E_posm, E_posp, Title, Xc):
    # 2D
    if Post['current'][0] == 1:
        _plot2d(Post, Title, Ne, Nfreq, Freq, Xc)

    # 3D
    if Post['current'][1] == 1:
        _plot3d(Post, Title, Ne, Nfreq, Freq, E_posm, E_posp, Xc)


# plot current distribution (2D) (private)
def _plot2d(Post, Title, Ne, Nfreq, Freq, Xc):
    for ifreq in range(Nfreq):
        # set data
        amp = abs(Xc[ifreq]) * 1e3  # mA
        phs = np.angle(Xc[ifreq], True)  # degree

        # figure
        strfig = 'OpenMOM - current distribution (%d/%d)' % (ifreq + 1, Nfreq)
        fig = plt.figure(strfig, figsize=(Post['w2d'][0], Post['w2d'][1]))
        ax1 = fig.add_subplot()

        # X-axis
        ax1.set_xlim(0, Ne - 1)
        ax1.set_xticks([0, Ne - 1])
        ax1.set_xticklabels(['1', str(Ne)])
        ax1.set_xlabel('element No.')

        # amplitude
        ax1.plot(amp, 'k')
        ax1.set_ylim(bottom=0)
        ax1.set_ylabel('amplitude [mA]')
        ax1.grid(True)

        # phase
        ax2 = ax1.twinx()
        ax2.plot(phs, 'r--')
        ax2.set_ylim(-180, 180)
        ax2.set_ylabel('phase [deg]', color='r')
        ax2.set_yticks([-180, -135, -90, -45, 0, 45, 90, 135, 180])

        # title
        ax2.set_title('%s\n f = %.3g%s, max = %.4f[mA]' % \
            (Title, Freq[ifreq] * Post['fscale'], Post['funit'], np.max(amp)))

# plot current distribution (3D) (private)
def _plot3d(Post, Title, Ne, Nfreq, Freq, E_posm, E_posp, Xc):
    for ifreq in range(Nfreq):
        # set data
        amp = abs(Xc[ifreq]) * 1e3
        cmax = np.max(amp)
        cmin = np.min(amp)
        
        # figure
        strfig = 'OpenMOM - current distribution (3D) (%d/%d)' % (ifreq + 1, Nfreq)
        fig = plt.figure(strfig, figsize=(Post['w3d'][0], Post['w3d'][1]))
        ax = fig.add_subplot(projection='3d')
        
        # plot lines
        for ie in range(Ne):
            x = [E_posm[ie, 0], E_posp[ie, 0]]
            y = [E_posm[ie, 1], E_posp[ie, 1]]
            z = [E_posm[ie, 2], E_posp[ie, 2]]
            icolor = round(255 * (amp[ie] - cmin) / (cmax - cmin))
            icolor = max(0, min(255, icolor))
            color = plt.cm.rainbow(icolor)
            ax.plot(x, y, z, color=color)

        ax.set_aspect('equal')
        ax.view_init(elev = 90 - Post['w3d'][2], azim = Post['w3d'][3], roll = 0)

        # title
        ax.set_title('%s\nf = %.3g%s, max = %.4f[mA]' % \
           (Title, Freq[ifreq] * Post['fscale'], Post['funit'], np.max(amp)))

        # label
        ax.set_xlabel('X[m]')
        ax.set_ylabel('Y[m]')
        ax.set_zlabel('Z[m]')

        # limit
        gmin = np.zeros(3, float)
        gmax = np.zeros(3, float)
        for m in range(3):
            gmin[m] = min(np.min(E_posm[:, m]), np.min(E_posp[:, m]))
            gmax[m] = max(np.max(E_posm[:, m]), np.max(E_posp[:, m]))
        margin = 0.05 * sum(gmax - gmin)
        gmin -= margin
        gmax += margin
        ax.set_xlim(gmin[0], gmax[0])
        ax.set_ylim(gmin[1], gmax[1])
        ax.set_zlim(gmin[2], gmax[2])
