aboutsummaryrefslogtreecommitdiffstats
path: root/python/MCTS.py
blob: a65e2d476ddf29d2dace1c68626c01559549290a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import math
from copy import deepcopy
from time import clock
from random import choice

from GameState import GameState

# Based on https://github.com/DieterBuys/mcts-player/

class GameController(object):
    def get_next_move(self, state):
        # when you get a new move, it is assumed that the game is not ended yet
        assert state.get_moves()


class MCTSNode(object):
    """Monte Carlo Tree Node.
    Each node encapsulates a particular game state, the moves that
    are possible from that state and the strategic information accumulated
    by the tree search as it progressively samples the game space.
    """

    def __init__(self, state, parent=None, move=None):
        self.parent = parent
        self.move = move
        self.state = state

        self.plays = 0
        self.score = 0

        self.pending_moves = state.get_moves()
        self.children = []

    def select_child_ucb(self):
        # Note that each node's plays count is equal
        # to the sum of its children's plays
        def ucb(child):
            win_ratio = child.score / child.plays \
                        + math.sqrt(2 * math.log(self.plays) / child.plays)
            return win_ratio

        return max(self.children, key=ucb)

    def expand_move(self, move):
        self.pending_moves.remove(move)  # raises KeyError

        child_state = deepcopy(self.state)
        child_state.play_move(move)

        child = MCTSNode(state=child_state, parent=self, move=move)
        self.children.append(child)
        return child

    def get_score(self, result):
        # return result
        if result == 0.5:
            return result

        if self.state.player == 2:
            if self.state.next_turn_player == result:
                return 0.0
            else:
                return 1.0
        else:
            if self.state.next_turn_player == result:
                return 1.0
            else:
                return 0.0

        if self.state.next_turn_player == result:
            return 0.0
        else:
            return 1.0

    def __repr__(self):
        s = 'ROOT\n' if self.parent is None else ''

        children_moves = [c.move for c in self.children]

        s += """Score ratio: {score} / {plays}
Pending moves: {pending_moves}
Children's moves: {children_moves}
State:
{state}\n""".format(children_moves=children_moves, **self.__dict__)

        return s


class MCTSGameController(GameController):
    """Game controller that uses MCTS to determine the next move.
    This is the class which implements the Monte Carlo Tree Search algorithm.
    It builds a game tree of MCTSNodes and samples the game space until a set
    time has elapsed.
    """

    def select(self):
        node = self.root_node

        # Descend until we find a node that has pending moves, or is terminal
        while node.pending_moves == set() and node.children != []:
            node = node.select_child_ucb()

        return node

    def expand(self, node):
        assert node.pending_moves != set()

        move = choice(tuple(node.pending_moves))
        return node.expand_move(move)

    def simulate(self, state, max_iterations=1000):
        state = deepcopy(state)

        move = state.get_random_move()
        while move is not None:
            state.play_move(move)
            move = state.get_random_move()

            max_iterations -= 1
            if max_iterations <= 0:
                return 0.5  # raise exception? (game too deep to simulate)

        return state.game_result

    def update(self, node, result):
        while node is not None:
            node.plays += 1
            node.score += node.get_score(result)
            node = node.parent

    def get_next_move(self, state, time_allowed=1.0):
        super(MCTSGameController, self).get_next_move(state)

        # Create new tree (TODO: Preserve some state for better performance?)
        self.root_node = MCTSNode(state)
        iterations = 0

        start_time = clock()
        while clock() < start_time + time_allowed:
            node = self.select()

            if node.pending_moves != set():
                node = self.expand(node)

            result = self.simulate(node.state)
            self.update(node, result)

            iterations += 1

        # Return most visited node's move
        return max(self.root_node.children, key=lambda n: n.plays).move