/***********************************************************************
 * algebra.cc : an example of function algebra with differenciation    *
 * This version uses virtual mechanism.                                *
 *                                                                     *
 * Author : Pascal Hav, IFP Energies nouvelles (pascal.have@ifpen.fr) *
 * This example is a part of the lecture 'C++ Express' given for       *
 * CEMRACS 2012 (http://smai.emath.fr/cemracs/cemracs12)               *
 ***********************************************************************/

#include <cmath>
#include <iostream>
#include <fstream>
#include <iomanip>

// #define CPPPTR // use define to disable boost::shared_ptr

// Pr-dclaration de l'interface
class IFunction;

typedef double Real;

#ifdef CPPPTR
typedef const IFunction * IFunctionPtr;
#else /* CPPPTR */
#include <boost/shared_ptr.hpp>
typedef boost::shared_ptr<IFunction> IFunctionPtr;
#endif /* CPPPTR */

// Wrapper pour l'interface pour manipuler des objets sans s'occuper de l'allocation dynamique
class Function {
public:
#ifndef CPPPTR
  Function(IFunction * f) : m_f(f) { }
#endif /* CPPPTR */
  Function(const IFunctionPtr & f) : m_f(f) { }
  virtual ~Function() { }
public:
  operator IFunctionPtr() const { return m_f; }
public:
  Real operator()(const Real x) const;
  int memory_size() const;
private:
  const IFunctionPtr m_f;
};

// Interface de fonctions
class IFunction {
public:
  virtual Real eval(Real x) const = 0;
  virtual Function diff() const = 0;
  virtual std::ostream & print(std::ostream & o) const = 0;
  virtual int memory_size() const = 0;
};

// Identit
class Identity : public IFunction {
public:
  Identity() { }
  Real eval(Real x) const { return x; }
  Function diff() const;
  std::ostream & print(std::ostream & o) const { o << 'x'; return o; }
  int memory_size() const { return sizeof(*this); }
};

// Constante 
class Constant : public IFunction {
public:
  Constant(Real v) : m_v(v) { }
  Real eval(Real x) const { return m_v; }
  Function diff() const;
  std::ostream & print(std::ostream & o) const { o << m_v; return o; }
  int memory_size() const { return sizeof(*this); }
private:
  const Real m_v;
};

// Sinus
class Sinus : public IFunction {
public:
  Sinus(const IFunctionPtr & f) : m_f(f) { }
  Real eval(Real x) const { return std::sin(m_f->eval(x)); }
  Function diff() const;
  std::ostream & print(std::ostream & o) const { o << "sin("; m_f->print(o); o << ')'; return o; }
  int memory_size() const { return sizeof(*this) + m_f->memory_size(); }
private:
  const IFunctionPtr m_f;
};

// Cosinus
class Cosinus : public IFunction {
public:
  Cosinus(const IFunctionPtr & f) : m_f(f) { }
  Real eval(Real x) const { return std::cos(m_f->eval(x)); }
  Function diff() const;
  std::ostream & print(std::ostream & o) const { o << "cos("; m_f->print(o); o << ')'; return o; }
  int memory_size() const { return sizeof(*this) + m_f->memory_size(); }
private:
  const IFunctionPtr m_f;
};

// Addition
class Add : public IFunction {
public:
  Add(const IFunctionPtr & f, const IFunctionPtr & g) : m_f(f), m_g(g) { }
  Real eval(Real x) const { return m_f->eval(x) + m_g->eval(x); }
  Function diff() const;
  std::ostream & print(std::ostream & o) const { o << '('; m_f->print(o); o << ") + ("; m_g->print(o); o << ')'; return o; }
  int memory_size() const { return sizeof(*this) + m_f->memory_size() + m_g->memory_size(); }
private:
  const IFunctionPtr m_f, m_g;
};

// Multiplication
class Mult : public IFunction {
public:
  Mult(const IFunctionPtr & f, const IFunctionPtr & g) : m_f(f), m_g(g) { }
  Real eval(Real x) const { return m_f->eval(x) * m_g->eval(x); }
  Function diff() const;
  std::ostream & print(std::ostream & o) const { o << '('; m_f->print(o); o << ") * ("; m_g->print(o); o << ')'; return o; }
  int memory_size() const { return sizeof(*this) + m_f->memory_size() + m_g->memory_size(); }
private:
  const IFunctionPtr m_f, m_g;
};

// Inverse
class Inverse : public IFunction {
public:
  Inverse(const IFunctionPtr & f) : m_f(f) { }
  Real eval(Real x) const { return 1. / m_f->eval(x); }
  Function diff() const;
  std::ostream & print(std::ostream & o) const { o << "inv("; m_f->print(o); o << ")"; return o; }
  int memory_size() const { return sizeof(*this) + m_f->memory_size(); }
private:
  const IFunctionPtr m_f;
};

// Puissance
class Pow : public IFunction {
public:
  Pow(const IFunctionPtr & f, const int n) : m_f(f), m_n(n) { if (n==0) throw std::exception(); }
  Real eval(Real x) const { return std::pow(m_f->eval(x),m_n); }
  Function diff() const;
  // std::ostream & print(std::ostream & o) const { o << '('; m_f->print(o); o << ") ^" << m_n; return o; }
  std::ostream & print(std::ostream & o) const { o << "std::pow("; m_f->print(o); o << "," << m_n << ")"; return o; }
  int memory_size() const { return sizeof(*this) + m_f->memory_size(); }
private:
  const IFunctionPtr m_f;
  const int m_n;
};

