# -*- coding: utf-8 -*-
# solve.py

import numpy as np
import scipy as sp
import cupy as cp
import time
import sol.zmatrix, sol.planewave

C = 2.99792458e8
EPS = 1e-10

# 電流分布を計算する
def solve(GPU, c_dtype, Ne, Nfreq, Freq, Nfeed, Iplanewave, Planewave,
    E_posc, E_posm, E_posp, E_lng, E_tan, E_rad, E_ifeed, E_feed, E_iload, E_load, Iground):

    # 配列
    Z  = np.zeros((Ne, Ne), c_dtype)
    b  = np.zeros(Ne, c_dtype)
    x  = np.zeros(Ne, c_dtype)
    Xc = np.zeros((Nfreq, Ne), 'c16')

    tcpu = [0] * 2

    if GPU:
        start = cp.cuda.Event()
        end   = cp.cuda.Event()

    # 周波数に関するループ
    for ifreq in range(Nfreq):
        t0 = time.time()
        
        # 経過表示
        print('                   %d/%d' % (ifreq + 1, Nfreq))
    
        # Z行列を計算する
        sol.zmatrix.zmatrix(Z, Ne, E_posc, E_posm, E_posp, E_lng, E_tan, E_rad, E_iload, E_load, Iground, Freq[ifreq])
    
        # 右辺ベクトルを用意する
        _rhs(b, Ne, Nfeed, Iplanewave, Planewave, E_posc, E_lng, E_tan, E_ifeed, E_feed, Iground, Freq[ifreq])

        t1 = time.time()
        tcpu[0] += t1 - t0

        # 電流分布を計算する
        if not GPU:
            # CPU
            #x = np.linalg.solve(Z, b)
            x = sp.linalg.solve(Z, b)
            t2 = time.time()
            tcpu[1] += t2 - t1
        else:
            # GPU
            start.record()
            d_Z = cp.array(Z)
            d_b = cp.array(b)
            d_x = cp.linalg.solve(d_Z, d_b)
            x = cp.asnumpy(d_x)
            end.record()
            end.synchronize()
            tcpu[1] += cp.cuda.get_elapsed_time(start, end) * 1e-3
        Xc[ifreq] = x.copy()

        # cpu時間
        #t2 = time.time()
        #tcpu[0] += t1 - t0
        #tcpu[1] += t2 - t1

    Z = None
    b = None
    x = None

    return Xc, tcpu

# (private) setup RHS vector
def _rhs(b, Ne, Nfeed, Iplanewave, Planewave, E_posc, E_lng, E_tan, E_ifeed, E_feed, Iground, freq):

    if Nfeed > 0:
        # feed
        for ie in range(Ne):
            if (E_ifeed[ie] > 0):
                b[ie] = E_feed[ie, 0] * np.exp(1j * np.deg2rad(E_feed[ie, 1]))
    else:
        # planewave
        k0 = (2 * np.pi * freq) / C
        pol   = Iplanewave
        theta = Planewave[0]
        phi   = Planewave[1]
        a     = Planewave[2]
        r     = Planewave[3]
        for ie in range(Ne):
            ei, _ = sol.planewave.planewave(E_posc[ie], theta, phi, pol, a, r, k0, Iground)
            b[ie] = np.dot(ei, E_tan[ie]) * E_lng[ie, 0]

    # check zero source
    if np.sum(np.abs(b)) < EPS:
        print('*** error : no source.')
