// left preconditioned GMRES with Restricted Additive Schwarz preconditioner
// GMRES routine is based on https://math.nist.gov/iml++/gmres.h.txt
// for Seminar of Cybermedia Center, Osaka University, 21th November 2018
// "On solution of large linear system : iterative solver"
// Copyright 2018, Atsushi Suzuki : Atsushi.Suzuki@cas.cmc.osaka-u.ac.jp

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <math.h>
#include <vector>
#include <map>

#include <mkl.h>

#include <mkl_cblas.h>
#include <mkl_trans.h>

#include <time.h>
#include "metis.h"

typedef struct timespec elapsed_t;
#define COPYTIME(a, b) ((a).tv_sec = (b).tv_sec);\
((a).tv_nsec = (b).tv_nsec)

void get_realtime(elapsed_t *tm)
{
  //clock_gettime(CLOCK_REALTIME, tm);
  clock_gettime(CLOCK_MONOTONIC, tm);
}

double convert_time(elapsed_t time1, elapsed_t time0)
{
  double t;
  t = ((double)time1.tv_sec - 
       (double)time0.tv_sec +
       ((double)time1.tv_nsec - 
	(double)time0.tv_nsec) / 1.0e+9);
  return t;
}

int convert_sec(elapsed_t t)
{
  return (int)t.tv_sec;
}

int convert_microsec(elapsed_t t)
{
  return (int)(t.tv_nsec / 1.0e+3);
}

typedef MKL_INT BLAS_INT;

double blas_dot(const int n, const double* x, const int incx, 
		 const double* y, const int incy)
{
  return cblas_ddot((BLAS_INT)n, x, (BLAS_INT)incx, y, (BLAS_INT)incy);
}

void blas_copy(const int n, const double* x, const int incx, double* y,
	  const int incy)
{
  if ((incx == 1) && (incy == 1)) {
    memcpy((void *)y, (void *)x, n * sizeof(double));
  }
  else {
    cblas_dcopy((BLAS_INT)n, x, (BLAS_INT)incx, y, (BLAS_INT)incy);
  }
}

void blas_zero(const int n, double* y)
{
  memset((void *)y, 0, n * sizeof(double));
}

double blas_l2norm(const int n, double *x, const int incX)
{
  double tmp = blas_dot(n, x, incX, x, incX);
  return sqrt(tmp); //
}

struct csr_matrix {
  MKL_INT nrow, nnz;
  MKL_INT *ia, *ja;
  double *coefs;
};

void SpMV(const int nrow, const csr_matrix a, std::vector<double> &x, std::vector<double> &y)
{
  const double zero(0.0);
  for (int i = 0; i < nrow; i++) {
    y[i] = zero;
    for (int k = a.ia[i]; k < a.ia[i + 1]; k++) {
      int j = a.ja[k];
      y[i] += a.coefs[k] * x[j];
    }
  }
}

void GeneratePlaneRotation(double &dx, double &dy, double &cs, double &sn)
{
  const double _zero(0.0);
  const double _one(1.0);
    if (dy == _zero) {
        cs = _one;
        sn = _zero;
    } else if (fabs(dy) > fabs(dx)) {
        double temp = dx / dy;
        sn = _one / sqrt( _one + temp*temp );
        cs = temp * sn;
    } else {
        double temp = dy / dx;
        cs = _one / sqrt( _one + temp*temp );
        sn = temp * cs;
    }
}


void ApplyPlaneRotation(double &dx, double &dy, double &cs, double &sn)
{
    double temp  =  cs * dx + sn * dy;
    dy = -sn * dx + cs * dy;
    dx = temp;
}

void 
Update(std::vector<double> &x, int k, std::vector<double>* h,
       std::vector<double> &s, std::vector<double> *v)
{
  std::vector<double> y(s);

  // Backsolve:  
  for (int i = k; i >= 0; i--) {
    y[i] /= h[i][i];
    for (int j = i - 1; j >= 0; j--)
      y[j] -= h[j][i] * y[i];
  }

  for (int j = 0; j <= k; j++) {
    for (int n = 0; n < x.size(); n++) {
      x[n] += v[j][n] * y[j];
    }
  }
}