// Implmentation de l'oprateur d'valuation
Real Function::operator()(const Real x) const { return m_f->eval(x); }
int Function::memory_size() const { return sizeof(*this) + m_f->memory_size(); }

// Quelques fonctions pour l'interface utilisateur
Function sin(const Function & f)                           { return Function(new Sinus(f)); }
Function cos(const Function & f)                           { return Function(new Cosinus(f)); }
// Des oprateurs primaires
Function operator+(const Function & f, const Function & g) { return new Add(f, g); }
Function operator*(const Function & f, const Function & g) { return new Mult(f, g); }
Function operator/(const Function & f, const Function & g) { return f * Function(new Inverse(g)); }
// Autres oprateurs associs 
Function operator+(const Function & f, const Real     & a) { return new Add(f, Function(new Constant(a))); }
Function operator*(const Real     & a, const Function & f) { return new Mult(Function(new Constant(a)), f); }
Function operator/(const Real     & a, const Function & f) { return a * Function(new Inverse(f)); }
Function operator-(const Function & f, const Real     & a) { return new Add(f, Function(new Constant(-a))); }
// Des oprations dduits
Function operator+(const Real     & a, const Function & f) { return f+a; }
Function operator*(const Function & f, const Real     & a) { return a*f; }
Function operator/(const Function & f, const Real     & a) { return f * (1./a); }
Function operator-(const Real     & a, const Function & f) { return f-a; }
Function operator-(const Function & f)                     { return -1 * f; }
Function operator-(const Function & f, const Function & g) { return f + (-g); }
// Exemple d'optimisation sur l'oprateur ^
Function operator^(const Function & f, const int     & n)  {  // ATTN: oprateur de basse priorit
  if (n == 0)
    return new Constant(1);
  else if (n == 1)
    return f;
  else
    return new Pow(f,n); 
}

// Les calculs de drives symboliques
Function Identity::diff() const { return new Constant(1); }
Function Constant::diff() const { return new Constant(0); }
Function Sinus::diff()    const { return cos(m_f) * m_f->diff(); }
Function Cosinus::diff()  const { return - sin(m_f) * m_f->diff(); }
Function Add::diff()      const { return m_f->diff() + m_g->diff(); }
Function Mult::diff()     const { return m_f->diff() * Function(m_g) + Function(m_f) * m_g->diff(); }
Function Inverse::diff()  const { return - m_f->diff() / Function(new Pow(m_f,2)); }
Function Pow::diff()      const { 
  // m_n != 0 par construction
  if (m_n == 1)
    return m_f->diff();
  else
    return m_n * (Function(m_f)^(m_n-1)) * m_f->diff();
}

// Oprateur de drivation
Function d(const Function & f) { return static_cast<IFunctionPtr>(f)->diff(); }

// Affichage d'une expression
std::ostream & operator<<(std::ostream & o, const Function & f) { return static_cast<IFunctionPtr>(f)->print(o); }

void plot(const Function & f, const char * filename)
{
  std::ofstream o(filename);
  const Real xmin = -1;
  const Real xmax = +1;
  const int  n    = (2<<22);

  Real sum = 0;
  const Real dx = (xmax-xmin)/n;
  for(int i=0;i<=n;++i)
    {
      const Real x = xmin + i*dx;
      sum += f(x);
      // o << x << ' ' << f(x) << '\n';
    }
  std::cout << "Sum is " << sum << std::endl;
  // auto-close when out of scope
}

int main() {
  Function x(new Identity());

  Function g = sin(x);
  Function dg = d(g);
  std::cout << "g(1) = " << std::setw(8) << g(1) << ";\t g'(1) = " << std::setw(8) << dg(1) << std::endl;
  Function h = x^4;
  std::cout << "h(2) = " << std::setw(8) << h(2) << ";\t h'(2) = " << std::setw(8) << d(d(h))(2.) << std::endl;
  // Function f = (sin(8*x-1)^3.) / cos(x^2.);
  Function f = (((2*x+1)^4)-((x-1)^3)/(1+x+((x/2.)^3)));
  std::cout << "f(2) = " << std::setw(8) << f(2) << ";\t f'(2) = " << std::setw(8) << d(f)(2.) << std::endl;
  std::cout << "  f  is " << f << std::endl;
  std::cout << "d(f) is " << d(f) << std::endl;

  std::cout << "Sizeofs : "
	    << "\n\t  x=" << (x).memory_size()
	    << "\n\t  g=" << (g).memory_size()
	    << "\n\t dg=" << (dg).memory_size() 
	    << "\n\t  h=" << (h).memory_size()
	    << "\n\tddh=" << (d(d(h))).memory_size() 
	    << "\n\t  f=" << (f).memory_size()
	    << "\n\t df=" << (d(f)).memory_size() << "\n";

  plot(f,"f.dat");
  plot(d(f),"df.dat");
};
