M3PC: Test-Time Model Predictive Control for Pretrained Masked Trajectory Model
Kehan Wen¹†, Yutong Hu¹,²†, Yao Mu³*, Lei Ke⁴*
¹ETH Zurich, ²KU Leuven, ³Hong Kong University, ⁴Carnegie Mellon University
† Equal contribution
* Corresponding author
Kehan Wen¹†, Yutong Hu¹,²†, Yao Mu³*, Lei Ke⁴*
¹ETH Zurich, ²KU Leuven, ³Hong Kong University, ⁴Carnegie Mellon University
† Equal contribution
* Corresponding author
Abstract:
Recent work in Offline Reinforcement Learning (RL) has shown that a unified Transformer trained under a masked auto-encoding objective can effectively capture the relationships between different modalities (e.g., states, actions, rewards) within given trajectory datasets. However, this information has not been fully exploited during the inference phase, where the agent needs to generate an optimal policy instead of just reconstructing masked components from unmasked ones. Given that a pretrained trajectory model can act as both a Policy Model and a World Model with appropriate mask patterns, we propose using Model Predictive Control (MPC) at test time to leverage the model's own predictive capability to guide its action selection. Empirical results on D4RL and RoboMimic show that our inference-phase MPC significantly improves the decision-making performance of a pretrained trajectory model without any additional parameter training. Furthermore, our framework can be adapted to Offline to Online (O2O) RL and Goal Reaching RL, resulting in more substantial performance gains when an additional online interaction budget is provided, and better generalization capabilities when different task targets are specified.
Overview:
Bidirectional Trajectory Model
The bidirectional trajectory model is pretrained using MAE loss that aims to reconstruct the whole MDP trajectory taken a [Random] masked one. After pretraining, the model show multiple capabilities by applying different test-time masks. E.g., Return-Conditioned Behaviour Clone [RCBC] Mask: Predict actions given states, expected return and context trajectory. Reward and Return Prediction [RP] Mask: Predict rewards and future return given states and actions. Forward Dynamics [FD] Mask: Predict future states given current state and future actions. Inverse Dynamics [ID] Mask: Infer actions needed taken to per- form a given state path.
M3PC
M3PC utilizes a pretrained bidirectional trajectory model's versatile inference capabilities to enhance decision making. Forward M3PC which is shown in the block (a) employ employ [RCBC], [FD] and [RP] masks to build an MPC pipeline for planning, prediction, and action resample. (b) Backward M3PC: Given a goal state that we finally want to reach, we first use Path Inference [PI] mask to infer the waypoint-states, followed by a Inverse Dynamic [ID] mask to get the action sequence conditioned on those waypoints, and finally execute the first one.
Highlights:
Offline and Offline-to-Online Results on MuJoCo
Forward M3PC achieves promising results in both offline RL (left) and offline-to-online RL (right) settings. Typically, M3PC achieves 123% more substantial improvements during online finetuning phase compared to ODT.
Goal Reaching Results on MuJoCo
Backward M3PC drive the agent to follow a predefined trajectory instead of merely imitating the behaviors in offline datasets.
Results of Manipulation Tasks
Backward M3PC can be applied on long horizon tasks (e.g. manipulation). After pretraining the model using RoboMimic Can-Pair dataset (50% can pick and place, 50% throwing away), we can not only reproduce two behaviors seen in pretrained data (left, middle) but also generate unseen one (right).
Citation:
@article{M3PC,
title={M^3PC: Test-Time Model Predictive Control for Pretrained Masked Trajectory Model},
author={Kehan Wen and Yutong Hu and Yao Mu and Lei Ke},
journal={arxiv:2412.05675},
year={2024},
}