# -*- coding: utf-8 -*-
# zmatrix.py

import math
import numpy as np
from numba import jit, prange

EPS = 1e-6
C = 2.99792458e8
ETA0 = C * 4 * math.pi * 1e-7

# インピーダンス行列
@jit(cache=True, nopython=True, nogil=True, parallel=True)
#@jit(cache=True, nopython=True)
def zmatrix(Z, Ne, E_posc, E_posm, E_posp, E_lng, E_tan, E_rad, E_iload, E_load, Iground, freq):

    #C = 2.99792458e8
    #ETA0 = C * 4 * math.pi * 1e-7
    
    omega = 2 * math.pi * freq
    kwave = omega / C
    """
    for i in prange(Ne):
        #for j in range(Ne):
        for j in range(i + 1):  # 左下+対角成分
            Z[i, j] = _zelement(0, i, j, kwave, E_posc, E_posm, E_posp, E_lng, E_tan, E_rad)

            # グラウンド
            if Iground > 0:
                Z[i, j] -= _zelement(1, i, j, kwave, E_posc, E_posm, E_posp, E_lng, E_tan, E_rad)
    """
    _zmatrix(Z, 0, Ne, kwave, E_posc, E_posm, E_posp, E_lng, E_tan, E_rad)

    # グラウンド
    if Iground > 0:
        _zmatrix(Z, 1, Ne, kwave, E_posc, E_posm, E_posp, E_lng, E_tan, E_rad)

    
    # 対称行列, 右上=左下
    for i in prange(Ne):
        for j in range(i + 1, Ne):
            Z[i, j] = Z[j, i]
    

    # factor
    #Z *= -1j * ETA0 / kwave

    # 負荷RLC
    for ie in range(Ne):
        zload = E_load[ie]
        if   E_iload[ie] == 1:  # R
            Z[ie, ie] += zload
        elif E_iload[ie] == 2:  # L
            Z[ie, ie] += +1j * omega * zload
        elif E_iload[ie] == 3 and (abs(zload) > 1e-16):  # C
            Z[ie, ie] += -1j / (omega * zload)

"""
# (private) 行列要素
@jit(cache=True, nopython=True)
def _zelement(gnd, i, j, kwave, E_posc, E_posm, E_posp, E_lng, E_tan, E_rad):

    # グラウンド因子
    sgn = -1 if gnd else 1

    # X/Y/Z座標(修正されるのでcopy)
    pci = E_posc[i].copy()
    pcj = E_posc[j].copy()
    pmi = E_posm[i].copy()
    pmj = E_posm[j].copy()
    ppi = E_posp[i].copy()
    ppj = E_posp[j].copy()

    # グラウンド:Z座標反転
    pcj[2] *= sgn
    pmj[2] *= sgn
    ppj[2] *= sgn

    # 要素長 (中心/-節点/+節点)
    lci = E_lng[i, 0]
    lcj = E_lng[j, 0]
    lmi = E_lng[i, 1]
    lmj = E_lng[j, 1]
    lpi = E_lng[i, 2]
    lpj = E_lng[j, 2]

    # 導線半径
    rwire = (E_rad[i] + E_rad[j]) / 2

    # ij座標差
    dcc = pci - pcj
    dmm = pmi - pmj
    dmp = pmi - ppj
    dpm = ppi - pmj
    dpp = ppi - ppj

    # ij距離
    rcc = math.sqrt(dcc[0]**2 + dcc[1]**2 + dcc[2]**2)
    rmm = math.sqrt(dmm[0]**2 + dmm[1]**2 + dmm[2]**2)
    rmp = math.sqrt(dmp[0]**2 + dmp[1]**2 + dmp[2]**2)
    rpm = math.sqrt(dpm[0]**2 + dpm[1]**2 + dpm[2]**2)
    rpp = math.sqrt(dpp[0]**2 + dpp[1]**2 + dpp[2]**2)

    # Green関数
    gcc = _green(rcc, lci, lcj, kwave, rwire)
    gmm = _green(rmm, lmi, lmj, kwave, rwire)
    gmp = _green(rmp, lmi, lpj, kwave, rwire)
    gpm = _green(rpm, lpi, lmj, kwave, rwire)
    gpp = _green(rpp, lpi, lpj, kwave, rwire)

    aip = (E_tan[i, 0] * E_tan[j, 0]) \
        + (E_tan[i, 1] * E_tan[j, 1]) \
        + (E_tan[i, 2] * E_tan[j, 2]) * sgn

    wss = (kwave * lci) * (kwave * lcj)

    return gpp - gpm - gmp + gmm - (wss * aip * gcc)
"""

