#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fcntl.h>
#include <unistd.h>
#include <math.h>

#include "dmumps_c.h"
#include "elapsed_time.hpp"
#include <iostream>
#include <sstream>
#include <fstream>
#include <omp.h>

#define USE_COMM_WORLD -987654

using namespace std;


void SpMV(double *y, double *x, int nrow, int nnz,
	  MUMPS_INT *irow, MUMPS_INT *jcol, double *val, bool isSym)
{
  for (int i = 0; i < nrow; i++) {
    y[i] = 0.0;
  }
  if (isSym) {
    for (int k = 0; k < nnz; k++) {
      if (irow[k] != jcol[k]) {
	y[irow[k] - 1] += val[k] * x[jcol[k] - 1];
	y[jcol[k] - 1] += val[k] * x[irow[k] - 1];
      }
      else {
	y[irow[k] - 1] += val[k] * x[jcol[k] - 1];
      }
    }
  }
  else {
    for (int k = 0; k < nnz; k++) {
      y[irow[k] - 1] += val[k] * x[jcol[k] - 1];
    }
  }
}

int main(int argc, char **argv)
{
  int n, itmp, jtmp, ktmp;
  char fname[256], fname1[256];
  char buf[1024];
  int nrow, nnz, flag;
  int *ptrows, *indcols;
  MUMPS_INT *irow, *jcol;
  double *val, *coefs;
  int num_threads = 1;
  int ordering = 3;
  double eps_pivot = 1.0e-2;
  double null_pivot = 1.0e-8;
  FILE *fp;
  bool isSym;
  bool rhs_data = false;
    
  DMUMPS_STRUC_C id;
    
  clock_t t0_cpu, t1_cpu, t2_cpu, t3_cpu, t4_cpu, t5_cpu;
  elapsed_t t0_elapsed, t1_elapsed, t2_elapsed, t3_elapsed;
  elapsed_t t4_elapsed, t5_elapsed;
  
  if (argc == 1) {
    fprintf(stderr,
	    "MM-MUMPS [matrix data] [pivot thres.] [null thres.] [rhs data]\n");
    exit(-1);
  }    
  strcpy(fname, argv[1]);
  ordering = atoi(argv[2]);
  num_threads = atoi(argv[3]);
  if (argc >= 5) {
    eps_pivot = atof(argv[4]);
  }
  if (argc >= 6) {
    null_pivot = atof(argv[5]);
  }
  if (argc >= 7) {
    strcpy(fname1, argv[6]);
    rhs_data = true;
  }
  // read from the file
  if ((fp = fopen(fname, "r")) == NULL) {
    fprintf(stderr, "fail to open %s\n", fname);
  }
  fgets(buf, 256, fp);
  //
  if (strstr(buf, "symmetric") != NULL) {
   isSym = true;
  }
  else {
    isSym = false;
  }
  if (strstr(buf, "complex") != NULL) {
    fprintf(stderr, "%s %d : complex data is not supoorted\n",
	    __FILE__, __LINE__);
    exit(-1);
  }

  fprintf(stderr, "symmetric = %s\n", isSym ? "true " : "false");

  omp_set_dynamic(0);
  omp_set_num_threads(num_threads);
    
  while (1) {
    fgets(buf, 256, fp);
    if (buf[0] != '%') {
      sscanf(buf, "%d %d %d", &nrow, &itmp, &nnz);
      break;
    }
  }
  irow = new MUMPS_INT[nnz];
  jcol = new MUMPS_INT[nnz];

  {
    val = new double[nnz];
    for (int i = 0; i < nnz; i++) {
      fscanf(fp, "%d %d %lf", &jcol[i], &irow[i], &val[i]); // read lower
      if (isSym && irow[i] > jcol[i]) {
	//	fprintf(stderr, "exchanged : %d > %d\n", irow[i], jcol[i]);
	itmp = irow[i];
	irow[i] = jcol[i];
	jcol[i] = itmp;
      }
    }
  }
  fclose (fp);
  double *x = new double[nrow];
  double *y = new double[nrow];
  double *z = new double[nrow];
  double *w = new double[nrow];

  if (rhs_data) {
    // read from the file
    if ((fp = fopen(fname1, "r")) == NULL) {
      fprintf(stderr, "fail to open %s\n", fname1);
    }
    fgets(buf, 256, fp);
    //
    while (1) {
      fgets(buf, 256, fp);
      if (buf[0] != '%') {
	sscanf(buf, "%d %d %d", &itmp, &jtmp, &ktmp);
	break;
      }
    }
    
    {
      for (int i = 0; i < nrow; i++) {
	fscanf(fp, "%d %d %lf", &itmp, &jtmp, &y[i]);
      }
    }
    fclose (fp);
  }
  id.job = (-1); // job init
  id.par = 1;
  id.sym = isSym ? 2 : 0;
  id.comm_fortran = USE_COMM_WORLD;
  dmumps_c(&id);

  id.job = 1; // symbolic
  id.n = nrow;
  id.nz = nnz;
  id.irn = irow;
  id.jcn = jcol;
  id.icntl[0] = 6;
  id.icntl[1] = 6;
  id.icntl[2] = 6;
  id.icntl[3] = 2;
  id.icntl[6] = ordering; // 3 : SCOTCH, 5 : METIS
  id.icntl[12] = 1;
  id.icntl[23] = 1;

  t0_cpu = clock();
  get_realtime(&t0_elapsed);
  dmumps_c(&id);
  t1_cpu = clock();
  get_realtime(&t1_elapsed);

  id.job = 2; // numeirc
  id.a = val;
  id.icntl[7] = 0;
  id.cntl[0] = eps_pivot;
  id.cntl[2] = (-1.0) * null_pivot;
    
  t2_cpu = clock();
  get_realtime(&t2_elapsed);
  dmumps_c(&id);
  t3_cpu = clock();
  get_realtime(&t3_elapsed);
  fprintf(stderr, "%s %d : NumericFact() done\n", __FILE__, __LINE__);

  if (rhs_data) {
    for (int i = 0; i < nrow; i++) {
      x[i] = y[i];
    }
  }
  else {
    for (int i = 0; i < nrow; i++) {
      y[i] = (double)(i % 11);
    }
    SpMV(z, y, nrow, nnz, irow, jcol, val, isSym); // 
    SpMV(y, z, nrow, nnz, irow, jcol, val, isSym); // 
    for (int i = 0; i < nrow; i++) {
      x[i] = y[i];
    }
  }
  id.job = 3; // forward/backward
  id.nrhs = 1;
  id.rhs = x;

  t4_cpu = clock();
  get_realtime(&t4_elapsed);  
  dmumps_c(&id);
  t5_cpu = clock();
  get_realtime(&t5_elapsed);

  SpMV(y, x, nrow, nnz, irow, jcol, val, isSym); // recompute RHS

  for (int i = 0; i < nrow; i++) {
    z[i] = y[i];
  }
  id.job = 3; // forward/backward
  id.nrhs = 1;
  id.rhs = z;

  t4_cpu = clock();
  get_realtime(&t4_elapsed);  
  dmumps_c(&id);
  t5_cpu = clock();
  get_realtime(&t5_elapsed);
  
  SpMV(w, z, nrow, nnz, irow, jcol, val, isSym); // for residual

  double norm0, norm1;
  norm0 = 0.0;
  norm1 = 0.0;
  for (int i = 0; i < nrow; i++) {
    norm0 += x[i] * x[i];
    norm1 += (z[i] - x[i]) * (z[i] - x[i]);
  }
  fprintf(stderr, "## error    : %18.7e = %18.7e / %18.7e\n",
	  sqrt(norm1 / norm0), sqrt(norm1), sqrt(norm0));

  norm0 = 0.0;
  norm1 = 0.0;
  for (int i = 0; i < nrow; i++) {
    norm0 += y[i] * y[i];
    norm1 += (w[i] - y[i]) * (w[i] - y[i]);
  }
  
  fprintf(stderr, "## residual : %18.7e = %18.7e / %18.7e\n",
	  sqrt(norm1 / norm0), sqrt(norm1), sqrt(norm0));

  fprintf(stderr,
	  "## symbolic fact    : cpu time = %.4e elapsed time = %.4e\n", 
	  (double)(t1_cpu - t0_cpu) / (double)CLOCKS_PER_SEC,
	  convert_time(t1_elapsed, t0_elapsed));
  
  fprintf(stderr,
	  "## numeric fact     : cpu time = %.4e elapsed time = %.4e\n", 
	  (double)(t3_cpu - t2_cpu) / (double)CLOCKS_PER_SEC,
	  convert_time(t3_elapsed, t2_elapsed));
  
  fprintf(stderr,
	  "## solve single RHS : cpu time = %.4e elapsed time = %.4e\n", 
	  (double)(t5_cpu - t4_cpu) / (double)CLOCKS_PER_SEC,
	  convert_time(t5_elapsed, t4_elapsed));
  
  id.job = -2; // clean up
  dmumps_c(&id);
  
  delete [] irow;
  delete [] jcol;
  delete [] val;
  delete [] x;
  delete [] y;
  delete [] z;
  delete [] w;
}

