# -*- coding: utf-8 -*-
"""
plot_far2d.py
"""

import numpy as np
import matplotlib.pyplot as plt
import sol.farfield

C = 2.99792458e8
ETA0 = C * 4 * np.pi * 1e-7
EPS = 1e-10

# far field (3D)
def plot(
    Post, Ne, Nfreq, Freq, Nfeed, Iplanewave, Planewave,
    E_posc, E_lng, E_tan, E_ifeed, E_feed, E_iload, E_load,
    Iground, Title, Z0, Xc, Zin):
    
    if (Post['f2d'][0] < 1) or (Post['f2d'][1] < 2) or (Post['f2d'][2] < 2):
        return

    # log
    fname = 'far2d.log'
    fp = open(fname, 'wt', encoding='utf-8')

    # far field factor
    ffctr = np.zeros(Nfreq, float)
    for ifreq in range(Nfreq):
        ffctr[ifreq] = sol.farfield.factor(Ne, Nfeed, E_ifeed, E_feed, Zin, Z0, ifreq, Post['mloss'])

    # alloc
    nth = Post['f2d'][1] + 1
    nph = Post['f2d'][2] + 1
    pfar = np.zeros((nth, nph, 7), float)
    th = np.linspace(0, 180, nth)
    ph = np.linspace(0, 360, nph)
    
    # plot
    nfig = 0
    for ifreq in range(Nfreq):
        # set data
        for ith in range(nth):
            for iph in range(nph):
                _, pfar[ith, iph, :] = sol.farfield.field(ifreq, Freq[ifreq], th[ith], ph[iph], ffctr[ifreq], Ne, E_posc, E_lng, E_tan, Iground, Xc)

        # skip zero component
        for m in range(7):
            if Post['f2dcompo'][m] == 1:
                if (np.max(np.max(pfar[:, :, m])) < EPS):
                    print('*** %s : max = 0' % Post['farcomp'][m])
                    Post['f2dcompo'][m] = 0
        
        # statistics
        cfar, _ = sol.farfield.field(ifreq, Freq[ifreq], 180 - Planewave[0], Planewave[1] + 180, ffctr[ifreq], Ne, E_posc, E_lng, E_tan, Iground, Xc)
        stat1, stat2 = _statistics(Post, Ne, Nfeed, Iplanewave, Planewave, E_iload, E_load, Xc, nth, nph, ifreq, Freq[ifreq], cfar, pfar)

        # log
        _log_f2d(fp, Freq[ifreq], nth, nph, th, ph, pfar)

        # to dB
        if Post['f2ddb'] == 1:
            pfar = 10 * np.log10(np.maximum(pfar, EPS))

        for m in range(7):
            if Post['f2dcompo'][m] == 0:
                continue

            # scale and boresight
            rmax = -float('inf')
            thmax = 0
            phmax = 0
            for ith in range(nth):
                for iph in range(nph):
                    if (pfar[ith, iph, m] > rmax):
                        rmax = pfar[ith, iph, m]
                        thmax = th[ith]
                        phmax = ph[iph]

            if Post['f2ddb'] == 1:
                rmin = rmax - abs(Post['f2dscale'][2] - Post['f2dscale'][1])
            else:
                rmin = 0

            # figure
            nfig += 1
            strfig = 'OpenMOM - far field (3D) (%d/%d)' % (nfig, Nfreq * sum(Post['f2dcompo']))
            fig = plt.figure(strfig, figsize=(Post['w3d'][0], Post['w3d'][1]))
            ax = fig.add_subplot(projection='3d')

            ax.view_init(elev = 90 - Post['w3d'][2], azim = Post['w3d'][3], roll = 0)
    
            # plot
            ph2d, th2d = np.meshgrid(np.deg2rad(ph), np.deg2rad(th))
            r = np.maximum((pfar[:, :, m] - rmin) / (rmax - rmin), 0)
            x = r * np.cos(ph2d) * np.sin(th2d)
            y = r * np.sin(ph2d) * np.sin(th2d)
            z = r                * np.cos(th2d)
            ax.plot_surface(x, y, z)#, cmap='rainbow')
            ax.set_aspect('equal')

            # title
            str2 = '%s, f = %.3g%s' % (Post['farcomp'][m], Freq[ifreq] * Post['fscale'], Post['funit'])
            str5 = 'max = %.4g%s @ (theta, phi) = (%.1f, %.1f)[deg]' % (rmax, Post['f2dunit'], thmax, phmax)
            ax.set_title('%s\n%s\n%s\n%s\n%s' % (Title, str2, stat1, stat2, str5))
        
            # label
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')

