diff options
Diffstat (limited to 'System_Python/system_swingup_test.py')
-rw-r--r-- | System_Python/system_swingup_test.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/System_Python/system_swingup_test.py b/System_Python/system_swingup_test.py index e4d3a72..7a9b5bd 100644 --- a/System_Python/system_swingup_test.py +++ b/System_Python/system_swingup_test.py @@ -52,7 +52,7 @@ class SwingUpEnv(): }
def __init__(self):
- self.sys = System()
+ self.sys = System(angular_units='Radians')
self.force_mag = 10.0
self.tau = 0.02 # seconds between state updates
@@ -95,10 +95,10 @@ class SwingUpEnv(): self.up_time = 0
new_theta, new_x = self.sys.measure()
- new_theta = radians(new_theta)
theta_dot = (new_theta - theta) / self.tau
x_dot = (new_x - x) / self.tau
self.state = (new_x, x_dot, new_theta, theta_dot)
+ self.sys.add_results(new_theta, new_x, force)
done = x < -self.x_threshold \
or x > self.x_threshold \
@@ -159,7 +159,7 @@ class nnQ(pt.nn.Module): def forward(self,x,a):
x = pt.tensor(x, dtype = pt.float32)
- b = pt.nn.functional.one_hot(pt.tensor(a), self.numActions)
+ b = pt.nn.functional.one_hot(pt.tensor(a).long(), self.numActions)
c = b.float().detach()
y = pt.cat([x, c])
@@ -179,6 +179,7 @@ class sarsaAgent: def action(self, x):
# This is an epsilon greedy selection
+ a = 0
if rnd.rand() < self.epsilon:
a = rnd.randint(numActions)
else:
@@ -252,7 +253,6 @@ while step < maxSteps: y = x_to_y(x)
a = agent.action(y)
u = Actions[a:a+1]
- env.render()
x_next, c, done, info = env.step(u)
max_up_time = info['max_up_time']
|