# (private) Green関数
@jit(cache=True, nopython=True)
def _green(r, s1, s2, kwave, rwire):

    arg = kwave * r
    """
    if arg > EPS:
        return (math.cos(arg) - 1j * math.sin(arg)) / (4 * math.pi * r)
    else:
        ds = (s1 + s2) / 2
        a = ds / (2 * rwire)
        return math.log(a + math.sqrt(a**2 + 1)) / (2 * math.pi * ds) \
          - 1j * kwave / (4 * math.pi)
    """
    if arg > EPS:
        return (np.cos(arg) - 1j * np.sin(arg)) / (4 * np.pi * r)
    else:
        ds = (s1 + s2) / 2
        a = ds / (2 * rwire)
        return np.log(a + np.sqrt(a**2 + 1)) / (2 * np.pi * ds) \
          - 1j * kwave / (4 * np.pi)

# (private) 行列要素
@jit(cache=True, nopython=True, nogil=True, parallel=True)
#@jit(cache=True, nopython=True)
def _zmatrix(Z, gnd, Ne, kwave, E_posc, E_posm, E_posp, E_lng, E_tan, E_rad):

    # グラウンド因子
    sgn = -1 if gnd else 1

    #
    cfctr = -1j * ETA0 / kwave

    pci = np.zeros(3, 'f8')
    pcj = np.zeros(3, 'f8')
    pmi = np.zeros(3, 'f8')
    pmj = np.zeros(3, 'f8')
    ppi = np.zeros(3, 'f8')
    ppj = np.zeros(3, 'f8')

    dcc = np.zeros(3, 'f8')
    dmm = np.zeros(3, 'f8')
    dmp = np.zeros(3, 'f8')
    dpm = np.zeros(3, 'f8')
    dpp = np.zeros(3, 'f8')

    for i in prange(Ne):
        #for j in range(Ne):
        for j in range(i + 1):  # 左下+対角成分

            # X/Y/Z座標(修正されるのでcopy、遅いので不可)
            """
            pci = E_posc[i]#.copy()
            pcj = E_posc[j]#.copy()
            pmi = E_posm[i]#.copy()
            pmj = E_posm[j]#.copy()
            ppi = E_posp[i]#.copy()
            ppj = E_posp[j]#.copy()
            """
            # 遅いので不可([:]なくても同じ)
            """
            pci[:] = E_posc[i, :]
            pcj[:] = E_posc[j, :]
            pmi[:] = E_posm[i, :]
            pmj[:] = E_posm[j, :]
            ppi[:] = E_posp[i, :]
            ppj[:] = E_posp[j, :]
            """
            for k in range(3):
                pci[k] = E_posc[i, k]
                pcj[k] = E_posc[j, k]
                pmi[k] = E_posm[i, k]
                pmj[k] = E_posm[j, k]
                ppi[k] = E_posp[i, k]
                ppj[k] = E_posp[j, k]

            # グラウンド:Z座標反転
            pcj[2] *= sgn
            pmj[2] *= sgn
            ppj[2] *= sgn

            # 要素長 (中心/-節点/+節点)
            lci = E_lng[i, 0]
            lcj = E_lng[j, 0]
            lmi = E_lng[i, 1]
            lmj = E_lng[j, 1]
            lpi = E_lng[i, 2]
            lpj = E_lng[j, 2]

            # 導線半径
            rwire = (E_rad[i] + E_rad[j]) / 2

            # ij座標差
            #dcc = pci - pcj
            #dmm = pmi - pmj
            #dmp = pmi - ppj
            #dpm = ppi - pmj
            #dpp = ppi - ppj
            for k in range(3):
                dcc[k] = pci[k] - pcj[k]
                dmm[k] = pmi[k] - pmj[k]
                dmp[k] = pmi[k] - ppj[k]
                dpm[k] = ppi[k] - pmj[k]
                dpp[k] = ppi[k] - ppj[k]

            # ij距離
            rcc = math.sqrt(dcc[0]**2 + dcc[1]**2 + dcc[2]**2)
            rmm = math.sqrt(dmm[0]**2 + dmm[1]**2 + dmm[2]**2)
            rmp = math.sqrt(dmp[0]**2 + dmp[1]**2 + dmp[2]**2)
            rpm = math.sqrt(dpm[0]**2 + dpm[1]**2 + dpm[2]**2)
            rpp = math.sqrt(dpp[0]**2 + dpp[1]**2 + dpp[2]**2)

            # Green関数
            gcc = _green(rcc, lci, lcj, kwave, rwire)
            gmm = _green(rmm, lmi, lmj, kwave, rwire)
            gmp = _green(rmp, lmi, lpj, kwave, rwire)
            gpm = _green(rpm, lpi, lmj, kwave, rwire)
            gpp = _green(rpp, lpi, lpj, kwave, rwire)
        
            aip = (E_tan[i, 0] * E_tan[j, 0]) \
                + (E_tan[i, 1] * E_tan[j, 1]) \
                + (E_tan[i, 2] * E_tan[j, 2]) * sgn

            wss = (kwave * lci) * (kwave * lcj)

            z = cfctr * (gpp - gpm - gmp + gmm - (wss * aip * gcc))

            if not gnd:
                Z[i, j] = z
            else:
                Z[i, j] -= z
