import mylib.mesh as mesh
import mylib.fem as fem
import mylib.cg as cg   
import mylib.interface as interface

import mpi4py.MPI as mpi
import numpy as np
import matplotlib.pylab as plt
import string
import sys

if len(sys.argv) != 3:
    print "usage: mpiexec -np 4 python test_parallel.py 2 2"
    sys.exit(0)

comm = mpi.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
if rank != 0 :
    fileName = 'output' + string.zfill(rank,3) + '.txt'
    foutput = open(fileName, 'w')
    sys.stdout = foutput

#####################################
# create mesh
print 'Create mesh'
# size of rectangular domain
xmin = 0.
xmax = 1.
ymin = xmin
ymax = xmax

# number of nodes per direction in global mesh
nx = 101
ny = 101

# compute coordinates of nodes
x = np.linspace(xmin, xmax, nx)
y = np.linspace(ymin, ymax, ny)

# number of subdomains per direction 
npx = int(sys.argv[1])
npy = int(sys.argv[2])

# build mesh
m = mesh.Mesh() 
m.fromCoords(x, y, npx, npy)

# show mesh
#m.show()

#######################################
# build list of interfaces
print 'Create list of interfaces'
interf = interface.build_interface(m)

#######################################
# create matrix
print 'Create matrix A'
A = fem.buildEFLaplacianMatrix(m)

#######################################
# create rhs
print 'Create right hand size b'
b = np.ones(m.nodes.shape[0])
b = fem.buildRhs(m, b)
#b = np.zeros(m.nodes.shape[0])

########################################
# create dirichlet boundaries conditions
print 'Create dirichlet boundaries conditions'
xdir, ldir = fem.buildDirichletBc(m, [1., 1., 1., 1.], [ymin, xmax, ymax, xmin])

########################################
# solve problem
print 'Resolution of Ax = b'
#initial solution
x0 = np.zeros(b.shape)
# conjugate gradient resolution
print ' parallel conjugate gradient resolution'
x, residual, ite = cg.cgPar(A, b, x0, ldir, xdir, interf, withInfo=True, verbose=False)
#m.plotSolution(x, parallel=True)
#plt.show()
print '   convergence in', ite, ' iterations'
print '   ||Ax - b|| / ||Ax0 - b|| = ', residual[ite]
print '   ***************************************************'
print ' parallel preconditioned conjugate gradient resolution'
x, residual, ite = cg.cgPrecPar(A, b, x0, ldir, xdir, interf, withInfo=True, verbose=False)
#m.plotSolution(x, parallel=True)
#plt.show()
print '   convergence in', ite, ' iterations'
print '   ||Ax - b|| / ||Ax0 - b|| = ', residual[ite]
print '   ***************************************************'
print ' schur method (parallel conjugate gradient on interface problem)'
x, residual, ite = cg.schur(A, b, x0, ldir, xdir, interf, withInfo=True, verbose=False)
print '   convergence in', ite, ' iterations'
print '   ||Ax - b|| / ||Ax0 - b|| = ', residual[ite]

#show solution
#m.showSolution(x)

m.plotSolution(x, parallel=True)
plt.show()

#plt.figure()
#plt.semilogy(residual)
#plt.show()
