from fastai.vision.all import *
import plotly.express as px
= 'https://gist.githubusercontent.com/jph00/959aaf8695e723246b5e21f3cd5deb02/raw/sweep.csv' url
0013_best_vision_models_for_fine_tuning
Introduction
best vision models for training from scratch vs for fine tuning
In a recent notebook I tried to answer the question “Which image models are best?” This showed which models in Ross Wightman’s PyTorch Image Models (timm) were the fastest and most accurate for training from scratch with Imagenet.
However, this is not what most of us use models for. Most of us fine-tune pretrained models. Therefore, what most of us really want to know is which models are the fastest and most accurate for fine-tuning. However, this analysis has not, to my knowledge, previously existed.
Therefore I teamed up with Thomas Capelle of Weights and Biases to answer this question. In this notebook, I present our results.
The analysis
how to evaluate or compare models for fine tuning
There are two key dimensions on which datasets can vary when it comes to how well they fine-tune a model:
- How similar they are to the pre-trained model’s dataset
- How large they are.
Therefore, we decided to test on two datasets that were very different on both of these axes. We tested pre-trained models that were trained on Imagenet, and tested fine-tuning on two different datasets:
- The Oxford IIT-Pet Dataset, which is very similar to Imagenet. Imagenet contains many pictures of animals, and each picture is a photo in which the animal is the main subject. IIT-Pet contains nearly 15,000 images, that are also of this type.
- The Kaggle Planet sample contains 1,000 satellite images of Earth. There are no images of this kind in Imagenet.
So these two datasets are of very different sizes, and very different in terms of their similarity to Imagenet. Furthermore, they have different types of labels - Planet is a multi-label problem, whereas IIT-Pet is a single label problem.
how to use Weights and Biases with fastai
To test the fine-tuning accuracy of different models, Thomas put together this script. The basic script contains the standard 4 lines of code needed for fastai image recognition models, plus some code to handle various configuration options, such as learning rate and batch size. It was particularly easy to handle in fastai since fastai supports all timm models directly.
Then, to allow us to easily try different configuration options, Thomas created Weights and Biases (wandb) YAML files such as this one. This takes advantage of the convenient wandb “sweeps” feature which tries a range of different levels of a model input and tracks the results.
wandb makes it really easy for a group of people to run these kinds of analyses on whatever GPUs they have access to. When you create a sweep using the command-line wandb client, it gives you a command to run to have a computer run experiments for the project. You run that same command on each computer where you want to run experiments. The wandb client automatically ensures that each computer runs different parts of the sweep, and has each on report back its results to the wandb server. You can look at the progress in the wandb web GUI at any time during or after the run. I’ve got three GPUs in my PC at home, so I ran three copies of the client, with each using a different GPU. Thomas also ran the client on a Paperspace Gradient server.
I liked this approach because I could start and stop the clients any time I wanted, and wandb would automatically handle keeping all the results in sync. When I restarted a client, it would automatically grab from the server whatever the next set of sweep settings were needed. Furthermore, the integration in fastai is really exceptional, thanks particularly to Boris Dayma, who worked tirelessly to ensure that wandb automatically tracks every aspect of all fastai data processing, model architectures, and optimisation.
Hyperparameters
how to decide hyperparameters to create all the possible and meaningful models for testing
We decided to try out all the timm models which had reasonable performance on timm, and which are capable of working with 224x224 px images. We ended up with a list of 86 models and variants to try.
Our first step was to find a good set of hyper-parameters for each model variant and for each dataset. Our experience at fast.ai has been that there’s generally not much difference between models and datasets in terms of what hyperparameter settings work well – and that experience was repeated in this project. Based on some initial sweeps across a smaller number of representative models, on which we found little variation in optimal hyperparameters, in our final sweep we included all combinations of the following options:
- Learning rate (AdamW): 0.008 and 0.02
- Resize method: Squish
- Pooling type: Concat and Average Pooling
For other parameters, we used defaults that we’ve previously found at fast.ai to be reliable across a range of models and datasets (see the fastai docs for details).
Analysis
how to analyse the sweep results from W&B
Let’s take a look at the data. I’ve put a CSV of the results into a gist:
For each model variant and dataset, for each hyperparameter setting, we did three runs. For the final sweep, we just used the hyperparameter settings listed above.
For each model variant and dataset, I create a group with the minimum error and fit time, and GPU memory use if used. I use the minimum because there might be some reason that a particular run didn’t do so well (e.g. maybe there was some resource contention), and I’m mainly interested in knowing what the best case results for a model can be.
I create a “score” which, somewhat arbitrarily combines the accuracy and speed into a single number. I tried a few options until I came up with something that closely matched my own opinions about the tradeoffs between the two. (Feel free of course to fork this notebook and adjust how that’s calculated.)
= pd.read_csv(url)
df 'family'] = df.model_name.str.extract('^([a-z]+?(?:v2)?)(?:\d|_|$)')
df[=='swinv2', 'family'] = 'swin'
df.loc[df.family= df.pivot_table(values=['error_rate','fit_time','GPU_mem'], index=['dataset', 'family', 'model_name'],
pt_all =np.min).reset_index()
aggfunc'score'] = pt_all.error_rate*(pt_all.fit_time+80) pt_all[
IIT Pet
Here’s the top 15 models on the IIT Pet dataset, ordered by score:
= pt_all[pt_all.dataset=='pets'].sort_values('score').reset_index(drop=True)
pt 15) pt.head(
dataset | family | model_name | GPU_mem | error_rate | fit_time | score | |
---|---|---|---|---|---|---|---|
0 | pets | convnext | convnext_tiny_in22k | 2.660156 | 0.044655 | 94.557838 | 7.794874 |
1 | pets | swin | swin_s3_tiny_224 | 3.126953 | 0.041949 | 112.282200 | 8.065961 |
2 | pets | convnext | convnext_tiny | 2.660156 | 0.047361 | 92.761599 | 8.182216 |
3 | pets | vit | vit_small_r26_s32_224 | 3.367188 | 0.045332 | 103.240067 | 8.306554 |
4 | pets | mobilevit | mobilevit_s | 2.781250 | 0.046685 | 100.770686 | 8.439222 |
5 | pets | resnetv2 | resnetv2_50x1_bit_distilled | 3.892578 | 0.047361 | 105.952172 | 8.806939 |
6 | pets | vit | vit_small_patch16_224 | 2.111328 | 0.054804 | 80.739517 | 8.809135 |
7 | pets | swin | swin_tiny_patch4_window7_224 | 2.796875 | 0.048038 | 105.797015 | 8.925296 |
8 | pets | swin | swinv2_cr_tiny_ns_224 | 3.302734 | 0.042625 | 129.435368 | 8.927222 |
9 | pets | resnetrs | resnetrs50 | 2.419922 | 0.047361 | 109.549398 | 8.977309 |
10 | pets | levit | levit_384 | 1.699219 | 0.054127 | 86.199098 | 8.995895 |
11 | pets | resnet | resnet26d | 1.412109 | 0.060216 | 69.395598 | 8.996078 |
12 | pets | convnext | convnext_tiny_hnf | 2.970703 | 0.049391 | 103.014163 | 9.039269 |
13 | pets | regnety | regnety_006 | 0.914062 | 0.052097 | 93.912189 | 9.060380 |
14 | pets | levit | levit_256 | 1.031250 | 0.056157 | 82.682410 | 9.135755 |
As you can see, the convnext, swin, and vit families are fairly dominent. The excellent showing of convnext_tiny
matches my view that we should think of this as our default baseline for image recognition today. It’s fast, accurate, and not too much of a memory hog. (And according to Ross Wightman, it could be even faster if NVIDIA and PyTorch make some changes to better optimise the operations it relies on!)
vit_small_patch16
is also a good option – it’s faster and leaner on memory than convnext_tiny
, although there is some performance cost too.
Interestingly, resnets are still a great option – especially the resnet26d
variant, which is the fastest in our top 15.
Here’s a quick visual representation of the seven model families which look best in the above analysis (the “fit lines” are just there to help visually show where the different families are – they don’t necessarily actually follow a linear fit):
= 900,700
w,h = ['vit','convnext','resnet','levit', 'regnetx', 'swin']
faves = pt[pt.family.isin(faves)]
pt2 =w, height=h, x='fit_time', y='error_rate', color='family', hover_name='model_name', trendline="ols",) px.scatter(pt2, width
This chart shows that there’s a big drop-off in performance towards the far left. It seems like there’s a big compromise if we want the fastest possible model. It also seems that the best models in terms of accuracy, convnext and swin, aren’t able to make great use of the larger capacity of larger models. So an ensemble of smaller models may be effective in some situations.
Note that vit
doesn’t include any larger/slower models, since they only work with larger images. We would recommend trying larger models on your dataset if you have larger images and the resources to handle them.
I particularly like using fast and small models, since I wanted to be able to iterate rapidly to try lots of ideas (see this notebook for more on this). Here’s the top models (based on accuracy) that are smaller and faster than the median model:
"(GPU_mem<2.7) & (fit_time<110)").sort_values("error_rate").head(15).reset_index(drop=True) pt.query(
dataset | family | model_name | GPU_mem | error_rate | fit_time | score | |
---|---|---|---|---|---|---|---|
0 | pets | convnext | convnext_tiny_in22k | 2.660156 | 0.044655 | 94.557838 | 7.794874 |
1 | pets | convnext | convnext_tiny | 2.660156 | 0.047361 | 92.761599 | 8.182216 |
2 | pets | resnetrs | resnetrs50 | 2.419922 | 0.047361 | 109.549398 | 8.977309 |
3 | pets | regnety | regnety_006 | 0.914062 | 0.052097 | 93.912189 | 9.060380 |
4 | pets | levit | levit_384 | 1.699219 | 0.054127 | 86.199098 | 8.995895 |
5 | pets | vit | vit_small_patch16_224 | 2.111328 | 0.054804 | 80.739517 | 8.809135 |
6 | pets | resnet | resnet50d | 2.037109 | 0.055480 | 92.989515 | 9.597521 |
7 | pets | levit | levit_256 | 1.031250 | 0.056157 | 82.682410 | 9.135755 |
8 | pets | regnetx | regnetx_016 | 1.369141 | 0.059540 | 88.658087 | 10.041888 |
9 | pets | resnet | resnet26d | 1.412109 | 0.060216 | 69.395598 | 8.996078 |
10 | pets | levit | levit_192 | 0.781250 | 0.060893 | 82.385787 | 9.888177 |
11 | pets | resnetblur | resnetblur50 | 2.195312 | 0.061570 | 96.008735 | 10.836803 |
12 | pets | mobilevit | mobilevit_xs | 2.349609 | 0.062923 | 98.758011 | 11.247972 |
13 | pets | vit | vit_tiny_patch16_224 | 1.074219 | 0.064276 | 65.670202 | 9.363104 |
14 | pets | regnety | regnety_008 | 1.044922 | 0.064953 | 94.741903 | 11.349943 |
…and here’s the top 15 models that are the very fastest and most memory efficient:
"(GPU_mem<1.6) & (fit_time<90)").sort_values("error_rate").head(15).reset_index(drop=True) pt.query(
dataset | family | model_name | GPU_mem | error_rate | fit_time | score | |
---|---|---|---|---|---|---|---|
0 | pets | levit | levit_256 | 1.031250 | 0.056157 | 82.682410 | 9.135755 |
1 | pets | regnetx | regnetx_016 | 1.369141 | 0.059540 | 88.658087 | 10.041888 |
2 | pets | resnet | resnet26d | 1.412109 | 0.060216 | 69.395598 | 8.996078 |
3 | pets | levit | levit_192 | 0.781250 | 0.060893 | 82.385787 | 9.888177 |
4 | pets | vit | vit_tiny_patch16_224 | 1.074219 | 0.064276 | 65.670202 | 9.363104 |
5 | pets | vit | vit_small_patch32_224 | 0.775391 | 0.065629 | 68.478869 | 9.744556 |
6 | pets | efficientnet | efficientnet_es_pruned | 1.507812 | 0.066306 | 69.601242 | 9.919432 |
7 | pets | efficientnet | efficientnet_es | 1.507812 | 0.066306 | 69.822634 | 9.934112 |
8 | pets | resnet | resnet26 | 1.291016 | 0.067659 | 64.398096 | 9.769834 |
9 | pets | resnet | resnet34 | 0.951172 | 0.070365 | 66.932345 | 10.338949 |
10 | pets | resnet | resnet34d | 1.056641 | 0.070365 | 71.631269 | 10.669590 |
11 | pets | regnetx | regnetx_008 | 0.976562 | 0.070365 | 81.937185 | 11.394770 |
12 | pets | regnetx | regnetx_006 | 0.730469 | 0.071042 | 78.592555 | 11.266723 |
13 | pets | mobilevit | mobilevit_xxs | 1.152344 | 0.073072 | 88.449456 | 12.308891 |
14 | pets | levit | levit_128 | 0.650391 | 0.077808 | 82.819645 | 12.668646 |
ResNet-RS performs well here, with lower memory use than convnext but nonetheless high accuracy. A version trained on the larger Imagenet-22k dataset (like convnext_tiny_in22k
would presumably do even better, and may top the charts!)
RegNet-y is impressively miserly in terms of memory use, whilst still achieving high accuracy.
Planet
Here’s the top-15 for Planet:
= pt_all[pt_all.dataset=='planet'].sort_values('score').reset_index(drop=True)
pt 15) pt.head(
dataset | family | model_name | GPU_mem | error_rate | fit_time | score | |
---|---|---|---|---|---|---|---|
0 | planet | vit | vit_small_patch16_224 | 2.121094 | 0.035000 | 20.075387 | 3.502641 |
1 | planet | swin | swin_base_patch4_window7_224_in22k | 6.283203 | 0.031177 | 37.115593 | 3.651255 |
2 | planet | vit | vit_small_patch32_224 | 0.775391 | 0.038529 | 17.817797 | 3.768855 |
3 | planet | convnext | convnext_tiny_in22k | 2.660156 | 0.037647 | 22.014424 | 3.840538 |
4 | planet | vit | vit_base_patch32_224 | 2.755859 | 0.038823 | 19.060116 | 3.845859 |
5 | planet | swin | swinv2_cr_tiny_ns_224 | 3.302734 | 0.036176 | 26.547731 | 3.854518 |
6 | planet | vit | vit_base_patch32_224_sam | 2.755859 | 0.039412 | 18.567447 | 3.884713 |
7 | planet | swin | swin_tiny_patch4_window7_224 | 2.796875 | 0.036765 | 25.790094 | 3.889339 |
8 | planet | vit | vit_base_patch16_224_miil | 4.853516 | 0.036471 | 28.131062 | 3.943604 |
9 | planet | vit | vit_base_patch16_224 | 4.853516 | 0.036176 | 29.274090 | 3.953148 |
10 | planet | convnext | convnext_small_in22k | 4.210938 | 0.036471 | 28.446879 | 3.955122 |
11 | planet | vit | vit_small_r26_s32_224 | 3.367188 | 0.038529 | 23.008444 | 3.968847 |
12 | planet | vit | vit_tiny_patch16_224 | 1.070312 | 0.040588 | 18.103888 | 3.981860 |
13 | planet | swin | swin_small_patch4_window7_224 | 4.486328 | 0.035588 | 31.928643 | 3.983339 |
14 | planet | swin | swin_s3_tiny_224 | 3.126953 | 0.038235 | 24.459997 | 3.994054 |
Interestingly, the results look quite different: vit and swin take most of the top positions in terms of the combination of accuracy and speed. vit_small_patch32
is a particular standout with its extremely low memory use and also the fastest in the top 15.
Because this dataset is so different to Imagenet, what we’re testing here is more about how quickly and data-efficiently a model can learn new features that it hasn’t seen before. We can see that the transformers-based architectures able to do that better than any other model. convnext_tiny
still puts in a good performance, but it’s a bit let down by it’s relatively poor speed – hopefully we’ll see NVIDIA speed it up in the future, because in theory it’s a light-weight architecture which should be able to do better.
The downside of vit and swin models, like most transformers-based models, is that they can only handle one input image size. Of course, we can always squish or crop or pad our input images to the required size, but this can have a significant impact on performance. For instance, recently in looking at the Kaggle Paddy Disease competition I’ve found that the ability of convnext models to handle dynamically sized inputs to be very convenient.
Here’s a chart of the seven top families, this time for the Planet dataset:
= pt[pt.family.isin(faves)]
pt2 =w, height=h, x='fit_time', y='error_rate', color='family', hover_name='model_name', trendline="ols") px.scatter(pt2, width
One striking feature is that for this dataset, there’s little correlation between model size and performance. Regnetx and vit are the only families that show much of a relationship here. This suggests that if you have data that’s very different to your pretrained model’s data, that you might want to focus on smaller models. This makes intuitive sense, since these models have more new features to learn, and if they’re too big they’re either going to overfit, or fail to utilise their capacity effectively.
Here’s the most accurate small and fast models on the Planet dataset:
"(GPU_mem<2.7) & (fit_time<25)").sort_values("error_rate").head(15).reset_index(drop=True) pt.query(
dataset | family | model_name | GPU_mem | error_rate | fit_time | score | |
---|---|---|---|---|---|---|---|
0 | planet | vit | vit_small_patch16_224 | 2.121094 | 0.035000 | 20.075387 | 3.502641 |
1 | planet | convnext | convnext_tiny_in22k | 2.660156 | 0.037647 | 22.014424 | 3.840538 |
2 | planet | vit | vit_small_patch32_224 | 0.775391 | 0.038529 | 17.817797 | 3.768855 |
3 | planet | convnext | convnext_tiny | 2.660156 | 0.039706 | 23.180807 | 4.096878 |
4 | planet | vit | vit_tiny_patch16_224 | 1.070312 | 0.040588 | 18.103888 | 3.981860 |
5 | planet | mobilevit | mobilevit_xxs | 1.152344 | 0.041471 | 20.329964 | 4.160743 |
6 | planet | vit | vit_tiny_r_s16_p8_224 | 0.785156 | 0.041765 | 20.520312 | 4.198198 |
7 | planet | resnetblur | resnetblur50 | 2.195312 | 0.042353 | 21.530770 | 4.300124 |
8 | planet | resnet | resnet18 | 0.634766 | 0.042647 | 17.189185 | 4.144828 |
9 | planet | resnetrs | resnetrs50 | 2.419922 | 0.043823 | 23.490568 | 4.535317 |
10 | planet | resnet | resnet26 | 1.289062 | 0.044118 | 17.832233 | 4.316120 |
11 | planet | regnetx | regnetx_016 | 1.367188 | 0.044118 | 22.212399 | 4.509375 |
12 | planet | resnet | resnet26d | 1.412109 | 0.044412 | 20.341083 | 4.456320 |
13 | planet | regnety | regnety_006 | 0.914062 | 0.045000 | 22.715365 | 4.622193 |
14 | planet | levit | levit_384 | 1.699219 | 0.045588 | 21.410115 | 4.623098 |
convnext_tiny
is still the most accurate option amongst architectures that don’t have a fixed resolution. Resnet 18 has very low memory use, is fast, and is still quite accurate.
Here’s the subset of ultra lean/fast models on the Planet dataset:
"(GPU_mem<1.6) & (fit_time<21)").sort_values("error_rate").head(15).reset_index(drop=True) pt.query(
dataset | family | model_name | GPU_mem | error_rate | fit_time | score | |
---|---|---|---|---|---|---|---|
0 | planet | vit | vit_small_patch32_224 | 0.775391 | 0.038529 | 17.817797 | 3.768855 |
1 | planet | vit | vit_tiny_patch16_224 | 1.070312 | 0.040588 | 18.103888 | 3.981860 |
2 | planet | mobilevit | mobilevit_xxs | 1.152344 | 0.041471 | 20.329964 | 4.160743 |
3 | planet | vit | vit_tiny_r_s16_p8_224 | 0.785156 | 0.041765 | 20.520312 | 4.198198 |
4 | planet | resnet | resnet18 | 0.634766 | 0.042647 | 17.189185 | 4.144828 |
5 | planet | resnet | resnet26 | 1.289062 | 0.044118 | 17.832233 | 4.316120 |
6 | planet | resnet | resnet26d | 1.412109 | 0.044412 | 20.341083 | 4.456320 |
7 | planet | efficientnet | efficientnet_es | 1.507812 | 0.046176 | 17.470632 | 4.500840 |
8 | planet | regnetx | regnetx_008 | 0.974609 | 0.048235 | 20.212098 | 4.833754 |
9 | planet | resnet | resnet34 | 0.949219 | 0.048823 | 19.884937 | 4.876730 |
10 | planet | efficientnet | efficientnet_es_pruned | 1.507812 | 0.050294 | 17.644619 | 4.910943 |
11 | planet | regnety | regnety_002 | 0.490234 | 0.050882 | 20.417092 | 5.109463 |
12 | planet | regnetx | regnetx_002 | 0.462891 | 0.051176 | 18.394935 | 5.035501 |
13 | planet | regnetx | regnetx_006 | 0.730469 | 0.051765 | 19.354445 | 5.143050 |
14 | planet | efficientnet | efficientnet_lite0 | 1.494141 | 0.052059 | 16.381403 | 5.017507 |
Conclusions
It really seems like it’s time for a changing of the guard when it comes to computer vision models. There are, as at the time of writing (June 2022) three very clear winners when it comes to fine-tuning pretrained models:
Tanishq Abraham studied the top results of a recent Kaggle computer vision competition and found that the above three approaches did indeed appear to the best approaches. However, there were two other architectures which were also very strong in that competition, but which aren’t in our top models above:
- EfficientNet and v2
- BEiT.
BEiT isn’t there because it’s too big to fit on my GPU (even the smallest BEiT model is too big!) This is fixable with gradient accumulation, so perhaps in a future iteration we’ll add it in. EfficientNet didn’t have any variants that were fast and accurate enough to appear in the top 15 on either dataset. However, it’s notoriously fiddly to train, so there might well be some set of hyperparameters that would work for these datasets. Having said that, I’m mainly interested in knowing which architectures can be trained quickly and easily without to much mucking around, so perhaps EfficientNet doesn’t really fit here anyway!
Thankfully, it’s easy to try lots of different models, especially if you use fastai and timm, because it’s literally as easy as changing the model name in one place in your code. Your existing hyperparameters are most likely going to continue to work fine regardless of what model you try. And it’s particularly easy if you use wandb, since you can start and stop experiments at any time and they’ll all be automatically tracked and managed for you.
If you found this notebook useful, please remember to click the little up-arrow at the top to upvote it, since I like to know when people have found my work useful, and it helps others find it too. And if you have any questions or comments, please pop them below – I read every comment I receive!