diff options
-rw-r--r-- | GameState.py | 144 | ||||
-rw-r--r-- | MCTS.py | 150 |
2 files changed, 294 insertions, 0 deletions
diff --git a/GameState.py b/GameState.py new file mode 100644 index 0000000..eed8f36 --- /dev/null +++ b/GameState.py @@ -0,0 +1,144 @@ +from random import choice + +# Based on https://github.com/DieterBuys/mcts-player/ + +class GameState(object): + + def __init__(self): + self.next_turn_player = 1 + self.player = None + + @property + def game_result(self): + return None + + def get_moves(self): + return set() + + def get_random_move(self): + moves = self.get_moves() + return choice(tuple(moves)) if moves != set() else None + + def play_move(self, move): + pass + + +class DotsAndBoxesState(GameState): + def __init__(self, nb_rows, nb_cols, player): + super(DotsAndBoxesState, self).__init__() + + self.nb_rows = nb_rows + self.nb_cols = nb_cols + rows = [] + for ri in range(nb_rows + 1): + columns = [] + for ci in range(nb_cols + 1): + columns.append({"v": 0, "h": 0}) + rows.append(columns) + self.board = rows + + self.score = {1: 0, 2: 0} + self.player = player + print("Player: ", player) + + @property + def game_result(self): + def game_decided(nb_cols, nb_rows, scoreP, scoreO): + # the game is decided if the winner is already known even before the game is ended + # you're guaranteed to win the game if you have more than halve of the total points that can be earned + total_points = nb_rows * nb_cols + if scoreP > total_points // 2 or scoreO > total_points // 2: + return True + else: + return False + + # check if the board is full, then decide based on score + free_lines = self.get_moves() + player = self.player + opponent = self.player % 2 + 1 + + if not game_decided(self.nb_cols, self.nb_rows, self.score[player], self.score[opponent]) and len(free_lines) > 0: + return None + elif self.score[player] > self.score[opponent]: + return 1 + elif self.score[player] < self.score[opponent]: + return 0 + else: + return 0.5 + + def get_moves(self): + free_lines = [] + for ri in range(len(self.board)): + row = self.board[ri] + for ci in range(len(row)): + cell = row[ci] + if ri < (len(self.board) - 1) and cell["v"] == 0: + free_lines.append((ri, ci, "v")) + if ci < (len(row) - 1) and cell["h"] == 0: + free_lines.append((ri, ci, "h")) + return set(free_lines) + + def play_move(self, move): + r, c, o = move + assert move in self.get_moves() + + # check if this move makes a box + makes_box = False + if o == "h": + if r - 1 >= 0: + # check above + if self.board[r-1][c]["h"] != 0 and self.board[r-1][c]["v"] != 0 and self.board[r-1][c+1]["v"] != 0: + makes_box = True + self.score[self.next_turn_player] += 1 + if r + 1 <= self.nb_rows: + # check below + if self.board[r+1][c]["h"] != 0 and self.board[r][c]["v"] != 0 and self.board[r][c+1]["v"] != 0: + makes_box = True + self.score[self.next_turn_player] += 1 + + elif o == "v": + if c - 1 >= 0: + # check left + if self.board[r][c-1]["v"] != 0 and self.board[r][c-1]["h"] != 0 and self.board[r+1][c-1]["h"] != 0: + makes_box = True + self.score[self.next_turn_player] += 1 + + if c + 1 <= self.nb_cols: + # check right + if self.board[r][c+1]["v"] != 0 and self.board[r][c]["h"] != 0 and self.board[r+1][c]["h"] != 0: + makes_box = True + self.score[self.next_turn_player] += 1 + + + # register move + self.board[r][c][o] = self.next_turn_player + + if not makes_box: + # switch turns + self.next_turn_player = self.next_turn_player % 2 + 1 + + def __repr__(self): + str = "" + for r in range(self.nb_rows + 1): + for o in ["h", "v"]: + for c in range(self.nb_cols + 1): + if o == "h": + str += "." + if c != self.nb_cols: + if self.board[r][c][o] == 0: + str += " " + else: + str += "__" + else: + str += "\n" + elif o == "v": + if r != self.nb_rows: + if self.board[r][c][o] == 0: + str += " " + else: + str += "|" + if c != self.nb_cols: + str += " " + else: + str += "\n" + return str @@ -0,0 +1,150 @@ +import math +from copy import deepcopy +from time import clock +from random import choice + +from GameState import GameState + +# Based on https://github.com/DieterBuys/mcts-player/ + +class GameController(object): + def get_next_move(self, state): + # when you get a new move, it is assumed that the game is not ended yet + assert state.get_moves() + +class MCTSNode(object): + """Monte Carlo Tree Node. + Each node encapsulates a particular game state, the moves that + are possible from that state and the strategic information accumulated + by the tree search as it progressively samples the game space. + """ + + def __init__(self, state, parent=None, move=None): + self.parent = parent + self.move = move + self.state = state + + self.plays = 0 + self.score = 0 + + self.pending_moves = state.get_moves() + self.children = [] + + def select_child_ucb(self): + # Note that each node's plays count is equal + # to the sum of its children's plays + def ucb(child): + win_ratio = child.score / child.plays \ + + math.sqrt(2 * math.log(self.plays) / child.plays) + return win_ratio + + return max(self.children, key=ucb) + + def expand_move(self, move): + self.pending_moves.remove(move) # raises KeyError + + child_state = deepcopy(self.state) + child_state.play_move(move) + + child = MCTSNode(state=child_state, parent=self, move=move) + self.children.append(child) + return child + + def get_score(self, result): + # return result + if result == 0.5: + return result + + if self.state.player == 2: + if self.state.next_turn_player == result: + return 0.0 + else: + return 1.0 + else: + if self.state.next_turn_player == result: + return 1.0 + else: + return 0.0 + + if self.state.next_turn_player == result: + return 0.0 + else: + return 1.0 + + def __repr__(self): + s = 'ROOT\n' if self.parent is None else '' + + children_moves = [c.move for c in self.children] + + s += """Score ratio: {score} / {plays} +Pending moves: {pending_moves} +Children's moves: {children_moves} +State: +{state}\n""".format(children_moves=children_moves, **self.__dict__) + + return s + + +class MCTSGameController(GameController): + """Game controller that uses MCTS to determine the next move. + This is the class which implements the Monte Carlo Tree Search algorithm. + It builds a game tree of MCTSNodes and samples the game space until a set + time has elapsed. + """ + + def select(self): + node = self.root_node + + # Descend until we find a node that has pending moves, or is terminal + while node.pending_moves == set() and node.children != []: + node = node.select_child_ucb() + + return node + + def expand(self, node): + assert node.pending_moves != set() + + move = choice(tuple(node.pending_moves)) + return node.expand_move(move) + + def simulate(self, state, max_iterations=1000): + state = deepcopy(state) + + move = state.get_random_move() + while move is not None: + state.play_move(move) + move = state.get_random_move() + + max_iterations -= 1 + if max_iterations <= 0: + return 0.5 # raise exception? (game too deep to simulate) + + return state.game_result + + def update(self, node, result): + while node is not None: + node.plays += 1 + node.score += node.get_score(result) + node = node.parent + + def get_next_move(self, state, time_allowed=1.0): + super(MCTSGameController, self).get_next_move(state) + + # Create new tree (TODO: Preserve some state for better performance?) + self.root_node = MCTSNode(state) + iterations = 0 + + start_time = clock() + while clock() < start_time + time_allowed: + node = self.select() + + if node.pending_moves != set(): + node = self.expand(node) + + result = self.simulate(node.state) + self.update(node, result) + + iterations += 1 + + # Return most visited node's move + return max(self.root_node.children, key=lambda n: n.plays).move |