Completed in an afternoon. Was very easy to train, with a simple reward function. It bounces for about 38 minutes before failing.
Full script in case anyone is interested:
import os
import asyncio
import stable_baselines3.ppo
import torch
import prototwin
import prototwin_gymnasium
import gymnasium
import numpy as np
import math
import stable_baselines3
import stable_baselines3.common
import stable_baselines3.common.vec_env
from stable_baselines3.common.callbacks import CheckpointCallback
# Signal addresses (copy/paste from ProtoTwin).
link5_motor_state = 7
link5_motor_target_position = 8
link5_motor_target_velocity = 9
link5_motor_force_limit = 10
link5_motor_current_position = 11
link5_motor_current_velocity = 12
link5_motor_current_force = 13
link1_motor_state = 14
link1_motor_target_position = 15
link1_motor_target_velocity = 16
link1_motor_force_limit = 17
link1_motor_current_position = 18
link1_motor_current_velocity = 19
link1_motor_current_force = 20
link6_motor_state = 21
link6_motor_target_position = 22
link6_motor_target_velocity = 23
link6_motor_force_limit = 24
link6_motor_current_position = 25
link6_motor_current_velocity = 26
link6_motor_current_force = 27
surface_position_sensor_x = 49
surface_position_sensor_y = 50
surface_position_sensor_z = 51
surface_orientation_sensor_x = 52
surface_orientation_sensor_y = 53
surface_orientation_sensor_z = 54
surface_orientation_sensor_w = 55
surface_linear_velocity_sensor_x = 56
surface_linear_velocity_sensor_y = 57
surface_linear_velocity_sensor_z = 58
surface_angular_velocity_sensor_x = 59
surface_angular_velocity_sensor_y = 60
surface_angular_velocity_sensor_z = 61
link3_motor_state = 28
link3_motor_target_position = 29
link3_motor_target_velocity = 30
link3_motor_force_limit = 31
link3_motor_current_position = 32
link3_motor_current_velocity = 33
link3_motor_current_force = 34
link2_motor_state = 35
link2_motor_target_position = 36
link2_motor_target_velocity = 37
link2_motor_force_limit = 38
link2_motor_current_position = 39
link2_motor_current_velocity = 40
link2_motor_current_force = 41
link4_motor_state = 42
link4_motor_target_position = 43
link4_motor_target_velocity = 44
link4_motor_force_limit = 45
link4_motor_current_position = 46
link4_motor_current_velocity = 47
link4_motor_current_force = 48
ball_position_sensor_x = 1
ball_position_sensor_y = 2
ball_position_sensor_z = 3
ball_linear_velocity_sensor_x = 4
ball_linear_velocity_sensor_y = 5
ball_linear_velocity_sensor_z = 6
torques = [link1_motor_current_force, link2_motor_current_force, link3_motor_current_force, link4_motor_current_force, link5_motor_current_force, link6_motor_current_force]
states = [link1_motor_current_position, link2_motor_current_position, link3_motor_current_position, link4_motor_current_position, link5_motor_current_position, link6_motor_current_position,
link1_motor_current_velocity, link2_motor_current_velocity, link3_motor_current_velocity, link4_motor_current_velocity, link5_motor_current_velocity, link6_motor_current_velocity,
surface_position_sensor_x, surface_position_sensor_y, surface_position_sensor_z,
surface_orientation_sensor_x, surface_orientation_sensor_y, surface_orientation_sensor_z, surface_orientation_sensor_w,
surface_linear_velocity_sensor_x, surface_linear_velocity_sensor_y, surface_linear_velocity_sensor_z,
surface_angular_velocity_sensor_x, surface_angular_velocity_sensor_y, surface_angular_velocity_sensor_z,
ball_position_sensor_x, ball_position_sensor_y, ball_position_sensor_z,
ball_linear_velocity_sensor_x, ball_linear_velocity_sensor_y, ball_linear_velocity_sensor_z]
actions = [link1_motor_target_position, link2_motor_target_position, link3_motor_target_position, link4_motor_target_position, link5_motor_target_position, link6_motor_target_position]
torque_size = len(torques)
state_size = len(states)
action_size = len(actions)
obs_size = state_size + action_size
ball_x_zero = 0.0003
ball_y_zero = 0.5828
ball_z_zero = 0.2957
bat_x_zero = 0.0003
bat_y_zero = -0.2247
bat_z_zero = 0.3017
class PingPongEnv(prototwin_gymnasium.VecEnvInstance):
def __init__(self, client: prototwin.Client, instance: int):
super().__init__(client, instance)
self.time = 0
self.previous_action = [0] * action_size
def _reward_torque_penalty(self):
normsq = 0
for i in range(torque_size):
torque = self.get(torques[i])
normsq += torque * torque
return normsq
def _reward_position_penalty(self):
bat_x = self.get(surface_position_sensor_x)
bat_y = self.get(surface_position_sensor_y)
bat_z = self.get(surface_position_sensor_z)
dx = bat_x - bat_x_zero
dy = bat_y - bat_y_zero
dz = bat_z - bat_z_zero
return dx * dx + dy * dy + dz * dz
def _reward_ball_height(self):
# We reward based on how close the ball is to the initial (zero) position
# Only reward if the ball is sufficiently above the bat, to avoid rewarding
# for the ball resting on the bat
ball_y = self.get(ball_position_sensor_y)
bat_y = self.get(surface_position_sensor_y)
if ball_y - bat_y > 0.05:
ball_y_distance = ball_y - ball_y_zero
return math.exp(-5 * ball_y_distance * ball_y_distance)
return 0
def _reward_fn(self, obs):
dt = 0.005
reward = 0
reward += self._reward_ball_height() * dt
reward -= self._reward_position_penalty() * dt
reward -= self._reward_torque_penalty() * dt * 0.0002
return reward
def _terminal(self):
ball_x_min = ball_x_zero - 0.3
ball_x_max = ball_x_zero + 0.3
ball_y_min = ball_y_zero - 1.0
ball_y_max = ball_y_zero + 0.4
ball_z_min = ball_z_zero - 0.3
ball_z_max = ball_z_zero + 0.3
bat_x_min = bat_x_zero - 0.3
bat_x_max = bat_x_zero + 0.3
bat_y_min = bat_y_zero - 0.3
bat_y_max = bat_y_zero + 0.3
bat_z_min = bat_z_zero - 0.2
bat_z_max = bat_z_zero + 0.3
ball_x = self.get(ball_position_sensor_x)
ball_y = self.get(ball_position_sensor_y)
ball_z = self.get(ball_position_sensor_z)
bat_x = self.get(surface_position_sensor_x)
bat_y = self.get(surface_position_sensor_y)
bat_z = self.get(surface_position_sensor_z)
return (ball_x < ball_x_min or ball_x > ball_x_max or
ball_y < ball_y_min or ball_y > ball_y_max or
ball_z < ball_z_min or ball_z > ball_z_max or
bat_x < bat_x_min or bat_x > bat_x_max or
bat_y < bat_y_min or bat_y > bat_y_max or
bat_z < bat_z_min or bat_z > bat_z_max)
def reset(self, seed = None):
super().reset(seed=seed)
self.previous_action = [0] * action_size
return np.array([0] * obs_size), {}
def apply(self, action):
self.action = action
for i in range(action_size):
self.set(actions[i], action[i])
def step(self):
obs = [0] * obs_size
for i in range(state_size):
obs[i] = self.get(states[i])
for i in range(action_size):
obs[state_size + i] = self.previous_action[i]
reward = self._reward_fn(obs)
self.previous_action = self.action
done = self._terminal()
truncated = self.time > 60
return obs, reward, done, truncated, {}
async def main():
client = await prototwin.start(dev=True)
await client.load(os.path.join(os.path.dirname(__file__), "pingpong.ptm"))
action_high = np.array([0.6] * action_size, dtype=np.float32)
action_space = gymnasium.spaces.Box(-action_high, action_high, dtype=np.float32)
observation_high = np.array([np.finfo(np.float32).max] * obs_size, dtype=np.float32)
observation_space = gymnasium.spaces.Box(-observation_high, observation_high, dtype=np.float32)
save_freq = 4000 # Number of timesteps per instance
def lr_schedule(progress_remaining):
initial_lr = 0.0003
return initial_lr * progress_remaining
policy_kwargs = dict(activation_fn=torch.nn.ReLU, net_arch=dict(pi=[256, 128, 64, 32], vf=[256, 128, 64, 32]))
instances = 10*10
steps = 4000
batch_size = 8000
ent_coef = 0.0003
env = prototwin_gymnasium.VecEnv(PingPongEnv, client, "Main", instances, observation_space, action_space, pattern=prototwin.Pattern.GRID, spacing=1)
monitored = stable_baselines3.common.vec_env.VecMonitor(env)
checkpoint_callback = CheckpointCallback(save_freq=save_freq, save_path="./logs/checkpoints/", name_prefix="checkpoint", save_replay_buffer=True, save_vecnormalize=True)
model = stable_baselines3.PPO(stable_baselines3.ppo.MlpPolicy, monitored, verbose=1, ent_coef=ent_coef, learning_rate=lr_schedule, batch_size=batch_size, n_steps=steps, policy_kwargs=policy_kwargs, tensorboard_log="./tensorboard/")
model.learn(total_timesteps=100_000_000, callback=checkpoint_callback)
asyncio.run(main())