"""
cim.py, a non-linear eigenvalue solver.
Copyright (C) 2017 N. Marsic, F. Wolf, S. Schoeps and H. De Gersem,
Institut fuer Theorie Elektromagnetischer Felder (TEMF),
Technische Universitaet Darmstadt.

See the LICENSE.txt and README.md for more license and copyright information.
"""

import getdp as GetDP
import numpy as np

class GetDPWave:
    """A GetDP solver using complex arithmetic for time-harmonic wave problems

    Only one instance of this class makes sens

    The .pro file should use the following variables:
    angularFreqRe -- the real part of the angular frequency to use
    angularFreqIm -- the imaginary part of the angular frequency to use
    nRHS -- the number of right hand side that should be considered (default: 1)
    b_i / b~{i} -- the ith right hand side (RHS) to use (first index is 0)
    x_i / x~{i} -- the ith solution vector computed by GetDP (first index is 0)
    doInit -- should some initialization be done (default: 0)?
    doSolve -- should Ax_i = b_i be solved for all i (default: 0)?
    doPostpro -- should only a view be created for x (default: 0)?
    doApply -- should only the application of x be computed (default: 0)?
    fileName -- post-processing file name
    """

    def __init__(self, pro, mesh, resolution, optional=[]):
        """Instanciates a new SolverGetDP with a full '-solve'

        Keyword arguments:
        pro -- the .pro file to use
        mesh -- the .msh file to use
        resolution -- the resolution (from the .pro file) to use
        optional -- optional arguments for GetDP (default value = [])

        Optional argument has the following structure:
        ["GetDP option", "value 1", "value 2", ..., "GetDP option", ...]
        """
        # Save
        self.__pro        = pro
        self.__mesh       = mesh
        self.__resolution = resolution
        self.__optional   = optional

        # Generate DoFs and assemble a first system
        GetDP.GetDPSetNumber("doInit", 1);
        GetDP.GetDP(["getdp",     self.__pro,
                     "-msh",      self.__mesh,
                     "-solve",    self.__resolution,
                     "-v", "2"] + self.__optional)

    def solution(self, i):
        """Returns the solution for the ith RHS (first index is 0)"""
        return self.__toNumpy(GetDP.GetDPGetNumber("x_" + str(i)))

    def size(self):
        """Returns the number of degrees of freedom"""
        return self.solution(0).shape[0]

    def apply(self, x, w):
        """Applies x to the operator with a pulsation of w

        This method updates self.solution()
        """
        GetDP.GetDPSetNumber("angularFreqRe", np.real(w).tolist())
        GetDP.GetDPSetNumber("angularFreqIm", np.imag(w).tolist())
        GetDP.GetDPSetNumber("x_0", self.__toGetDP(x))
        GetDP.GetDPSetNumber("doApply", 1)
        GetDP.GetDP(["getdp",  self.__pro,
                     "-msh",   self.__mesh,
                     "-cal"] + self.__optional)

    def solve(self, b, w):
        """Solves with b as RHS and w as complex angular frequency

        b is a matrix and each column is a different RHS
        """
        # Number of RHS
        nRHS = b.shape[1]

        # Set variables
        GetDP.GetDPSetNumber("nRHS", nRHS)
        GetDP.GetDPSetNumber("angularFreqRe", np.real(w).tolist())
        GetDP.GetDPSetNumber("angularFreqIm", np.imag(w).tolist())

        # Set RHS
        for i in range(nRHS):
            GetDP.GetDPSetNumber("b_" + str(i), self.__toGetDP(b[:, i]))

        # Solve
        GetDP.GetDPSetNumber("doSolve", 1)
        GetDP.GetDP(["getdp",  self.__pro,
                     "-msh",   self.__mesh,
                     "-cal"] + self.__optional)

    def view(self, x, fileName):
        """Generates a post-processing view

        Keyword arguments:
        x -- the solution vector to use
        fileName -- the post-precessing file name

        This method generates a linear system
        """
        GetDP.GetDPSetNumber("x_0", self.__toGetDP(x))
        GetDP.GetDPSetNumber("doPostpro", 1)
        GetDP.GetDPSetString("fileName", fileName)
        GetDP.GetDP(["getdp",  self.__pro,
                     "-msh",   self.__mesh,
                     "-cal"] + self.__optional)

    @staticmethod
    def __toNumpy(vGetDP):
        """Takes a GetDP list and returns a numpy array"""
        size   = vGetDP.size() / 2
        vNumpy = np.empty((size, 1), dtype=complex)

        for i in range(size):
            vNumpy[i] = complex(vGetDP[i*2], vGetDP[i*2 + 1])

        return vNumpy

    @staticmethod
    def __toGetDP(vNumpy):
        """Takes a numpy array and returns a GetDP list"""
        size   = vNumpy.shape[0]
        vGetDP = list()

        for i in range(size):
            vGetDP.append(float(np.real(vNumpy[i])))
            vGetDP.append(float(np.imag(vNumpy[i])))

        return vGetDP