int main(int argc, char **argv)
{
  int itmp, jtmp, ktmp;
  char fname[256], fname1[256];
  char buf[1024];
  MKL_INT nrow, nnz, nnz_orig;
  int *ptrows, *indcols;
  MKL_INT *irow, *jcol;
  double *val, *coefs;
  int max_iter;
  double tol_eps;
  csr_matrix a;
  
  FILE *fp;
  bool isSym;
  int nexcl = 0;
  int nparts = 8;
  int layer_overlap = 1;
  
  if (argc < 5) {
    fprintf(stderr,
	    "%s [data file] [max_iter] [tol_eps] [nparts]", argv[0]);
    exit(-1);
  }    
  strcpy(fname, argv[1]);
  max_iter = atoi(argv[2]);
  tol_eps = atof(argv[3]);
  nparts = atoi(argv[4]);
  layer_overlap = atoi(argv[5]);
  
  // read from the file
  if ((fp = fopen(fname, "r")) == NULL) {
    fprintf(stderr, "fail to open %s\n", fname);
  }
  fgets(buf, 256, fp);
  //
  if (strstr(buf, "symmetric") != NULL) {
   isSym = true;
  }
  else {
    isSym = false;
  }

  fprintf(stderr, "symmetric = %s\n", isSym ? "true " : "false");

  while (1) {
    fgets(buf, 256, fp);
    if (buf[0] != '%') {
      sscanf(buf, "%d %d %d", &itmp, &jtmp, &ktmp);
      nrow = (MKL_INT)itmp;
      nnz_orig = (MKL_INT)ktmp;
      break;
    }
  }
  nnz = isSym ? nnz_orig * 2 : nnz_orig;
  irow = new MKL_INT[nnz];
  jcol = new MKL_INT[nnz];
  val = new double[nnz];
  {
    int ii = 0;
    int itmp, jtmp;
    double vtmp;
    for (int i = 0; i < nnz_orig; i++) {
      fscanf(fp, "%d %d %lf", &itmp, &jtmp, &vtmp);
      if (vtmp != 0.0 || (itmp == jtmp)) {
	irow[ii] = itmp - 1; // zero based
	jcol[ii] = jtmp - 1; // zero based
	val[ii] = vtmp;
	ii++;
	if (isSym && (itmp != jtmp)) {
	  irow[ii] = jtmp - 1; // zero based
	  jcol[ii] = itmp - 1; // zero based
	  val[ii] = vtmp;
	  ii++;
	}
      }
    }
    nnz = ii;
  }

  fclose (fp);
  if (nnz_orig > nnz) {
    fprintf(stderr, "%s %d : zero entries excluded %d -> %d\n",
	  __FILE__, __LINE__, nnz_orig, nnz);
  }

  clock_t t0_cpu, t1_cpu, t2_cpu;
  elapsed_t t0_elapsed, t1_elapsed, t2_elapsed;

  t0_cpu = clock();
  get_realtime(&t0_elapsed);

  MKL_INT job[6] = {0};
  job[0] = 2; // COO -> CSR with increasing order in each row
  a.nrow = nrow;
  a.nnz = nnz;
  a.ia = new MKL_INT[nrow + 1];
  a.ja = new MKL_INT[nnz];
  a.coefs = new double[nnz];
  MKL_INT info;
  mkl_dcsrcoo(job, &nrow, a.coefs, a.ja, a.ia, &nnz, val, irow, jcol, &info);
#if 0
  for (int i = 0; i < nrow; i++) {
    for (int k = a.ia[i]; k < a.ia[i + 1]; k++) {
      fprintf(stderr, "%d %d %g\n", i, a.ja[k], a.coefs[k]);
    }
  }
#endif

  idx_t *xadj, *adjcy, *part;
  xadj = new idx_t[nrow + 1];
  adjcy = new idx_t[nnz - nrow];
  part = new idx_t[nrow];
  int *itmps0, *itmps1;
  itmps0 = new int[nrow];
  itmps1 = new int[nrow];
  
  for (int i = 0; i < nrow; i++) {
    itmps0[i] = 0;
    for (int k = a.ia[i]; k < a.ia[i + 1]; k++) {
      if (a.ja[k] != i) {
	itmps0[i]++;
      }
    }
  }
  xadj[0] = 0;
  for (int i = 0; i < nrow; i++) {
    xadj[i + 1] = xadj[i] + itmps0[i];
  }
  for (int i = 0; i < nrow; i++) {
    itmps0[i] = xadj[i];
  }
     
  for (int i = 0; i < nrow; i++) {
    for (int k = a.ia[i]; k < a.ia[i + 1]; k++) {
      if (a.ja[k] != i) {
	adjcy[itmps0[i]] = a.ja[k];
	itmps0[i]++;
      }
    }
  }
  int ncon = 1;
  idx_t options[METIS_NOPTIONS] = {0} ;
  int objval;
  METIS_SetDefaultOptions(options);
  options[METIS_OPTION_NUMBERING] = 0;
  options[METIS_OPTION_DBGLVL] = METIS_DBG_INFO;
  METIS_PartGraphRecursive(&nrow, &ncon, xadj, adjcy, NULL, NULL, NULL, &nparts,
			   NULL, NULL, options, &objval, part);

  delete [] xadj;
  delete [] adjcy;
  
  int **mask, *mask_work;
  mask_work = new int[nrow * nparts];
  mask = new int*[nrow];
  for (int i = 0; i < nparts; i++) {
    mask[i] = &mask_work[i * nrow];
  }
  memset(mask_work, 0, sizeof(int) * nrow * nparts); // zero clear
  // 1st layer overlap
  for (int i = 0; i < nrow; i++) {
    mask[part[i]][i] = 1;    
  }
  for (int ll = 0; ll < layer_overlap; ll++) {
    for (int n = 0; n < nparts; n++) {
      for (int i = 0; i < nrow; i++) {
	itmps0[i] = mask[n][i];
      }
      for (int i = 0; i < nrow; i++) {
	if (itmps0[i] == 1) {
	  for (int k = a.ia[i]; k < a.ia[i + 1]; k++) {
	    mask[n][a.ja[k]] = 1;
	  }
	}
      }
    }
  }
  csr_matrix *aa;
  aa = new csr_matrix[nparts];
  std::vector<int> *submatrix_indx;
  submatrix_indx = new std::vector<int>[nparts];
  double *partition_unity;
  partition_unity = new double[nrow];
  for (int i = 0; i < nrow; i++) {
    partition_unity[i] = 0.0;
  }
  
  for (int n = 0; n < nparts; n++) {
    for (int i = 0; i < nrow; i++) {
      if (mask[n][i] == 1) {
	submatrix_indx[n].push_back(i);
	partition_unity[i] += 1.0;
      }
    }
  }
  for (int i = 0; i < nrow; i++) {
    partition_unity[i] = 1.0 / partition_unity[i];
  }
#if 0
  for (int i = 0; i < nrow; i++) {
    fprintf(stderr, "%g ", partition_unity[i]);
  }
  fprintf(stderr, "\n");
#endif
  for (int n = 0; n < nparts; n++) {
    aa[n].nrow =submatrix_indx[n].size();
    aa[n].ia = new int[aa[n].nrow + 1];
    // inverse index to generate sub sparse matrix
    for (int i = 0; i < aa[n].nrow; i++) {
      itmps1[submatrix_indx[n][i]] = i;
    }
    for (int i = 0; i < nrow; i++) {
      itmps0[i] = 0;
    }
    for (int i = 0; i < aa[n].nrow; i++) {
      int ii = submatrix_indx[n][i];
      for (int k = a.ia[ii]; k < a.ia[ii + 1]; k++) {
	if (isSym) {
	  if (mask[n][a.ja[k]] == 1 && a.ja[k] >= ii) {
	    itmps0[i]++;
	  }
	}
	else {
	  if (mask[n][a.ja[k]] == 1) {
	    itmps0[i]++;
	  }
	}
      }
    }
    aa[n].ia[0] = 0;
    for (int i = 0; i < aa[n].nrow; i++) {
      aa[n].ia[i + 1] = aa[n].ia[i] + itmps0[i];
    }
    for (int i = 0; i < aa[n].nrow; i++) {
      itmps0[i] = aa[n].ia[i];
    }
    aa[n].nnz = aa[n].ia[aa[n].nrow];
    aa[n].ja = new int[aa[n].nnz];
    aa[n].coefs = new double[aa[n].nnz];
    for (int i = 0; i < aa[n].nrow; i++) {
      int ii = submatrix_indx[n][i];
      for (int k = a.ia[ii]; k < a.ia[ii + 1]; k++) {
	if (isSym) {
	  if (mask[n][a.ja[k]] == 1 && a.ja[k] >= ii) {
	    aa[n].ja[itmps0[i]] = itmps1[a.ja[k]];
	    aa[n].coefs[itmps0[i]] = a.coefs[k];
	    itmps0[i]++;
	  }
	}
	else {
	  if (mask[n][a.ja[k]] == 1) {
	    aa[n].ja[itmps0[i]] = itmps1[a.ja[k]];
	    aa[n].coefs[itmps0[i]] = a.coefs[k];
	    itmps0[i]++;
	  }
	}
      }
    }
  }
  delete [] mask;
  delete [] mask_work;
  fprintf(stderr, "%s %d : nparts = %d : ", __FILE__, __LINE__, nparts);
  for (int n = 0; n < nparts; n++) {
    fprintf(stderr, "%d ", aa[n].nrow);
  }
  fprintf(stderr, "\n");
  
#if 0
  for (int n = 0; n < nparts; n++) {
    fprintf(stderr,
	    "*** part = %d nrow = %d nnz = %d\n", n, aa[n].nrow, aa[n].nnz);
    for (int i = 0; i < aa[n].nrow; i++) {
      fprintf(stderr, "%d %d %g : ", i, submatrix_indx[n][i],
	      partition_unity[submatrix_indx[n][i]]);
      for (int k = aa[n].ia[i]; k < aa[n].ia[i + 1]; k++) {
	fprintf(stderr, "%d ", aa[n].ja[k]);
      }
      fprintf(stderr, "\n");
    }
  }
#endif
  MKL_INT **iparam;
  MKL_INT maxfct, mnum, phase, error, msglvl, mtype;
  maxfct = 1;
  mnum = 1;
  msglvl = 1;

  MKL_INT nrhs = 1;
  void ***pt;
  MKL_INT idum; /* Integer dummy. */
  pt = new void**[nparts];
  iparam = new MKL_INT*[nparts];
  for (int n = 0; n < nparts; n++) {
    pt[n] = new void*[64];
    iparam[n] = new MKL_INT[64];
  }
  mtype = isSym ? -2 : 11;   // real and structually symmetric
  for (int n = 0; n < nparts; n++) {
    pardisoinit(pt[n], &mtype, iparam[n]);
    iparam[n][34] = 1; // zero-based
  }
  double **x, **b;
  x = new double*[nparts];
  b = new double*[nparts];
  for (int n = 0; n < nparts; n++) {
    x[n] = new double[aa[n].nrow];
    b[n] = new double[aa[n].nrow];
  }
  
  for (int n = 0; n < nparts; n++) {
    phase = 11;
    pardiso(pt[n], &maxfct, &mnum, &mtype, &phase,
	    &aa[n].nrow, (void *)aa[n].coefs, aa[n].ia, aa[n].ja, &idum, &nrhs,
	    iparam[n], &msglvl, (void *)b[n], (void *)x[n], &error);
    phase = 22;
    pardiso(pt[n], &maxfct, &mnum, &mtype, &phase,
	    &aa[n].nrow, (void *)aa[n].coefs, aa[n].ia, aa[n].ja, &idum, &nrhs,
	    iparam[n], &msglvl, (void *)b[n], (void *)x[n], &error);
  }  

  
  MKL_INT RCI_request;
  MKL_INT ipar[128];
  double dpar[128];

  ipar[1] = 6;
  ipar[4] = max_iter;
  ipar[14] = max_iter;

  std::vector<double> s, cs(max_iter + 1), sn(max_iter + 1), w(nrow);
  std::vector<double> * v = new std::vector<double>[max_iter + 1];
  std::vector<double> sol(nrow), rhs(nrow), exact(nrow), tmps(nrow);
  for (int i = 0; i < max_iter + 1; i++) {
    v[i].resize(nrow);
  }
  std::vector<double> *Hessenberg;
  Hessenberg = new std::vector<double>[max_iter + 1];
  for (int i = 0; i < max_iter + 1; i++) {
    Hessenberg[i].resize(max_iter);
  }

  char cvar = 'n';
  ipar[1] = 6;
  ipar[4] = max_iter;
  ipar[14] = max_iter;
  ipar[7] = 1;
  ipar[8] = 1;
  ipar[9] = 0;
  ipar[11] = 1;
  dpar[0] = tol_eps;

  for (int i = 0; i < nrow; i++) {
    exact[i] = (double)(i % 100);
  }

  SpMV(nrow, a, exact, rhs);
				 
  msglvl = 0;
  blas_zero(nrow, &tmps[0]);
  for (int n = 0; n < nparts; n++) {
    phase = 33;
    for (int i = 0; i < submatrix_indx[n].size(); i++) {
      b[n][i] = rhs[submatrix_indx[n][i]];
    }
    pardiso(pt[n], &maxfct, &mnum, &mtype, &phase,
	    &aa[n].nrow, (void *)aa[n].coefs, aa[n].ia, aa[n].ja, &idum, &nrhs,
	    iparam[n], &msglvl, (void *)b[n], (void *)x[n], &error);
    for (int i = 0; i < submatrix_indx[n].size(); i++) {
      int ii = submatrix_indx[n][i];
      tmps[ii] += x[n][i] * partition_unity[ii];
    }
  }

  blas_copy(nrow, &tmps[0], 1, &rhs[0], 1);

  for (int i = 0; i < nrow; i++) {
    sol[i] = 0.0;
  }

  t1_cpu = clock();
  get_realtime(&t1_elapsed);

  double beta = blas_l2norm(nrow, &rhs[0], 1);
  std::vector<double> r(nrow);
  blas_copy(nrow, &rhs[0], 1, &r[0], 1);
  for (int n = 0; n < nrow; n++) {
    v[0][n] = r[n] / beta;
  }
  s.resize(max_iter + 1, 0.0);
  s[0] = beta;
  fprintf(stderr, "beta = %.12e\n", beta);
  t1_cpu = clock();
  get_realtime(&t1_elapsed);

  int m;
  bool flag_conv = false;
  for (m = 0; m < max_iter; m++){
    SpMV(nrow, a, v[m], w);


    blas_zero(nrow, &tmps[0]);

    for (int n = 0; n < nparts; n++) {
      phase = 33;
      for (int i = 0; i < submatrix_indx[n].size(); i++) {
	b[n][i] = w[submatrix_indx[n][i]];
      }
      pardiso(pt[n], &maxfct, &mnum, &mtype, &phase,
	      &aa[n].nrow, (void *)aa[n].coefs, aa[n].ia, aa[n].ja,
	      &idum, &nrhs,
	      iparam[n], &msglvl, (void *)b[n], (void *)x[n], &error);
      for (int i = 0; i < submatrix_indx[n].size(); i++) {
	int ii = submatrix_indx[n][i];
	tmps[ii] += x[n][i] * partition_unity[ii];
      }
    }
    blas_copy(nrow, &tmps[0], 1, &w[0], 1);

    // Arnoldi process by modified Gram-Schmidt
    for (int k = 0; k <= m; k++) {
      Hessenberg[k][m] = blas_dot(nrow, &w[0], 1, &v[k][0], 1);
      for (int i = 0; i < nrow; i++) {
	w[i] -= Hessenberg[k][m] * v[k][i];
      }
    }
    double htmp = blas_l2norm(nrow, &w[0], 1);
    Hessenberg[m + 1][m] = htmp;

    htmp = 1.0 / htmp;
    for (int i = 0; i < nrow; i++) {
      v[m + 1][i] = w[i] * htmp;
    }
    for (int k = 0; k < m; k++) {
      ApplyPlaneRotation(Hessenberg[k][m], Hessenberg[k + 1][m],
			 cs[k], sn[k]);
    }
    
    GeneratePlaneRotation(Hessenberg[m][m], Hessenberg[m +1][m],
				  cs[m], sn[m]);
    ApplyPlaneRotation(Hessenberg[m][m], Hessenberg[m + 1][m],
			       cs[m], sn[m]);
    ApplyPlaneRotation(s[m], s[m + 1], cs[m], sn[m]);
    fprintf(stderr, "%d %.12e\n", m, fabs(s[m + 1]));
    if ( (fabs(s[m + 1]) / beta ) < tol_eps) {
      flag_conv = true;
      break;
    }
  } // loop : m
  fprintf(stderr, "\n");

  Update(sol, flag_conv ? m : (m - 1), Hessenberg, s, v);

  t2_cpu = clock();
  get_realtime(&t2_elapsed);

  
  fprintf(stderr, "%s %d %.4e %.4e\n",  __FILE__, __LINE__,
	    (double)(t2_cpu - t0_cpu) / (double)CLOCKS_PER_SEC,
	    convert_time(t2_elapsed, t0_elapsed));

  fprintf(stderr, "%s %d %.4e %.4e\n",  __FILE__, __LINE__,
	    (double)(t1_cpu - t0_cpu) / (double)CLOCKS_PER_SEC,
	    convert_time(t1_elapsed, t0_elapsed));

  fprintf(stderr, "%s %d %.4e %.4e\n",  __FILE__, __LINE__,
	    (double)(t2_cpu - t1_cpu) / (double)CLOCKS_PER_SEC,
	    convert_time(t2_elapsed, t1_elapsed));

  double norm0 = 0.0;
  double norm1 = 0.0;
  for (int i = 0; i < nrow; i++) {
    norm1 += (sol[i] - exact[i]) * (sol[i] - exact[i]);
    norm0 += exact[i] * exact[i];
  }
  fprintf(stderr, "error = %g\n", sqrt(norm1 / norm0));

  SpMV(nrow, a, sol, exact);
  
  for (int i = 0; i < nrow; i++) {
    tmps[i] = 0.0;
  }

  for (int n = 0; n < nparts; n++) {
    phase = 33;
    for (int i = 0; i < submatrix_indx[n].size(); i++) {
      b[n][i] = exact[submatrix_indx[n][i]];
    }
    pardiso(pt[n], &maxfct, &mnum, &mtype, &phase,
	    &aa[n].nrow, (void *)aa[n].coefs, aa[n].ia, aa[n].ja,
	    &idum, &nrhs,
	    iparam[n], &msglvl, (void *)b[n], (void *)x[n], &error);
    for (int i = 0; i < submatrix_indx[n].size(); i++) {
      int ii = submatrix_indx[n][i];
	  tmps[ii] += x[n][i] * partition_unity[ii];
    }
  }
  blas_copy(nrow, &tmps[0], 1, &exact[0], 1);
  //  for (int i = 0; i < nrow; i++) {
  //    exact[i] = tmps[i];
  //  }      
  norm0 = 0.0;
  norm1 = 0.0;
  for (int i = 0; i < nrow; i++) {
    norm1 += (exact[i] - rhs[i]) * (exact[i] - rhs[i]);
    norm0 += rhs[i] * rhs[i];
  }
  fprintf(stderr, "residual = %.12e\n", sqrt(norm1 / norm0));
  for (int n = 0; n < nparts; n++) {
    phase = -1;
    pardiso(pt[n], &maxfct, &mnum, &mtype, &phase,
	    &aa[n].nrow, (void *)aa[n].coefs, aa[n].ia, aa[n].ja, &idum, &nrhs,
	    iparam[n], &msglvl, (void *)b[n], (void *)x[n], &error);
  }
}
