/*
  function.cc, copyright (c) 2006 by Vincent Fourmond: 
  The implementation of (all) functions.
  
  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.
  
  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details (in the COPYING file).
  
*/

#include <calc_internals.hh>

// the functions we want ;-) !!!!
#include <math.h>

namespace SCalc {
  std::string FuncDef::pretty_print()
  {
    return std::string("biniou");
  }

  int FuncDef::register_self()
  {
    if(name().empty())
      return 0; // will fail anyway
    return session()->register_func_def(this);
  }


  CFunc::CFunc(Session * s, const char * n, 
	       c_function_t f, 
	       FuncDef * d) : FuncDef(s, 1)
  {
    deriv = d;
    func = f;
    set_name(n);
    register_self();
  }

  double CFunc::evaluate(const double * vars, const double * args)
  {
    return func(*args); // simple, isn't it ?
  }

  // The heavyside function, to allow for conditional stuff.
  static double heavy(double arg)
  {
    if(arg >= 0)
      return 1.0;
    else
      return 0;
  }

  // This function registers a whole bunch of common functions for the
  // session to use.
  void FuncDef::register_common_functions(Session * s)
  {
    CFunc * def;
    CFunc * def2;

    // hyperbolic functions
    def = new CFunc(s, "exp", exp);
    def->set_derivative(def); // cool, isn't it ???

    def = new CFunc(s, "ln", log);
    def->set_derivative(s->eval("x->1/x")
			->to_func_def());

    // trigonometric functions
    def = new CFunc(s, "sin", sin);
    def2 = new CFunc(s, "cos", cos);
    def->set_derivative(def2);
    def2->set_derivative(s->eval("x->-sin(x)")
			 ->to_func_def());

    def = new CFunc(s, "tan", tan);
    def->set_derivative(s->eval("x->1 + tan(x)^2")
			->to_func_def());


    // Define sqrt before using it...
    def = new CFunc(s, "sqrt", sqrt);
    def->set_derivative(s->eval("x->0.5*x**(-0.5)")
			->to_func_def());

    // reverse trigonometric functions
    def = new CFunc(s, "asin", asin);
    def->set_derivative(s->eval("x->1/sqrt(1 - x**2)")
			->to_func_def());

    def = new CFunc(s, "acos", acos);
    def->set_derivative(s->eval("x->-1/sqrt(1 - x**2)")
			->to_func_def());

    def = new CFunc(s, "atan", atan);
    def->set_derivative(s->eval("x->1/(1 + x**2)")
			->to_func_def());

    // Common hyperbolic functions
    def = new CFunc(s, "cosh", cosh);
    def2 = new CFunc(s, "sinh", sinh);
    // The most easy derivates, and the most elegant as well
    def2->set_derivative(def);
    def->set_derivative(def2);
//     def->set_derivative(s->eval("x->sinh(x")
// 			->to_func_def());

    def = new CFunc(s, "tanh", tanh);
    def->set_derivative(s->eval("x->1/(cosh(x)**2)")
			->to_func_def());


    def = new CFunc(s, "erf", erf);
    def->set_derivative(s->eval("x->exp(-x**2)")
			->to_func_def());

    def = new CFunc(s, "gamma", tgamma);
    /* no simple derivative that I know of ... */

    // The heavyside function - to allow conditional behavior
    // Do not derive !!
    def = new CFunc(s, "heavy", heavy);
  }


  void CFunc::destroy_anonymous_derivatives()
  {
    if(deriv && deriv->name().empty())
      {
	delete deriv;
	deriv = NULL;
      }
  }

  void ExprFunc::destroy_anonymous_derivatives()
  {
    for(std::map<int, FuncDef*>::iterator i = 
	  cached_derivatives.begin(); 
	i  != cached_derivatives.end(); i++)
      {
	FuncDef * f = i->second;
	if(f->name().empty())
	  delete f;
      }
  }

  FuncDef * ExprFunc::derivative(int nb)
  {
    std::map<int, FuncDef*>::iterator i = cached_derivatives.find(nb);
    if(i != cached_derivatives.end())
      return i->second;
    FuncDef * func = new ExprFunc(session(), 
				  exp->derive(-(nb+1)),
				  nb_params());
    cached_derivatives[nb] = func;
    return func;
  }

  std::string ExprFunc::pretty_print()
  {
    std::string str = "fundef: ";
    str += exp->pretty_print();
    return str;
  }


  CFuncParam::CFuncParam(Session * s, const char * n, 
			 c_function_t f, void * p,
			 FuncDef * d) : CFunc(s, n, NULL, d)
  {
    func = f;
    _param = p;
  }

  double CFuncParam::evaluate(const double * vars, const double * args)
  {
    return func(_param, *args); // simple, isn't it ?
  }


};
