aboutsummaryrefslogtreecommitdiffstats
path: root/System/system_swingup_test_2.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--System/system_swingup_test_2.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/System/system_swingup_test_2.py b/System/system_swingup_test_2.py
index 81d5419..e240b01 100644
--- a/System/system_swingup_test_2.py
+++ b/System/system_swingup_test_2.py
@@ -53,7 +53,7 @@ class SwingUpEnv():
}
def __init__(self):
- self.sys = System(angular_units='Radians')
+ self.sys = System(angular_units='Radians', positive_limit=10., negative_limit=-10., sw_limit_routine=self.x_threshold_routine)
self.force_mag = 10.
self.last_time = time.time() # time for seconds between updates
@@ -73,6 +73,7 @@ class SwingUpEnv():
self.state = None
self.steps_beyond_done = None
+ self.done = False
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
@@ -83,7 +84,9 @@ class SwingUpEnv():
state = self.state
x, x_dot, theta, theta_dot = state
force = self.force_mag * action[0]
- self.sys.adjust(force)
+ # Do not adjust the motor further if the x_threshold has been triggered by the SW limit
+ if self.done == False:
+ self.sys.adjust(force)
costheta = math.cos(theta)
sintheta = math.sin(theta)
@@ -105,10 +108,15 @@ class SwingUpEnv():
self.state = (new_x, x_dot, new_theta, theta_dot)
self.sys.add_results(new_theta, new_x, force)
- done = x < -self.x_threshold \
+ done = theta_dot < -self.theta_dot_threshold \
+ or theta_dot > self.theta_dot_threshold \
+ or self.done == True
+
+ '''done = x < -self.x_threshold \
or x > self.x_threshold \
or theta_dot < -self.theta_dot_threshold \
- or theta_dot > self.theta_dot_threshold
+ or theta_dot > self.theta_dot_threshold \
+ or self.done == True'''
done = bool(done)
if not done:
@@ -125,6 +133,10 @@ class SwingUpEnv():
return np.array(self.state), reward, done, {'max_up_time' : self.max_up_time}
+ def x_threshold_routine(self):
+ self.done = True
+ self.sys.adjust(0)
+
def reset(self, home = True):
if home == True:
self.sys.return_home()
@@ -138,6 +150,7 @@ class SwingUpEnv():
self.max_up_time = 0
self.up = False
self.steps_beyond_done = None
+ self.done = False
return np.array(self.state)
def end(self):