0013_best_vision_models_for_fine_tuning

Introduction

best vision models for training from scratch vs for fine tuning

In a recent notebook I tried to answer the question “Which image models are best?” This showed which models in Ross Wightman’s PyTorch Image Models (timm) were the fastest and most accurate for training from scratch with Imagenet.

However, this is not what most of us use models for. Most of us fine-tune pretrained models. Therefore, what most of us really want to know is which models are the fastest and most accurate for fine-tuning. However, this analysis has not, to my knowledge, previously existed.

Therefore I teamed up with Thomas Capelle of Weights and Biases to answer this question. In this notebook, I present our results.

The analysis

how to evaluate or compare models for fine tuning

There are two key dimensions on which datasets can vary when it comes to how well they fine-tune a model:

  1. How similar they are to the pre-trained model’s dataset
  2. How large they are.

Therefore, we decided to test on two datasets that were very different on both of these axes. We tested pre-trained models that were trained on Imagenet, and tested fine-tuning on two different datasets:

  1. The Oxford IIT-Pet Dataset, which is very similar to Imagenet. Imagenet contains many pictures of animals, and each picture is a photo in which the animal is the main subject. IIT-Pet contains nearly 15,000 images, that are also of this type.
  2. The Kaggle Planet sample contains 1,000 satellite images of Earth. There are no images of this kind in Imagenet.

So these two datasets are of very different sizes, and very different in terms of their similarity to Imagenet. Furthermore, they have different types of labels - Planet is a multi-label problem, whereas IIT-Pet is a single label problem.

how to use Weights and Biases with fastai

To test the fine-tuning accuracy of different models, Thomas put together this script. The basic script contains the standard 4 lines of code needed for fastai image recognition models, plus some code to handle various configuration options, such as learning rate and batch size. It was particularly easy to handle in fastai since fastai supports all timm models directly.

Then, to allow us to easily try different configuration options, Thomas created Weights and Biases (wandb) YAML files such as this one. This takes advantage of the convenient wandb “sweeps” feature which tries a range of different levels of a model input and tracks the results.

wandb makes it really easy for a group of people to run these kinds of analyses on whatever GPUs they have access to. When you create a sweep using the command-line wandb client, it gives you a command to run to have a computer run experiments for the project. You run that same command on each computer where you want to run experiments. The wandb client automatically ensures that each computer runs different parts of the sweep, and has each on report back its results to the wandb server. You can look at the progress in the wandb web GUI at any time during or after the run. I’ve got three GPUs in my PC at home, so I ran three copies of the client, with each using a different GPU. Thomas also ran the client on a Paperspace Gradient server.

I liked this approach because I could start and stop the clients any time I wanted, and wandb would automatically handle keeping all the results in sync. When I restarted a client, it would automatically grab from the server whatever the next set of sweep settings were needed. Furthermore, the integration in fastai is really exceptional, thanks particularly to Boris Dayma, who worked tirelessly to ensure that wandb automatically tracks every aspect of all fastai data processing, model architectures, and optimisation.

Hyperparameters

how to decide hyperparameters to create all the possible and meaningful models for testing

We decided to try out all the timm models which had reasonable performance on timm, and which are capable of working with 224x224 px images. We ended up with a list of 86 models and variants to try.

Our first step was to find a good set of hyper-parameters for each model variant and for each dataset. Our experience at fast.ai has been that there’s generally not much difference between models and datasets in terms of what hyperparameter settings work well – and that experience was repeated in this project. Based on some initial sweeps across a smaller number of representative models, on which we found little variation in optimal hyperparameters, in our final sweep we included all combinations of the following options:

  • Learning rate (AdamW): 0.008 and 0.02
  • Resize method: Squish
  • Pooling type: Concat and Average Pooling

For other parameters, we used defaults that we’ve previously found at fast.ai to be reliable across a range of models and datasets (see the fastai docs for details).

Analysis

how to analyse the sweep results from W&B

Let’s take a look at the data. I’ve put a CSV of the results into a gist:

