#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fcntl.h>
#include <unistd.h>
#include <math.h>

#include "elapsed_time.hpp"

#include "mkl_types.h"
#include "mkl_pardiso.h"
#include "mkl_service.h"

#include <iostream>
#include <sstream>
#include <fstream>
#include <list>
#include <complex>
using namespace std;

template<typename T>
bool generate_CSR(std::list<int>* ind_cols_tmp, std::list<T>* val_tmp, 
		  int nrow, int *nnz, 
		  int *irow, int *jcol, T* val, bool symmetrize)
{
  bool flag_modified = false;
  const T zero(0.0);
  //  ind_cols_tmp = new std::list<int>[nrow];
  //  val_tmp = new std::list<T>[nrow];
  int nnz1 = *nnz;
  for (int i = 0; i < *nnz; i++) {
    const int ii = irow[i];
    const int jj = jcol[i];
    //    fprintf(stderr, "%d %d -> %d %d \n", i0, j0, ii, jj);
    if (ind_cols_tmp[ii].empty()) {
      ind_cols_tmp[ii].push_back(jj);
      val_tmp[ii].push_back(val[i]);
    }
    else {
      if (ind_cols_tmp[ii].back() < jj) {
	ind_cols_tmp[ii].push_back(jj);
	val_tmp[ii].push_back(val[i]);
      }
      else {
	std::list<int>::iterator it = ind_cols_tmp[ii].begin();
	typename std::list<T>::iterator iv = val_tmp[ii].begin();
	for ( ; it != ind_cols_tmp[ii].end(); ++it, ++iv) {
	  if (*it == jj) {
	    fprintf(stderr, "already exits? (%d %d)\n", ii, jj);
	      break;
	  }
	  if (*it > jj) {
	    ind_cols_tmp[ii].insert(it, jj);
	    val_tmp[ii].insert(iv, val[i]);
	    break;
	  }
	}
      }
    }
  }
  // symmetrize
  if (symmetrize) {
    for (int i = 0; i < nrow; i++) {
      for (std::list<int>::iterator jt = ind_cols_tmp[i].begin();
	   jt != ind_cols_tmp[i].end(); ++jt) {
	const int jj = (*jt);
	bool flag = false;
	for (std::list<int>::iterator it = ind_cols_tmp[jj].begin();
	      it != ind_cols_tmp[jj].end(); ++it) {
	  if ((*it) == i) {
	    flag = true;
	    break;
	  }
	}
	if (!flag) {
	  flag_modified = true;
	  if (ind_cols_tmp[jj].back() < i) {
	    ind_cols_tmp[jj].push_back(i);
	    val_tmp[jj].push_back(zero);
	    nnz1++;
	  }
	  else {
	    typename std::list<T>::iterator iv = val_tmp[jj].begin();
	    std::list<int>::iterator it = ind_cols_tmp[jj].begin();
	    for (; it != ind_cols_tmp[jj].end(); ++it, ++iv) {
	      if ((*it) > i) {
		ind_cols_tmp[jj].insert(it, i);
		val_tmp[jj].insert(iv, zero);
		nnz1++;
		break;
	      }
	    }
	  }
	} // if (!flag);
      }
    }
  }

  *nnz = nnz1;
  return flag_modified;
}

template
bool generate_CSR<double>(std::list<int>* ind_cols_tmp,
			  std::list<double>* val_tmp, 
			  int nrow, int *nnz, 
			  int *irow, int *jcol, double* val, bool symmetrize);

