We tackle the Multi-task Batch Reinforcement Learning problem. Given multiple datasets collected from different tasks, we train a multi-task policy to perform well in unseen tasks sampled from the same distribution. The task identities of the unseen tasks are not provided. To perform well, the policy must infer the task identity from collected transitions by modelling its dependency on states, actions and rewards. Because the different datasets may have state-action distributions with large divergence, the task inference module can learn to ignore the rewards and spuriously correlate only state-action pairs to the task identity, leading to poor test time performance. To robustify task inference, we propose a novel application of the triplet loss. To mine hard negative examples, we relabel the transitions from the training tasks by approximating their reward functions. When we allow further training on the unseen tasks, using the trained policy as an initialization leads to significantly faster convergence compared to randomly initialized policies (up to $80\%$ improvement and across 5 different Mujoco task distributions). We name our method MBML Multi-task Batch RL with Metric Learning).
We measure performance by the average return over unseen tasks, sampled from the same task distribution. We compare our model with two natural baselines. The first is by modifying PEARL to train from the batch, instead of allowing PEARL to collect more transitions. By including the results for PEARL, we demonstrate that conventional algorithms that require interaction with the environment during training does not perform well in the Multi-task Batch RL setting, which motivates our work. The second baseline is by modifying BCQ to incorporate a task inference module. By comparing this baseline, we argue that the problem we are facing cannot be solved by simply combining the current Batch RL algorithm with a simple task inference module.
While the multi-task policy generalize to unseen tasks, its performance is not optimal. If we allow further training, initializing networks with our multi-task policy significantly speeds up convergence to the optimal performance.