from fastai.vision.all import *
import plotly.express as px

url = 'https://gist.githubusercontent.com/jph00/959aaf8695e723246b5e21f3cd5deb02/raw/sweep.csv'

For each model variant and dataset, for each hyperparameter setting, we did three runs. For the final sweep, we just used the hyperparameter settings listed above.

For each model variant and dataset, I create a group with the minimum error and fit time, and GPU memory use if used. I use the minimum because there might be some reason that a particular run didn’t do so well (e.g. maybe there was some resource contention), and I’m mainly interested in knowing what the best case results for a model can be.

I create a “score” which, somewhat arbitrarily combines the accuracy and speed into a single number. I tried a few options until I came up with something that closely matched my own opinions about the tradeoffs between the two. (Feel free of course to fork this notebook and adjust how that’s calculated.)

df = pd.read_csv(url)
df['family'] = df.model_name.str.extract('^([a-z]+?(?:v2)?)(?:\d|_|$)')
df.loc[df.family=='swinv2', 'family'] = 'swin'
pt_all = df.pivot_table(values=['error_rate','fit_time','GPU_mem'], index=['dataset', 'family', 'model_name'],
                        aggfunc=np.min).reset_index()
pt_all['score'] = pt_all.error_rate*(pt_all.fit_time+80)

IIT Pet

Here’s the top 15 models on the IIT Pet dataset, ordered by score:

pt = pt_all[pt_all.dataset=='pets'].sort_values('score').reset_index(drop=True)
pt.head(15)
dataset family model_name GPU_mem error_rate fit_time score
0 pets convnext convnext_tiny_in22k 2.660156 0.044655 94.557838 7.794874
1 pets swin swin_s3_tiny_224 3.126953 0.041949 112.282200 8.065961
2 pets convnext convnext_tiny 2.660156 0.047361 92.761599 8.182216
3 pets vit vit_small_r26_s32_224 3.367188 0.045332 103.240067 8.306554
4 pets mobilevit mobilevit_s 2.781250 0.046685 100.770686 8.439222
5 pets resnetv2 resnetv2_50x1_bit_distilled 3.892578 0.047361 105.952172 8.806939
6 pets vit vit_small_patch16_224 2.111328 0.054804 80.739517 8.809135
7 pets swin swin_tiny_patch4_window7_224 2.796875 0.048038 105.797015 8.925296
8 pets swin swinv2_cr_tiny_ns_224 3.302734 0.042625 129.435368 8.927222
9 pets resnetrs resnetrs50 2.419922 0.047361 109.549398 8.977309
10 pets levit levit_384 1.699219 0.054127 86.199098 8.995895
11 pets resnet resnet26d 1.412109 0.060216 69.395598 8.996078
12 pets convnext convnext_tiny_hnf 2.970703 0.049391 103.014163 9.039269
13 pets regnety regnety_006 0.914062 0.052097 93.912189 9.060380
14 pets levit levit_256 1.031250 0.056157 82.682410 9.135755

As you can see, the convnext, swin, and vit families are fairly dominent. The excellent showing of convnext_tiny matches my view that we should think of this as our default baseline for image recognition today. It’s fast, accurate, and not too much of a memory hog. (And according to Ross Wightman, it could be even faster if NVIDIA and PyTorch make some changes to better optimise the operations it relies on!)

vit_small_patch16 is also a good option – it’s faster and leaner on memory than convnext_tiny, although there is some performance cost too.

Interestingly, resnets are still a great option – especially the resnet26d variant, which is the fastest in our top 15.

Here’s a quick visual representation of the seven model families which look best in the above analysis (the “fit lines” are just there to help visually show where the different families are – they don’t necessarily actually follow a linear fit):

w,h = 900,700
faves = ['vit','convnext','resnet','levit', 'regnetx', 'swin']
pt2 = pt[pt.family.isin(faves)]
px.scatter(pt2, width=w, height=h, x='fit_time', y='error_rate', color='family', hover_name='model_name', trendline="ols",)

