Modularity and Attention as Key Ingredients for Generalization
Anirudh Goyal, Alex Lamb
Deep neural networks have achieved excellent results on perceptual tasks, yet still struggle to
match the adaptivity and generalization performance of humans. As an example of this:
humans can easily recognize familiar objects when placed in a new background, yet neural
networks often struggle with this. As another example, neural networks can be fooled by small
background perturbations or local perturbations which humans are robust to. A central cause of
this problem is that neural networks generally represent their world with a single large hidden
state (or recurrent state) which entangles many different aspects of the world and thus may fail
to generalize when some of these aspects are changed.
A key ingredient for addressing this challenge is creating deep networks in which information is
processed in a dynamic and selective way - allowing the model to be robust to changes in the
data distribution. Our tutorial will focus on the critical ideas which have been developed in this
space and discuss how they improve generalization. We explore advances in attention and
modularity which allow models to keep information well-separated. We overview ways of
introducing sparsity and its connection to causality. This tutorial is also concrete - and we will
discuss real-world tasks which suffer from these problems and how they can be addressed with
Session A (1h35m)
Inductive biases for higher level cognition (15 slides, 20m)
Role of causality (9 slides, 15m)
Independent Mechanisms and Modularity (17 slides, 25m)
Recurrent Independent Mechanisms (26 slides, 35m)
Session B (1h35m)
RIMs and Meta-Learning (8 slides, 15m)
Schemata and Object Files (13 slides, 20m)
Fast and Slow thinking (5 slides, 5m)
Top-Down and Bottom-Up (21 slides, 25m)
Credit Assignment and Time (5 slides, 10m)
New Objective Functions (5 slides, 10m)
Practical Considerations (5 slides, 10m)
Guide: Setting up RIMs in a GRU/LSTM Codebase
Github: One file RIMs model.
Using RIMs instead of an LSTM or GRU can be as simple as adding the new class and changing:
model = nn.LSTM(600,300)
model = RIM_LSTM(600,300,6,4)
In general, setting ~6 RIMs and 4-5 of them active per-step is a pretty safe choice.
We can use a few github repos for RIMs:
(1) Author's official github repo:
Within that, this interface can match the exact same interface as a GRU:
(2) Code reproducing RIMs in pytorch.
This code is set up for the minigrid generalization tasks.
(3) Tensorflow Code for RIMs
Future work: we're not aware of a Jax implementation of RIMs, but if you can find one (or want to produce one) that would be extremely useful for us!