CODE SHALL BE RELEASED ON ACCEPTANCE
With the increasing availability of open-source robotic data, imitation learning has emerged as a viable approach for both robot manipulation and locomotion. Currently, large generalized policies are trained to predict controls or trajectories using diffusion models, which have the desirable property of learning multimodal action distributions. However, generalizability comes with a cost, namely, larger model size and slower inference. This is especially an issue for robotic tasks that require high control frequency. Further, there is a known trade-off between performance and action horizon for Diffusion Policy (DP), a popular model for generating trajectories: fewer diffusion queries accumulate greater trajectory tracking errors. For these reasons, it is common practice to run these models at high inference frequency, subject to robot computational constraints. To address these limitations, we propose Latent Weight Diffusion (LWD), a method that uses diffusion to generate closed-loop policies (weights for neural policies) for robotic tasks, rather than generating trajectories. Learning the behavior distribution through parameter space over trajectory space offers two key advantages: longer action horizons (fewer diffusion queries) & robustness to perturbations while retaining high performance; and a lower inference compute cost. To this end, we show that LWD has higher success rates than DP when the action horizon is longer and when stochastic perturbations exist in the environment. Furthermore, LWD achieves multitask performance comparable to DP while requiring just ~1/45th of the inference-time FLOPS per step.
The generated policy can run a closed loop for longer action horizons, thereby allowing for fewer diffusion model queries.
Below are demonstrations of LWD trained on pusht trajectory data, at different action horizons
Action Horizon 64
Action Horizon 128
Action Horizon 246
Drawer Close
Drawer open
Button press
Peg insert side
Push
Pick and place
Window close
Door open
Reach
Window open
PushT task: LWD with perturbation: 50
PushT task: DP with perturbation: 50
Can task: LWD with perturbation: 3
Can task: DP with perturbation: 3
Lift task: LWD with perturbation: 3
DP with perturbation: 3