import mylib.mesh as mesh
import mylib.cg as cg   
import mylib.fem as fem
import mylib.interface as interface

import numpy as np
import numpy.linalg as nplg
import mpi4py.MPI as mpi
import scipy.sparse as sp
import matplotlib.pylab as plt

import sys

if len(sys.argv) != 3:
    print "usage: mpiexec -np 4 python test_parallel.py 2 2"
    sys.exit(0)

#####################################
# mpi variable
#
comm = mpi.COMM_WORLD
rank = comm.rank

#####################################
# input parameters
#
# number of nodes per direction
nx, ny = 61, 61
overlap = 4

npx = int(sys.argv[1])
npy = int(sys.argv[2])
itmax = 500
tol = 1e-6

#####################################
# create mesh
#
# compute coordinates of nodes
x = np.linspace(0., 1., nx)
y = np.linspace(0., 1., ny)

# build local mesh
m = mesh.Mesh() 
m.fromCoords(x, y, npx, npy, overlap)

# build global mesh
mglob = mesh.Mesh() 
mglob.fromCoords(x, y)

# build interface
interf = interface.build_interface(m)

# build restriction matrix
R = sp.csr_matrix((np.ones(m.nx*m.ny), [np.arange(m.nx*m.ny), m.loc2glo]),
                  shape=(m.nx*m.ny, m.nxtot*m.nytot))

# build partition of unity
D = np.ones(m.nx*m.ny)
for i in interf:
    D[i.nodes2send] += 1
D = 1./D

# build global matrix
A = fem.buildEFLaplaciandMatrix(mglob)
# build local matrix
Ai = R*A*R.T

# init second member
b = np.ones(mglob.nodes.shape[0])
b = fem.buildRhs(mglob, b)

# set Dirichlet condition
xdirg, ldirg = fem.buildDirichletBc(mglob, [1., 1., 1., 1.])
xdir, ldir = fem.buildDirichletBc(m, [1., 1., 1., 1.])

y = np.empty(b.shape)
u0 = np.ones(b.shape)
u0[ldirg] = xdirg
u = u0.copy()

r = b - A*u0
r[ldirg] = 0.

norm0 = np.sqrt(np.dot(r, r))

if norm0 == 0.:
    norm0 = 1.
norm = 1.

if ldir is not None:
    xdir[:] = 0.

# iterative solver
it = 0
while norm > tol and it < itmax:
    # solve local problem on each subdomain
    bi = R*r
    ui = cg.cg(Ai, bi, np.zeros(bi.shape), ldir, xdir)
    bi = D*ui # RAS
    #bi[:] = ui # ASM
    yi = R.T*bi

    # compute global correction
    comm.Allreduce([yi, mpi.DOUBLE], [y, mpi.DOUBLE], op=mpi.SUM)

    u[:] = u + y
    r[:] = b - A*u
    r[ldirg] = 0.

    norm = np.sqrt(np.dot(r, r))/norm0

    if rank == 0:
        print 'iteration', it, 'residu ->', norm

    it +=1

    mglob.plotSolution(u, animated=True, parallel=False)

