# -*- coding: utf-8 -*-
"""
plot_near2d.py
"""

import numpy as np
import matplotlib.pyplot as plt
import post.nearfield

C = 2.99792458e8
EPS = 1e-10

# near field on planes (2D only)
# TODO : 3D, animation
def plot(Post, Nfreq, Freq, Nfeed, Iplanewave, Planewave, E_posc, E_posm, E_posp, E_lng, E_tan, Iground, Title, Xc):
    nplane = len(Post['n2dcompo'])

    if (nplane < 1) or (Nfreq < 1):
        return
    if (Post['n2dcontour'] < 0) or (Post['n2dcontour'] > 3):
        return

    # log
    fname = 'near2d.log'
    fp = open(fname, 'wt', encoding='utf-8')

    # No. of figures
    nfig = 0
    for ip in range(nplane):
        compo = Post['n2dcompo'][ip]
        if (compo == 'E') or (compo == 'H'):
            nfig += 1 * Nfreq  # amplitude
        else:
            nfig += 2 * Nfreq  # amplitude and phase

    ifig = 0
    for ip in range(nplane): #= 1 : nplane
        # setup
        cdir  = Post['n2ddir'][ip]  # ='X'/'Y'/'Z'
        compo = Post['n2dcompo'][ip]
        ndiv  = Post['n2ddiv'][ip]
        pos   = Post['n2dpos'][ip]
        pos0  = pos[0]
        pos1  = np.linspace(pos[1], pos[2], ndiv[0] + 1)
        pos2  = np.linspace(pos[3], pos[4], ndiv[1] + 1)
        pos2d = np.zeros((ndiv[0] + 1, ndiv[1] + 1,  3), float) # 3D position : row = horizontal, column = vertical
        eh    = np.zeros((ndiv[1] + 1, ndiv[0] + 1, 14), float)  # contour : row = vertical, column = horizontal

        # unit
        strunit = ['', '']
        if Post['n2ddb'] == 1:
            if   compo.startswith('E'):
                strunit[0] = '[dBV/m]'
            elif compo.startswith('H'):
                strunit[0] = '[dBA/m]'
        else:
            if   compo.startswith('E'):
                strunit[0] = '[V/m]'
            elif compo.startswith('H'):
                strunit[0] = '[A/m]'
        strunit[1] = '[deg]'

        # 3D position
        for n1 in range(ndiv[0] + 1): #= 1 : ndiv(1) + 1
            for n2 in range(ndiv[1] + 1): #= 1 : ndiv(2) + 1
                posx = posy = posz = 0
                if   (cdir == 'X'):
                    # Y-Z
                    posx = pos0
                    posy = pos1[n1]
                    posz = pos2[n2]
                elif (cdir == 'Y'):
                    # X-Z
                    posx = pos1[n1]
                    posy = pos0
                    posz = pos2[n2]
                elif (cdir == 'Z'):
                    # X-Y
                    posx = pos1[n1]
                    posy = pos2[n2]
                    posz = pos0
                pos2d[n1, n2, 0] = posx
                pos2d[n1, n2, 1] = posy
                pos2d[n1, n2, 2] = posz

        # direction
        str1 = str2 = ''
        if   (cdir == 'X'):
            str1 = 'Y'
            str2 = 'Z'
        elif (cdir == 'Y'):
            str1 = 'X'
            str2 = 'Z'
        elif (cdir == 'Z'):
            str1 = 'X'
            str2 = 'Y'

        # component index
        icmp = []
        if   compo == 'E':
            icmp = [0]
        elif compo == 'Ex':
            icmp = [1, 4]
        elif compo == 'Ey':
            icmp = [2, 5]
        elif compo == 'Ez':
            icmp = [3, 6]
        elif compo == 'H':
            icmp = [7]
        elif compo == 'Hx':
            icmp = [8, 11]
        elif compo == 'Hy':
            icmp = [9, 12]
        elif compo == 'Hz':
            icmp = [10, 13]
        
        for ifreq in range(Nfreq):
            # E, H
            k0 = (2 * np.pi * Freq[ifreq]) / C
            for n1 in range(ndiv[0] + 1):
                for n2 in range(ndiv[1] + 1):
                    pos = np.array([pos2d[n1, n2, 0], pos2d[n1, n2, 1], pos2d[n1, n2, 2]])
                    eh[n2, n1] = post.nearfield.field(Nfeed, Iplanewave, Planewave, E_posc, E_posm, E_posp, E_tan, E_lng, Iground, Xc, pos, ifreq, k0, Post['n1dnoinc'])
            
            # log
            _log_n2d(fp, ip, Freq[ifreq], compo, ndiv, pos2d, eh)

            # dBに変換する
            if Post['n2ddb'] == 1:
                for m in [0, 1, 2, 3, 7, 8, 9, 10]:
                    eh[:, :, m] = 20 * np.log10(np.maximum(eh[:, :, m], EPS))

            # スケール
            dmax = np.max(eh[:, :, icmp[0]])
            if Post['n2dscale'][0] == 0:
                # 自動スケール
                fmax = dmax
                if Post['n2ddb'] == 1:
                    fmin = fmax - 30
                else:
                    fmin = 0
            else:
                # 指定スケール
                fmin = Post['n2dscale'][1]
                fmax = Post['n2dscale'][2]
            # 一定値のときは図形表示しない
            #print(fmin, fmax, dmax)
            if abs(fmin - fmax) < EPS:
                print('*** constant data : %s = %e\n' % (compo, fmin))
                continue
            # 最小値・最大値で足切り
            eh = np.minimum(eh, fmax)
            eh = np.maximum(eh, fmin)

            # plot, m = 0/1 : amplitude/phase
            for m in range(len(icmp)):
                # figure
                ifig += 1
                strfig = 'OpenMOM - near field 2d (%d/%d)' % (ifig, nfig)
                fig = plt.figure(strfig, figsize=(Post['w2d'][0], Post['w2d'][1]))
                ax = fig.add_subplot()

                if m == 0:
                    levels = np.linspace(fmin, fmax, Post['n2dscale'][3] + 1)
                else:
                    levels = np.linspace(-180, 180, Post['n2dscale'][3] + 1)
 
                # contour
                if   (Post['n2dcontour'] == 0) or (Post['n2dcontour'] == 2):
                    CS = ax.contourf(pos1, pos2, eh[:, :, icmp[m]], levels, cmap='rainbow')
                elif (Post['n2dcontour'] == 1) or (Post['n2dcontour'] == 3):
                    CS = ax.contour(pos1, pos2, eh[:, :, icmp[m]], levels, cmap='rainbow')
                ax.set_aspect('equal')

                # color bar
                if m == 0:
                    cbar = fig.colorbar(CS)
                else:
                    cbar = fig.colorbar(CS, ticks=[-180, -120, -60, 0, 60, 120, 180])
                cbar.set_label(compo + ' ' + strunit[m])

                # overplot geometry
                if Post['n2dobj'] == 1:
                    _n2dobj(ax, cdir, [pos1[0], pos1[-1]], [pos2[0], pos2[-1]], E_posm, E_posp)

                # X,Y label
                ax.set_xlabel(str1 + '[m]')
                ax.set_ylabel(str2 + '[m]')

                # title
                strmax = ', max = %.4g%s' % (dmax, strunit[0]) if m == 0 else ''
                ax.set_title('%s\n%s = %g[m], f = %.3f%s%s' % \
                    (Title, cdir, pos0, Freq[ifreq] * Post['fscale'], Post['funit'], strmax))

