from interface import interf_assemble

import numpy as np
import mpi4py.MPI as mpi
import scipy.sparse.linalg as splg

def factorize(A, ldir):
    """
    Realise la factorisation LU de la matrice A et 
    retourne une fonction permettant de resoudre Ax=b
    """

    # copie de A dans Alu
    Alu = A.copy()

    # Prise en compte de conditions de Dirichlet
    for i in ldir:
        Alu.data[range(Alu.indptr[i], Alu.indptr[i+1])] = 0.
        Alu[i,i] = 1.

    # Factorisation LU de A dans Alu
    solve = splg.factorized(Alu.tocsc())

    return solve

def cg(A, b, x0, ldir, xdir, tol = 1e-6, itmax = 500, withInfo=False, verbose=False):
    """
    Gradient conjugue parallele

    Parametres d'entree:
    --------------------

    A: matrice du systeme a resoudre
    b: second membre
    x0: solution initiale
    ldir: liste des points Dirichlet
    xdir: conditions de Dirichlet
    tol: tolerance relative
          Defaut: 1e-6
    itmax: nombre d'iterations maximal
    verbose: affichage des iterations
             par defaut: False

    """

    x = x0.copy()
    if ldir is not None:
        x[ldir] = xdir

    g = A*x - b
    if ldir is not None:
        g[ldir] = 0.

    g0 = np.sqrt(np.dot(g, g))
    if g0 == 0.:
        g0 = 1.

    if verbose:
        print '  || Ax0 - b || =', g0

    w = g.copy()

    it = 0
    residual = []

    while it < itmax:
        Aw = A*w
        if ldir is not None:
            Aw[ldir] = 0.

        rho = -np.dot(g,w) / np.dot(Aw, w)

        x[:] = x + rho*w
        g[:] = g + rho*Aw

        res = np.sqrt(np.dot(g, g)) / g0
        residual.append(res)
        if verbose:
            print '  iteration', it, 'residu ->', res
        if res < tol:
            break

        gamma = -np.dot(g, Aw) / np.dot(Aw, w)
        w[:] = g + gamma*w

        it += 1

    if withInfo:
        return x, residual, it
    else:
        return x

def cgPar(A, b, x0, ldir, xdir, interf, tol = 1e-6, itmax = 500, withInfo=False, verbose=False):
    """
    Gradient conjugue parallel

    Parametres d'entree:
    --------------------

    A: matrice du systeme a resoudre
    b: second membre
    x0: solution initiale
    ldir: liste des points Dirichlet
    xdir: conditions de Dirichlet
    tol: tolerance relative
          Defaut: 1e-6
    interf: liste des interfaces
    itmax: nombre d'iterations maximal
    withInfo: retourne le nombre d'iteration
              et la liste des gradients
    verbose: affichage des iterations
             par defaut: False

    """

    comm = mpi.COMM_WORLD
    rank = comm.Get_rank()

    x = x0.copy()
    if ldir is not None:
        x[ldir] = xdir

    g = A*x - b
    if ldir is not None:
        g[ldir] = 0.
    a_g = g.copy()
    interf_assemble(a_g, interf)

    loc_g0 = np.dot(a_g, g)
    glob_g0 = comm.allreduce(loc_g0, None, op=mpi.SUM)
    g0 = np.sqrt(glob_g0)
    if g0 == 0.:
        g0 = 1.

    if verbose:
        print '|| Ax0 - b || =', g0

    w = a_g.copy()

    it = 0
    residual = []

    while it < itmax:
        Aw = A*w
        if ldir is not None:
            Aw[ldir] = 0.

        loc_gw = np.dot(g,w)
        glob_gw = comm.allreduce(loc_gw, None, op=mpi.SUM)

        loc_Aww = np.dot(Aw,w)
        glob_Aww = comm.allreduce(loc_Aww, None, op=mpi.SUM)

        rho = - glob_gw / glob_Aww 

        x[:] = x + rho*w
        g[:] = g + rho*Aw

        a_g = g.copy()
        interf_assemble(a_g, interf)

        loc_gn = np.dot(a_g,g)
        glob_gn = comm.allreduce(loc_gn, None, op=mpi.SUM)

        res = np.sqrt(glob_gn) / g0 
        residual.append(res)
        if verbose and rank==0:
            print 'iteration', it, 'residu ->', res
        if res < tol:
            break        

        loc_gAw = np.dot(a_g,Aw)
        glob_gAw = comm.allreduce(loc_gAw, None, op=mpi.SUM) 

        gamma = - glob_gAw / glob_Aww 

        w[:] = a_g + gamma*w

        it += 1

    if withInfo:
        return x, residual, it
    else:
        return x


