Context-aware Dynamics Model
for Generalization in Model-Based Reinforcement Learning

Kimin Lee*, Younggyo Seo*, Seunghyun Lee, Honglak Lee, Jinwoo Shin

*Equal contribution

UC Berkeley, KAIST, University of Michigan, Google Brain

[Github Code] [Paper] [Talk]

Abstract

Model-based reinforcement learning (RL) enjoys several benefits, such as data-efficiency and planning, by learning a model of the environment’s dynamics. However, learning a global model that can generalize across different dynamics is a challenging task. To tackle this problem, we decompose the task of learning a global dynamics model into two stages: (a) learning a context latent vector that captures the local dynamics, then (b) predicting the next state conditioned on it. In order to encode dynamics-specific information into the context latent vector, we introduce a novel loss function that encourages the context latent vector to be useful for predicting both forward and backward dynamics. The proposed method achieves superior generalization ability across various simulated robotics and control tasks, compared to existing RL schemes.

(a) CartPole with varying pole lengths

(b) Pendulum with varying pendulum lengths

(c) HalfCheetah with varying body masses

(d) Ant with varying leg masses

Method

In our paper, we show how context-aware dynamics model improves the generalization performance of model-based RL methods. We combine context-aware dynamics model(CaDM) with (i) vanilla dynamics model(Vanilla DM) and (ii) PE-TS for solving various control tasks from OpenAI Gym. Our method can incorporate any dynamics model simply by conditioning dynamics model on the output from context encoder.

(a) Forward prediction

(b) Backward prediction

(c) Future-step prediction

Results

(1) CaDM significantly improves the generalization performances of baseline model-based methods in all environments.

(a) HalfCheetah

(b) Ant

(c) CrippledHalfCheetah

(d) SlimHumanoid

(2) CaDM can also be used in improving generalization performances of model-free RL method, by conditioning policy on learned context latent vector from CaDM.

(a) HalfCheetah

(b) Ant

(c) CrippledHalfCheetah

(d) SlimHumanoid

Prediction Visualization

We visualize the future state predictions in test environments from CartPole and Pendulum with unseen environment parameters (i.e., force magnitude and mass of pole). Given 10 past states and actions, we generate 20 future state predictions from vanilla dynamics model (Vanilla DM), stacked dynamics model (Stacked DM), and Vanilla + CaDM (ours). We found that CaDM consistently gives accurate predictions across future timesteps, which shows that CaDM can capture contextual information about the transition dynamics.

BibTex

@inproceedings{lee2020context,

title={Context-aware Dynamics Model for Generalization in Model-Based Reinforcement Learning},

author={Lee, Kimin and Seo, Younggyo and Lee, Seunghyun and Lee, Honglak and Shin, Jinwoo},

booktitle={International Conference on Machine Learning},

year={2020}

}