# -*- coding: utf-8 -*- 
# --------------------------------------------------------------------
__author__ = 'Loic Gouarin'
__all__ = ['Mesh']
__docformat__ = 'restructuredtext'
# --------------------------------------------------------------------

import numpy as np
import mpi4py.MPI as mpi

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib import cm

# Element type description
# we take the gmsh type
_dico_element = {  1: 'line',
                   2: 'quadrangle'}

# Number of nodes for each element
_dico_elemsize = { 1: 2,
                   3: 4}

class Mesh:
    """
    Un maillage est defini par

        - dim : dimension du maillage
        - nodes : tableau des coordonnees
        - elements : dictionnaire d'elementType

                     elementType:
                         - 0: point,
                         - 1: line,
                         - 2: quadrangle.

                     Dans ce dictionnaire, nous retrouvons un autre 
                     dictionnaire. Si il n'y a pas de valeur, il n'y
                     a pas d'elements de ce type. Sinon, nous avons 
                     une cle par région et les valeurs sont le tableau
                     des elements associes.

    """
    def __init__(self):
        self.nodes = None
        self.elements={}
        for k in _dico_element.keys():
            self.elements[k]={}
        self.oldcol = None

    def fromCoords(self, x, y=None, npx=1, npy=1, overlap=0):
        """ 
        Creation d'un maillage a partir des intervalles x, y.

        Parametres d'entree:
        --------------------

        x: points dans la direction x

        y: points dans la direction y 
           par defaut: None
           
        npx: nombre de decoupages suivant x
             par defaut: 1

        npy: nombre de decoupages suivant y
             par defaut: 1
              
        Sortie:
        -------

        maillage 

        Exemple:
        --------
        
            import mesh
            from numpy import linspace

            m = mesh.Mesh()
            m.fromCoords(np.linspace(0, 1, 10), np.linspace(0, 1, 10))
        

        """
        coords = (np.asarray(x, dtype = np.double),)

        self.nx = coords[0].size

        if y is not None:
            coords += (np.asarray(y, dtype = np.double),)
            self.ny = coords[1].size

        self.ndim = len(coords)

        shape = [c.size for c in coords]
        size = 1
        for i in shape:
            size *= i

        self._setNodes(coords, shape, size)
        self._setElements(shape, size)

    def numberOfNodes(self) :
        """
        Renvoie le nombre de points contenus dans le maillage.

        :Parametres de sortie:

            - nombre de points

        """
        return self.nodes.shape[0]

    def numberOfElements( self, elementType=None) :
        """
        Renvoie le nombre d'elements contenus dans le maillage.
        
        :Parametre d'entree:

            - elementType : si None, renvoie le nombre total d'elements
                            sinon renvoie le nombre d'element ayant le 
                            type elementType
                          
                          - 0: point,
                          - 1: line,
                          - 2: quadrangle.

        :Parametres de sortie:

            - nombre d'elements

        """
        s = 0
        if elementType == None:
            for k in self.elements.keys():
                if self.elements[k] != {}:
                    for key, value in self.elements[k].iteritems():
                        s += value.shape[0]
        else:
            for key, value in self.elements[elementType].iteritems():
                s += value.shape[0]
        return s

    def _setNodes(self, coords, shape, size):
        nn = np.indices(shape[-1::-1])
        self.nodes = np.zeros((size, 3), dtype = np.double)

        for i, c in enumerate(coords):
            self.nodes[:, i] = c[nn[-1 - i]].flatten()

    def _setElements(self, shape, size):

        nn = np.arange(size - shape[0])
        nn = nn[(nn + 1) % shape[0] != 0]
        
        quad = np.empty((nn.size, 4), dtype = np.int)
        quad[:, 0] = nn[:]
        quad[:, 1] = nn[:] + 1
        quad[:, 2] = nn[:] + shape[0] + 1
        quad[:, 3] = nn[:] + shape[0]

        self.elements[2][1] = quad
        
        line = np.arange(shape[0])
        self.elements[1][1] = np.asarray([line[:-1], line[1:]]).transpose()
        self.elements[1][3] = size - 1 - np.asarray([line[:-1], line[1:]]).transpose()

        line = np.arange(shape[0] - 1, size, shape[0])
        self.elements[1][2] = np.asarray([line[:-1], line[1:]]).transpose()
        self.elements[1][4] = size - 1 - np.asarray([line[:-1], line[1:]]).transpose()            

        for k in self.elements.keys():
            for r in self.elements[k].keys():
                self.elements[k][r] = np.ascontiguousarray(self.elements[k][r], np.int32)

    def showMeshCarac(self):
        """
        Affiche le nombre de points et le nombre d'elements
        contenus dans le maillage.

        """
        print 'Number of nodes: %d'%self.numberOfNodes()
        print 'Number of elements: %d'%self.numberOfElements()
              
    def show(self, withLabel = False, onlyBorder=False):
        """
        Representation du maillage avec vtk.

        """
        import vtktools

        M = mpi.COMM_WORLD.gather(self, 0)
        if mpi.COMM_WORLD.rank == 0:
            vtktools.vtkRepresentation(M, wireframe=True, withLabel=withLabel, onlyBorder=onlyBorder)

    def plotSolution(self, x, animated=False, parallel=True):
        rank = mpi.COMM_WORLD.rank
        if self.oldcol is None and rank == 0:
            if animated:
                plt.ion()
            self.wframe = None
            self.fig = plt.figure()
            #self.ax = self.fig.add_subplot(111, projection='3d')
            self.ax = Axes3D(self.fig)
            
        if parallel:
            M = mpi.COMM_WORLD.gather(self.nodes, 0)
            Nx = mpi.COMM_WORLD.gather(self.nx, 0)
            Ny = mpi.COMM_WORLD.gather(self.ny, 0)
            SOL = mpi.COMM_WORLD.gather(x, 0)
            if mpi.COMM_WORLD.rank == 0:
                plt.hold(True)
                for i in xrange(len(M)):
                    nx, ny = Nx[i], Ny[i] 
                    X = M[i][:, 0].reshape(nx, ny)
                    Y = M[i][:, 1].reshape(nx, ny)
                    Z = SOL[i].reshape(nx, ny)
                    if self.oldcol is not None:
                        for i in xrange(len(self.ax.collections)):
                            self.ax.collections.pop()

                    self.wframe = self.ax.plot_wireframe(X, Y, Z, rstride=1, cstride=1, cmap=cm.jet)

                self.oldcol = self.wframe

                plt.draw()
                    
        else:
            if rank == 0:
                X = self.nodes[:, 0].reshape(self.nx, self.ny)
                Y = self.nodes[:, 1].reshape(self.nx, self.ny)
                Z = x.reshape(self.nx, self.ny)
                self.wframe = self.ax.plot_wireframe(X, Y, Z, rstride=1, cstride=1, cmap=cm.jet)
                if self.oldcol is not None:
                    self.ax.collections.remove(self.oldcol)

                self.oldcol = self.wframe

                plt.draw()

    def showSolution(self, x, parallel=True, saveInFile=None):
        import vtktools

        if parallel:
            M = mpi.COMM_WORLD.gather(self, 0)
            SOL = mpi.COMM_WORLD.gather(x, 0)
            if mpi.COMM_WORLD.rank == 0:
                vtktools.vtkRepresentation(M, SOL, withColorBar = True, saveInFile=saveInFile)
        else:
            if mpi.COMM_WORLD.rank == 0:
                vtktools.vtkRepresentation([self], [x], withColorBar = True, saveInFile=saveInFile)

if __name__ == '__main__':
    import time
    import sys

    if len(sys.argv) != 3:
        print "usage: mpiexec -np 4 python mesh.py 2 2"
        sys.exit(0)

    nbpts = 101
    x = np.linspace(0., 1., nbpts)
            
    m = Mesh()
    t1 = time.time() 

    npx = int(sys.argv[1])
    npy = int(sys.argv[2])
    m.fromCoords(x, x, npx, npy)
    print 'execution time' , time.time() - t1

    m.show()
#    m.showMeshCarac()
#    m.showElements(1)


        
