import numpy as np
import mpi4py.MPI as mpi

class Interface:
    """
    classe interface
    ----------------

    domvois: domaine voisin
    nodes2send: liste des indices des points a envoyer
    nodes2update: liste des indices des points a mettre a jour
    recvbuf: buffer de reception
    sendbuf: buffer d'envoi

    """
    def __init__(self, nodes, domvois, label=None):
        self.domvois = domvois
        self.label = label
        self.nodes2send = np.asarray(nodes)
        self.nodes2update = np.asarray(nodes)
        self.recvbuf = np.zeros(nodes.shape)
        self.sendbuf = np.zeros(nodes.shape)

    def __str__(self):
        s = 'proc ' + str(mpi.COMM_WORLD.rank)
        s += ' -> Interface with domain ' + str(self.domvois)
        s += '\n\tnodes to send\n' + str(self.nodes2send)
        s += '\n\tnodes to update\n' + str(self.nodes2update)
        s += '\n\n'
        return s

    def __repr__(self):
        return self.__str__()



def build_interface(m):
    """
    Construction de la liste des interfaces 

    Parametre d'entree:
    -------------------

    m: un maillage

    Sortie:
    -------
    
    liste des interfaces

    """
    interf = []
    rank = mpi.COMM_WORLD.rank
    size = mpi.COMM_WORLD.size

    for i in xrange(size):
        nbnodes = np.array(0, dtype=np.int)
        if i == rank:
            nbnodes.fill(m.nodes.shape[0])
            
        # diffusion generale du nombre de points du sous-domaine i
        mpi.COMM_WORLD.Bcast([nbnodes, mpi.INT], root=i)
        node_listext = np.empty(nbnodes, dtype=np.int)

        if i == rank:
            node_listext[:] = m.loc2glo[:]

        # diffusion generale de la correspondance locale to globale des points du sous-domaine i
        mpi.COMM_WORLD.Bcast([node_listext, mpi.INT], root=i)
        
        # recherche des concordances
        if i != rank:
            intf_list(m.loc2glo, node_listext, interf,  i)
        
    return interf

def intf_list(node_list, node_listext, interf, proc):
    """
    Recherche si des points sont sur un sous-domaine voisin

    Parametres d'entree:
    --------------------

    node_list: liste des indices des points du sous-domaine
    node_listext: liste des indices des points d'un sous-domaine voisin
    interf: liste d'instances Interface
    proc: numero du voisin

    """

    nt = max(np.max(node_list), np.max(node_listext)) + 1
    mask = np.zeros(nt)
    mask[node_listext] = 1

    nn = np.sum(mask[node_list])
    
    if nn == 0:
        return
    
    ind = np.where(mask[node_list] > 0)[0]
    mask[ind] = -ind

    interf.append(Interface(ind, proc))
    
def intf_listb(node_list, border, node_listext, border_ext, interf, proc):
    """
    Recherche si des points sont sur un sous-domaine voisin

    Parametres d'entree:
    --------------------

    node_list: liste des indices des points du sous-domaine
    node_listext: liste des indices des points d'un sous-domaine voisin
    interf: liste d'instances Interface
    proc: numero du voisin

    """

    nt = max(np.max(node_list), np.max(node_listext)) + 1
    mask = np.zeros(nt)
    mask[node_listext] = 1

    nn = np.sum(mask[node_list])
    
    if nn == 0:
        return
    
    ind = np.where(mask[border] > 0)[0]
    mask[border[ind]] = -1
    ind = np.where(mask[node_list] < 0)[0]
    print 'border', mpi.COMM_WORLD.rank, proc, ind
    

def interf_assemble(x, interf):
    """
    Mise a jour des points interfaces

    Parametres d'entree:
    --------------------

    x: vecteur a mettre a jour
    interf: liste des interfaces

    """
    rank = mpi.COMM_WORLD.rank

    for i in interf:
        i.sendbuf[:] = x[i.nodes2send]
        mpi.COMM_WORLD.Isend([i.sendbuf, mpi.DOUBLE], dest = i.domvois, tag = 100*i.domvois + rank)

    for i in interf:
        mpi.COMM_WORLD.Recv([i.recvbuf, mpi.DOUBLE], source = i.domvois, tag = 100*rank + i.domvois)
        x[i.nodes2update] = x[i.nodes2update] + i.recvbuf[:]

    mpi.COMM_WORLD.Barrier()
