module linalg_mod
  use CSR_mod
  use interface_mod
  use laplacian_mod
  use util_mod
  implicit none
contains
  !> gradient conjugué
  !! @param A matrice CSR
  !! @param b second membre
  !! @param x solution 
  !! @param tol tolérance 
  !! @param itmax nombre d'itérations maximal
  subroutine cg(A, b, x, tol, itmax)
    type(CSR) :: A
    real(kind=8), dimension(A%n) :: x, b, g, w, Aw
    real(kind=8) :: residual, residual0, tol, rho, gamma, aww
    integer :: itmax, it

    call matmult(A, x, g)
    
    g = g - b 
    w = g
    
    residual0 = sqrt(DOT_PRODUCT(g, g))
   
    if (residual0 == 0.d0) residual0 = 1.d0

    it = 0

    do while (it < itmax)
       call matmult(A, w, Aw)
       aww = DOT_PRODUCT(Aw, w)
       rho = -DOT_PRODUCT(g, w)/aww

       x = x + rho*w
       g = g + rho*Aw

       residual = sqrt(DOT_PRODUCT(g, g))/residual0

       print*, 'iteration', it, 'residual ->', residual

       if (residual < tol) exit

       gamma = -DOT_PRODUCT(g, Aw)/aww
       w = g + gamma*w

       it = it +1

    end do
  end subroutine cg

  !> gradient conjugué parallèle
  !! @param A matrice CSR d'un sous domaine
  !! @param b second membre d'un sous domaine
  !! @param x solution d'un sous domaine 
  !! @param tol tolérance 
  !! @param itmax nombre d'itérations maximal
  subroutine cgPara(A, b, x, interf, tol, itmax)
    type(CSR) :: A
    type(interface), dimension(:) :: interf
    real(kind=8), dimension(:) :: x, b
    real(kind=8), dimension(A%n) :: g, w, Aw
    real(kind=8) :: residual, residual0, tol, rho, gamma, aww
    integer :: itmax, it
    INTEGER :: ierr, rank

    call mpi_comm_rank(mpi_comm_world, rank, ierr)

    call matmultpara(A, x, g, interf)
    g = g - b 
    w = g
    
    residual0 = sqrt(DOT_PRODUCT_PARA(g, g))

    if (residual0 == 0.d0) residual0 = 1.d0

    it = 0

    do while (it < itmax)
       call matmultpara(A, w, Aw, interf)

       aww = DOT_PRODUCT_PARA(Aw, w)
       rho = -DOT_PRODUCT_PARA(g, w)/aww

       x = x + rho*w
       g = g + rho*Aw

       residual = sqrt(DOT_PRODUCT_PARA(g, g))/residual0

       if (rank == 0) print*, 'iteration', it, 'residual ->', residual

       if (residual < tol) exit

       gamma = -DOT_PRODUCT_PARA(g, Aw)/aww
       w = g + gamma*w

       it = it +1

    end do
  end subroutine cgPara

end module linalg_mod
