#include "distributed_vector.h"
#include <iostream.h>
#include <stdio.h>
#include <math.h>
#define pi 3.14159265358979323846e0

//dummy distributed_vector template

/*template <class T>
class distributed_vector{
   T *data;
   int size;
  public:
     distributed_vector(int n){
         data = new T[n];
         size = n;
        }
     typedef T * pariterator;
     typedef T * iterator;
     pariterator parbegin(){ return data; }
     pariterator parend(){ return data+size; }
     pariterator begin(){ return data; }
     pariterator end(){ return data+size; }
};
int RTS_LocalLocation(){ return 0; }
int RTS_NumProc(){ return 1; }
void RTS_Barrier(){};

template <class PI, class Func>
void par_apply(PI begin, PI end, Func F, PI base){
    for(; begin != end; begin++) F(*begin);
}
template <class PI, class Func>
void par_apply(PI begin, PI end, PI b2, Func F, PI base){
    for(; begin != end; begin++, b2++) F(*begin, *b2);
}
template <class PI, class Func>
void par_apply(PI begin, PI end, PI b2, PI b3, Func F, PI base){
    for(; begin != end; begin++, b2++, b3++) F(*begin, *b2, *b3);
}
template <class PI, class Func>
void par_apply(PI begin, PI end, PI b2, PI b3, PI b4, Func F, PI base){
    for(; begin != end; begin++, b2++, b3++, b4++) F(*begin, *b2, *b3, *b4);
}
template <class PI, class Func>
void par_apply(PI begin, PI end, PI b2, PI b3, PI b4, PI b5, Func F, PI base){
    for(; begin != end; begin++, b2++, b3++, b4++, b5++) 
             F(*begin, *b2, *b3, *b4, *b5);
}
template <class PI, class Func1, class Func2, class I>
I par_reduction(PI begin, PI end, Func1 S, Func2 F, I init, PI base){
    I sum = init;
    for(; begin != end; begin++) sum = S(sum, F(*begin));
    return sum;
}

template <class T>
class plus{
   public:
     plus(){}
     T operator()(T a, T b){ return a+b; }
};
*/
// end of dummy stuff;
     
// A C++ FFT and Sin Transform.

class FFT{
   int n;     // size of Vector.
   int log2n;
   // The following two arrays are used to hold the real and imaginary
   // part of the roots of unity.       W is set up as a log2(n) by n array
   // where n is the size of the complex fourier transform you need.
   // this is viewed as log2(n) stages of vectors of n complex coeficients
   // P is an array that gives a bit reversal permutation of n numbers.
   double *wr[20], *wi[20];
   double *tmp1, *tmp2; 
   int p[1024];
  public:
   // the constructor takes the size of the vector (assumed less that 1024)
   // for a complex fft on a vector of lenght n/2 where the real part
   // is stored in the first n/2 places and the imaginary part is storted
   // in the second n/2 places.  
   // n-1 is also the lenght of a vector of real values that will be transformed
   // by the sin transform.
   FFT(int size);
   void initw(int n);
   int mylog2(int n){
        int x;
        x = 0;
        while(n > 1){ n = n/2; x++;}
        return x;
        }
   // CFT uses second half of Vector as imaginary part, i.e. size = n/2
   void ComplexFourierTransform(double *v, double *result);

   // CFT utilities
   void step(int k, double *ar, double *ai, double *cr, double *ci);
   int rev(int k, int i);

   // computes result[i] = sum(j = 0; j < n; j++) sin(i*j*pi/n)*v[j]
   void SinTransform(double *v, double *result);
};


// step is the basic computational step used in the kernel of the complex 
// Fourier transform.  It computes 
// c(2*i) = a(i) + w(i)*a(n/4+i); c(2*i+1) = a(i) - w(i)*a(n/4+i);

void FFT::step(int k, double *ar, double *ai, double *cr, double *ci)
{
        int i;
        int hn = n/4;
        double tr, ti;
        for( i = 0; i<hn; i++){
                tr = wr[k][i]*ar[hn+i]-wi[k][i]*ai[hn+i];
                ti = wr[k][i]*ai[hn+i]+wi[k][i]*ar[hn+i];
                cr[2*i] = ar[i] + tr;
                cr[2*i+1] = ar[i] - tr;
                ci[2*i] = ai[i] + ti;
                ci[2*i+1] = ai[i] - ti;
           }
 }

