Implementation of various AI papers for image classification
Model Architectures
- TResNet
- MobileNetV2
- MobileNetV3
- ResNetV2
- ResNetV2 + Stochastic Depth
- ResNeXt
- SeNet
- DenseNet
Other Features
- Step Learning Rate (LR) decay schedule
- HTD (Hyperbolic-Tangent LR Decay schedule)
- Cosine LR decay schedule
- Cutout
- Mixup
- Cutmix
- Mish
- AntiAliasDownsampling
GPU: RTX3090 @1800MHz | FP16 + XLA autoclastering
Epochs: 150
Batch Size: 1024 (or b=512)
Augmentation: random l/r flip -> 4px shift in x/y -> Cutmix
Cos lr schedule 0.5 -> 0.001, 10 epoch warmup
Optmizer: SGD nesterov m=0.9
Model \ Augmentation | Basic | Mixup | Cutout | Cutmix |
---|---|---|---|---|
ResNet50 | 93.46% | 94.64% | 94.70% | 94.77% |
MobileNetV3S 192px w=2 | 94.35% | 95.14% | 95.44% | 95.85% |
⠀⠀⠀⠀⠀⠀⠀Model⠀⠀⠀⠀⠀⠀⠀ | Top1 Accuracy | Param count | Training (imgs/sec) |
Inference (imgs/sec) |
||
---|---|---|---|---|---|---|
TResNet | ||||||
TResNet M | ||||||
64px | 92.51% | 3.2 M | 17 540 | 44 435 | ||
96x | 95.03% | 9 969 | 26 356 | |||
128px | 95.84% | 5 882 | 16 937 | |||
160px | 95.84% | 4 161 | 12 046 | |||
192px | 95.89% | 3 087 | 8 645 | |||
TResNet L | Overfit, no improvments over TResNet M | |||||
TResNet XL | ||||||
MobileNetV3 | ||||||
MobileNetV3S | ||||||
128px | 93.72% | 0.95 M | 11 845 | 66 137 | ||
160px | 94.41% | 9 177 | 55 245 | |||
192px | 94.86% | 10 675 | 43 226 | |||
224px | 95.53% | 8 040 | 35 209 | |||
128px w=2 | 95.10% | 3.6 M | 7 254 | 44 722 | ||
128px w=4 b=512 | 95.99% | 13.9 M | 2 608 | 23 516 | ||
160px w=2 | 95.56% | 3.6 M | 5 512 | 31 467 | ||
192px w=2 | 96.02% | 7 652 | 26 653 | |||
224px w=2 | 96.22% | 5 711 | 20 760 | |||
224px w=2 b=512 | 96.30% | 5 379 | 19 635 | |||
MobileNetV3L | ||||||
128px | 95.57% | 3.0 M | 5 765 | 34 980 | ||
160px | 96.07% | 4 303 | 25 000 | |||
192px b=512 | 96.58% | 4 531 | 17 142 | |||
224px b=512 | 96.52% | 3 494 | 13 591 | |||
128px w=2 | 96.06% | 11.7 M | 3 286 | 20 087 | ||
192px w=2 b=512 | 96.95% | 2 509 | 9 733 | |||
MobileNetV2 | ||||||
96px | 94.45% | 2.3 M | 5 201 | 42 184 | ||
128px | 95.10% | 7 739 | 27 789 | |||
160px | 95.52% | 5 377 | 19 118 | |||
192px | 95.78% | 4 057 | 15 478 | |||
224px batch=512 | 96.20% | 2 963 | 11 179 | |||
128px w=2 | 96.27% | 7.95 M | 4 510 | 16 414 | ||
ResNetV2 | ||||||
ResNet18 mish | 92.81% 93.53% | 0.69 M | 39 127 -4% | 99 028 -4% | ||
ResNet34 mish | 93.69% 94.26% | 1.3 M | 25 534 -4% | 75 071 -4% | ||
ResNet35 mish | 94.09% 94.42% | 0.87 M | 17 304 -5% | 58 520 -4% | ||
ResNet50 mish | 94.57% 95.05% | 1.3 M | 12 939 -5% | 45 775 -3% | ||
ResNet101 mish | 95.15% 95.57% | 2.5 M | 8 469 -6% | 31 813 -5% | ||
ResNet152 mish | 95.62% 95.99% | 3.5 M | 5 954 -7% | 23 211 -3% | ||
ResNet170 mish | 95.68% 96.18% | 4.2 M | 5 113 -8% | 20 246 -5% | ||
+mish +lr=.75 | 96.44% | |||||
WideResNet34 w=4 | 96.40% | 21.1 M | 5 605 | 20 382 | ||
w=8 | 96.91% | 84.5 M | 1 773 | 6 539 | ||
WideResNet170 +mish w=2 | 97.18% | 16.6 M | 2 511 | 9 392 | ||
SeNet | ||||||
SeNet35 mish | 94.33% 94.7% | 0.98 M | 15 162-8% | 52 390-9% | ||
SeNet50 mish | 94.76% 95.17% | 1.5 M | 11 277-8% | 39 142-5% | ||
SeNet101 mish | 95.43% 96.03% | 2.8 M | 7 223-9% | 25 303-3% | ||
+mish w=2 | 96.69% | 11.2 M | 3 985 | 13 820 | ||
SeNet152 mish | 95.78% 96.49% | 3.95 M | 4 976-8% | 18 747-6% | ||
SeNet170 +mish w=2 b=768 | 97.07% | 18.6 M | 2 253 | 8 258 | ||
ResNeXt | ||||||
ResNeXt35_16x4d mish | 95.87% 96.37% | 3.6 M | 1 893 -1% | 20 215 -1% | ||
ResNeXt50_16x4d mish | 96.26% 96.45% | 5.5 M | 1 436 -1% | 15 064 -1% | ||
ResNeXt101_16x4d mish | 96.39% 96.74% | 10.6 M | 990 -1% | 11 063 -2% | ||
DenseNet | ||||||
DenseNet52k12 | 93.75% | 0.27 M | 7 209 | 31 956 | ||
DenseNet100k12 | 95.4% | 0.79 M | 2 734 | 12 119 | ||
DenseNet100k16 | 95.87% | 1.4 M | 2 394 | 11 114 | ||
DenseNet160k12 b=512 | 96.43% | 1.8 M | 1 212 | 4 860 |
- Cos = Cosine Learning Rate Decay schedule.
From Stochastic Gradient Descent with Warm Restarts - Mish = Self regularized non-monotonic activation function, f(x) = x*tanh(softplus(x)).
From Mish: A Self Regularized Non-Monotonic Activation Function