#ifndef __EEM_COMMON_H__ #define __EEM_COMMON_H__ #include #include #include #include #include #include #include "param/matrix.h" #include "param/complex.h" #include "algorithm/least_square.h" #include "algorithm/wavelet.h" typedef double float_t; typedef Matrix matrix_t; struct add_coef_item_t { const char *name; int id, lhs_id, rhs_id; float_t init, init_P; add_coef_item_t() : init(0), init_P(1E-8) {} ~add_coef_item_t() {} friend std::ostream& operator<<(std::ostream &out, const add_coef_item_t &item){ out << item.name << "(" << item.id << "): L" << item.lhs_id << ", R" << item.rhs_id << ", " << item.init << ", " << item.init_P; return out; } }; typedef std::vector add_coefs_t; struct EquationErrorMethod { typedef std::map matrix_list_t; typedef std::map > lhs_rhs_indexes_t; ///< (左辺のID, 右辺のインデックスの集合) matrix_list_t &_Ks; matrix_list_t &_Ps; EquationErrorMethod( matrix_list_t &Ks, matrix_list_t &Ps) : _Ks(Ks), _Ps(Ps){} virtual ~EquationErrorMethod(){} /** * * @return (bool) 値の更新があった場合にtrueを返す */ virtual bool update( int lhs_id, matrix_t &lhs, matrix_t &rhs, matrix_t &lhs_R) = 0; virtual void prepare_next_update(const float_t &deltaT){} /** * 状態を出力する * * @param out 出力先 * @param t 時刻 * @param states 状態量 * @param dynamics 運動方程式 * @param table (係数のID, pair<左辺のID、右辺のID>)が記録されたテーブル */ template void dump( std::ostream &out, const float_t &t, const States &states, const Dynamics &dynamics, const std::map > &table){ typedef std::map > table_t; out << t << "," << states.join(","); for(unsigned i(0); i < Dynamics::num_of_coefs; i++){ table_t::const_iterator it(table.find(i)); if(it == table.end()){ out << "," << const_cast(dynamics).coefficient(i); }else{ out << "," << _Ks[(*it).second.first]((*it).second.second, 0); } } for(int i(0); i < states.variables(); i++){ out << "," << 0; } for(unsigned i(0); i < Dynamics::num_of_coefs; i++){ table_t::const_iterator it(table.find(i)); if(it == table.end()){ out << "," << 0; }else{ out << "," << _Ps[(*it).second.first]((*it).second.second, (*it).second.second); } } out << std::endl; } }; struct RLS : public EquationErrorMethod { RecursiveWeightedLeastSquare wls; using EquationErrorMethod::_Ks; using EquationErrorMethod::_Ps; RLS(matrix_list_t &Ks, matrix_list_t &Ps) : EquationErrorMethod(Ks, Ps), wls() {} ~RLS(){} bool update( int lhs_id, matrix_t &lhs, matrix_t &rhs, matrix_t &lhs_R){ wls.getP() = _Ps[lhs_id]; wls(lhs, rhs, lhs_R, _Ks[lhs_id]); _Ps[lhs_id] = wls.getP(); return true; } }; struct FTR : public EquationErrorMethod { using EquationErrorMethod::_Ks; using EquationErrorMethod::_Ps; typedef Complex complex_t; typedef Matrix cmatrix_t; template void matrix_elm_eq(matrix_src_t &src, matrix_dst_t &dst){ for(int i(0); i < dst.rows(); i++){ for(int j(0); j < dst.columns(); j++){ dst(i, j) = src(i, j); } } } /** * FTR用の最小二乗法、行列の要素が複素数になるので、そこから実数を求める * * @param y 左辺 * @param K 右辺の係数行列に相当する部分 * @param P 推定したxの共分散行列の格納先 */ static matrix_t ls( const cmatrix_t &y, const cmatrix_t &K, matrix_t &P){ //adjoint matrix、随伴行列 cmatrix_t K_adj(K.columns(), K.rows()); for(int i(0); i < K_adj.rows(); i++){ for(int j(0); j < K_adj.columns(); j++){ K_adj(i, j) = const_cast(K)(j, i).conjugate(); } } cmatrix_t first_c(K_adj * K), second_c(K_adj * y); matrix_t first_r(first_c.rows(), first_c.columns()), second_r(second_c.rows(), second_c.columns()); #define copy_real(src, dist) \ for(int i(0); i < src.rows(); i++){ \ for(int j(0); j < src.columns(); j++){ \ dist(i, j) = src(i, j).real(); \ } \ } copy_real(first_c, first_r); copy_real(second_c, second_r); #undef copy_op matrix_t first_r_inv(first_r.inverse()); P = first_r_inv; return first_r_inv * second_r; } typedef std::vector omega_table_t; omega_table_t omega_table; ///< FTR対象となる角速度の表 typedef std::vector exp_table_t; exp_table_t exp_table; ///< e^{j \omega \sum \DeltaT}の表 typedef std::map index_dwt_t; index_dwt_t dwt_lhs_table; ///< これまでに累算した結果の表 index_dwt_t dwt_rhs_table; ///< これまでに累算した結果の表 int samples; bool autodiff; FTR( matrix_list_t &Ks, matrix_list_t &Ps, bool _autodiff = false) : EquationErrorMethod(Ks, Ps), samples(0), autodiff(false) {} ~FTR(){} void add_convert_table( const float_t &omega){ omega_table.push_back(omega); exp_table.push_back(complex_t::exp(0)); } FTR *init_convert_table( const float_t &omega_min, const float_t &omega_max, const int steps){ if(steps >= 1){ // ステップ数は正の数 float_t delta_omega( (omega_max - omega_min) / ((steps > 1) ? (steps - 1) : 1)); // ゼロ割対策 float_t omega(omega_min); // DWT用のテーブルを作成する for(int i(0); i < steps; i++, omega += delta_omega){ add_convert_table(omega); } } return this; } FTR *init_convert_table_freq( const float_t &freq_min, const float_t &freq_max, const int steps){ return init_convert_table(freq_min * 2 * M_PI, freq_max * 2 * M_PI, steps); } bool update( int lhs_id, matrix_t &lhs, matrix_t &rhs, matrix_t &lhs_R){ //std::cerr << "L0:" << lhs << std::endl; //std::cerr << "R0:" << rhs << std::endl; // 初回の場合 if(dwt_lhs_table.find(lhs_id) == dwt_lhs_table.end()){ // 行列を登録しておく dwt_lhs_table.insert(std::make_pair(lhs_id, cmatrix_t(exp_table.size(), lhs.columns()))); dwt_rhs_table.insert(std::make_pair(lhs_id, cmatrix_t(exp_table.size(), rhs.columns()))); } // 複素数行列へ変換 cmatrix_t lhs_c(lhs.rows(), lhs.columns()); cmatrix_t rhs_c(rhs.rows(), rhs.columns()); matrix_elm_eq(lhs, lhs_c); matrix_elm_eq(rhs, rhs_c); index_dwt_t::mapped_type &lhs_sum(dwt_lhs_table[lhs_id]), &rhs_sum(dwt_rhs_table[lhs_id]); //std::cerr << "L1:" << lhs_sum << std::endl; //std::cerr << "R1:" << rhs_sum << std::endl; int i(0); for(exp_table_t::iterator it(exp_table.begin()); it != exp_table.end(); ++it, ++i){ // 右辺のDWTをする rhs_sum.pivotMerge(i, 0, (rhs_c * (*it))); // 左辺のDWTをする if(autodiff){ // フーリエ変換の公式を利用して微分する // この場合、左辺には微分量(差分でできている)ではなく状態量が送られてくることに注意 // TODO: なぜかうまく動かないのでよく検証すること!! lhs_sum.pivotMerge(i, 0, (lhs_c * ((*it) * complex_t(0, omega_table[i])))); }else{ // それ以外は普通に差分量が入ってきている lhs_sum.pivotMerge(i, 0, (lhs_c * (*it))); } } //std::cerr << "L2:" << lhs_sum << std::endl; //std::cerr << "R2:" << rhs_sum << std::endl; // 最小二乗法の適用 if(samples > 20){ try{ matrix_elm_eq(ls(lhs_sum, rhs_sum, _Ps[lhs_id]), _Ks[lhs_id]); }catch(std::exception &e){ // 操舵開始前で逆行列の計算が破綻することもあるのではじいておく std::cerr << e.what() << std::endl; } } return true; } void prepare_next_update(const float_t &deltaT){ if(!(samples++)){return;} int i(0); for(exp_table_t::iterator it(exp_table.begin()); it != exp_table.end(); ++it, ++i){ //std::cerr << *it << std::endl; // 変換表(e^{j \sum \DeltaT}の更新)の更新をしておく *it *= complex_t::exp(-omega_table[i] * deltaT); } } }; struct WFR : public EquationErrorMethod { using EquationErrorMethod::_Ks; using EquationErrorMethod::_Ps; typedef DaubechiesCascade<6, matrix_t> cascade_t; const static int default_cascade_depth = 8; struct Filter { matrix_t &_K, &_P; RecursiveWeightedLeastSquare wls; struct Hook { int cascade_depth; cascade_t *cascade; typedef std::map spool_t; spool_t spool; Hook(const int &_cascade_depth) : cascade_depth(_cascade_depth <= 0 ? default_cascade_depth : _cascade_depth), cascade(new cascade_t [cascade_depth]), spool() { for(int i(cascade_depth - 2); i >= 0; --i){ new(&cascade[i]) cascade_t(cascade[i + 1]); } } ~Hook() {delete [] cascade;} void operator()(const int &depth, const matrix_t &value){ spool_t::iterator it(spool.find(depth)); if(it == spool.end()){ spool.insert(std::make_pair(depth, value)); }else{ (it->second) += value; } } void update(matrix_t &new_item){ cascade->propagate(new_item, *this); } void terminate(){ cascade->terminate_propagate(*this); } } hook_lhs, hook_rhs, hook_u; float_t threshold, input_fading; Filter( matrix_t &K, matrix_t &P, const int &_cascade_depth = 0, const float_t &_threshold = 0, const float_t &_input_fading = 0) : _K(K), _P(P), wls(P), hook_lhs(_cascade_depth), hook_rhs(_cascade_depth), hook_u(_cascade_depth), threshold(_threshold), input_fading(_input_fading) {} ~Filter(){} bool update( matrix_t &lhs, matrix_t &rhs, matrix_t &u, matrix_t &lhs_R){ hook_lhs.update(lhs); hook_rhs.update(rhs); hook_u.update(u); bool updated(false); for(Hook::spool_t::iterator it(hook_u.spool.begin()); it != hook_u.spool.end(); ++it){ Hook::spool_t::iterator it_lhs(hook_lhs.spool.find(it->first)); Hook::spool_t::iterator it_rhs(hook_rhs.spool.find(it->first)); if((it_lhs == hook_lhs.spool.end()) || (it_rhs == hook_rhs.spool.end())){ continue; } bool is_use(true); float_t abs2(0); // uは列ベクトル for(int i(0); i < it->second.rows(); i++){ abs2 += pow(it->second(i, 0), 2); } if((it->first < 0) || (abs2 < threshold)){ is_use = false; } //std::cerr << "ABS2(" << it->first << "): " // << (is_use ? "OK, " : "NG, ") << abs2 << std::endl; // 復元用 /*rhook_lhs.spool[it->first].push_back( is_use ? it_lhs->second : zero_matrix_lhs); rhook_rhs.spool[it->first].push_back( is_use ? it_rhs->second : zero_matrix_rhs);*/ if(!is_use){continue;} //srd::cerr << it->first << ", " << it->second(0, 0) << srd::endl; wls(it_lhs->second, it_rhs->second, lhs_R, _K); updated = true; } _P = wls.getP(); hook_lhs.spool.clear(); hook_rhs.spool.clear(); // 入力の時間周波数情報をコントロール if(input_fading == 0){ hook_u.spool.clear(); }else{ for(Hook::spool_t::iterator it(hook_u.spool.begin()); it != hook_u.spool.end(); ++it){ (it->second) *= input_fading; } } return updated; } }; typedef std::map filter_list_t; filter_list_t filter_list; lhs_rhs_indexes_t &input_indexes; WFR( matrix_list_t &Ks, matrix_list_t &Ps, lhs_rhs_indexes_t &rhs_input_indexes, const int &_cascade_depth, const float_t &_threshold, const float_t &_input_fading) : EquationErrorMethod(Ks, Ps), filter_list(), input_indexes(rhs_input_indexes) { for(matrix_list_t::iterator it(Ks.begin()); it != Ks.end(); ++it){ filter_list.insert( std::make_pair(it->first, new Filter(Ks[it->first], Ps[it->first], _cascade_depth, _threshold, _input_fading))); } } ~WFR(){ for(filter_list_t::iterator it(filter_list.begin()); it != filter_list.end(); ++it){ delete it->second; } } bool update( int lhs_id, matrix_t &lhs, matrix_t &rhs, matrix_t &lhs_R){ if(filter_list.find(lhs_id) == filter_list.end()){return false;} matrix_t u(input_indexes[lhs_id].size(), 1); // uは列ベクトル int i(0); for(lhs_rhs_indexes_t::mapped_type::iterator it(input_indexes[lhs_id].begin()); it != input_indexes[lhs_id].end(); ++it){ u(i++, 0) = rhs(0, *it); } // 重み一定 return filter_list[lhs_id]->update(lhs, rhs, u, matrix_t::getI(1)); // 重みをかえる //return filter_list[lhs_id]->update(lhs, rhs, u, lhs_R); } }; #endif /* __EEM_COMMON_H__ */