diff --git a/Mesh/search.hpp b/Mesh/search.hpp index 727dcc5495be1ad326dc88110f9162b5bf9e2d9a..754fd303a1bb990cd020b189f6ecae4d281bab2a 100644 --- a/Mesh/search.hpp +++ b/Mesh/search.hpp @@ -152,20 +152,26 @@ void greedy_search(State &state, Successor &fn, Evaluator &eval) { } std::vector<delta> changes(actions.size()); - std::transform(actions.begin(), actions.end(), - changes.begin(), [&](const action &a) { - return a(state); - }); + + typename std::vector<action>::const_iterator actions_it; + typename std::vector<delta>::iterator changes_it; + for (actions_it = actions.begin(), changes_it = changes.begin(); + actions_it != actions.end(); actions_it++, changes_it++) { + const action &a = *actions_it; + *changes_it = a(state); + } std::vector<score> scores(actions.size()); - std::transform(changes.begin(), changes.end(), - scores.begin(), [&](const delta &delta) { - delta.apply(state); - auto result = eval(state); - delta.reverse(state); - return result; - }); + typename std::vector<score>::iterator scores_it; + for (changes_it = changes.begin(), scores_it = scores.begin(); + changes_it != changes.end(); changes_it++, scores_it++) { + const delta &delta = *changes_it; + + delta.apply(state); + *scores_it = eval(state); + delta.reverse(state); + } auto it = std::max_element(scores.begin(), scores.end()); @@ -214,21 +220,32 @@ struct mcts_node { return (x_1/visit_count) + std::sqrt(2*std::log(Score(n))/visit_count); } + class compare_by_ucb1 { + std::size_t _n; + public: + compare_by_ucb1(std::size_t n): _n(n) {} + + bool operator()(const mcts_node<Score> &a, const mcts_node<Score> &b) const { + return a.ucb1(_n) < b.ucb1(_n); + } + }; + + struct compare_by_mean { + bool operator()(const mcts_node<Score> &a, const mcts_node<Score> &b) const { + return a.x_1 < b.x_1; + } + }; + std::size_t best_child(std::size_t n) const { auto it = std::max_element(children.begin(), children.end(), - [&](const mcts_node<Score> &a, const mcts_node<Score> &b) { - return a.ucb1(n) < b.ucb1(n); - }); + compare_by_ucb1(n)); return std::distance(children.begin(), it); } std::size_t best_move(std::size_t n) const { auto it = std::max_element(children.begin(), children.end(), - [&](const mcts_node<Score> &a, const mcts_node<Score> &b) { - return a.x_1 < b.x_1; - }); - + compare_by_mean()); return std::distance(children.begin(), it); } };