import solver as Solver
import numpy  as np
import numpy.matlib

def simple(operator, origin, radius,
           nodes=100, maxIt=10, lStart=1, lStep=1, rankTol=1e-4, verbose=True):
    """Solves an eigenvalue problem using Beyn's algorithm (simple version)

    Keyword arguments:
    operator -- the solver defining the operator to use
    origin -- the origin (in the complex plane) of the above circular contour
    radius -- the radius of the circular contour used to search the eigenvalues
    nodes -- the number of nodes for the trapezoidal integration rule (optional)
    maxIt -- the maximal number of iteration for constructing A0 (optional)
    lStart -- the number of columns used for A0 when algorithm starts (optional)
    lStep -- the step used for increasing the number of columns of A0 (optional)
    rankTol -- the tolerance for the rank test (optional)
    verbose -- should I be verbose? (optional)

    Returns the computed eigenvalues
    """

    # Display the parameter used
    if(verbose): display(nodes, maxIt, lStart, lStep, rankTol, origin, radius)

    # Initialise A0 search
    myPath = path(nodes, origin, radius)
    hasK   = False
    it     = 0
    m      = operator.size()
    l      = lStart
    k      = -1

    # Search A0
    if(verbose): print "Searching A0..."
    while(not hasK and it != maxIt):
        if(verbose): print " # Iteration: " + str(it+1)

        vHat = randomMatrix(m ,l)                       # Take a random VHat
        A0   = integrate(operator, myPath, 0, vHat)     # Compute A0
        k    = np.linalg.matrix_rank(A0, tol=rankTol)   # Rank test

        if(k == 0):                                     # Rank is zero?
            raise RuntimeError(zeroRankErrorStr())      #  -> Stop
        elif(k == m):                                   # Maximum rank reached?
            raise RuntimeError(maxRankErrorStr())       #  -> Stop
        elif(k < l):                                    # Found a null SV?
            hasK = True                                 #  -> We have A0
        else:                                           # Matrix is full rank?
            l = l + lStep                               #  -> Increase A0 size

        it += 1                                         # Keep on searching A0

    # Check if maxIt was reached
    if(it == maxIt):
        raise RuntimeError(maxItErrorStr())
    else:
        if(verbose): print "Constructing linear EVP..."

    # Compute V, S and Wh
    #  NB: For SVD(A) = V*S*Wh, numpy computes {v, s, w}, such that:
    #      v = V; diag(s) = S and w = Wh
    V, S, Wh = np.linalg.svd(A0, full_matrices=False, compute_uv=1)

    # Extract V0, W0 and S0Inv
    V0    = np.delete(V,  l-1, 1)
    W0    = np.delete(Wh, l-1, 0).H
    S0Inv = np.matrix(np.diag(1/np.delete(S, l-1, 0)))

    # Compute A1 and B
    A1 = integrate(operator, myPath, 1, vHat)
    B  = V0.H * A1 * W0 * S0Inv

    # Eigenvalues of B
    if(verbose): print "Solving linear EVP..."
    myLambda, QHat = numpy.linalg.eig(B)

    # Done
    if(verbose): print "Done!"
    return myLambda


## Import only simple (other functions are just helpers)
__all__ = ['simple']


## Helper functions
def path(nodes, origin, radius):
    """Returns a list with the coordinates of a circular contour

    Keyword arguments:
    nodes -- the number of nodes used to discretise the contour
    radius -- the radius of the circular contour
    origin -- the origin (in the complex plane) of the circular contour

    The returned list contains nodes+1 elements,
    such that the first and the last are identical (up the machine precision)
    """

    step         = 1.0 / nodes
    nodesPlusOne = nodes + 1

    path = list()
    for i in range(nodesPlusOne):
        path.append(origin + radius * np.exp(1j * 2 * np.pi * i * step))

    return path


def integrate(operator, path, order, vHat):
    """Computes the countour integral of Beyn's method, that is matrix A_p

    Keyword arguments:
    operator -- the solver defining the operator to use
    path -- the path to integrate on
    order -- the order of Beyn's integral (that is the 'p' defining matrix A_p)
    vHat -- the RHS matrix defing Beyn's integral
    """

    # Initialise I
    I = np.matlib.zeros(vHat.shape, dtype=complex)

    # Initialise integration loop
    F1  = multiSolve(operator, vHat, path[0])
    F1 *= np.power(path[0], order)

    # Integration loop
    pathSizeMinus = len(path) - 1;
    for i in range(pathSizeMinus):
        F2  = multiSolve(operator, vHat, path[i + 1])
        F2 *= np.power(path[i + 1], order)

        I += (F1 + F2) * (path[i + 1] - path[i])
        F1 = F2

    # Done
    return I


def multiSolve(solver, B, w):
    """Solves for multiple RHS

    Keyword arguments:
    solver -- the solver to use
    B -- the matrix of RHS
    w -- the complex frequency to use
    """

    # Number of solve
    nSolve = B.shape[1]

    # Initialise X
    size = solver.size()
    X    = np.matlib.empty((size, nSolve), dtype=complex)

    # Loop and solve
    for i in range(nSolve):
        b = B[:, i]
        solver.solve(b, w)
        X[:, i] = solver.solution()

    # Done
    return X


def randomMatrix(n, m):
    """Returns a random complex matrix of the given size (n x m)"""
    return np.matlib.rand(n, m) + np.matlib.rand(n, m) * 1j


def display(nodes, maxIt, lStart, lStep, rankTol, origin, radius):
    print "Beyn's contour integral method (simple)"
    print "---------------------------------------"
    print " # Nodes used for the trapezoidal rule:" + " " + str(nodes)
    print " # Maximum number of iterations:       " + " " + str(maxIt)
    print " # Initial size of col(A0):            " + " " + str(lStart)
    print " # Step size for col(A0):              " + " " + str(lStep)
    print " # Rank test tolerance:                ",
    print format(rankTol, '.2e')
    print "---------------------------------------"
    print " # Cirular path origin:                ",
    print "(" + format(np.real(origin).tolist(), '+.2e') + ")",
    print "+",
    print "(" + format(np.imag(origin).tolist(), '+.2e') + ")j"
    print " # Cirular path radius:                ",
    print format(radius, '+.2e')
    print "---------------------------------------"


def zeroRankErrorStr():
    """Returns a string explaining the probable reason of a zero rank"""
    return ("Found a rank of zero: " +
            "the contour is probably enclosing no eigenvalues")


def maxRankErrorStr():
    """Returns a string explaining the probable reason of a maximal rank"""
    return ("Maximal rank found: " +
            "the contour is probably enclosing to many eigenvalues")


def maxItErrorStr():
    """Returns a string claiming: the maximum number of iterations is reached"""
    return "Maximum number iterations is reached!"