// ComplexFourierTransform computes the complex fourier transform of a 
// Vector of length n/2
// where the real part is stored in the first n/2 places and the imaginary 
// part is in the second n/2 places. 
void FFT::ComplexFourierTransform(double *x, double *result)    
{
        int togle, i;
        int logn;
        double *xr, *xi, *yr, *yi;

        logn = log2n-1;
        xr = &(x[0]);
        xi = &(x[n/2]);
        yr = &(result[0]);
        yi = &(result[n/2]);
        togle = 1;
        for( i = 0; i < logn; i++){
                if (togle == 1){
                        step(i,xr,xi,yr,yi);
                        togle = 0;
                        }
                else {
                        step(i,yr,yi,xr,xi);
                        togle = 1;
                        }
                }
        if (2*(logn/2) != logn){
                for( i= 0; i <n/2; i++){
                        xr[i] = yr[i];
                        xi[i] = yi[i];
                }
           }
        for( i=0; i < n/2; i++){
                yr[i] = xr[p[i]];
                yi[i] = xi[p[i]];
           } 
}

// rev takes k bits from i and computes a new number with the
// bits reversed.  
       
int FFT::rev(int k, int i)
{
        int x, j, leadbit;
        x = 0;
        for( j=0; j < k; j++){
                if( 2*(i/2) != i) leadbit = 1; 
                else leadbit = 0; 
                i = (i-leadbit)/2;
                x = 2*x + leadbit;
            }
        return x;
}


FFT::FFT(int size){ 
      n = size;
      log2n = mylog2(n);
      tmp1  = new double[size+1];
      tmp2  = new double[size+1];
      initw(n/2);
   }

// Initw allocates and initializes the arrays wr and wi.
// wr and wi are arrays of size log2n by n  where n is the
// size of the Vector you want for Complex Fourier Transforms.
// NOTE: this is 1/2 the size of the Vector lenght used for
// a sine transform because the real coeficients in a sine
// transform are folded into a complex transform of 1/2 size.
// P is an array that contains a bit reversal permutation.
// Wr and Wi should be viewed as the log2(n) stages of n complex
// complex coefficients used in the CFT. 
// Initw should be called once at the initialization of the
// program.
           
void FFT:: initw(int n)
{
     int log2n = mylog2(n);
     int i,j,k,m;
     double x;
        x = 2.0e0*pi/n;
           
        for(i = 0; i < log2n; i++) {
                wr[i] = new double[n];
                wi[i] = new double[n];
                }
        
        for( i = 0; i < n/2; i++){
                p[i] = rev(log2n-1,i);
             }
        
        for( i=0; i<n/2; i++){
                wr[log2n-1][p[i]] =  (double) cos(x*i);
                wi[log2n-1][p[i]] =  (double) sin(x*i);
             }
         m = 1;
         for( i=0; i<log2n-1; i++ ){
                for( j= 0; j < n/2; j++){
                    k = (m-1)&j;
                    wr[i][j] = wr[log2n-1][k];
                    wi[i][j] = wi[log2n-1][k];
                }       
                m = 2*m;
             }
         for( i=0; i <n; i++){
                p[i] = rev(log2n,i);
             }
   }

// SinTransform computes sum(i=0; i < n){ y[i]* sin(i*j*pi/n) }
// The vector is actually stored in y[1] ... y[n-1]
void FFT::SinTransform(double *y, double *z)
{
        int i;
        double c, s,k;
        double x, t1, tnm1;
        double *a, *b;

        a = tmp1;
        b = tmp2;
        t1 = y[1];
        tnm1 = y[n-1];
        for( i= 1; i <n; i++){
                a[i] = 0.5*(y[i]+y[n-i]);
                b[i] = 0.5*(y[i]-y[n-i]);
           }
        a[0] = y[0];
        b[0] = 0.0;
        
        x = 2.0*pi/n;
        for( i= 1; i < n/2; i++){
                c =   cos(x*i);
                s =   sin(x*i); 
                z[i] =  -s*(a[2*i+1]-a[2*i-1]) +c*a[2*i] +(b[2*i+1]-b[2*i-1]);
                z[n/2+i] = -b[2*i] +c*(a[2*i+1]-a[2*i-1]) +s*a[2*i];
           }
        z[0] =   t1-tnm1;
        z[n/2] = t1+tnm1;
        ComplexFourierTransform(z, y);  
        for( i=0; i <n/2; i++){
                a[2*i] = y[i];
                a[2*i+1] = y[n/2+i];
           }
        
        x = pi/n;
        k = sqrt((double) (2.0/n));
        for( i=1; i < n; i++ ){
                s =  sin(x*i);
                s = 0.25/s;
                z[i] = k*(a[i]*(0.5+s)+a[n-i]*(s-0.5));
           }
 }
// ******************************************************************
// ****************************************************************
// ********** example starts here **********************************8
// ****************************************************************** 
// a grid row.
 
