diff --git a/Mesh/search.hpp b/Mesh/search.hpp index b4c6e054d9f4a5fcbb33283fc1322b991ae13acb..4a954c583ff73f29004eb68baa2faeda7ead1f9c 100644 --- a/Mesh/search.hpp +++ b/Mesh/search.hpp @@ -213,13 +213,22 @@ struct mcts_node { return std::distance(children.begin(), it); } + + std::size_t best_move(std::size_t n) const { + auto it = std::max_element(children.begin(), children.end(), + [&](const mcts_node<Score> &a, const mcts_node<Score> &b) { + return a.x_1 < b.x_1; + }); + + return std::distance(children.begin(), it); + } }; /** * Runs a single Monte Carlo Simulation as part of a Monte Carlo Tree Search. * * This operation is done as follows: - * 1. Explore the seach tree of statistics (described by node), playing + * 1. Explore the tree of statistics (node being a pointer to its root), playing * optimially according to those statistics, until reaching * one of its leaves. * 2. Perform a completely random simulation, starting from the leaf found at @@ -244,12 +253,12 @@ void mcts_simulation(mcts_node<Score> *node, ancestors.push_back(node); - bool explore = true; + bool selection = true; while (!actions.empty()) { std::size_t i; - if (explore && node->children.size() == 0) { + if (selection && node->children.size() == 0) { node->children.resize(actions.size()); i = rand() % actions.size(); @@ -257,9 +266,9 @@ void mcts_simulation(mcts_node<Score> *node, node = &node->children[i]; ancestors.push_back(node); - explore = false; + selection = false; } - else if (explore) { + else if (selection) { i = node->best_child(n); node = &node->children[i]; ancestors.push_back(node); @@ -321,7 +330,7 @@ void monte_carlo_tree_search(State &state, Successor &fn, Evaluator &eval) { actions.begin(), actions.end()); } - actions[node.best_child(i)](state).apply(state); + actions[node.best_move(i)](state).apply(state); } } }