aboutsummaryrefslogtreecommitdiffstats
path: root/dotsandboxes/agents
diff options
context:
space:
mode:
Diffstat (limited to 'dotsandboxes/agents')
-rw-r--r--dotsandboxes/agents/agent_AB.py57
-rw-r--r--dotsandboxes/agents/agent_MCTS.py55
-rw-r--r--dotsandboxes/agents/agent_random.py212
-rw-r--r--dotsandboxes/agents/algorithms/MCTS.py151
-rw-r--r--dotsandboxes/agents/algorithms/alphaBeta.py105
-rw-r--r--dotsandboxes/agents/algorithms/ann.py170
6 files changed, 750 insertions, 0 deletions
diff --git a/dotsandboxes/agents/agent_AB.py b/dotsandboxes/agents/agent_AB.py
new file mode 100644
index 0000000..5564f11
--- /dev/null
+++ b/dotsandboxes/agents/agent_AB.py
@@ -0,0 +1,57 @@
+from python.alphaBeta import AlphaBeta
+import dotsandboxes.dotsandboxesagent as dba
+
+import sys
+import argparse
+import logging
+from GameState import GameState, DotsAndBoxesState
+import alphaBeta
+
+
+logger = logging.getLogger(__name__)
+games = {}
+agentclass = dba
+
+
+class Agent(dba.DotsAndBoxesAgent):
+ def __init__(self, player, nb_rows, nb_cols, timelimit):
+ super(Agent, self).__init__(player, nb_rows, nb_cols, timelimit)
+ self.GameStateClass = DotsAndBoxesState
+ self.game_state = self.GameStateClass(nb_rows, nb_cols, player)
+ self.controller = AlphaBeta()
+
+ def register_action(self, row, column, orientation, player):
+ super(Agent, self).register_action(row, column, orientation, player)
+ # adjust agent specific board representation
+ move = (row, column, orientation)
+ self.game_state.play_move(move)
+
+ def next_action(self):
+ r, c, o = self.controller.get_next_move(self.game_state, time_allowed=self.timelimit)
+ return r, c, o
+
+ def end_game(self):
+ super(Agent, self).end_game()
+
+
+# Adapted from provided code
+def main(argv=None):
+ global agentclass
+ parser = argparse.ArgumentParser(description='Start agent to play Dots and Boxes')
+ parser.add_argument('--verbose', '-v', action='count', default=0, help='Verbose output')
+ parser.add_argument('--quiet', '-q', action='count', default=0, help='Quiet output')
+ parser.add_argument('port', metavar='PORT', type=int, help='Port to use for server')
+ args = parser.parse_args(argv)
+ print(args)
+
+ logger.setLevel(max(logging.INFO - 10 * (args.verbose - args.quiet), logging.DEBUG))
+ logger.addHandler(logging.StreamHandler(sys.stdout))
+
+ agentclass = Agent
+ dba.agentclass = Agent
+ dba.start_server(args.port)
+
+
+if __name__ == "__main__":
+ sys.exit(main())
+
diff --git a/dotsandboxes/agents/agent_MCTS.py b/dotsandboxes/agents/agent_MCTS.py
new file mode 100644
index 0000000..b60f5ec
--- /dev/null
+++ b/dotsandboxes/agents/agent_MCTS.py
@@ -0,0 +1,55 @@
+import dotsandboxes.dotsandboxesagent as dba
+
+import sys
+import argparse
+import logging
+from GameState import GameState, DotsAndBoxesState
+from MCTS import MCTSNode, MCTSGameController
+
+logger = logging.getLogger(__name__)
+games = {}
+agentclass = dba
+
+
+class Agent(dba.DotsAndBoxesAgent):
+ def __init__(self, player, nb_rows, nb_cols, timelimit):
+ super(Agent, self).__init__(player, nb_rows, nb_cols, timelimit)
+ self.GameStateClass = DotsAndBoxesState
+ self.game_state = self.GameStateClass(nb_rows, nb_cols, player)
+ self.controller = MCTSGameController()
+
+ def register_action(self, row, column, orientation, player):
+ super(Agent, self).register_action(row, column, orientation, player)
+ # adjust agent specific board representation
+ move = (row, column, orientation)
+ self.game_state.play_move(move)
+
+ def next_action(self):
+ r, c, o = self.controller.get_next_move(self.game_state, time_allowed=self.timelimit)
+ return r, c, o
+
+ def end_game(self):
+ super(Agent, self).end_game()
+
+
+# Adapted from provided code
+def main(argv=None):
+ global agentclass
+ parser = argparse.ArgumentParser(description='Start agent to play Dots and Boxes')
+ parser.add_argument('--verbose', '-v', action='count', default=0, help='Verbose output')
+ parser.add_argument('--quiet', '-q', action='count', default=0, help='Quiet output')
+ parser.add_argument('port', metavar='PORT', type=int, help='Port to use for server')
+ args = parser.parse_args(argv)
+ print(args)
+
+ logger.setLevel(max(logging.INFO - 10 * (args.verbose - args.quiet), logging.DEBUG))
+ logger.addHandler(logging.StreamHandler(sys.stdout))
+
+ agentclass = Agent
+ dba.agentclass = Agent
+ dba.start_server(args.port)
+
+
+if __name__ == "__main__":
+ sys.exit(main())
+
diff --git a/dotsandboxes/agents/agent_random.py b/dotsandboxes/agents/agent_random.py
new file mode 100644
index 0000000..abf677b
--- /dev/null
+++ b/dotsandboxes/agents/agent_random.py
@@ -0,0 +1,212 @@
+#!/usr/bin/env python3
+# encoding: utf-8
+"""
+dotsandboxesagent.py
+
+Template for the Machine Learning Project course at KU Leuven (2017-2018)
+of Hendrik Blockeel and Wannes Meert.
+
+Copyright (c) 2018 KU Leuven. All rights reserved.
+"""
+import sys
+import argparse
+import logging
+import asyncio
+import websockets
+import json
+from collections import defaultdict
+import random
+
+logger = logging.getLogger(__name__)
+games = {}
+agentclass = None
+
+
+class DotsAndBoxesAgent:
+ """Example Dots and Boxes agent implementation base class.
+ It returns a random next move.
+
+ A DotsAndBoxesAgent object should implement the following methods:
+ - __init__
+ - add_player
+ - register_action
+ - next_action
+ - end_game
+
+ This class does not necessarily use the best data structures for the
+ approach you want to use.
+ """
+ def __init__(self, player, nb_rows, nb_cols, timelimit):
+ """Create Dots and Boxes agent.
+
+ :param player: Player number, 1 or 2
+ :param nb_rows: Rows in grid
+ :param nb_cols: Columns in grid
+ :param timelimit: Maximum time allowed to send a next action.
+ """
+ self.player = {player}
+ self.timelimit = timelimit
+ self.ended = False
+ self.nb_rows = nb_rows
+ self.nb_cols = nb_cols
+ rows = []
+ for ri in range(nb_rows + 1):
+ columns = []
+ for ci in range(nb_cols + 1):
+ columns.append({"v": 0, "h": 0})
+ rows.append(columns)
+ self.cells = rows
+
+ def add_player(self, player):
+ """Use the same agent for multiple players."""
+ self.player.add(player)
+
+ def register_action(self, row, column, orientation, player):
+ """Register action played in game.
+
+ :param row:
+ :param columns:
+ :param orientation: "v" or "h"
+ :param player: 1 or 2
+ """
+ self.cells[row][column][orientation] = player
+
+ def next_action(self):
+ """Return the next action this agent wants to perform.
+
+ In this example, the function implements a random move. Replace this
+ function with your own approach.
+
+ :return: (row, column, orientation)
+ """
+ logger.info("Computing next move (grid={}x{}, player={})"\
+ .format(self.nb_rows, self.nb_cols, self.player))
+ # Random move
+ free_lines = []
+ for ri in range(len(self.cells)):
+ row = self.cells[ri]
+ for ci in range(len(row)):
+ cell = row[ci]
+ if ri < (len(self.cells) - 1) and cell["v"] == 0:
+ free_lines.append((ri, ci, "v"))
+ if ci < (len(row) - 1) and cell["h"] == 0:
+ free_lines.append((ri, ci, "h"))
+ if len(free_lines) == 0:
+ # Board full
+ return None
+ movei = random.randint(0, len(free_lines) - 1)
+ r, c, o = free_lines[movei]
+ return r, c, o
+
+ def end_game(self):
+ self.ended = True
+
+
+## MAIN EVENT LOOP
+
+async def handler(websocket, path):
+ logger.info("Start listening")
+ game = None
+ # msg = await websocket.recv()
+ try:
+ async for msg in websocket:
+ logger.info("< {}".format(msg))
+ try:
+ msg = json.loads(msg)
+ except json.decoder.JSONDecodeError as err:
+ logger.error(err)
+ return False
+ game = msg["game"]
+ answer = None
+ if msg["type"] == "start":
+ # Initialize game
+ if msg["game"] in games:
+ games[msg["game"]].add_player(msg["player"])
+ else:
+ nb_rows, nb_cols = msg["grid"]
+ games[msg["game"]] = agentclass(msg["player"],
+ nb_rows,
+ nb_cols,
+ msg["timelimit"])
+ if msg["player"] == 1:
+ # Start the game
+ nm = games[game].next_action()
+ print('nm = {}'.format(nm))
+ if nm is None:
+ # Game over
+ logger.info("Game over")
+ continue
+ r, c, o = nm
+ answer = {
+ 'type': 'action',
+ 'location': [r, c],
+ 'orientation': o
+ }
+ else:
+ # Wait for the opponent
+ answer = None
+
+ elif msg["type"] == "action":
+ # An action has been played
+ r, c = msg["location"]
+ o = msg["orientation"]
+ games[game].register_action(r, c, o, msg["player"])
+ if msg["nextplayer"] in games[game].player:
+ # Compute your move
+ nm = games[game].next_action()
+ if nm is None:
+ # Game over
+ logger.info("Game over")
+ continue
+ nr, nc, no = nm
+ answer = {
+ 'type': 'action',
+ 'location': [nr, nc],
+ 'orientation': no
+ }
+ else:
+ answer = None
+
+ elif msg["type"] == "end":
+ # End the game
+ games[msg["game"]].end_game()
+ answer = None
+ else:
+ logger.error("Unknown message type:\n{}".format(msg))
+
+ if answer is not None:
+ print(answer)
+ await websocket.send(json.dumps(answer))
+ logger.info("> {}".format(answer))
+ except websockets.exceptions.ConnectionClosed as err:
+ logger.info("Connection closed")
+ logger.info("Exit handler")
+
+
+def start_server(port):
+ server = websockets.serve(handler, 'localhost', port)
+ print("Running on ws://127.0.0.1:{}".format(port))
+ asyncio.get_event_loop().run_until_complete(server)
+ asyncio.get_event_loop().run_forever()
+
+
+## COMMAND LINE INTERFACE
+
+def main(argv=None):
+ global agentclass
+ parser = argparse.ArgumentParser(description='Start agent to play Dots and Boxes')
+ parser.add_argument('--verbose', '-v', action='count', default=0, help='Verbose output')
+ parser.add_argument('--quiet', '-q', action='count', default=0, help='Quiet output')
+ parser.add_argument('port', metavar='PORT', type=int, help='Port to use for server')
+ args = parser.parse_args(argv)
+
+ logger.setLevel(max(logging.INFO - 10 * (args.verbose - args.quiet), logging.DEBUG))
+ logger.addHandler(logging.StreamHandler(sys.stdout))
+
+ agentclass = DotsAndBoxesAgent
+ start_server(args.port)
+
+
+if __name__ == "__main__":
+ sys.exit(main())
+
diff --git a/dotsandboxes/agents/algorithms/MCTS.py b/dotsandboxes/agents/algorithms/MCTS.py
new file mode 100644
index 0000000..6c71ba9
--- /dev/null
+++ b/dotsandboxes/agents/algorithms/MCTS.py
@@ -0,0 +1,151 @@
+import math
+from copy import deepcopy
+from time import perf_counter
+from random import choice
+
+from GameState import GameState
+
+# Based on https://github.com/DieterBuys/mcts-player/
+
+class GameController(object):
+ def get_next_move(self, state):
+ # when you get a new move, it is assumed that the game is not ended yet
+ assert state.get_moves()
+
+
+class MCTSNode(object):
+ """Monte Carlo Tree Node.
+ Each node encapsulates a particular game state, the moves that
+ are possible from that state and the strategic information accumulated
+ by the tree search as it progressively samples the game space.
+ """
+
+ def __init__(self, state, parent=None, move=None):
+ self.parent = parent
+ self.move = move
+ self.state = state
+
+ self.plays = 0
+ self.score = 0
+
+ self.pending_moves = state.get_moves()
+ self.children = []
+
+ def select_child_ucb(self):
+ # Note that each node's plays count is equal
+ # to the sum of its children's plays
+ def ucb(child):
+ win_ratio = child.score / child.plays \
+ + math.sqrt(2 * math.log(self.plays) / child.plays)
+ return win_ratio
+
+ return max(self.children, key=ucb)
+
+ def expand_move(self, move):
+ self.pending_moves.remove(move) # raises KeyError
+
+ child_state = deepcopy(self.state)
+ child_state.play_move(move)
+
+ child = MCTSNode(state=child_state, parent=self, move=move)
+ self.children.append(child)
+ return child
+
+ def get_score(self, result):
+ # return result
+ if result == 0.5:
+ return result
+
+ if self.state.player == 2:
+ if self.state.next_turn_player == result:
+ return 0.0
+ else:
+ return 1.0
+ else:
+ if self.state.next_turn_player == result:
+ return 1.0
+ else:
+ return 0.0
+
+ if self.state.next_turn_player == result:
+ return 0.0
+ else:
+ return 1.0
+
+ def __repr__(self):
+ s = 'ROOT\n' if self.parent is None else ''
+
+ children_moves = [c.move for c in self.children]
+
+ s += """Score ratio: {score} / {plays}
+Pending moves: {pending_moves}
+Children's moves: {children_moves}
+State:
+{state}\n""".format(children_moves=children_moves, **self.__dict__)
+
+ return s
+
+
+class MCTSGameController(GameController):
+ """Game controller that uses MCTS to determine the next move.
+ This is the class which implements the Monte Carlo Tree Search algorithm.
+ It builds a game tree of MCTSNodes and samples the game space until a set
+ time has elapsed.
+ """
+
+ def select(self):
+ node = self.root_node
+
+ # Descend until we find a node that has pending moves, or is terminal
+ while node.pending_moves == set() and node.children != []:
+ node = node.select_child_ucb()
+
+ return node
+
+ def expand(self, node):
+ assert node.pending_moves != set()
+
+ move = choice(tuple(node.pending_moves))
+ return node.expand_move(move)
+
+ def simulate(self, state, max_iterations=1000):
+ state = deepcopy(state)
+
+ move = state.get_random_move()
+ while move is not None:
+ state.play_move(move)
+ move = state.get_random_move()
+
+ max_iterations -= 1
+ if max_iterations <= 0:
+ return 0.5 # raise exception? (game too deep to simulate)
+
+ return state.game_result
+
+ def update(self, node, result):
+ while node is not None:
+ node.plays += 1
+ node.score += node.get_score(result)
+ node = node.parent
+
+ def get_next_move(self, state, time_allowed=1.0):
+ super(MCTSGameController, self).get_next_move(state)
+
+ # Create new tree (TODO: Preserve some state for better performance?)
+ self.root_node = MCTSNode(state)
+ iterations = 0
+
+ start_time = perf_counter()
+ while perf_counter() < start_time + time_allowed:
+ node = self.select()
+
+ if node.pending_moves != set():
+ node = self.expand(node)
+
+ result = self.simulate(node.state)
+ self.update(node, result)
+
+ iterations += 1
+
+ # Return most visited node's move
+ return max(self.root_node.children, key=lambda n: n.plays).move
diff --git a/dotsandboxes/agents/algorithms/alphaBeta.py b/dotsandboxes/agents/algorithms/alphaBeta.py
new file mode 100644
index 0000000..8e041fe
--- /dev/null
+++ b/dotsandboxes/agents/algorithms/alphaBeta.py
@@ -0,0 +1,105 @@
+from GameState import GameState
+class GameController(object):
+ def get_next_move(self, state):
+ # when you get a new move, it is assumed that the game is not ended yet
+ assert state.get_moves()
+
+
+def alpha_beta(node, alpha, beta):
+
+ # Based on https://en.wikipedia.org/wiki/Alpha%E2%80%93beta_pruning#Pseudocode
+ # node needs to support three operations: isTerminal(), value(), getChildren(), maximizingPlayer()
+
+ if node.isTerminal():
+ return node.value()
+
+ if node.maximizingPlayer():
+
+ v = float("-inf")
+ for child in node.getChildren():
+
+ v = max(v, alpha_beta(child, alpha, beta))
+ alpha = max(alpha, v)
+ if beta <= alpha:
+ break
+
+ else:
+
+ v = float("inf")
+ for child in node.getChildren():
+
+ v = min(v, alpha_beta(child, alpha, beta))
+ beta = min(beta, v)
+ if beta <= alpha:
+ break
+
+ return v
+
+
+# A class for defining algorithms used (minimax and alpha-beta pruning)
+class AlphaBeta:
+
+ def miniMax(State, Ply_num): # Function for the minimax algorithm
+
+ for i in range(State.Current.dimY):
+ for j in range(State.Current.dimX):
+ if State.Current.Mat[i][j] == ' ' and (j, i) not in State.children:
+ State.Make(j, i, True)
+ if Ply_num < 2:
+ return (i, j)
+
+ Minimum_Score = 1000
+ i = 0
+
+ j = 0
+ for k, z in State.children.items():
+ Result = Algo.Maximum(z, Ply_num - 1, Minimum_Score)
+ if Minimum_Score > Result:
+ Minimum_Score = Result
+ i = k[0]
+ j = k[1]
+
+ return (i, j)
+
+ # Alpha-beta pruning function for taking care of Alpha values
+ def Maximum(State, Ply_num, Alpha):
+ if Ply_num == 0:
+ return State.CurrentScore
+
+ for i in range(State.Current.dimY):
+ for j in range(State.Current.dimX):
+ if State.Current.Mat[i][j] == ' ' and (j, i) not in State.children:
+ State.Make(j, i, False)
+
+ Maximum_Score = -1000
+ i = 0
+ j = 0
+ for k, z in State.children.items():
+ Result = Algo.Minimum(z, Ply_num - 1, Maximum_Score)
+ if Maximum_Score < Result:
+ Maximum_Score = Result
+ if Result > Alpha:
+ return Result
+
+ return Maximum_Score
+
+ def Minimum(State, Ply_num, Beta): # Alpha-beta pruning function for taking care of Beta values
+ if Ply_num == 0:
+ return State.CurrentScore
+
+ for i in range(State.Current.dimY):
+ for j in range(State.Current.dimX):
+ if State.Current.Mat[i][j] == ' ' and (j, i) not in State.children:
+ State.Make(j, i, True)
+
+ Minimum_Score = 1000
+ i = 0
+ j = 0
+ for k, z in State.children.items():
+ Result = Algo.Maximum(z, Ply_num - 1, Minimum_Score)
+ if Minimum_Score > Result:
+ Minimum_Score = Result
+ if Result < Beta:
+ return Result
+
+ return Minimum_Score
diff --git a/dotsandboxes/agents/algorithms/ann.py b/dotsandboxes/agents/algorithms/ann.py
new file mode 100644
index 0000000..05ae647
--- /dev/null
+++ b/dotsandboxes/agents/algorithms/ann.py
@@ -0,0 +1,170 @@
+from numpy import *
+from math import sqrt
+from copy import deepcopy
+from time import time
+
+class ANN:
+
+ """ANN with one hidden layer, one output and full connections in between consecutive layers.
+ Initial weights are chosen from a normal distribution.
+ Activation function is tanh."""
+
+ INIT_SIGMA = 0.02
+ REL_STOP_MARGIN = 0.01
+ MAX_ITERATIONS = 1000000
+ ACTIVATION = tanh
+ D_ACTIVATION = lambda x: 1 - tanh(x)**2 # Derivative of tanh
+ VEC_ACTIVATION = vectorize(ACTIVATION)
+ VEC_D_ACTIVATION = vectorize(D_ACTIVATION)
+ STEP_SIZE = 0.1
+
+ def __init__(self, input_size, hidden_size):
+
+ #self.input_size = input_size
+ #self.hidden_size = hidden_size
+ self.hidden_weights = random.normal(0, ANN.INIT_SIGMA, (hidden_size, input_size))
+ self.output_weights = random.normal(0, ANN.INIT_SIGMA, hidden_size)
+
+ def get_weights(self):
+ return self.hidden_weights, self.output_weights
+
+ def predict(self, input_vector):
+
+ # Predicts the output for this input vector
+ # input_vector will be normalized
+
+ input_vector = input_vector/linalg.norm(input_vector)
+ return ANN.ACTIVATION(dot(self.output_weights, ANN.VEC_ACTIVATION(dot(self.hidden_weights, input_vector))))
+
+ @staticmethod
+ def frob_norm(a, b):
+
+ # Calculates the total Frobenius norm of both matrices A and B
+ return sqrt(linalg.norm(a)**2 + linalg.norm(b)**2)
+
+ def train(self, examples):
+
+ #print("Training")
+ start = time()
+
+ # examples is a list of (input, output)-tuples
+ # input will be normalized
+ # We stop when the weights have converged within some relative margin
+
+ for example in examples:
+ example[0] = example[0]/linalg.norm(example[0])
+
+ iteration = 0
+ while True:
+
+
+ # Store old weights to check for convergence later
+ prev_hidden_weights = deepcopy(self.hidden_weights)
+ prev_output_weights = deepcopy(self.output_weights)
+
+ for k in range(len(examples)):
+
+ input_vector, output = examples[k]
+
+ # Calculate outputs
+ hidden_input = dot(self.hidden_weights, input_vector)
+ hidden_output = ANN.VEC_ACTIVATION(hidden_input)
+ final_input = dot(self.output_weights, hidden_output)
+ predicted_output = ANN.ACTIVATION(final_input)
+
+ #print("Output:", output)
+ #print("Predicted output:", predicted_output)
+
+ # Used in calculations
+ prediction_error = output - predicted_output
+ output_derivative = ANN.D_ACTIVATION(final_input)
+
+ # Adjust output weights and calculate requested hidden change
+ requested_hidden_change = prediction_error*output_derivative*self.output_weights
+ self.output_weights = self.output_weights + ANN.STEP_SIZE*prediction_error*hidden_output
+
+ #print("After adjusting output weights:", ANN.ACTIVATION(dot(self.output_weights, hidden_output)))
+
+ # Backpropagate requested hidden change to adjust hidden weights
+ self.hidden_weights = self.hidden_weights + ANN.STEP_SIZE*outer(requested_hidden_change*(ANN.VEC_D_ACTIVATION(hidden_input)), input_vector)
+
+ #print("After adjusting hidden weights:", ANN.ACTIVATION(dot(self.output_weights, ANN.VEC_ACTIVATION(dot(self.hidden_weights, input_vector)))))
+
+ # Check stop criteria
+ iteration += 1
+ if iteration >= ANN.MAX_ITERATIONS:
+ break
+
+ # Check stop criteria
+ if iteration >= ANN.MAX_ITERATIONS:
+ break
+ diff = ANN.frob_norm(self.hidden_weights - prev_hidden_weights, self.output_weights - prev_output_weights)
+ base = ANN.frob_norm(self.hidden_weights, self.output_weights)
+ #if base > 0 and diff/base < ANN.REL_STOP_MARGIN:
+ # break
+
+ print(time() - start)
+ print("Stopped training after %s iterations."%iteration)
+
+# TESTING
+
+def print_difference(ann1, ann2):
+
+ # Prints the differences in weights in between two ANN's with identical topology
+
+ hidden_weights1, output_weights1 = ann1.get_weights()
+ hidden_weights2, output_weights2 = ann2.get_weights()
+ hidden_diff = hidden_weights1 - hidden_weights2
+ output_diff = output_weights1 - output_weights2
+
+ print(hidden_diff)
+ print(output_diff)
+ print("Frobenius norms:")
+ print("Hidden weights difference:", linalg.norm(hidden_diff))
+ print("Output weights difference:", linalg.norm(output_diff))
+ print("Both:", ANN.frob_norm(hidden_diff, output_diff))
+
+def RMSE(ann, examples):
+
+ total = 0
+ for input_vector, output in examples:
+ total += (output - ann.predict(input_vector))**2
+ return sqrt(total/len(examples))
+
+def generate_examples(amount, input_size, evaluate):
+ # evaluate is a function mapping an input vector onto a numerical value
+ examples = []
+ inputs = random.normal(0, 100, (amount, input_size))
+ for i in range(amount):
+ input_vector = inputs[i]
+ examples.append([input_vector, evaluate(input_vector)])
+ return examples
+
+def test():
+
+ # Test the ANN by having it model another ANN with identical topology but unknown weights
+
+ input_size = 5
+ hidden_size = 3
+ real = ANN(input_size, hidden_size)
+ model = ANN(input_size, hidden_size)
+
+ # Generate training data
+ training_data = generate_examples(10000, input_size, real.predict)
+ validation_data = generate_examples(10000, input_size, real.predict)
+
+ # Print initial difference, train, then print new difference
+ print("Initial difference:")
+ print_difference(real, model)
+ print("Initial RMSE (on training data):", RMSE(model, training_data))
+ print("Initial RMSE (on validation data):", RMSE(model, validation_data))
+ model.train(training_data)
+ print("After training:")
+ print_difference(real, model)
+ print("After training RMSE (on training data):", RMSE(model, training_data))
+ print("After training RMSE (on validation data):", RMSE(model, validation_data))
+
+if __name__ == "__main__":
+ test()
+
+