from buildMatrix import buildA, buildAt, buildb, setDirichletCondition, initsol
from scipy.sparse.linalg import dsolve
import numpy as np
import mpi4py.MPI as mpi
import math

def solveMonoDomain(domain, problem, fct_scd_membre, fct_bords, uinit=None):
    """
    Calcul d'une solution mono domaine

    Parameters:
    -----------
        domain : domainClass 
            definition du nombre de points, des pas d'espace 
            et de temps, ...

        problem : problemClass
            definition des parametres physiques et des conditions
            de bords

        fct_scd_membre : fonction
            fonction initialisant le second membre

    Output:
    -------
        nombre d'iterations de l'algorithme de Schwarz

    """
    ### Construct the linear system 
    if problem.q == 0:
        A = buildA(domain, problem)
        b = buildb(domain, problem.bc, fct_scd_membre)
        setDirichletCondition(b, domain, problem, fct_bords)
        return dsolve.spsolve(A, b)
    else:
        A = buildA(domain, problem)
        At = buildAt(domain, problem)
        A = A + At

        b = buildb(domain, problem.bc, fct_scd_membre)

        u = initsol(domain, problem.bc, uinit)

        for i in xrange(domain.nt):
            print 'time iteration ->', i
            r = At.rmatvec(u) + b
            setDirichletCondition(r, domain, problem, fct_bords)
            u = dsolve.spsolve(A, r)
        return u

def solveMultiDomain(domain, problem, fct_scd_membre, fct_bords, uinit=None):
    """
    Calcul d'une solution mono domaine

    Parameters:
    -----------
        domain : domainClass 
            definition du nombre de points, des pas d'espace 
            et de temps, ...

        problem : problemClass
            definition des parametres physiques et des conditions
            de bords

        fct_scd_membre : fonction
            fonction initialisant le second membre

    Output:
    -------
        nombre d'iterations de l'algorithme de Schwarz

    """
    rank = mpi.COMM_WORLD.Get_rank()
    ### Construct the linear system 
    if problem.q == 0:
        A = buildA(domain, problem)
        b = buildb(domain, problem.bc, fct_scd_membre)
        setDirichletCondition(b, domain, problem, fct_bords)
        u = dsolve.spsolve(A, b)
        uold = u.copy()
        residu = 1.
        nbite = 0
        r = np.empty(u.shape)
        while (residu > 1e-6):
            r[:] = b[:] + updateInterfaces(u, domain, problem)
            u = dsolve.spsolve(A, r)
            if nbite == 0:
                prodScal = np.dot(u-uold, u-uold)
                residu0 = math.sqrt(mpi.COMM_WORLD.allreduce(prodScal,None,op=mpi.SUM))
            prodScal = np.dot(u-uold, u-uold)                
            residu = math.sqrt(mpi.COMM_WORLD.allreduce(prodScal,None,op=mpi.SUM))/residu0

            nbite += 1
            if rank == 0:
                print "iteration", nbite, "-> residual =", residu
            uold[:] = u[:]
    return u

def updateInterfaces(u, domain, problem):
    dx, nx = domain.dx, domain.nxNoD
    dy, ny = domain.dy, domain.nyNoD
    vois = domain.voisin
    bint = domain.bordInterface

    res = np.zeros(nx*ny)
    dx2 = dx*dx
    dy2 = dy*dy

    rank = mpi.COMM_WORLD.Get_rank()

    sendvalue = []
    for b, v in zip(bint, vois):
        if b == 0:
            sendvalue.append(np.ascontiguousarray(u[1:nx*ny:nx]))
        elif b == 1:
            sendvalue.append(np.ascontiguousarray(u[nx:2*nx]))
        elif b == 2:
            sendvalue.append(np.ascontiguousarray(u[nx-2:nx*ny:nx]))
        elif b == 3:
            sendvalue.append(np.ascontiguousarray(u[nx*(ny-2):nx*(ny-1)]))
        mpi.COMM_WORLD.Issend([sendvalue[-1], mpi.DOUBLE], 
                             dest = v, 
                             tag = 100*v + rank)

    recvalue = []
    for b, v in zip(bint, vois):
        if b == 0:
            recvalue.append(np.zeros(ny))
            mpi.COMM_WORLD.Recv([recvalue[-1], mpi.DOUBLE], 
                                source = v, 
                                tag = rank*100 + v)
            res[0:nx*ny:nx] += recvalue[-1]*problem.nu/dx2
        elif b == 1:
            recvalue.append(np.zeros(nx))
            mpi.COMM_WORLD.Recv([recvalue[-1], mpi.DOUBLE], 
                                source = v, 
                                tag = rank*100 + v)
            res[0:nx] += recvalue[-1]*problem.nu/dy2
        elif b == 2:
            recvalue.append(np.zeros(ny))
            mpi.COMM_WORLD.Recv([recvalue[-1], mpi.DOUBLE], 
                                source = v, 
                                tag = rank*100 + v)
            res[nx-1:nx*ny+1:nx] += recvalue[-1]*problem.nu/dx2
        elif b == 3:
            recvalue.append(np.zeros(nx))
            mpi.COMM_WORLD.Recv([recvalue[-1], mpi.DOUBLE], 
                                source = v, 
                                tag = rank*100 + v)
            res[nx*(ny-1):nx*ny+1] += recvalue[-1]*problem.nu/dy2
    mpi.COMM_WORLD.Barrier()  
    return res
