// left preconditioned GMRES with ILU(0) preconditioner
// FGMRES with RCI in intel MKL as non-preronditioned version to (QA)x=Qb
// for Seminar of Cybermedia Center, Osaka University, 21th November 2018
// "On solution of large linear system : iterative solver"
// Copyright 2018, Atsushi Suzuki : Atsushi.Suzuki@cas.cmc.osaka-u.ac.jp

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <math.h>

#include <mkl.h>

#include <time.h>

typedef struct timespec elapsed_t;
#define COPYTIME(a, b) ((a).tv_sec = (b).tv_sec);\
((a).tv_nsec = (b).tv_nsec)

void get_realtime(elapsed_t *tm)
{
  //clock_gettime(CLOCK_REALTIME, tm);
  clock_gettime(CLOCK_MONOTONIC, tm);
}

double convert_time(elapsed_t time1, elapsed_t time0)
{
  double t;
  t = ((double)time1.tv_sec - 
       (double)time0.tv_sec +
       ((double)time1.tv_nsec - 
	(double)time0.tv_nsec) / 1.0e+9);
  return t;
}

int convert_sec(elapsed_t t)
{
  return (int)t.tv_sec;
}

int convert_microsec(elapsed_t t)
{
  return (int)(t.tv_nsec / 1.0e+3);
}

struct csr_matrix {
  MKL_INT nrow, nnz;
  MKL_INT *ia, *ja;
  double *coefs;
};

