// left preconditioned CG with Additive Schwarz preconditioner
// preconditioned CG with RCI in intel MKL
// for Seminar of Cybermedia Center, Osaka University, 21th November 2018
// "On solution of large linear system : iterative solver"
// Copyright 2018, Atsushi Suzuki : Atsushi.Suzuki@cas.cmc.osaka-u.ac.jp

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <math.h>
#include <vector>

#include <mkl.h>
#include <time.h>
#include "metis.h"

typedef struct timespec elapsed_t;
#define COPYTIME(a, b) ((a).tv_sec = (b).tv_sec);\
((a).tv_nsec = (b).tv_nsec)

void get_realtime(elapsed_t *tm)
{
  //clock_gettime(CLOCK_REALTIME, tm);
  clock_gettime(CLOCK_MONOTONIC, tm);
}

double convert_time(elapsed_t time1, elapsed_t time0)
{
  double t;
  t = ((double)time1.tv_sec - 
       (double)time0.tv_sec +
       ((double)time1.tv_nsec - 
	(double)time0.tv_nsec) / 1.0e+9);
  return t;
}

int convert_sec(elapsed_t t)
{
  return (int)t.tv_sec;
}

int convert_microsec(elapsed_t t)
{
  return (int)(t.tv_nsec / 1.0e+3);
}

struct csr_matrix {
  MKL_INT nrow, nnz;
  MKL_INT *ia, *ja;
  double *coefs;
};

