MoCoDA: Model-based Counterfactual Data Augmentation

Silviu Pitis, Elliot Creager, Ajay Mandlekar, Animesh Garg

36th Conference on Neural Information Processing Systems (NeurIPS 2022)
(Poster Session 5 on Thursday, December 1)

OpenReview | Code: https://github.com/spitis/mocoda

Arxiv: https://arxiv.org/abs/2210.11287


Abstract: The number of states in a dynamic process is exponential in the number of objects, making reinforcement learning (RL) difficult in complex, multi-object domains. For agents to scale to the real world, they will need to react to and reason about unseen combinations of objects. We argue that the ability to recognize and use local factorization in transition dynamics is a key element in unlocking the power of multi-object reasoning. To this end, we show that (1) known local structure in the environment transitions is sufficient for an exponential reduction in the sample complexity of training a dynamics model, and (2) a locally factored dynamics model provably generalizes out-of-distribution to unseen states and actions. Knowing the local structure also allows us to predict which unseen states and actions this dynamics model will generalize to. We propose to leverage these observations in a novel Model-based Counterfactual Data Augmentation (MoCoDA) framework. MoCoDA applies a learned locally factored dynamics model to an augmented distribution of states and actions to generate counterfactual transitions for RL. MoCoDA works with a broader set of local structures than prior work and allows for direct control over the augmented training distribution. We show that MoCoDA enables RL agents to learn policies that generalize to unseen states and actions. We use MoCoDA to train an offline RL agent to solve an out-of-distribution robotics manipulation task on which standard offline RL algorithms fail.


MoCoDA in 3 Steps


Assuming:

    • Empirical dataset

    • Object-oriented state

    • Known causal transition structure (when objects interact to cause an object at the next state)


Steps:

1. Generate parent distribution: Using our knowledge about groups of objects that collectively "cause" an object at the next state (the "parent sets"), we train a generator on the empirical dataset to generate a parent distribution. This parent distribution differs from the joint empirical state-action distribution, but the marginal of each parent set in the parent distribution is trained to match its empirical marginal.

2. Generate dataset: Using the empirical data, we train a locally-factored dynamics model that respects the local causal structure. We apply our dynamics model to the parent distribution from Step 1 to generate an augmented dataset of (s, a, s') tuples. Since the parents of each causal mechanism in the dynamics model are well supported in the empirical dataset, and the parent distribution has similar marginals, the dynamics model generalizes well on the parent distribution, so that the augmented dataset is relatively accurate.

3. Train agent: We label the augmented data with the target task reward and train a reinforcement learning agent on the augmented data. This agent generalizes on the parent distribution, whose support may be larger than the support of the empirical distribution.

Hook Sweep Visualization


Empirical Data:

The agent only sees trajectories from a noisy expert that pushes exactly one, never both, blocks to one side of the table.

sweep1_expert_with_bonus_reward.mp4


Trained MoCoDA Agent (TD3-BC with Prioritized MoCoDA):

With Prioritized MoCoDA (rebalance the MoCoDA distribution to be uniform in the goal space), the agent is able to learn to push both blocks to one side of the table. With unprioritized MoCoDA, the agent gets approximately 40% success. Without MoCoDA, the agent does not learn anything useful (<5% success).


td3_bc_mocoda.mp4