int main(int argc, char **argv)
{
  int n, itmp, jtmp, ktmp;
  char fname[256], fname1[256];
  char buf[1024];
  MKL_INT nrow, nnz, nnz_orig;
  int *ptrows, *indcols;
  MKL_INT *irow, *jcol;
  double *val, *coefsilu, *sol, *exact, *rhs;
  int max_iter;
  double tol_eps;
  csr_matrix a;
 
  FILE *fp;
  bool isSym, upper_flag;
  int nexcl = 0;

  if (argc < 3) {
    fprintf(stderr, "%s [data file] [max_iter] [tol_eps]\n", argv[0]);
    exit(-1);
  }    
  strcpy(fname, argv[1]);
  max_iter = atoi(argv[2]);
  tol_eps = atof(argv[3]);

    // read from the file
  if ((fp = fopen(fname, "r")) == NULL) {
    fprintf(stderr, "fail to open %s\n", fname);
  }
  fgets(buf, 256, fp);
  //
  if (strstr(buf, "symmetric") != NULL) {
   isSym = true;
  }
  else {
    isSym = false;
  }

  fprintf(stderr, "symmetric = %s\n", isSym ? "true " : "false");

  while (1) {
    fgets(buf, 256, fp);
    if (buf[0] != '%') {
      sscanf(buf, "%d %d %d", &itmp, &jtmp, &ktmp);
      nrow = itmp;
      nnz_orig = ktmp;
      break;
    }
  }
  nnz = isSym ? nnz_orig * 2 : nnz_orig;
  irow = new MKL_INT[nnz];
  jcol = new MKL_INT[nnz];
  val = new double[nnz];
  {
    int ii = 0;
    int itmp, jtmp;
    double vtmp;
    for (int i = 0; i < nnz_orig; i++) {
      fscanf(fp, "%d %d %lf", &itmp, &jtmp, &vtmp);
      if (vtmp != 0.0 || (itmp == jtmp)) {
	irow[ii] = itmp; // one based
	jcol[ii] = jtmp; // one based
	val[ii] = vtmp;
	ii++;
	if (isSym && (itmp != jtmp)) {
	  irow[ii] = jtmp; // one based
	  jcol[ii] = itmp; // one based
	  val[ii] = vtmp;
	  ii++;
	}
      }
    }
    fprintf(stderr, "%d\n", ii);
    nnz = ii;
  }

  fclose (fp);
  if (nnz_orig > nnz) {
    fprintf(stderr, "%s %d : zero entries excluded %d -> %d\n",
	  __FILE__, __LINE__, nnz_orig, nnz);
  }

  clock_t t0_cpu, t1_cpu, t2_cpu;
  elapsed_t t0_elapsed, t1_elapsed, t2_elapsed;

  t0_cpu = clock();
  get_realtime(&t0_elapsed);

  MKL_INT job[6] = {0};
  job[0] = 2; // COO -> CSR with increasing order in each row
  job[1] = 1; // one based
  job[2] = 1; // one based
  a.ia = new MKL_INT[nrow + 1];
  a.ja = new MKL_INT[nnz];
  a.coefs = new double[nnz];
  coefsilu = new double[nnz];
  MKL_INT info;
  mkl_dcsrcoo(job, &nrow, a.coefs, a.ja, a.ia, &nnz, val, irow, jcol, &info);
#if 0
  for (int i = 0; i < nrow; i++) {
    for (int k = a.ia[i]; k < a.ia[i + 1]; k++) {
      fprintf(stderr, "%d %d %g\n", i, a.ja[k - 1] - 1, a.coefs[k - 1]);
    }
  }
#endif
  MKL_INT ierr;
  MKL_INT RCI_request;
  MKL_INT ipar[128];
  double dpar[128];
  double *tmp, *tmp0, *tmp1;

  ipar[1] = 6;
  ipar[4] = max_iter;
  ipar[14] = max_iter;
  int size_tmp = ((2*ipar[14]+1)*nrow+(ipar[14]*(ipar[14]+9))/2 + 1);
  fprintf(stderr, "%d\n", size_tmp);
  tmp = new double[size_tmp];
  tmp0 = new double[nrow];
  tmp1 = new double[nrow];
  
  dfgmres_init (&nrow, sol, rhs, &RCI_request, ipar, dpar, tmp);
  for (int i = 0; i < 23; i++) {
    fprintf(stderr, "%d ", (int)ipar[i]);
  }
  fprintf(stderr, "\n");
  // ILU factorization by MKL function
  dcsrilu0(&nrow, a.coefs, a.ia, a.ja, coefsilu, ipar, dpar, &ierr);
  
  char cvar = 'n';
  char uplo, transa, diag;
  ipar[1] = 6;          // type of output for error : standard Fortran output
  ipar[4] = max_iter;
  ipar[14] = max_iter;  // non-restarted version
  ipar[7] = 1;   // stopping test for the maximum number of iterations
  ipar[8] = 1;   // residual stopping test
  ipar[9] = 0;
  ipar[10] = 0;  // non preconditioned version of FGMRES == GMRES
  ipar[11] = 1;  // zero norm of the currently generated vector
  dpar[0] = tol_eps;
  sol = new double[nrow];
  exact = new double[nrow];
  rhs = new double[nrow];
  for (int i = 0; i < nrow; i++) {
    exact[i] = (double)(i % 100);
  }
  mkl_dcsrgemv(&cvar, &nrow, a.coefs, a.ia, a.ja, exact, tmp0);

  uplo = 'l';
  transa = 'n';
  diag = 'u';
  mkl_dcsrtrsv(&uplo, &transa, &diag, &nrow, coefsilu, a.ia, a.ja,
	       tmp0, tmp1);
  uplo = 'u';
  transa = 'n';
  diag = 'n';
  mkl_dcsrtrsv(&uplo, &transa, &diag, &nrow, coefsilu, a.ia, a.ja,
	       tmp1, rhs);

  for (int i = 0; i < nrow; i++) {
    sol[i] = 0.0;
  }

  dfgmres_check (&nrow, sol, rhs, &RCI_request, ipar, dpar, tmp);
  fprintf(stderr, "after check %d\n", RCI_request);

  t1_cpu = clock();
  get_realtime(&t1_elapsed);

  int m = 0;
  while (1) {
    dfgmres (&nrow, sol, rhs, &RCI_request, ipar, dpar, tmp);

    if (RCI_request <= 0) {
      break;
    }
    if (RCI_request == 1) {
      fprintf(stderr, "%d %g\n", m, dpar[4]);
      mkl_dcsrgemv(&cvar, &nrow, a.coefs, a.ia, a.ja,
		   &tmp[ipar[21] - 1],
		   tmp0);

      uplo = 'l';
      transa = 'n';
      diag = 'u';
      mkl_dcsrtrsv(&uplo, &transa, &diag, &nrow, coefsilu, a.ia, a.ja,
		   tmp0, tmp1);
      uplo = 'u';
      transa = 'n';
      diag = 'n';
      mkl_dcsrtrsv(&uplo, &transa, &diag, &nrow, coefsilu, a.ia, a.ja,
		   tmp1, &tmp[ipar[22] - 1]);

    }
    m++;
  }
  fprintf(stderr, "\n");
  MKL_INT itercount;
  dfgmres_get (&nrow, sol, rhs, &RCI_request, ipar, dpar, tmp,
	       &itercount);

  t2_cpu = clock();
  get_realtime(&t2_elapsed);

  
  fprintf(stderr, "%s %d %.4e %.4e\n",  __FILE__, __LINE__,
	    (double)(t2_cpu - t0_cpu) / (double)CLOCKS_PER_SEC,
	    convert_time(t2_elapsed, t0_elapsed));

  fprintf(stderr, "%s %d %.4e %.4e\n",  __FILE__, __LINE__,
	    (double)(t1_cpu - t0_cpu) / (double)CLOCKS_PER_SEC,
	    convert_time(t1_elapsed, t0_elapsed));

  fprintf(stderr, "%s %d %.4e %.4e\n",  __FILE__, __LINE__,
	    (double)(t2_cpu - t1_cpu) / (double)CLOCKS_PER_SEC,
	    convert_time(t2_elapsed, t1_elapsed));

  double norm0 = 0.0;
  double norm1 = 0.0;
  for (int i = 0; i < nrow; i++) {
    norm1 += (sol[i] - exact[i]) * (sol[i] - exact[i]);
    norm0 += exact[i] * exact[i];
  }
  fprintf(stderr, "error = %g\n", sqrt(norm1 / norm0));

  mkl_dcsrgemv(&cvar, &nrow, a.coefs, a.ia, a.ja, sol, tmp0);

  uplo = 'l';
  transa = 'n';
  diag = 'u';
  mkl_dcsrtrsv(&uplo, &transa, &diag, &nrow, coefsilu, a.ia, a.ja,
	       tmp0, tmp1);
  uplo = 'u';
  transa = 'n';
  diag = 'n';
  mkl_dcsrtrsv(&uplo, &transa, &diag, &nrow, coefsilu, a.ia, a.ja,
	       tmp1, exact);

  norm0 = 0.0;
  norm1 = 0.0;
  for (int i = 0; i < nrow; i++) {
    norm1 += (exact[i] - rhs[i]) * (exact[i] - rhs[i]);
    norm0 += rhs[i] * rhs[i];
  }
  fprintf(stderr, "residual = %g\n", sqrt(norm1 / norm0));

}
