BECAUSE: Bilinear Causal Representation for

Generalizable Offline Model-based RL

NeurIPS 2024 Poster

Abstract

Offline model-based reinforcement learning (MBRL) enhances data efficiency by utilizing pre-collected datasets to learn models and policies, especially in scenarios where exploration is costly or infeasible. Nevertheless, its performance often suffers from the objective mismatch between model and policy learning, resulting in inferior performance despite accurate model predictions.Β  This paper first identifies the primary source of this mismatch comes from the underlying confounders present in offline data for MBRL. Subsequently, we introduce Bilinear Causal Representation~(BECAUSE), an algorithm to capture causal representation for both states and actions to reduce the influence of the distribution shift, thus mitigating the objective mismatch problem. Comprehensive evaluations on 18 tasks that vary in data quality and environment context demonstrate the superior performance of BECAUSE over existing offline RL algorithms. We show the generalizability and robustness of BECAUSE under fewer samples or larger numbers of confounders. Additionally, we offer theoretical analysis of BECAUSE to prove its error bound and sample efficiency when integrating causal representation into offline MBRL.Β 

Motivation: Objective Mismatch

Model Errors v.s. Planning Reward

Small errors in dynamics may still result in task failure!

What is BECAUSE?

In the presence of a hidden confounder u, we model the confounder behind the transition dynamics as a linear confounded MDP, and model the policy confounder induced by the sub-optimal behavior policies between the current state and action. According to the formulation in the overall Action-State Confounded MDP (ASC-MDP), the future state s' Β will be conditionally independent given the dynamics confounder and the past state and action pairs (s, a). We parameterize the transition function into a bilinear structure, with two feature encoders $\mu, \phi$ and a core matrix M.Β 

Formulation: ASC-MDP

Structured Confounders: (i) Policy Confounders; (ii) Dynamics Confounders

How do we learn BECAUSE?

We first learn the causal world model $T(s'|s, a)$ in the presence of confounders $u$ in the offline datasets. As formulated in ASC-MDP~\ref{def:asc_mdp}, there are two sets of confounders: $u_\pi$ and $u_c$. To estimate an unconfounded transition model and remove the effect of confounder, we first remove the impact of $u_c$ which comes from the dynamics shift by estimating a batch-wise transition matrix $M(u_c)$, then we apply a reweighting formula to deconfound $u_\pi$ induced by the behavior policies and mitigate the model objective mismatch.Β 

Learning Process of BECAUSE

Alternative Optimization:Β 

(1) Estimating causal mask M w/ sparsity regularization

(2) Optimizing feature encoders $\phi, \mu$

(3) Mask reweighting

How do we use BECAUSE?

To avoid entering OOD states in the online deployment, we further design a pessimistic planner according to the uncertainty of the predicted trajectories in the imagination rollout step to mitigate objective mismatch. We use the feature embedding from bilinear causal representation to help quantify the uncertainty, denoted as $E_\theta(s,a)$.Β  As we have access to the offline dataset, we learn an Energy-based Model~(EBM) to quantify the causal uncertainty of the state action pairs along the planning trajectories and take this into account during the online planning phase. We will plan a rewarding and most causally confident path to the goal in the online deployment.Β 

Theoretical Analysis

Key differences with the prior bounds: (1) Bilinear MDP in offline setting; (2) No closed-form solution under sparsity regularizations.Β Β 

Experiments

Lift: Object manipulation environment in RoboSuite. We designed this environment for the agent to lift an object with a specific color configuration on the table to a desired height. In the OOD environment, there is an injected spurious correlation between the color of the cube and the position of the cube in the training phase. \diff{During the testing phase, the correlation between color and position is different from training.

Unlock: We designed this environment for the agent to collect a key to open doors in Minigrid. In the OOD environment \textit{Unlock-O}, there will be a different number of goals~(doors to be opened) in the testing environments from the training environments.}

Crash: Safety is critical in autonomous driving, which is reflected by the collision avoidance capability. We consider a risky scenario where an AV collides with a jaywalker because its view is blocked by another car. We design such a crash scenario based on highway-env, where the goal is to create crashes between a pedestrian and AVs. In the OOD environment, the distribution of reward~(number of pedestrians) is different in online testing environments.Β 

Key Findings

Quantitative Comparison in the Unlock

BECAUSE (IID)

BECAUSE (OOD)

MOPO (IID)

MOPO (OOD)

Alleviation of Objective Mismatch