aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--GameState.py144
-rw-r--r--MCTS.py150
2 files changed, 294 insertions, 0 deletions
diff --git a/GameState.py b/GameState.py
new file mode 100644
index 0000000..eed8f36
--- /dev/null
+++ b/GameState.py
@@ -0,0 +1,144 @@
+from random import choice
+
+# Based on https://github.com/DieterBuys/mcts-player/
+
+class GameState(object):
+
+ def __init__(self):
+ self.next_turn_player = 1
+ self.player = None
+
+ @property
+ def game_result(self):
+ return None
+
+ def get_moves(self):
+ return set()
+
+ def get_random_move(self):
+ moves = self.get_moves()
+ return choice(tuple(moves)) if moves != set() else None
+
+ def play_move(self, move):
+ pass
+
+
+class DotsAndBoxesState(GameState):
+ def __init__(self, nb_rows, nb_cols, player):
+ super(DotsAndBoxesState, self).__init__()
+
+ self.nb_rows = nb_rows
+ self.nb_cols = nb_cols
+ rows = []
+ for ri in range(nb_rows + 1):
+ columns = []
+ for ci in range(nb_cols + 1):
+ columns.append({"v": 0, "h": 0})
+ rows.append(columns)
+ self.board = rows
+
+ self.score = {1: 0, 2: 0}
+ self.player = player
+ print("Player: ", player)
+
+ @property
+ def game_result(self):
+ def game_decided(nb_cols, nb_rows, scoreP, scoreO):
+ # the game is decided if the winner is already known even before the game is ended
+ # you're guaranteed to win the game if you have more than halve of the total points that can be earned
+ total_points = nb_rows * nb_cols
+ if scoreP > total_points // 2 or scoreO > total_points // 2:
+ return True
+ else:
+ return False
+
+ # check if the board is full, then decide based on score
+ free_lines = self.get_moves()
+ player = self.player
+ opponent = self.player % 2 + 1
+
+ if not game_decided(self.nb_cols, self.nb_rows, self.score[player], self.score[opponent]) and len(free_lines) > 0:
+ return None
+ elif self.score[player] > self.score[opponent]:
+ return 1
+ elif self.score[player] < self.score[opponent]:
+ return 0
+ else:
+ return 0.5
+
+ def get_moves(self):
+ free_lines = []
+ for ri in range(len(self.board)):
+ row = self.board[ri]
+ for ci in range(len(row)):
+ cell = row[ci]
+ if ri < (len(self.board) - 1) and cell["v"] == 0:
+ free_lines.append((ri, ci, "v"))
+ if ci < (len(row) - 1) and cell["h"] == 0:
+ free_lines.append((ri, ci, "h"))
+ return set(free_lines)
+
+ def play_move(self, move):
+ r, c, o = move
+ assert move in self.get_moves()
+
+ # check if this move makes a box
+ makes_box = False
+ if o == "h":
+ if r - 1 >= 0:
+ # check above
+ if self.board[r-1][c]["h"] != 0 and self.board[r-1][c]["v"] != 0 and self.board[r-1][c+1]["v"] != 0:
+ makes_box = True
+ self.score[self.next_turn_player] += 1
+ if r + 1 <= self.nb_rows:
+ # check below
+ if self.board[r+1][c]["h"] != 0 and self.board[r][c]["v"] != 0 and self.board[r][c+1]["v"] != 0:
+ makes_box = True
+ self.score[self.next_turn_player] += 1
+
+ elif o == "v":
+ if c - 1 >= 0:
+ # check left
+ if self.board[r][c-1]["v"] != 0 and self.board[r][c-1]["h"] != 0 and self.board[r+1][c-1]["h"] != 0:
+ makes_box = True
+ self.score[self.next_turn_player] += 1
+
+ if c + 1 <= self.nb_cols:
+ # check right
+ if self.board[r][c+1]["v"] != 0 and self.board[r][c]["h"] != 0 and self.board[r+1][c]["h"] != 0:
+ makes_box = True
+ self.score[self.next_turn_player] += 1
+
+
+ # register move
+ self.board[r][c][o] = self.next_turn_player
+
+ if not makes_box:
+ # switch turns
+ self.next_turn_player = self.next_turn_player % 2 + 1
+
+ def __repr__(self):
+ str = ""
+ for r in range(self.nb_rows + 1):
+ for o in ["h", "v"]:
+ for c in range(self.nb_cols + 1):
+ if o == "h":
+ str += "."
+ if c != self.nb_cols:
+ if self.board[r][c][o] == 0:
+ str += " "
+ else:
+ str += "__"
+ else:
+ str += "\n"
+ elif o == "v":
+ if r != self.nb_rows:
+ if self.board[r][c][o] == 0:
+ str += " "
+ else:
+ str += "|"
+ if c != self.nb_cols:
+ str += " "
+ else:
+ str += "\n"
+ return str
diff --git a/MCTS.py b/MCTS.py
new file mode 100644
index 0000000..7e81ac6
--- /dev/null
+++ b/MCTS.py
@@ -0,0 +1,150 @@
+import math
+from copy import deepcopy
+from time import clock
+from random import choice
+
+from GameState import GameState
+
+# Based on https://github.com/DieterBuys/mcts-player/
+
+class GameController(object):
+ def get_next_move(self, state):
+ # when you get a new move, it is assumed that the game is not ended yet
+ assert state.get_moves()
+
+class MCTSNode(object):
+ """Monte Carlo Tree Node.
+ Each node encapsulates a particular game state, the moves that
+ are possible from that state and the strategic information accumulated
+ by the tree search as it progressively samples the game space.
+ """
+
+ def __init__(self, state, parent=None, move=None):
+ self.parent = parent
+ self.move = move
+ self.state = state
+
+ self.plays = 0
+ self.score = 0
+
+ self.pending_moves = state.get_moves()
+ self.children = []
+
+ def select_child_ucb(self):
+ # Note that each node's plays count is equal
+ # to the sum of its children's plays
+ def ucb(child):
+ win_ratio = child.score / child.plays \
+ + math.sqrt(2 * math.log(self.plays) / child.plays)
+ return win_ratio
+
+ return max(self.children, key=ucb)
+
+ def expand_move(self, move):
+ self.pending_moves.remove(move) # raises KeyError
+
+ child_state = deepcopy(self.state)
+ child_state.play_move(move)
+
+ child = MCTSNode(state=child_state, parent=self, move=move)
+ self.children.append(child)
+ return child
+
+ def get_score(self, result):
+ # return result
+ if result == 0.5:
+ return result
+
+ if self.state.player == 2:
+ if self.state.next_turn_player == result:
+ return 0.0
+ else:
+ return 1.0
+ else:
+ if self.state.next_turn_player == result:
+ return 1.0
+ else:
+ return 0.0
+
+ if self.state.next_turn_player == result:
+ return 0.0
+ else:
+ return 1.0
+
+ def __repr__(self):
+ s = 'ROOT\n' if self.parent is None else ''
+
+ children_moves = [c.move for c in self.children]
+
+ s += """Score ratio: {score} / {plays}
+Pending moves: {pending_moves}
+Children's moves: {children_moves}
+State:
+{state}\n""".format(children_moves=children_moves, **self.__dict__)
+
+ return s
+
+
+class MCTSGameController(GameController):
+ """Game controller that uses MCTS to determine the next move.
+ This is the class which implements the Monte Carlo Tree Search algorithm.
+ It builds a game tree of MCTSNodes and samples the game space until a set
+ time has elapsed.
+ """
+
+ def select(self):
+ node = self.root_node
+
+ # Descend until we find a node that has pending moves, or is terminal
+ while node.pending_moves == set() and node.children != []:
+ node = node.select_child_ucb()
+
+ return node
+
+ def expand(self, node):
+ assert node.pending_moves != set()
+
+ move = choice(tuple(node.pending_moves))
+ return node.expand_move(move)
+
+ def simulate(self, state, max_iterations=1000):
+ state = deepcopy(state)
+
+ move = state.get_random_move()
+ while move is not None:
+ state.play_move(move)
+ move = state.get_random_move()
+
+ max_iterations -= 1
+ if max_iterations <= 0:
+ return 0.5 # raise exception? (game too deep to simulate)
+
+ return state.game_result
+
+ def update(self, node, result):
+ while node is not None:
+ node.plays += 1
+ node.score += node.get_score(result)
+ node = node.parent
+
+ def get_next_move(self, state, time_allowed=1.0):
+ super(MCTSGameController, self).get_next_move(state)
+
+ # Create new tree (TODO: Preserve some state for better performance?)
+ self.root_node = MCTSNode(state)
+ iterations = 0
+
+ start_time = clock()
+ while clock() < start_time + time_allowed:
+ node = self.select()
+
+ if node.pending_moves != set():
+ node = self.expand(node)
+
+ result = self.simulate(node.state)
+ self.update(node, result)
+
+ iterations += 1
+
+ # Return most visited node's move
+ return max(self.root_node.children, key=lambda n: n.plays).move