#ifndef __STRATEGY_H__
#define __STRATEGY_H__

#include "othello.h"
#include "evaluator.h"
#include "util.h"

#include <vector>

using namespace std;

class Strategy {
  public:
    virtual Othello &select(Othello &othello) = 0;
};

class BootLoader : public Strategy {
  protected:
    Strategy &_another;
  public:
    BootLoader(Strategy &another) : _another(another) {} 
    virtual Othello &select(Othello &othello){
      if(othello.current().black_stones() + othello.current().white_stones() == 4){
        int lower(othello.current().init_lower());
        return othello.mount(lower, lower - 1);
      }else{
        return _another.select(othello);
      }
    };
};

class Boltzman : public Strategy {
  public:
    virtual Othello &select(Othello &othello){
      Othello::State::candidates_t candidates(othello.current().candidates());
      int index(rand() % candidates.size());
      for(Othello::State::candidates_t::iterator it = candidates.begin();
            it != candidates.end();
            ++it){
        if(!(index--)){
          return othello.next(it->first);
        }
      }
      return othello;
    }
};

class UseEvaluatorStrategy : public Strategy {
  protected:
    Evaluator &_evaluator;
  public:
    UseEvaluatorStrategy(Evaluator &evaluator)
        : _evaluator(evaluator){
      
    }
    virtual ~UseEvaluatorStrategy() {}
    Evaluator &evaluator() const {return _evaluator;}
    
  protected:
    class Invoked_Each_TopLevel_Evaluate{
      public:
        virtual void operator()(int index, double score) = 0;
    };
  public:
    virtual void toplevel_evaluate(Othello &othello, Invoked_Each_TopLevel_Evaluate &invoked) = 0;
    
  protected:
    class SelectGreedy : public Invoked_Each_TopLevel_Evaluate{
      protected:
        double max_score;
        int index_maximized;
      public:
        SelectGreedy() : max_score(0), index_maximized(-1) {}
        ~SelectGreedy(){}
        void operator()(int index, double score){
          if((score > max_score) || (index_maximized == -1)){
            max_score = score;
            index_maximized = index;
          }
#ifdef STRATEGY_DEBUG
          cout << "Q(s, " << index << ") => " << score << endl;
#endif
        }
        Othello &result(Othello &othello){
#ifdef STRATEGY_DEBUG
          cout << "selected: " << index_maximized << " "
               << "("  << (index_maximized / othello.current().board().size()) 
               << ", " << (index_maximized % othello.current().board().size()) 
               << ")"  << endl;
#endif
          return (index_maximized >= 0) ? othello.next(index_maximized) : othello;          
        }
    };
  public:
    Othello &select(Othello &othello){
      SelectGreedy selector;
      toplevel_evaluate(othello, selector);
      return selector.result(othello);
    }
    
  protected:
    class SelectSoftMax : public Invoked_Each_TopLevel_Evaluate{
      protected:
        struct candidate_t{
          double selectable_coef;
          int index;
        };
        vector<candidate_t> candidates;
      public:
        SelectSoftMax() : candidates() {}
        ~SelectSoftMax(){}
        void operator()(int index, double score){
          candidate_t candidate;
          candidate.selectable_coef = score;
          candidate.index = index;
          candidates.push_back(candidate);
#ifdef STRATEGY_DEBUG
          cout << "Q(s, " << index << ") => " << score << endl;
#endif
        }
        Othello &result(Othello &othello, double tau){
          if(!candidates.empty()){
            double selected_coef(0);
            for(vector<candidate_t>::iterator it = candidates.begin();
                it != candidates.end();
                ++it){
              selected_coef += ((it->selectable_coef) = exp((it->selectable_coef) / tau)); 
            }
            selected_coef *= drand();
            for(vector<candidate_t>::iterator it = candidates.begin();
                it != candidates.end();
                ++it){
              if((selected_coef -= (it->selectable_coef)) <= 0){
#ifdef STRATEGY_DEBUG
                cout << "selected: " << it->index
                     << "("  << (it->index / othello.current().board().size()) 
                     << ", " << (it->index % othello.current().board().size()) 
                     << ")"  << endl;
#endif
                return othello.next(it->index);
              }
            }
            return othello.next(candidates.front().index);
          }
          return othello;
        }
    };
  public:
    /**
     * @param tau 温度
     */
    Othello &select_softmax(Othello &othello, double tau = 1.0){
      SelectSoftMax selector;
      toplevel_evaluate(othello, selector);
      return selector.result(othello, tau);
    }
    
