import numpy as np
import mpi4py.MPI as mpi
import scipy.sparse.linalg as splg
import cg

def buildZ(R, D):
    Z = np.ones(R.shape[1])
    zi = R*Z
    zi = D*zi
    Z = R.transpose()*zi
    return Z

def buildE(A, Z, npx, npy):
    allZ = mpi.COMM_WORLD.allgather(Z, None)
    
    size = mpi.COMM_WORLD.size
    E = np.zeros((size, size))
    for i in xrange(size):
        t = A*allZ[i]
        for j in xrange(size):
            E[j, i] = np.dot(allZ[j].transpose(), t)
    return E

def Q(u, Z, E):
    rank = mpi.COMM_WORLD.rank

    vaux = np.dot(Z.transpose(), u)
    allvaux = mpi.COMM_WORLD.allgather(vaux, None)
    
    sol, info = splg.gmres(E, allvaux)
    
    resl = sol[rank]*Z 
    
    res = np.empty(resl.shape)
    
    mpi.COMM_WORLD.Allreduce([resl, mpi.DOUBLE], [res, mpi.DOUBLE], op=mpi.SUM)
    return res

def P(u, A, Z, E):
    res = Q(u, Z, E)
    res2 = A*res
    res2 -= u
    res2 *= -1.
    return res2

def PT(u, A, Z, E):
    res = A*u
    res2 = Q(res, Z, E)
    res2 -= u
    res2 *= -1.
    return res2

def Mm1(Ai, bi, ui0, ldir, xdir):
    if ldir is None:
        ui = cg.cg(Ai, bi, ui0, ldir, xdir)
    else:
        ui = cg.cg(Ai, bi, ui0, ldir, np.zeros(xdir.shape))
    


    
