RobLAX: A Differentiable Robotics Framework for Physics Augmented Reinforcement Learning

Guo Ye, Qinjie Lin, Tim Tse-Kit Lau, Wanxin Jin, Haozheng Luo, Cheng Zhou, Zhuoran Yang, Zhaoran Wang, Han Liu

Northwestern University

Paper, Code

Summary:

We propose RobLAX, a JAX-implemented framework that augments model-based reinforcement learning with a fully differentiable robotics simulation. All physical parameters in the robotics simulation are learnable, which allows fast model alignment to external interaction data, reducing sim-to-real gap. RobLAX backpropagates through a small number of rollouts, i.e., short unrolling horizons, of robotics simulation combined with a value function learned using the augmented data also from the robotics simulation; such a process achieves more accurate gradient signals and circumvents the challenges of exploding and vanishing gradients for long-duration tasks. We also provide the monotonic performance improvement guarantee for the policy training in RobLAX. The RobLAX framework are implemented using JAX library to enable high-performance numerical computing. We demonstrate RobLAX framework on multiple robotic control tasks, and show its efficiency in comparison to state-of-the-art model-based and model-free baselines.

Illustration of our algorithm. The left model learning part illustrates the process of using trajectories to learn a differentialble physics engine. We use real experience to update the physical parameters like mass or length of the link in our simulator. The policy learning part shows a computational graph describing the policy optimization's objective. Taking advantage of the learned differentiable physics engine, we can access the gradient by backpropagating through the time (BPTT) for H time steps, H(Horizon) is the number of unrolling steps we defined. The introduction of the terminal state value function at the end of the partial trajectory is to reduce the variance and avoid gradient exploding/vanishing.

Video

Performance on Long Duration Tasks

For model-based reinforcement learning, the common issue is the compounded error while backpropagating through abundant timesteps. In this section, we choose three classic control problem with long-duration setting which means each episode has hundreds timesteps.

We first evaluate our algorithm with other five popular model-based and model-free methods and show our superiority. Then we investigate the effect of unrolling horizon length to performance. At the end, a series of experiments on different duration demonstrate our algorithm's generality.

term duration we use here represents the total timesteps in one episode. All experiments share the same unrolling horizon length, reward function and neural network size. % see details in supplentary. Same environment but with different length of duration can lead to various performance. Duration length combined with unrolling horizon is highly related to the model compounded error. It's straightforward that the longer duration is more difficult for learning algorithm to reach a good performance. we choose swingup cartpole to test our algorithm's generality. The following figure shows our method works well under different settings of durations. As the duration increases, the learning curve has more ripples.

Duration 50

Duration 100

Duration 200

Duration 500

Performance on High-Dimension Environmets

In our experiment, we find that popular model-free methods like ppo, sac suffer a lot in dealing with these two environments. We then head to classic trajectory control algorithms like iLQR, Guided Policy Search(GPS) and Pontryagin Differentiable Programming(PDP). As Figure shows, our method outperforms GPS and PDP, it learns the optimal control policy in a fast mode and converge to a low cost.

Quadrotor Loss


Rocket Loss