This chart shows that there’s a big drop-off in performance towards the far left. It seems like there’s a big compromise if we want the fastest possible model. It also seems that the best models in terms of accuracy, convnext and swin, aren’t able to make great use of the larger capacity of larger models. So an ensemble of smaller models may be effective in some situations.

Note that vit doesn’t include any larger/slower models, since they only work with larger images. We would recommend trying larger models on your dataset if you have larger images and the resources to handle them.

I particularly like using fast and small models, since I wanted to be able to iterate rapidly to try lots of ideas (see this notebook for more on this). Here’s the top models (based on accuracy) that are smaller and faster than the median model:

pt.query("(GPU_mem<2.7) & (fit_time<110)").sort_values("error_rate").head(15).reset_index(drop=True)
dataset family model_name GPU_mem error_rate fit_time score
0 pets convnext convnext_tiny_in22k 2.660156 0.044655 94.557838 7.794874
1 pets convnext convnext_tiny 2.660156 0.047361 92.761599 8.182216
2 pets resnetrs resnetrs50 2.419922 0.047361 109.549398 8.977309
3 pets regnety regnety_006 0.914062 0.052097 93.912189 9.060380
4 pets levit levit_384 1.699219 0.054127 86.199098 8.995895
5 pets vit vit_small_patch16_224 2.111328 0.054804 80.739517 8.809135
6 pets resnet resnet50d 2.037109 0.055480 92.989515 9.597521
7 pets levit levit_256 1.031250 0.056157 82.682410 9.135755
8 pets regnetx regnetx_016 1.369141 0.059540 88.658087 10.041888
9 pets resnet resnet26d 1.412109 0.060216 69.395598 8.996078
10 pets levit levit_192 0.781250 0.060893 82.385787 9.888177
11 pets resnetblur resnetblur50 2.195312 0.061570 96.008735 10.836803
12 pets mobilevit mobilevit_xs 2.349609 0.062923 98.758011 11.247972
13 pets vit vit_tiny_patch16_224 1.074219 0.064276 65.670202 9.363104
14 pets regnety regnety_008 1.044922 0.064953 94.741903 11.349943

…and here’s the top 15 models that are the very fastest and most memory efficient:

pt.query("(GPU_mem<1.6) & (fit_time<90)").sort_values("error_rate").head(15).reset_index(drop=True)
dataset family model_name GPU_mem error_rate fit_time score
0 pets levit levit_256 1.031250 0.056157 82.682410 9.135755
1 pets regnetx regnetx_016 1.369141 0.059540 88.658087 10.041888
2 pets resnet resnet26d 1.412109 0.060216 69.395598 8.996078
3 pets levit levit_192 0.781250 0.060893 82.385787 9.888177
4 pets vit vit_tiny_patch16_224 1.074219 0.064276 65.670202 9.363104
5 pets vit vit_small_patch32_224 0.775391 0.065629 68.478869 9.744556
6 pets efficientnet efficientnet_es_pruned 1.507812 0.066306 69.601242 9.919432
7 pets efficientnet efficientnet_es 1.507812 0.066306 69.822634 9.934112
8 pets resnet resnet26 1.291016 0.067659 64.398096 9.769834
9 pets resnet resnet34 0.951172 0.070365 66.932345 10.338949
10 pets resnet resnet34d 1.056641 0.070365 71.631269 10.669590
11 pets regnetx regnetx_008 0.976562 0.070365 81.937185 11.394770
12 pets regnetx regnetx_006 0.730469 0.071042 78.592555 11.266723
13 pets mobilevit mobilevit_xxs 1.152344 0.073072 88.449456 12.308891
14 pets levit levit_128 0.650391 0.077808 82.819645 12.668646

ResNet-RS performs well here, with lower memory use than convnext but nonetheless high accuracy. A version trained on the larger Imagenet-22k dataset (like convnext_tiny_in22k would presumably do even better, and may top the charts!)

RegNet-y is impressively miserly in terms of memory use, whilst still achieving high accuracy.

Planet

Here’s the top-15 for Planet:

pt = pt_all[pt_all.dataset=='planet'].sort_values('score').reset_index(drop=True)
pt.head(15)
dataset family model_name GPU_mem error_rate fit_time score
0 planet vit vit_small_patch16_224 2.121094 0.035000 20.075387 3.502641
1 planet swin swin_base_patch4_window7_224_in22k 6.283203 0.031177 37.115593 3.651255
2 planet vit vit_small_patch32_224 0.775391 0.038529 17.817797 3.768855
3 planet convnext convnext_tiny_in22k 2.660156 0.037647 22.014424 3.840538
4 planet vit vit_base_patch32_224 2.755859 0.038823 19.060116 3.845859
5 planet swin swinv2_cr_tiny_ns_224 3.302734 0.036176 26.547731 3.854518
6 planet vit vit_base_patch32_224_sam 2.755859 0.039412 18.567447 3.884713
7 planet swin swin_tiny_patch4_window7_224 2.796875 0.036765 25.790094 3.889339
8 planet vit vit_base_patch16_224_miil 4.853516 0.036471 28.131062 3.943604
9 planet vit vit_base_patch16_224 4.853516 0.036176 29.274090 3.953148
10 planet convnext convnext_small_in22k 4.210938 0.036471 28.446879 3.955122
11 planet vit vit_small_r26_s32_224 3.367188 0.038529 23.008444 3.968847
12 planet vit vit_tiny_patch16_224 1.070312 0.040588 18.103888 3.981860
13 planet swin swin_small_patch4_window7_224 4.486328 0.035588 31.928643 3.983339
14 planet swin swin_s3_tiny_224 3.126953 0.038235 24.459997 3.994054

Interestingly, the results look quite different: vit and swin take most of the top positions in terms of the combination of accuracy and speed. vit_small_patch32 is a particular standout with its extremely low memory use and also the fastest in the top 15.

Because this dataset is so different to Imagenet, what we’re testing here is more about how quickly and data-efficiently a model can learn new features that it hasn’t seen before. We can see that the transformers-based architectures able to do that better than any other model. convnext_tiny still puts in a good performance, but it’s a bit let down by it’s relatively poor speed – hopefully we’ll see NVIDIA speed it up in the future, because in theory it’s a light-weight architecture which should be able to do better.

The downside of vit and swin models, like most transformers-based models, is that they can only handle one input image size. Of course, we can always squish or crop or pad our input images to the required size, but this can have a significant impact on performance. For instance, recently in looking at the Kaggle Paddy Disease competition I’ve found that the ability of convnext models to handle dynamically sized inputs to be very convenient.

Here’s a chart of the seven top families, this time for the Planet dataset:

pt2 = pt[pt.family.isin(faves)]
px.scatter(pt2, width=w, height=h, x='fit_time', y='error_rate', color='family', hover_name='model_name', trendline="ols")

One striking feature is that for this dataset, there’s little correlation between model size and performance. Regnetx and vit are the only families that show much of a relationship here. This suggests that if you have data that’s very different to your pretrained model’s data, that you might want to focus on smaller models. This makes intuitive sense, since these models have more new features to learn, and if they’re too big they’re either going to overfit, or fail to utilise their capacity effectively.

Here’s the most accurate small and fast models on the Planet dataset:

pt.query("(GPU_mem<2.7) & (fit_time<25)").sort_values("error_rate").head(15).reset_index(drop=True)
dataset family model_name GPU_mem error_rate fit_time score
0 planet vit vit_small_patch16_224 2.121094 0.035000 20.075387 3.502641
1 planet convnext convnext_tiny_in22k 2.660156 0.037647 22.014424 3.840538
2 planet vit vit_small_patch32_224 0.775391 0.038529 17.817797 3.768855
3 planet convnext convnext_tiny 2.660156 0.039706 23.180807 4.096878
4 planet vit vit_tiny_patch16_224 1.070312 0.040588 18.103888 3.981860
5 planet mobilevit mobilevit_xxs 1.152344 0.041471 20.329964 4.160743
6 planet vit vit_tiny_r_s16_p8_224 0.785156 0.041765 20.520312 4.198198
7 planet resnetblur resnetblur50 2.195312 0.042353 21.530770 4.300124
8 planet resnet resnet18 0.634766 0.042647 17.189185 4.144828
9 planet resnetrs resnetrs50 2.419922 0.043823 23.490568 4.535317
10 planet resnet resnet26 1.289062 0.044118 17.832233 4.316120
11 planet regnetx regnetx_016 1.367188 0.044118 22.212399 4.509375
12 planet resnet resnet26d 1.412109 0.044412 20.341083 4.456320
13 planet regnety regnety_006 0.914062 0.045000 22.715365 4.622193
14 planet levit levit_384 1.699219 0.045588 21.410115 4.623098