  protected:
    class SelectEpsilonGreedy : public Invoked_Each_TopLevel_Evaluate{
      protected:
        double max_score;
        int index_maximized;
        vector<int> candidates;
      public:
        SelectEpsilonGreedy() : max_score(0), index_maximized(-1), candidates() {}
        ~SelectEpsilonGreedy(){}
        void operator()(int index, double score){
#ifdef STRATEGY_DEBUG
          cout << "Q(s, " << index << ") => " << score << endl;
#endif
          if((score > max_score) || (index_maximized == -1)){
            max_score = score;
            index_maximized = index;
          }
          candidates.push_back(index);
        }
        Othello &result(Othello &othello, double epsilon){
          if(!candidates.empty()){
            double selected_coef(drand());
            epsilon /= candidates.size();
            for(vector<int>::iterator it = candidates.begin();
                it != candidates.end();
                ++it){
              if((selected_coef -= epsilon) <= 0){
                index_maximized = *it;
                break;
              }
            }
#ifdef STRATEGY_DEBUG
            cout << "selected: " << index_maximized
                 << "("  << (index_maximized / othello.current().board().size()) 
                 << ", " << (index_maximized % othello.current().board().size()) 
                 << ")"  << endl;
#endif
            return othello.next(index_maximized);
          }
          return othello;
        }
    };
  public:
    Othello &select_egreedy(Othello &othello, double epsilon = 0.5){
      SelectEpsilonGreedy selector;
      toplevel_evaluate(othello, selector);
      return selector.result(othello, epsilon);
    }
};

class MinMaxSearch : public UseEvaluatorStrategy {
  protected:
    const unsigned int _max_nodes;
    const unsigned int _max_depth;
  public:
    MinMaxSearch(
          Evaluator &evaluator,
          const unsigned int max_nodes,
          const unsigned int max_depth)
        : UseEvaluatorStrategy(evaluator),
          _max_nodes(max_nodes),
          _max_depth(max_depth){
          
    }
    
  protected:
    virtual double recursive_evaluate(
        Othello::State *state, 
        unsigned int nodes, 
        unsigned int depth){
      if((nodes >= _max_nodes) || (depth >= _max_depth)){
        return evaluator().evaluate(*state);
      }
      //cout << *state << endl;
      //cout << nodes << ":" << depth << endl;
      Othello::State::candidates_t candidates(state->candidates());
      if(candidates.empty()){
        return evaluator().evaluate(*state);
      }
      nodes *= candidates.size();
      depth++;
      
      double minmax_score;
      int index = -1;
      
      for(Othello::State::candidates_t::iterator it = candidates.begin();
            it != candidates.end();
            ++it){
        double score(recursive_evaluate(it->second, nodes, depth));
        // depth % 2 == 0 when 自分 => 最大化 
        // depth % 2 == 1 when 相手 => 最小化
        if(((depth % 2 == 0) ? (score > minmax_score) : (score < minmax_score)) 
            || (index == -1)){
          minmax_score = score;
          index = it->first;
        }
      }
      return minmax_score;
    }
    
  public:
    void toplevel_evaluate(
        Othello &othello, 
        UseEvaluatorStrategy::Invoked_Each_TopLevel_Evaluate &invoked){
      Othello::State::candidates_t candidates(othello.current().candidates());
      for(Othello::State::candidates_t::iterator it = candidates.begin();
            it != candidates.end();
            ++it){
        double score(recursive_evaluate(it->second, candidates.size(), 0));
        invoked(it->first, score);
      }
    }
};

