// ------------------------------------
//   invert.cpp
// ------------------------------------

#include "w.h"

// we play with a function inversion machine
// minimal protection against error

// declaration of the class
class invert
{

	public:
	invert(fptype1 fpi = fid);
	void setfp(fptype1 fpi);
	void setx0(double xi);
	void setdx(double dxi);
	void setytol(double ytoli);
	void setnmax(int nmaxi);
	int getn();
	double inverse(double y);
	double zero();

	private:
	fptype1 fp;
	double dx,x0,ytol;
	int nmax,n;
};

// definitions
invert::invert(fptype1 fpi): fp(fpi), x0(1), dx(0.01),
ytol(1e-10),nmax(20), n(0) { }

void invert::setfp(fptype1 fpi) { fp = fpi; }

void invert::setx0(double x0i) { x0 = x0i; }

void invert::setdx(double dxi) { dx = dxi; }

void invert::setytol(double ytoli) { ytol = ytoli; }

void invert::setnmax(int nmaxi) { nmax = nmaxi; }

int invert::getn() { return n; }


double invert::inverse(double y)
{
  n = 0;
  double x = x0;
  double z = fp(x)-y;
  double dx2 = 2*dx;
  if (fabs(z) < ytol) return x0;

  while ((fabs(z) > ytol)&&(n < nmax))
  {
	  n++;
	  x = x - dx2*z/(fp(x+dx) - fp(x-dx));
	  z = fp(x)-y;
  }

  return x;
}

double invert::zero() { return inverse(0); }

// -----------------------------------------------------------
  // some test functions defined 'inline'

  double f(double x) {return x*exp(x); }

  double g(double x) { return sin(x)*pow(x,5); }

  double h(double x) { return (x - cos(x)); }

 // ----------------------------------------------------------

 void main()
{
  nl(0);
  banner("invert.cpp");

  // send message to output stream
  cout.precision(15);

  // create an instance of invert with fp = tan
  invert inv = invert(tan);

  // use inv with default settings

  p("atan(1) = ", inv.inverse(1));
  pl(",   n = ", inv.getn());
  pl("check:  atan(1) = ", atan(1));

  pl("------------------------------------------- ");
  inv.setfp(f); // reset the function
  double y = 30;
  pl("f(x) = xexp(x) = y = ",y);
  double x = inv.inverse(y);
  p("f^-1(y) = x = ", inv.inverse(y));
  pl(",   n = ", inv.getn());
  pl("check:  f(x) = ", f(x));

  pl("------------------------------------------- ");
  inv.setfp(g); // reset the function
  inv.setx0(2); // reset x0
  y = 5;
  pl("g(x) = sin(x)x^5 = y = ",y);
  x = inv.inverse(y);
  p("g^-1(y) = x = ", inv.inverse(y));
  pl(",   n = ", inv.getn());
  pl("check:  g(x) = ", g(x));

  pl("------------------------------------------- ");
  inv.setfp(h);
  inv.setx0(0.5);
  x = inv.zero();
  pl("h(x) = x - cos(x) = 0,  x = ",x);
  pl("n = ",inv.getn());
  pl("check h(x) = ", h(x));

  pl("------------------------------------------- ");
  }


 /* output

	------------
    invert.cpp
   ------------
   atan(1) = 0.785398163397603,   n = 5
   check:  atan(1) = 0.785398163397448
   -------------------------------------------
   f(x) = xexp(x) = y = 30
   f^-1(y) = x = 2.48922568815786,   n = 10
   check:  f(x) = 30.0000000000078
   -------------------------------------------
   g(x) = sin(x)x^5 = y = 5
   g^-1(y) = x = 1.38455186359695,   n = 6
   check:  g(x) = 5.0000000000013
   -------------------------------------------
   h(x) = x - cos(x) = 0,  x = 0.739085133215158
   n = 4
   check h(x) = -3.66487439354413e-15
   -------------------------------------------

 */