# -*- coding: utf-8 -*-
"""
plot_near1d.py
"""

import numpy as np
import matplotlib.pyplot as plt
import post.nearfield

# plot near field 
def plot(Post, Nfreq, Freq, Nfeed, Iplanewave, Planewave, E_posc, E_posm, E_posp, E_lng, E_tan, Iground, Title, Xc):
    C = 2.99792458e8
    EPS = 1e-10
    
    nline = len(Post['n1ddiv'])

    if (nline < 1) or (Nfreq < 1):
        return

    # log
    fname = 'near1d.log'
    fp = open(fname, 'wt', encoding='utf-8')

    # plot
    nfig = 0
    for il in range(nline):
        ndiv = Post['n1ddiv'][il]
        posx = np.linspace(Post['n1dposx'][il][0], Post['n1dposx'][il][1], ndiv + 1)
        posy = np.linspace(Post['n1dposy'][il][0], Post['n1dposy'][il][1], ndiv + 1)
        posz = np.linspace(Post['n1dposz'][il][0], Post['n1dposz'][il][1], ndiv + 1)
        compo = Post['n1dcompo'][il]
        #print(ndiv, posx, posy, posz, compo)

        x = np.linspace(0, ndiv, ndiv + 1)
        eh = np.zeros((ndiv + 1, 14), float)

        if Post['n1ddb'] == 1:
            if   compo.startswith('E'):
                strunit = '[dBV/m]'
            elif compo.startswith('H'):
                strunit = '[dBA/m]'
        else:
            if   compo.startswith('E'):
                strunit = '[V/m]'
            elif compo.startswith('H'):
                strunit = '[A/m]'
        
        for ifreq in range(Nfreq):
            # E, H
            k0 = (2 * np.pi * Freq[ifreq]) / C
            for n in range(ndiv + 1):
                pos = np.array([posx[n], posy[n], posz[n]])
                eh[n] = post.nearfield.field(Nfeed, Iplanewave, Planewave, E_posc, E_posm, E_posp, E_tan, E_lng, Iground, Xc, pos, ifreq, k0, Post['n1dnoinc'])

            # log
            _log_n1d(fp, il, Freq[ifreq], compo, ndiv, posx, posy, posz, eh)

            # to dB
            if Post['n1ddb'] == 1:
                for m in [0, 1, 2, 3, 7, 8, 9, 10]:
                    eh[:, m] = 20 * np.log10(np.maximum(eh[:, m], EPS))
            
            # component
            icmp = []
            icol = []
            scmp = []
            if   compo == 'E':
                icmp = [0, 1, 2, 3]
                icol = ['k', 'r', 'g', 'b']
                scmp = ['E', 'Ex', 'Ey', 'Ez']
            elif compo == 'Ex':
                icmp = [1, 4]
            elif compo == 'Ey':
                icmp = [2, 5]
            elif compo == 'Ez':
                icmp = [3, 6]
            elif compo == 'H':
                icmp = [7, 8, 9, 10]
                icol = ['k', 'r', 'g', 'b']
                scmp = ['H', 'Hx', 'Hy', 'Hz']
            elif compo == 'Hx':
                icmp = [8, 11]
            elif compo == 'Hy':
                icmp = [9, 12]
            elif compo == 'Hz':
                icmp = [10, 13]

            # scale
            dmax = np.max(eh[:, icmp[0]], axis=0)
            if Post['n1dscale'][0] == 0:
                # auto scale
                ymax = dmax
                if Post['n1ddb'] == 1:
                    ymin = ymax - 50
                else:
                    ymin = 0
                ydiv = 10
            else:
                # user scale
                ymin = Post['n1dscale'][1]
                ymax = Post['n1dscale'][2]
                ydiv = Post['n1dscale'][3]
            #print(ymin, ymax, ydiv)

            # figure
            nfig += 1
            strfig = 'OpenMOM - near field 1d (%d/%d)' % (nfig, nline * Nfreq)
            fig = plt.figure(strfig, figsize=(Post['w2d'][0], Post['w2d'][1]))
            ax = fig.add_subplot()

            # plot
            if (compo == 'E') or (compo == 'H'):
                # E/Ex/Ey/Ez or H/Hx/Hy/Hz : amplitude only
                for m in range(len(icmp)):
                    ax.plot(x, eh[:, icmp[m]], color=icol[m], label=scmp[m])
                ax.set_xlim(x[0], x[-1])
                ax.set_ylabel(compo + ' ' + strunit)
                ax.set_ylim(ymin, ymax)
                if Post['n1dscale'][0] == 1:
                    ax.set_yticks(np.linspace(ymin, ymax, ydiv + 1))
                ax.legend(loc='best')
                ax.grid(True)
            else:
                # Ex/Ey/Ez/Hx/Hy/Hz : amplitude and phase
                # amplitude
                ax.plot(x, eh[:, icmp[0]], color='k')
                ax.set_xlim(x[0], x[-1])
                ax.set_ylabel('amplitude ' + strunit)
                ax.set_ylim(ymin, ymax)
                if Post['n1dscale'][0] == 1:
                    ax.set_yticks(np.linspace(ymin, ymax, ydiv + 1))
                ax.grid(True)

                # phase
                ax2 = ax.twinx()
                plt.plot(x, eh[:, icmp[1]], color='r', linestyle='--')
                ax2.set_xlim(x[0], x[-1])
                ax2.set_ylabel('phase [deg]', color='r')
                ax2.set_ylim(-180, 180)
                ax2.set_yticks([-180, -135, -90, -45, 0, 45, 90, 135, 180])
                ax2.grid(False)

            # x-label
            xstr = 'point No. : (%g, %g, %g) - (%g, %g, %g) [m]' % (posx[0], posy[0], posz[0], posx[-1], posy[-1], posz[-1])
            ax.set_xlabel(xstr)
            
            # title
            ax.set_title('%s\n%s, f = %.3f%s, max = %.4g%s' %
                (Title, compo, Freq[ifreq] * Post['fscale'], Post['funit'], dmax, strunit))
    
# near1d.log (private)
def _log_n1d(fp, iline, freq, compo, ndiv, posx, posy, posz, eh):

    # header
    fp.write('#%d : frequency[Hz] = %.3e\n' % (iline, freq))
    fp.write(' No.     X[m]        Y[m]        Z[m]       ')
    if compo.startswith('E'):
        fp.write('E[V/m]      Ex[V/m]   Ex[deg]    Ey[V/m]   Ey[deg]    Ez[V/m]   Ez[deg]\n')
    else:
        fp.write('H[A/m]      Hx[A/m]   Hx[deg]    Hy[A/m]   Hy[deg]    Hz[A/m]   Hz[deg]\n')

    # body
    fmt1 = '%4d%12.3e%12.3e%12.3e'
    fmt2 = '%12.4e%12.4e%9.3f%12.4e%9.3f%12.4e%9.3f\n'
    for n in range(ndiv + 1): #= 1 : ndiv + 1
        fp.write(fmt1 % (n, posx[n], posy[n], posz[n]))
        if compo.startswith('E'):
            fp.write(fmt2 % (eh[n, 0], eh[n, 1], eh[n, 4], eh[n, 2], eh[n, 5], eh[n, 3], eh[n, 6]))
            #fp.write(fmt2 % (eh[n, [0, 1, 4, 2, 5, 3, 6]]))
        else:
            fp.write(fmt2 % (eh[n, 7], eh[n, 8], eh[n, 11], eh[n, 9], eh[n, 12], eh[n, 10], eh[n, 13]))