int main(int argc, char **argv)
{
  int itmp, jtmp, ktmp;
  char fname[256], fname1[256];
  char buf[1024];
  int nrow, nnz, nnz_orig;
  int *ptrows, *indcols;
  MKL_INT *irow, *jcol;
  double *val, *coefs, *sol, *exact, *rhs, *tmps;
  int max_iter;
  double tol_eps;
  csr_matrix a;
  
  FILE *fp;
  bool isSym;
  int nexcl = 0;
  int nparts = 8;
  int layer_overlap = 1;
  if (argc < 5) {
    fprintf(stderr,
	    "%s [data file] [max_iter] [tol_eps] [nparts] [layer_overlap]\n",
	    argv[0]);
    exit(-1);
  }    
  strcpy(fname, argv[1]);
  max_iter = atoi(argv[2]);
  tol_eps = atof(argv[3]);
  nparts = atoi(argv[4]);
  layer_overlap = atoi(argv[5]);
  // read from the file
  if ((fp = fopen(fname, "r")) == NULL) {
    fprintf(stderr, "fail to open %s\n", fname);
  }
  fgets(buf, 256, fp);
  //
  if (strstr(buf, "symmetric") != NULL) {
   isSym = true;
  }
  else {
    isSym = false;
    fprintf(stderr, "%s %d : input CSR matix is not symmetic\n",
	    __FILE__, __LINE__);
    fclose(fp);
    exit(-1);
  }

  fprintf(stderr, "symmetric = %s\n", isSym ? "true " : "false");

  while (1) {
    fgets(buf, 256, fp);
    if (buf[0] != '%') {
      sscanf(buf, "%d %d %d", &itmp, &jtmp, &ktmp);
      nrow = (MKL_INT)itmp;
      nnz_orig = (MKL_INT)ktmp;
      break;
    }
  }
  nnz = isSym ? nnz_orig * 2 : nnz_orig;
  irow = new MKL_INT[nnz];
  jcol = new MKL_INT[nnz];
  val = new double[nnz];
  {
    int ii = 0;
    int itmp, jtmp;
    double vtmp;
    for (int i = 0; i < nnz_orig; i++) {
      fscanf(fp, "%d %d %lf", &itmp, &jtmp, &vtmp);
      if (vtmp != 0.0 || (itmp == jtmp)) {
	irow[ii] = itmp - 1; // zero based
	jcol[ii] = jtmp - 1; // zero based
	val[ii] = vtmp;
	ii++;
	if (isSym && (itmp != jtmp)) {
	  irow[ii] = jtmp - 1; // zero based
	  jcol[ii] = itmp - 1; // zero based
	  val[ii] = vtmp;
	  ii++;
	}
      }
    }
    fprintf(stderr, "%d\n", ii);
    nnz = ii;
  }

  fclose (fp);
  if (nnz_orig > nnz) {
    fprintf(stderr, "%s %d : zero entries excluded %d -> %d\n",
	  __FILE__, __LINE__, nnz_orig, nnz);
  }

  clock_t t0_cpu, t1_cpu, t2_cpu;
  elapsed_t t0_elapsed, t1_elapsed, t2_elapsed;

  t0_cpu = clock();
  get_realtime(&t0_elapsed);

  MKL_INT job[6] = {0};
  job[0] = 2; // COO -> CSR with increasing order in each row
  a.nrow = nrow;
  a.nnz = nnz;
  a.ia = new MKL_INT[nrow + 1];
  a.ja = new MKL_INT[nnz];
  a.coefs = new double[nnz];
  MKL_INT info;
  mkl_dcsrcoo(job, &nrow, a.coefs, a.ja, a.ia, &nnz, val, irow, jcol, &info);

#if 0
  for (int i = 0; i < nrow; i++) {
    for (int k = a.ia[i]; k < a.ia[i + 1]; k++) {
      fprintf(stderr, "%d %d %g\n", i, a.ja[k], a.coefs[k]);
    }
  }
#endif

  idx_t *xadj, *adjcy, *part;
  xadj = new idx_t[nrow + 1];
  adjcy = new idx_t[nnz - nrow];
  part = new idx_t[nrow];
  int *itmps0, *itmps1;
  itmps0 = new int[nrow];
  itmps1 = new int[nrow];
  
  for (int i = 0; i < nrow; i++) {
    itmps0[i] = 0;
    for (int k = a.ia[i]; k < a.ia[i + 1]; k++) {
      if (a.ja[k] != i) {
	itmps0[i]++;
      }
    }
  }
  xadj[0] = 0;
  for (int i = 0; i < nrow; i++) {
    xadj[i + 1] = xadj[i] + itmps0[i];
  }
  for (int i = 0; i < nrow; i++) {
    itmps0[i] = xadj[i];
  }
     
  for (int i = 0; i < nrow; i++) {
    for (int k = a.ia[i]; k < a.ia[i + 1]; k++) {
      if (a.ja[k] != i) {
	adjcy[itmps0[i]] = a.ja[k];
	itmps0[i]++;
      }
    }
  }
  int ncon = 1;
  idx_t options[METIS_NOPTIONS] = {0} ;
  int objval;
  METIS_SetDefaultOptions(options);
  options[METIS_OPTION_NUMBERING] = 0;
  options[METIS_OPTION_DBGLVL] = METIS_DBG_INFO;
  fprintf(stderr, "nparts = %d\n", nparts);
  METIS_PartGraphRecursive(&nrow, &ncon, xadj, adjcy, NULL, NULL, NULL, &nparts,
			   NULL, NULL, options, &objval, part);
#if 0
  for (int i = 0; i < nrow; i++) {
    fprintf(stderr, "%d ", part[i]);
  }
#endif
  fprintf(stderr, "\n");
  delete [] xadj;
  delete [] adjcy;
  
  int **mask, *mask_work;
  mask_work = new int[nrow * nparts];
  mask = new int*[nrow];
  for (int i = 0; i < nparts; i++) {
    mask[i] = &mask_work[i * nrow];
  }
  memset(mask_work, 0, sizeof(int) * nrow * nparts); // zero clear
  // 1st layer overlap
  for (int i = 0; i < nrow; i++) {
    mask[part[i]][i] = 1;    
  }
  for (int ll = 0; ll < layer_overlap; ll++) {
    for (int n = 0; n < nparts; n++) {
      for (int i = 0; i < nrow; i++) {
	itmps0[i] = mask[n][i];
      }
      for (int i = 0; i < nrow; i++) {
	if (itmps0[i] == 1) {
	  for (int k = a.ia[i]; k < a.ia[i + 1]; k++) {
	    mask[n][a.ja[k]] = 1;
	  }
	}
      }
    }
  }
  csr_matrix *aa;
  aa = new csr_matrix[nparts];
  std::vector<int> *submatrix_indx;
  submatrix_indx = new std::vector<int>[nparts];
  double *partition_unity;
  partition_unity = new double[nrow];
  for (int i = 0; i < nrow; i++) {
    partition_unity[i] = 0.0;
  }
  
  for (int n = 0; n < nparts; n++) {
    for (int i = 0; i < nrow; i++) {
      if (mask[n][i] == 1) {
	submatrix_indx[n].push_back(i);
	partition_unity[i] += 1.0;
      }
    }
  }
  for (int i = 0; i < nrow; i++) {
    partition_unity[i] = 1.0 / partition_unity[i];
  }
#if 0
  for (int i = 0; i < nrow; i++) {
    fprintf(stderr, "%g ", partition_unity[i]);
  }
  fprintf(stderr, "\n");
#endif
  for (int n = 0; n < nparts; n++) {
    aa[n].nrow =submatrix_indx[n].size();
    aa[n].ia = new int[aa[n].nrow + 1];
    // inverse index to generate sub sparse matrix
    for (int i = 0; i < aa[n].nrow; i++) {
      itmps1[submatrix_indx[n][i]] = i;
    }
    for (int i = 0; i < nrow; i++) {
      itmps0[i] = 0;
    }
    for (int i = 0; i < aa[n].nrow; i++) {
      int ii = submatrix_indx[n][i];
      for (int k = a.ia[ii]; k < a.ia[ii + 1]; k++) {
	if (isSym) {
	  if (mask[n][a.ja[k]] == 1 && a.ja[k] >= ii) {
	    itmps0[i]++;
	  }
	}
	else {
	  if (mask[n][a.ja[k]] == 1) {
	    itmps0[i]++;
	  }
	}
      }
    }
    aa[n].ia[0] = 0;
    for (int i = 0; i < aa[n].nrow; i++) {
      aa[n].ia[i + 1] = aa[n].ia[i] + itmps0[i];
    }
    for (int i = 0; i < aa[n].nrow; i++) {
      itmps0[i] = aa[n].ia[i];
    }
    aa[n].nnz = aa[n].ia[aa[n].nrow];
    aa[n].ja = new int[aa[n].nnz];
    aa[n].coefs = new double[aa[n].nnz];
    for (int i = 0; i < aa[n].nrow; i++) {
      int ii = submatrix_indx[n][i];
      for (int k = a.ia[ii]; k < a.ia[ii + 1]; k++) {
	if (isSym) {
	  if (mask[n][a.ja[k]] == 1 && a.ja[k] >= ii) {
	    aa[n].ja[itmps0[i]] = itmps1[a.ja[k]];
	    aa[n].coefs[itmps0[i]] = a.coefs[k];
	    itmps0[i]++;
	  }
	}
	else {
	  if (mask[n][a.ja[k]] == 1) {
	    aa[n].ja[itmps0[i]] = itmps1[a.ja[k]];
	    aa[n].coefs[itmps0[i]] = a.coefs[k];
	    itmps0[i]++;
	  }
	}
      }
    }
  }
  delete [] mask;
  delete [] mask_work;
#if 0
  for (int n = 0; n < nparts; n++) {
    fprintf(stderr,
	    "*** part = %d nrow = %d nnz = %d\n", n, aa[n].nrow, aa[n].nnz);
    for (int i = 0; i < aa[n].nrow; i++) {
      fprintf(stderr, "%d %d %g : ", i, submatrix_indx[n][i],
	      partition_unity[submatrix_indx[n][i]]);
      for (int k = aa[n].ia[i]; k < aa[n].ia[i + 1]; k++) {
	fprintf(stderr, "%d ", aa[n].ja[k]);
      }
      fprintf(stderr, "\n");
    }
  }
#endif
  MKL_INT **iparam;
  MKL_INT maxfct, mnum, phase, error, msglvl, mtype;
  maxfct = 1;
  mnum = 1;
  msglvl = 1;
  MKL_INT nrhs = 1;
  void ***pt;
  MKL_INT idum; /* Integer dummy. */
  pt = new void**[nparts];
  iparam = new MKL_INT*[nparts];
  for (int n = 0; n < nparts; n++) {
    pt[n] = new void*[64];
    iparam[n] = new MKL_INT[64];
  }
  mtype = isSym ? -2 : 1;   // real and structually symmetric
  for (int n = 0; n < nparts; n++) {
    pardisoinit(pt[n], &mtype, iparam[n]);
    iparam[n][34] = 1; // zero-based
  }
  double **x, **b;
  x = new double*[nparts];
  b = new double*[nparts];
  for (int n = 0; n < nparts; n++) {
    x[n] = new double[aa[n].nrow];
    b[n] = new double[aa[n].nrow];
  }
  
  for (int n = 0; n < nparts; n++) {
    phase = 11;
    pardiso(pt[n], &maxfct, &mnum, &mtype, &phase,
	    &aa[n].nrow, (void *)aa[n].coefs, aa[n].ia, aa[n].ja, &idum, &nrhs,
	    iparam[n], &msglvl, (void *)b[n], (void *)x[n], &error);
    phase = 22;
    pardiso(pt[n], &maxfct, &mnum, &mtype, &phase,
	    &aa[n].nrow, (void *)aa[n].coefs, aa[n].ia, aa[n].ja, &idum, &nrhs,
	    iparam[n], &msglvl, (void *)b[n], (void *)x[n], &error);
  }  

  
  MKL_INT RCI_request;
  MKL_INT ipar[128];
  double dpar[128];
  double *tmp;

  ipar[1] = 6;
  ipar[4] = max_iter;
  int size_tmp = 4 * nrow;
  fprintf(stderr, "%d\n", size_tmp);
  tmp = new double[size_tmp];
  dcg_init (&nrow, sol, rhs, &RCI_request, ipar, dpar, tmp);
  for (int i = 0; i < 11; i++) {
    fprintf(stderr, "%d ", (int)ipar[i]);
  }
  fprintf(stderr, "\n");
  char cvar = 'n';
  ipar[1] = 6;          // type of output for error : standard Fortran output
  ipar[4] = max_iter;
  ipar[7] = 1;   // stopping test for the maximum number of iterations
  ipar[8] = 1;   // residual stopping test
  ipar[9] = 0;
  ipar[10] = 1;  // preconditioned version of CG

  dpar[0] = tol_eps * tol_eps;
  sol = new double[nrow];
  exact = new double[nrow];
  rhs = new double[nrow];
  tmps = new double[nrow];
  for (int i = 0; i < nrow; i++) {
    exact[i] = (double)(i % 100);
  }
  mkl_cspblas_dcsrgemv(&cvar, &nrow, a.coefs, a.ia, a.ja, exact, rhs);

  
  msglvl = 0;

  for (int i = 0; i < nrow; i++) {
    sol[i] = 0.0;
  }

  dcg_check (&nrow, sol, rhs, &RCI_request, ipar, dpar, tmp);

  t1_cpu = clock();
  get_realtime(&t1_elapsed);

  int m = 0;
  while (1) {
    dcg (&nrow, sol, rhs, &RCI_request, ipar, dpar, tmp);

    if (RCI_request <= 0) {
      break;
    }
    if (RCI_request == 1) {
      fprintf(stderr, "%d %g\n", m, sqrt(dpar[4]));
	
      mkl_cspblas_dcsrgemv (&cvar, &nrow, a.coefs, a.ia, a.ja,
			    &tmp[0], &tmp[nrow]);
      m++;
    }
    
    if (RCI_request == 3) {
      for (int i = 0; i < nrow; i++) {
	tmps[i] = 0.0;
      }
      for (int n = 0; n < nparts; n++) {
	phase = 33;
	for (int i = 0; i < submatrix_indx[n].size(); i++) {
	  b[n][i] = tmp[2 * nrow + submatrix_indx[n][i]];
	}
	pardiso(pt[n], &maxfct, &mnum, &mtype, &phase,
		&aa[n].nrow, (void *)aa[n].coefs, aa[n].ia, aa[n].ja,
		&idum, &nrhs,
		iparam[n], &msglvl, (void *)b[n], (void *)x[n], &error);
	for (int i = 0; i < submatrix_indx[n].size(); i++) {
	  int ii = submatrix_indx[n][i];
	  tmps[ii] += x[n][i]; //  * partition_unity[ii]; for symmetric operator
	}
      }
      for (int i = 0; i < nrow; i++) {
	tmp[3 * nrow + i] = tmps[i];
      }      
    }
  }
  fprintf(stderr, "\n");
  MKL_INT itercount;
  dcg_get (&nrow, sol, rhs, &RCI_request, ipar, dpar, tmp,
	       &itercount);

  t2_cpu = clock();
  get_realtime(&t2_elapsed);

  
  fprintf(stderr, "%s %d %.4e %.4e\n",  __FILE__, __LINE__,
	    (double)(t2_cpu - t0_cpu) / (double)CLOCKS_PER_SEC,
	    convert_time(t2_elapsed, t0_elapsed));

  fprintf(stderr, "%s %d %.4e %.4e\n",  __FILE__, __LINE__,
	    (double)(t1_cpu - t0_cpu) / (double)CLOCKS_PER_SEC,
	    convert_time(t1_elapsed, t0_elapsed));

  fprintf(stderr, "%s %d %.4e %.4e\n",  __FILE__, __LINE__,
	    (double)(t2_cpu - t1_cpu) / (double)CLOCKS_PER_SEC,
	    convert_time(t2_elapsed, t1_elapsed));

  double norm0 = 0.0;
  double norm1 = 0.0;
  for (int i = 0; i < nrow; i++) {
    norm1 += (sol[i] - exact[i]) * (sol[i] - exact[i]);
    norm0 += exact[i] * exact[i];
  }
  fprintf(stderr, "error = %g\n", sqrt(norm1 / norm0));
  mkl_cspblas_dcsrgemv(&cvar, &nrow, a.coefs, a.ia, a.ja, sol, exact);

  norm0 = 0.0;
  norm1 = 0.0;
  for (int i = 0; i < nrow; i++) {
    norm1 += (exact[i] - rhs[i]) * (exact[i] - rhs[i]);
    norm0 += rhs[i] * rhs[i];
  }
  fprintf(stderr, "residual = %g\n", sqrt(norm1 / norm0));
}
