// Code elements finis de resolution Poisson
// Le code genere deux fichiers : carre_tria.mesh : le domaine de calculer maille à lire avec le logiciel medit ou ffmedt
// et le fichier carre_tria.sol la solution du calcul.
// On pourra visualiser avec la commande : "medit carre_tria" ou "ffmedit carre_tria" vus dans la formation freefem.


#include <stdlib.h>
#include <assert.h>
#include "petscsys.h"
#include "petscvec.h"
#include "petscmat.h"
#include "petscksp.h"

typedef struct {
	double c[2];
	int label;
}
Point;
typedef Point * pPoint;

typedef struct {
	int v[3];
	int label;
}
Tria;
typedef Tria * pTria;

typedef struct {
	int nPoint, nTria;
	pPoint point;
	pTria tria;
}
Mesh;
typedef Mesh * pMesh;

int createMesh(Mesh* mesh, int nbIntervalle, double h)
{
	int i,j;
	mesh->nPoint = (nbIntervalle + 1)*(nbIntervalle + 1);
	mesh->point = (pPoint)calloc((mesh->nPoint),sizeof(Point));
	assert(mesh->point);
		
	for (i = 0; i < nbIntervalle + 1;i++){
		for (j = 0; j < nbIntervalle + 1;j++){
			mesh->point[i*(nbIntervalle+1)+j].c[0] = j*h;
			mesh->point[i*(nbIntervalle+1)+j].c[1] = i*h;
			if (i == 0)
				mesh->point[i*(nbIntervalle+1)+j].label = 1;
			if (i == nbIntervalle)
				mesh->point[i*(nbIntervalle+1)+j].label = 3;
			if (j == 0)
				mesh->point[i*(nbIntervalle+1)+j].label = 4;
			if (j == nbIntervalle)
				mesh->point[i*(nbIntervalle+1)+j].label = 2;
		}
	}

	mesh->nTria = 2*nbIntervalle * nbIntervalle;
	mesh->tria = (pTria)calloc((mesh->nTria),sizeof(Tria));
	assert(mesh->tria);
	
	for (i = 0; i < nbIntervalle;i++){
		for (j = 0; j < nbIntervalle;j++){
			mesh->tria[i*2*nbIntervalle+2*j].v[0] = i*(nbIntervalle + 1) + j;
			mesh->tria[i*2*nbIntervalle+2*j].v[1] = i*(nbIntervalle + 1) + j + 1;
			mesh->tria[i*2*nbIntervalle+2*j].v[2] = (i+1)*(nbIntervalle + 1) + j;
			mesh->tria[i*2*nbIntervalle+2*j].label = 0;
      
			mesh->tria[i*2*nbIntervalle+2*j+1].v[0] = (i+1)*(nbIntervalle + 1) + j + 1;
			mesh->tria[i*2*nbIntervalle+2*j+1].v[1] = (i+1)*(nbIntervalle + 1) + j;
			mesh->tria[i*2*nbIntervalle+2*j+1].v[2] = i*(nbIntervalle + 1) + j + 1;
			mesh->tria[i*2*nbIntervalle+2*j+1].label = 0;
		}
	}
	return 0;
}

int writeMesh(MPI_Comm comm, Mesh* mesh)
{
	int i;
	FILE *pFile;
	pFile = fopen( "carre_tria.mesh","w" );
	if (pFile==NULL)
	{
		PetscPrintf(comm,"Impossible to write file: carre_quad_121v.mesh");
		exit(1);
	};
	
	PetscFPrintf(comm,pFile,"MeshVersionFormatted 2 \n Dimension\n2");
	PetscFPrintf(comm,pFile,"Vertices\n %d \n", mesh->nPoint);
	for (i = 0; i < mesh->nPoint;i++){
		PetscFPrintf(comm,pFile,"%lf %lf %d\n",mesh->point[i].c[0],mesh->point[i].c[1], mesh->point[i].label);
	}
	
	PetscFPrintf(comm,pFile,"Triangles\n %d \n", mesh->nTria);
	for (i = 0; i < mesh->nTria;i++){
		PetscFPrintf(PETSC_COMM_WORLD,pFile,"%d %d %d %d\n",mesh->tria[i].v[0]+1,mesh->tria[i].v[1]+1,mesh->tria[i].v[2]+1, mesh->tria[i].label);
	}
  
	PetscFPrintf(comm,pFile,"End\n");
	fclose(pFile);
	return 0;
}

