# -*- coding: utf-8 -*-
# make wire grid model

import math
import numpy as np
from numba import jit, prange

EPS = 1e-8

@jit(cache=True, nopython=True, nogil=True, parallel=True)
#@jit(cache=True, nopython=True)
def makedata(
    Iradiusall, Radiusall, Ngeom, G_gtype, G_cosys, G_pos, G_div, G_ifeed, G_feed, G_iload, G_load, G_iradius, G_radius, G_offset, Iground):

    # max array size
    dim = 0
    for ig in range(Ngeom):
        n12 = G_div[ig][0]
        n14 = G_div[ig][1]
        if   G_gtype[ig] == 1:
            dim += n12
        elif G_gtype[ig] == 2:
            dim += n12 * (n14 + 1) \
                 + n14 * (n12 + 1)
    #print(dim)

    # alloc
    e_posc  = np.zeros((dim, 3), 'f8')  # center position
    e_posm  = np.zeros((dim, 3), 'f8')  # minus position
    e_posp  = np.zeros((dim, 3), 'f8')  # plus position
    e_lng   = np.zeros((dim, 3), 'f8')  # center/minus/plus length
    e_tan   = np.zeros((dim, 3), 'f8')  # tangential vector
    e_rad   = np.zeros( dim,     'f8')  # wire radius
    e_ifeed = np.zeros( dim,     'i4')  # 0/1
    e_feed  = np.zeros((dim, 2), 'f8')  # V, deg
    e_iload = np.zeros( dim,     'i4')  # 0/1/2/3
    e_load  = np.zeros( dim,     'f8')  # R/L/C

    ne = 0
    for ig in range(Ngeom):
        if G_gtype[ig] == 1:
            # 線状ユニット
            n12 = G_div[ig][0]
            for i12 in range(n12):
                # minus/plus node position
                a12m = (i12 - 0) / n12
                a12p = (i12 + 1) / n12
                cm = (1 - a12m) * G_pos[ig][0] \
                   + (    a12m) * G_pos[ig][1]
                cp = (1 - a12p) * G_pos[ig][0] \
                   + (    a12p) * G_pos[ig][1]

                # set wire element data
                e_posc[ne], e_posm[ne], e_posp[ne], e_lng[ne], e_tan[ne], e_rad[ne], e_ifeed[ne], e_feed[ne], e_iload[ne], e_load[ne] \
                    = _set_element(Iradiusall, Radiusall, G_cosys[ig], G_iradius[ig], G_radius[ig], G_offset[ig], cm, cp)

                # skip zero length
                if (e_lng[ne][0] < 1e-8):
                    continue

                # skip double count (N^2)
                if (_count2(ne, e_posm, e_posp)):
                    continue

                # skip element on the ground
                if (Iground == 1) and (np.abs(e_posm[ne][2]) < 1e-8) and (np.abs(e_posp[ne][2]) < 1e-8):
                    continue

                # feed
                if (G_ifeed[ig] == 1) and ((2 * i12 == n12 - 2) or (2 * i12 == n12 - 1)):
                    e_ifeed[ne] = G_ifeed[ig]
                    e_feed[ne] = G_feed[ig]

                # load
                if (G_iload[ig]  > 0) and ((2 * i12 == n12 - 2) or (2 * i12 == n12 - 1)):
                    e_iload[ne] = G_iload[ig]
                    e_load[ne] = G_load[ig]

                # number of elements
                ne += 1

        elif (G_gtype[ig] == 2):
            # 面状ユニット
            n12 = G_div[ig][0]
            n14 = G_div[ig][1]

            # (1) direction 1->2
            for i14 in range(n14 + 1):
                for i12 in range(n12):
                    # minus/plus node position
                    a12m = (i12 + 0) / n12
                    a12p = (i12 + 1) / n12
                    a14  = (i14 + 0) / n14
                    cm = (1 - a12m) * (1 - a14) * G_pos[ig][0] \
                       + (    a12m) * (1 - a14) * G_pos[ig][1] \
                       + (    a12m) * (    a14) * G_pos[ig][2] \
                       + (1 - a12m) * (    a14) * G_pos[ig][3]
                    cp = (1 - a12p) * (1 - a14) * G_pos[ig][0] \
                       + (    a12p) * (1 - a14) * G_pos[ig][1] \
                       + (    a12p) * (    a14) * G_pos[ig][2] \
                       + (1 - a12p) * (    a14) * G_pos[ig][3]

                    # set wire element data
                    e_posc[ne], e_posm[ne], e_posp[ne], e_lng[ne], e_tan[ne], e_rad[ne], e_ifeed[ne], e_feed[ne], e_iload[ne], e_load[ne] \
                        = _set_element(Iradiusall, Radiusall, G_cosys[ig], G_iradius[ig], G_radius[ig], G_offset[ig], cm, cp)

                    # skip zero length
                    if (e_lng[ne, 0] < 1e-8):
                        continue

                    # skip double count (N^2)
                    if (_count2(ne, e_posm, e_posp)):
                        continue

                    # skip element on the ground
                    if (Iground == 1) and (np.abs(e_posm[ne][2]) < 1e-8) and (np.abs(e_posp[ne][2]) < 1e-8):
                        continue

                    # number of elements
                    ne += 1

            # (2) direction 1->4        
            for i12 in range(n12 + 1):
                for i14 in range(n14):
                    # minus/plus node position
                    a14m = (i14 + 0) / n14
                    a14p = (i14 + 1) / n14
                    a12  = (i12 + 0) / n12
                    cm = (1 - a14m) * (1 - a12) * G_pos[ig][0] \
                       + (    a14m) * (1 - a12) * G_pos[ig][3] \
                       + (    a14m) * (    a12) * G_pos[ig][2] \
                       + (1 - a14m) * (    a12) * G_pos[ig][1]
                    cp = (1 - a14p) * (1 - a12) * G_pos[ig][0] \
                       + (    a14p) * (1 - a12) * G_pos[ig][3] \
                       + (    a14p) * (    a12) * G_pos[ig][2] \
                       + (1 - a14p) * (    a12) * G_pos[ig][1]

                    # set wire element data
                    e_posc[ne], e_posm[ne], e_posp[ne], e_lng[ne], e_tan[ne], e_rad[ne], e_ifeed[ne], e_feed[ne], e_iload[ne], e_load[ne] \
                        = _set_element(Iradiusall, Radiusall, G_cosys[ig], G_iradius[ig], G_radius[ig], G_offset[ig], cm, cp)

                    # skip zero length
                    if (e_lng[ne, 0] < 1e-8):
                        continue

                    # skip double count (N^2)
                    if (_count2(ne, e_posm, e_posp)):
                        continue

                    # skip element on the ground
                    if (Iground == 1) and (np.abs(e_posm[ne][2]) < 1e-8) and (np.abs(e_posp[ne][2]) < 1e-8):
                        continue

                    # number of elements
                    ne += 1

    # 要素数
    Ne = ne

    # delete unused array
    #if ne < dim:
    E_posc  = np.zeros((Ne, 3), 'f8')  # center position
    E_posm  = np.zeros((Ne, 3), 'f8')  # minus position
    E_posp  = np.zeros((Ne, 3), 'f8')  # plus position
    E_lng   = np.zeros((Ne, 3), 'f8')  # center/minus/plus length
    E_tan   = np.zeros((Ne, 3), 'f8')  # tangential vector
    E_rad   = np.zeros( Ne,     'f8')  # wire radius
    E_ifeed = np.zeros( Ne,     'i4')    # 0/1
    E_feed  = np.zeros((Ne, 2), 'f8')  # V, deg
    E_iload = np.zeros( Ne,     'i4')    # 0/1/2/3
    E_load  = np.zeros( Ne,     'f8')  # R/L/C

    E_posc[: Ne]  = e_posc[: Ne]
    E_posm[: Ne]  = e_posm[: Ne]
    E_posp[: Ne]  = e_posp[: Ne]
    E_lng[: Ne]   = e_lng[: Ne]
    E_tan[: Ne]   = e_tan[: Ne]
    E_rad[: Ne]   = e_rad[: Ne]
    E_ifeed[: Ne] = e_ifeed[: Ne]
    E_feed[: Ne]  = e_feed[: Ne]
    E_iload[: Ne] = e_iload[: Ne]
    E_load[: Ne]  = e_load[: Ne]

    # -節点と+節点の要素長
    for ie in prange(Ne):
        # -節点の要素長
        pos = E_posm[ie]
        neqm = 0
        leqm = 0
        for je in range(Ne):
            dm = abs(E_posm[je, 0] - pos[0]) \
               + abs(E_posm[je, 1] - pos[1]) \
               + abs(E_posm[je, 2] - pos[2])
            dp = abs(E_posp[je, 0] - pos[0]) \
               + abs(E_posp[je, 1] - pos[1]) \
               + abs(E_posp[je, 2] - pos[2])
            if (dm < EPS) or (dp < EPS):
            #if np.sum(np.abs(E_posm[je] - pos)) < EPS or \
            #   np.sum(np.abs(E_posp[je] - pos)) < EPS:
                neqm += 1
                leqm += E_lng[je, 0]
        #assert(neqm > 0)
        E_lng[ie, 1] = leqm / neqm
        
        # +節点の要素長
        pos = E_posp[ie]
        neqp = 0
        leqp = 0
        for je in range(Ne):
            dm = abs(E_posm[je, 0] - pos[0]) \
               + abs(E_posm[je, 1] - pos[1]) \
               + abs(E_posm[je, 2] - pos[2])
            dp = abs(E_posp[je, 0] - pos[0]) \
               + abs(E_posp[je, 1] - pos[1]) \
               + abs(E_posp[je, 2] - pos[2])
            if (dm < EPS) or (dp < EPS):
            #if np.sum(np.abs(E_posm[je] - pos)) < EPS or \
            #   np.sum(np.abs(E_posp[je] - pos)) < EPS:
                neqp += 1
                leqp += E_lng[je, 0]
        #assert(neqp > 0)
        E_lng[ie, 2] = leqp / neqp
        #print(ie, E_lng[ie])

    # 給電点の数と負荷の数
    Nfeed = np.sum(E_ifeed > 0)
    Nload = np.sum(E_iload > 0)
    #print(Nfeed, Nload)

    return Ne, Nfeed, Nload, E_posc, E_posm, E_posp, E_lng, E_tan, E_rad, E_ifeed, E_feed, E_iload, E_load