# statistics (private)
def _statistics(Post, Ne, Nfeed, Iplanewave, Planewave, E_iload, E_load, Xc, nth, nph, ifreq, freq, cfar, pfar):
    stat1 = ''
    stat2 = ''

    pmax = 0
    psum = 0
    fctr = (np.pi / (nth - 1)) * (2 * np.pi / (nph - 1)) / (4 * np.pi)
    for ith in range(nth): #= 1 : nth
        for iph in range(nph - 1): #= 1 : nph - 1
            th = np.pi * ith / (nth - 1)
            powf = pfar[ith, iph, 0]
            psum += fctr * np.sin(th) * powf
            pmax = np.maximum(powf, pmax)

    if Nfeed > 0:
        # feed
        pmax = pmax / psum
        if Post['f2ddb'] == 1:
            pmaxdb = 10 * np.log10(np.maximum(pmax, EPS))
            stat1 = 'directive gain = %.3f[dBi]' % pmaxdb
        else:
            stat1 = 'directive gain = %.3f' % pmax
        stat2 = 'efficiency = %.3f[%%]' % (psum * 100)
    else:
        # plane wave
        if Post['f2ddb'] == 1:
            psumdb = 10 * np.log10(np.maximum(psum, EPS))
            stat1 = 'total cross section = %.3f[dBm^2]' % psumdb
        else:
            stat1 = 'total cross section = %.3e[m^2]' % psum

        # power loss by R
        ploss = 0
        for ie in range(Ne):
            if E_iload[ie] == 1:  # = R
                ploss += 0.5 * E_load[ie] * np.abs(Xc[ifreq, ie])**2

        # optical theorem error
        k0 = (2 * np.pi * freq) / C
        rhs = _opterror(Iplanewave, Planewave, k0, cfar)
        ei = 1
        lhs = psum + (2 * ETA0 * ploss) / ei**2
        #print(rhs, lhs)
        oerr = np.abs(1 - (rhs / lhs))
        stat2 = 'optical theorem error = %.3f[%%]' % (oerr * 100)

    return stat1, stat2

# optical error (private)
def _opterror(Iplanewave, Planewave, k0, cfar):
    pol = Iplanewave
    a   = Planewave[2]
    r   = Planewave[3]

    etheta = cfar[0]
    ephi   = cfar[1]
    polfctr = 1
    if   pol == 1:
        polfctr = etheta.imag
    elif pol == 2:
        polfctr = ephi.imag
    elif pol == 3:
        polfctr = np.sqrt(0.5) * (-etheta.imag - ephi.real)
    elif pol == 4:
        polfctr = np.sqrt(0.5) * (-etheta.imag + ephi.real)
    elif pol == 5:
        cosa = np.cos(np.deg2rad(a))
        sina = np.sin(np.deg2rad(a))
        emajor = 1j * (-cosa * etheta + sina * ephi)
        eminor = 1j * (-sina * etheta - cosa * ephi)
        polfctr = -(emajor + (1j * r * eminor)).real / np.sqrt(1 + r**2)

    oerr = (np.sqrt(4 * np.pi) / k0) * polfctr

    return oerr

# far2d.log (private)
def _log_f2d(fp, freq, nth, nph, th, ph, pfar):
    # header
    fp.write('frequency[Hz] = %.3e\n' % freq)
    fp.write(' No. No. theta[deg] phi[deg]   E-abs[dB]  E-theta[dB]    E-phi[dB]  E-major[dB]  E-minor[dB]   E-RHCP[dB]   E-LHCP[dB] AxialRatio[dB]\n')

    # body
    fmt = '%4d%4d %9.1f%9.1f%13.4f%13.4f%13.4f%13.4f%13.4f%13.4f%13.4f%13.4f\n'
    for ith in range(nth):
        for iph in range(nph):
            pdb = 10 * np.log10(np.maximum(pfar[ith, iph, :], EPS))
            fp.write(fmt % \
                (ith, iph, th[ith], ph[iph], pdb[0], pdb[1], pdb[2], pdb[3], pdb[4], pdb[5], pdb[6], pdb[3] - pdb[4]))
