import mylib.mesh as mesh
import mylib.cg as cg   
import mylib.fem as fem
import mylib.interface as interface
import mylib.coarseGrid as CGrid

import numpy as np
import numpy.linalg as nplg
import mpi4py.MPI as mpi
import scipy.sparse as sp
import matplotlib.pylab as plt

import sys

if len(sys.argv) != 3:
    print "usage: mpiexec -np 4 python test_parallel.py 2 2"
    sys.exit(0)

#####################################
# mpi variable
#
comm = mpi.COMM_WORLD
rank = comm.rank

#####################################
# input parameters
#
# number of nodes per direction
nx, ny = 21, 21
overlap = 4

npx = int(sys.argv[1])
npy = int(sys.argv[2])
itmax = 500
tol = 1e-6
verbose = True

#####################################
# create mesh
#
#weak scaling
nx = (nx -1)*npx +1
ny = (ny -1)*npy +1

# compute coordinates of nodes
x = np.linspace(0., 1., nx)
y = np.linspace(0., 1., ny)

# build local mesh
m = mesh.Mesh() 
m.fromCoords(x, y, npx, npy, overlap)

# build global mesh
mglob = mesh.Mesh() 
mglob.fromCoords(x, y)

# build interface
interf = interface.build_interface(m)

# build restriction matrix
R = sp.csr_matrix((np.ones(m.nx*m.ny), [np.arange(m.nx*m.ny), m.loc2glo]),
                  shape=(m.nx*m.ny, m.nxtot*m.nytot))

# build partition of unity
D = np.ones(m.nx*m.ny)
for i in interf:
    D[i.nodes2send] += 1
D = 1./D

# build global matrix
A = fem.buildEFLaplaciandMatrix(mglob)
# build local matrix
Ai = R*A*R.T

# init second member
b = np.ones(mglob.nodes.shape[0])
b = fem.buildRhs(mglob, b)

# set Dirichlet condition
xdirg, ldirg = fem.buildDirichletBc(mglob, [1., 1., 1., 1.])
xdir, ldir = fem.buildDirichletBc(m, [1., 1., 1., 1.])

# init Z and E for P_BNN preconditioner
Z = CGrid.buildZ(R, D) 
Z[ldirg] = 0.
E = CGrid.buildE(A, Z, npx, npy) 

y = np.empty(b.shape)
u0 = np.zeros(b.shape)
u0[ldirg] = xdirg
u = u0.copy()

g = A*u0 - b 
g[ldirg] = 0.

if ldir is not None:
    xdir[:] = 0.

#####################################
# P_BNN preconditioner
r1 = CGrid.Q(g, Z, E)
r2 = CGrid.P(g, A, Z, E)
r3 = np.empty(u0.shape)
bi = R*r2
ui = cg.cg(Ai, bi, np.zeros(bi.shape), ldir, xdir)
yi = R.T*ui
mpi.COMM_WORLD.Allreduce([yi, mpi.DOUBLE], [r3, mpi.DOUBLE], op=mpi.SUM)
y  = CGrid.PT(r3, A, Z, E)
y[:] = y + r1
y[ldirg] = 0.

w = y.copy()

g0 = np.sqrt(np.dot(g, y))

if g0 == 0.:
    g0 = 1.

if verbose and rank == 0:
    print '|| Ax0 - b || =', g0

it = 0
# conjugate gradient
while it < itmax:
    Aw = A*w
    Aw[ldirg] = 0.

    rho = -np.dot(g,w) / np.dot(Aw, w)
    
    u[:] = u + rho*w
    g[:] = g + rho*Aw

    #####################################
    # P_BNN preconditioner
    r1 = CGrid.Q(g, Z, E)
    r2 = CGrid.P(g, A, Z, E)
    r3 = np.empty(u0.shape)
    bi = R*r2
    ui = cg.cg(Ai, bi, np.zeros(bi.shape), ldir, xdir)
    yi = R.T*ui
    mpi.COMM_WORLD.Allreduce([yi, mpi.DOUBLE], [r3, mpi.DOUBLE], op=mpi.SUM)
    y  = CGrid.PT(r3, A, Z, E)
    y[:] = y + r1
    y[ldirg] = 0.

    res = np.sqrt(np.dot(g, y))/g0

    if verbose and rank == 0:
        print 'iteration', it, 'residu ->', res
    if res < tol:
        break

    gamma = -np.dot(y, Aw)/np.dot(Aw, w)
    w[:] = y + gamma*w

    it += 1

    mglob.plotSolution(u, animated=True, parallel=False)

