aboutsummaryrefslogtreecommitdiffstats
path: root/dotsandboxes/agents
diff options
context:
space:
mode:
authorMatt Strapp <strap012@umn.edu>2021-04-26 17:12:01 -0500
committerMatt Strapp <strap012@umn.edu>2021-04-26 17:12:01 -0500
commita093060b0e8a787e51212b5f2879dc839605da65 (patch)
tree7ec2d69219d41ae6447efc41ebaaac34c696984b /dotsandboxes/agents
parentRefactor jsut about everything (diff)
downloadcsci4511w-a093060b0e8a787e51212b5f2879dc839605da65.tar
csci4511w-a093060b0e8a787e51212b5f2879dc839605da65.tar.gz
csci4511w-a093060b0e8a787e51212b5f2879dc839605da65.tar.bz2
csci4511w-a093060b0e8a787e51212b5f2879dc839605da65.tar.lz
csci4511w-a093060b0e8a787e51212b5f2879dc839605da65.tar.xz
csci4511w-a093060b0e8a787e51212b5f2879dc839605da65.tar.zst
csci4511w-a093060b0e8a787e51212b5f2879dc839605da65.zip
Revert "Refactor jsut about everything"
This reverts commit e58a60ed18bde5db28ba96910df518a61b3999b2.
Diffstat (limited to 'dotsandboxes/agents')
-rw-r--r--dotsandboxes/agents/agent_AB.py57
-rw-r--r--dotsandboxes/agents/agent_MCTS.py55
-rw-r--r--dotsandboxes/agents/agent_random.py212
-rw-r--r--dotsandboxes/agents/algorithms/MCTS.py151
-rw-r--r--dotsandboxes/agents/algorithms/alphaBeta.py105
-rw-r--r--dotsandboxes/agents/algorithms/ann.py170
6 files changed, 0 insertions, 750 deletions
diff --git a/dotsandboxes/agents/agent_AB.py b/dotsandboxes/agents/agent_AB.py
deleted file mode 100644
index 5564f11..0000000
--- a/dotsandboxes/agents/agent_AB.py
+++ /dev/null
@@ -1,57 +0,0 @@
-from python.alphaBeta import AlphaBeta
-import dotsandboxes.dotsandboxesagent as dba
-
-import sys
-import argparse
-import logging
-from GameState import GameState, DotsAndBoxesState
-import alphaBeta
-
-
-logger = logging.getLogger(__name__)
-games = {}
-agentclass = dba
-
-
-class Agent(dba.DotsAndBoxesAgent):
- def __init__(self, player, nb_rows, nb_cols, timelimit):
- super(Agent, self).__init__(player, nb_rows, nb_cols, timelimit)
- self.GameStateClass = DotsAndBoxesState
- self.game_state = self.GameStateClass(nb_rows, nb_cols, player)
- self.controller = AlphaBeta()
-
- def register_action(self, row, column, orientation, player):
- super(Agent, self).register_action(row, column, orientation, player)
- # adjust agent specific board representation
- move = (row, column, orientation)
- self.game_state.play_move(move)
-
- def next_action(self):
- r, c, o = self.controller.get_next_move(self.game_state, time_allowed=self.timelimit)
- return r, c, o
-
- def end_game(self):
- super(Agent, self).end_game()
-
-
-# Adapted from provided code
-def main(argv=None):
- global agentclass
- parser = argparse.ArgumentParser(description='Start agent to play Dots and Boxes')
- parser.add_argument('--verbose', '-v', action='count', default=0, help='Verbose output')
- parser.add_argument('--quiet', '-q', action='count', default=0, help='Quiet output')
- parser.add_argument('port', metavar='PORT', type=int, help='Port to use for server')
- args = parser.parse_args(argv)
- print(args)
-
- logger.setLevel(max(logging.INFO - 10 * (args.verbose - args.quiet), logging.DEBUG))
- logger.addHandler(logging.StreamHandler(sys.stdout))
-
- agentclass = Agent
- dba.agentclass = Agent
- dba.start_server(args.port)
-
-
-if __name__ == "__main__":
- sys.exit(main())
-
diff --git a/dotsandboxes/agents/agent_MCTS.py b/dotsandboxes/agents/agent_MCTS.py
deleted file mode 100644
index b60f5ec..0000000
--- a/dotsandboxes/agents/agent_MCTS.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import dotsandboxes.dotsandboxesagent as dba
-
-import sys
-import argparse
-import logging
-from GameState import GameState, DotsAndBoxesState
-from MCTS import MCTSNode, MCTSGameController
-
-logger = logging.getLogger(__name__)
-games = {}
-agentclass = dba
-
-
-class Agent(dba.DotsAndBoxesAgent):
- def __init__(self, player, nb_rows, nb_cols, timelimit):
- super(Agent, self).__init__(player, nb_rows, nb_cols, timelimit)
- self.GameStateClass = DotsAndBoxesState
- self.game_state = self.GameStateClass(nb_rows, nb_cols, player)
- self.controller = MCTSGameController()
-
- def register_action(self, row, column, orientation, player):
- super(Agent, self).register_action(row, column, orientation, player)
- # adjust agent specific board representation
- move = (row, column, orientation)
- self.game_state.play_move(move)
-
- def next_action(self):
- r, c, o = self.controller.get_next_move(self.game_state, time_allowed=self.timelimit)
- return r, c, o
-
- def end_game(self):
- super(Agent, self).end_game()
-
-
-# Adapted from provided code
-def main(argv=None):
- global agentclass
- parser = argparse.ArgumentParser(description='Start agent to play Dots and Boxes')
- parser.add_argument('--verbose', '-v', action='count', default=0, help='Verbose output')
- parser.add_argument('--quiet', '-q', action='count', default=0, help='Quiet output')
- parser.add_argument('port', metavar='PORT', type=int, help='Port to use for server')
- args = parser.parse_args(argv)
- print(args)
-
- logger.setLevel(max(logging.INFO - 10 * (args.verbose - args.quiet), logging.DEBUG))
- logger.addHandler(logging.StreamHandler(sys.stdout))
-
- agentclass = Agent
- dba.agentclass = Agent
- dba.start_server(args.port)
-
-
-if __name__ == "__main__":
- sys.exit(main())
-
diff --git a/dotsandboxes/agents/agent_random.py b/dotsandboxes/agents/agent_random.py
deleted file mode 100644
index abf677b..0000000
--- a/dotsandboxes/agents/agent_random.py
+++ /dev/null
@@ -1,212 +0,0 @@
-#!/usr/bin/env python3
-# encoding: utf-8
-"""
-dotsandboxesagent.py
-
-Template for the Machine Learning Project course at KU Leuven (2017-2018)
-of Hendrik Blockeel and Wannes Meert.
-
-Copyright (c) 2018 KU Leuven. All rights reserved.
-"""
-import sys
-import argparse
-import logging
-import asyncio
-import websockets
-import json
-from collections import defaultdict
-import random
-
-logger = logging.getLogger(__name__)
-games = {}
-agentclass = None
-
-
-class DotsAndBoxesAgent:
- """Example Dots and Boxes agent implementation base class.
- It returns a random next move.
-
- A DotsAndBoxesAgent object should implement the following methods:
- - __init__
- - add_player
- - register_action
- - next_action
- - end_game
-
- This class does not necessarily use the best data structures for the
- approach you want to use.
- """
- def __init__(self, player, nb_rows, nb_cols, timelimit):
- """Create Dots and Boxes agent.
-
- :param player: Player number, 1 or 2
- :param nb_rows: Rows in grid
- :param nb_cols: Columns in grid
- :param timelimit: Maximum time allowed to send a next action.
- """
- self.player = {player}
- self.timelimit = timelimit
- self.ended = False
- self.nb_rows = nb_rows
- self.nb_cols = nb_cols
- rows = []
- for ri in range(nb_rows + 1):
- columns = []
- for ci in range(nb_cols + 1):
- columns.append({"v": 0, "h": 0})
- rows.append(columns)
- self.cells = rows
-
- def add_player(self, player):
- """Use the same agent for multiple players."""
- self.player.add(player)
-
- def register_action(self, row, column, orientation, player):
- """Register action played in game.
-
- :param row:
- :param columns:
- :param orientation: "v" or "h"
- :param player: 1 or 2
- """
- self.cells[row][column][orientation] = player
-
- def next_action(self):
- """Return the next action this agent wants to perform.
-
- In this example, the function implements a random move. Replace this
- function with your own approach.
-
- :return: (row, column, orientation)
- """
- logger.info("Computing next move (grid={}x{}, player={})"\
- .format(self.nb_rows, self.nb_cols, self.player))
- # Random move
- free_lines = []
- for ri in range(len(self.cells)):
- row = self.cells[ri]
- for ci in range(len(row)):
- cell = row[ci]
- if ri < (len(self.cells) - 1) and cell["v"] == 0:
- free_lines.append((ri, ci, "v"))
- if ci < (len(row) - 1) and cell["h"] == 0:
- free_lines.append((ri, ci, "h"))
- if len(free_lines) == 0:
- # Board full
- return None
- movei = random.randint(0, len(free_lines) - 1)
- r, c, o = free_lines[movei]
- return r, c, o
-
- def end_game(self):
- self.ended = True
-
-
-## MAIN EVENT LOOP
-
-async def handler(websocket, path):
- logger.info("Start listening")
- game = None
- # msg = await websocket.recv()
- try:
- async for msg in websocket:
- logger.info("< {}".format(msg))
- try:
- msg = json.loads(msg)
- except json.decoder.JSONDecodeError as err:
- logger.error(err)
- return False
- game = msg["game"]
- answer = None
- if msg["type"] == "start":
- # Initialize game
- if msg["game"] in games:
- games[msg["game"]].add_player(msg["player"])
- else:
- nb_rows, nb_cols = msg["grid"]
- games[msg["game"]] = agentclass(msg["player"],
- nb_rows,
- nb_cols,
- msg["timelimit"])
- if msg["player"] == 1:
- # Start the game
- nm = games[game].next_action()
- print('nm = {}'.format(nm))
- if nm is None:
- # Game over
- logger.info("Game over")
- continue
- r, c, o = nm
- answer = {
- 'type': 'action',
- 'location': [r, c],
- 'orientation': o
- }
- else:
- # Wait for the opponent
- answer = None
-
- elif msg["type"] == "action":
- # An action has been played
- r, c = msg["location"]
- o = msg["orientation"]
- games[game].register_action(r, c, o, msg["player"])
- if msg["nextplayer"] in games[game].player:
- # Compute your move
- nm = games[game].next_action()
- if nm is None:
- # Game over
- logger.info("Game over")
- continue
- nr, nc, no = nm
- answer = {
- 'type': 'action',
- 'location': [nr, nc],
- 'orientation': no
- }
- else:
- answer = None
-
- elif msg["type"] == "end":
- # End the game
- games[msg["game"]].end_game()
- answer = None
- else:
- logger.error("Unknown message type:\n{}".format(msg))
-
- if answer is not None:
- print(answer)
- await websocket.send(json.dumps(answer))
- logger.info("> {}".format(answer))
- except websockets.exceptions.ConnectionClosed as err:
- logger.info("Connection closed")
- logger.info("Exit handler")
-
-
-def start_server(port):
- server = websockets.serve(handler, 'localhost', port)
- print("Running on ws://127.0.0.1:{}".format(port))
- asyncio.get_event_loop().run_until_complete(server)
- asyncio.get_event_loop().run_forever()
-
-
-## COMMAND LINE INTERFACE
-
-def main(argv=None):
- global agentclass
- parser = argparse.ArgumentParser(description='Start agent to play Dots and Boxes')
- parser.add_argument('--verbose', '-v', action='count', default=0, help='Verbose output')
- parser.add_argument('--quiet', '-q', action='count', default=0, help='Quiet output')
- parser.add_argument('port', metavar='PORT', type=int, help='Port to use for server')
- args = parser.parse_args(argv)
-
- logger.setLevel(max(logging.INFO - 10 * (args.verbose - args.quiet), logging.DEBUG))
- logger.addHandler(logging.StreamHandler(sys.stdout))
-
- agentclass = DotsAndBoxesAgent
- start_server(args.port)
-
-
-if __name__ == "__main__":
- sys.exit(main())
-
diff --git a/dotsandboxes/agents/algorithms/MCTS.py b/dotsandboxes/agents/algorithms/MCTS.py
deleted file mode 100644
index 6c71ba9..0000000
--- a/dotsandboxes/agents/algorithms/MCTS.py
+++ /dev/null
@@ -1,151 +0,0 @@
-import math
-from copy import deepcopy
-from time import perf_counter
-from random import choice
-
-from GameState import GameState
-
-# Based on https://github.com/DieterBuys/mcts-player/
-
-class GameController(object):
- def get_next_move(self, state):
- # when you get a new move, it is assumed that the game is not ended yet
- assert state.get_moves()
-
-
-class MCTSNode(object):
- """Monte Carlo Tree Node.
- Each node encapsulates a particular game state, the moves that
- are possible from that state and the strategic information accumulated
- by the tree search as it progressively samples the game space.
- """
-
- def __init__(self, state, parent=None, move=None):
- self.parent = parent
- self.move = move
- self.state = state
-
- self.plays = 0
- self.score = 0
-
- self.pending_moves = state.get_moves()
- self.children = []
-
- def select_child_ucb(self):
- # Note that each node's plays count is equal
- # to the sum of its children's plays
- def ucb(child):
- win_ratio = child.score / child.plays \
- + math.sqrt(2 * math.log(self.plays) / child.plays)
- return win_ratio
-
- return max(self.children, key=ucb)
-
- def expand_move(self, move):
- self.pending_moves.remove(move) # raises KeyError
-
- child_state = deepcopy(self.state)
- child_state.play_move(move)
-
- child = MCTSNode(state=child_state, parent=self, move=move)
- self.children.append(child)
- return child
-
- def get_score(self, result):
- # return result
- if result == 0.5:
- return result
-
- if self.state.player == 2:
- if self.state.next_turn_player == result:
- return 0.0
- else:
- return 1.0
- else:
- if self.state.next_turn_player == result:
- return 1.0
- else:
- return 0.0
-
- if self.state.next_turn_player == result:
- return 0.0
- else:
- return 1.0
-
- def __repr__(self):
- s = 'ROOT\n' if self.parent is None else ''
-
- children_moves = [c.move for c in self.children]
-
- s += """Score ratio: {score} / {plays}
-Pending moves: {pending_moves}
-Children's moves: {children_moves}
-State:
-{state}\n""".format(children_moves=children_moves, **self.__dict__)
-
- return s
-
-
-class MCTSGameController(GameController):
- """Game controller that uses MCTS to determine the next move.
- This is the class which implements the Monte Carlo Tree Search algorithm.
- It builds a game tree of MCTSNodes and samples the game space until a set
- time has elapsed.
- """
-
- def select(self):
- node = self.root_node
-
- # Descend until we find a node that has pending moves, or is terminal
- while node.pending_moves == set() and node.children != []:
- node = node.select_child_ucb()
-
- return node
-
- def expand(self, node):
- assert node.pending_moves != set()
-
- move = choice(tuple(node.pending_moves))
- return node.expand_move(move)
-
- def simulate(self, state, max_iterations=1000):
- state = deepcopy(state)
-
- move = state.get_random_move()
- while move is not None:
- state.play_move(move)
- move = state.get_random_move()
-
- max_iterations -= 1
- if max_iterations <= 0:
- return 0.5 # raise exception? (game too deep to simulate)
-
- return state.game_result
-
- def update(self, node, result):
- while node is not None:
- node.plays += 1
- node.score += node.get_score(result)
- node = node.parent
-
- def get_next_move(self, state, time_allowed=1.0):
- super(MCTSGameController, self).get_next_move(state)
-
- # Create new tree (TODO: Preserve some state for better performance?)
- self.root_node = MCTSNode(state)
- iterations = 0
-
- start_time = perf_counter()
- while perf_counter() < start_time + time_allowed:
- node = self.select()
-
- if node.pending_moves != set():
- node = self.expand(node)
-
- result = self.simulate(node.state)
- self.update(node, result)
-
- iterations += 1
-
- # Return most visited node's move
- return max(self.root_node.children, key=lambda n: n.plays).move
diff --git a/dotsandboxes/agents/algorithms/alphaBeta.py b/dotsandboxes/agents/algorithms/alphaBeta.py
deleted file mode 100644
index 8e041fe..0000000
--- a/dotsandboxes/agents/algorithms/alphaBeta.py
+++ /dev/null
@@ -1,105 +0,0 @@
-from GameState import GameState
-class GameController(object):
- def get_next_move(self, state):
- # when you get a new move, it is assumed that the game is not ended yet
- assert state.get_moves()
-
-
-def alpha_beta(node, alpha, beta):
-
- # Based on https://en.wikipedia.org/wiki/Alpha%E2%80%93beta_pruning#Pseudocode
- # node needs to support three operations: isTerminal(), value(), getChildren(), maximizingPlayer()
-
- if node.isTerminal():
- return node.value()
-
- if node.maximizingPlayer():
-
- v = float("-inf")
- for child in node.getChildren():
-
- v = max(v, alpha_beta(child, alpha, beta))
- alpha = max(alpha, v)
- if beta <= alpha:
- break
-
- else:
-
- v = float("inf")
- for child in node.getChildren():
-
- v = min(v, alpha_beta(child, alpha, beta))
- beta = min(beta, v)
- if beta <= alpha:
- break
-
- return v
-
-
-# A class for defining algorithms used (minimax and alpha-beta pruning)
-class AlphaBeta:
-
- def miniMax(State, Ply_num): # Function for the minimax algorithm
-
- for i in range(State.Current.dimY):
- for j in range(State.Current.dimX):
- if State.Current.Mat[i][j] == ' ' and (j, i) not in State.children:
- State.Make(j, i, True)
- if Ply_num < 2:
- return (i, j)
-
- Minimum_Score = 1000
- i = 0
-
- j = 0
- for k, z in State.children.items():
- Result = Algo.Maximum(z, Ply_num - 1, Minimum_Score)
- if Minimum_Score > Result:
- Minimum_Score = Result
- i = k[0]
- j = k[1]
-
- return (i, j)
-
- # Alpha-beta pruning function for taking care of Alpha values
- def Maximum(State, Ply_num, Alpha):
- if Ply_num == 0:
- return State.CurrentScore
-
- for i in range(State.Current.dimY):
- for j in range(State.Current.dimX):
- if State.Current.Mat[i][j] == ' ' and (j, i) not in State.children:
- State.Make(j, i, False)
-
- Maximum_Score = -1000
- i = 0
- j = 0
- for k, z in State.children.items():
- Result = Algo.Minimum(z, Ply_num - 1, Maximum_Score)
- if Maximum_Score < Result:
- Maximum_Score = Result
- if Result > Alpha:
- return Result
-
- return Maximum_Score
-
- def Minimum(State, Ply_num, Beta): # Alpha-beta pruning function for taking care of Beta values
- if Ply_num == 0:
- return State.CurrentScore
-
- for i in range(State.Current.dimY):
- for j in range(State.Current.dimX):
- if State.Current.Mat[i][j] == ' ' and (j, i) not in State.children:
- State.Make(j, i, True)
-
- Minimum_Score = 1000
- i = 0
- j = 0
- for k, z in State.children.items():
- Result = Algo.Maximum(z, Ply_num - 1, Minimum_Score)
- if Minimum_Score > Result:
- Minimum_Score = Result
- if Result < Beta:
- return Result
-
- return Minimum_Score
diff --git a/dotsandboxes/agents/algorithms/ann.py b/dotsandboxes/agents/algorithms/ann.py
deleted file mode 100644
index 05ae647..0000000
--- a/dotsandboxes/agents/algorithms/ann.py
+++ /dev/null
@@ -1,170 +0,0 @@
-from numpy import *
-from math import sqrt
-from copy import deepcopy
-from time import time
-
-class ANN:
-
- """ANN with one hidden layer, one output and full connections in between consecutive layers.
- Initial weights are chosen from a normal distribution.
- Activation function is tanh."""
-
- INIT_SIGMA = 0.02
- REL_STOP_MARGIN = 0.01
- MAX_ITERATIONS = 1000000
- ACTIVATION = tanh
- D_ACTIVATION = lambda x: 1 - tanh(x)**2 # Derivative of tanh
- VEC_ACTIVATION = vectorize(ACTIVATION)
- VEC_D_ACTIVATION = vectorize(D_ACTIVATION)
- STEP_SIZE = 0.1
-
- def __init__(self, input_size, hidden_size):
-
- #self.input_size = input_size
- #self.hidden_size = hidden_size
- self.hidden_weights = random.normal(0, ANN.INIT_SIGMA, (hidden_size, input_size))
- self.output_weights = random.normal(0, ANN.INIT_SIGMA, hidden_size)
-
- def get_weights(self):
- return self.hidden_weights, self.output_weights
-
- def predict(self, input_vector):
-
- # Predicts the output for this input vector
- # input_vector will be normalized
-
- input_vector = input_vector/linalg.norm(input_vector)
- return ANN.ACTIVATION(dot(self.output_weights, ANN.VEC_ACTIVATION(dot(self.hidden_weights, input_vector))))
-
- @staticmethod
- def frob_norm(a, b):
-
- # Calculates the total Frobenius norm of both matrices A and B
- return sqrt(linalg.norm(a)**2 + linalg.norm(b)**2)
-
- def train(self, examples):
-
- #print("Training")
- start = time()
-
- # examples is a list of (input, output)-tuples
- # input will be normalized
- # We stop when the weights have converged within some relative margin
-
- for example in examples:
- example[0] = example[0]/linalg.norm(example[0])
-
- iteration = 0
- while True:
-
-
- # Store old weights to check for convergence later
- prev_hidden_weights = deepcopy(self.hidden_weights)
- prev_output_weights = deepcopy(self.output_weights)
-
- for k in range(len(examples)):
-
- input_vector, output = examples[k]
-
- # Calculate outputs
- hidden_input = dot(self.hidden_weights, input_vector)
- hidden_output = ANN.VEC_ACTIVATION(hidden_input)
- final_input = dot(self.output_weights, hidden_output)
- predicted_output = ANN.ACTIVATION(final_input)
-
- #print("Output:", output)
- #print("Predicted output:", predicted_output)
-
- # Used in calculations
- prediction_error = output - predicted_output
- output_derivative = ANN.D_ACTIVATION(final_input)
-
- # Adjust output weights and calculate requested hidden change
- requested_hidden_change = prediction_error*output_derivative*self.output_weights
- self.output_weights = self.output_weights + ANN.STEP_SIZE*prediction_error*hidden_output
-
- #print("After adjusting output weights:", ANN.ACTIVATION(dot(self.output_weights, hidden_output)))
-
- # Backpropagate requested hidden change to adjust hidden weights
- self.hidden_weights = self.hidden_weights + ANN.STEP_SIZE*outer(requested_hidden_change*(ANN.VEC_D_ACTIVATION(hidden_input)), input_vector)
-
- #print("After adjusting hidden weights:", ANN.ACTIVATION(dot(self.output_weights, ANN.VEC_ACTIVATION(dot(self.hidden_weights, input_vector)))))
-
- # Check stop criteria
- iteration += 1
- if iteration >= ANN.MAX_ITERATIONS:
- break
-
- # Check stop criteria
- if iteration >= ANN.MAX_ITERATIONS:
- break
- diff = ANN.frob_norm(self.hidden_weights - prev_hidden_weights, self.output_weights - prev_output_weights)
- base = ANN.frob_norm(self.hidden_weights, self.output_weights)
- #if base > 0 and diff/base < ANN.REL_STOP_MARGIN:
- # break
-
- print(time() - start)
- print("Stopped training after %s iterations."%iteration)
-
-# TESTING
-
-def print_difference(ann1, ann2):
-
- # Prints the differences in weights in between two ANN's with identical topology
-
- hidden_weights1, output_weights1 = ann1.get_weights()
- hidden_weights2, output_weights2 = ann2.get_weights()
- hidden_diff = hidden_weights1 - hidden_weights2
- output_diff = output_weights1 - output_weights2
-
- print(hidden_diff)
- print(output_diff)
- print("Frobenius norms:")
- print("Hidden weights difference:", linalg.norm(hidden_diff))
- print("Output weights difference:", linalg.norm(output_diff))
- print("Both:", ANN.frob_norm(hidden_diff, output_diff))
-
-def RMSE(ann, examples):
-
- total = 0
- for input_vector, output in examples:
- total += (output - ann.predict(input_vector))**2
- return sqrt(total/len(examples))
-
-def generate_examples(amount, input_size, evaluate):
- # evaluate is a function mapping an input vector onto a numerical value
- examples = []
- inputs = random.normal(0, 100, (amount, input_size))
- for i in range(amount):
- input_vector = inputs[i]
- examples.append([input_vector, evaluate(input_vector)])
- return examples
-
-def test():
-
- # Test the ANN by having it model another ANN with identical topology but unknown weights
-
- input_size = 5
- hidden_size = 3
- real = ANN(input_size, hidden_size)
- model = ANN(input_size, hidden_size)
-
- # Generate training data
- training_data = generate_examples(10000, input_size, real.predict)
- validation_data = generate_examples(10000, input_size, real.predict)
-
- # Print initial difference, train, then print new difference
- print("Initial difference:")
- print_difference(real, model)
- print("Initial RMSE (on training data):", RMSE(model, training_data))
- print("Initial RMSE (on validation data):", RMSE(model, validation_data))
- model.train(training_data)
- print("After training:")
- print_difference(real, model)
- print("After training RMSE (on training data):", RMSE(model, training_data))
- print("After training RMSE (on validation data):", RMSE(model, validation_data))
-
-if __name__ == "__main__":
- test()
-
-