int writeSolution(MPI_Comm comm, Mesh* mesh, Vec solution)
{
	PetscErrorCode ierr;
	int rank, i;
	Vec vout;
	VecScatter vecscat;
	double* xsol3;
	
	MPI_Comm_rank(comm, &rank);
	
	//On recupere les valeurs reparties sur les processeurs dans un vecteur sequentiel
	ierr = VecScatterCreateToAll(solution,&vecscat,&vout);CHKERRQ(ierr);
	ierr = VecScatterBegin(vecscat,solution,vout,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
	ierr = VecScatterEnd(vecscat,solution,vout,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);

	if ( rank == 0){
		ierr = VecGetArray(vout,&xsol3);CHKERRQ(ierr);
    
		FILE *pFile;
		pFile = fopen( "carre_tria.sol","w" );
		if (pFile==NULL){
			PetscPrintf(comm,"Impossible to write file: carre_tria.sol");
			exit(1);
		};
    
		PetscFPrintf(comm,pFile,"MeshVersionFormatted 2 \n Dimension\n2\nSolAtVertices %d\n1 1\n", mesh->nPoint);
		for (i = 0;i<mesh->nPoint;i++){
			PetscFPrintf(comm,pFile,"%e \n",xsol3[i]);
		}
		PetscFPrintf(comm,pFile,"End\n");
		ierr = VecRestoreArray(vout, &xsol3);CHKERRQ(ierr);
	}
  
	ierr = VecScatterDestroy(&vecscat);CHKERRQ(ierr);
	ierr = VecDestroy(&vout);CHKERRQ(ierr);
	return 0;
}

int deleteMesh(Mesh* mesh)
{
	if (mesh->point){
		free(mesh->point);
		mesh->point = NULL;
	};
	
	if (mesh->tria){
		free(mesh->tria);
		mesh->tria = NULL;
	};
	return 0;
}

int main(int argc, char** argv)
{
	
	/* Declaration des variables */
	int i;
	
	// petsc + parallele
	int rank, size;
	PetscErrorCode ierr;
	int start, end;
	int numElementPerProcessor, numPointPerProcessor;
	
	//Matrice + Vecteur
	Mat A;
	Vec rhs,x;
	
	//Solveur + preconditionneur
	KSP ksp;
	PC pc;	
	
	//Maillage
	Mesh mesh;	
	
	// Le domaine est defini ainsi : [a;b]x[a;b]
	double a,b;
	double h = 0.; //pas d'espace
	int nbIntervalle = 0;

	/* Initialisation Petsc + MPI */
	PetscInitialize(&argc,&argv, PETSC_NULL, PETSC_NULL);
	MPI_Comm_size(MPI_COMM_WORLD, &size);
	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
	PetscPrintf(MPI_COMM_SELF,"Proc : %d\n", rank);
	
	//Creation d'un maillage triangle
	a = 0.;
	b = 1.;
	h = 0.01;
	nbIntervalle = (b-a)/h;
	
	//Creation d'un maillage : support du calcul	
	createMesh(&mesh,nbIntervalle,h);
	writeMesh(MPI_COMM_WORLD, &mesh);
	
	//Definition de la matrice : A
	ierr = MatCreate(MPI_COMM_WORLD, &A);CHKERRQ(ierr);
	ierr = MatSetSizes(A, PETSC_DECIDE, PETSC_DECIDE, mesh.nPoint, mesh.nPoint);CHKERRQ(ierr);
	ierr = MatSetType(A, "mpiaij");CHKERRQ(ierr);
	ierr = MatSetFromOptions(A);CHKERRQ(ierr);
	ierr = MatSetUp(A);CHKERRQ(ierr);
	//PetscPrintf(MPI_COMM_WORLD,"A : \n");
	//ierr = MatView(A,PETSC_VIEWER_STDOUT_WORLD);CHKERRQ(ierr);
	
	//Definition du second membre : rhs
	ierr = VecCreate(MPI_COMM_WORLD, &rhs);CHKERRQ(ierr);
	ierr = VecSetSizes(rhs, PETSC_DECIDE, mesh.nPoint);CHKERRQ(ierr);
	ierr = VecSetType(rhs, "mpi");CHKERRQ(ierr);
	ierr = VecSetFromOptions(rhs);CHKERRQ(ierr);
	//PetscPrintf(MPI_COMM_WORLD,"rhs : \n");
	//ierr = VecView(rhs,PETSC_VIEWER_STDOUT_WORLD);CHKERRQ(ierr);
	
	//Definition du vecteur solution : x, par copie
	ierr = VecDuplicate(rhs, &x);CHKERRQ(ierr);
	
	//Assemblage de la matrice et du second-membre
	//Definition des indices sur les elements pour chaque processeur
	numElementPerProcessor = mesh.nTria/size;
	start = rank*numElementPerProcessor;
	end = (rank+1)*numElementPerProcessor;
	if ( rank == size - 1)
		end = start + (mesh.nTria - (size-1)*numElementPerProcessor);

  //Affichage des indices de boucles sur les elements par processeur
  //PetscPrintf(PETSC_COMM_SELF,"rank : %d, start : %d, end : %d \n", rank, start, end);

  //begin the loop
	double matElem[9]={1.,-0.5,-0.5,-0.5,0.5,0.,-0.5,0.,0.5};
	PetscPrintf(PETSC_COMM_WORLD,"Mat Elem \n");
  
	int column[3];
	for (i = start; i < end; i++){
		column[0] = mesh.tria[i].v[0];
		column[1] = mesh.tria[i].v[1];
		column[2] = mesh.tria[i].v[2];
		ierr = MatSetValues(A,3,&column[0],3,&column[0],&matElem[0],ADD_VALUES);CHKERRQ(ierr);
	}
  
	ierr = MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
	ierr = MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
  
  //ierr = MatView(A,PETSC_VIEWER_STDOUT_WORLD);
  
  //Application des conditions limites
  //on impose la valeur 1 sur le segment [0;1] (les points avec le label 1)
  //on impose la valeur 0 sur les autres bords (les points avec les labels 2,3,4)
  //on effectue une boucle sur les points. Chaque processeur effectue une partie de la boucle.
	numPointPerProcessor = mesh.nPoint/size;
	start = rank*numPointPerProcessor;
	end = (rank+1)*numPointPerProcessor;
	if ( rank == size - 1)
		end = start + (mesh.nPoint - (size-1)*numPointPerProcessor);
  
	//Affichage des indices de boucles sur les points par processeur
	//PetscPrintf(PETSC_COMM_SELF,"rank : %d, start : %d, end : %d \n", rank, start, end);
  
	double TGV = 1e20;
	for ( i = start; i < end; i++){
		column[0] = i;
		switch (mesh.point[i].label) {
			case 1:
			ierr = MatSetValues(A,1,&column[0],1,&column[0],&TGV,INSERT_VALUES);CHKERRQ(ierr);
			ierr = VecSetValue(rhs,column[0],1.*TGV,INSERT_VALUES);CHKERRQ(ierr);
			break;
			case 2:
			ierr = MatSetValues(A,1,&column[0],1,&column[0],&TGV,INSERT_VALUES);CHKERRQ(ierr);
			ierr = VecSetValue(rhs,column[0],0.,INSERT_VALUES);CHKERRQ(ierr);
			break;
			case 3:
			ierr = MatSetValues(A,1,&column[0],1,&column[0],&TGV,INSERT_VALUES);CHKERRQ(ierr);
			ierr = VecSetValue(rhs,column[0],0.,INSERT_VALUES);CHKERRQ(ierr);
			break;
			case 4:
			ierr = MatSetValues(A,1,&column[0],1,&column[0],&TGV,INSERT_VALUES);CHKERRQ(ierr);
			ierr = VecSetValue(rhs,column[0],0.,INSERT_VALUES);CHKERRQ(ierr);
			break;
			default:
			break;
		}
	}
  
	ierr = MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
	ierr = MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
  
  //MatView(A,PETSC_VIEWER_STDOUT_WORLD);
  
	ierr = VecAssemblyBegin(rhs);CHKERRQ(ierr);
	ierr = VecAssemblyEnd(rhs);CHKERRQ(ierr);
  //VecView(rhs,PETSC_VIEWER_STDOUT_WORLD);
  
	ierr = VecDuplicate(rhs,&x);CHKERRQ(ierr);
  //VecView(x,PETSC_VIEWER_STDOUT_WORLD);
  
	//Definition du solveur : ksp
	ierr = KSPCreate(PETSC_COMM_WORLD,&ksp);CHKERRQ(ierr);CHKERRQ(ierr);
	ierr = KSPSetOperators(ksp,A,A,DIFFERENT_NONZERO_PATTERN);CHKERRQ(ierr);
  //Set the PC type to be the new method
	ierr = KSPGetPC(ksp,&pc);CHKERRQ(ierr);
	ierr = PCSetType(pc,"asm");CHKERRQ(ierr);
	ierr = KSPSetFromOptions(ksp);CHKERRQ(ierr);
  //Affichage de la description du KSP
  //ierr = KSPView(ksp,PETSC_VIEWER_STDOUT_WORLD);CHKERRQ(ierr);
  
	//Resolution du systeme lineaire : A*x = rhs
	ierr = KSPSolve(ksp,rhs,x);CHKERRQ(ierr);
	int its;
	KSPGetIterationNumber(ksp,&its);
	PetscPrintf(PETSC_COMM_WORLD,"its : %d\n", its);
	  
	writeSolution(MPI_COMM_WORLD, &mesh, x);

	//Destruction des objets
	ierr = KSPDestroy(&ksp);CHKERRQ(ierr);
	ierr = MatDestroy(&A);CHKERRQ(ierr);
	ierr = VecDestroy(&rhs);CHKERRQ(ierr);
	ierr = VecDestroy(&x);CHKERRQ(ierr);
	deleteMesh(&mesh);
	PetscFinalize();
	return 0;
}
