from multiprocessing import Pool
import pyShift.cartTh as CTH
import pyShift.volume as PSV
import numpy as np
import sharedmem
import time

def volume_parallel(m, pool):
    n = m.shape[1]
    assert n > 2
    nv = n-1
    nv1 = nv / 2
    nv2 = nv-nv1
    v1v2 = pool.map(PSV.volume, [m[:, :nv1+1, ...], m[:, nv1:, ...]])
    return np.concatenate(v1v2, axis=0)

def volume_parallel_sharedmem(m, pool):
    m_shared = sharedmem.empty(m.shape, m.dtype)
    m_shared[...] = m
    n = m.shape[1]
    assert n > 2
    nv = n-1
    nv1 = nv / 2
    nv2 = nv-nv1
    v1v2 = pool.map(PSV.volume, [m_shared[:, :nv1+1, ...],
                                 m_shared[:, nv1:, ...]])
    return np.concatenate(v1v2, axis=0)

if __name__ == '__main__':

    pool = Pool()

    n = 31
    m = CTH.cartThNumpy(n, n, n)

    t0 = time.time()
    v = PSV.volume(m)
    t1 = time.time()
    vp1 = volume_parallel(m, pool)
    t2 = time.time()
    vp2 = volume_parallel_sharedmem(m, pool)
    t3 = time.time()

    assert (v == vp1).all()
    assert (v == vp2).all()
    
    print t1-t0, t2-t1, t3-t2