def cgPrecPar(A, b, x0, ldir, xdir, interf, tol = 1e-6, itmax = 50, withInfo=False, verbose=False):
    """
    Gradient conjugue parallel preconditionne

    Parametres d'entree:
    --------------------

    A: matrice du systeme a resoudre
    b: second membre
    x0: solution initiale
    ldir: liste des points Dirichlet
    xdir: conditions de Dirichlet
    tol: tolerance relative
          Defaut: 1e-6
    interf: liste des interfaces
    itmax: nombre d'iterations maximal
    withInfo: retourne le nombre d'iteration
              et la liste des gradients
    verbose: affichage des iterations
             par defaut: False

    """

    comm = mpi.COMM_WORLD
    rank = comm.Get_rank()

    # build list of interface nodes
    lf = None
    for i in interf:
        if lf is None:
            lf = i.nodes2send.copy()
        else:
            lf = np.concatenate((lf, i.nodes2send))

    # build list of dirichlet nodes + interface nodes
    ldirf = None
    if ldir is not None:
        ldirf = np.concatenate((ldir, lf))
    else:
        ldirf = lf.copy()

    # factorize matrix
    solve = factorize(A, ldirf)

    # initial solution
    x = x0.copy()
    if ldir is not None:
        x[ldir] = xdir

    # compute gradient 
    g = A*x - b
    if ldir is not None:
        g[ldir] = 0.
    a_g = g.copy()
    interf_assemble(a_g, interf)
    loc_g0 = np.dot(a_g, g)
    glob_g0 = comm.allreduce(loc_g0, None, op=mpi.SUM)
    g0 = np.sqrt(glob_g0)
    if verbose:
        print '|| Ax0 - b || =', g0

    # compute x such as g = 0 for interior nodes
    # => xi = (Aii)^-1 (bi - Aif.xf)
    bb = b.copy()
    bb[lf] = x[lf]
    bb[ldir] = xdir
    x = solve(bb) 

    # compute new gradient 
    g = A*x - b
    if ldir is not None:
        g[ldir] = 0.
    a_g = g.copy()
    interf_assemble(a_g, interf)

    # compute preconditionned gradient
    bb = np.zeros(b.size)
    bb[lf] = a_g[lf]
    Mg = solve(bb) 
    w = Mg.copy()
    
    it = 0
    residual = []

    while it < itmax:
        Aw = A*w
        if ldir is not None:
            Aw[ldir] = 0.

        loc_gw = np.dot(g,w)
        glob_gw = comm.allreduce(loc_gw, None, op=mpi.SUM)

        loc_Aww = np.dot(Aw,w)
        glob_Aww = comm.allreduce(loc_Aww, None, op=mpi.SUM)

        rho = - glob_gw / glob_Aww 

        x[:] = x + rho*w
        g[:] = g + rho*Aw

        a_g = g.copy()
        interf_assemble(a_g, interf)

        loc_gn = np.dot(a_g,g)
        glob_gn = comm.allreduce(loc_gn, None, op=mpi.SUM)

        res = np.sqrt(glob_gn) / g0 
        residual.append(res)
        if verbose and rank==0:
            print 'iteration', it, 'residu ->', res
        if res < tol:
            break
  
        bb = np.zeros(b.size)
        bb[lf] = a_g[lf]  
        Mg = solve(bb) 

        loc_gAw = np.dot(Mg,Aw)
        glob_gAw = comm.allreduce(loc_gAw, None, op=mpi.SUM) 

        gamma = - glob_gAw / glob_Aww 

        w[:] = Mg + gamma*w

        it += 1

    if withInfo:
        return x, residual, it
    else:
        return x

