# # Load the model ============== if whole network structure definition is not changed, we could use this directly
# model = SS_Style(**init_params).to(device)
# ckpt = torch.load(ckpt_path)
# model.load_state_dict(ckpt["model"])
# style_encoder = model.style_encoder
# Load the model ============== if we are pretty sure the sytle_encoder is not changed, we could use this instead
model = SS_Style(**init_params).to(device)
style_encoder = model.style_encoder
# pretrained_dict = model.state_dict()
style_encoder_dict = style_encoder.state_dict()
ckpt = torch.load(ckpt_path)
pretrained_dict = ckpt["model"]
# 1. filter out unnecessary keys
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in style_encoder_dict}
pretrained_dict = {k[14:]: v for k, v in pretrained_dict.items() if k[:13]=="style_encoder"}
# 2. overwrite entries in the existing state dict
style_encoder_dict.update(pretrained_dict)
# 3. load the new state dict
style_encoder.load_state_dict(pretrained_dict)