convnext_tiny is still the most accurate option amongst architectures that don’t have a fixed resolution. Resnet 18 has very low memory use, is fast, and is still quite accurate.

Here’s the subset of ultra lean/fast models on the Planet dataset:

pt.query("(GPU_mem<1.6) & (fit_time<21)").sort_values("error_rate").head(15).reset_index(drop=True)
dataset family model_name GPU_mem error_rate fit_time score
0 planet vit vit_small_patch32_224 0.775391 0.038529 17.817797 3.768855
1 planet vit vit_tiny_patch16_224 1.070312 0.040588 18.103888 3.981860
2 planet mobilevit mobilevit_xxs 1.152344 0.041471 20.329964 4.160743
3 planet vit vit_tiny_r_s16_p8_224 0.785156 0.041765 20.520312 4.198198
4 planet resnet resnet18 0.634766 0.042647 17.189185 4.144828
5 planet resnet resnet26 1.289062 0.044118 17.832233 4.316120
6 planet resnet resnet26d 1.412109 0.044412 20.341083 4.456320
7 planet efficientnet efficientnet_es 1.507812 0.046176 17.470632 4.500840
8 planet regnetx regnetx_008 0.974609 0.048235 20.212098 4.833754
9 planet resnet resnet34 0.949219 0.048823 19.884937 4.876730
10 planet efficientnet efficientnet_es_pruned 1.507812 0.050294 17.644619 4.910943
11 planet regnety regnety_002 0.490234 0.050882 20.417092 5.109463
12 planet regnetx regnetx_002 0.462891 0.051176 18.394935 5.035501
13 planet regnetx regnetx_006 0.730469 0.051765 19.354445 5.143050
14 planet efficientnet efficientnet_lite0 1.494141 0.052059 16.381403 5.017507

Conclusions

It really seems like it’s time for a changing of the guard when it comes to computer vision models. There are, as at the time of writing (June 2022) three very clear winners when it comes to fine-tuning pretrained models:

Tanishq Abraham studied the top results of a recent Kaggle computer vision competition and found that the above three approaches did indeed appear to the best approaches. However, there were two other architectures which were also very strong in that competition, but which aren’t in our top models above:

BEiT isn’t there because it’s too big to fit on my GPU (even the smallest BEiT model is too big!) This is fixable with gradient accumulation, so perhaps in a future iteration we’ll add it in. EfficientNet didn’t have any variants that were fast and accurate enough to appear in the top 15 on either dataset. However, it’s notoriously fiddly to train, so there might well be some set of hyperparameters that would work for these datasets. Having said that, I’m mainly interested in knowing which architectures can be trained quickly and easily without to much mucking around, so perhaps EfficientNet doesn’t really fit here anyway!

Thankfully, it’s easy to try lots of different models, especially if you use fastai and timm, because it’s literally as easy as changing the model name in one place in your code. Your existing hyperparameters are most likely going to continue to work fine regardless of what model you try. And it’s particularly easy if you use wandb, since you can start and stop experiments at any time and they’ll all be automatically tracked and managed for you.

If you found this notebook useful, please remember to click the little up-arrow at the top to upvote it, since I like to know when people have found my work useful, and it helps others find it too. And if you have any questions or comments, please pop them below – I read every comment I receive!