def schur(A, b, x0, ldir, xdir, interf, tol = 1e-6, itmax = 50, withInfo=False, verbose=False):
    """
    Methode du complement de Schur (slides 9,10 et 11 du cours de F.X. Roux)
    Gradient conjugue sur le probleme interface S.xf = c

    Parametres d'entree:
    --------------------

    A: matrice du systeme a resoudre
    b: second membre
    x0: solution initiale
    ldir: liste des points Dirichlet
    xdir: conditions de Dirichlet
    tol: tolerance relative
          Defaut: 1e-6
    interf: liste des interfaces
    itmax: nombre d'iterations maximal
    withInfo: retourne le nombre d'iteration
              et la liste des gradients
    verbose: affichage des iterations
             par defaut: False

    """

    comm = mpi.COMM_WORLD
    rank = comm.Get_rank()

    # build list of interface nodes
    lf = None
    for i in interf:
        if lf is None:
            lf = i.nodes2send.copy()
        else:
            lf = np.concatenate((lf, i.nodes2send))

    # build list of dirichlet nodes + interface nodes
    ldirf = None
    if ldir is not None:
        ldirf = np.concatenate((ldir, lf))
    else:
        ldirf = lf.copy()

    # factorize matrix
    solve = factorize(A, ldirf)

    # compute initial global gradient
    x = x0.copy()
    if ldir is not None:
        x[ldir] = xdir
    g = A*x - b
    if ldir is not None:
        g[ldir] = 0.
    a_g = g.copy()
    interf_assemble(a_g, interf)

    loc_g0 = np.dot(a_g, g)
    glob_g0 = comm.allreduce(loc_g0, None, op=mpi.SUM)
    g0 = np.sqrt(glob_g0)
    if g0 == 0.:
        g0 = 1.

    if verbose:
        print '|| Ax0 - b || =', g0

    # initial interface solution
    xf = x[lf]

    # solve Aii.xi = bi - Aif.xf
    bb = b.copy()
    bb[lf] = xf
    bb[ldir] = xdir
    x = solve(bb)

    g = A*x - b
    if ldir is not None:
        g[ldir] = 0.
    gf = g[lf]
    a_g = g.copy()
    interf_assemble(a_g, interf)
    a_gf = a_g[lf]

    w = a_gf.copy()

    it = 0
    residual = []

    while it < itmax:

        # compute Sw
        bb = np.zeros(b.size)
        bb[lf] = w
        xx = solve(bb)
        Axx = A*xx
        if ldir is not None:
            Axx[ldir] = 0.
        Sw = Axx[lf]

        loc_gfw = np.dot(gf,w)
        glob_gfw = comm.allreduce(loc_gfw, None, op=mpi.SUM)
 
        loc_Sww = np.dot(Sw,w)
        glob_Sww = comm.allreduce(loc_Sww, None, op=mpi.SUM)

        rho =  - glob_gfw / glob_Sww

        xf[:] = xf + rho*w
        gf[:] = gf + rho*Sw

        g[lf] = gf  
        a_g = g.copy()
        interf_assemble(a_g, interf)
        a_gf = a_g[lf]

        loc_gn = np.dot(a_g,g) 
        glob_gn = comm.allreduce(loc_gn, None, op=mpi.SUM)

        res = np.sqrt(glob_gn) / g0 
        residual.append(res)
        if verbose and rank==0:
            print 'iteration', it, 'residu ->', res
        if res < tol:
            break 
 
        loc_gfSw = np.dot(a_gf,Sw)
        glob_gfSw = comm.allreduce(loc_gfSw, None, op=mpi.SUM) 

        gamma = - glob_gfSw / glob_Sww 

        w[:] = a_gf + gamma*w

        it += 1  

    bb = b.copy()
    bb[lf] = xf
    bb[ldir] = xdir
    x = solve(bb)
    if withInfo:
        return x, residual, it
    else:
        return x

