CVPR 2025
Tengda Han†, Dilara Gokay†, Joseph Heyward†, Chuhan Zhang†,
Daniel Zoran†, Viorica Patraucean†, João Carreira†, Dima Damen† ‡, Andrew Zisserman† ♢
† Google DeepMind, ‡ University of Bristol, ♢ University of Oxford
Corresponding author: tengda@google.com
Here we provide two example implementations of Orthogonal-AdamW in PyTorch and JAX (optax).
These examples are build on top of the existing AdamW implementation from PyTorch codebase and optax codebase, with small changes (<100 lines) related to the orthogonal gradients as proposed in our paper.
from typing import cast, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.optim.optimizer import _use_grad_for_differentiable, _get_value, _get_scalar_dtype
import torch.nn.functional as F
def _single_tensor_orthogonal_adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
history_buffer_list: Optional[Tensor],
*,
amsgrad: bool,
beta: float,
beta1: float,
beta2: float,
lr: Union[Tensor, float],
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool,
):
assert grad_scale is None and found_inf is None
if torch.jit.is_scripting():
# this assert is due to JIT being dumb and not realizing that the ops below
# have overloads to handle both float and Tensor lrs, so we just assert it's
# a float since most people using JIT are using floats
assert isinstance(lr, float)
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
# update step
step_t += 1
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
# for history buffer
hbuf = history_buffer_list[i]
if hbuf is None:
hbuf = torch.clone(grad).detach()
history_buffer_list[i] = hbuf
lr_multiplier = 1
else:
# [out_channel, in_channel, *_]
new_hbuf = torch.clone(grad).detach()
# get grad A (past) and grad B (current), find component of B that is orthogonal to A
# (A * B) / (|A|*|A|) * A
# stable version: cos(A, B) * (|B| / |A|) * A
cos_sim = F.cosine_similarity(new_hbuf, hbuf, dim=0).unsqueeze(0)
normalized_a = F.normalize(hbuf, dim=0)
norm_b = torch.norm(new_hbuf, dim=0, keepdim=True)
proj_b_on_a = cos_sim * norm_b * normalized_a
grad = grad.add(proj_b_on_a, alpha=-1)
lr_multiplier = 1
hbuf.mul_(beta).add_(new_hbuf, alpha=1-beta)
# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
if capturable or differentiable:
step = step_t
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
step_size = lr / bias_correction1
step_size_neg = step_size.neg()
bias_correction2_sqrt = bias_correction2.sqrt()
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
if differentiable:
max_exp_avg_sq = max_exp_avg_sqs[i].clone()
else:
max_exp_avg_sq = max_exp_avg_sqs[i]
max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))
# Uses the max. for normalizing running avg. of gradient
# Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
# (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
denom = (
max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
).add_(eps / step_size_neg)
else:
denom = (
exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
).add_(eps / step_size_neg)
param.addcdiv_(exp_avg, denom, value=lr_multiplier)
else:
step = _get_value(step_t)
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
step_size = lr / bias_correction1
bias_correction2_sqrt = bias_correction2**0.5
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
else:
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
param.addcdiv_(exp_avg, denom, value=-step_size * lr_multiplier)
# The code below is copied & slightly modified from
# https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py
class OrthogonalAdamW(torch.optim.AdamW):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _init_group(
self,
group,
params_with_grad,
grads,
amsgrad,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
history_buffer_list,
state_steps,
):
has_complex = False
for p in group["params"]:
if p.grad is None:
continue
has_complex |= torch.is_complex(p)
if has_complex:
raise NotImplementedError
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("AdamW does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = (
torch.zeros(
(),
dtype=_get_scalar_dtype(is_fused=group["fused"]),
device=p.device,
)
if group["capturable"] or group["fused"]
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if group["amsgrad"]:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
if group["differentiable"] and state["step"].requires_grad:
raise RuntimeError(
"`requires_grad` is not supported for `step` in differentiable mode"
)
# Foreach without capturable does not support a tensor lr
if (
group["foreach"]
and isinstance(group["lr"], Tensor)
and not group["capturable"]
):
raise RuntimeError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
history_buffer_list.append(state.get("history_buffer"))
state_steps.append(state["step"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
max_exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
amsgrad: bool = group["amsgrad"]
beta1, beta2 = cast(Tuple[float, float], group["betas"])
history_buffer_list: List[Optional[Tensor]] = []
_ = self._init_group(
group,
params_with_grad,
grads,
amsgrad,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
history_buffer_list,
state_steps,
)
_single_tensor_orthogonal_adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
history_buffer_list=history_buffer_list,
amsgrad=amsgrad,
beta=0.9,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=group["maximize"],
capturable=group["capturable"],
differentiable=group["differentiable"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
)
for p, history_buffer in zip(params_with_grad, history_buffer_list):
state = self.state[p]
state["history_buffer"] = history_buffer
return loss
from collections.abc import Callable
from typing import Any, NamedTuple, Optional, Union
import chex
from jax import tree_util as jtu
import jax.numpy as jnp
import numpy as np
import optax
from optax import tree_utils as otu
def _get_orth_component(grad, v):
# first find the projection of grad (B) onto v (A)
# Proj_A{B} = (A * B) / (|A|*|A|) * A
# A numerically more stable version: cos(A, B) * (|B| / |A|) * A
eps = 1e-8
dot = jnp.sum(grad * v, axis=0, keepdims=True)
denom = jnp.maximum(
jnp.sqrt(
jnp.sum(grad**2, axis=0, keepdims=True)
* jnp.sum(v**2, axis=0, keepdims=True)
),
eps,
)
cos_sim = jnp.divide(dot, denom)
normalized_a = jnp.divide(v, jnp.maximum(jnp.linalg.norm(v, axis=0, keepdims=True), eps))
norm_b = jnp.linalg.norm(grad, axis=0, keepdims=True)
proj_b_on_a = jnp.multiply(jnp.multiply(cos_sim, norm_b), normalized_a)
# then get the orthogonal component
orth_grad = jnp.subtract(grad, proj_b_on_a)
return orth_grad
def scale_by_orth_adam(
b: float = 0.9,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[chex.ArrayDType] = None,
*,
nesterov: bool = False,
) -> optax.GradientTransformation:
mu_dtype = optax._src.utils.canonicalize_dtype(mu_dtype)
def init_fn(params):
mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment
nu = otu.tree_zeros_like(params) # Second moment
c = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment for orthogonal component
return ScaleByOrthogonalAdamState(
count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, velocity=c
)
def update_fn(updates, state, params=None):
del params
# orthogonal to velocity
orth_updates = jtu.tree_map(
_get_orth_component,
updates,
state.velocity,
)
c = otu.tree_update_moment(updates, state.velocity, b, 1)
updates = orth_updates
mu = otu.tree_update_moment(updates, state.mu, b1, 1)
nu = otu.tree_update_moment_per_elem_norm(updates, state.nu, b2, 2)
count_inc = optax.safe_increment(state.count)
if nesterov:
mu_hat = jtu.tree_map(
lambda m, g: b1 * m + (1 - b1) * g,
otu.tree_bias_correction(mu, b1, optax.safe_increment(count_inc)),
otu.tree_bias_correction(updates, b1, count_inc),
)
else:
mu_hat = otu.tree_bias_correction(mu, b1, count_inc)
nu_hat = otu.tree_bias_correction(nu, b2, count_inc)
updates = jtu.tree_map(
lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat
)
mu = otu.tree_cast(mu, mu_dtype)
return updates, ScaleByOrthogonalAdamState(
count=count_inc, mu=mu, nu=nu, velocity=c
)
return optax.GradientTransformation(init_fn, update_fn)
# The code below is copied & slightly modified from
# https://github.com/google-deepmind/optax/blob/main/optax/_src/alias.py#L599
class ScaleByOrthogonalAdamState(NamedTuple):
"""State for the Orthogonal Adam algorithm."""
count: chex.Array
mu: optax.Updates
nu: optax.Updates
velocity: optax.Updates
def orthogonal_adamw(
learning_rate: optax.ScalarOrSchedule,
b: float = 0.9,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[Any] = None,
weight_decay: float = 1e-4,
mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
*,
nesterov: bool = False,
) -> optax.GradientTransformation:
"""Orthogonal AdamW."""
return optax.chain(
scale_by_orth_adam(
b=b,
b1=b1,
b2=b2,
eps=eps,
eps_root=eps_root,
mu_dtype=mu_dtype,
nesterov=nesterov,
),
optax.add_decayed_weights(weight_decay, mask),
optax.scale_by_learning_rate(learning_rate),
)