Modularity and Attention as Key Ingredients for Generalization

Anirudh Goyal, Alex Lamb

Description

Deep neural networks have achieved excellent results on perceptual tasks, yet still struggle to

match the adaptivity and generalization performance of humans. As an example of this:

humans can easily recognize familiar objects when placed in a new background, yet neural

networks often struggle with this. As another example, neural networks can be fooled by small

background perturbations or local perturbations which humans are robust to. A central cause of

this problem is that neural networks generally represent their world with a single large hidden

state (or recurrent state) which entangles many different aspects of the world and thus may fail

to generalize when some of these aspects are changed.


A key ingredient for addressing this challenge is creating deep networks in which information is

processed in a dynamic and selective way - allowing the model to be robust to changes in the

data distribution. Our tutorial will focus on the critical ideas which have been developed in this

space and discuss how they improve generalization. We explore advances in attention and

modularity which allow models to keep information well-separated. We overview ways of

introducing sparsity and its connection to causality. This tutorial is also concrete - and we will

discuss real-world tasks which suffer from these problems and how they can be addressed with

current techniques.

Schedule

Session A (1h35m)

  • Inductive biases for higher level cognition (15 slides, 20m)

  • Role of causality (9 slides, 15m)

  • Independent Mechanisms and Modularity (17 slides, 25m)

  • Recurrent Independent Mechanisms (26 slides, 35m)

Session B (1h35m)

  • RIMs and Meta-Learning (8 slides, 15m)

  • Schemata and Object Files (13 slides, 20m)

  • Fast and Slow thinking (5 slides, 5m)

  • Top-Down and Bottom-Up (21 slides, 25m)

  • Credit Assignment and Time (5 slides, 10m)

  • New Objective Functions (5 slides, 10m)

  • Practical Considerations (5 slides, 10m)

Tutorial Lectures

Slides covering tutorial content (may end up becoming slightly out of date).

https://drive.google.com/file/d/1Gl-s5_lmBGnyCkwSIm3usj1Rlwqc_yvm/view?usp=sharing

Guide: Setting up RIMs in a GRU/LSTM Codebase

Github: One file RIMs model.

Using RIMs instead of an LSTM or GRU can be as simple as adding the new class and changing:

model = nn.LSTM(600,300)

to:

model = RIM_LSTM(600,300,6,4)

In general, setting ~6 RIMs and 4-5 of them active per-step is a pretty safe choice.


We can use a few github repos for RIMs:

(1) Author's official github repo:

https://github.com/anirudh9119/RIMs

Within that, this interface can match the exact same interface as a GRU:

https://github.com/anirudh9119/RIMs/blob/master/event_based/block_wrapper.py

(2) Code reproducing RIMs in pytorch.

This code is set up for the minigrid generalization tasks.

https://github.com/dido1998/Recurrent-Independent-Mechanisms

(3) Tensorflow Code for RIMs

https://github.com/fuyuan-li/tensorflow-RIMs


Future work: we're not aware of a Jax implementation of RIMs, but if you can find one (or want to produce one) that would be extremely useful for us!