RetinaNET
RetinaNet: Un Modello One-Stage per l'Object Detection
RetinaNet è un'architettura per il riconoscimento di oggetti introdotta da Facebook AI Research (FAIR) nel 2017. Il modello ha rivoluzionato l'object detection one-stage, superando le prestazioni dei modelli two-stage come Faster R-CNN grazie all'uso della Focal Loss.
1. Perché RetinaNet?
Nei modelli di rilevazione di oggetti, esistono due principali approcci:
- Two-stage (Faster R-CNN) → Genera regioni candidate (Region Proposal Network, RPN), poi le classifica. Accurato ma lento.
- One-stage (SSD, YOLO) → Predice direttamente classi e bounding box. Molto veloce ma meno accurato.
🔥 Problema: I modelli one-stage soffrono lo sbilanciamento tra classi positive (oggetti) e negative (sfondo), portando la rete a ignorare gli esempi difficili.
✅ Soluzione di RetinaNet: la Focal Loss, che bilancia meglio gli esempi facili e difficili.
2. Architettura di RetinaNet
2.1. Feature Pyramid Network (FPN)
- Utilizza una Feature Pyramid Network (FPN) per gestire oggetti di diverse dimensioni.
- Estrae feature map da vari livelli della rete backbone (es. ResNet50, ResNet101).
2.2. Sub-network per la Classificazione
- Predice la probabilità di ciascun oggetto.
- Usa la Focal Loss per migliorare il training.
2.3. Sub-network per la Regressione dei Bounding Box
- Stima le coordinate del bounding box.
- Usa la Smooth L1 Loss per affinare le predizioni.
3. La Focal Loss: Il Cuore di RetinaNet
RetinaNet introduce la Focal Loss, una modifica della Cross-Entropy Loss per gestire lo squilibrio tra classi:
FL(p_t) = -α_t (1 - p_t)ᵞ log(p_t)
Dove:
- (1 - p_t)ᵞ → Dà più peso agli esempi difficili.
- γ (gamma) → Controlla l'intensità del bilanciamento (tipicamente 2 o 3).
- α_t → Parametro per bilanciare classi positive e negative.
✅ Effetti:
- Riduce il contributo degli esempi facili.
- Permette alla rete di imparare meglio dagli esempi difficili.
4. Implementazione di RetinaNet in PyTorch
Ecco un'implementazione semplificata di RetinaNet in PyTorch:
import torch
import torch.nn as nn
import torchvision.models as models
class RetinaNet(nn.Module):
def __init__(self, num_classes):
super(RetinaNet, self).__init__()
self.backbone = models.resnet50(pretrained=True)
self.fpn = FeaturePyramidNetwork()
self.classification_head = ClassificationHead(num_classes)
self.regression_head = RegressionHead()
def forward(self, x):
features = self.backbone(x)
pyramid_features = self.fpn(features)
cls_preds = self.classification_head(pyramid_features)
bbox_preds = self.regression_head(pyramid_features)
return cls_preds, bbox_preds
class FeaturePyramidNetwork(nn.Module):
def forward(self, x):
return x
class ClassificationHead(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.conv = nn.Conv2d(256, num_classes, kernel_size=3, padding=1)
def forward(self, x):
return self.conv(x)
class RegressionHead(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(256, 4, kernel_size=3, padding=1)
def forward(self, x):
return self.conv(x)
# Creazione del modello RetinaNet
model = RetinaNet(num_classes=80)
print(model)
🔹 Nota: Questa è una versione semplificata. Framework come Detectron2 forniscono implementazioni complete di RetinaNet.
5. Confronto con Altri Modelli
| Modello | Tipo | Velocità | Accuratezza |
|---|---|---|---|
| Faster R-CNN | Two-stage | ❌ Lento | ✅ Alta |
| SSD | One-stage | ✅ Veloce | ❌ Inferiore |
| YOLO | One-stage | 🚀 Super veloce | 🔄 Bilanciato |
| RetinaNet | One-stage | ✅ Medio | ✅ Alta (grazie alla Focal Loss) |
6. Conclusione
✅ RetinaNet è stato un punto di svolta nei modelli di object detection one-stage, risolvendo il problema dello sbilanciamento tra classi grazie alla Focal Loss.
✅ Combina la velocità dei modelli one-stage con l'accuratezza dei two-stage.
✅ Ancora oggi è utilizzato in molte applicazioni di computer vision.
📌 Se vuoi implementare RetinaNet in un progetto specifico o confrontarlo con YOLO, fammelo sapere! 🚀
Commenti
Posta un commento