template<typename T>
void copy_CSR(int *ptrows, int *indcols, T* coefs, int nrow, 
	      bool upper_flag, bool isSym,
	      std::list<int>* ind_cols_tmp, std::list<T>* val_tmp)
{
  const T zero(0.0);
  ptrows[0] = 0;
  for (int i = 0; i < nrow; i++) {
    int k;
    int itmp = ind_cols_tmp[i].size();
    if (upper_flag) {
      if (ind_cols_tmp[i].front() == i) {
	ptrows[i + 1] = ptrows[i] + itmp;
	k = ptrows[i];
      }
      else {
	fprintf(stderr, "zero is added to diagonal : %d\n", i);
	ptrows[i + 1] = ptrows[i] + itmp + 1;
	indcols[ptrows[i]] = i;
	coefs[ptrows[i]] = zero;
	k = ptrows[i] + 1;
      }
    }
    else {
      k = ptrows[i];
      if (ind_cols_tmp[i].back() == i || (!isSym)) {
	ptrows[i + 1] = ptrows[i] + itmp;
      }
      else {
	fprintf(stderr, "zero is added to diagonal : %d\n", i);
	ptrows[i + 1] = ptrows[i] + itmp + 1;
	indcols[ptrows[i + 1] - 1] = i;
	coefs[ptrows[i + 1] - 1] = zero;
      }
    }
    std::list<int>::iterator it = ind_cols_tmp[i].begin();
    typename std::list<T>::iterator iv = val_tmp[i].begin();
    for ( ; it != ind_cols_tmp[i].end(); ++it, ++iv, k++) {
      indcols[k] = *it;
      coefs[k] = *iv;
    }
  } // loop : i
}

template
void copy_CSR<double>(int *ptrows, int *indcols, double* coefs, int nrow, 
	      bool upper_flag, bool isSym,
	      std::list<int>* ind_cols_tmp, std::list<double>* val_tmp);

inline
double SpMAX(double a, double b)
{
  return (a > b ? a : b);
}

void SpMV(double *y,  double *x, int nrow,
	  int *ptrows, int *indcols, double *coefs, bool isSym)
{
  if (isSym) {
  }
  else {
    for (int i = 0; i < nrow; i++) {
      y[i] = 0.0;
      for (int k = ptrows[i]; k < ptrows[i + 1]; k++) {
	y[i] += coefs[k] * x[indcols[k]];
      }
    }
  }
}