# overplot geometry (private)
def _n2dobj(ax, cdir, xs, ys, E_posm, E_posp):
    x1 = y1 = x2 = y2 = []
    if   cdir == 'X':
        x1 = E_posm[:, 1]
        y1 = E_posm[:, 2]
        x2 = E_posp[:, 1]
        y2 = E_posp[:, 2]
    elif cdir == 'Y':
        x1 = E_posm[:, 0]
        y1 = E_posm[:, 2]
        x2 = E_posp[:, 0]
        y2 = E_posp[:, 2]
    elif cdir == 'Z':
        x1 = E_posm[:, 0]
        y1 = E_posm[:, 1]
        x2 = E_posp[:, 0]
        y2 = E_posp[:, 1]

    ias = \
        ((x1 - xs[0]) * (x1 - xs[1]) <= EPS) & \
        ((x2 - xs[0]) * (x2 - xs[1]) <= EPS) & \
        ((y1 - ys[0]) * (y1 - ys[1]) <= EPS) & \
        ((y2 - ys[0]) * (y2 - ys[1]) <= EPS)
    

    ax.plot([x1[ias], x2[ias]], [y1[ias], y2[ias]], color='k', lw=1)

# log (private)
def _log_n2d(fp, ip, freq, compo, ndiv, pos2d, eh):
    # header
    fp.write('#%d : frequency[Hz] = %.3e\n' % (ip + 1, freq))
    fp.write('  No.  No.     X[m]        Y[m]        Z[m]       ')
    if   compo.startswith('E'):
        fp.write('E[V/m]      Ex[V/m]   Ex[deg]    Ey[V/m]   Ey[deg]    Ez[V/m]   Ez[deg]\n')
    elif compo.startswith('H'):
        fp.write('H[A/m]      Hx[A/m]   Hx[deg]    Hy[A/m]   Hy[deg]    Hz[A/m]   Hz[deg]\n')

    # body
    fmt1 = '%5d%5d%12.3e%12.3e%12.3e'
    fmt2 = '%12.4e%12.4e%9.3f%12.4e%9.3f%12.4e%9.3f\n'
    for n1 in range(ndiv[0] + 1):
        for n2 in range(ndiv[1] + 1):
            p = pos2d[n1, n2, :]
            f = eh[n2, n1, :]
            fp.write(fmt1 % (n1, n2, p[0], p[1], p[2]))
            if   compo.startswith('E'):
                fp.write(fmt2 % (f[0], f[1], f[4], f[2], f[5], f[3], f[6]))
            elif compo.startswith('H'):
                fp.write(fmt2 % (f[7], f[8], f[11], f[9], f[12], f[10], f[13]))
