Inducing, Detecting and Characterising Neural Modules:
A Pipeline for Functionally Interpretable Reinforcement Learning
Inducing, Detecting and Characterising Neural Modules:
A Pipeline for Functionally Interpretable Reinforcement Learning
Abstract
Interpretability is crucial for ensuring RL systems align with human values, but remains extremely challenging to achieve in complex decision making domains. Existing methods frequently attempt interpretability level of fundamental model units, such as neurons or decision nodes: an approach which scales poorly to large models.
Consequently, we instead propose to approach interpretability at the level of functional modularity. We demonstrate how training modifications can encourage the emergence of modularity in RL policy networks, and develop a neural-network specific method for detecting these modules.
Applying these methods to 2D and 3D Minigrid environments consistently finds navigational modules which act along different axes. To validate this functionality, we directly intervene on module weights and analyse the resulting behavioural changes.
By interpreting RL decision making at the level of functional modules rather than fundamental model units, we offer a promising level of abstraction for model understanding which balances explanation accuracy and tractability for a human interpreter.
Inducing Modularity
We induce modularity by encouraging locality and sparsity in network connections by introducing a connection cost loss (CC loss) in training. Each neuron in an MLP policy network is given a 2D coordinate, and the CC loss is defined as a log based sparsity applied to weight magnitudes, scaled by the 'distance' between the relevant neurons. Neurons are further relocated within their layers during training, by swapping positions in the manner that minimises the total CC of the network.
As shown above for 3 Minigrid environments, as the scaling λ of the CC loss is increased, structurally independent modules emerge within MLP based policy networks.
Detecting Modules
In order to automatically detect these emergent modules, we develop an extended version of the Louvain community detection algorithm which accounts for the constrained architecture of neural networks, and our interpretability objectives . We optimise modules for isolation and correlation alignment - the extent to which the activations of neurons within a specific module are correlated. We show the resulting modules below for three minigrid environments: dynamic obstacles (left), where an agent must avoid 3 moving balls; go to key (centre) where the agent must navigate to the correct target key; and 3D dynamic obstacles (right) which extends dynamic obstacles to a 3D grid.
Characterising Functionality
The detected modules appear to seperately control naviagtion along different axes. However, interpreting neural networks based on visualisation risks subjectivity, so we propose to empirically evaluate module functionality by directly modifying module parameters prior to inference. We find that 'disabling' a module, by setting all its weights to a large negative number significantly reduces the frequency of actions along the corresponding axes, while not increasing the failure rate along the remaining axes. This evidences that a high level of functional independence has been achieved between decision making along different axes.
The right animation shows behaviour when community 1 (the left/right module) is disabled. The agent retains the ability to avoid the obstacles by moving up and down, but is unable to move left and right. Tap on a graph node (left) to trace its connectivity through the network, or on the graph background to reset.
Application to Pong
We also present initial results from applying the framework to Pong. We train PPO agents on the Gymnax implementation adapted to return a symbolic observation of paddle coordinates, ball coordinates and ball velocity. We find that the simplicity of the task results in a single module, which learns to ignore the 'stay' action, ball velocities and player x coordinate. Interestingly the agent still relies on the opponent paddles x and y positions (our trained agent is P1 and the opponent, which implements a simple 'follow ball' policy is P2).
We note that these are initial results. We have not performed hyperparameter tuning for performance (average scores range from 3 to 6) and have not implemented the pruning and finetuning stages described in the paper. Training has been run on a single seed.
Immediate future work will complete the implementation and quantify the performance changes observed over a larger number of experiments. We will also explore the impact of ablating input connections, particularly the opponent positions (P2), to evaluate their significance.
Conclusions
By adressing interpretability at a higher level of abstraction, we present an approach to RL interpretability which can remain tractable when applied to complex model architectures and decision making domains. To enable this, we demonstrate how modules can be induced, detected and characterised in an automated and thus scalable manner. Future work should explore the broad potential applicability of this approach within different model architectures, particularly in complex, real-world domains.