#include <iostream>
#include <fstream>
#include <iomanip>
#include <string>
#include <exception>
#include <sstream>

#include <stdio.h>
#include <fcntl.h>

using namespace std;

#define STRATEGY_DEBUG

#include "othello.h"
#include "evaluator.h"
#include "strategy.h"
#include "cirno.h"

class Board_180Rotated : public Othello::Board {
  public:
    typedef Othello::Board super_t;
    typedef Othello::Stone stone_t;
    Board_180Rotated(const super_t &original)
        : super_t(original) {}
    stone_t &operator()(const int i, const int j){
      return super_t::operator()(super_t::_size - (i + 1), super_t::_size - (j + 1));
    }
    stone_t *get(const int i, const int j) const {
      return super_t::get(super_t::_size - (i + 1), super_t::_size - (j + 1));
    }
    stone_t *set(const int i, const int j, const stone_t *stone){
      return super_t::set(super_t::_size - (i + 1), super_t::_size - (j + 1), stone);
    }
};

class Board_LineSymmetric1 : public Othello::Board {
  public:
    typedef Othello::Board super_t;
    typedef Othello::Stone stone_t;
    Board_LineSymmetric1(const super_t &original)
        : super_t(original) {}
    stone_t &operator()(const int i, const int j){
      return super_t::operator()(j, i);
    }
    stone_t *get(const int i, const int j) const {
      return super_t::get(j, i);
    }
    stone_t *set(const int i, const int j, const stone_t *stone){
      return super_t::set(j, i, stone);
    }
};

class Board_LineSymmetric2 : public Othello::Board {
  public:
    typedef Othello::Board super_t;
    typedef Othello::Stone stone_t;
    Board_LineSymmetric2(const super_t &original)
        : super_t(original) {}
    stone_t &operator()(const int i, const int j){
      return super_t::operator()(super_t::_size - (j + 1), super_t::_size - (i + 1));
    }
    stone_t *get(const int i, const int j) const {
      return super_t::get(super_t::_size - (j + 1), super_t::_size - (i + 1));
    }
    stone_t *set(const int i, const int j, const stone_t *stone){
      return super_t::set(super_t::_size - (j + 1), super_t::_size - (i + 1), stone);
    }
};

/**
 * おバカをがんばって鍛えましょう
 * 
 * というのは置いておいて、とりあえず強化学習のSarsaっぽい実装をしてみることにする
 * Sarsaは方策on型であるから、使う方策をpi(s,a) > 0なるような方策にする必要がある
 * ここではソフトマックスを用いたepsilon-greedy方策で関数を鍛えてみることにする
 * 
 * といってもゲームにおいては途中報酬が全くないので、
 * モンテカルロと同じような気もしないでもない。
 */
class Cirno_Trained_by_Sarsa_SM_e_greedy : public Cirno {
  public:
    typedef Cirno super_t;
    typedef super_t::Tiler tiler_t;
    Cirno_Trained_by_Sarsa_SM_e_greedy(const Othello &othello)
        : super_t(othello){
      
    }
    ~Cirno_Trained_by_Sarsa_SM_e_greedy(){}
  
  protected:
    /**
     * 1差分分だけ学習させる
     * 途中の報酬が0なので、非常に単純
     * 
     * @param alpha ステップサイズ・パラメータ
     */
    void train_TDn_board(
        super_t::Tiler::encoded_t &old_encoded_board,
        super_t::Tiler::encoded_t &new_encoded_board,
        double alpha){
      
      super_t::trained_tiles_t::iterator it(super_t::trained_data.begin());
      int index(0);
      while(it != super_t::trained_data.end()){
        int tile_index_new(new_encoded_board[index]); 
        int tile_index_old(old_encoded_board[index]);
        trained_item_t &item_old((*it)->trained_items[tile_index_old]);
        trained_item_t &item_new((*it)->trained_items[tile_index_old]);
        
        // Sarsaによる更新
        item_old.score += (item_new.score - item_old.score) * alpha;
        
        // その他パラメータの更新
        item_old.score_sum += item_old.score; 
        item_old.score_pow2 += pow(item_old.score, 2);
        (item_old.run_count)++;
        
        ((*it)->total_score_sum) += item_old.score;
        ((*it)->total_score_pow2) += pow(item_old.score, 2);
        ((*it)->total_run_count)++;
        
        ++it;
        index++;
      }
    }
    
    void train_TDn(
        const Othello::State &old_state,
        const Othello::State &new_state,
        double alpha){
      
      train_TDn_board(
          super_t::Tiler::encode(old_state.board()),
          super_t::Tiler::encode(new_state.board()),
          alpha);
      
      // その他対象性があるものも同時に鍛える
      train_TDn_board(
          super_t::Tiler::encode(Board_180Rotated(old_state.board())),
          super_t::Tiler::encode(Board_180Rotated(new_state.board())),
          alpha);
      train_TDn_board(
          super_t::Tiler::encode(Board_LineSymmetric1(old_state.board())),
          super_t::Tiler::encode(Board_LineSymmetric1(new_state.board())),
          alpha);
      train_TDn_board(
          super_t::Tiler::encode(Board_LineSymmetric2(old_state.board())),
          super_t::Tiler::encode(Board_LineSymmetric2(new_state.board())),
          alpha);
    }
    
    
  public:
    /**
     * 現在の状態と関係がある部分のみ評価関数の更新を行う
     * (すなわち、オンライン更新用)
     * 
     * @param othello 対極の途中でもOK
     * @param alpha ステップサイズ・パラメータ
     * @param n_step nステップSarsa, デフォルト1
     */
    void train_current(const Othello &othello, double alpha, int n_step = 1){
      if(n_step <= othello.transit_count()){return;}
      train_TDn(
          othello.previous(n_step), 
          othello.previous(n_step - 1),
          alpha);
    }
    
