module linalg_mod
  use omp_lib
  use CSR_mod
  implicit none
contains
  !> gradient conjugué
  !! @param A matrice CSR
  !! @param b second membre
  !! @param x solution 
  !! @param tol tolérance 
  !! @param itmax nombre d'itérations maximal
  subroutine cg(A, b, x, tol, itmax)
    type(CSR) :: A
    real(kind=8), dimension(A%n) :: x, b, g, w, Aw
    real(kind=8) :: residual, residual0, tol, rho, gamma, aww
    integer :: itmax, it, nbThreads

    call matmult(A, x, g)
    g = g - b 
    w = g
    
    residual0 = sqrt(DOT_PRODUCT(g, g))
   
    if (residual0 == 0.d0) residual0 = 1.d0
    it = 0

    do while (it < itmax)
       call matmult(A, w, Aw)
       aww = DOT_PRODUCT(Aw, w)
       rho = -DOT_PRODUCT(g, w)/aww

       x = x + rho*w
       g = g + rho*Aw

       residual = sqrt(DOT_PRODUCT(g, g))/residual0

       !!print*, 'iteration', it, 'residual ->', residual

       if (residual < tol) exit

       gamma = -DOT_PRODUCT(g, Aw)/aww
       w = g + gamma*w

       it = it + 1

    end do
    !$omp parallel
    nbThreads = omp_get_num_threads()
    !$omp end parallel
    print '("Conjugate gradient solver on ",i2, " threads")', nbThreads
    if (it < itmax) then
      print '("Convergence in ", i4, " iteration(s), residual = ", e20.14)', it+1, residual
    else
      print '("No convergence in ", i4, " iteration(s), residual = ", e20.14)', it+1, residual
    end if
  end subroutine cg

end module linalg_mod