# (private) set wire element data
@jit(cache=True, nopython=True)
def _set_element(iradiusall, radiusall, cosys, iradius, radius, offset, cm, cp):

    # to xyz
    dm = _conv2xyz(cosys, cm)
    dp = _conv2xyz(cosys, cp)

    # position
    eposm = dm + offset
    eposp = dp + offset
    eposc = (eposm + eposp) / 2

    # center length
    #elng = [np.linalg.norm(eposp - eposm), 0, 0]
    elng = np.zeros(3, 'f8')
    elng[0] = math.sqrt( \
        (eposp[0] - eposm[0])**2 + \
        (eposp[1] - eposm[1])**2 + \
        (eposp[2] - eposm[2])**2)
    

    # tangent vector
    etan = (eposp - eposm) / elng[0]

    # radius
    if   iradius == 0:
        if   iradiusall == 1:
            erad = radiusall
        elif iradiusall == 2:
            erad = radiusall * elng[0]
        else:
            erad = 0.2 * elng[0]
    elif iradius == 1:
        erad = radius

    # initialze feed/load
    ifeed = 0
    efeed = np.zeros(2, 'f8')
    iload = 0
    eload = 0

    return eposc, eposm, eposp, elng, etan, erad, ifeed, efeed, iload, eload

# (private) coordinate conversion to XYZ
@jit(cache=True, nopython=True)
def _conv2xyz(cosys, pos):
    xyz = np.zeros(3, 'f8')

    if   (cosys == 1):
        xyz = np.copy(pos)
    elif (cosys == 2):
        xyz[0] = pos[0] * np.cos(np.deg2rad(pos[1]))
        xyz[1] = pos[0] * np.sin(np.deg2rad(pos[1]))
        xyz[2] = pos[2]
    elif (cosys == 3):
        xyz[0] = pos[0] * np.sin(np.deg2rad(pos[1])) * np.cos(np.deg2rad(pos[2]))
        xyz[1] = pos[0] * np.sin(np.deg2rad(pos[1])) * np.sin(np.deg2rad(pos[2]))
        xyz[2] = pos[0] * np.cos(np.deg2rad(pos[1]))

    return xyz

# (private) check double count of elements
#@jit(cache=True, nopython=True, nogil=True, parallel=True)
@jit(cache=True, nopython=True)
def _count2(ie, E_posm, E_posp):

    if ie <= 0:
        return False

    a1 = E_posm[ie]
    a2 = E_posp[ie]

    for je in range(ie):
        b1 = E_posm[je]
        b2 = E_posp[je]
        d11 = abs(a1[0] - b1[0]) + abs(a1[1] - b1[1]) + abs(a1[2] - b1[2])
        d12 = abs(a1[0] - b2[0]) + abs(a1[1] - b2[1]) + abs(a1[2] - b2[2])
        d21 = abs(a2[0] - b1[0]) + abs(a2[1] - b1[1]) + abs(a2[2] - b1[2])
        d22 = abs(a2[0] - b2[0]) + abs(a2[1] - b2[1]) + abs(a2[2] - b2[2])
        #d11 = sum(abs(a1 - b1))
        #d12 = sum(abs(a1 - b2))
        #d21 = sum(abs(a2 - b1))
        #d22 = sum(abs(a2 - b2))
        if (d11 < EPS and d22 < EPS) or (d12 < EPS and d21 < EPS):
            return True

    return False