    /**
     * これまでの全状態について更新を行う
     * (すなわち、オフライン更新用)
     * 
     * @param othello 現在までの状態についてoffline更新を行う
     * @param alpha ステップサイズ・パラメータ
     * @param n_step nステップSarsa, デフォルト1
     */
    void train_all(const Othello &othello, double alpha, int n_step = 1){
      for(int step = n_step; step <= othello.transit_count(); step++){
        train_TDn(
            othello.previous(step), 
            othello.previous(step - 1),
            alpha);
      }
    }
  protected:
    void train_final_board(
        super_t::Tiler::encoded_t &encoded_board, 
        double alpha,
        double score){
      
      super_t::trained_tiles_t::iterator it(super_t::trained_data.begin());
      int index(0);
      while(it != super_t::trained_data.end()){
        int tile_index(encoded_board[index]); 

        trained_item_t &item((*it)->trained_items[tile_index]);
        
        // Sarsaによる更新
        item.score += (score - item.score) * alpha;
        
        // その他パラメータの更新
        item.score_sum += item.score; 
        item.score_pow2 += pow(item.score, 2);
        (item.run_count)++;
        
        ((*it)->total_score_sum) += item.score;
        ((*it)->total_score_pow2) += pow(item.score, 2);
        ((*it)->total_run_count)++;
        
        ++it;
        index++;
      }
    }
  
  public:
    /**
     * 最終状態のQ値を学習させる
     * 評価関数の元となるテーブルは(黒石)-(白石)で作ること
     * 
     * @param othello 
     * @param alpha ステップサイズ・パラメータ
     */
    void train_final_state(const Othello &othello, double alpha){
      // 終局常態かのチェック
      //if(othello.current().has_next()){return;}
      
      const Othello::State &current(othello.current());
      double score(current.black_stones() - current.white_stones());
      
      train_final_board(
          super_t::Tiler::encode(current.board()),
          alpha,
          score);
      
      // 対象性があるものも同時に鍛える
      train_final_board(
          super_t::Tiler::encode(Board_180Rotated(current.board())),
          alpha,
          score);
      train_final_board(
          super_t::Tiler::encode(Board_LineSymmetric1(current.board())),
          alpha,
          score);
      train_final_board(
          super_t::Tiler::encode(Board_LineSymmetric2(current.board())),
          alpha,
          score);
    }
};


#define BOARD_SIZE 6

#define TAU_INIT 10.0
#define TAU_DISCOUNT (1.0 - 1E-4)
#define TAU_DISCOUNT_ON_EPISODE (1.0 - 1E-2)

#define ALPHA_INIT 0.5
#define ALPHA_MIN 0.1
#define ALPHA_DISCOUNT (1.0 - 1E-4)


int main(int argc, char *argv[]){
  Othello othello(BOARD_SIZE);
  
  Cirno_Trained_by_Sarsa_SM_e_greedy cirno(othello);
  
  //NegaAlphaBetaSearch selector(cirno, 50000, 12);
  AlphaBetaSearch selector(cirno, 30000, 12);
  //NegaMaxSearch selector(cirno, 10000, 12);
  //MinMaxSearch selector(cirno, 30000, 10);
  //Boltzman selector;
  BootLoader bootloader(selector);
   
  int loop_count(0);
  int loop_count_init(0);
  
  if(argc > 1){
    stringstream ss;
    loop_count_init = atoi(argv[1]);
    ss << "cirno_" << setw(8) << setfill('0') << loop_count_init << ".log";
    cout << "init_log: " << ss.str() << endl;
    
    ifstream logfs(ss.str().c_str(), ios::in);
    logfs >> cirno;
    //cout << setprecision(10) << cirno;
  }
  
  double tau(TAU_INIT);
  double alpha(ALPHA_INIT);
  
  while(true){
    
    loop_count++;
    tau *= TAU_DISCOUNT;
    alpha *= ALPHA_DISCOUNT;
    if(loop_count_init >= loop_count){continue;}
    double tau_on_episode(tau);
    
    bootloader.select(othello);
    while(true){
      tau_on_episode *= TAU_DISCOUNT_ON_EPISODE; 
      cirno.set_evaluate_mode(othello.current());
      selector.select_softmax(othello, max(tau_on_episode, 0.1));
      //selector.select(othello);
      cout << othello.current() << endl;
      if(!othello.current().has_next()){break;}
      cirno.train_current(othello, max(alpha, ALPHA_MIN), 1);
    }
    cirno.train_final_state(othello, alpha);
    cout << "試行回数: " << loop_count 
         << ", 温度: " << tau 
         << ", alpha: " << alpha
         << endl;
    cout << "黒: " << othello.current().black_stones() << ", " 
         << "白: " << othello.current().white_stones() << endl;
    cout << "評価関数値(" << cirno.get_evaluate_mode()->to_s()
         << "): " << cirno.evaluate(othello.current()) << endl;
    cout << endl << endl;
    
    // 初期状態に戻す
    while(othello.transit_count()){
      othello.rollback();
    }
    
    if(loop_count % 100 == 0){
      stringstream ss;
      ss << "cirno_" << setw(8) << setfill('0') << loop_count << ".log"; 
      ofstream logfs(ss.str().c_str(), ios::out); 
      logfs << setprecision(10) << cirno;
    }
  }
}

