TL;DR

Masked autoencoders (MAE) has emerged as a scalable and effective self-supervised learning technique. Can MAE be also effective for visual model-based RL? Yes! with the recipe of convolutional feature masking and reward prediction to capture fine-grained and task-relevant information.

Abstract

Visual model-based reinforcement learning (RL) has the potential to enable sample-efficient robot learning from visual observations. Yet the current approaches typically train a single model end-to-end for learning both visual representations and dynamics, making it difficult to accurately model the interaction between robots and small objects. In this work, we introduce a visual model-based RL framework that decouples visual representation learning and dynamics learning. Specifically, we train an autoencoder with convolutional layers and vision transformers (ViT) to reconstruct pixels given masked convolutional features, and learn a latent dynamics model that operates on the representations from the autoencoder. Moreover, to encode task-relevant information, we introduce an auxiliary reward prediction objective for the autoencoder. We continually update both autoencoder and dynamics model using online samples collected from environment interaction. We demonstrate that our decoupling approach achieves state-of-the-art performance on a variety of visual robotic tasks from Meta-world and RLBench, e.g., we achieve 81.7% success rate on 50 visual robotic manipulation tasks from Meta-world, while the baseline achieves 67.9%.

Masked World Models

We present Masked World Models (MWM), a visual model-based RL algorithm that decouples visual representation learning and dynamics learning. The key idea of MWM is to train an autoencoder that reconstructs visual observations with convolutional feature masking, and a latent dynamics model on top of the autoencoder. By introducing early convolutional layers and masking out convolutional features instead of pixel patches, our approach enables the world model to capture fine-grained visual details from complex visual observations. Moreover, in order to learn task-relevant information that might not be captured solely by the reconstruction objective, we introduce an auxiliary reward prediction task for the autoencoder.

Specifically, we separately update visual representations and dynamics by repeating the iterative processes of (i) training the autoencoder with convolutional feature masking and reward prediction, and (ii) learning the latent dynamics model that predicts visual representations from the autoencoder. Here, we emphasize that our framework is not pre-training and fine-tuning scheme, but continually updates the autoencoder and latent dynamics model using the samples collected from environment interaction.

Why Conv Masking?

One might wonder why we use convolutional feature masking. It's from our observation that it's sometimes difficult to learn fine-grained details within patches! See how object positions within masked patches could be wrong in reconstructions in below figures.

1st row: Ground-truth / 2nd row: Masked images / 3rd row: MAE Reconstructions

1st row: Ground-truth / 2nd row: Masked images / 3rd row: MAE Reconstructions

Experimental Setups

We consider Meta-world, RLBench, DeepMind Control Suite as our experimental benchmarks

Visual Robotic Manipulation Experiments

We find that MWM significantly outperforms DreamerV2 on various robotic manipulation tasks from Meta-world and RLBench!

DeepMind Control Suite Experiments

We also evaluate MWM on DeepMind Control Suite tasks, where we find that gain from MWM is clear on manipulation tasks where capturing fine-grained details is important

Ablation Studies

  • (a) We show that convolutional feature masking is much more effective than pixel patch masking, i.e., MAE.

  • (b) We find that high masking ratio (75%. but not too high as 90%) can be more effective.

  • (c) We find that reward prediction makes a big difference.

Prediction Analysis

We provide the randomly sampled visualizations, supporting the results of Figure 7 in the original draft. We observe that predictions from our latent dynamics model are better at capturing the positions of red blocks (i.e., a target a robot arm should reach), when compared to the predictions from DreamerV2. We would like to emphasize that we do not cherry-pick the results, so predictions sometime can be wrong in some cases (e.g., wrong colors or positions).

  • 1st row: Ground-Truth Images

  • 2nd row: Reconstructions (MWM-Ours)

  • 3rd row: Predictions (MWM-Ours)

  • 4th row: Predictions (DreamerV2)

  • 1st row: Ground-Truth Images

  • 2nd row: Reconstructions (MWM-Ours)

  • 3rd row: Predictions (MWM-Ours)

  • 4th row: Predictions (DreamerV2)