class NegaMaxSearch : public MinMaxSearch {
  public:
    NegaMaxSearch(
          Evaluator &evaluator,
          const unsigned int max_nodes,
          const unsigned int max_depth)
        : MinMaxSearch(evaluator, max_nodes, max_depth){
          
    }
    
  protected:
    virtual double recursive_evaluate(
        Othello::State *state, 
        unsigned int nodes, 
        unsigned int depth){
      
      while(true){
        if((nodes >= MinMaxSearch::_max_nodes) 
            || (depth >= MinMaxSearch::_max_depth)){
          break;
        }
        Othello::State::candidates_t candidates(state->candidates());
        if(candidates.empty()){break;}
          
        depth++;
        double max_score;
        nodes *= candidates.size();
          
        max_score = MinMaxSearch::evaluator().min_value();
        for(Othello::State::candidates_t::iterator it = candidates.begin();
              it != candidates.end();
              ++it){
          
          double score(recursive_evaluate(it->second, nodes, depth));
          
          if(score > max_score){max_score = score;}
        }
        
        return max_score * -1;
      }
      
      // depth % 2 == 0 when 自分 => 最大化 
      // depth % 2 == 1 when 相手 => 最小化
      return MinMaxSearch::evaluator().evaluate(*state) 
                * ((depth % 2 == 0) ? 1 : -1); 
    }
};

class AlphaBetaSearch : public UseEvaluatorStrategy {
  protected:
    unsigned int _max_nodes;
    unsigned int _max_depth;
  public:
    AlphaBetaSearch(
          Evaluator &evaluator,
          unsigned int max_nodes,
          unsigned int max_depth)
        : UseEvaluatorStrategy(evaluator),
          _max_nodes(max_nodes),
          _max_depth(max_depth){
          
    }
    
  protected:
    virtual double recursive_evaluate(
        Othello::State *state, 
        unsigned int nodes, 
        unsigned int depth,
        double threshold){
      
      //cout << nodes << ":" << depth << endl;
      while(true){
        if(nodes >= _max_nodes){
          _max_depth = depth;
#ifdef STRATEGY_DEBUG
          cout << "depth_limit: " << _max_depth
               << " @ node = " << nodes << " > " << _max_nodes << endl;
#endif
#ifdef max
#define max2 max
#undef max
#endif
          _max_nodes = numeric_limits<unsigned int>::max();
#ifdef max2
#define max max2
#undef max2
#endif 
        }
        if(depth < _max_depth){
          break;
        }
        return evaluator().evaluate(*state);
      }
      
      depth++;
      double minmax_score(
          (depth % 2 == 0) 
            ? UseEvaluatorStrategy::evaluator().min_value()
            : UseEvaluatorStrategy::evaluator().max_value());
      bool has_next(false);
      
      if(UseEvaluatorStrategy::evaluator().min_value() == threshold){
        // 初期サーチで相手の手番(最小だとしてもminは成立し得ない)
        Othello::State::candidates_t candidates(state->candidates());
        if(nodes *= candidates.size()){
          has_next = true;
          for(Othello::State::candidates_t::iterator it = candidates.begin();
                it != candidates.end();
                ++it){
            double score(recursive_evaluate(it->second, nodes, depth, minmax_score));
            if(score < minmax_score){minmax_score = score;}
          }
        }
      }else if(UseEvaluatorStrategy::evaluator().max_value() == threshold){
        // 初期サーチで自分の手番(最大だとしてもmaxは成立し得ない)
        Othello::State::candidates_t candidates(state->candidates());
        if(nodes *= candidates.size()){
          has_next = true;
          for(Othello::State::candidates_t::iterator it = candidates.begin();
                it != candidates.end();
                ++it){
            double score(recursive_evaluate(it->second, nodes, depth, minmax_score));
            if(score > minmax_score){minmax_score = score;}
          }
        }
      }else{
        // alpha / beta -pruningが可能の場合
        for(Othello::State::candidates_iterator_t it = state->candidates_begin();
              it != state->candidates_end();
              ++it, has_next = true){
          double score(recursive_evaluate(it->state, nodes, depth, minmax_score));
          // depth % 2 == 0 when 自分 => 最大化 
          // depth % 2 == 1 when 相手 => 最小化
          if(depth % 2 == 0){
            if(score > threshold){return score;} // "最大だとしてもbeta" beta-pruning 探査終了
            if(score > minmax_score){minmax_score = score;}
          }else{
            if(score < threshold){return score;} // "最小だとしてもalpha" alpha-pruning 探査終了
            if(score < minmax_score){minmax_score = score;}
          }
        }
      }
      
      return has_next ? minmax_score : evaluator().evaluate(*state);
    }
    
