aboutsummaryrefslogtreecommitdiffstats
path: root/System_Python/system_swingup_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'System_Python/system_swingup_test.py')
-rw-r--r--System_Python/system_swingup_test.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/System_Python/system_swingup_test.py b/System_Python/system_swingup_test.py
index e4d3a72..7a9b5bd 100644
--- a/System_Python/system_swingup_test.py
+++ b/System_Python/system_swingup_test.py
@@ -52,7 +52,7 @@ class SwingUpEnv():
}
def __init__(self):
- self.sys = System()
+ self.sys = System(angular_units='Radians')
self.force_mag = 10.0
self.tau = 0.02 # seconds between state updates
@@ -95,10 +95,10 @@ class SwingUpEnv():
self.up_time = 0
new_theta, new_x = self.sys.measure()
- new_theta = radians(new_theta)
theta_dot = (new_theta - theta) / self.tau
x_dot = (new_x - x) / self.tau
self.state = (new_x, x_dot, new_theta, theta_dot)
+ self.sys.add_results(new_theta, new_x, force)
done = x < -self.x_threshold \
or x > self.x_threshold \
@@ -159,7 +159,7 @@ class nnQ(pt.nn.Module):
def forward(self,x,a):
x = pt.tensor(x, dtype = pt.float32)
- b = pt.nn.functional.one_hot(pt.tensor(a), self.numActions)
+ b = pt.nn.functional.one_hot(pt.tensor(a).long(), self.numActions)
c = b.float().detach()
y = pt.cat([x, c])
@@ -179,6 +179,7 @@ class sarsaAgent:
def action(self, x):
# This is an epsilon greedy selection
+ a = 0
if rnd.rand() < self.epsilon:
a = rnd.randint(numActions)
else:
@@ -252,7 +253,6 @@ while step < maxSteps:
y = x_to_y(x)
a = agent.action(y)
u = Actions[a:a+1]
- env.render()
x_next, c, done, info = env.step(u)
max_up_time = info['max_up_time']