3D Vision Transformer: Challenges


Challenges for designing 3D Vision Transformers

Published on January 10, 2022 by Hongyi WANG

vision transformer 3d segmentation

3 min READ

Recently I have been working on designing 3D Vision Transformers for medical images. I have found many problems in my experiments, and I will list some of them in this blog.

Challenges 1: Computational Cost

Computational cost can be mainly divided into three parts: number of parameters, FLOPs, and memory cost. More parameters will lead to a larger model size, while higher FLOPs mean more float point calculation is needed to get the model output. Most previous lightweight works focus on these two metrics and the video memory cost is often neglected. However, video memory cost is actually also a very important issue. It determines what kind of devices are capable of training and testing a certain model. Many models these days need high-end GPUs with up to 24~48G video memory to train. And if you want to use a larger batch size, you have to get several GPUs to use “torch.nn.DataParallel”. These three metrics are not always correlated, since a model with fewer parameters may still cost more graphical memory.

High computational cost is the primary challenge for 3D vision transformers (actually it is also a big challenge for 2D vision transformers and that’s why many people are working on lightweight transformers). We have to find less costly operations to substitute SA to lower the cost and meanwhile design efficient structures to reduce the memory cost.

Challenges 2: Pure Transformer / CNN + Transformer?

Pure Transformer has been proven to have a similar modeling ability as CNNs. This causes pure transformer structures to become more and more popular nowadays. However, such modeling ability only comes when the training samples are sufficient. Compared with pure transformers, CNNs have better structure prior, making them able to work without large-scale pretraining. Therefore, to build an EFFICIENT Vision Transformer, CNN+Transformer may be a better option. CNN can be used in the shallow layers to process the large-sized features (since the number of convolution parameters has nothing to do with the feature size) and transformer can be used to process the smaller features in deeper layers. Therefore, fine-grained local details can be efficiently exploited by convolution, while the global-wise modeling only needs to be conducted on features after pooling. However, with sufficient computational resources and large-scale datasets, many works show that pure transformers can achieve better performance with some modifications, such as Swin-Transformer. More and more works show that maybe we should jump out of the canonical form of network structures and find new backbones.

Challenge 3: Modifications for Visual Tasks

This is also a critical issue for 2D vision transformers. Transformers are first proposed for NLP tasks, so directly applying it to visual tasks may cause incompatibility. For example, patches are widely used in vision transformers to convert an image into a sequence. But is it really necessary? Will it harm the model performance especially for dense prediction tasks?

Personlly I believe vision transformers should abandon patches (or in other words, use 1*1 patches) to further improve the segmentation performance. The use of patches in vision transformers is more like a way to reduce the number of tokens, trying to embed the local information of a small region into one single token. Such an operation lacks flexibility and may cause negative influence.