  public:
    void toplevel_evaluate(
        Othello &othello, 
        UseEvaluatorStrategy::Invoked_Each_TopLevel_Evaluate &invoked){
      Othello::State::candidates_t candidates(othello.current().candidates());
      unsigned int backup_max_depth(_max_depth);
      unsigned int backup_max_nodes(_max_nodes);
      double max_score(UseEvaluatorStrategy::evaluator().min_value());
      for(Othello::State::candidates_t::iterator it = candidates.begin();
            it != candidates.end();
            ++it){
        double score(recursive_evaluate(it->second, candidates.size(), 0, max_score));
        invoked(it->first, score);
        if(max_score < score){max_score = score;}
      }
      _max_depth = backup_max_depth;
      _max_nodes = backup_max_nodes;
    }
};

class NegaAlphaBetaSearch : public AlphaBetaSearch {
  public:
    NegaAlphaBetaSearch(
          Evaluator &evaluator,
          const unsigned int max_nodes,
          unsigned int max_depth)
        : AlphaBetaSearch(evaluator, max_nodes, max_depth){
          
    }
    
  protected:
    static const double min_value;
    static const double max_value;
    double recursive_evaluate(
        Othello::State *state, 
        unsigned int nodes, 
        unsigned int depth,
        double threshold){
      
      //cout << nodes << ":" << depth << endl;
      while(true){
        if(nodes >= AlphaBetaSearch::_max_nodes){
          AlphaBetaSearch::_max_depth = depth;
#ifdef STRATEGY_DEBUG
          cout << "depth_limit: " << AlphaBetaSearch::_max_depth
               << " @ node = " << nodes << " > " << AlphaBetaSearch::_max_nodes << endl;
#endif
#ifdef max
#define max2 max
#undef max
#endif
          AlphaBetaSearch::_max_nodes = numeric_limits<unsigned int>::max();
#ifdef max2
#define max max2
#undef max2
#endif
        }
        if(depth >= AlphaBetaSearch::_max_depth){
          break;
        }
      
        depth++;
        double max_score(AlphaBetaSearch::evaluator().min_value());
        bool has_next(false);
        
        if(AlphaBetaSearch::evaluator().min_value() == threshold){
          // 初期サーチ
          // 最小だとしてもminは成立し得ない
          Othello::State::candidates_t candidates(state->candidates());
          if(nodes *= candidates.size()){
            has_next = true;
            for(Othello::State::candidates_t::iterator it = candidates.begin();
                  it != candidates.end();
                  ++it){
              double score(recursive_evaluate(it->second, nodes, depth, max_score));
              
              if(score > max_score){max_score = score;}
            }
          }
        }else{
          // alpha / beta -pruningが可能の場合
          for(Othello::State::candidates_iterator_t it = state->candidates_begin();
                it != state->candidates_end();
                ++it, has_next = true){
            
            double score(recursive_evaluate(it->state, nodes, depth, max_score));
            
            // "最大だとしても..." beta-pruning 探査終了
            if(score > threshold){
              max_score = score; 
              break;
            } 
            if(score > max_score){max_score = score;}
          }
        }
        
        if(!has_next){
          depth--;
          break;
        }
        
        return max_score * -1;
      }
      
      // depth % 2 == 0 when 自分 => 最大化 
      // depth % 2 == 1 when 相手 => 最小化
      return evaluator().evaluate(*state) * ((depth % 2 == 0) ? 1 : -1);
    }
};

#endif