#define N 1024

class Row{
   int s;
   double p[N+1];
   int position;
  public: 
   Row(){ 
      s = N+1;
      }
  ~Row() {}

  // these are needed for casting from global_ref<Row> to Row
  Row(Row& r) : s(r.s),position(r.position) {
    for (int i=0; i<s; i++)
      p[i]=r.p[i];
  } 
  Row(const Row& r) : s(r.s),position(r.position) {
    for (int i=0; i<s; i++)
      p[i]=r.p[i];
  } 

   double *data(){ return &p[0]; }
   int size(){ return s; }
   int index(){return position;}
   void set_index(int i){ position = i; }
  double &operator[](int i) { return p[i]; }
  double get_data_i(int i) { return p[i]; }
   void print() const{
      printf("[%d]: ", position);
      for(int i = 0; i < s; i++)
          printf("%7.5f ", p[i]);
      printf("\n");
      }
};

class print_it{
  public:
    print_it(){}
    void operator()(const Row &r){ r.print(); }
};

void print(distributed_vector<Row> &r) {
   par_apply(r.parbegin(), r.parend(), print_it(), r.begin());
}

      
class tridiag_solver{
    double *diag;
    int n;
  public:
   tridiag_solver(int m): n(m){
        diag = new double[m];
        }
   void solve(double a, double b, Row &x, Row &y){
      int i;
      double c;
      diag[1] = a;
      x[1] = y[1];
      for(i = 1; i < n-1; i++){
                c = b/diag[i];
                diag[i+1] = a - b*c;
                x[i+1] = y[i+1] - c*x[i];
                }
        x[n-1] = x[n-1]/diag[n-1];
        for(i = n-2; i > 0; i--){
                x[i] = (x[i]-b*x[i+1])/diag[i];
                }
        x[0] = 0.0; x[n] = 0.0;
        }
};


const int n = N;

FFT fft(n);
tridiag_solver tridiag(n);

class row_ffts{
  public:
   row_ffts(){}
   void operator()(Row &x, Row &y){ 
     // had to switch the order because SinTransform puts result into
     //   second argument
     fft.SinTransform(y.data(),x.data());
   }
};

class solve_tridiag{
  public:
    solve_tridiag(){}
  // switched order on this because par_apply order was changed
    void operator()( Row &x, Row &y){
                double gamma;
                double a, b;
                int i = x.index();
                b = -1.0;
                gamma = pi*i/n;
                a = 4.0 -2.0* (float) cos(gamma);
                tridiag.solve(a,b, x, y);
           }
};
        


class transpose{
  Row R;
  public:
   transpose(distributed_vector<Row> &X, int i){
            distributed_vector<Row>::pariterator r;
           r =  X.parbegin();
	   R = r[i];
           };
   void operator()(Row &x){
     int i = R.index();
     int j = x.index();
     x[i] = R[j];
   }
};

class DiffOper{
    distributed_vector<Row>::pariterator u;
    Row mid, left, right;
    public:
       DiffOper(distributed_vector<Row> &U){u = U.parbegin(); }
       void operator()(Row &result){
           int i = result.index();
           if((i > 0) && (i < result.size()-1)){
               mid = u[i];
               left = u[i-1];
               right = u[i+1];
               for(int j = 1; j < result.size()-1; j++)
                   result[j] = 4.0*mid[j] -( mid[j+1] + mid[j-1] +
                                             left[j] + right[j]);
               }
           }
};

class squares{
    public:
      squares(){}
      double operator()(Row &AU){
         int i = AU.index();
         double err = 0.0;
         if((i > 0) && (i < AU.size()-1))
           for(int j = 1; j < AU.size()-1; j++){
             //cout << AU[j]-10.0*i/j << endl;
             err = err + (AU[j]-10.0*i/j) * (AU[j]-10.0*i/j) ;
             }
         return err;
         }
};

//define DEBUG

