module linalg_mod
  use omp_lib
  use CSR_mod
  implicit none
contains
  !> gradient conjugué
  !! @param A matrice CSR
  !! @param b second membre
  !! @param x solution 
  !! @param tol tolérance 
  !! @param itmax nombre d'itérations maximal
  subroutine cg(A, b, x, tol, itmax)
    type(CSR) :: A
    real(kind=8), dimension(A%n) :: x, b, g, w, Aw
    real(kind=8) :: residual, residual0, tol, rho, gamma, Aww, gg, gw, gAw
    integer :: itmax, it, i, nbThreads

    gg = 0.d0
    Aww = 0.d0
    gw = 0.d0
    gAw = 0.d0

    !$omp parallel
    nbThreads = omp_get_num_threads()

    call matmult(A, x, g)

    !$omp workshare
    g = g - b 
    w = g
    !$omp end workshare
    
    !$omp do reduction(+:gg)
    do i = 1, A%n
      gg = gg + g(i)*g(i)
    end do
    !$omp end do

    !$omp single
    residual0 = sqrt(gg)
   
    if (residual0 == 0.d0) residual0 = 1.d0
    it = 0
    !$omp end single

    do while (it < itmax)
       call matmult(A, w, Aw)

       !$omp do reduction(+:Aww,gw)
       do i = 1, A%n
         Aww = Aww + Aw(i)*w(i)
         gw = gw + g(i)*w(i)
       end do
       !$omp end do

       !$omp single
       rho = -gw/Aww
       gg = 0.d0 
       !$omp end single

       !$omp workshare
       x = x + rho*w
       g = g + rho*Aw
       !$omp end workshare

       !$omp do reduction(+:gg)
       do i = 1, A%n
         gg = gg + g(i)*g(i)
       end do
       !$omp end do

       !$omp single
       residual = sqrt(gg)/residual0
       !!print*, 'iteration', it, 'residual ->', residual
       !$omp end single

       if (residual < tol) exit

       !$omp do reduction(+:gAw)
       do i = 1, A%n
         gAw = gAw + g(i)*Aw(i)
       end do
       !$omp end do

       !$omp single
       gamma = -gAw/Aww
       gAw = 0.d0
       Aww = 0.d0 
       gw = 0.d0 
       it = it +1
       !$omp end single

       !$omp workshare
       w = g + gamma*w
       !$omp end workshare


    end do
    !$omp end parallel
    print '("Conjugate gradient solver on ",i2, " threads")', nbThreads
    if (it < itmax) then
      print '("Convergence in ", i4, " iteration(s), residual = ", e20.14)', it+1, residual
    else
      print '("No convergence in ", i4, " iteration(s), residual = ", e20.14)', it+1, residual
    end if
 
  end subroutine cg

end module linalg_mod
