import mpi4py.MPI as mpi
import matplotlib.pylab as plt
import numpy as np
import vtk

def plotSolution_vtk(u, domain):
    """
    Affichage de la solution via matplotlib
    
    Parametres:
        u: solution
        domain : classe domain
        
    """
    
    #on rapatrie toutes les solutions sur le processus 0
    U = mpi.COMM_WORLD.gather(u, 0)
    
    # on cherche la valeur maximale et minimale de la 
    # solution globale
    maxu = mpi.COMM_WORLD.allreduce(u.max(), None, mpi.MAX)
    minu = mpi.COMM_WORLD.allreduce(u.min(), None, mpi.MIN)
    I, lx, nxNoD, dx = domain.I, domain.lx, domain.nxNoD, domain.dx
    J, ly, nyNoD, dy = domain.J, domain.ly, domain.nyNoD, domain.dy
    x = np.linspace( lx * I, lx * I + lx, nxNoD)
    y = np.linspace( ly * J, ly * J + ly, nyNoD)
    xmin=np.min(x)
    xmax=np.max(x)
    ymin=np.min(y)
    ymax=np.max(y)
    ## VTK - chaque processus ecrit ses donnees :
    Grid = vtk.vtkStructuredGrid()
    Grid.SetExtent(0,nxNoD-1,0,nyNoD-1,0,0)  	
    VTS=vtk.vtkXMLStructuredGridWriter()
    VTS.SetFileName("temperature_piece"+str(mpi.COMM_WORLD.rank)+".vts") 
    #VTS.SetDataModeToAscii() # DEBUG
    Points = vtk.vtkPoints()
    Temperature = vtk.vtkFloatArray()
    Temperature.SetName("Temperature")
    for j in xrange(nyNoD):
        for i in xrange(nxNoD):
            Points.InsertNextPoint(x[i],y[j],0.)
            Temperature.InsertNextValue(u[i+j*nxNoD])
    Grid.SetPoints(Points)
    Grid.GetPointData().SetScalars(Temperature)
    VTS.SetInput(Grid)
    VTS.Write()  
    mpi.COMM_WORLD.barrier()
    ## VTK - le processus 0 lit toutes les donnees et dessine l'image :
    if mpi.COMM_WORLD.rank == 0:
        fenetre = vtk.vtkRenderWindow()
        fenetre.SetSize(800,500)
        fenetre.SetPosition(50,50) 
        fenetre.SetWindowName('Temperature')
        ren = vtk.vtkRenderer()
        for i in xrange(mpi.COMM_WORLD.size):
            data= vtk.vtkXMLStructuredGridReader()
            data.SetFileName("temperature_piece"+str(i)+".vts")
            data.Update()
            contour = vtk.vtkContourFilter()
            contour.SetInput(data.GetOutput())
            contour.GenerateValues(30, minu, maxu)
            contourMapper = vtk.vtkPolyDataMapper()
            contourMapper.SetInputConnection(contour.GetOutputPort())
            contourMapper.SetScalarRange(minu,maxu)
            contourActor = vtk.vtkActor()
            contourActor.SetMapper(contourMapper)
            ren.AddActor(contourActor)
            planeMapper=vtk.vtkDataSetMapper()
            planeMapper.SetInput(data.GetOutput())
            planeMapper.SetScalarRange(minu,maxu)
            planeActor = vtk.vtkActor()
            planeActor.SetMapper(planeMapper)
            planeActor.GetProperty().SetOpacity(0.4)
            ren.AddActor(planeActor)          
        colorbar = vtk.vtkScalarBarActor()
        colorbar.SetLookupTable(contourMapper.GetLookupTable())
        colorbar.SetTitle('Temperature')
        ren.AddActor(colorbar)
        #ren.SetBackground(1.,1.,1.)
        fenetre.AddRenderer(ren)
        iren = vtk.vtkRenderWindowInteractor()
        iren.SetRenderWindow(fenetre)   
        iren.Initialize()
        iren.Start()

 


def plotSolution(u, domain):
    """
    Affichage de la solution via matplotlib
    
    Parametres:
        u: solution
        domain : classe domain
        
    """
    
    #on rapatrie toutes les solutions sur le processus 0
    U = mpi.COMM_WORLD.gather(u, 0)
    
    # on cherche la valeur maximale et minimale de la 
    # solution globale
    maxu = mpi.COMM_WORLD.allreduce(u.max(), None, mpi.MAX)
    minu = mpi.COMM_WORLD.allreduce(u.min(), None, mpi.MIN)

    I, lx, nxNoD = domain.I, domain.lx, domain.nxNoD
    J, ly, nyNoD = domain.J, domain.ly, domain.nyNoD

    x = np.linspace(lx * I, lx * I + lx, nxNoD)
    y = np.linspace(ly * J, ly * J + ly, nyNoD)

    # on rapatrie les grilles sur le processus 0
    X = mpi.COMM_WORLD.gather(x, 0)
    Y = mpi.COMM_WORLD.gather(y, 0)
    
    # seul le processus 0 affiche la solution globale
    if mpi.COMM_WORLD.rank == 0:
        fig = plt.figure()
        V = np.linspace(minu, maxu, 60)

        plt.hold(True)
        for i in xrange(mpi.COMM_WORLD.size):
            xt, yt = np.meshgrid(X[i],Y[i])
            plt.contourf(xt, yt, U[i].reshape(xt.shape), V)
        plt.hold(False)
        plt.colorbar()
        plt.show()