int main(int argc, char **argv)
{
  int n, itmp, jtmp;
  char fname[256], fname1[256], fname2[256];
  char buf[1024];
  int nrow, nnz, flag, nnz1;
  int *ptrows, *indcols;
  int *irow, *jcol;
  double *val, *coefs;
  int decomposer;
  int num_threads;
  double eps_pivot;
  int numlevels = -1;
  int minNodes = 128;
  std::list<int>* ind_cols_tmp;
  std::list<double>* val_tmp;
  FILE *fp0, *fp;
  bool isSym;
  bool upper_flag = true;
  bool kernel_detection_all = false;
  bool rhs_data = false;
  bool flag_modified = false;
  int perturb = -1;
  if (argc < 2) {
    fprintf(stderr, 
	    "MM-Pardiso [data file] [num_threads]\n");
    exit(-1);
  }    
  strcpy(fname, argv[1]);

  num_threads = atoi(argv[2]);

  if (argc >= 4) {
    perturb = atoi(argv[3]);
  }

  if (argc >= 5) {
    strcpy(fname1, argv[4]);
    rhs_data = true;
  }

  // read from the file
  if ((fp = fopen(fname, "r")) == NULL) {
    fprintf(stderr, "fail to open %s\n", fname);
  }
  fgets(buf, 256, fp);
  //
  if (strstr(buf, "symmetric") != NULL) {
   isSym = true;
  }
  else {
    isSym = false;
    upper_flag = false;
  }

  fprintf(stderr, "symmetric = %s\n", isSym ? "true " : "false");
  while (1) {
    fgets(buf, 256, fp);
    if (buf[0] != '%') {
      sscanf(buf, "%d %d %d", &nrow, &itmp, &nnz);
      break;
    }
  }
  irow = new int[nnz];
  jcol = new int[nnz];
  nnz1 = 0;
  {
    val = new double[nnz];
    if (upper_flag) {
      for (int i = 0; i < nnz; i++) {
	fscanf(fp, "%d %d %lf", &jcol[i], &irow[i], &val[i]); // read lower
	irow[i]--;
	jcol[i]--;
	if (isSym && irow[i] > jcol[i]) {
	  //	  fprintf(stderr, "exchanged : %d > %d\n", irow[i], jcol[i]);
	  itmp = irow[i];
	  irow[i] = jcol[i];
	  jcol[i] = itmp;
	}
      }
    }
    else {
      int ii = 0;
      int itmp, jtmp;
      double vtmp;
      nnz1 = 0;
      for (int i = 0; i < nnz; i++) {
	fscanf(fp, "%d %d %lf", &itmp, &jtmp, &vtmp);
//	if (vtmp != 0.0) {
	if (true) {
	  irow[ii] = itmp - 1;
	  jcol[ii] = jtmp - 1;
	  val[ii] = vtmp;
	  ii++;
	}
	else {
	  nnz1++;
	}
      }
    }
  }
  fclose (fp);
  if (nnz1 > 0) {
    fprintf(stderr, "%s %d : %d zero entries excluded %d -> %d\n",
	    __FILE__, __LINE__, nnz1, nnz, (nnz - nnz1));
  }
  nnz = nnz - nnz1;
  
  ind_cols_tmp = new std::list<int>[nrow];
  fprintf(stderr, "%s %d : getnerate_CSR\n", __FILE__, __LINE__);
  {
    val_tmp = new std::list<double>[nrow];
    nnz1 = nnz;
    flag_modified = generate_CSR<double>(ind_cols_tmp, val_tmp, 
					 nrow, &nnz1,
					 irow, jcol, val, true);
  }
  if (flag_modified) {
    fprintf(stderr, "%s %d : matrix is not structual symmetric %d ->%d\n",
	    __FILE__, __LINE__, nnz, nnz1);
  }
    nnz = nnz1;
  delete [] irow;
  delete [] jcol;
  delete [] val;

  fprintf(stderr, "%s %d : copy_CSR\n", __FILE__, __LINE__);
  ptrows = new int[nrow + 1];
  indcols = new int[nnz];
  {
    coefs = new double[nnz];
    copy_CSR<double>( ptrows, indcols, coefs, 
		     nrow, upper_flag, isSym,
		     ind_cols_tmp, val_tmp);
  }
  delete [] ind_cols_tmp;

  delete [] val_tmp;

#if 0
  if ((fp = fopen("debug.matrix.data", "w")) != NULL) {
    for (int i = 0; i < nrow; i++) {
      fprintf(fp, "%d : %d :: ", i, (ptrows[i + 1] - ptrows[i]));
      for (int k = ptrows[i]; k < ptrows[i + 1]; k++) {
	fprintf(fp, "%d ", indcols[k]);
      }
      fprintf(fp, "\n");
    }
  }
  fclose(fp);
#endif

  clock_t t0_cpu, t1_cpu, t2_cpu, t3_cpu, t4_cpu, t5_cpu;
  elapsed_t t0_elapsed, t1_elapsed, t2_elapsed, t3_elapsed;
  elapsed_t t4_elapsed, t5_elapsed;

  mkl_set_num_threads(num_threads);
  t0_cpu = clock();
  
  
  double *x = new double[nrow];
  double *y = new double[nrow];
  double *w = new double[nrow];
  double *z = new double[nrow];
  if (rhs_data) {
    fprintf(stderr, "%s %d : rhs = %s\n", __FILE__, __LINE__, fname1);
    int itmp0, itmp1, itmp2;
    double vtmp;
    if ((fp = fopen(fname1, "r")) == NULL) {
      fprintf(stderr, "fail to open %s\n", fname);
    }
    fgets(buf, 256, fp);
    while (1) {
      fgets(buf, 256, fp);
      if (buf[0] != '%') {
	sscanf(buf, "%d %d %d", &itmp0, &itmp1, &itmp2);
	if (itmp0 != nrow) {
	  fprintf(stderr, "%s %d : size mismatched %d %d\n",
		  __FILE__, __LINE__, nrow, itmp0);
	  exit(-1);
	}
	break;
      }
    } // while(1)
    for (int i = 0; i < nrow; i++) {
      fscanf(fp, "%d %d %lf", &itmp0, &itmp1, &vtmp);
      y[(itmp0 - 1)] = vtmp;
    }
  }     // if (rhs_data)
  else {
    for (int i = 0; i < nrow; i++) {
      y[i] = (double)(i % 11);
    }
    SpMV(z, y, nrow, ptrows, indcols, coefs, isSym);
    SpMV(y, z, nrow, ptrows, indcols, coefs, isSym);
  }
  mkl_set_num_threads(num_threads);
  { 
    char mklversion[256];
    mkl_get_version_string(mklversion, 256);
    fprintf(stderr, "%s\n", mklversion);
  }
  //  struct timespec ts0, ts1;
  
  MKL_INT mtype; 
  if (isSym) {
    mtype = (-2); /* Real symmetric matrix */
  }
  else {
    mtype = 11; // Real and structually symmetric
  }
  fprintf(stderr, "%s %d : mtype = %d\n", __FILE__, __LINE__, (int)mtype);
  /* RHS and solution vectors. */
  /* Internal solver memory pointer pt, */
  /* 32-bit: int pt[64]; 64-bit: long int pt[64] */
  /* or void *pt[64] should be OK on both architectures */
  
  /* Pardiso control parameters. */
  MKL_INT *iparm = new MKL_INT[64];
  MKL_INT maxfct, mnum, phase, error, msglvl;
  /* Auxiliary variables. */
  //  double ddum; /* Double dummy */
  //  MKL_INT idum; /* Integer dummy. */
  MKL_INT nrhs = 1;
  void *pt[64];
  MKL_INT idum; /* Integer dummy. */
  /* -------------------------------------------------------------------- */
  /* .. Setup Pardiso control parameters. */
  /* -------------------------------------------------------------------- */
  for (int i = 0; i < 64; i++) {
    iparm[i] = 0;
  }
  iparm[0] = 1; // do not use default vaules 
  iparm[1] = 2; // fill-in reducing ordering : 2 = METIS, 3 : nested dissection
  iparm[3] = 0; // preconditioned CGS
  iparm[4] = 0; // user permutation
  iparm[5] = 0; // write solution on x
  iparm[7] = 3; // iterative refinement step
  iparm[9] = perturb > 0 ? perturb : (isSym ? 8 : 13) ; // pivoting pertubation : 10^-8 = default value for mtype=-2
  iparm[10] = (isSym ? 0 : 1); // scaling vectors : default value for mtype=-2
  iparm[11] = 0; // solving with transposed matrix
  iparm[12] = (isSym ? 0 : 1); // improved accuracy using symmetic weighted watchings
  iparm[17] = -1; // Output: numbers of non-zeros elements in the fartors
  iparm[18] = -1; // Output: MFLOPS of factorization
  iparm[20] = 1; // pivoting for symmetric indefinite matrices 1x1 and 2x2
  iparm[23] = 0; // parallel fartocization control : 1 two-level scheduling
  iparm[24] = 0; // parallel forward/backward solve control
  iparm[26] = 0; // matryx checker
  iparm[27] = 0; // set single or double precision of PARDISO
  iparm[30] = 0; // partial sol. for sparse right-hand sides and sparse solution
  iparm[34] = 1; // C or Fortran style array indexing : 0 = Fortran
  iparm[59] = 0; // version of PARDISO : 0 = in-core
  maxfct = 1; /* Maximum number of numerical factorizations. */
  mnum = 1; /* Which factorization to use. */
  msglvl = 1; /* Print statistical information in file */
  error = 0; /* Initialize error flag */

  for (int i = 0; i < 64; i++) {
    pt[i] = 0;
  }
  
  phase = 11;
  t0_cpu = clock();
  get_realtime(&t0_elapsed);
  pardiso(pt, &maxfct, &mnum, &mtype, &phase,
	  &nrow, coefs,
	  ptrows, indcols, &idum, &nrhs,
	  iparm, &msglvl, (void *)y, (void *)x, &error);
  get_realtime(&t1_elapsed);
  t1_cpu = clock();    
  phase = 22;
  t2_cpu = clock();
  get_realtime(&t2_elapsed);
  pardiso(pt, &maxfct, &mnum, &mtype, &phase,
	  &nrow, coefs,
	  ptrows, indcols, &idum, &nrhs,
	  iparm, &msglvl, (void *)y, (void *)x, &error);
  get_realtime(&t3_elapsed);
  t3_cpu = clock();
    
  phase = 33;
  t4_cpu = clock();
  get_realtime(&t4_elapsed);
  pardiso(pt, &maxfct, &mnum, &mtype, &phase,
	  &nrow, coefs,
	    ptrows, indcols, &idum, &nrhs,
	  iparm, &msglvl, (void *)y, (void *)x, &error);
  get_realtime(&t5_elapsed);
  t5_cpu = clock();
  //  fprintf(stderr, "%s %d : iterative refinement is done with %d times\n",
  //	  __FILE__, __LINE__, iparm[6]);

  SpMV(z, x, nrow, ptrows, indcols, coefs, isSym); // recompute RHS

  phase = 33;
  t4_cpu = clock();
  get_realtime(&t4_elapsed);
  pardiso(pt, &maxfct, &mnum, &mtype, &phase,
	  &nrow, coefs,
	    ptrows, indcols, &idum, &nrhs,
	  iparm, &msglvl, (void *)z, (void *)w, &error);
  get_realtime(&t5_elapsed);
  t5_cpu = clock();
  fprintf(stderr, "%s %d : iterative refinement is done with %d times\n",
	  __FILE__, __LINE__, iparm[6]);
  SpMV(y, w, nrow, ptrows, indcols, coefs, isSym);
  int itmp0, itmp1, itmp2;
  double vtmp;
  double norm0, norm1;
  norm0 = 0.0;
  norm1 = 0.0;
  for (int i = 0; i < nrow; i++) {
    norm0 += x[i] * x[i];
    norm1 += (w[i] - x[i]) * (w[i] - x[i]);
  }
  fprintf(stderr, "%s %d : ## error = %18.7e = %18.7e / %18.7e\n",
	  __FILE__, __LINE__,
	  sqrt(norm1 / norm0), sqrt(norm1), sqrt(norm0));
  
  norm0 = 0.0;
  norm1 = 0.0;
  for (int i = 0; i < nrow; i++) {
    norm0 += z[i] * z[i];
    norm1 += (z[i] - y[i]) * (z[i] - y[i]);
  }
  fprintf(stderr, "%s %d : ## residual = %18.7e = %18.7e / %18.7e\n",
	  __FILE__, __LINE__,
	  sqrt(norm1 / norm0), sqrt(norm1), sqrt(norm0));
  
  fprintf(stderr,
	  "## symbolic fact    : cpu time = %.4e elapsed time = %.4e\n", 
	  (double)(t1_cpu - t0_cpu) / (double)CLOCKS_PER_SEC,
	  convert_time(t1_elapsed, t0_elapsed));
  
  fprintf(stderr,
	  "## numeric fact     : cpu time = %.4e elapsed time = %.4e\n", 
	  (double)(t3_cpu - t2_cpu) / (double)CLOCKS_PER_SEC,
	  convert_time(t3_elapsed, t2_elapsed));
  
  fprintf(stderr,
	  "## solve single RHS : cpu time = %.4e elapsed time = %.4e\n", 
	  (double)(t5_cpu - t4_cpu) / (double)CLOCKS_PER_SEC,
	  convert_time(t5_elapsed, t4_elapsed));
  phase = -1;
  pardiso(pt, &maxfct, &mnum, &mtype, &phase,
	  &nrow, coefs,
	    ptrows, indcols, &idum, &nrhs,
	  iparm, &msglvl, (void *)NULL, (void *)NULL, &error);
  delete [] ptrows;
  delete [] indcols;
  delete [] x;
  delete [] y;
  delete [] z;
  delete [] w;
}
