How to run the code:
Name this file check_equivariant.py and put it in the source folder of the supplementary material code.
Run "python check_equivariant.py"
"""This example creates an NFN to process the weight space of a MLP and output another weight space of MLP,
then verifies that the NFN is equivariant under the group action g.
"""
import os
from tqdm import tqdm
import random
import torch
from einops import rearrange
from torch import nn
import numpy as np
from nfn.common import WeightSpaceFeatures, network_spec_from_wsfeat
from nfn.layers import HNPSMixerLinear, TupleOp
def make_nfn(network_spec, nfn_channels = 4):
return nn.Sequential(
HNPSMixerLinear(network_spec, in_channels=1, out_channels = nfn_channels),
)
def set_seed(manualSeed=3):
random.seed(manualSeed)
np.random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(manualSeed)
def check_params_eq(params1: WeightSpaceFeatures, params2: WeightSpaceFeatures):
equal = True
for p1, p2 in zip(params1.weights, params2.weights):
equal = equal and torch.allclose(p1, p2, atol=1e-2)
for p1, p2 in zip(params1.biases, params2.biases):
equal = equal and torch.allclose(p1, p2, atol=1e-2)
return equal
def sample_perm_scale(layer_sizes):
"""Sample a random permutation and random scale for each of the hidden layers."""
perms = []
scales = []
for layer_size in layer_sizes[1:-1]:
perms.append(torch.eye(layer_size)[torch.randperm(layer_size)])
scales.append(torch.diag_embed(torch.abs(torch.rand(layer_size))))
return (perms, scales)
def apply_perm_scale_to_params(params, perm_scales):
"""Apply a list of *coupled* perms to weights."""
uncoupled_perm_scales = convert_coupled_perm_scale_to_uncoupled(perm_scales, params)
return apply_uncoupled_perm_scale_to_params(params, uncoupled_perm_scales)
def convert_coupled_perm_scale_to_uncoupled(perm_scales, params: WeightSpaceFeatures):
"""
input: list of permutation matrices
output: list of 2-tuples of permutation matrices, coupled
in (row_perm, col_perm format)
"""
out_perms = []
out_scales = []
prev_perm = torch.eye(params.weights[0].shape[3])
pre_scale = torch.eye(params.weights[0].shape[3])
for perm, scale in zip(perm_scales[0], perm_scales[1]):
out_perms.append((perm, prev_perm.T))
# inverse by diagnal inverse to reduce floating point error
out_scales.append((scale, torch.diag_embed(torch.diag(pre_scale) ** (-1))))
prev_perm = perm
pre_scale = scale
out_perms.append((torch.eye(params.weights[-1].shape[2]), prev_perm.T))
out_scales.append((torch.eye(params.weights[-1].shape[2]), torch.diag_embed(torch.diag(pre_scale) ** (-1))))
return out_perms, out_scales
def apply_uncoupled_perm_scale_to_params(params: WeightSpaceFeatures, perm_scales):
"""Perms is a list of 2-tuples of permutation matrices, one for rows and one for columns."""
perms = perm_scales[0]
scales = perm_scales[1]
permed_weights = []
permed_bias = []
for (row_perm, col_perm), (row_scale, col_scale), weight, bias in zip(perms, scales, params.weights, params.biases):
h, w = None, None
if weight.dim() == 6: # conv filter bank
h, w = weight.shape[-2:]
weight = rearrange(weight, 'b c i j k l -> b (c k l) i j')
permed_weight = row_scale[None, None] @ row_perm[None, None] @ weight @ col_perm[None, None] @ col_scale[None, None]
if h is not None:
permed_weight = rearrange(permed_weight, 'b (c k l) i j -> b c i j k l', k=h, l=w)
permed_weights.append(permed_weight)
permed_bias.append((row_scale[None, None] @ row_perm[None, None] @ bias.unsqueeze(-1)).squeeze(-1))
return WeightSpaceFeatures(permed_weights, permed_bias)
def sample_params(bs, layer_sizes):
weights, biases = [], []
for i in range(len(layer_sizes) - 1):
weights.append(torch.randn(bs, 1, layer_sizes[i + 1], layer_sizes[i]))
biases.append(torch.randn(bs, 1, layer_sizes[i + 1]))
return WeightSpaceFeatures(weights, biases)
def test_layer_equivariance_permutation_scale():
input_network_size = [1,32,32,3]
spec = network_spec_from_wsfeat(sample_params(1, input_network_size))
nfn = make_nfn(spec)
params = sample_params(1, input_network_size)
out = nfn(params)
all_equivariance = True
equivariance_count = 0
test_samples = 2500
for i in (pbar := tqdm(range(test_samples))):
perm_scales = sample_perm_scale(input_network_size)
permed_params = apply_perm_scale_to_params(params, perm_scales)
out_of_permed = nfn(permed_params)
permed_out = apply_perm_scale_to_params(out, perm_scales)
equiv = check_params_eq(out_of_permed, permed_out)
all_equivariance = all_equivariance and equiv
equivariance_count+=1 if equiv else 0
pbar.set_description(f"Equivariant check passed: {equivariance_count}/{(i+1)} models")
if __name__ == "__main__":
set_seed(4)
torch.set_default_dtype(torch.float64) # for more precision
test_layer_equivariance_permutation_scale()
How to run the code:
Name this file check_invariant.py and put it in the source folder of the supplementary material code.
Run "python check_invariant.py"
"""This example creates an NFN to process the weight space of a MLP and output a vector,
then verifies that the NFN is permutation invariant.
"""
import os
import random
import numpy as np
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data.dataloader import default_collate
from nfn.common import state_dict_to_tensors, WeightSpaceFeatures, network_spec_from_wsfeat
from nfn.layers import TupleOp, HNPSMixerLinear, HNPSMixerInv
from examples.basic_cnn.helpers import sample_perm_scale
def set_seed(manualSeed=3):
random.seed(manualSeed)
np.random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(manualSeed)
def make_MLP():
return nn.Sequential(
nn.Linear(2, 32), nn.ReLU(),
nn.Linear(32, 32), nn.ReLU(),
nn.Linear(32, 3)
)
def make_nfn(network_spec, nfn_channels = 4):
return nn.Sequential(
HNPSMixerInv(network_spec, 1, nfn_channels),
)
def test_layer_invariance_permutation_scale(num_random_group_actions):
# Constructed two feature maps, g(U) and U.
wts_and_bs, wts_and_bs_perm_scale = [], []
for _ in range(num_random_group_actions):
sd = make_MLP().state_dict()
wts_and_bs.append(state_dict_to_tensors(sd))
state_dict_tensors_perm_scale = sample_perm_scale(sd)
wts_and_bs_perm_scale.append(state_dict_to_tensors(state_dict_tensors_perm_scale))
wtfeat = WeightSpaceFeatures(*default_collate(wts_and_bs))
wtfeat_perm = WeightSpaceFeatures(*default_collate(wts_and_bs_perm_scale))
# Create the NFN
in_network_spec = network_spec_from_wsfeat(wtfeat)
nfn = make_nfn(in_network_spec)
# Compute NFN(U) and NFN(g(U))
out = nfn(wtfeat)
out_of_perm = nfn(wtfeat_perm)
# Check invariance
return torch.allclose(out, out_of_perm, atol=1e-2)
if __name__ == "__main__":
set_seed(4)
torch.set_default_dtype(torch.float64) # for more precision
all_invariant = True
num_random_networks = 2500
num_random_group_actions = 1
invariant_cout = 0
for i in (pbar := tqdm(range(num_random_networks))):
invariant = test_layer_invariance_permutation_scale(num_random_group_actions)
invariant_cout += num_random_group_actions if invariant else 0
all_invariant = all_invariant and invariant
pbar.set_description(f"Invariant check passed: {invariant_cout}/{(i+1)*num_random_group_actions} models")