int tests(){

  distributed_vector<Row> F(n+1), U(n+1);
  distributed_vector<Row>::iterator f, fe;
  distributed_vector<Row>::iterator u;
  cout << "finished creation" << endl;
  f =F.begin();
  fe=F.end();
  u =U.begin();

  int i=RTS_LocalLocation()*((n+1)/RTS_NumLocations());
  for(; f != fe; f++, u++, i++) {
      int j;
      f->set_index(i);
      u->set_index(i);
      for(j = 1; j < n; j++){
           (*f)[j] = 10.0*i/j;
           (*u)[j] = 1.0;
       }
      (*f)[0] = 0.0;
      (*u)[0] = 0.0;
      (*f)[n] = 0.0;
      (*u)[n] = 0.0;
      if((i == 0) || (i == n) )
         for(j = 0; j < n+1; j++){
             (*f)[j] = 0.0; (*u)[j] = 0.0;
             }
       }

  tulip_UserTimerClear(0);
  tulip_UserTimerClear(1);
  tulip_UserTimerClear(2);
  tulip_UserTimerClear(3);
  tulip_UserTimerClear(4);

  RTS_Barrier();

#ifdef DEBUG

  printf("finished init\n");
  printf("print F:\n");
  print(F);
  printf("print U:\n");
  print(U);
#endif  
  // now begin the solution process.  
  // first compute the sintransforms of all the rows.
  // U is being modified so it needs to be first args
  tulip_UserTimerStart(0);
  par_apply(U.parbegin(),U.parend(),F.parbegin(),row_ffts(),U.begin());
  tulip_UserTimerStop(0);
#ifdef DEBUG
  printf("after fft print U:\n");
  print(U);
#endif

  // now transpose the grid.
  tulip_UserTimerStart(1);
for (i=0; i < N  ; i++)
  par_apply(F.parbegin(),F.parend(),transpose(U,i),F.begin());
  tulip_UserTimerStop(1);  
#ifdef DEBUG  
  printf("after transpose print F:\n");
  print(F);
#endif
  // now slove the tridiagonal equations
  // had to switch order here too since U is changed
  tulip_UserTimerStart(4);
  par_apply(U.parbegin(),U.parend(), F.parbegin(), solve_tridiag(), U.begin());
  tulip_UserTimerStop(4);

#ifdef DEBUG  
  printf("after tridiag: print U:\n");  
  print(U);
#endif
  // now traspose again
  tulip_UserTimerStart(1);
for (i=0; i < N  ; i++)
  par_apply(F.parbegin(),F.parend(),transpose(U,i),F.begin());
  tulip_UserTimerStop(1);
#ifdef DEBUG  
  printf("after transpose print F:\n");  
  print(F);
#endif

  // now sine transform again.
    tulip_UserTimerStart(0);
  par_apply(U.parbegin(),U.parend(),F.parbegin(),row_ffts(),U.begin());
    tulip_UserTimerStop(0);
  // calculation now complete
#ifdef DEBUG  
  printf("after fft: print U:\n");
  print(U);
#endif  
  // check for errors.  apply the differential operator.
  tulip_UserTimerStart(2);
  par_apply(F.parbegin(),F.parend(), DiffOper(U), F.begin());
  tulip_UserTimerStop(2);

#ifdef DEBUG  
  printf("apply oper: print F:\n");  
  print(F);
#endif

  // do inner product of residual
    tulip_UserTimerStart(3);
  double resid = par_reduction(F.parbegin(),F.parend(), 
                 plus<double>(), squares(), (double) 0.0, 
                 F.begin());
    tulip_UserTimerStop(3);

  RTS_Barrier();
  if (tulip_MyContext()==0) {
    cout  << "Processors = " << RTS_NumLocations() << endl;
    cout << "square of l2 error is = " << resid/(n*n) << endl;
    cout << "should be less than 0.3" << endl;

    cout << "time for fft  = " << tulip_UserTimerElapsed(0) << endl;
    cout << "time for trns = " << tulip_UserTimerElapsed(1) << endl;
    cout << "time for trid = " << tulip_UserTimerElapsed(4) << endl;
    cout << "time for oper = " << tulip_UserTimerElapsed(2) << endl;
    cout << "time for redu = " << tulip_UserTimerElapsed(3) << endl;
    cout << "total time = " << tulip_UserTimerElapsed(0)+
      tulip_UserTimerElapsed(1)+
      tulip_UserTimerElapsed(2)+
      tulip_UserTimerElapsed(3)+
      tulip_UserTimerElapsed(4) << endl;
    cout << tulip_UserTimerElapsed(0) << " " << 
      tulip_UserTimerElapsed(1) << " " <<
      tulip_UserTimerElapsed(4) << " " <<
      tulip_UserTimerElapsed(2) << " " <<
      tulip_UserTimerElapsed(3) << " " <<
      tulip_UserTimerElapsed(0)+
      tulip_UserTimerElapsed(1)+
      tulip_UserTimerElapsed(2)+
      tulip_UserTimerElapsed(3)+
      tulip_UserTimerElapsed(4) << endl;
  }
  RTS_Barrier();  

  return 0;
}




extern "C" {
  int tulip_ParallelMain(int argc, char* argv[]) 
    {
      mem_init=1;
      RTS_init();
      int i=tests();
      return 1;
    }
}
