/* 
 * Copyright 2008-2009 CAPS entreprise. All rights reserved.
 */

#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include <time.h>

#define TITLE "sgemm4_guard"
#define PLOT "replot"

#include "util.h"
#include "sgemm.h"

/* TO DO: Add guard execution of codelet sgemm2 only when size<130 */
/*        Add guard execution of codelet sgemm3 only when size>=130 */ 
#pragma hmpp sgemm2 codelet, target=CUDA, args[vin1;vin2;vout].mirror, args[vin1;vin2;vout].transfer=manual, args[alpha;beta].transfer=atfirstcall, args[n].transfer=atcall, cond="size<130"
#pragma hmpp sgemm3 codelet, target=CUDA, args[vin1;vin2;vout].mirror, args[vin1;vin2;vout].transfer=manual, args[alpha;beta].transfer=atfirstcall, args[n].transfer=atcall, cond="size>=130"
void sgemm( int n, float alpha, const float vin1[n][n], const float vin2[n][n], float beta, float vout[n][n] );

int main(int argc, char **argv) {
  struct timeval start, end;

  if( argc != 4 ) {
    printf( "usage: %s <seed> <from> <to>\n", argv[0] );
    exit(1);
  }
  
  FILE *data_file = fopen( TITLE".dat", "w" );
  if( ! data_file ) {
    perror( TITLE".dat" );
    exit(1);
  }

  int seed = atoi( argv[1] );
  int range_from = atoi( argv[2] );
  int range_to = atoi( argv[3] );

  float alpha, beta, *vin1, *vin2, *vout;
  if( ! init( range_to, seed, &alpha, &beta, &vin1, &vin2, &vout ) ) {
    printf( "Initialization failed.\n" );
    return 1;
  }

  int *sizes = getSizes( range_from, range_to );

  int i = 0;
  while( sizes[i] != 0 ) {
    int size = sizes[i++];
    double best = 0;

    #pragma hmpp sgemm2 allocate, data["vin1";"vin2";"vout"], size={size*size}
    #pragma hmpp sgemm2 advancedload, data["vin1";"vin2";"vout"]

    #pragma hmpp sgemm3 allocate, data["vin1";"vin2";"vout"], size={size*size}
    #pragma hmpp sgemm3 advancedload, data["vin1";"vin2";"vout"]

    int j;
    for( j = 0 ; j < 2 ; j++ ) {
      double current, t0, t1;
      
      t0 = wallclock();

      #pragma hmpp sgemm2 callsite
      #pragma hmpp sgemm3 callsite
      sgemm( size, alpha, vin1, vin2, beta, vout );
      
      t1 = wallclock();
      current = t1 - t0;

      if( best == 0 )
	    best = current;
      else if( best > current )
	    best = current;
    }

  #pragma hmpp sgemm2 delegatedstore, data["vout"]
  #pragma hmpp sgemm3 delegatedstore, data["vout"]

  #pragma hmpp sgemm2 free, data["vin1";"vin2";"vout"]
  #pragma hmpp sgemm3 free, data["vin1";"vin2";"vout"]

    fprintf( data_file, "%8d %10lf\n", size,  (double)NB_FLOP((long long)size) / (double)best );
    printf("[%4d x %4d] %12f %12f (...) %12f %12f \n", size, size, vout[0], vout[1], vout[size*(size-1) + size-2], vout[size*(size-1) + size-1]);
  }

  fclose( data_file );

  if( ! printGnuplotFile( TITLE".gp", TITLE, PLOT ) )
    return 1;

  free( sizes );
  free( vin1 );
  free( vin2 );
  free( vout );

  return 0;
}
