"""
Resolution de l'equation 1D monodomaine

     
- nu laplacien(u) + eta u = f
      

avec condition de Dirichlet homogene sur le bord
par un schema aux differences finies

Auteur: L. Gouarin
"""

import scipy.sparse as sp
import scipy.sparse.linalg as linsolve
import matplotlib.pyplot as plt 
import numpy as np
import time
import mpi4py.MPI as mpi

def laplacian(nx, dx, nu, eta, alphag=None, alphad=None):
    """
    Assemblage de la matrice pour l'equation de la 
    chaleur
    """
    A = sp.lil_matrix((nx, nx))
    
    d = 2.*nu/dx**2*np.ones(nx)
    T = -nu*np.ones(nx)/dx**2
    A.setdiag((2.*nu/dx**2 + eta)*np.ones(nx))
    A.setdiag(T, 1)
    A.setdiag(T, -1)

    if alphag is not None:
        A[0, 0] = 1./dx + alphag
        A[0, 1] = -1./dx

    if alphad is not None:
        A[-1, -1] = 1./dx + alphad
        A[-1, -2] = -1./dx

    return A

def dirichletCondition(A, b, alphag=None, alphad=None):
    """
    Condition de Dirichlet homogene a gauche et a droite
    """
    
    if alphag is None:
        A[0, :] = 0.
        A[0, 0] = 1.
    b[0] = 0.

    if alphad is None:
        A[-1, :] = 0.
        A[-1, -1] = 1.
    b[-1] = 0.

def plotSolution(x, u):
    """
    Recuperation de la solution sur l'ensemble des sous domaines
    et visualisation
    """
    rank = mpi.COMM_WORLD.rank
    size = mpi.COMM_WORLD.size

    xsol = mpi.COMM_WORLD.gather(x, 0)
    sol = mpi.COMM_WORLD.gather(u, 0)

    color = ['b', 'r']

    if rank == 0:
        plt.hold(True)
        plt.cla()

        for i in xrange(size):
            plt.plot(xsol[i], sol[i], color[i%2], lw=2)
        plt.title("Schwarz iteration = %d"%(k))
        time.sleep(.5)
        plt.draw()
        plt.hold(False)


nx = 101
nu = 1.
eta = 1.

dx = 1./(nx - 1)

x = np.linspace(0., 1., nx)
alphad = None
alphag = None
A = laplacian(nx, dx, nu, eta, alphag, alphad)

b = 100.*np.exp(-500.*(x - .5)**2)

# Condition de Dirichlet
dirichletCondition(A, b, alphag, alphad)

LU = linsolve.factorized(A.tocsc())

# Initialisation de la fenetre de visualisation
u = LU(b)
plt.plot(u)
plt.show()
        


