__all__ = ['meshRepresentation']
__docformat__ = 'restructuredtext'
import vtk
import random
import numpy as np

_vtk_element = {1: vtk.vtkLine,
                2: vtk.vtkQuad}

# Node number for each element
_vtk_element_size = { 0: 1,
                      1: 2,
                      2: 4}


def addPoint(nodes, grid, shift=0):
    points = vtk.vtkPoints()
    nbVert = nodes.shape[0]
    points.SetNumberOfPoints(nbVert)
    map(points.InsertPoint, xrange(shift, nbVert+shift), nodes[:,0], nodes[:,1], nodes[:,2])
    grid.SetPoints(points)

def addElements(elements, eltype, grid, shift=0):
    vtkMethod = _vtk_element[eltype]()
    for key, value in elements.iteritems():
        for v in value:
            map(vtkMethod.GetPointIds().SetId, xrange(_vtk_element_size[eltype]), v+shift)
            grid.InsertNextCell(vtkMethod.GetCellType(), vtkMethod.GetPointIds())

def addSol(sol, solName, grid):
    data = vtk.vtkDoubleArray()
    data.SetName(solName)
    map(data.InsertNextValue, sol[:])
    grid.GetPointData().SetScalars(data)

def vtkRepresentation(meshList, solList=None, solName="solution", saveInFile = None,
                      wireframe=False, withLabel = False, onlyBorder = False,
                      withBorder=True, withColorBar = False):
    # Vtk frame construction
    ren = vtk.vtkRenderer()
    ren.SetBackground((1,1,1))

    if solList is not None:
        max = np.max(solList[0][:])
        min = np.min(solList[0][:])
        for u in solList[1:]:
            if max < np.max(u[:]):
                max = np.max(u[:])
            if min > np.min(u[:]):
                min = np.min(u[:])

    for imesh, mesh in enumerate(meshList):
        grid = vtk.vtkUnstructuredGrid()
        gridMapper = vtk.vtkDataSetMapper()
        gridActor = vtk.vtkActor()
    
        addPoint(mesh.nodes, grid)

        nbElement = mesh.numberOfElements()
        grid.Allocate(nbElement, nbElement)
    
        if not onlyBorder:
            for k, v in mesh.elements.iteritems():
                #            if k != 0 and v != {}:
                if k in [2, 3] and v != {}:
                    addElements(v, k, grid)
        else:
            for k, v in mesh.elements.iteritems():
                if k in [1] and v != {}:
                   addElements(v, k, grid)

        if solList is not None:
            addSol(solList[imesh][:], solName, grid)

        gridMapper.SetInput(grid)


        # UnstructWriter = vtk.vtkXMLUnstructuredGridWriter()
        # UnstructWriter.SetFileName('test_small'+ str(imesh)+'.vtu')
        # UnstructWriter.SetInput(grid)
        # UnstructWriter.Write()

        if solList is not None:
            gridMapper.SetScalarRange((min, max))

        gridActor.SetMapper(gridMapper)
        if wireframe:
            gridActor.GetProperty().SetRepresentationToWireframe()
            gridActor.GetProperty().SetLineWidth(2)
            gridActor.GetProperty().SetColor((random.random(), random.random(), random.random()))
#    gridActor.GetProperty().SetDiffuseColor((0.2, 0.2, .5))

        ren.AddActor(gridActor)

        # label
        if withLabel:
            ptsLabeledMapper = vtk.vtkLabeledDataMapper()
            ptsLabeledMapper.SetInput(grid)
            ptsLabeledActor = vtk.vtkActor2D()
            ptsLabeledActor.SetMapper(ptsLabeledMapper)
            ren.AddActor(ptsLabeledActor)
            
#             GridCellCenter = vtk.vtkCellCenters()
#             GridCellCenter.SetInput(grid)
            
#             GridElemLabeledMapper = vtk.vtkLabeledDataMapper()
#             GridElemLabeledMapper.SetInput(GridCellCenter.GetOutput())
# #            GridElemLabeledMapper.GetLabelTextProperty().SetColor((1., 0., 0.))
    
#             GridElemLabeledActor = vtk.vtkActor2D()
#             GridElemLabeledActor.SetMapper(GridElemLabeledMapper)
#             ren.AddActor(GridElemLabeledActor)    
        if withColorBar:
        #create a colorbar
            colorbar = vtk.vtkScalarBarActor()
            colorbar.SetLookupTable(gridMapper.GetLookupTable())
            colorbar.SetOrientationToHorizontal()
            colorbar.SetWidth(.8)
            colorbar.SetHeight(.17)
            colorbar.SetPosition(0.1, 0.)
#             colorbar.SetLabelFormat("%.3g")
#             colorbar.PickableOff()
            colorbar.VisibilityOn()

            ren.AddActor(colorbar)
#            colorbar.PickableOff()
#            colorbar.VisibilityOn()
#            ren.AddActor(colorbar)

    if withBorder:
        for imesh, mesh in enumerate(meshList):
            grid1 = vtk.vtkUnstructuredGrid()
            gridMapper1 = vtk.vtkDataSetMapper()
            gridActor1 = vtk.vtkActor()
            addPoint(mesh.nodes, grid1)

            nbElement = mesh.numberOfElements()
            grid1.Allocate(nbElement, nbElement)
            for k, v in mesh.elements.iteritems():
                if k in [1] and v != {}:
                   addElements(v, k, grid1)
            gridMapper1.SetInput(grid1)
            gridActor1.SetMapper(gridMapper1)
            gridActor1.GetProperty().SetRepresentationToWireframe()
            gridActor1.GetProperty().SetLineWidth(3)
            gridActor1.GetProperty().SetColor((0,0,0))
            ren.AddActor(gridActor1)

    if saveInFile is None:
        fenetre = vtk.vtkRenderWindow()
        fenetre.SetSize(750,750)
        fenetre.AddRenderer(ren)
    
        iren = vtk.vtkRenderWindowInteractor()
        iren.SetRenderWindow(fenetre)    # Set the rendering window being controlled by this object.
    
        ## Projection :
        iren.Initialize()  # Prepare for handling events. This must be called before the interactor will work.
        fenetre.Render()   
        iren.Start()
    else:
        renWin = vtk.vtkRenderWindow()
        renWin.AddRenderer(ren)
        renWin.SetSize(750, 750)
        renWin.Render()
        # Save the window to a png file
        image = vtk.vtkWindowToImageFilter()
        image.SetInput(ren.GetRenderWindow())
        writer = vtk.vtkPNGWriter()
        writer.SetInputConnection(image.GetOutputPort())
        writer.SetFileName(saveInFile + ".png")
        writer.Write()
    

    
