SARA: Self-Adaptive Robust Attention

From regular Transformers to light Transformers for Robotics & Beyond

[PAPER]

Best Robotic Manipulation Paper Award at ICRA 2024

Isabel Leal*, Krzysztof Choromanski*, Deepali Jain*, Avinava Dubey*, Jake Varley*, Michael Ryoo, Yao Lu, Frederick Liu, Vikas Sindhwani, 

Quan Vuong, Tamas Sarlos, Ken Oslund, Karol Hausman, Kanishka Rao

                                                                                                                                                                                    *equal contribution

Transformers architectures are also revolutionizing the field of Robotics. Large Language Models (LLMs) excel in high-level planning (see: SayCan, Palm-E, RT-1) and even more recent findings show that multi-modal architectures, such as vision-language models (VLMs) can be successfully used to learn natural language conditioned policies leveraging other modalities (e.g. vision). However the notoriously difficult problem that substantially limits practical applications of Transformers in Robotics is the prohibitively expensive latency of those massive models, while used for real hardware.

To address this challenge, we propose Transformers adaptation procedure, where the regular attention module (of time complexity quadratic in the input's sequence length) is replaced by the light (computationally efficient) attention block (e.g. linear attention from Performers). We propose the so-called up-training strategy, where either during the pre-training or fine-tuning of the original model, attention replacement is conducted. The training continues, but with the efficient attention mechanism. We observe that after a short transition phase, the model adapts quality-wise to the new attention module.  Since this strategy is particularly efficient when efficient linear-attention mechanism from Performers is applied, it becomes our default choice.

Up-training technique is general enough so that it can be applied beyond the scope of Robotics, however we found Robotics a particularly attractive application since fast inference is of critical importance there. We call the resulting mechanism a Self-Adaptive Robust Attention (or SARA). So far, we applied SARA in two settings: (a) the recently introduced Robotic Transformer architecture (RT-2), providing natural language conditioned policy  and leveraging PaLI-5B VLM model, (b) Point Cloud Transformer (PCT) architectures taking as input point clouds. In both settings, the policies are designed for the robotic manipulation. 

In the RT-2 setting, efficient attention is replacing regular attention in the ViT-encoder tower of the PaLI model. Our results show that even for relatively short input sequences for the ViT tower (i.e. lower resolution images or/and images transformed into relatively short sequences of patches), SARA provides substantial speed-ups (14%) and importantly, with no loss of the quality of the adapted policy. While combined with other methods, such as keeping a short history of frames and applying new action tokenization techniques, SARA provides 10%+ accuracy gains.  

Fig. 1: Performer RT-2 VLM policy conditioned on natural language commands in action. Performer's attention is replacing regular attention module in the ViT encoder of the PaLI-5B model which is the algorithmic backbone of the entire policy. The natural language commands are as follows (from upper left to bottom right, rows from left to right): pick water bottle,  pick brown chip bag, pick rxbar blueberry, pick orange can, pick pepsi can, pick octopus toy, pick apple, pick cold brew can, pick catnip toy. SARA applies Performer-ReLU attention.

Fig. 2:  The same setting as above, but this time the natural language commands are more complex and often require certain high-level planning capabilities. The natural language commands are as follows (from upper left to bottom right, rows from left to right): pick 7up can upright, open bottom drawer, move redbull can near rxbar blueberry, move orange can near green rice chip bag, knock water bottle over, close bottom drawer.

Fig. 3:  The scheme of some of the key elements of the PaLI-X backbone of RT-2 from the computational viewpoint, accompanied with the real robot performing text instruction using SARA model. This example applies three-frame history with each frame partitioned into four patches (in practice the number of patches is of course much larger). Frames are encoded via SARA variants of the ViTs (sViT). Text instruction is separately pre-processed by the text Transformer (TT). In the fuser, all resulting embeddings are concatenated and interacting with each other via self-attention. This self-attention block is yet another good candidate for injecting SARA variants. We leave it to future work.

In the PCT setting, SARA is applied to handle point cloud inputs, which even in the smaller scale case can easily reach the size of few thousand points. As, in the RT-2 case,  we do not see any degradation of the quality of the adapted policy, but benefit from the substantial inference gains (see: Video. 1). In fact the accuracy of grasping policies went up by 11% and 2x+ speedups were obtained for point cloud sizes >2K.

Copy of UptrainingPerformerGrasping-timelapse.mp4

Video 1:  The timelapse for the SARA-uptrained policy using PCT encoder. The performance of the policy is not harmed (in fact we observe quality gains), yet speeed-ups in inference are obtained. As for the RT-2,  SARA applies here Performer-ReLU attention.

Fig. 4 : SARA-RT in action for the picking task with point cloud input.

Fig. 5:  As in the left figure. The position of the object requires accurate manipulation.

Fig 6:  Left: Speed tests for SARA-PCT and regular PCT. Reported are mean inference times (averaged over l = 10 random seeds) for PCT encoders (as well as the corresponding standard deviations; see: shaded regions) as functions of the point cloud size. Right: Speed tests (on a CPU). Reported numbers are as in left sub-figure, but for PaLI-ViT encoders as functions of the resolution of the image (xvalue × xvalue) for the default 16 × 16 patch size.